morok_schedule/linearize/cfg_context.rs
1//! Control flow graph context for linearization.
2//!
3//! CFGContext analyzes the control flow structure of a kernel AST and computes
4//! ordering edges between sibling RANGE operations at the same nesting level.
5//!
6//! This implementation matches Tinygrad's CFGContext (linearizer.py:59-91).
7
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use morok_ir::UOp;
12use morok_ir::op::Op;
13use morok_ir::uop::core::UOpKey;
14
15/// Control flow graph context for linearization.
16///
17/// Tracks ordering edges between sibling RANGE operations to ensure
18/// proper linearization order when loops are at the same nesting level.
19///
20/// Based on Tinygrad's CFGContext which tracks three relationships between ranges:
21/// - **nested**: END y is a dependency of END x AND RANGE x is a dependency of END y
22/// - **dependent**: END y is a dependency of END x AND RANGE x is NOT a dependency of END y
23/// - **independent**: END y is NOT a dependency of END x
24///
25/// # Control Flow Edges
26///
27/// When multiple ENDs exist at the same nesting level (siblings), they need to be
28/// ordered consistently. CFGContext computes edges where:
29/// - Each RANGE points to its predecessor (either the parent's RANGE or another END)
30/// - Edges ensure sequential execution of sibling loops
31///
32/// # Example
33///
34/// ```text
35/// RANGE(i) → ... → END(i) // first loop
36/// RANGE(j) → ... → END(j) // second loop (sibling)
37/// RANGE(k) → ... → END(k) // third loop (sibling)
38///
39/// CFGContext edges:
40/// RANGE(j) → END(i) // j's RANGE depends on i's END
41/// RANGE(k) → END(j) // k's RANGE depends on j's END
42/// ```
43#[derive(Debug, Default)]
44pub struct CFGContext {
45 /// Maps RANGE → predecessor (previous sibling END or parent's RANGE).
46 ///
47 /// The predecessor is the operation that must complete before
48 /// this RANGE can begin execution.
49 #[allow(clippy::mutable_key_type)]
50 pub edges: HashMap<UOpKey, Arc<UOp>>,
51}
52
53impl CFGContext {
54 /// Build a control flow context from a kernel AST.
55 ///
56 /// Analyzes the graph to find sibling ENDs at the same nesting level
57 /// and creates ordering edges between their RANGEs.
58 ///
59 /// # Algorithm (from Tinygrad linearizer.py:59-91)
60 ///
61 /// 1. Build transitive deps map (RANGE/END add themselves to deps)
62 /// 2. Build nesting map: which END/SINK nests each END
63 /// 3. Group siblings by parent
64 /// 4. Order siblings by dependency count (fewer deps = earlier)
65 /// 5. Create edges: RANGE of later sibling → predecessor (END or parent's RANGE)
66 pub fn new(sink: &Arc<UOp>) -> Self {
67 let mut ctx = Self::default();
68
69 // Collect all nodes via toposort
70 let nodes = sink.toposort();
71
72 // Step 1: Build dependency sets for each node
73 // RANGE and END add themselves to deps
74 // deps[u] = set of RANGE/END UOps that u transitively depends on
75 #[allow(clippy::mutable_key_type)]
76 let mut deps: HashMap<UOpKey, HashMap<UOpKey, ()>> = HashMap::new();
77
78 for node in &nodes {
79 // Get deps from sources
80 #[allow(clippy::mutable_key_type)]
81 let mut node_deps: HashMap<UOpKey, ()> = HashMap::new();
82 node.op().map_child(|src| {
83 if let Some(src_deps) = deps.get(&UOpKey(src.clone())) {
84 node_deps.extend(src_deps.iter().map(|(k, v)| (k.clone(), *v)));
85 }
86 });
87
88 // RANGE and END add themselves
89 if matches!(node.op(), Op::Range { .. } | Op::End { .. }) {
90 node_deps.insert(UOpKey(node.clone()), ());
91 }
92
93 deps.insert(UOpKey(node.clone()), node_deps);
94 }
95
96 // Step 2: Build nesting map
97 // For each END, find which END/SINK it is nested inside
98 // END x is nested in END/SINK u if:
99 // - u depends on x (x is in deps[u])
100 // - u is SINK, OR u's RANGE (u.src[1]) is in deps[x]
101 // - x hasn't been assigned a nesting parent yet
102 #[allow(clippy::mutable_key_type)]
103 let mut nesting: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
104
105 for node in &nodes {
106 if matches!(node.op(), Op::End { .. } | Op::Sink { .. })
107 && let Some(node_deps) = deps.get(&UOpKey(node.clone()))
108 {
109 for dep_key in node_deps.keys() {
110 // Only consider END nodes
111 if !matches!(dep_key.0.op(), Op::End { .. }) {
112 continue;
113 }
114
115 // Skip self-references (END cannot be nested inside itself)
116 if dep_key.0.id == node.id {
117 continue;
118 }
119
120 // Skip if already assigned
121 if nesting.contains_key(dep_key) {
122 continue;
123 }
124
125 // Check nesting condition
126 let is_nested = if matches!(node.op(), Op::Sink { .. }) {
127 true
128 } else if let Op::End { ranges, .. } = node.op() {
129 // Check if node's RANGE is in dep's dependencies
130 // node.src[1] in Tinygrad is the RANGE - we get it from ranges
131 if let Some(range) = ranges.first() {
132 deps.get(dep_key).is_some_and(|dep_deps| dep_deps.contains_key(&UOpKey(range.clone())))
133 } else {
134 false
135 }
136 } else {
137 false
138 };
139
140 if is_nested {
141 nesting.insert(dep_key.clone(), node.clone());
142 }
143 }
144 }
145 }
146
147 // Step 3: Group siblings by parent
148 #[allow(clippy::mutable_key_type)]
149 let mut siblings: HashMap<UOpKey, Vec<Arc<UOp>>> = HashMap::new();
150 for (end_key, parent) in &nesting {
151 siblings.entry(UOpKey(parent.clone())).or_default().push(end_key.0.clone());
152 }
153
154 // Step 4 & 5: Order siblings and create edges
155 for (parent, sibling_ends) in siblings {
156 if sibling_ends.is_empty() {
157 continue;
158 }
159
160 // Order by dependency count on other siblings (fewer deps = earlier)
161 let mut ordered: Vec<Arc<UOp>> = sibling_ends.clone();
162 ordered.sort_by_key(|end| {
163 if let Some(end_deps) = deps.get(&UOpKey(end.clone())) {
164 sibling_ends.iter().filter(|sib| end_deps.contains_key(&UOpKey((*sib).clone()))).count()
165 } else {
166 0
167 }
168 });
169
170 // Create edges
171 // If parent is SINK: zip(order, order[1:])
172 // If parent is END: zip([parent.src[1]] + order, order)
173 // where parent.src[1] is the parent's RANGE
174 let zipped: Vec<(Arc<UOp>, Arc<UOp>)> = if matches!(parent.0.op(), Op::Sink { .. }) {
175 // Pair consecutive siblings
176 ordered.windows(2).map(|w| (w[0].clone(), w[1].clone())).collect()
177 } else {
178 // Get parent's RANGE
179 if let Op::End { ranges, .. } = parent.0.op() {
180 if let Some(parent_range) = ranges.first() {
181 // Pair: parent_range → first, then consecutive siblings
182 let mut pairs = vec![(parent_range.clone(), ordered[0].clone())];
183 pairs.extend(ordered.windows(2).map(|w| (w[0].clone(), w[1].clone())));
184 pairs
185 } else {
186 ordered.windows(2).map(|w| (w[0].clone(), w[1].clone())).collect()
187 }
188 } else {
189 ordered.windows(2).map(|w| (w[0].clone(), w[1].clone())).collect()
190 }
191 };
192
193 // Create edges: y's RANGE → x (predecessor)
194 for (x, y) in zipped {
195 // y is an END, get its RANGE from y.src[1] (or ranges field)
196 let y_range = if let Op::End { ranges, .. } = y.op() { ranges.first().cloned() } else { None };
197
198 if let Some(range) = y_range {
199 // Tinygrad: assert y.src[1] not in x.backward_slice_with_self
200 // A cycle here indicates a malformed kernel structure.
201 assert!(
202 !x.backward_slice_ids().contains(&range.id),
203 "CFGContext: edge would create cycle (range {} → predecessor {}). \
204 This indicates a malformed kernel — see Tinygrad linearizer.py:81",
205 range.id,
206 x.id
207 );
208 tracing::trace!(range_id = range.id, predecessor_id = x.id, "CFGContext: creating edge");
209 ctx.edges.insert(UOpKey(range), x);
210 }
211 }
212 }
213
214 // Step 6: Create edges for reduce RANGEs to wait for their init STOREs.
215 //
216 // When AFTER(passthrough, [init_store, reduce_range]) appears (from reduce_to_acc),
217 // the heap-based linearization may schedule RANGE before STORE due to priority
218 // tie-breaking (RANGE +5 vs STORE +1). This edge ensures zero-init STORE appears
219 // before the reduce loop, which is required for correctness.
220 //
221 // NOTE: Tinygrad's linearizer.py lacks this fix and has known issues
222 // (see comment at line 85-86: "TODO: this can happen! it causes infinite loop
223 // in shufflenet"). Our explicit edge creation is more robust.
224 for node in &nodes {
225 if let Op::After { deps, .. } = node.op() {
226 let stores: Vec<_> = deps.iter().filter(|d| matches!(d.op(), Op::Store { .. })).collect();
227 let ranges: Vec<_> = deps.iter().filter(|d| matches!(d.op(), Op::Range { .. })).collect();
228
229 for store in &stores {
230 for range in &ranges {
231 let would_cycle = store.backward_slice_ids().contains(&range.id);
232 if !would_cycle {
233 tracing::trace!(range_id = range.id, store_id = store.id, "CFGContext: reduce init edge");
234 ctx.edges.insert(UOpKey((*range).clone()), (*store).clone());
235 }
236 }
237 }
238 }
239 }
240
241 ctx
242 }
243
244 /// Get the predecessor for a given RANGE operation.
245 ///
246 /// Returns `Some(predecessor)` if this RANGE has a sibling that must
247 /// execute before it, `None` if this is the first RANGE at its level.
248 pub fn get_predecessor(&self, range: &Arc<UOp>) -> Option<&Arc<UOp>> {
249 self.edges.get(&UOpKey(range.clone()))
250 }
251
252 /// Check if this context has any control flow edges.
253 ///
254 /// Returns `true` if there are sibling RANGEs that require ordering.
255 pub fn has_edges(&self) -> bool {
256 !self.edges.is_empty()
257 }
258
259 /// Get the number of control flow edges.
260 pub fn edge_count(&self) -> usize {
261 self.edges.len()
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use morok_dtype::DType;
269 use morok_ir::types::ConstValue;
270
271 #[test]
272 fn test_cfg_context_single_range() {
273 // Single RANGE should have no edges
274 let end_val = UOp::index_const(10);
275 let range = UOp::range(end_val, 0);
276 let value = UOp::const_(DType::Float32, ConstValue::Float(1.0));
277 let end = value.end(smallvec::smallvec![range]);
278 let sink = UOp::sink(vec![end]);
279
280 let ctx = CFGContext::new(&sink);
281 assert!(!ctx.has_edges());
282 }
283
284 #[test]
285 fn test_cfg_context_sibling_ranges() {
286 // Two sibling RANGEs should have one edge
287 let end_val = UOp::index_const(10);
288 let range1 = UOp::range(end_val.clone(), 0);
289 let range2 = UOp::range(end_val, 1);
290
291 let value = UOp::const_(DType::Float32, ConstValue::Float(1.0));
292 let end = value.end(smallvec::smallvec![range1.clone(), range2.clone()]);
293 let sink = UOp::sink(vec![end]);
294
295 let ctx = CFGContext::new(&sink);
296 // With 2 ranges, we should have 1 edge (range2 → range1)
297 assert!(ctx.edge_count() <= 1);
298 }
299
300 #[test]
301 fn test_cfg_context_nested_ranges() {
302 // Nested RANGEs: inner loop runs inside outer loop.
303 // For inner_end to be nested inside outer_end, inner_end must depend on outer_range.
304 let end_val = UOp::index_const(10);
305
306 // Outer range first (so inner can depend on it)
307 let outer_range = UOp::range(end_val.clone(), 1);
308
309 // Inner range
310 let inner_range = UOp::range(end_val, 0);
311
312 // Inner value that depends on outer_range (so it runs inside outer loop)
313 // Use outer_range as part of the computation to create the dependency
314 let outer_idx = outer_range.cast(DType::Float32);
315 let inner_value = UOp::const_(DType::Float32, ConstValue::Float(1.0)).add(&outer_idx);
316 let inner_end = inner_value.end(smallvec::smallvec![inner_range.clone()]);
317
318 // Outer END
319 let outer_end = inner_end.end(smallvec::smallvec![outer_range.clone()]);
320
321 let sink = UOp::sink(vec![outer_end]);
322
323 let ctx = CFGContext::new(&sink);
324 // inner_end is nested inside outer_end (not siblings), so outer_range
325 // should have no predecessor edge
326 assert!(ctx.get_predecessor(&outer_range).is_none());
327 }
328}