use std::collections::HashMap;
use std::sync::Arc;
use morok_ir::UOp;
use morok_ir::op::Op;
use morok_ir::uop::core::UOpKey;
#[derive(Debug, Default)]
pub struct CFGContext {
#[allow(clippy::mutable_key_type)]
pub edges: HashMap<UOpKey, Arc<UOp>>,
}
impl CFGContext {
pub fn new(sink: &Arc<UOp>) -> Self {
let mut ctx = Self::default();
let nodes = sink.toposort();
#[allow(clippy::mutable_key_type)]
let mut deps: HashMap<UOpKey, HashMap<UOpKey, ()>> = HashMap::new();
for node in &nodes {
#[allow(clippy::mutable_key_type)]
let mut node_deps: HashMap<UOpKey, ()> = HashMap::new();
node.op().map_child(|src| {
if let Some(src_deps) = deps.get(&UOpKey(src.clone())) {
node_deps.extend(src_deps.iter().map(|(k, v)| (k.clone(), *v)));
}
});
if matches!(node.op(), Op::Range { .. } | Op::End { .. }) {
node_deps.insert(UOpKey(node.clone()), ());
}
deps.insert(UOpKey(node.clone()), node_deps);
}
#[allow(clippy::mutable_key_type)]
let mut nesting: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
for node in &nodes {
if matches!(node.op(), Op::End { .. } | Op::Sink { .. })
&& let Some(node_deps) = deps.get(&UOpKey(node.clone()))
{
for dep_key in node_deps.keys() {
if !matches!(dep_key.0.op(), Op::End { .. }) {
continue;
}
if dep_key.0.id == node.id {
continue;
}
if nesting.contains_key(dep_key) {
continue;
}
let is_nested = if matches!(node.op(), Op::Sink { .. }) {
true
} else if let Op::End { ranges, .. } = node.op() {
if let Some(range) = ranges.first() {
deps.get(dep_key).is_some_and(|dep_deps| dep_deps.contains_key(&UOpKey(range.clone())))
} else {
false
}
} else {
false
};
if is_nested {
nesting.insert(dep_key.clone(), node.clone());
}
}
}
}
#[allow(clippy::mutable_key_type)]
let mut siblings: HashMap<UOpKey, Vec<Arc<UOp>>> = HashMap::new();
for (end_key, parent) in &nesting {
siblings.entry(UOpKey(parent.clone())).or_default().push(end_key.0.clone());
}
for (parent, sibling_ends) in siblings {
if sibling_ends.is_empty() {
continue;
}
let mut ordered: Vec<Arc<UOp>> = sibling_ends.clone();
ordered.sort_by_key(|end| {
if let Some(end_deps) = deps.get(&UOpKey(end.clone())) {
sibling_ends.iter().filter(|sib| end_deps.contains_key(&UOpKey((*sib).clone()))).count()
} else {
0
}
});
let zipped: Vec<(Arc<UOp>, Arc<UOp>)> = if matches!(parent.0.op(), Op::Sink { .. }) {
ordered.windows(2).map(|w| (w[0].clone(), w[1].clone())).collect()
} else {
if let Op::End { ranges, .. } = parent.0.op() {
if let Some(parent_range) = ranges.first() {
let mut pairs = vec![(parent_range.clone(), ordered[0].clone())];
pairs.extend(ordered.windows(2).map(|w| (w[0].clone(), w[1].clone())));
pairs
} else {
ordered.windows(2).map(|w| (w[0].clone(), w[1].clone())).collect()
}
} else {
ordered.windows(2).map(|w| (w[0].clone(), w[1].clone())).collect()
}
};
for (x, y) in zipped {
let y_range = if let Op::End { ranges, .. } = y.op() { ranges.first().cloned() } else { None };
if let Some(range) = y_range {
assert!(
!x.backward_slice_ids().contains(&range.id),
"CFGContext: edge would create cycle (range {} → predecessor {}). \
This indicates a malformed kernel — see Tinygrad linearizer.py:81",
range.id,
x.id
);
tracing::trace!(range_id = range.id, predecessor_id = x.id, "CFGContext: creating edge");
ctx.edges.insert(UOpKey(range), x);
}
}
}
for node in &nodes {
if let Op::After { deps, .. } = node.op() {
let stores: Vec<_> = deps.iter().filter(|d| matches!(d.op(), Op::Store { .. })).collect();
let ranges: Vec<_> = deps.iter().filter(|d| matches!(d.op(), Op::Range { .. })).collect();
for store in &stores {
for range in &ranges {
let would_cycle = store.backward_slice_ids().contains(&range.id);
if !would_cycle {
tracing::trace!(range_id = range.id, store_id = store.id, "CFGContext: reduce init edge");
ctx.edges.insert(UOpKey((*range).clone()), (*store).clone());
}
}
}
}
}
ctx
}
pub fn get_predecessor(&self, range: &Arc<UOp>) -> Option<&Arc<UOp>> {
self.edges.get(&UOpKey(range.clone()))
}
pub fn has_edges(&self) -> bool {
!self.edges.is_empty()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use morok_dtype::DType;
use morok_ir::types::ConstValue;
#[test]
fn test_cfg_context_single_range() {
let end_val = UOp::index_const(10);
let range = UOp::range(end_val, 0);
let value = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let end = value.end(smallvec::smallvec![range]);
let sink = UOp::sink(vec![end]);
let ctx = CFGContext::new(&sink);
assert!(!ctx.has_edges());
}
#[test]
fn test_cfg_context_sibling_ranges() {
let end_val = UOp::index_const(10);
let range1 = UOp::range(end_val.clone(), 0);
let range2 = UOp::range(end_val, 1);
let value = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let end = value.end(smallvec::smallvec![range1.clone(), range2.clone()]);
let sink = UOp::sink(vec![end]);
let ctx = CFGContext::new(&sink);
assert!(ctx.edge_count() <= 1);
}
#[test]
fn test_cfg_context_nested_ranges() {
let end_val = UOp::index_const(10);
let outer_range = UOp::range(end_val.clone(), 1);
let inner_range = UOp::range(end_val, 0);
let outer_idx = outer_range.cast(DType::Float32);
let inner_value = UOp::const_(DType::Float32, ConstValue::Float(1.0)).add(&outer_idx);
let inner_end = inner_value.end(smallvec::smallvec![inner_range.clone()]);
let outer_end = inner_end.end(smallvec::smallvec![outer_range.clone()]);
let sink = UOp::sink(vec![outer_end]);
let ctx = CFGContext::new(&sink);
assert!(ctx.get_predecessor(&outer_range).is_none());
}
}