mod cfg_context;
#[allow(clippy::module_inception)]
mod linearize;
use std::collections::HashMap;
use std::sync::Arc;
use morok_ir::UOp;
use morok_ir::op::Op;
use morok_ir::pattern::TypedPatternMatcher;
use morok_ir::rewrite::{graph_rewrite, graph_rewrite_bottom_up};
use smallvec::{SmallVec, smallvec};
pub use cfg_context::CFGContext;
pub use linearize::linearize;
pub fn pm_split_ends() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
End { computation, ranges } => |computation, ranges| {
split_end(computation, ranges)
},
}
}
fn split_end(computation: &Arc<UOp>, ranges: &SmallVec<[Arc<UOp>; 4]>) -> Option<Arc<UOp>> {
let sink = UOp::sink(ranges.iter().cloned().collect());
let actual_ranges = sink.ranges().clone();
if actual_ranges.is_empty() {
return None;
}
if actual_ranges.len() == 1 {
let new_end = computation.end(SmallVec::from_elem(actual_ranges[0].clone(), 1));
if ranges.len() == 1 && ranges[0].id == actual_ranges[0].id {
return None; }
return Some(new_end);
}
let mut sorted_ranges = actual_ranges;
sorted_ranges.sort_by(|a, b| {
let (a_id, a_ty) = match a.op() {
Op::Range { axis_id, axis_type, .. } => (axis_id.value(), axis_type.priority()),
_ => unreachable!("filtered to RANGEs only"),
};
let (b_id, b_ty) = match b.op() {
Op::Range { axis_id, axis_type, .. } => (axis_id.value(), axis_type.priority()),
_ => unreachable!("filtered to RANGEs only"),
};
(b_id, b_ty).cmp(&(a_id, a_ty))
});
let mut result = computation.clone();
for range in sorted_ranges {
result = result.end(SmallVec::from_elem(range, 1));
}
Some(result)
}
fn pm_add_control_flow() -> TypedPatternMatcher<CFGContext> {
crate::patterns! {
@context CFGContext;
range @ Range { end: _, .. } => {
let pred = ctx.get_predecessor(range)?;
let mut srcs = range.op().sources().to_vec();
srcs.push(pred.clone());
Some(range.with_sources(srcs))
},
}
}
pub fn linearize_with_cfg(sink: Arc<UOp>) -> Vec<Arc<UOp>> {
let sink = graph_rewrite(pm_split_ends(), sink, &mut ());
let mut cfg = CFGContext::new(&sink);
let sink = graph_rewrite_bottom_up(&pm_add_control_flow(), sink, &mut cfg);
linearize(sink)
}
fn line_rewrite<F>(lst: Vec<Arc<UOp>>, rewrite_fn: F) -> Vec<Arc<UOp>>
where
F: Fn(&Arc<UOp>, &HashMap<u64, Arc<UOp>>) -> Option<(Arc<UOp>, Vec<Arc<UOp>>)>,
{
let mut newlst = Vec::with_capacity(lst.len() * 2);
let mut replaced: HashMap<u64, Arc<UOp>> = HashMap::new();
for u in lst {
let nu = replace_sources_from_map(&u, &replaced);
let (replacement, outputs) = match rewrite_fn(&nu, &replaced) {
Some((repl, outs)) => (repl, outs),
None => (nu.clone(), vec![nu]),
};
replaced.insert(u.id, replacement);
newlst.extend(outputs);
}
newlst
}
fn replace_sources_from_map(uop: &Arc<UOp>, replaced: &HashMap<u64, Arc<UOp>>) -> Arc<UOp> {
let sources = uop.op().sources();
if sources.is_empty() {
return uop.clone();
}
let new_sources: Vec<Arc<UOp>> =
sources.iter().map(|src| replaced.get(&src.id).cloned().unwrap_or_else(|| src.clone())).collect();
if sources.iter().zip(&new_sources).all(|(old, new)| old.id == new.id) {
return uop.clone();
}
uop.replace().src(new_sources).call()
}
fn linearize_cleanup_pattern(uop: &Arc<UOp>, _replaced: &HashMap<u64, Arc<UOp>>) -> Option<(Arc<UOp>, Vec<Arc<UOp>>)> {
if matches!(uop.op(), Op::If { .. } | Op::EndIf { .. }) {
panic!("IF/ENDIF not allowed in graph before line_rewrite_cleanups");
}
let Op::Store { index, value, ranges } = uop.op() else {
return None;
};
let (actual_index, cast_dtype) = match index.op() {
Op::Cast { src, dtype } => (src, Some(dtype.clone())),
_ => (index, None),
};
let Op::Index { buffer, indices, gate: Some(gate) } = actual_index.op() else {
return None;
};
let ungated_index =
UOp::new(Op::Index { buffer: buffer.clone(), indices: indices.clone(), gate: None }, actual_index.dtype());
let final_index =
if let Some(ref dtype) = cast_dtype { ungated_index.cast(dtype.clone()) } else { ungated_index.clone() };
let ungated_store = final_index.store_with_ranges(value.clone(), ranges.clone());
let if_op = UOp::if_(gate.clone(), smallvec![ungated_index.clone()]);
let endif_op = UOp::endif(if_op.clone());
let mut outputs = vec![if_op, ungated_index];
if cast_dtype.is_some() {
outputs.push(final_index);
}
outputs.push(ungated_store.clone());
outputs.push(endif_op);
Some((ungated_store, outputs))
}
pub fn line_rewrite_cleanups(lst: Vec<Arc<UOp>>) -> Vec<Arc<UOp>> {
line_rewrite(lst, linearize_cleanup_pattern)
}