1use std::sync::Arc;
4
5use morok_ir::{Op, UOp};
6
7pub fn is_output_buffer(def_global: &Arc<UOp>, nodes: &[Arc<UOp>]) -> bool {
9 let buffer_id = def_global.id;
10
11 for node in nodes {
12 if let Some(buffer) = node.store_buffer() {
13 if buffer.id == buffer_id {
14 return true;
15 }
16 if let Op::Index { buffer: idx_buf, .. } = buffer.op()
17 && idx_buf.id == buffer_id
18 {
19 return true;
20 }
21 }
22 }
23 false
24}
25
26pub fn collect_buffers_and_vars(root: &Arc<UOp>) -> (Vec<Arc<UOp>>, Vec<Arc<UOp>>) {
34 let nodes = root.toposort();
35
36 let mut buffers = Vec::new();
38 for node in &nodes {
39 match node.op() {
40 Op::Buffer { .. } | Op::Param { device: None, .. } | Op::DefineLocal(_) => {
41 buffers.push(node.clone());
42 }
43 _ => {}
44 }
45 }
46
47 buffers.sort_by_key(|b| match b.op() {
49 Op::Param { slot, device: None, .. } => *slot as u64,
50 Op::DefineLocal(id) => (*id as u64) + (1u64 << 32),
51 Op::Buffer { .. } => b.id + (1u64 << 48),
52 _ => b.id,
53 });
54
55 let mut variables = Vec::new();
57 for node in &nodes {
58 if matches!(node.op(), Op::DefineVar { .. }) {
59 variables.push(node.clone());
60 }
61 }
62
63 variables.sort_by_key(|v| if let Op::DefineVar { name, .. } = v.op() { name.clone() } else { String::new() });
65
66 (buffers, variables)
67}