Skip to main content

morok_schedule/linearize/
cfg_context.rs

1//! Control flow graph context for linearization.
2//!
3//! CFGContext analyzes the control flow structure of a kernel AST and computes
4//! ordering edges between sibling RANGE operations at the same nesting level.
5//!
6//! This implementation matches Tinygrad's CFGContext (linearizer.py:59-91).
7
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use morok_ir::UOp;
12use morok_ir::op::Op;
13use morok_ir::uop::core::UOpKey;
14
15/// Control flow graph context for linearization.
16///
17/// Tracks ordering edges between sibling RANGE operations to ensure
18/// proper linearization order when loops are at the same nesting level.
19///
20/// Based on Tinygrad's CFGContext which tracks three relationships between ranges:
21/// - **nested**: END y is a dependency of END x AND RANGE x is a dependency of END y
22/// - **dependent**: END y is a dependency of END x AND RANGE x is NOT a dependency of END y
23/// - **independent**: END y is NOT a dependency of END x
24///
25/// # Control Flow Edges
26///
27/// When multiple ENDs exist at the same nesting level (siblings), they need to be
28/// ordered consistently. CFGContext computes edges where:
29/// - Each RANGE points to its predecessor (either the parent's RANGE or another END)
30/// - Edges ensure sequential execution of sibling loops
31///
32/// # Example
33///
34/// ```text
35/// RANGE(i) → ... → END(i)   // first loop
36/// RANGE(j) → ... → END(j)   // second loop (sibling)
37/// RANGE(k) → ... → END(k)   // third loop (sibling)
38///
39/// CFGContext edges:
40///   RANGE(j) → END(i)    // j's RANGE depends on i's END
41///   RANGE(k) → END(j)    // k's RANGE depends on j's END
42/// ```
43#[derive(Debug, Default)]
44pub struct CFGContext {
45    /// Maps RANGE → predecessor (previous sibling END or parent's RANGE).
46    ///
47    /// The predecessor is the operation that must complete before
48    /// this RANGE can begin execution.
49    #[allow(clippy::mutable_key_type)]
50    pub edges: HashMap<UOpKey, Arc<UOp>>,
51}
52
53impl CFGContext {
54    /// Build a control flow context from a kernel AST.
55    ///
56    /// Analyzes the graph to find sibling ENDs at the same nesting level
57    /// and creates ordering edges between their RANGEs.
58    ///
59    /// # Algorithm (from Tinygrad linearizer.py:59-91)
60    ///
61    /// 1. Build transitive deps map (RANGE/END add themselves to deps)
62    /// 2. Build nesting map: which END/SINK nests each END
63    /// 3. Group siblings by parent
64    /// 4. Order siblings by dependency count (fewer deps = earlier)
65    /// 5. Create edges: RANGE of later sibling → predecessor (END or parent's RANGE)
66    pub fn new(sink: &Arc<UOp>) -> Self {
67        let mut ctx = Self::default();
68
69        // Collect all nodes via toposort
70        let nodes = sink.toposort();
71
72        // Step 1: Build dependency sets for each node
73        // RANGE and END add themselves to deps
74        // deps[u] = set of RANGE/END UOps that u transitively depends on
75        #[allow(clippy::mutable_key_type)]
76        let mut deps: HashMap<UOpKey, HashMap<UOpKey, ()>> = HashMap::new();
77
78        for node in &nodes {
79            // Get deps from sources
80            #[allow(clippy::mutable_key_type)]
81            let mut node_deps: HashMap<UOpKey, ()> = HashMap::new();
82            node.op().map_child(|src| {
83                if let Some(src_deps) = deps.get(&UOpKey(src.clone())) {
84                    node_deps.extend(src_deps.iter().map(|(k, v)| (k.clone(), *v)));
85                }
86            });
87
88            // RANGE and END add themselves
89            if matches!(node.op(), Op::Range { .. } | Op::End { .. }) {
90                node_deps.insert(UOpKey(node.clone()), ());
91            }
92
93            deps.insert(UOpKey(node.clone()), node_deps);
94        }
95
96        // Step 2: Build nesting map
97        // For each END, find which END/SINK it is nested inside
98        // END x is nested in END/SINK u if:
99        //   - u depends on x (x is in deps[u])
100        //   - u is SINK, OR u's RANGE (u.src[1]) is in deps[x]
101        //   - x hasn't been assigned a nesting parent yet
102        #[allow(clippy::mutable_key_type)]
103        let mut nesting: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
104
105        for node in &nodes {
106            if matches!(node.op(), Op::End { .. } | Op::Sink { .. })
107                && let Some(node_deps) = deps.get(&UOpKey(node.clone()))
108            {
109                for dep_key in node_deps.keys() {
110                    // Only consider END nodes
111                    if !matches!(dep_key.0.op(), Op::End { .. }) {
112                        continue;
113                    }
114
115                    // Skip self-references (END cannot be nested inside itself)
116                    if dep_key.0.id == node.id {
117                        continue;
118                    }
119
120                    // Skip if already assigned
121                    if nesting.contains_key(dep_key) {
122                        continue;
123                    }
124
125                    // Check nesting condition
126                    let is_nested = if matches!(node.op(), Op::Sink { .. }) {
127                        true
128                    } else if let Op::End { ranges, .. } = node.op() {
129                        // Check if node's RANGE is in dep's dependencies
130                        // node.src[1] in Tinygrad is the RANGE - we get it from ranges
131                        if let Some(range) = ranges.first() {
132                            deps.get(dep_key).is_some_and(|dep_deps| dep_deps.contains_key(&UOpKey(range.clone())))
133                        } else {
134                            false
135                        }
136                    } else {
137                        false
138                    };
139
140                    if is_nested {
141                        nesting.insert(dep_key.clone(), node.clone());
142                    }
143                }
144            }
145        }
146
147        // Step 3: Group siblings by parent
148        #[allow(clippy::mutable_key_type)]
149        let mut siblings: HashMap<UOpKey, Vec<Arc<UOp>>> = HashMap::new();
150        for (end_key, parent) in &nesting {
151            siblings.entry(UOpKey(parent.clone())).or_default().push(end_key.0.clone());
152        }
153
154        // Step 4 & 5: Order siblings and create edges
155        for (parent, sibling_ends) in siblings {
156            if sibling_ends.is_empty() {
157                continue;
158            }
159
160            // Order by dependency count on other siblings (fewer deps = earlier)
161            let mut ordered: Vec<Arc<UOp>> = sibling_ends.clone();
162            ordered.sort_by_key(|end| {
163                if let Some(end_deps) = deps.get(&UOpKey(end.clone())) {
164                    sibling_ends.iter().filter(|sib| end_deps.contains_key(&UOpKey((*sib).clone()))).count()
165                } else {
166                    0
167                }
168            });
169
170            // Create edges
171            // If parent is SINK: zip(order, order[1:])
172            // If parent is END: zip([parent.src[1]] + order, order)
173            //   where parent.src[1] is the parent's RANGE
174            let zipped: Vec<(Arc<UOp>, Arc<UOp>)> = if matches!(parent.0.op(), Op::Sink { .. }) {
175                // Pair consecutive siblings
176                ordered.windows(2).map(|w| (w[0].clone(), w[1].clone())).collect()
177            } else {
178                // Get parent's RANGE
179                if let Op::End { ranges, .. } = parent.0.op() {
180                    if let Some(parent_range) = ranges.first() {
181                        // Pair: parent_range → first, then consecutive siblings
182                        let mut pairs = vec![(parent_range.clone(), ordered[0].clone())];
183                        pairs.extend(ordered.windows(2).map(|w| (w[0].clone(), w[1].clone())));
184                        pairs
185                    } else {
186                        ordered.windows(2).map(|w| (w[0].clone(), w[1].clone())).collect()
187                    }
188                } else {
189                    ordered.windows(2).map(|w| (w[0].clone(), w[1].clone())).collect()
190                }
191            };
192
193            // Create edges: y's RANGE → x (predecessor)
194            for (x, y) in zipped {
195                // y is an END, get its RANGE from y.src[1] (or ranges field)
196                let y_range = if let Op::End { ranges, .. } = y.op() { ranges.first().cloned() } else { None };
197
198                if let Some(range) = y_range {
199                    // Tinygrad: assert y.src[1] not in x.backward_slice_with_self
200                    // A cycle here indicates a malformed kernel structure.
201                    assert!(
202                        !x.backward_slice_ids().contains(&range.id),
203                        "CFGContext: edge would create cycle (range {} → predecessor {}). \
204                         This indicates a malformed kernel — see Tinygrad linearizer.py:81",
205                        range.id,
206                        x.id
207                    );
208                    tracing::trace!(range_id = range.id, predecessor_id = x.id, "CFGContext: creating edge");
209                    ctx.edges.insert(UOpKey(range), x);
210                }
211            }
212        }
213
214        // Step 6: Create edges for reduce RANGEs to wait for their init STOREs.
215        //
216        // When AFTER(passthrough, [init_store, reduce_range]) appears (from reduce_to_acc),
217        // the heap-based linearization may schedule RANGE before STORE due to priority
218        // tie-breaking (RANGE +5 vs STORE +1). This edge ensures zero-init STORE appears
219        // before the reduce loop, which is required for correctness.
220        //
221        // NOTE: Tinygrad's linearizer.py lacks this fix and has known issues
222        // (see comment at line 85-86: "TODO: this can happen! it causes infinite loop
223        // in shufflenet"). Our explicit edge creation is more robust.
224        for node in &nodes {
225            if let Op::After { deps, .. } = node.op() {
226                let stores: Vec<_> = deps.iter().filter(|d| matches!(d.op(), Op::Store { .. })).collect();
227                let ranges: Vec<_> = deps.iter().filter(|d| matches!(d.op(), Op::Range { .. })).collect();
228
229                for store in &stores {
230                    for range in &ranges {
231                        let would_cycle = store.backward_slice_ids().contains(&range.id);
232                        if !would_cycle {
233                            tracing::trace!(range_id = range.id, store_id = store.id, "CFGContext: reduce init edge");
234                            ctx.edges.insert(UOpKey((*range).clone()), (*store).clone());
235                        }
236                    }
237                }
238            }
239        }
240
241        ctx
242    }
243
244    /// Get the predecessor for a given RANGE operation.
245    ///
246    /// Returns `Some(predecessor)` if this RANGE has a sibling that must
247    /// execute before it, `None` if this is the first RANGE at its level.
248    pub fn get_predecessor(&self, range: &Arc<UOp>) -> Option<&Arc<UOp>> {
249        self.edges.get(&UOpKey(range.clone()))
250    }
251
252    /// Check if this context has any control flow edges.
253    ///
254    /// Returns `true` if there are sibling RANGEs that require ordering.
255    pub fn has_edges(&self) -> bool {
256        !self.edges.is_empty()
257    }
258
259    /// Get the number of control flow edges.
260    pub fn edge_count(&self) -> usize {
261        self.edges.len()
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use morok_dtype::DType;
269    use morok_ir::types::ConstValue;
270
271    #[test]
272    fn test_cfg_context_single_range() {
273        // Single RANGE should have no edges
274        let end_val = UOp::index_const(10);
275        let range = UOp::range(end_val, 0);
276        let value = UOp::const_(DType::Float32, ConstValue::Float(1.0));
277        let end = value.end(smallvec::smallvec![range]);
278        let sink = UOp::sink(vec![end]);
279
280        let ctx = CFGContext::new(&sink);
281        assert!(!ctx.has_edges());
282    }
283
284    #[test]
285    fn test_cfg_context_sibling_ranges() {
286        // Two sibling RANGEs should have one edge
287        let end_val = UOp::index_const(10);
288        let range1 = UOp::range(end_val.clone(), 0);
289        let range2 = UOp::range(end_val, 1);
290
291        let value = UOp::const_(DType::Float32, ConstValue::Float(1.0));
292        let end = value.end(smallvec::smallvec![range1.clone(), range2.clone()]);
293        let sink = UOp::sink(vec![end]);
294
295        let ctx = CFGContext::new(&sink);
296        // With 2 ranges, we should have 1 edge (range2 → range1)
297        assert!(ctx.edge_count() <= 1);
298    }
299
300    #[test]
301    fn test_cfg_context_nested_ranges() {
302        // Nested RANGEs: inner loop runs inside outer loop.
303        // For inner_end to be nested inside outer_end, inner_end must depend on outer_range.
304        let end_val = UOp::index_const(10);
305
306        // Outer range first (so inner can depend on it)
307        let outer_range = UOp::range(end_val.clone(), 1);
308
309        // Inner range
310        let inner_range = UOp::range(end_val, 0);
311
312        // Inner value that depends on outer_range (so it runs inside outer loop)
313        // Use outer_range as part of the computation to create the dependency
314        let outer_idx = outer_range.cast(DType::Float32);
315        let inner_value = UOp::const_(DType::Float32, ConstValue::Float(1.0)).add(&outer_idx);
316        let inner_end = inner_value.end(smallvec::smallvec![inner_range.clone()]);
317
318        // Outer END
319        let outer_end = inner_end.end(smallvec::smallvec![outer_range.clone()]);
320
321        let sink = UOp::sink(vec![outer_end]);
322
323        let ctx = CFGContext::new(&sink);
324        // inner_end is nested inside outer_end (not siblings), so outer_range
325        // should have no predecessor edge
326        assert!(ctx.get_predecessor(&outer_range).is_none());
327    }
328}