Skip to main content

morok_codegen/c/
ops.rs

1//! C source code rendering for individual UOp operations.
2//!
3//! Generates C expressions/statements for each Op variant.
4//! Uses SSA inlining: single-use values are inlined as expressions,
5//! multi-use values get local variable declarations.
6
7use std::collections::{HashMap, HashSet};
8use std::sync::Arc;
9
10use morok_dtype::{DType, ScalarDType};
11use morok_ir::{AxisType, BinaryOp, Op, ReduceOp, TernaryOp, UnaryOp, prelude::*};
12
13use super::types::{c_cast, c_const, c_dtype, c_math_fn};
14
15/// Context for C code generation, tracking variable names and SSA inlining.
16pub struct CContext {
17    /// UOp ID -> C expression or variable name
18    names: HashMap<u64, String>,
19    /// UOp ID -> reference count (how many times used)
20    ref_counts: HashMap<u64, usize>,
21    /// Variable counter for generating unique names
22    counter: usize,
23    /// Current indentation depth
24    depth: usize,
25    /// Pending reduce accumulator info: reduce_id -> (acc_name, dtype)
26    pending_reduces: HashMap<u64, (String, DType)>,
27    /// UOp IDs that escape their declaration scope — need function-scope declaration.
28    scope_escaping: HashSet<u64>,
29    /// Function-scope declarations for hoisted variables (emitted before kernel body).
30    pub hoisted_declarations: Vec<String>,
31}
32
33impl CContext {
34    pub fn new(ref_counts: HashMap<u64, usize>, scope_escaping: HashSet<u64>) -> Self {
35        Self {
36            names: HashMap::new(),
37            ref_counts,
38            counter: 0,
39            depth: 1,
40            pending_reduces: HashMap::new(),
41            scope_escaping,
42            hoisted_declarations: Vec::new(),
43        }
44    }
45
46    /// Get the C expression for a UOp. Panics if not registered.
47    pub fn get(&self, uop: &Arc<UOp>) -> &str {
48        self.names
49            .get(&uop.id)
50            .map(|s| s.as_str())
51            .unwrap_or_else(|| panic!("UOp {} ({}) not in C context", uop.id, uop.op().as_ref()))
52    }
53
54    /// Register a name/expression for a UOp ID.
55    pub fn register(&mut self, id: u64, expr: String) {
56        self.names.insert(id, expr);
57    }
58
59    /// Check if a value should be inlined (single-use, expression-safe).
60    pub fn should_inline(&self, id: u64) -> bool {
61        self.ref_counts.get(&id).copied().unwrap_or(0) <= 1
62    }
63
64    /// Generate a unique variable name with given prefix.
65    pub fn next_name(&mut self, prefix: &str) -> String {
66        let name = format!("{}{}", prefix, self.counter);
67        self.counter += 1;
68        name
69    }
70
71    /// Get current indentation string.
72    pub fn indent(&self) -> String {
73        "  ".repeat(self.depth)
74    }
75
76    /// Increase indentation depth.
77    pub fn push_indent(&mut self) {
78        self.depth += 1;
79    }
80
81    /// Decrease indentation depth.
82    pub fn pop_indent(&mut self) {
83        self.depth = self.depth.saturating_sub(1);
84    }
85
86    /// Register a pending reduce final load.
87    pub fn register_reduce_pending(&mut self, reduce_id: u64, acc_name: String, dtype: DType) {
88        self.pending_reduces.insert(reduce_id, (acc_name, dtype));
89    }
90
91    /// Take all pending reduces.
92    pub fn take_pending_reduces(&mut self) -> HashMap<u64, (String, DType)> {
93        std::mem::take(&mut self.pending_reduces)
94    }
95
96    /// Emit a C expression, either as an inline expression or a variable declaration.
97    /// Returns the name/expression to reference this value.
98    ///
99    /// Variables that escape their declaration scope are hoisted: declared at function
100    /// scope and assigned at current depth. This prevents "use of undeclared identifier"
101    /// errors when the linearizer places a shared node inside a loop but consumers exist
102    /// outside the loop.
103    pub fn emit_expr(&mut self, uop: &Arc<UOp>, expr: String, prefix: &str, kernel: &mut Vec<String>) -> String {
104        if self.should_inline(uop.id) {
105            self.register(uop.id, expr.clone());
106            expr
107        } else {
108            let name = self.next_name(prefix);
109            let dtype = c_dtype(&uop.dtype());
110            let indent = self.indent();
111            if self.scope_escaping.contains(&uop.id) {
112                // Hoist: declare at function scope, assign at current depth
113                self.hoisted_declarations.push(format!("  {dtype} {name};"));
114                kernel.push(format!("{indent}{name} = {expr};"));
115            } else {
116                kernel.push(format!("{indent}{dtype} {name} = {expr};"));
117            }
118            self.register(uop.id, name.clone());
119            name
120        }
121    }
122}
123
124/// Render a single UOp to C source code.
125///
126/// Returns `Some(())` if code was emitted, `None` for meta-ops.
127pub fn render_uop(uop: &Arc<UOp>, ctx: &mut CContext, kernel: &mut Vec<String>) -> Option<()> {
128    match uop.op() {
129        // Meta-ops: no code emitted
130        Op::Const(_)
131        | Op::VConst { .. }
132        | Op::Param { device: None, .. }
133        | Op::DefineLocal(_)
134        | Op::DefineVar { .. }
135        | Op::Noop
136        | Op::Sink { .. }
137        | Op::Group { .. }
138        | Op::Buffer { .. }
139        | Op::Unique(_)
140        | Op::Device(_)
141        | Op::Kernel { .. }
142        | Op::Barrier { .. } => None,
143
144        Op::DefineReg { .. } => {
145            // Read base type and size from dtype (matching Tinygrad's x.dtype.base/x.dtype.size).
146            // After devectorize's no_vectorized_buf, the dtype is the canonical source of truth:
147            // e.g. Ptr(base=Float32, size=35) instead of the Op's original size field.
148            let (base_dtype, alloc_size) = match uop.dtype() {
149                DType::Ptr { base, size, .. } => (base.as_ref().clone(), size.unwrap_or(1)),
150                other => (other, 1),
151            };
152            let name = ctx.next_name("reg");
153            let indent = ctx.indent();
154            kernel.push(format!("{indent}{} {name}[{alloc_size}];", c_dtype(&base_dtype)));
155            ctx.register(uop.id, name);
156            Some(())
157        }
158
159        Op::Index { buffer, indices, .. } => {
160            let buf = ctx.get(buffer).to_string();
161
162            if indices.is_empty() {
163                // No index - just alias the buffer pointer
164                ctx.register(uop.id, buf);
165            } else {
166                // Multi-index: linearize at render time using row-major strides
167                let idx = if indices.len() > 1 {
168                    render_linearize_multi_index_c(indices, ctx)
169                } else {
170                    ctx.get(&indices[0]).to_string()
171                };
172                let expr = format!("{buf} + {idx}");
173                ctx.emit_expr(uop, expr, "idx", kernel);
174            }
175            Some(())
176        }
177
178        Op::PointerIndex { ptr, offset } => {
179            let ptr_val = ctx.get(ptr).to_string();
180            let off_val = ctx.get(offset).to_string();
181            let expr = format!("{ptr_val} + {off_val}");
182            ctx.emit_expr(uop, expr, "pidx", kernel);
183            Some(())
184        }
185
186        Op::Load { index, alt, .. } => {
187            let idx = ctx.get(index).to_string();
188            let load_dtype = uop.dtype();
189            // Check if the INDEX source has a gate — render conditional load to avoid null deref.
190            // Tinygrad: LOAD with gated INDEX → (gate ? *(index) : alt_value)
191            let gate_expr = if let Op::Index { gate: Some(gate_uop), .. } = index.op() {
192                Some(ctx.get(gate_uop).to_string())
193            } else {
194                None
195            };
196            let deref_expr = if load_dtype.vcount() > 1 {
197                let cast_type = c_dtype(&load_dtype);
198                format!("*(({cast_type}*)({idx}))")
199            } else {
200                format!("*({idx})")
201            };
202            let expr = if let Some(gate) = gate_expr {
203                // Use the LOAD's alt value if present, otherwise default to zero
204                let alt_expr = if let Some(alt_uop) = alt {
205                    ctx.get(alt_uop).to_string()
206                } else {
207                    c_const(&morok_ir::types::ConstValue::zero(load_dtype.base()), &load_dtype)
208                };
209                format!("({gate} ? {deref_expr} : {alt_expr})")
210            } else {
211                deref_expr
212            };
213            ctx.emit_expr(uop, expr, "val", kernel);
214            Some(())
215        }
216
217        Op::Store { index, value, .. } => {
218            let idx = ctx.get(index).to_string();
219            let val = ctx.get(value).to_string();
220            let indent = ctx.indent();
221            let val_dtype = value.dtype();
222            // Buffer pointers are declared as scalar types (e.g., float*) in C,
223            // so vector stores need an explicit pointer cast.
224            if val_dtype.vcount() > 1 {
225                let cast_type = c_dtype(&val_dtype);
226                kernel.push(format!("{indent}*(({cast_type}*)({idx})) = {val};"));
227            } else {
228                kernel.push(format!("{indent}*({idx}) = {val};"));
229            }
230            Some(())
231        }
232
233        Op::Binary(op, lhs, rhs) => {
234            let l = ctx.get(lhs).to_string();
235            let r = ctx.get(rhs).to_string();
236            let expr = render_binary(*op, &l, &r, &lhs.dtype());
237            ctx.emit_expr(uop, expr, "alu", kernel);
238            Some(())
239        }
240
241        Op::Unary(op, src) => {
242            let s = ctx.get(src).to_string();
243            let expr = render_unary(*op, &s, &src.dtype());
244            ctx.emit_expr(uop, expr, "alu", kernel);
245            Some(())
246        }
247
248        Op::Ternary(TernaryOp::Where, cond, t, f) => {
249            let c = ctx.get(cond).to_string();
250            let tv = ctx.get(t).to_string();
251            let fv = ctx.get(f).to_string();
252            let expr = format!("({c} ? {tv} : {fv})");
253            ctx.emit_expr(uop, expr, "alu", kernel);
254            Some(())
255        }
256
257        Op::Ternary(TernaryOp::MulAcc, a, b, c) => {
258            let av = ctx.get(a).to_string();
259            let bv = ctx.get(b).to_string();
260            let cv = ctx.get(c).to_string();
261            let expr = if a.dtype().is_float() {
262                format!("{}({av}, {bv}, {cv})", c_math_fn("__builtin_fma", &a.dtype()))
263            } else {
264                format!("(({av} * {bv}) + {cv})")
265            };
266            ctx.emit_expr(uop, expr, "alu", kernel);
267            Some(())
268        }
269
270        Op::Cast { src, dtype } => {
271            let s = ctx.get(src).to_string();
272
273            // INDEX to Ptr is a no-op in C (INDEX already produces a pointer)
274            if matches!(src.op(), Op::Index { .. }) && matches!(dtype, DType::Ptr { .. }) {
275                ctx.register(uop.id, s);
276                return Some(());
277            }
278
279            // Vector casts use __builtin_convertvector for element-wise conversion
280            // (a plain C cast would reinterpret bits, not convert values)
281            let expr = if dtype.vcount() > 1 && !matches!(dtype, DType::Ptr { .. }) {
282                format!("__builtin_convertvector({s}, {})", c_dtype(dtype))
283            } else {
284                c_cast(&s, &src.dtype(), dtype)
285            };
286            ctx.emit_expr(uop, expr, "cast", kernel);
287            Some(())
288        }
289
290        Op::BitCast { src, dtype } => {
291            let s = ctx.get(src).to_string();
292            let from_type = c_dtype(&src.dtype());
293            let to_type = c_dtype(dtype);
294            if from_type == to_type {
295                ctx.register(uop.id, s);
296            } else {
297                let expr = format!("__builtin_bit_cast({to_type}, ({from_type})({s}))");
298                ctx.emit_expr(uop, expr, "cast", kernel);
299            }
300            Some(())
301        }
302
303        Op::Range { end, axis_id, axis_type, .. } => {
304            if matches!(axis_type, AxisType::Thread) {
305                return None;
306            }
307            let end_val = ctx.get(end).to_string();
308            let id = axis_id.value();
309            let range_dtype = c_dtype(&uop.dtype());
310            let var_name = format!("ridx{id}");
311            let indent = ctx.indent();
312            kernel.push(format!("{indent}for ({range_dtype} {var_name} = 0; {var_name} < {end_val}; {var_name}++) {{"));
313            ctx.register(uop.id, var_name);
314            ctx.push_indent();
315            Some(())
316        }
317
318        Op::End { ranges, .. } => {
319            for range in ranges.iter() {
320                if let Op::Range { axis_type, .. } = range.op() {
321                    if matches!(axis_type, AxisType::Thread) {
322                        continue;
323                    }
324                    ctx.pop_indent();
325                    let indent = ctx.indent();
326                    kernel.push(format!("{indent}}}"));
327                }
328            }
329
330            // After closing loops, resolve pending reduces.
331            // In C, the accumulator variable already holds the final value
332            // (unlike LLVM where we need to load from alloca).
333            let pending = ctx.take_pending_reduces();
334            for (reduce_id, (acc_name, _dtype)) in pending {
335                // Re-register the reduce with the accumulator name
336                // so downstream users reference the accumulated value.
337                ctx.register(reduce_id, acc_name);
338            }
339            Some(())
340        }
341
342        Op::Reduce { src, ranges, reduce_op } => {
343            let src_val = ctx.get(src).to_string();
344            let dtype = &uop.dtype();
345
346            if ranges.is_empty() {
347                // Passthrough reduce
348                ctx.register(uop.id, src_val);
349            } else {
350                // Accumulator was pre-declared in mod.rs with name acc{uop.id}
351                let acc_name = ctx.get(uop).to_string();
352                let indent = ctx.indent();
353
354                let acc_expr = render_reduce_accumulate(*reduce_op, &acc_name, &src_val, dtype);
355                kernel.push(format!("{indent}{acc_expr}"));
356
357                // Register pending for End to emit the final value
358                ctx.register_reduce_pending(uop.id, acc_name, dtype.clone());
359            }
360            Some(())
361        }
362
363        Op::Gep { vector, indices } => {
364            let vec = ctx.get(vector).to_string();
365            if indices.len() == 1 {
366                // Parenthesize to handle precedence: *((float4*)ptr)[i] → (*((float4*)ptr))[i]
367                let expr = format!("({vec})[{}]", indices[0]);
368                ctx.emit_expr(uop, expr, "gep", kernel);
369            } else {
370                // Multi-element GEP: build a new vector from extracted elements
371                let out_dtype = c_dtype(&uop.dtype());
372                let elements: Vec<String> = indices.iter().map(|&i| format!("({vec})[{i}]")).collect();
373                let expr = format!("({out_dtype}){{{}}}", elements.join(", "));
374                ctx.emit_expr(uop, expr, "gep", kernel);
375            }
376            Some(())
377        }
378
379        Op::Vectorize { elements } => {
380            let vals: Vec<String> = elements.iter().map(|e| ctx.get(e).to_string()).collect();
381            if matches!(uop.dtype(), DType::Ptr { .. }) {
382                // Ptr types can't be vectorized in C (no compound literal for pointers).
383                // All elements should be the same scalar pointer — use the first one.
384                ctx.emit_expr(uop, vals[0].clone(), "vec", kernel);
385            } else {
386                let out_dtype = c_dtype(&uop.dtype());
387                let expr = format!("({out_dtype}){{{}}}", vals.join(", "));
388                ctx.emit_expr(uop, expr, "vec", kernel);
389            }
390            Some(())
391        }
392
393        Op::Cat { sources } => {
394            render_cat(uop, sources, ctx, kernel);
395            Some(())
396        }
397
398        Op::PtrCat { .. } => {
399            panic!(
400                "PtrCat must be eliminated before codegen (devectorize should distribute it into scalar loads/stores)"
401            );
402        }
403
404        Op::Wmma { a, b, c, metadata } => {
405            let a_val = ctx.get(a).to_string();
406            let b_val = ctx.get(b).to_string();
407            let c_val = ctx.get(c).to_string();
408            let expr = format!("__{name}({a_val}, {b_val}, {c_val})", name = metadata.name);
409            ctx.emit_expr(uop, expr, "wmma", kernel);
410            Some(())
411        }
412
413        Op::Contract { src, .. } | Op::Unroll { src, .. } | Op::Detach { src } => {
414            let s = ctx.get(src).to_string();
415            ctx.register(uop.id, s);
416            None
417        }
418
419        Op::After { passthrough, .. } => {
420            assert!(
421                !matches!(passthrough.op(), Op::Group { .. }),
422                "BUG: AFTER passthrough is GROUP (id={}). AFTER tree:\n{}",
423                passthrough.id,
424                uop.tree()
425            );
426            let s = ctx.get(passthrough).to_string();
427            ctx.register(uop.id, s);
428            None
429        }
430
431        Op::Bind { var, value } => {
432            let v = ctx.get(value).to_string();
433            ctx.register(var.id, v);
434            None
435        }
436
437        Op::If { condition, .. } => {
438            let cond = ctx.get(condition).to_string();
439            let indent = ctx.indent();
440            kernel.push(format!("{indent}if ({cond}) {{"));
441            ctx.push_indent();
442            Some(())
443        }
444
445        Op::EndIf { .. } => {
446            ctx.pop_indent();
447            let indent = ctx.indent();
448            kernel.push(format!("{indent}}}"));
449            Some(())
450        }
451
452        _ => {
453            let indent = ctx.indent();
454            kernel.push(format!("{indent}/* UNSUPPORTED: {:?} */", uop.op().as_ref()));
455            None
456        }
457    }
458}
459
460/// Linearize multiple index expressions into a single C expression.
461///
462/// Produces `(idx0*stride0 + idx1*stride1 + ...)`.
463fn render_linearize_multi_index_c(indices: &[Arc<UOp>], ctx: &CContext) -> String {
464    use morok_schedule::passes::linearize_index::{compute_row_major_strides, extract_index_dimension};
465
466    let dims: Vec<i64> = indices
467        .iter()
468        .map(|idx| extract_index_dimension(idx).expect("multi-index dimension must be resolvable at codegen"))
469        .collect();
470    let strides = compute_row_major_strides(&dims);
471
472    let mut terms: Vec<String> = Vec::new();
473    for (idx_uop, &stride) in indices.iter().zip(strides.iter()) {
474        if stride == 0 {
475            continue;
476        }
477        let idx_val = ctx.get(idx_uop);
478        if stride == 1 {
479            terms.push(idx_val.to_string());
480        } else {
481            terms.push(format!("({idx_val} * {stride})"));
482        }
483    }
484
485    if terms.is_empty() { "0".to_string() } else { format!("({})", terms.join(" + ")) }
486}
487
488/// Render a binary operation as a C expression.
489fn render_binary(op: BinaryOp, l: &str, r: &str, dtype: &DType) -> String {
490    match op {
491        BinaryOp::Add => format!("({l} + {r})"),
492        BinaryOp::Sub => format!("({l} - {r})"),
493        BinaryOp::Mul => format!("({l} * {r})"),
494        BinaryOp::Fdiv => format!("({l} / {r})"),
495        BinaryOp::Idiv => format!("({l} / {r})"),
496        BinaryOp::Mod => {
497            if dtype.is_float() {
498                format!("{}({l}, {r})", c_math_fn("__builtin_fmod", dtype))
499            } else {
500                format!("({l} % {r})")
501            }
502        }
503        BinaryOp::Max => {
504            if dtype.is_float() {
505                format!("{}({l}, {r})", c_math_fn("__builtin_fmax", dtype))
506            } else {
507                format!("({l} > {r} ? {l} : {r})")
508            }
509        }
510        BinaryOp::Lt => format!("({l} < {r})"),
511        BinaryOp::Le => format!("({l} <= {r})"),
512        BinaryOp::Gt => format!("({l} > {r})"),
513        BinaryOp::Ge => format!("({l} >= {r})"),
514        BinaryOp::Eq => format!("({l} == {r})"),
515        BinaryOp::Ne => format!("({l} != {r})"),
516        BinaryOp::And => format!("({l} & {r})"),
517        BinaryOp::Or => format!("({l} | {r})"),
518        BinaryOp::Xor => format!("({l} ^ {r})"),
519        BinaryOp::Shl => format!("({l} << {r})"),
520        BinaryOp::Shr => format!("({l} >> {r})"),
521        BinaryOp::Pow => {
522            if dtype.is_float() {
523                format!("{}({l}, {r})", c_math_fn("__builtin_pow", dtype))
524            } else {
525                // Integer pow via cast to double
526                format!("(({})__builtin_pow((double){l}, (double){r}))", c_dtype(&DType::Scalar(dtype.base())))
527            }
528        }
529        BinaryOp::Threefry => format!("({l} ^ {r})"),
530    }
531}
532
533/// Render a unary operation as a C expression.
534fn render_unary(op: UnaryOp, s: &str, dtype: &DType) -> String {
535    match op {
536        UnaryOp::Neg => {
537            format!("(-{s})")
538        }
539        UnaryOp::Not => {
540            if dtype.is_bool() {
541                format!("(!{s})")
542            } else {
543                format!("(~{s})")
544            }
545        }
546        UnaryOp::Abs => {
547            if dtype.is_float() {
548                format!("{}({s})", c_math_fn("__builtin_fabs", dtype))
549            } else {
550                format!("({s} < 0 ? -{s} : {s})")
551            }
552        }
553        UnaryOp::Sqrt => format!("{}({s})", c_math_fn("__builtin_sqrt", dtype)),
554        UnaryOp::Rsqrt => {
555            let one = if matches!(dtype.base(), ScalarDType::Float64) { "1.0" } else { "1.0f" };
556            format!("({one} / {}({s}))", c_math_fn("__builtin_sqrt", dtype))
557        }
558        UnaryOp::Reciprocal => {
559            let one = if matches!(dtype.base(), ScalarDType::Float64) { "1.0" } else { "1.0f" };
560            format!("({one} / {s})")
561        }
562        UnaryOp::Exp => format!("{}({s})", c_math_fn("__builtin_exp", dtype)),
563        UnaryOp::Exp2 => format!("{}({s})", c_math_fn("__builtin_exp2", dtype)),
564        UnaryOp::Log => format!("{}({s})", c_math_fn("__builtin_log", dtype)),
565        UnaryOp::Log2 => format!("{}({s})", c_math_fn("__builtin_log2", dtype)),
566        UnaryOp::Sin => format!("{}({s})", c_math_fn("__builtin_sin", dtype)),
567        UnaryOp::Cos => format!("{}({s})", c_math_fn("__builtin_cos", dtype)),
568        UnaryOp::Tan => format!("{}({s})", c_math_fn("__builtin_tan", dtype)),
569        UnaryOp::Floor => format!("{}({s})", c_math_fn("__builtin_floor", dtype)),
570        UnaryOp::Ceil => format!("{}({s})", c_math_fn("__builtin_ceil", dtype)),
571        UnaryOp::Trunc => format!("{}({s})", c_math_fn("__builtin_trunc", dtype)),
572        UnaryOp::Round => format!("{}({s})", c_math_fn("__builtin_rint", dtype)),
573        UnaryOp::Erf => format!("{}({s})", c_math_fn("__builtin_erf", dtype)),
574        UnaryOp::Sign => {
575            if dtype.is_float() {
576                let zero = if matches!(dtype.base(), ScalarDType::Float64) { "0.0" } else { "0.0f" };
577                format!("(({s} > {zero}) - ({s} < {zero}))")
578            } else {
579                format!("(({s} > 0) - ({s} < 0))")
580            }
581        }
582        UnaryOp::Square => format!("({s} * {s})"),
583    }
584}
585
586/// Render a reduce accumulation statement.
587fn render_reduce_accumulate(op: ReduceOp, acc: &str, val: &str, dtype: &DType) -> String {
588    match op {
589        ReduceOp::Add => format!("{acc} += {val};"),
590        ReduceOp::Mul => format!("{acc} *= {val};"),
591        ReduceOp::Max => {
592            if dtype.is_float() {
593                format!("{acc} = {}({acc}, {val});", c_math_fn("__builtin_fmax", dtype))
594            } else {
595                format!("{acc} = ({acc} > {val} ? {acc} : {val});")
596            }
597        }
598        ReduceOp::Min => {
599            if dtype.is_float() {
600                format!("{acc} = {}({acc}, {val});", c_math_fn("__builtin_fmin", dtype))
601            } else {
602                format!("{acc} = ({acc} < {val} ? {acc} : {val});")
603            }
604        }
605    }
606}
607
608/// Render a Cat operation (concatenate vectors).
609fn render_cat(uop: &Arc<UOp>, sources: &[Arc<UOp>], ctx: &mut CContext, kernel: &mut Vec<String>) {
610    let out_dtype = c_dtype(&uop.dtype());
611    let mut elements = Vec::new();
612
613    for src in sources {
614        let src_val = ctx.get(src).to_string();
615        let src_vcount = src.dtype().vcount();
616        if src_vcount == 1 {
617            elements.push(src_val);
618        } else {
619            for i in 0..src_vcount {
620                elements.push(format!("{src_val}[{i}]"));
621            }
622        }
623    }
624
625    let expr = format!("({out_dtype}){{{}}}", elements.join(", "));
626    ctx.emit_expr(uop, expr, "cat", kernel);
627}
628
629/// Count references for each UOp ID in the linearized stream.
630/// Used to determine which values should be inlined vs declared.
631pub fn count_references(nodes: &[Arc<UOp>]) -> HashMap<u64, usize> {
632    let mut counts: HashMap<u64, usize> = HashMap::new();
633    for node in nodes {
634        for child in node.op().children() {
635            *counts.entry(child.id).or_insert(0) += 1;
636        }
637    }
638    counts
639}