Skip to main content

morok_codegen/
common.rs

1//! Common utilities shared between codegen backends.
2
3use std::sync::Arc;
4
5use morok_ir::{Op, UOp};
6
7/// Collect buffer and variable parameters from a UOp graph.
8///
9/// Collects:
10/// - Buffers: DEFINE_GLOBAL, DEFINE_LOCAL, BUFFER operations
11/// - Variables: DEFINE_VAR operations (passed as i64 kernel params)
12///
13/// Returns (buffers, variables) sorted for deterministic function signatures.
14pub fn collect_buffers_and_vars(root: &Arc<UOp>) -> (Vec<Arc<UOp>>, Vec<Arc<UOp>>) {
15    let nodes = root.toposort();
16
17    // Collect buffers
18    let mut buffers = Vec::new();
19    for node in &nodes {
20        match node.op() {
21            Op::Buffer { .. } | Op::DefineGlobal(_) | Op::DefineLocal(_) => {
22                buffers.push(node.clone());
23            }
24            _ => {}
25        }
26    }
27
28    // Sort buffers by internal ID (matches split_kernel.rs ordering)
29    buffers.sort_by_key(|b| match b.op() {
30        Op::DefineGlobal(id) => *id as u64,
31        Op::DefineLocal(id) => (*id as u64) + (1u64 << 32),
32        Op::Buffer { .. } => b.id + (1u64 << 48),
33        _ => b.id,
34    });
35
36    // Collect DefineVar nodes - these become i64 kernel parameters
37    let mut variables = Vec::new();
38    for node in &nodes {
39        if matches!(node.op(), Op::DefineVar { .. }) {
40            variables.push(node.clone());
41        }
42    }
43
44    // Sort variables by name for deterministic function signatures
45    variables.sort_by_key(|v| if let Op::DefineVar { name, .. } = v.op() { name.clone() } else { String::new() });
46
47    (buffers, variables)
48}