Skip to main content

morok_codegen/c/
ops.rs

1//! C source code rendering for individual UOp operations.
2//!
3//! Generates C expressions/statements for each Op variant.
4//! Uses SSA inlining: single-use values are inlined as expressions,
5//! multi-use values get local variable declarations.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use morok_dtype::{DType, ScalarDType};
11use morok_ir::{AxisType, BinaryOp, Op, ReduceOp, TernaryOp, UnaryOp, prelude::*};
12
13use super::types::{c_cast, c_const, c_dtype, c_math_fn};
14
15/// Context for C code generation, tracking variable names and SSA inlining.
16pub struct CContext {
17    /// UOp ID -> C expression or variable name
18    names: HashMap<u64, String>,
19    /// UOp ID -> reference count (how many times used)
20    ref_counts: HashMap<u64, usize>,
21    /// Variable counter for generating unique names
22    counter: usize,
23    /// Current indentation depth
24    depth: usize,
25    /// Pending reduce accumulator info: reduce_id -> (acc_name, dtype)
26    pending_reduces: HashMap<u64, (String, DType)>,
27}
28
29impl CContext {
30    pub fn new(ref_counts: HashMap<u64, usize>) -> Self {
31        Self { names: HashMap::new(), ref_counts, counter: 0, depth: 1, pending_reduces: HashMap::new() }
32    }
33
34    /// Get the C expression for a UOp. Panics if not registered.
35    pub fn get(&self, uop: &Arc<UOp>) -> &str {
36        self.names
37            .get(&uop.id)
38            .map(|s| s.as_str())
39            .unwrap_or_else(|| panic!("UOp {} ({:?}) not in C context", uop.id, uop.op()))
40    }
41
42    /// Register a name/expression for a UOp ID.
43    pub fn register(&mut self, id: u64, expr: String) {
44        self.names.insert(id, expr);
45    }
46
47    /// Check if a value should be inlined (single-use, expression-safe).
48    pub fn should_inline(&self, id: u64) -> bool {
49        self.ref_counts.get(&id).copied().unwrap_or(0) <= 1
50    }
51
52    /// Generate a unique variable name with given prefix.
53    pub fn next_name(&mut self, prefix: &str) -> String {
54        let name = format!("{}{}", prefix, self.counter);
55        self.counter += 1;
56        name
57    }
58
59    /// Get current indentation string.
60    pub fn indent(&self) -> String {
61        "  ".repeat(self.depth)
62    }
63
64    /// Increase indentation depth.
65    pub fn push_indent(&mut self) {
66        self.depth += 1;
67    }
68
69    /// Decrease indentation depth.
70    pub fn pop_indent(&mut self) {
71        self.depth = self.depth.saturating_sub(1);
72    }
73
74    /// Register a pending reduce final load.
75    pub fn register_reduce_pending(&mut self, reduce_id: u64, acc_name: String, dtype: DType) {
76        self.pending_reduces.insert(reduce_id, (acc_name, dtype));
77    }
78
79    /// Take all pending reduces.
80    pub fn take_pending_reduces(&mut self) -> HashMap<u64, (String, DType)> {
81        std::mem::take(&mut self.pending_reduces)
82    }
83
84    /// Emit a C expression, either as an inline expression or a variable declaration.
85    /// Returns the name/expression to reference this value.
86    pub fn emit_expr(&mut self, uop: &Arc<UOp>, expr: String, prefix: &str, kernel: &mut Vec<String>) -> String {
87        if self.should_inline(uop.id) {
88            self.register(uop.id, expr.clone());
89            expr
90        } else {
91            let name = self.next_name(prefix);
92            let dtype = c_dtype(&uop.dtype());
93            let indent = self.indent();
94            kernel.push(format!("{indent}{dtype} {name} = {expr};"));
95            self.register(uop.id, name.clone());
96            name
97        }
98    }
99}
100
101/// Render a single UOp to C source code.
102///
103/// Returns `Some(())` if code was emitted, `None` for meta-ops.
104pub fn render_uop(uop: &Arc<UOp>, ctx: &mut CContext, kernel: &mut Vec<String>) -> Option<()> {
105    match uop.op() {
106        // Meta-ops: no code emitted
107        Op::Const(_)
108        | Op::VConst { .. }
109        | Op::DefineGlobal(_)
110        | Op::DefineLocal(_)
111        | Op::DefineVar { .. }
112        | Op::Noop
113        | Op::Sink { .. }
114        | Op::Group { .. }
115        | Op::Buffer { .. }
116        | Op::Unique(_)
117        | Op::Device(_)
118        | Op::Kernel { .. }
119        | Op::Barrier { .. } => None,
120
121        Op::DefineReg { .. } => {
122            // Read base type and size from dtype (matching Tinygrad's x.dtype.base/x.dtype.size).
123            // After devectorize's no_vectorized_buf, the dtype is the canonical source of truth:
124            // e.g. Ptr(base=Float32, size=35) instead of the Op's original size field.
125            let (base_dtype, alloc_size) = match uop.dtype() {
126                DType::Ptr { base, size, .. } => (base.as_ref().clone(), size.unwrap_or(1)),
127                other => (other, 1),
128            };
129            let name = ctx.next_name("reg");
130            let indent = ctx.indent();
131            kernel.push(format!("{indent}{} {name}[{alloc_size}];", c_dtype(&base_dtype)));
132            ctx.register(uop.id, name);
133            Some(())
134        }
135
136        Op::Index { buffer, indices, .. } => {
137            let buf = ctx.get(buffer).to_string();
138
139            if indices.is_empty() {
140                // No index - just alias the buffer pointer
141                ctx.register(uop.id, buf);
142            } else {
143                // Multi-index: linearize at render time using row-major strides
144                let idx = if indices.len() > 1 {
145                    render_linearize_multi_index_c(indices, ctx)
146                } else {
147                    ctx.get(&indices[0]).to_string()
148                };
149                let expr = format!("{buf} + {idx}");
150                ctx.emit_expr(uop, expr, "idx", kernel);
151            }
152            Some(())
153        }
154
155        Op::PointerIndex { ptr, offset } => {
156            let ptr_val = ctx.get(ptr).to_string();
157            let off_val = ctx.get(offset).to_string();
158            let expr = format!("{ptr_val} + {off_val}");
159            ctx.emit_expr(uop, expr, "pidx", kernel);
160            Some(())
161        }
162
163        Op::Load { index, .. } => {
164            let idx = ctx.get(index).to_string();
165            let load_dtype = uop.dtype();
166            // Check if the INDEX source has a gate — render conditional load to avoid null deref.
167            // Tinygrad: LOAD with gated INDEX → (gate ? *(index) : alt_value)
168            let gate_expr = if let Op::Index { gate: Some(gate_uop), .. } = index.op() {
169                Some(ctx.get(gate_uop).to_string())
170            } else {
171                None
172            };
173            let deref_expr = if load_dtype.vcount() > 1 {
174                let cast_type = c_dtype(&load_dtype);
175                format!("*(({cast_type}*)({idx}))")
176            } else {
177                format!("*({idx})")
178            };
179            let expr = if let Some(gate) = gate_expr {
180                let zero = c_const(&morok_ir::types::ConstValue::zero(load_dtype.base()), &load_dtype);
181                format!("({gate} ? {deref_expr} : {zero})")
182            } else {
183                deref_expr
184            };
185            ctx.emit_expr(uop, expr, "val", kernel);
186            Some(())
187        }
188
189        Op::Store { index, value, .. } => {
190            let idx = ctx.get(index).to_string();
191            let val = ctx.get(value).to_string();
192            let indent = ctx.indent();
193            let val_dtype = value.dtype();
194            // Buffer pointers are declared as scalar types (e.g., float*) in C,
195            // so vector stores need an explicit pointer cast.
196            if val_dtype.vcount() > 1 {
197                let cast_type = c_dtype(&val_dtype);
198                kernel.push(format!("{indent}*(({cast_type}*)({idx})) = {val};"));
199            } else {
200                kernel.push(format!("{indent}*({idx}) = {val};"));
201            }
202            Some(())
203        }
204
205        Op::Binary(op, lhs, rhs) => {
206            let l = ctx.get(lhs).to_string();
207            let r = ctx.get(rhs).to_string();
208            let expr = render_binary(*op, &l, &r, &lhs.dtype());
209            ctx.emit_expr(uop, expr, "alu", kernel);
210            Some(())
211        }
212
213        Op::Unary(op, src) => {
214            let s = ctx.get(src).to_string();
215            let expr = render_unary(*op, &s, &src.dtype());
216            ctx.emit_expr(uop, expr, "alu", kernel);
217            Some(())
218        }
219
220        Op::Ternary(TernaryOp::Where, cond, t, f) => {
221            let c = ctx.get(cond).to_string();
222            let tv = ctx.get(t).to_string();
223            let fv = ctx.get(f).to_string();
224            let expr = format!("({c} ? {tv} : {fv})");
225            ctx.emit_expr(uop, expr, "alu", kernel);
226            Some(())
227        }
228
229        Op::Ternary(TernaryOp::MulAcc, a, b, c) => {
230            let av = ctx.get(a).to_string();
231            let bv = ctx.get(b).to_string();
232            let cv = ctx.get(c).to_string();
233            let expr = if a.dtype().is_float() {
234                format!("{}({av}, {bv}, {cv})", c_math_fn("__builtin_fma", &a.dtype()))
235            } else {
236                format!("(({av} * {bv}) + {cv})")
237            };
238            ctx.emit_expr(uop, expr, "alu", kernel);
239            Some(())
240        }
241
242        Op::Cast { src, dtype } => {
243            let s = ctx.get(src).to_string();
244
245            // INDEX to Ptr is a no-op in C (INDEX already produces a pointer)
246            if matches!(src.op(), Op::Index { .. }) && matches!(dtype, DType::Ptr { .. }) {
247                ctx.register(uop.id, s);
248                return Some(());
249            }
250
251            // Vector casts use __builtin_convertvector for element-wise conversion
252            // (a plain C cast would reinterpret bits, not convert values)
253            let expr = if dtype.vcount() > 1 && !matches!(dtype, DType::Ptr { .. }) {
254                format!("__builtin_convertvector({s}, {})", c_dtype(dtype))
255            } else {
256                c_cast(&s, &src.dtype(), dtype)
257            };
258            ctx.emit_expr(uop, expr, "cast", kernel);
259            Some(())
260        }
261
262        Op::BitCast { src, dtype } => {
263            let s = ctx.get(src).to_string();
264            let from_type = c_dtype(&src.dtype());
265            let to_type = c_dtype(dtype);
266            if from_type == to_type {
267                ctx.register(uop.id, s);
268            } else {
269                let expr = format!("__builtin_bit_cast({to_type}, ({from_type})({s}))");
270                ctx.emit_expr(uop, expr, "cast", kernel);
271            }
272            Some(())
273        }
274
275        Op::Range { end, axis_id, axis_type, .. } => {
276            if matches!(axis_type, AxisType::Thread) {
277                return None;
278            }
279            let end_val = ctx.get(end).to_string();
280            let id = axis_id.value();
281            let range_dtype = c_dtype(&uop.dtype());
282            let var_name = format!("ridx{id}");
283            let indent = ctx.indent();
284            kernel.push(format!("{indent}for ({range_dtype} {var_name} = 0; {var_name} < {end_val}; {var_name}++) {{"));
285            ctx.register(uop.id, var_name);
286            ctx.push_indent();
287            Some(())
288        }
289
290        Op::End { ranges, .. } => {
291            for range in ranges.iter() {
292                if let Op::Range { axis_type, .. } = range.op() {
293                    if matches!(axis_type, AxisType::Thread) {
294                        continue;
295                    }
296                    ctx.pop_indent();
297                    let indent = ctx.indent();
298                    kernel.push(format!("{indent}}}"));
299                }
300            }
301
302            // After closing loops, resolve pending reduces.
303            // In C, the accumulator variable already holds the final value
304            // (unlike LLVM where we need to load from alloca).
305            let pending = ctx.take_pending_reduces();
306            for (reduce_id, (acc_name, _dtype)) in pending {
307                // Re-register the reduce with the accumulator name
308                // so downstream users reference the accumulated value.
309                ctx.register(reduce_id, acc_name);
310            }
311            Some(())
312        }
313
314        Op::Reduce { src, ranges, reduce_op } => {
315            let src_val = ctx.get(src).to_string();
316            let dtype = &uop.dtype();
317
318            if ranges.is_empty() {
319                // Passthrough reduce
320                ctx.register(uop.id, src_val);
321            } else {
322                // Accumulator was pre-declared in mod.rs with name acc{uop.id}
323                let acc_name = ctx.get(uop).to_string();
324                let indent = ctx.indent();
325
326                let acc_expr = render_reduce_accumulate(*reduce_op, &acc_name, &src_val, dtype);
327                kernel.push(format!("{indent}{acc_expr}"));
328
329                // Register pending for End to emit the final value
330                ctx.register_reduce_pending(uop.id, acc_name, dtype.clone());
331            }
332            Some(())
333        }
334
335        Op::Gep { vector, indices } => {
336            let vec = ctx.get(vector).to_string();
337            if indices.len() == 1 {
338                // Parenthesize to handle precedence: *((float4*)ptr)[i] → (*((float4*)ptr))[i]
339                let expr = format!("({vec})[{}]", indices[0]);
340                ctx.emit_expr(uop, expr, "gep", kernel);
341            } else {
342                // Multi-element GEP: build a new vector from extracted elements
343                let out_dtype = c_dtype(&uop.dtype());
344                let elements: Vec<String> = indices.iter().map(|&i| format!("({vec})[{i}]")).collect();
345                let expr = format!("({out_dtype}){{{}}}", elements.join(", "));
346                ctx.emit_expr(uop, expr, "gep", kernel);
347            }
348            Some(())
349        }
350
351        Op::Vectorize { elements } => {
352            let vals: Vec<String> = elements.iter().map(|e| ctx.get(e).to_string()).collect();
353            if matches!(uop.dtype(), DType::Ptr { .. }) {
354                // Ptr types can't be vectorized in C (no compound literal for pointers).
355                // All elements should be the same scalar pointer — use the first one.
356                ctx.emit_expr(uop, vals[0].clone(), "vec", kernel);
357            } else {
358                let out_dtype = c_dtype(&uop.dtype());
359                let expr = format!("({out_dtype}){{{}}}", vals.join(", "));
360                ctx.emit_expr(uop, expr, "vec", kernel);
361            }
362            Some(())
363        }
364
365        Op::Cat { sources } => {
366            render_cat(uop, sources, ctx, kernel);
367            Some(())
368        }
369
370        Op::PtrCat { sources } => {
371            // PtrCat is not typical in C, render as array
372            render_cat(uop, sources, ctx, kernel);
373            Some(())
374        }
375
376        Op::Wmma { a, b, c, metadata } => {
377            let a_val = ctx.get(a).to_string();
378            let b_val = ctx.get(b).to_string();
379            let c_val = ctx.get(c).to_string();
380            let expr = format!("__{name}({a_val}, {b_val}, {c_val})", name = metadata.name);
381            ctx.emit_expr(uop, expr, "wmma", kernel);
382            Some(())
383        }
384
385        Op::Contract { src, .. } | Op::Unroll { src, .. } | Op::Detach { src } => {
386            let s = ctx.get(src).to_string();
387            ctx.register(uop.id, s);
388            None
389        }
390
391        Op::After { passthrough, .. } => {
392            let s = ctx.get(passthrough).to_string();
393            ctx.register(uop.id, s);
394            None
395        }
396
397        Op::Bind { var, value } => {
398            let v = ctx.get(value).to_string();
399            ctx.register(var.id, v);
400            None
401        }
402
403        Op::If { condition, .. } => {
404            let cond = ctx.get(condition).to_string();
405            let indent = ctx.indent();
406            kernel.push(format!("{indent}if ({cond}) {{"));
407            ctx.push_indent();
408            Some(())
409        }
410
411        Op::EndIf { .. } => {
412            ctx.pop_indent();
413            let indent = ctx.indent();
414            kernel.push(format!("{indent}}}"));
415            Some(())
416        }
417
418        _ => {
419            let indent = ctx.indent();
420            kernel.push(format!("{indent}/* UNSUPPORTED: {:?} */", uop.op().as_ref()));
421            None
422        }
423    }
424}
425
426/// Linearize multiple index expressions into a single C expression.
427///
428/// Produces `(idx0*stride0 + idx1*stride1 + ...)`.
429fn render_linearize_multi_index_c(indices: &[Arc<UOp>], ctx: &CContext) -> String {
430    use morok_schedule::passes::linearize_index::{compute_row_major_strides, extract_index_dimension};
431
432    let dims: Vec<i64> = indices
433        .iter()
434        .map(|idx| extract_index_dimension(idx).expect("multi-index dimension must be resolvable at codegen"))
435        .collect();
436    let strides = compute_row_major_strides(&dims);
437
438    let mut terms: Vec<String> = Vec::new();
439    for (idx_uop, &stride) in indices.iter().zip(strides.iter()) {
440        if stride == 0 {
441            continue;
442        }
443        let idx_val = ctx.get(idx_uop);
444        if stride == 1 {
445            terms.push(idx_val.to_string());
446        } else {
447            terms.push(format!("({idx_val} * {stride})"));
448        }
449    }
450
451    if terms.is_empty() { "0".to_string() } else { format!("({})", terms.join(" + ")) }
452}
453
454/// Render a binary operation as a C expression.
455fn render_binary(op: BinaryOp, l: &str, r: &str, dtype: &DType) -> String {
456    match op {
457        BinaryOp::Add => format!("({l} + {r})"),
458        BinaryOp::Sub => format!("({l} - {r})"),
459        BinaryOp::Mul => format!("({l} * {r})"),
460        BinaryOp::Fdiv => format!("({l} / {r})"),
461        BinaryOp::Idiv => format!("({l} / {r})"),
462        BinaryOp::Mod => {
463            if dtype.is_float() {
464                format!("{}({l}, {r})", c_math_fn("__builtin_fmod", dtype))
465            } else {
466                format!("({l} % {r})")
467            }
468        }
469        BinaryOp::Max => {
470            if dtype.is_float() {
471                format!("{}({l}, {r})", c_math_fn("__builtin_fmax", dtype))
472            } else {
473                format!("({l} > {r} ? {l} : {r})")
474            }
475        }
476        BinaryOp::Lt => format!("({l} < {r})"),
477        BinaryOp::Le => format!("({l} <= {r})"),
478        BinaryOp::Gt => format!("({l} > {r})"),
479        BinaryOp::Ge => format!("({l} >= {r})"),
480        BinaryOp::Eq => format!("({l} == {r})"),
481        BinaryOp::Ne => format!("({l} != {r})"),
482        BinaryOp::And => format!("({l} & {r})"),
483        BinaryOp::Or => format!("({l} | {r})"),
484        BinaryOp::Xor => format!("({l} ^ {r})"),
485        BinaryOp::Shl => format!("({l} << {r})"),
486        BinaryOp::Shr => format!("({l} >> {r})"),
487        BinaryOp::Pow => {
488            if dtype.is_float() {
489                format!("{}({l}, {r})", c_math_fn("__builtin_pow", dtype))
490            } else {
491                // Integer pow via cast to double
492                format!("(({})__builtin_pow((double){l}, (double){r}))", c_dtype(&DType::Scalar(dtype.base())))
493            }
494        }
495        BinaryOp::Threefry => format!("({l} ^ {r})"),
496    }
497}
498
499/// Render a unary operation as a C expression.
500fn render_unary(op: UnaryOp, s: &str, dtype: &DType) -> String {
501    match op {
502        UnaryOp::Neg => {
503            format!("(-{s})")
504        }
505        UnaryOp::Not => {
506            if dtype.is_bool() {
507                format!("(!{s})")
508            } else {
509                format!("(~{s})")
510            }
511        }
512        UnaryOp::Abs => {
513            if dtype.is_float() {
514                format!("{}({s})", c_math_fn("__builtin_fabs", dtype))
515            } else {
516                format!("({s} < 0 ? -{s} : {s})")
517            }
518        }
519        UnaryOp::Sqrt => format!("{}({s})", c_math_fn("__builtin_sqrt", dtype)),
520        UnaryOp::Rsqrt => {
521            let one = if matches!(dtype.base(), ScalarDType::Float64) { "1.0" } else { "1.0f" };
522            format!("({one} / {}({s}))", c_math_fn("__builtin_sqrt", dtype))
523        }
524        UnaryOp::Reciprocal => {
525            let one = if matches!(dtype.base(), ScalarDType::Float64) { "1.0" } else { "1.0f" };
526            format!("({one} / {s})")
527        }
528        UnaryOp::Exp => format!("{}({s})", c_math_fn("__builtin_exp", dtype)),
529        UnaryOp::Exp2 => format!("{}({s})", c_math_fn("__builtin_exp2", dtype)),
530        UnaryOp::Log => format!("{}({s})", c_math_fn("__builtin_log", dtype)),
531        UnaryOp::Log2 => format!("{}({s})", c_math_fn("__builtin_log2", dtype)),
532        UnaryOp::Sin => format!("{}({s})", c_math_fn("__builtin_sin", dtype)),
533        UnaryOp::Cos => format!("{}({s})", c_math_fn("__builtin_cos", dtype)),
534        UnaryOp::Tan => format!("{}({s})", c_math_fn("__builtin_tan", dtype)),
535        UnaryOp::Floor => format!("{}({s})", c_math_fn("__builtin_floor", dtype)),
536        UnaryOp::Ceil => format!("{}({s})", c_math_fn("__builtin_ceil", dtype)),
537        UnaryOp::Trunc => format!("{}({s})", c_math_fn("__builtin_trunc", dtype)),
538        UnaryOp::Round => format!("{}({s})", c_math_fn("__builtin_rint", dtype)),
539        UnaryOp::Erf => format!("{}({s})", c_math_fn("__builtin_erf", dtype)),
540        UnaryOp::Sign => {
541            if dtype.is_float() {
542                let zero = if matches!(dtype.base(), ScalarDType::Float64) { "0.0" } else { "0.0f" };
543                format!("(({s} > {zero}) - ({s} < {zero}))")
544            } else {
545                format!("(({s} > 0) - ({s} < 0))")
546            }
547        }
548        UnaryOp::Square => format!("({s} * {s})"),
549    }
550}
551
552/// Render a reduce accumulation statement.
553fn render_reduce_accumulate(op: ReduceOp, acc: &str, val: &str, dtype: &DType) -> String {
554    match op {
555        ReduceOp::Add => format!("{acc} += {val};"),
556        ReduceOp::Mul => format!("{acc} *= {val};"),
557        ReduceOp::Max => {
558            if dtype.is_float() {
559                format!("{acc} = {}({acc}, {val});", c_math_fn("__builtin_fmax", dtype))
560            } else {
561                format!("{acc} = ({acc} > {val} ? {acc} : {val});")
562            }
563        }
564        ReduceOp::Min => {
565            if dtype.is_float() {
566                format!("{acc} = {}({acc}, {val});", c_math_fn("__builtin_fmin", dtype))
567            } else {
568                format!("{acc} = ({acc} < {val} ? {acc} : {val});")
569            }
570        }
571    }
572}
573
574/// Render a Cat operation (concatenate vectors).
575fn render_cat(uop: &Arc<UOp>, sources: &[Arc<UOp>], ctx: &mut CContext, kernel: &mut Vec<String>) {
576    let out_dtype = c_dtype(&uop.dtype());
577    let mut elements = Vec::new();
578
579    for src in sources {
580        let src_val = ctx.get(src).to_string();
581        let src_vcount = src.dtype().vcount();
582        if src_vcount == 1 {
583            elements.push(src_val);
584        } else {
585            for i in 0..src_vcount {
586                elements.push(format!("{src_val}[{i}]"));
587            }
588        }
589    }
590
591    let expr = format!("({out_dtype}){{{}}}", elements.join(", "));
592    ctx.emit_expr(uop, expr, "cat", kernel);
593}
594
595/// Count references for each UOp ID in the linearized stream.
596/// Used to determine which values should be inlined vs declared.
597pub fn count_references(nodes: &[Arc<UOp>]) -> HashMap<u64, usize> {
598    let mut counts: HashMap<u64, usize> = HashMap::new();
599    for node in nodes {
600        for child in node.op().children() {
601            *counts.entry(child.id).or_insert(0) += 1;
602        }
603    }
604    counts
605}