1use std::collections::HashMap;
19
20use crate::{Graph, NodeId, Op};
21
22#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25pub enum Phase {
26 Prologue,
28 SteadyState,
30 Epilogue,
32}
33
34impl Phase {
35 pub fn order(self) -> u8 {
36 match self {
37 Self::Prologue => 0,
38 Self::SteadyState => 1,
39 Self::Epilogue => 2,
40 }
41 }
42}
43
44#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
46#[derive(Debug, Clone, Default, PartialEq, Eq)]
47pub struct PhaseSchedule {
48 map: HashMap<NodeId, Phase>,
49}
50
51impl PhaseSchedule {
52 pub fn new() -> Self {
53 Self::default()
54 }
55
56 pub fn set(&mut self, node: NodeId, phase: Phase) {
57 self.map.insert(node, phase);
58 }
59
60 pub fn get(&self, node: NodeId) -> Option<Phase> {
61 self.map.get(&node).copied()
62 }
63
64 pub fn iter(&self) -> impl Iterator<Item = (NodeId, Phase)> + '_ {
65 self.map.iter().map(|(&id, &p)| (id, p))
66 }
67
68 pub fn len(&self) -> usize {
69 self.map.len()
70 }
71
72 pub fn is_empty(&self) -> bool {
73 self.map.is_empty()
74 }
75
76 pub fn nodes_in(&self, phase: Phase) -> Vec<NodeId> {
79 self.nodes_in_ordered(phase, None)
80 }
81
82 pub fn nodes_in_ordered(&self, phase: Phase, schedule: Option<&[NodeId]>) -> Vec<NodeId> {
83 if let Some(order) = schedule {
84 return order
85 .iter()
86 .copied()
87 .filter(|id| self.get(*id) == Some(phase))
88 .collect();
89 }
90 let mut v: Vec<NodeId> = self
91 .map
92 .iter()
93 .filter_map(|(&id, &p)| if p == phase { Some(id) } else { None })
94 .collect();
95 v.sort();
96 v
97 }
98}
99
100pub fn derive_phases(graph: &Graph) -> PhaseSchedule {
102 let mut sched = PhaseSchedule::new();
103 let n = graph.len();
104 if n == 0 {
105 return sched;
106 }
107
108 let mut last_compute_step: Option<usize> = None;
109 let mut last_sample_step: Option<usize> = None;
110 for (step, node) in graph.nodes().iter().enumerate() {
111 match &node.op {
112 Op::Sample { .. } | Op::TopK { .. } => {
113 last_sample_step = Some(step);
114 }
115 Op::MatMul
116 | Op::FusedMatMulBiasAct { .. }
117 | Op::Attention { .. }
118 | Op::FusedAttentionBlock { .. }
119 | Op::FusedTransformerLayer { .. }
120 | Op::DotGeneral { .. }
121 | Op::GroupedMatMul
122 | Op::DequantGroupedMatMul { .. }
123 | Op::DequantMoEWeights { .. }
124 | Op::LoraMatMul { .. }
125 | Op::DequantMatMul { .. }
126 | Op::GatedDeltaNet { .. } => {
127 last_compute_step = Some(step);
128 }
129 _ => {}
130 }
131 }
132
133 for (step, node) in graph.nodes().iter().enumerate() {
134 let phase = match &node.op {
135 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => Phase::Prologue,
136 Op::Sample { .. } | Op::TopK { .. } => Phase::Epilogue,
137 _ => {
138 if let Some(last) = last_sample_step {
139 if step > last
140 || (last_compute_step.is_some() && Some(step) > last_compute_step)
141 {
142 Phase::Epilogue
143 } else {
144 Phase::SteadyState
145 }
146 } else if let Some(last) = last_compute_step {
147 if step > last {
148 Phase::Epilogue
149 } else {
150 Phase::SteadyState
151 }
152 } else {
153 Phase::SteadyState
154 }
155 }
156 };
157 sched.set(node.id, phase);
158 }
159 sched
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::{DType, Shape};
166
167 #[test]
168 fn derive_phases_classifies_typical_graph() {
169 let f = DType::F32;
170 let mut g = Graph::new("derive");
171 let x = g.input("x", Shape::new(&[1, 8], f));
172 let w = g.param("w", Shape::new(&[8, 4], f));
173 let mm = g.matmul(x, w, Shape::new(&[1, 4], f));
174 let s = g.sample(mm, 0, 1.0, 1.0, 0, Shape::new(&[1], f));
175 g.set_outputs(vec![s]);
176
177 let sched = derive_phases(&g);
178 assert_eq!(sched.get(x), Some(Phase::Prologue));
179 assert_eq!(sched.get(w), Some(Phase::Prologue));
180 assert_eq!(sched.get(mm), Some(Phase::SteadyState));
181 assert_eq!(sched.get(s), Some(Phase::Epilogue));
182 }
183
184 #[test]
185 fn phase_ordering_is_deterministic() {
186 assert!(Phase::Prologue.order() < Phase::SteadyState.order());
187 assert!(Phase::SteadyState.order() < Phase::Epilogue.order());
188 }
189}