Skip to main content

morok_codegen/llvm/cpu/
ops.rs

1//! CPU-specific LLVM IR operation rendering.
2//!
3//! Generates LLVM IR strings for individual UOp operations on CPU.
4//! Based on Tinygrad's PatternMatcher templates in `llvmir.py`.
5
6use std::sync::Arc;
7
8use morok_dtype::DType;
9use morok_ir::{AxisType, BinaryOp, Op, ReduceOp, TernaryOp, UnaryOp, prelude::*};
10
11use crate::llvm::common::{RenderContext, lcast, ldt};
12
13/// Extract a scalar `ptr` from a vectorized `<N x ptr>` via `extractelement ... i32 0`.
14///
15/// When the devectorize pipeline doesn't fully eliminate vectorized PARAM pointers
16/// (see `no_vectorized_buf` / `no_vectorized_index` which only target DEFINE_LOCAL/DEFINE_REG),
17/// the GEP result can be `<N x ptr>`. All elements are identical (broadcast of the same buffer
18/// pointer), so extracting element 0 yields the correct scalar ptr for LLVM load/store.
19fn maybe_extract_scalar_ptr(
20    dst: &str,
21    idx: &str,
22    idx_type: &str,
23    dtype: &DType,
24    kernel: &mut Vec<String>,
25) -> (String, String) {
26    if matches!(dtype, DType::Ptr { vcount, .. } if *vcount > 1) {
27        let extract = format!("{dst}.ptr");
28        kernel.push(format!("  {extract} = extractelement {idx_type} {idx}, i32 0"));
29        (extract, "ptr".to_string())
30    } else {
31        (idx.to_string(), idx_type.to_string())
32    }
33}
34
35/// Render a UOp to LLVM IR string.
36///
37/// Returns None for meta-ops that don't produce instructions.
38pub fn render_uop(uop: &Arc<UOp>, ctx: &mut RenderContext, kernel: &mut Vec<String>) -> Option<()> {
39    let dst = ctx.name(uop);
40
41    match uop.op() {
42        Op::Const(_)
43        | Op::VConst { .. }
44        | Op::Param { device: None, .. }
45        | Op::DefineVar { .. }
46        | Op::Noop
47        | Op::Sink { .. }
48        | Op::Group { .. }
49        | Op::Buffer { .. }
50        | Op::Unique(_)
51        | Op::Device(_)
52        | Op::Kernel { .. }
53        | Op::Barrier { .. } => None,
54
55        Op::DefineLocal(_) | Op::DefineReg { .. } => {
56            // Emit alloca for local/register memory.
57            // Read base type and size from dtype (matching Tinygrad's x.dtype.base/x.dtype.size).
58            // After devectorize's no_vectorized_buf, dtype is the canonical source of truth.
59            let (base_dtype, alloc_size) = match uop.dtype() {
60                DType::Ptr { base, size, .. } => (base.as_ref().clone(), size.unwrap_or(1)),
61                other => (other, 1),
62            };
63            let base = ldt(&base_dtype);
64            // Tinygrad: DEFINE_LOCAL gets align 16 (for SSE vector loads), DEFINE_REG gets default.
65            let align = if matches!(uop.op(), Op::DefineLocal(_)) { ", align 16" } else { "" };
66            kernel.push(format!("  {dst} = alloca [{alloc_size} x {base}]{align}"));
67            Some(())
68        }
69
70        Op::Index { buffer, indices, .. } => {
71            let buf = ctx.get(buffer);
72            let buf_type = ldt(&buffer.dtype());
73
74            if indices.is_empty() {
75                kernel.push(format!("  {dst} = bitcast {buf_type} {buf} to {}", ldt(&uop.dtype())));
76            } else {
77                // Multi-index: linearize at render time using row-major strides
78                let (final_idx, final_idx_type) = if indices.len() > 1 {
79                    render_linearize_multi_index(&dst, indices, ctx, kernel)
80                } else {
81                    (ctx.get(&indices[0]).to_string(), ldt(&indices[0].dtype()))
82                };
83
84                let elem_type = match uop.dtype() {
85                    morok_dtype::DType::Ptr { ref base, .. } => ldt(base),
86                    other => ldt(&other),
87                };
88
89                // Gate is NOT handled here — matching Tinygrad's approach where INDEX
90                // always emits a plain GEP. The gate is handled at LOAD level (branch+phi)
91                // and at STORE level (IF/ENDIF via line_rewrite_cleanups).
92                kernel.push(format!(
93                    "  {dst} = getelementptr inbounds {elem_type}, {buf_type} {buf}, {final_idx_type} {final_idx}"
94                ));
95            }
96            Some(())
97        }
98
99        Op::PointerIndex { ptr, offset } => {
100            let ptr_val = ctx.get(ptr);
101            let off_val = ctx.get(offset);
102            let elem_type = ldt(&uop.dtype());
103            let ptr_type = ldt(&ptr.dtype());
104            let off_type = ldt(&offset.dtype());
105
106            kernel.push(format!(
107                "  {dst} = getelementptr inbounds {elem_type}, {ptr_type} {ptr_val}, {off_type} {off_val}"
108            ));
109            Some(())
110        }
111
112        Op::Load { index, alt, .. } => {
113            let idx = ctx.get(index);
114            let dtype = ldt(&uop.dtype());
115            let idx_type = ldt(&index.dtype());
116
117            let (idx, idx_type) = maybe_extract_scalar_ptr(&dst, idx, &idx_type, &index.dtype(), kernel);
118
119            // Gated LOAD: emit branch+phi to avoid null deref.
120            // Matches Tinygrad's pattern (llvmir.py:123-129) which requires BOTH
121            // a gated INDEX and an alt value on the LOAD. If gate exists without
122            // alt, that's a pipeline bug (line_rewrite_cleanups should provide it).
123            // Unwrap one CAST layer to find the INDEX gate (matches Tinygrad's .or_casted("idx")).
124            // The pipeline CAN produce CAST(INDEX) — devectorize handles this shape explicitly.
125            let actual_index = match index.op() {
126                Op::Cast { src, .. } => src,
127                _ => index,
128            };
129            let gate_info = if let Op::Index { gate: Some(gate_uop), .. } = actual_index.op() {
130                let alt_uop = alt.as_ref().expect(
131                    "gated LOAD without alt value — pipeline bug: \
132                     line_rewrite_cleanups should ensure alt is present for gated loads",
133                );
134                Some((ctx.get(gate_uop).to_string(), ctx.get(alt_uop).to_string()))
135            } else {
136                None
137            };
138
139            if let Some((gate, alt_val)) = gate_info {
140                let label_base = &dst[1..]; // strip leading %
141                let entry_label = format!("{label_base}_entry");
142                let load_label = format!("{label_base}_load");
143                let exit_label = format!("{label_base}_exit");
144                let load_val = format!("{dst}_yes");
145
146                kernel.push(format!("  br label %{entry_label}"));
147                kernel.push(format!("{entry_label}:"));
148                kernel.push(format!("  br i1 {gate}, label %{load_label}, label %{exit_label}"));
149                kernel.push(format!("{load_label}:"));
150                kernel.push(format!("  {load_val} = load {dtype}, {idx_type} {idx}"));
151                kernel.push(format!("  br label %{exit_label}"));
152                kernel.push(format!("{exit_label}:"));
153                kernel.push(format!("  {dst} = phi {dtype} [{load_val}, %{load_label}], [{alt_val}, %{entry_label}]"));
154            } else {
155                kernel.push(format!("  {dst} = load {dtype}, {idx_type} {idx}"));
156            }
157            Some(())
158        }
159
160        Op::Store { index, value, .. } => {
161            let idx = ctx.get(index);
162            let val = ctx.get(value);
163            let val_type = ldt(&value.dtype());
164            let idx_type = ldt(&index.dtype());
165
166            let (idx, idx_type) = maybe_extract_scalar_ptr(&dst, idx, &idx_type, &index.dtype(), kernel);
167
168            kernel.push(format!("  store {val_type} {val}, {idx_type} {idx}"));
169            Some(())
170        }
171
172        Op::Binary(op, lhs, rhs) => {
173            let l = ctx.get(lhs);
174            let r = ctx.get(rhs);
175            let ltype = ldt(&lhs.dtype());
176            let rtype = ldt(&rhs.dtype());
177
178            // Debug: detect type mismatch (logged via tracing)
179            if ltype != rtype {
180                tracing::error!(
181                    uop_id = uop.id,
182                    uop_dtype = ?uop.dtype(),
183                    op = ?op,
184                    lhs_id = lhs.id,
185                    rhs_id = rhs.id,
186                    lhs_dtype = ?lhs.dtype(),
187                    rhs_dtype = ?rhs.dtype(),
188                    lhs_op = ?lhs.op().as_ref(),
189                    rhs_op = ?rhs.op().as_ref(),
190                    "Binary op type mismatch - lhs and rhs have different dtypes"
191                );
192            }
193
194            if matches!(op, BinaryOp::Max) {
195                render_binary_max(&dst, lhs, l, r, &ltype, kernel);
196            } else if matches!(op, BinaryOp::Pow) {
197                render_binary_pow(&dst, lhs, l, r, &ltype, kernel);
198            } else {
199                let instr = binary_instr(*op, &lhs.dtype());
200                kernel.push(format!("  {dst} = {instr} {ltype} {l}, {r}"));
201            }
202            Some(())
203        }
204
205        Op::Unary(op, src) => {
206            let s = ctx.get(src);
207            let stype = ldt(&src.dtype());
208
209            match op {
210                UnaryOp::Neg => {
211                    if src.dtype().is_float() {
212                        kernel.push(format!("  {dst} = fneg {stype} {s}"));
213                    } else {
214                        kernel.push(format!("  {dst} = sub {stype} 0, {s}"));
215                    }
216                }
217                UnaryOp::Not => {
218                    let all_ones = if src.dtype().is_bool() { "1".to_string() } else { "-1".to_string() };
219                    kernel.push(format!("  {dst} = xor {stype} {s}, {all_ones}"));
220                }
221                UnaryOp::Floor | UnaryOp::Ceil | UnaryOp::Trunc | UnaryOp::Round if !src.dtype().is_float() => {
222                    // Rounding is identity for integer types (defense-in-depth;
223                    // symbolic_simple folds these away upstream).
224                    kernel.push(format!("  {dst} = bitcast {stype} {s} to {stype}"));
225                }
226                UnaryOp::Sqrt
227                | UnaryOp::Exp
228                | UnaryOp::Exp2
229                | UnaryOp::Log
230                | UnaryOp::Log2
231                | UnaryOp::Sin
232                | UnaryOp::Cos
233                | UnaryOp::Floor
234                | UnaryOp::Ceil
235                | UnaryOp::Trunc
236                | UnaryOp::Round => {
237                    let intrinsic = unary_instr(*op, &src.dtype()).unwrap();
238                    render_intrinsic(&dst, intrinsic, &[(&stype, s)], &stype, kernel);
239                }
240                UnaryOp::Abs => {
241                    if src.dtype().is_float() {
242                        render_intrinsic(&dst, "fabs", &[(&stype, s)], &stype, kernel);
243                    } else {
244                        render_intrinsic(&dst, "abs", &[(&stype, s), ("i1", "1")], &stype, kernel);
245                    }
246                }
247                UnaryOp::Rsqrt => {
248                    let sqrt_dst = format!("{dst}.sqrt");
249                    render_intrinsic(&sqrt_dst, "sqrt", &[(&stype, s)], &stype, kernel);
250                    kernel.push(format!("  {dst} = fdiv nsz arcp contract afn {stype} 1.0, {sqrt_dst}"));
251                }
252                UnaryOp::Reciprocal => {
253                    kernel.push(format!("  {dst} = fdiv nsz arcp contract afn {stype} 1.0, {s}"));
254                }
255                UnaryOp::Tan => {
256                    let sin_dst = format!("{dst}.sin");
257                    let cos_dst = format!("{dst}.cos");
258                    render_intrinsic(&sin_dst, "sin", &[(&stype, s)], &stype, kernel);
259                    render_intrinsic(&cos_dst, "cos", &[(&stype, s)], &stype, kernel);
260                    kernel.push(format!("  {dst} = fdiv nsz arcp contract afn {stype} {sin_dst}, {cos_dst}"));
261                }
262                UnaryOp::Sign => {
263                    if src.dtype().is_float() {
264                        let gt_zero = format!("{dst}.gt");
265                        let lt_zero = format!("{dst}.lt");
266                        let gt_ext = format!("{dst}.gt_ext");
267                        let lt_ext = format!("{dst}.lt_ext");
268                        kernel.push(format!("  {gt_zero} = fcmp nsz arcp contract afn ogt {stype} {s}, 0.0"));
269                        kernel.push(format!("  {lt_zero} = fcmp nsz arcp contract afn olt {stype} {s}, 0.0"));
270                        kernel.push(format!("  {gt_ext} = uitofp i1 {gt_zero} to {stype}"));
271                        kernel.push(format!("  {lt_ext} = uitofp i1 {lt_zero} to {stype}"));
272                        kernel.push(format!("  {dst} = fsub nsz arcp contract afn {stype} {gt_ext}, {lt_ext}"));
273                    } else {
274                        let is_signed = src.dtype().is_signed();
275                        let cmp = if is_signed { "sgt" } else { "ugt" };
276                        let cmp_lt = if is_signed { "slt" } else { "icmp eq" };
277                        let gt_zero = format!("{dst}.gt");
278                        let lt_zero = format!("{dst}.lt");
279                        let gt_ext = format!("{dst}.gt_ext");
280                        let lt_ext = format!("{dst}.lt_ext");
281                        kernel.push(format!("  {gt_zero} = icmp {cmp} {stype} {s}, 0"));
282                        if is_signed {
283                            kernel.push(format!("  {lt_zero} = icmp {cmp_lt} {stype} {s}, 0"));
284                        } else {
285                            kernel.push(format!("  {lt_zero} = icmp eq {stype} {s}, 0"));
286                            kernel.push(format!("  {lt_zero} = xor i1 {lt_zero}, 1"));
287                            kernel.push(format!("  {lt_zero} = and i1 {lt_zero}, 0"));
288                        }
289                        kernel.push(format!("  {gt_ext} = zext i1 {gt_zero} to {stype}"));
290                        kernel.push(format!("  {lt_ext} = zext i1 {lt_zero} to {stype}"));
291                        kernel.push(format!("  {dst} = sub {stype} {gt_ext}, {lt_ext}"));
292                    }
293                }
294                UnaryOp::Erf => {
295                    render_intrinsic(&dst, "erf", &[(&stype, s)], &stype, kernel);
296                }
297                UnaryOp::Square => {
298                    if src.dtype().is_float() {
299                        kernel.push(format!("  {dst} = fmul nsz arcp contract afn {stype} {s}, {s}"));
300                    } else {
301                        kernel.push(format!("  {dst} = mul {stype} {s}, {s}"));
302                    }
303                }
304            }
305            Some(())
306        }
307
308        Op::Ternary(TernaryOp::Where, cond, t, f) => {
309            let c = ctx.get(cond);
310            let tv = ctx.get(t);
311            let fv = ctx.get(f);
312            kernel.push(format!(
313                "  {dst} = select {} {c}, {} {tv}, {} {fv}",
314                ldt(&cond.dtype()),
315                ldt(&t.dtype()),
316                ldt(&f.dtype())
317            ));
318            Some(())
319        }
320
321        Op::Ternary(TernaryOp::MulAcc, a, b, c) => {
322            let av = ctx.get(a);
323            let bv = ctx.get(b);
324            let cv = ctx.get(c);
325            let dtype = ldt(&a.dtype());
326
327            if a.dtype().is_float() {
328                render_intrinsic(&dst, "fmuladd", &[(&dtype, av), (&dtype, bv), (&dtype, cv)], &dtype, kernel);
329            } else {
330                let mul_dst = format!("{dst}.mul");
331                kernel.push(format!("  {mul_dst} = mul {dtype} {av}, {bv}"));
332                kernel.push(format!("  {dst} = add {dtype} {mul_dst}, {cv}"));
333            }
334            Some(())
335        }
336
337        Op::Cast { src, dtype } => {
338            let s = ctx.get(src);
339
340            // INDEX always produces ptr in LLVM (via GEP), regardless of Morok dtype.
341            // When source is INDEX, treat source LLVM type as ptr for cast selection.
342            let is_index_src = matches!(src.op(), Op::Index { .. });
343            let src_llvm_type = if is_index_src { "ptr".to_string() } else { ldt(&src.dtype()) };
344            let dst_llvm_type = ldt(dtype);
345
346            // CAST(INDEX) to Ptr is a no-op - INDEX already produces ptr via GEP.
347            // This matches Tinygrad's approach (llvmir.py:189) where CAST to PtrDType
348            // is register aliasing: r[u] = r[u.src[0]]
349            if is_index_src && matches!(dtype, DType::Ptr { .. }) {
350                // Emit a bitcast as a named no-op to maintain SSA form
351                kernel.push(format!("  {dst} = bitcast ptr {s} to ptr"));
352                return Some(());
353            }
354
355            if dtype.is_bool() && !src.dtype().is_bool() {
356                // Cast to bool: compare != 0 (not trunc, which only takes the low bit).
357                // Matches Tinygrad llvmir.py:99-101.
358                let cmp = if src.dtype().is_float() { "fcmp nsz arcp contract afn une" } else { "icmp ne" };
359                kernel.push(format!("  {dst} = {cmp} {src_llvm_type} {s}, zeroinitializer"));
360            } else if src_llvm_type == dst_llvm_type {
361                kernel.push(format!("  {dst} = bitcast {src_llvm_type} {s} to {dst_llvm_type}"));
362            } else {
363                let cast_instr = lcast(&src.dtype(), dtype);
364                kernel.push(format!("  {dst} = {cast_instr} {src_llvm_type} {s} to {dst_llvm_type}"));
365            }
366            Some(())
367        }
368
369        Op::BitCast { src, dtype } => {
370            let s = ctx.get(src);
371            kernel.push(format!("  {dst} = bitcast {} {s} to {}", ldt(&src.dtype()), ldt(dtype)));
372            Some(())
373        }
374
375        Op::Range { axis_id, end, .. } => {
376            let id = axis_id.value();
377            let dtype = ldt(&uop.dtype());
378            let end_val = ctx.get(end).to_string();
379
380            // Track range nesting for correct END footer ordering.
381            ctx.push_range(id);
382
383            // Matches Tinygrad llvmir.py:156-165 exactly:
384            //   entry → loop_entry (preheader) → loop_latch (phi+incr+cmp) → loop_body / loop_exit
385            //   loop_body contains body instructions
386            //   END branches to loop_footer → loop_latch (back edge)
387            kernel.push(format!("  br label %loop_entry_{id}"));
388            kernel.push(format!("loop_entry_{id}:"));
389            kernel.push(format!("  br label %loop_latch_{id}"));
390            kernel.push(format!("loop_latch_{id}:"));
391            kernel.push(format!("  {dst} = phi {dtype} [ 0, %loop_entry_{id} ], [ {dst}phi, %loop_footer_{id} ]"));
392            kernel.push(format!("  {dst}phi = add {dtype} {dst}, 1"));
393            kernel.push(format!("  {dst}cmp = icmp ult {dtype} {dst}, {end_val}"));
394            kernel.push(format!("  br i1 {dst}cmp, label %loop_body_{id}, label %loop_exit_{id}"));
395            kernel.push(format!("loop_body_{id}:"));
396            Some(())
397        }
398
399        Op::End { ranges, .. } => {
400            // After pm_split_ends, each END has exactly one RANGE.
401            // Use the range_stack to emit footer blocks in correct nesting order
402            // (innermost first = LIFO), regardless of the END's ranges field order.
403            let range_count = ranges
404                .iter()
405                .filter(|r| matches!(r.op(), Op::Range { axis_type, .. } if !matches!(axis_type, AxisType::Thread)))
406                .count();
407            for _ in 0..range_count {
408                if let Some(id) = ctx.pop_range() {
409                    // Matches Tinygrad llvmir.py:166-170 exactly:
410                    //   body → loop_footer → loop_latch (back edge)
411                    //   loop_exit: falls through after loop
412                    kernel.push(format!("  br label %loop_footer_{id}"));
413                    kernel.push(format!("loop_footer_{id}:"));
414                    kernel.push(format!("  br label %loop_latch_{id}"));
415                    kernel.push(format!("loop_exit_{id}:"));
416                }
417            }
418
419            let pending = ctx.take_pending_reduces();
420            for (reduce_id, info) in pending {
421                let result_name = format!("%reduce_{reduce_id}.final");
422                kernel.push(format!("  {result_name} = load {}, ptr {}", info.dtype, info.acc_ptr));
423                ctx.register(reduce_id, result_name);
424            }
425            Some(())
426        }
427
428        Op::Reduce { src, ranges, reduce_op } => {
429            let src_val = ctx.get(src);
430            let dtype = ldt(&uop.dtype());
431
432            if ranges.is_empty() {
433                kernel.push(format!("  {dst} = bitcast {dtype} {src_val} to {dtype}"));
434            } else {
435                let acc_ptr = format!("%reduce_{}", uop.id);
436                let acc_load = format!("{acc_ptr}.load");
437                let acc_new = format!("{acc_ptr}.new");
438                let instr = reduce_instr(*reduce_op, &uop.dtype());
439
440                kernel.push(format!("  {acc_load} = load {dtype}, ptr {acc_ptr}"));
441
442                if matches!(reduce_op, ReduceOp::Max | ReduceOp::Min) {
443                    render_reduce_minmax(&acc_new, *reduce_op, &uop.dtype(), &acc_load, src_val, &dtype, kernel);
444                } else {
445                    kernel.push(format!("  {acc_new} = {instr} {dtype} {acc_load}, {src_val}"));
446                }
447
448                kernel.push(format!("  store {dtype} {acc_new}, ptr {acc_ptr}"));
449                ctx.register_reduce_pending(uop.id, acc_ptr.clone(), dtype.clone());
450            }
451            Some(())
452        }
453
454        Op::Gep { vector, indices } => {
455            let vec = ctx.get(vector);
456            let vec_type = ldt(&vector.dtype());
457            let out_type = ldt(&uop.dtype());
458
459            if indices.len() == 1 {
460                kernel.push(format!("  {dst} = extractelement {vec_type} {vec}, i32 {}", indices[0]));
461            } else {
462                render_multi_gep(&dst, vec, &vector.dtype(), indices, &out_type, kernel);
463            }
464            Some(())
465        }
466
467        Op::Vectorize { elements } => {
468            render_vectorize(&dst, elements, ctx, kernel);
469            Some(())
470        }
471
472        Op::Cat { sources } => {
473            render_cat(&dst, sources, ctx, kernel);
474            Some(())
475        }
476
477        Op::PtrCat { .. } => {
478            panic!(
479                "PtrCat must be eliminated before codegen (devectorize should distribute it into scalar loads/stores)"
480            );
481        }
482
483        Op::Contract { src, .. } | Op::Unroll { src, .. } | Op::Detach { src } => {
484            let s = ctx.get(src);
485            ctx.alias(uop.id, s.to_string());
486            None
487        }
488
489        Op::After { passthrough, .. } => {
490            #[cfg(debug_assertions)]
491            if matches!(passthrough.op(), Op::Range { .. }) {
492                panic!("AFTER passthrough is Range (id={}), this violates Tinygrad semantics", passthrough.id);
493            }
494            let s = ctx.get(passthrough);
495            ctx.alias(uop.id, s.to_string());
496            None
497        }
498
499        Op::Bind { var, value } => {
500            let v = ctx.get(value);
501            ctx.alias(var.id, v.to_string());
502            None
503        }
504
505        Op::If { condition, .. } => {
506            let cond = ctx.get(condition);
507            let if_id = uop.id;
508            kernel.push(format!("  br i1 {cond}, label %if_then_{if_id}, label %if_end_{if_id}"));
509            kernel.push(format!("if_then_{if_id}:"));
510            Some(())
511        }
512
513        Op::EndIf { if_op } => {
514            let if_id = if_op.id;
515            kernel.push(format!("  br label %if_end_{if_id}"));
516            kernel.push(format!("if_end_{if_id}:"));
517            Some(())
518        }
519
520        op if op.is_movement() => {
521            panic!(
522                "movement op {:?} (id={}) reached LLVM codegen — \
523                 should have been eliminated during rangeify. \
524                 This indicates a bug in remove_movement_op or apply_bufferize_transform.",
525                std::mem::discriminant(op),
526                uop.id,
527            );
528        }
529
530        _ => {
531            kernel.push(format!("; UNSUPPORTED: {:?}", uop.op()));
532            None
533        }
534    }
535}
536
537fn binary_instr(op: BinaryOp, dtype: &DType) -> &'static str {
538    assert!(
539        !matches!(dtype.base(), morok_dtype::ScalarDType::Index),
540        "Index dtype reached LLVM codegen binary_instr({op:?}, {dtype:?}) — \
541         pm_lower_index_dtype should have lowered it to i32/i64"
542    );
543    let is_float = dtype.is_float();
544    let is_signed = dtype.is_signed();
545
546    match op {
547        BinaryOp::Add => {
548            if is_float {
549                "fadd nsz arcp contract afn"
550            } else if is_signed {
551                "add nsw"
552            } else {
553                "add"
554            }
555        }
556        BinaryOp::Mul => {
557            if is_float {
558                "fmul nsz arcp contract afn"
559            } else {
560                "mul"
561            }
562        }
563        BinaryOp::Sub => {
564            if is_float {
565                "fsub nsz arcp contract afn"
566            } else {
567                "sub"
568            }
569        }
570        BinaryOp::Fdiv => "fdiv nsz arcp contract afn",
571        BinaryOp::Idiv => {
572            if is_signed {
573                "sdiv"
574            } else {
575                "udiv"
576            }
577        }
578        BinaryOp::Mod => {
579            if is_float {
580                "frem nsz arcp contract afn"
581            } else if is_signed {
582                "srem"
583            } else {
584                "urem"
585            }
586        }
587        BinaryOp::Max => {
588            if is_float {
589                "maxnum"
590            } else if is_signed {
591                "smax"
592            } else {
593                "umax"
594            }
595        }
596        BinaryOp::Lt => {
597            if is_float {
598                "fcmp nsz arcp contract afn ult"
599            } else if is_signed {
600                "icmp slt"
601            } else {
602                "icmp ult"
603            }
604        }
605        BinaryOp::Le => {
606            if is_float {
607                "fcmp nsz arcp contract afn ule"
608            } else if is_signed {
609                "icmp sle"
610            } else {
611                "icmp ule"
612            }
613        }
614        BinaryOp::Gt => {
615            if is_float {
616                "fcmp nsz arcp contract afn ugt"
617            } else if is_signed {
618                "icmp sgt"
619            } else {
620                "icmp ugt"
621            }
622        }
623        BinaryOp::Ge => {
624            if is_float {
625                "fcmp nsz arcp contract afn uge"
626            } else if is_signed {
627                "icmp sge"
628            } else {
629                "icmp uge"
630            }
631        }
632        BinaryOp::Eq => {
633            if is_float {
634                "fcmp nsz arcp contract afn oeq"
635            } else {
636                "icmp eq"
637            }
638        }
639        BinaryOp::Ne => {
640            if is_float {
641                "fcmp nsz arcp contract afn une"
642            } else {
643                "icmp ne"
644            }
645        }
646        BinaryOp::And => "and",
647        BinaryOp::Or => "or",
648        BinaryOp::Xor => "xor",
649        BinaryOp::Shl => "shl",
650        BinaryOp::Shr => {
651            if is_signed {
652                "ashr"
653            } else {
654                "lshr"
655            }
656        }
657        BinaryOp::Pow => "pow",
658        BinaryOp::Threefry => "xor",
659    }
660}
661
662fn unary_instr(op: UnaryOp, dtype: &DType) -> Option<&'static str> {
663    let is_float = dtype.is_float();
664
665    match op {
666        UnaryOp::Neg => Some(if is_float { "fneg" } else { "sub" }),
667        UnaryOp::Not => Some("xor"),
668        UnaryOp::Sqrt => Some("sqrt"),
669        UnaryOp::Rsqrt => None,
670        UnaryOp::Exp => Some("exp"),
671        UnaryOp::Exp2 => Some("exp2"),
672        UnaryOp::Log => Some("log"),
673        UnaryOp::Log2 => Some("log2"),
674        UnaryOp::Sin => Some("sin"),
675        UnaryOp::Cos => Some("cos"),
676        UnaryOp::Abs => Some(if is_float { "fabs" } else { "abs" }),
677        UnaryOp::Floor => Some("floor"),
678        UnaryOp::Ceil => Some("ceil"),
679        UnaryOp::Trunc => Some("trunc"),
680        UnaryOp::Round => Some("rint"),
681        UnaryOp::Reciprocal => None,
682        UnaryOp::Tan => None,
683        UnaryOp::Sign => None,
684        UnaryOp::Erf => None,
685        UnaryOp::Square => None,
686    }
687}
688
689fn reduce_instr(op: ReduceOp, dtype: &DType) -> &'static str {
690    let is_float = dtype.is_float();
691    let is_signed = dtype.is_signed();
692
693    match op {
694        ReduceOp::Add => {
695            if is_float {
696                "fadd nsz arcp contract afn"
697            } else {
698                "add"
699            }
700        }
701        ReduceOp::Mul => {
702            if is_float {
703                "fmul nsz arcp contract afn"
704            } else {
705                "mul"
706            }
707        }
708        ReduceOp::Max => {
709            if is_float {
710                "maxnum"
711            } else if is_signed {
712                "smax"
713            } else {
714                "umax"
715            }
716        }
717        ReduceOp::Min => {
718            if is_float {
719                "minnum"
720            } else if is_signed {
721                "smin"
722            } else {
723                "umin"
724            }
725        }
726    }
727}
728
729fn mangle_type(llvm_type: &str) -> String {
730    match llvm_type {
731        "float" => "f32".to_string(),
732        "double" => "f64".to_string(),
733        "half" => "f16".to_string(),
734        "i8" => "i8".to_string(),
735        "i16" => "i16".to_string(),
736        "i32" => "i32".to_string(),
737        "i64" => "i64".to_string(),
738        _ if llvm_type.starts_with('<') && llvm_type.ends_with('>') => {
739            let inner = &llvm_type[1..llvm_type.len() - 1];
740            let parts: Vec<&str> = inner.split(" x ").collect();
741            if parts.len() == 2 {
742                let count = parts[0].trim();
743                let base = mangle_type(parts[1].trim());
744                format!("v{count}{base}")
745            } else {
746                llvm_type.to_string()
747            }
748        }
749        _ => llvm_type.to_string(),
750    }
751}
752
753fn render_intrinsic(dst: &str, name: &str, args: &[(&str, &str)], ret_type: &str, kernel: &mut Vec<String>) {
754    let args_str: String = args.iter().map(|(ty, val)| format!("{ty} {val}")).collect::<Vec<_>>().join(", ");
755    let mangled = mangle_type(ret_type);
756    kernel.push(format!("  {dst} = call {ret_type} @llvm.{name}.{mangled}({args_str})"));
757}
758
759fn render_binary_max(dst: &str, lhs: &Arc<UOp>, l: &str, r: &str, ltype: &str, kernel: &mut Vec<String>) {
760    if lhs.dtype().is_float() {
761        render_intrinsic(dst, "maxnum", &[(ltype, l), (ltype, r)], ltype, kernel);
762    } else {
763        let is_signed = lhs.dtype().is_signed();
764        let cmp = if is_signed { "sgt" } else { "ugt" };
765        let cmp_dst = format!("{dst}.cmp");
766        kernel.push(format!("  {cmp_dst} = icmp {cmp} {ltype} {l}, {r}"));
767        kernel.push(format!("  {dst} = select i1 {cmp_dst}, {ltype} {l}, {ltype} {r}"));
768    }
769}
770
771fn render_binary_pow(dst: &str, lhs: &Arc<UOp>, l: &str, r: &str, ltype: &str, kernel: &mut Vec<String>) {
772    if lhs.dtype().is_float() {
773        render_intrinsic(dst, "pow", &[(ltype, l), (ltype, r)], ltype, kernel);
774    } else {
775        let l_float = format!("{dst}.lf");
776        let r_float = format!("{dst}.rf");
777        let pow_float = format!("{dst}.pf");
778        kernel.push(format!("  {l_float} = sitofp {ltype} {l} to double"));
779        kernel.push(format!("  {r_float} = sitofp {ltype} {r} to double"));
780        render_intrinsic(&pow_float, "pow", &[("double", &l_float), ("double", &r_float)], "double", kernel);
781        kernel.push(format!("  {dst} = fptosi double {pow_float} to {ltype}"));
782    }
783}
784
785fn render_reduce_minmax(
786    dst: &str,
787    op: ReduceOp,
788    dtype: &DType,
789    acc: &str,
790    val: &str,
791    ltype: &str,
792    kernel: &mut Vec<String>,
793) {
794    if dtype.is_float() {
795        let intrinsic = match op {
796            ReduceOp::Max => "maxnum",
797            ReduceOp::Min => "minnum",
798            _ => unreachable!(),
799        };
800        render_intrinsic(dst, intrinsic, &[(ltype, acc), (ltype, val)], ltype, kernel);
801    } else {
802        let is_signed = dtype.is_signed();
803        let cmp = match op {
804            ReduceOp::Max => {
805                if is_signed {
806                    "sgt"
807                } else {
808                    "ugt"
809                }
810            }
811            ReduceOp::Min => {
812                if is_signed {
813                    "slt"
814                } else {
815                    "ult"
816                }
817            }
818            _ => unreachable!(),
819        };
820        let cmp_dst = format!("{dst}.cmp");
821        kernel.push(format!("  {cmp_dst} = icmp {cmp} {ltype} {acc}, {val}"));
822        kernel.push(format!("  {dst} = select i1 {cmp_dst}, {ltype} {acc}, {ltype} {val}"));
823    }
824}
825
826fn render_multi_gep(
827    dst: &str,
828    vec: &str,
829    vec_dtype: &DType,
830    indices: &[usize],
831    out_type: &str,
832    kernel: &mut Vec<String>,
833) {
834    let vec_type = ldt(vec_dtype);
835
836    let elem_dtype = match vec_dtype {
837        DType::Ptr { base, addrspace, size, .. } => {
838            DType::Ptr { base: base.clone(), addrspace: *addrspace, size: *size, vcount: 1 }
839        }
840        DType::Vector { scalar, .. } => DType::Scalar(*scalar),
841        _ => DType::Scalar(vec_dtype.base()),
842    };
843    let elem_type = ldt(&elem_dtype);
844
845    for (i, &idx) in indices.iter().enumerate() {
846        let elem = format!("{dst}.e{i}");
847        kernel.push(format!("  {elem} = extractelement {vec_type} {vec}, i32 {idx}"));
848    }
849
850    if indices.len() == 1 {
851        kernel.push(format!("  {dst} = bitcast {elem_type} {dst}.e0 to {out_type}"));
852    } else {
853        let count = indices.len();
854        kernel.push(format!("  {dst}.undef = undef <{count} x {elem_type}>"));
855        let mut prev = format!("{dst}.undef");
856        for i in 0..count {
857            let next = if i == count - 1 { dst.to_string() } else { format!("{dst}.v{i}") };
858            kernel.push(format!(
859                "  {next} = insertelement <{count} x {elem_type}> {prev}, {elem_type} {dst}.e{i}, i32 {i}"
860            ));
861            prev = next;
862        }
863    }
864}
865
866fn render_vectorize(dst: &str, elements: &[Arc<UOp>], ctx: &RenderContext, kernel: &mut Vec<String>) {
867    if elements.is_empty() {
868        return;
869    }
870
871    let scalar_type = ldt(&elements[0].dtype());
872    let count = elements.len();
873    let vec_type = format!("<{count} x {scalar_type}>");
874
875    let mut prev = "undef".to_string();
876    for (i, elem) in elements.iter().enumerate() {
877        let val = ctx.get(elem);
878        let next = if i == count - 1 { dst.to_string() } else { format!("{dst}.v{i}") };
879        kernel.push(format!("  {next} = insertelement {vec_type} {prev}, {scalar_type} {val}, i32 {i}"));
880        prev = next;
881    }
882}
883
884fn render_cat(dst: &str, sources: &[Arc<UOp>], ctx: &RenderContext, kernel: &mut Vec<String>) {
885    if sources.is_empty() {
886        return;
887    }
888
889    let total_count: usize = sources.iter().map(|s| s.dtype().vcount()).sum();
890    let scalar_type = ldt(&sources[0].dtype().scalar_dtype());
891    let out_type = format!("<{total_count} x {scalar_type}>");
892
893    let mut out_idx = 0;
894    let mut prev = "undef".to_string();
895
896    for src in sources.iter() {
897        let src_val = ctx.get(src);
898        let src_count = src.dtype().vcount();
899
900        if src_count == 1 {
901            let next = if out_idx == total_count - 1 { dst.to_string() } else { format!("{dst}.c{out_idx}") };
902            kernel.push(format!("  {next} = insertelement {out_type} {prev}, {scalar_type} {src_val}, i32 {out_idx}"));
903            prev = next;
904            out_idx += 1;
905        } else {
906            let src_type = ldt(&src.dtype());
907            for i in 0..src_count {
908                let elem = format!("{dst}.e{out_idx}");
909                kernel.push(format!("  {elem} = extractelement {src_type} {src_val}, i32 {i}"));
910
911                let next = if out_idx == total_count - 1 { dst.to_string() } else { format!("{dst}.c{out_idx}") };
912                kernel.push(format!("  {next} = insertelement {out_type} {prev}, {scalar_type} {elem}, i32 {out_idx}"));
913                prev = next;
914                out_idx += 1;
915            }
916        }
917    }
918}
919
920/// Linearize multiple index expressions into a single linear offset at render time.
921///
922/// Emits LLVM IR `mul` + `add` chain for `idx0*stride0 + idx1*stride1 + ...`.
923/// Returns the final SSA name and its LLVM type string.
924fn render_linearize_multi_index(
925    dst: &str,
926    indices: &[Arc<UOp>],
927    ctx: &RenderContext,
928    kernel: &mut Vec<String>,
929) -> (String, String) {
930    use morok_schedule::passes::linearize_index::{compute_row_major_strides, extract_index_dimension};
931
932    // Extract dimensions from index UOps
933    let dims: Vec<i64> = indices
934        .iter()
935        .map(|idx| extract_index_dimension(idx).expect("multi-index dimension must be resolvable at codegen"))
936        .collect();
937    let strides = compute_row_major_strides(&dims);
938    let idx_type = ldt(&indices[0].dtype());
939
940    let mut current = String::new();
941    for (i, (idx_uop, &stride)) in indices.iter().zip(strides.iter()).enumerate() {
942        if stride == 0 {
943            continue;
944        }
945        let idx_val = ctx.get(idx_uop);
946        let term = if stride == 1 {
947            idx_val.to_string()
948        } else {
949            let mul_name = format!("{dst}.lin_mul{i}");
950            kernel.push(format!("  {mul_name} = mul {idx_type} {idx_val}, {stride}"));
951            mul_name
952        };
953
954        if current.is_empty() {
955            current = term;
956        } else {
957            let add_name = format!("{dst}.lin_add{i}");
958            kernel.push(format!("  {add_name} = add {idx_type} {current}, {term}"));
959            current = add_name;
960        }
961    }
962
963    if current.is_empty() {
964        current = "0".to_string();
965    }
966
967    (current, idx_type)
968}
969
970/// Get identity element for reduce operation.
971pub fn reduce_identity(op: ReduceOp, dtype: &DType) -> String {
972    let is_vector = matches!(dtype, DType::Vector { .. });
973
974    match op {
975        ReduceOp::Add => {
976            if is_vector {
977                "zeroinitializer".to_string()
978            } else if dtype.is_float() {
979                "0.0".to_string()
980            } else {
981                "0".to_string()
982            }
983        }
984        ReduceOp::Mul => {
985            if is_vector {
986                "zeroinitializer".to_string()
987            } else if dtype.is_float() {
988                "1.0".to_string()
989            } else {
990                "1".to_string()
991            }
992        }
993        ReduceOp::Max => {
994            if is_vector {
995                "zeroinitializer".to_string()
996            } else if dtype.is_float() {
997                "-0x7FF0000000000000".to_string()
998            } else if dtype.is_signed() {
999                i64::MIN.to_string()
1000            } else {
1001                "0".to_string()
1002            }
1003        }
1004        ReduceOp::Min => {
1005            if is_vector {
1006                "zeroinitializer".to_string() // TODO: proper +inf splat
1007            } else if dtype.is_float() {
1008                "0x7FF0000000000000".to_string() // +inf
1009            } else if dtype.is_signed() {
1010                i64::MAX.to_string()
1011            } else {
1012                u64::MAX.to_string()
1013            }
1014        }
1015    }
1016}