Skip to main content

morok_codegen/
common.rs

1//! Common utilities shared between codegen backends.
2
3use std::sync::Arc;
4
5use morok_ir::{Op, UOp};
6
7/// Check whether a buffer (PARAM/DefineGlobal) is used as a STORE target in the graph.
8pub fn is_output_buffer(def_global: &Arc<UOp>, nodes: &[Arc<UOp>]) -> bool {
9    let buffer_id = def_global.id;
10
11    for node in nodes {
12        if let Some(buffer) = node.store_buffer() {
13            if buffer.id == buffer_id {
14                return true;
15            }
16            if let Op::Index { buffer: idx_buf, .. } = buffer.op()
17                && idx_buf.id == buffer_id
18            {
19                return true;
20            }
21        }
22    }
23    false
24}
25
26/// Collect buffer and variable parameters from a UOp graph.
27///
28/// Collects:
29/// - Buffers: PARAM, DEFINE_LOCAL operations
30/// - Variables: DEFINE_VAR operations (passed as i64 kernel params)
31///
32/// Returns (buffers, variables) sorted for deterministic function signatures.
33pub fn collect_buffers_and_vars(root: &Arc<UOp>) -> (Vec<Arc<UOp>>, Vec<Arc<UOp>>) {
34    let nodes = root.toposort();
35
36    // Collect buffers
37    let mut buffers = Vec::new();
38    for node in &nodes {
39        match node.op() {
40            Op::Buffer { .. } | Op::Param { device: None, .. } | Op::DefineLocal(_) => {
41                buffers.push(node.clone());
42            }
43            _ => {}
44        }
45    }
46
47    // Sort buffers by internal ID (matches split_kernel.rs ordering)
48    buffers.sort_by_key(|b| match b.op() {
49        Op::Param { slot, device: None, .. } => *slot as u64,
50        Op::DefineLocal(id) => (*id as u64) + (1u64 << 32),
51        Op::Buffer { .. } => b.id + (1u64 << 48),
52        _ => b.id,
53    });
54
55    // Collect DefineVar nodes - these become i64 kernel parameters
56    let mut variables = Vec::new();
57    for node in &nodes {
58        if matches!(node.op(), Op::DefineVar { .. }) {
59            variables.push(node.clone());
60        }
61    }
62
63    // Sort variables by name for deterministic function signatures
64    variables.sort_by_key(|v| if let Op::DefineVar { name, .. } = v.op() { name.clone() } else { String::new() });
65
66    (buffers, variables)
67}