Skip to main content

morok_codegen/llvm/cpu/
ops.rs

1//! CPU-specific LLVM IR operation rendering.
2//!
3//! Generates LLVM IR strings for individual UOp operations on CPU.
4//! Based on Tinygrad's PatternMatcher templates in `llvmir.py`.
5
6use std::sync::Arc;
7
8use morok_dtype::DType;
9use morok_ir::{AxisType, BinaryOp, Op, ReduceOp, TernaryOp, UnaryOp, prelude::*};
10
11use crate::llvm::common::{RenderContext, lcast, ldt};
12
13/// Extract a scalar `ptr` from a vectorized `<N x ptr>` via `extractelement ... i32 0`.
14///
15/// When the devectorize pipeline doesn't fully eliminate vectorized DEFINE_GLOBAL pointers
16/// (see `no_vectorized_buf` / `no_vectorized_index` which only target DEFINE_LOCAL/DEFINE_REG),
17/// the GEP result can be `<N x ptr>`. All elements are identical (broadcast of the same buffer
18/// pointer), so extracting element 0 yields the correct scalar ptr for LLVM load/store.
19fn maybe_extract_scalar_ptr(
20    dst: &str,
21    idx: &str,
22    idx_type: &str,
23    dtype: &DType,
24    kernel: &mut Vec<String>,
25) -> (String, String) {
26    if matches!(dtype, DType::Ptr { vcount, .. } if *vcount > 1) {
27        let extract = format!("{dst}.ptr");
28        kernel.push(format!("  {extract} = extractelement {idx_type} {idx}, i32 0"));
29        (extract, "ptr".to_string())
30    } else {
31        (idx.to_string(), idx_type.to_string())
32    }
33}
34
35/// Render a UOp to LLVM IR string.
36///
37/// Returns None for meta-ops that don't produce instructions.
38pub fn render_uop(uop: &Arc<UOp>, ctx: &mut RenderContext, kernel: &mut Vec<String>) -> Option<()> {
39    let dst = ctx.name(uop);
40
41    match uop.op() {
42        Op::Const(_)
43        | Op::VConst { .. }
44        | Op::DefineGlobal(_)
45        | Op::DefineVar { .. }
46        | Op::Noop
47        | Op::Sink { .. }
48        | Op::Group { .. }
49        | Op::Buffer { .. }
50        | Op::Unique(_)
51        | Op::Device(_)
52        | Op::Kernel { .. }
53        | Op::Barrier { .. } => None,
54
55        Op::DefineLocal(_) | Op::DefineReg { .. } => {
56            // Emit alloca for local/register memory.
57            // Read base type and size from dtype (matching Tinygrad's x.dtype.base/x.dtype.size).
58            // After devectorize's no_vectorized_buf, dtype is the canonical source of truth.
59            let (base_dtype, alloc_size) = match uop.dtype() {
60                DType::Ptr { base, size, .. } => (base.as_ref().clone(), size.unwrap_or(1)),
61                other => (other, 1),
62            };
63            let base = ldt(&base_dtype);
64            // Tinygrad: DEFINE_LOCAL gets align 16 (for SSE vector loads), DEFINE_REG gets default.
65            let align = if matches!(uop.op(), Op::DefineLocal(_)) { ", align 16" } else { "" };
66            kernel.push(format!("  {dst} = alloca [{alloc_size} x {base}]{align}"));
67            Some(())
68        }
69
70        Op::Index { buffer, indices, .. } => {
71            let buf = ctx.get(buffer);
72            let buf_type = ldt(&buffer.dtype());
73
74            if indices.is_empty() {
75                kernel.push(format!("  {dst} = bitcast {buf_type} {buf} to {}", ldt(&uop.dtype())));
76            } else {
77                // Multi-index: linearize at render time using row-major strides
78                let (final_idx, final_idx_type) = if indices.len() > 1 {
79                    render_linearize_multi_index(&dst, indices, ctx, kernel)
80                } else {
81                    (ctx.get(&indices[0]).to_string(), ldt(&indices[0].dtype()))
82                };
83
84                let elem_type = match uop.dtype() {
85                    morok_dtype::DType::Ptr { ref base, .. } => ldt(base),
86                    other => ldt(&other),
87                };
88
89                // Gate is NOT handled here — matching Tinygrad's approach where INDEX
90                // always emits a plain GEP. The gate is handled at LOAD level (branch+phi)
91                // and at STORE level (IF/ENDIF via line_rewrite_cleanups).
92                kernel.push(format!(
93                    "  {dst} = getelementptr inbounds {elem_type}, {buf_type} {buf}, {final_idx_type} {final_idx}"
94                ));
95            }
96            Some(())
97        }
98
99        Op::PointerIndex { ptr, offset } => {
100            let ptr_val = ctx.get(ptr);
101            let off_val = ctx.get(offset);
102            let elem_type = ldt(&uop.dtype());
103            let ptr_type = ldt(&ptr.dtype());
104            let off_type = ldt(&offset.dtype());
105
106            kernel.push(format!(
107                "  {dst} = getelementptr inbounds {elem_type}, {ptr_type} {ptr_val}, {off_type} {off_val}"
108            ));
109            Some(())
110        }
111
112        Op::Load { index, alt, .. } => {
113            let idx = ctx.get(index);
114            let dtype = ldt(&uop.dtype());
115            let idx_type = ldt(&index.dtype());
116
117            let (idx, idx_type) = maybe_extract_scalar_ptr(&dst, idx, &idx_type, &index.dtype(), kernel);
118
119            // Gated LOAD: emit branch+phi to avoid null deref.
120            // Matches Tinygrad's pattern (llvmir.py:123-129) which requires BOTH
121            // a gated INDEX and an alt value on the LOAD. If gate exists without
122            // alt, that's a pipeline bug (line_rewrite_cleanups should provide it).
123            // Unwrap one CAST layer to find the INDEX gate (matches Tinygrad's .or_casted("idx")).
124            // The pipeline CAN produce CAST(INDEX) — devectorize handles this shape explicitly.
125            let actual_index = match index.op() {
126                Op::Cast { src, .. } => src,
127                _ => index,
128            };
129            let gate_info = if let Op::Index { gate: Some(gate_uop), .. } = actual_index.op() {
130                let alt_uop = alt.as_ref().expect(
131                    "gated LOAD without alt value — pipeline bug: \
132                     line_rewrite_cleanups should ensure alt is present for gated loads",
133                );
134                Some((ctx.get(gate_uop).to_string(), ctx.get(alt_uop).to_string()))
135            } else {
136                None
137            };
138
139            if let Some((gate, alt_val)) = gate_info {
140                let label_base = &dst[1..]; // strip leading %
141                let entry_label = format!("{label_base}_entry");
142                let load_label = format!("{label_base}_load");
143                let exit_label = format!("{label_base}_exit");
144                let load_val = format!("{dst}_yes");
145
146                kernel.push(format!("  br label %{entry_label}"));
147                kernel.push(format!("{entry_label}:"));
148                kernel.push(format!("  br i1 {gate}, label %{load_label}, label %{exit_label}"));
149                kernel.push(format!("{load_label}:"));
150                kernel.push(format!("  {load_val} = load {dtype}, {idx_type} {idx}"));
151                kernel.push(format!("  br label %{exit_label}"));
152                kernel.push(format!("{exit_label}:"));
153                kernel.push(format!("  {dst} = phi {dtype} [{load_val}, %{load_label}], [{alt_val}, %{entry_label}]"));
154            } else {
155                kernel.push(format!("  {dst} = load {dtype}, {idx_type} {idx}"));
156            }
157            Some(())
158        }
159
160        Op::Store { index, value, .. } => {
161            let idx = ctx.get(index);
162            let val = ctx.get(value);
163            let val_type = ldt(&value.dtype());
164            let idx_type = ldt(&index.dtype());
165
166            let (idx, idx_type) = maybe_extract_scalar_ptr(&dst, idx, &idx_type, &index.dtype(), kernel);
167
168            kernel.push(format!("  store {val_type} {val}, {idx_type} {idx}"));
169            Some(())
170        }
171
172        Op::Binary(op, lhs, rhs) => {
173            let l = ctx.get(lhs);
174            let r = ctx.get(rhs);
175            let ltype = ldt(&lhs.dtype());
176            let rtype = ldt(&rhs.dtype());
177
178            // Debug: detect type mismatch (logged via tracing)
179            if ltype != rtype {
180                tracing::error!(
181                    uop_id = uop.id,
182                    uop_dtype = ?uop.dtype(),
183                    op = ?op,
184                    lhs_id = lhs.id,
185                    rhs_id = rhs.id,
186                    lhs_dtype = ?lhs.dtype(),
187                    rhs_dtype = ?rhs.dtype(),
188                    lhs_op = ?lhs.op().as_ref(),
189                    rhs_op = ?rhs.op().as_ref(),
190                    "Binary op type mismatch - lhs and rhs have different dtypes"
191                );
192            }
193
194            if matches!(op, BinaryOp::Max) {
195                render_binary_max(&dst, lhs, l, r, &ltype, kernel);
196            } else if matches!(op, BinaryOp::Pow) {
197                render_binary_pow(&dst, lhs, l, r, &ltype, kernel);
198            } else {
199                let instr = binary_instr(*op, &lhs.dtype());
200                kernel.push(format!("  {dst} = {instr} {ltype} {l}, {r}"));
201            }
202            Some(())
203        }
204
205        Op::Unary(op, src) => {
206            let s = ctx.get(src);
207            let stype = ldt(&src.dtype());
208
209            match op {
210                UnaryOp::Neg => {
211                    if src.dtype().is_float() {
212                        kernel.push(format!("  {dst} = fneg {stype} {s}"));
213                    } else {
214                        kernel.push(format!("  {dst} = sub {stype} 0, {s}"));
215                    }
216                }
217                UnaryOp::Not => {
218                    let all_ones = if src.dtype().is_bool() { "1".to_string() } else { "-1".to_string() };
219                    kernel.push(format!("  {dst} = xor {stype} {s}, {all_ones}"));
220                }
221                UnaryOp::Floor | UnaryOp::Ceil | UnaryOp::Trunc | UnaryOp::Round if !src.dtype().is_float() => {
222                    // Rounding is identity for integer types (defense-in-depth;
223                    // symbolic_simple folds these away upstream).
224                    kernel.push(format!("  {dst} = bitcast {stype} {s} to {stype}"));
225                }
226                UnaryOp::Sqrt
227                | UnaryOp::Exp
228                | UnaryOp::Exp2
229                | UnaryOp::Log
230                | UnaryOp::Log2
231                | UnaryOp::Sin
232                | UnaryOp::Cos
233                | UnaryOp::Floor
234                | UnaryOp::Ceil
235                | UnaryOp::Trunc
236                | UnaryOp::Round => {
237                    let intrinsic = unary_instr(*op, &src.dtype()).unwrap();
238                    render_intrinsic(&dst, intrinsic, &[(&stype, s)], &stype, kernel);
239                }
240                UnaryOp::Abs => {
241                    if src.dtype().is_float() {
242                        render_intrinsic(&dst, "fabs", &[(&stype, s)], &stype, kernel);
243                    } else {
244                        render_intrinsic(&dst, "abs", &[(&stype, s), ("i1", "1")], &stype, kernel);
245                    }
246                }
247                UnaryOp::Rsqrt => {
248                    let sqrt_dst = format!("{dst}.sqrt");
249                    render_intrinsic(&sqrt_dst, "sqrt", &[(&stype, s)], &stype, kernel);
250                    kernel.push(format!("  {dst} = fdiv nsz arcp contract afn {stype} 1.0, {sqrt_dst}"));
251                }
252                UnaryOp::Reciprocal => {
253                    kernel.push(format!("  {dst} = fdiv nsz arcp contract afn {stype} 1.0, {s}"));
254                }
255                UnaryOp::Tan => {
256                    let sin_dst = format!("{dst}.sin");
257                    let cos_dst = format!("{dst}.cos");
258                    render_intrinsic(&sin_dst, "sin", &[(&stype, s)], &stype, kernel);
259                    render_intrinsic(&cos_dst, "cos", &[(&stype, s)], &stype, kernel);
260                    kernel.push(format!("  {dst} = fdiv nsz arcp contract afn {stype} {sin_dst}, {cos_dst}"));
261                }
262                UnaryOp::Sign => {
263                    if src.dtype().is_float() {
264                        let gt_zero = format!("{dst}.gt");
265                        let lt_zero = format!("{dst}.lt");
266                        let gt_ext = format!("{dst}.gt_ext");
267                        let lt_ext = format!("{dst}.lt_ext");
268                        kernel.push(format!("  {gt_zero} = fcmp nsz arcp contract afn ogt {stype} {s}, 0.0"));
269                        kernel.push(format!("  {lt_zero} = fcmp nsz arcp contract afn olt {stype} {s}, 0.0"));
270                        kernel.push(format!("  {gt_ext} = uitofp i1 {gt_zero} to {stype}"));
271                        kernel.push(format!("  {lt_ext} = uitofp i1 {lt_zero} to {stype}"));
272                        kernel.push(format!("  {dst} = fsub nsz arcp contract afn {stype} {gt_ext}, {lt_ext}"));
273                    } else {
274                        let is_signed = src.dtype().is_signed();
275                        let cmp = if is_signed { "sgt" } else { "ugt" };
276                        let cmp_lt = if is_signed { "slt" } else { "icmp eq" };
277                        let gt_zero = format!("{dst}.gt");
278                        let lt_zero = format!("{dst}.lt");
279                        let gt_ext = format!("{dst}.gt_ext");
280                        let lt_ext = format!("{dst}.lt_ext");
281                        kernel.push(format!("  {gt_zero} = icmp {cmp} {stype} {s}, 0"));
282                        if is_signed {
283                            kernel.push(format!("  {lt_zero} = icmp {cmp_lt} {stype} {s}, 0"));
284                        } else {
285                            kernel.push(format!("  {lt_zero} = icmp eq {stype} {s}, 0"));
286                            kernel.push(format!("  {lt_zero} = xor i1 {lt_zero}, 1"));
287                            kernel.push(format!("  {lt_zero} = and i1 {lt_zero}, 0"));
288                        }
289                        kernel.push(format!("  {gt_ext} = zext i1 {gt_zero} to {stype}"));
290                        kernel.push(format!("  {lt_ext} = zext i1 {lt_zero} to {stype}"));
291                        kernel.push(format!("  {dst} = sub {stype} {gt_ext}, {lt_ext}"));
292                    }
293                }
294                UnaryOp::Erf => {
295                    render_intrinsic(&dst, "erf", &[(&stype, s)], &stype, kernel);
296                }
297                UnaryOp::Square => {
298                    if src.dtype().is_float() {
299                        kernel.push(format!("  {dst} = fmul nsz arcp contract afn {stype} {s}, {s}"));
300                    } else {
301                        kernel.push(format!("  {dst} = mul {stype} {s}, {s}"));
302                    }
303                }
304            }
305            Some(())
306        }
307
308        Op::Ternary(TernaryOp::Where, cond, t, f) => {
309            let c = ctx.get(cond);
310            let tv = ctx.get(t);
311            let fv = ctx.get(f);
312            kernel.push(format!(
313                "  {dst} = select {} {c}, {} {tv}, {} {fv}",
314                ldt(&cond.dtype()),
315                ldt(&t.dtype()),
316                ldt(&f.dtype())
317            ));
318            Some(())
319        }
320
321        Op::Ternary(TernaryOp::MulAcc, a, b, c) => {
322            let av = ctx.get(a);
323            let bv = ctx.get(b);
324            let cv = ctx.get(c);
325            let dtype = ldt(&a.dtype());
326
327            if a.dtype().is_float() {
328                render_intrinsic(&dst, "fmuladd", &[(&dtype, av), (&dtype, bv), (&dtype, cv)], &dtype, kernel);
329            } else {
330                let mul_dst = format!("{dst}.mul");
331                kernel.push(format!("  {mul_dst} = mul {dtype} {av}, {bv}"));
332                kernel.push(format!("  {dst} = add {dtype} {mul_dst}, {cv}"));
333            }
334            Some(())
335        }
336
337        Op::Cast { src, dtype } => {
338            let s = ctx.get(src);
339
340            // INDEX always produces ptr in LLVM (via GEP), regardless of Morok dtype.
341            // When source is INDEX, treat source LLVM type as ptr for cast selection.
342            let is_index_src = matches!(src.op(), Op::Index { .. });
343            let src_llvm_type = if is_index_src { "ptr".to_string() } else { ldt(&src.dtype()) };
344            let dst_llvm_type = ldt(dtype);
345
346            // CAST(INDEX) to Ptr is a no-op - INDEX already produces ptr via GEP.
347            // This matches Tinygrad's approach (llvmir.py:189) where CAST to PtrDType
348            // is register aliasing: r[u] = r[u.src[0]]
349            if is_index_src && matches!(dtype, DType::Ptr { .. }) {
350                // Emit a bitcast as a named no-op to maintain SSA form
351                kernel.push(format!("  {dst} = bitcast ptr {s} to ptr"));
352                return Some(());
353            }
354
355            if src_llvm_type == dst_llvm_type {
356                kernel.push(format!("  {dst} = bitcast {src_llvm_type} {s} to {dst_llvm_type}"));
357            } else {
358                let cast_instr = lcast(&src.dtype(), dtype);
359                kernel.push(format!("  {dst} = {cast_instr} {src_llvm_type} {s} to {dst_llvm_type}"));
360            }
361            Some(())
362        }
363
364        Op::BitCast { src, dtype } => {
365            let s = ctx.get(src);
366            kernel.push(format!("  {dst} = bitcast {} {s} to {}", ldt(&src.dtype()), ldt(dtype)));
367            Some(())
368        }
369
370        Op::Range { end, axis_id, .. } => {
371            let end_val = ctx.get(end);
372            let id = axis_id.value();
373            let dtype = ldt(&uop.dtype());
374
375            kernel.push(format!("  br label %loop_entry_{id}"));
376            kernel.push(format!("loop_entry_{id}:"));
377            kernel.push(format!("  br label %loop_latch_{id}"));
378            kernel.push(format!("loop_latch_{id}:"));
379            kernel.push(format!("  {dst} = phi {dtype} [ 0, %loop_entry_{id} ], [ {dst}phi, %loop_footer_{id} ]"));
380            kernel.push(format!("  {dst}phi = add {dtype} {dst}, 1"));
381            kernel.push(format!("  {dst}cmp = icmp ult {dtype} {dst}, {end_val}"));
382            kernel.push(format!("  br i1 {dst}cmp, label %loop_body_{id}, label %loop_exit_{id}"));
383            kernel.push(format!("loop_body_{id}:"));
384            Some(())
385        }
386
387        Op::End { ranges, .. } => {
388            for range in ranges.iter() {
389                if let Op::Range { axis_id, axis_type, .. } = range.op() {
390                    if matches!(axis_type, AxisType::Thread) {
391                        continue;
392                    }
393                    let id = axis_id.value();
394                    kernel.push(format!("  br label %loop_footer_{id}"));
395                    kernel.push(format!("loop_footer_{id}:"));
396                    kernel.push(format!("  br label %loop_latch_{id}"));
397                    kernel.push(format!("loop_exit_{id}:"));
398                }
399            }
400
401            let pending = ctx.take_pending_reduces();
402            for (reduce_id, info) in pending {
403                let result_name = format!("%reduce_{reduce_id}.final");
404                kernel.push(format!("  {result_name} = load {}, ptr {}", info.dtype, info.acc_ptr));
405                ctx.register(reduce_id, result_name);
406            }
407            Some(())
408        }
409
410        Op::Reduce { src, ranges, reduce_op } => {
411            let src_val = ctx.get(src);
412            let dtype = ldt(&uop.dtype());
413
414            if ranges.is_empty() {
415                kernel.push(format!("  {dst} = bitcast {dtype} {src_val} to {dtype}"));
416            } else {
417                let acc_ptr = format!("%reduce_{}", uop.id);
418                let acc_load = format!("{acc_ptr}.load");
419                let acc_new = format!("{acc_ptr}.new");
420                let instr = reduce_instr(*reduce_op, &uop.dtype());
421
422                kernel.push(format!("  {acc_load} = load {dtype}, ptr {acc_ptr}"));
423
424                if matches!(reduce_op, ReduceOp::Max | ReduceOp::Min) {
425                    render_reduce_minmax(&acc_new, *reduce_op, &uop.dtype(), &acc_load, src_val, &dtype, kernel);
426                } else {
427                    kernel.push(format!("  {acc_new} = {instr} {dtype} {acc_load}, {src_val}"));
428                }
429
430                kernel.push(format!("  store {dtype} {acc_new}, ptr {acc_ptr}"));
431                ctx.register_reduce_pending(uop.id, acc_ptr.clone(), dtype.clone());
432            }
433            Some(())
434        }
435
436        Op::Gep { vector, indices } => {
437            let vec = ctx.get(vector);
438            let vec_type = ldt(&vector.dtype());
439            let out_type = ldt(&uop.dtype());
440
441            if indices.len() == 1 {
442                kernel.push(format!("  {dst} = extractelement {vec_type} {vec}, i32 {}", indices[0]));
443            } else {
444                render_multi_gep(&dst, vec, &vector.dtype(), indices, &out_type, kernel);
445            }
446            Some(())
447        }
448
449        Op::Vectorize { elements } => {
450            render_vectorize(&dst, elements, ctx, kernel);
451            Some(())
452        }
453
454        Op::Cat { sources } => {
455            render_cat(&dst, sources, ctx, kernel);
456            Some(())
457        }
458
459        Op::PtrCat { sources } => {
460            render_ptrcat(&dst, sources, ctx, kernel);
461            Some(())
462        }
463
464        Op::Contract { src, .. } | Op::Unroll { src, .. } | Op::Detach { src } => {
465            let s = ctx.get(src);
466            ctx.alias(uop.id, s.to_string());
467            None
468        }
469
470        Op::After { passthrough, .. } => {
471            #[cfg(debug_assertions)]
472            if matches!(passthrough.op(), Op::Range { .. }) {
473                panic!("AFTER passthrough is Range (id={}), this violates Tinygrad semantics", passthrough.id);
474            }
475            let s = ctx.get(passthrough);
476            ctx.alias(uop.id, s.to_string());
477            None
478        }
479
480        Op::Bind { var, value } => {
481            let v = ctx.get(value);
482            ctx.alias(var.id, v.to_string());
483            None
484        }
485
486        Op::If { condition, .. } => {
487            let cond = ctx.get(condition);
488            let if_id = uop.id;
489            kernel.push(format!("  br i1 {cond}, label %if_then_{if_id}, label %if_end_{if_id}"));
490            kernel.push(format!("if_then_{if_id}:"));
491            Some(())
492        }
493
494        Op::EndIf { if_op } => {
495            let if_id = if_op.id;
496            kernel.push(format!("  br label %if_end_{if_id}"));
497            kernel.push(format!("if_end_{if_id}:"));
498            Some(())
499        }
500
501        op if op.is_movement() => {
502            panic!(
503                "movement op {:?} (id={}) reached LLVM codegen — \
504                 should have been eliminated during rangeify. \
505                 This indicates a bug in remove_movement_op or apply_bufferize_transform.",
506                std::mem::discriminant(op),
507                uop.id,
508            );
509        }
510
511        _ => {
512            kernel.push(format!("; UNSUPPORTED: {:?}", uop.op()));
513            None
514        }
515    }
516}
517
518fn binary_instr(op: BinaryOp, dtype: &DType) -> &'static str {
519    assert!(
520        !matches!(dtype.base(), morok_dtype::ScalarDType::Index),
521        "Index dtype reached LLVM codegen binary_instr({op:?}, {dtype:?}) — \
522         pm_lower_index_dtype should have lowered it to i32/i64"
523    );
524    let is_float = dtype.is_float();
525    let is_signed = dtype.is_signed();
526
527    match op {
528        BinaryOp::Add => {
529            if is_float {
530                "fadd nsz arcp contract afn"
531            } else if is_signed {
532                "add nsw"
533            } else {
534                "add"
535            }
536        }
537        BinaryOp::Mul => {
538            if is_float {
539                "fmul nsz arcp contract afn"
540            } else {
541                "mul"
542            }
543        }
544        BinaryOp::Sub => {
545            if is_float {
546                "fsub nsz arcp contract afn"
547            } else {
548                "sub"
549            }
550        }
551        BinaryOp::Fdiv => "fdiv nsz arcp contract afn",
552        BinaryOp::Idiv => {
553            if is_signed {
554                "sdiv"
555            } else {
556                "udiv"
557            }
558        }
559        BinaryOp::Mod => {
560            if is_float {
561                "frem nsz arcp contract afn"
562            } else if is_signed {
563                "srem"
564            } else {
565                "urem"
566            }
567        }
568        BinaryOp::Max => {
569            if is_float {
570                "maxnum"
571            } else if is_signed {
572                "smax"
573            } else {
574                "umax"
575            }
576        }
577        BinaryOp::Lt => {
578            if is_float {
579                "fcmp nsz arcp contract afn ult"
580            } else if is_signed {
581                "icmp slt"
582            } else {
583                "icmp ult"
584            }
585        }
586        BinaryOp::Le => {
587            if is_float {
588                "fcmp nsz arcp contract afn ule"
589            } else if is_signed {
590                "icmp sle"
591            } else {
592                "icmp ule"
593            }
594        }
595        BinaryOp::Gt => {
596            if is_float {
597                "fcmp nsz arcp contract afn ugt"
598            } else if is_signed {
599                "icmp sgt"
600            } else {
601                "icmp ugt"
602            }
603        }
604        BinaryOp::Ge => {
605            if is_float {
606                "fcmp nsz arcp contract afn uge"
607            } else if is_signed {
608                "icmp sge"
609            } else {
610                "icmp uge"
611            }
612        }
613        BinaryOp::Eq => {
614            if is_float {
615                "fcmp nsz arcp contract afn oeq"
616            } else {
617                "icmp eq"
618            }
619        }
620        BinaryOp::Ne => {
621            if is_float {
622                "fcmp nsz arcp contract afn une"
623            } else {
624                "icmp ne"
625            }
626        }
627        BinaryOp::And => "and",
628        BinaryOp::Or => "or",
629        BinaryOp::Xor => "xor",
630        BinaryOp::Shl => "shl",
631        BinaryOp::Shr => {
632            if is_signed {
633                "ashr"
634            } else {
635                "lshr"
636            }
637        }
638        BinaryOp::Pow => "pow",
639        BinaryOp::Threefry => "xor",
640    }
641}
642
643fn unary_instr(op: UnaryOp, dtype: &DType) -> Option<&'static str> {
644    let is_float = dtype.is_float();
645
646    match op {
647        UnaryOp::Neg => Some(if is_float { "fneg" } else { "sub" }),
648        UnaryOp::Not => Some("xor"),
649        UnaryOp::Sqrt => Some("sqrt"),
650        UnaryOp::Rsqrt => None,
651        UnaryOp::Exp => Some("exp"),
652        UnaryOp::Exp2 => Some("exp2"),
653        UnaryOp::Log => Some("log"),
654        UnaryOp::Log2 => Some("log2"),
655        UnaryOp::Sin => Some("sin"),
656        UnaryOp::Cos => Some("cos"),
657        UnaryOp::Abs => Some(if is_float { "fabs" } else { "abs" }),
658        UnaryOp::Floor => Some("floor"),
659        UnaryOp::Ceil => Some("ceil"),
660        UnaryOp::Trunc => Some("trunc"),
661        UnaryOp::Round => Some("rint"),
662        UnaryOp::Reciprocal => None,
663        UnaryOp::Tan => None,
664        UnaryOp::Sign => None,
665        UnaryOp::Erf => None,
666        UnaryOp::Square => None,
667    }
668}
669
670fn reduce_instr(op: ReduceOp, dtype: &DType) -> &'static str {
671    let is_float = dtype.is_float();
672    let is_signed = dtype.is_signed();
673
674    match op {
675        ReduceOp::Add => {
676            if is_float {
677                "fadd nsz arcp contract afn"
678            } else {
679                "add"
680            }
681        }
682        ReduceOp::Mul => {
683            if is_float {
684                "fmul nsz arcp contract afn"
685            } else {
686                "mul"
687            }
688        }
689        ReduceOp::Max => {
690            if is_float {
691                "maxnum"
692            } else if is_signed {
693                "smax"
694            } else {
695                "umax"
696            }
697        }
698        ReduceOp::Min => {
699            if is_float {
700                "minnum"
701            } else if is_signed {
702                "smin"
703            } else {
704                "umin"
705            }
706        }
707    }
708}
709
710fn mangle_type(llvm_type: &str) -> String {
711    match llvm_type {
712        "float" => "f32".to_string(),
713        "double" => "f64".to_string(),
714        "half" => "f16".to_string(),
715        "i8" => "i8".to_string(),
716        "i16" => "i16".to_string(),
717        "i32" => "i32".to_string(),
718        "i64" => "i64".to_string(),
719        _ if llvm_type.starts_with('<') && llvm_type.ends_with('>') => {
720            let inner = &llvm_type[1..llvm_type.len() - 1];
721            let parts: Vec<&str> = inner.split(" x ").collect();
722            if parts.len() == 2 {
723                let count = parts[0].trim();
724                let base = mangle_type(parts[1].trim());
725                format!("v{count}{base}")
726            } else {
727                llvm_type.to_string()
728            }
729        }
730        _ => llvm_type.to_string(),
731    }
732}
733
734fn render_intrinsic(dst: &str, name: &str, args: &[(&str, &str)], ret_type: &str, kernel: &mut Vec<String>) {
735    let args_str: String = args.iter().map(|(ty, val)| format!("{ty} {val}")).collect::<Vec<_>>().join(", ");
736    let mangled = mangle_type(ret_type);
737    kernel.push(format!("  {dst} = call {ret_type} @llvm.{name}.{mangled}({args_str})"));
738}
739
740fn render_binary_max(dst: &str, lhs: &Arc<UOp>, l: &str, r: &str, ltype: &str, kernel: &mut Vec<String>) {
741    if lhs.dtype().is_float() {
742        render_intrinsic(dst, "maxnum", &[(ltype, l), (ltype, r)], ltype, kernel);
743    } else {
744        let is_signed = lhs.dtype().is_signed();
745        let cmp = if is_signed { "sgt" } else { "ugt" };
746        let cmp_dst = format!("{dst}.cmp");
747        kernel.push(format!("  {cmp_dst} = icmp {cmp} {ltype} {l}, {r}"));
748        kernel.push(format!("  {dst} = select i1 {cmp_dst}, {ltype} {l}, {ltype} {r}"));
749    }
750}
751
752fn render_binary_pow(dst: &str, lhs: &Arc<UOp>, l: &str, r: &str, ltype: &str, kernel: &mut Vec<String>) {
753    if lhs.dtype().is_float() {
754        render_intrinsic(dst, "pow", &[(ltype, l), (ltype, r)], ltype, kernel);
755    } else {
756        let l_float = format!("{dst}.lf");
757        let r_float = format!("{dst}.rf");
758        let pow_float = format!("{dst}.pf");
759        kernel.push(format!("  {l_float} = sitofp {ltype} {l} to double"));
760        kernel.push(format!("  {r_float} = sitofp {ltype} {r} to double"));
761        render_intrinsic(&pow_float, "pow", &[("double", &l_float), ("double", &r_float)], "double", kernel);
762        kernel.push(format!("  {dst} = fptosi double {pow_float} to {ltype}"));
763    }
764}
765
766fn render_reduce_minmax(
767    dst: &str,
768    op: ReduceOp,
769    dtype: &DType,
770    acc: &str,
771    val: &str,
772    ltype: &str,
773    kernel: &mut Vec<String>,
774) {
775    if dtype.is_float() {
776        let intrinsic = match op {
777            ReduceOp::Max => "maxnum",
778            ReduceOp::Min => "minnum",
779            _ => unreachable!(),
780        };
781        render_intrinsic(dst, intrinsic, &[(ltype, acc), (ltype, val)], ltype, kernel);
782    } else {
783        let is_signed = dtype.is_signed();
784        let cmp = match op {
785            ReduceOp::Max => {
786                if is_signed {
787                    "sgt"
788                } else {
789                    "ugt"
790                }
791            }
792            ReduceOp::Min => {
793                if is_signed {
794                    "slt"
795                } else {
796                    "ult"
797                }
798            }
799            _ => unreachable!(),
800        };
801        let cmp_dst = format!("{dst}.cmp");
802        kernel.push(format!("  {cmp_dst} = icmp {cmp} {ltype} {acc}, {val}"));
803        kernel.push(format!("  {dst} = select i1 {cmp_dst}, {ltype} {acc}, {ltype} {val}"));
804    }
805}
806
807fn render_multi_gep(
808    dst: &str,
809    vec: &str,
810    vec_dtype: &DType,
811    indices: &[usize],
812    out_type: &str,
813    kernel: &mut Vec<String>,
814) {
815    let vec_type = ldt(vec_dtype);
816
817    let elem_dtype = match vec_dtype {
818        DType::Ptr { base, addrspace, size, .. } => {
819            DType::Ptr { base: base.clone(), addrspace: *addrspace, size: *size, vcount: 1 }
820        }
821        DType::Vector { scalar, .. } => DType::Scalar(*scalar),
822        _ => DType::Scalar(vec_dtype.base()),
823    };
824    let elem_type = ldt(&elem_dtype);
825
826    for (i, &idx) in indices.iter().enumerate() {
827        let elem = format!("{dst}.e{i}");
828        kernel.push(format!("  {elem} = extractelement {vec_type} {vec}, i32 {idx}"));
829    }
830
831    if indices.len() == 1 {
832        kernel.push(format!("  {dst} = bitcast {elem_type} {dst}.e0 to {out_type}"));
833    } else {
834        let count = indices.len();
835        kernel.push(format!("  {dst}.undef = undef <{count} x {elem_type}>"));
836        let mut prev = format!("{dst}.undef");
837        for i in 0..count {
838            let next = if i == count - 1 { dst.to_string() } else { format!("{dst}.v{i}") };
839            kernel.push(format!(
840                "  {next} = insertelement <{count} x {elem_type}> {prev}, {elem_type} {dst}.e{i}, i32 {i}"
841            ));
842            prev = next;
843        }
844    }
845}
846
847fn render_vectorize(dst: &str, elements: &[Arc<UOp>], ctx: &RenderContext, kernel: &mut Vec<String>) {
848    if elements.is_empty() {
849        return;
850    }
851
852    let scalar_type = ldt(&elements[0].dtype());
853    let count = elements.len();
854    let vec_type = format!("<{count} x {scalar_type}>");
855
856    let mut prev = "undef".to_string();
857    for (i, elem) in elements.iter().enumerate() {
858        let val = ctx.get(elem);
859        let next = if i == count - 1 { dst.to_string() } else { format!("{dst}.v{i}") };
860        kernel.push(format!("  {next} = insertelement {vec_type} {prev}, {scalar_type} {val}, i32 {i}"));
861        prev = next;
862    }
863}
864
865fn render_cat(dst: &str, sources: &[Arc<UOp>], ctx: &RenderContext, kernel: &mut Vec<String>) {
866    if sources.is_empty() {
867        return;
868    }
869
870    let total_count: usize = sources.iter().map(|s| s.dtype().vcount()).sum();
871    let scalar_type = ldt(&sources[0].dtype().scalar_dtype());
872    let out_type = format!("<{total_count} x {scalar_type}>");
873
874    let mut out_idx = 0;
875    let mut prev = "undef".to_string();
876
877    for src in sources.iter() {
878        let src_val = ctx.get(src);
879        let src_count = src.dtype().vcount();
880
881        if src_count == 1 {
882            let next = if out_idx == total_count - 1 { dst.to_string() } else { format!("{dst}.c{out_idx}") };
883            kernel.push(format!("  {next} = insertelement {out_type} {prev}, {scalar_type} {src_val}, i32 {out_idx}"));
884            prev = next;
885            out_idx += 1;
886        } else {
887            let src_type = ldt(&src.dtype());
888            for i in 0..src_count {
889                let elem = format!("{dst}.e{out_idx}");
890                kernel.push(format!("  {elem} = extractelement {src_type} {src_val}, i32 {i}"));
891
892                let next = if out_idx == total_count - 1 { dst.to_string() } else { format!("{dst}.c{out_idx}") };
893                kernel.push(format!("  {next} = insertelement {out_type} {prev}, {scalar_type} {elem}, i32 {out_idx}"));
894                prev = next;
895                out_idx += 1;
896            }
897        }
898    }
899}
900
901fn render_ptrcat(dst: &str, sources: &[Arc<UOp>], ctx: &RenderContext, kernel: &mut Vec<String>) {
902    if sources.is_empty() {
903        return;
904    }
905
906    let count = sources.len();
907    let ptr_type = ldt(&sources[0].dtype());
908    let vec_type = format!("<{count} x {ptr_type}>");
909
910    let mut prev = "undef".to_string();
911    for (i, src) in sources.iter().enumerate() {
912        let val = ctx.get(src);
913        let next = if i == count - 1 { dst.to_string() } else { format!("{dst}.p{i}") };
914        kernel.push(format!("  {next} = insertelement {vec_type} {prev}, {ptr_type} {val}, i32 {i}"));
915        prev = next;
916    }
917}
918
919/// Linearize multiple index expressions into a single linear offset at render time.
920///
921/// Emits LLVM IR `mul` + `add` chain for `idx0*stride0 + idx1*stride1 + ...`.
922/// Returns the final SSA name and its LLVM type string.
923fn render_linearize_multi_index(
924    dst: &str,
925    indices: &[Arc<UOp>],
926    ctx: &RenderContext,
927    kernel: &mut Vec<String>,
928) -> (String, String) {
929    use morok_schedule::passes::linearize_index::{compute_row_major_strides, extract_index_dimension};
930
931    // Extract dimensions from index UOps
932    let dims: Vec<i64> = indices
933        .iter()
934        .map(|idx| extract_index_dimension(idx).expect("multi-index dimension must be resolvable at codegen"))
935        .collect();
936    let strides = compute_row_major_strides(&dims);
937    let idx_type = ldt(&indices[0].dtype());
938
939    let mut current = String::new();
940    for (i, (idx_uop, &stride)) in indices.iter().zip(strides.iter()).enumerate() {
941        if stride == 0 {
942            continue;
943        }
944        let idx_val = ctx.get(idx_uop);
945        let term = if stride == 1 {
946            idx_val.to_string()
947        } else {
948            let mul_name = format!("{dst}.lin_mul{i}");
949            kernel.push(format!("  {mul_name} = mul {idx_type} {idx_val}, {stride}"));
950            mul_name
951        };
952
953        if current.is_empty() {
954            current = term;
955        } else {
956            let add_name = format!("{dst}.lin_add{i}");
957            kernel.push(format!("  {add_name} = add {idx_type} {current}, {term}"));
958            current = add_name;
959        }
960    }
961
962    if current.is_empty() {
963        current = "0".to_string();
964    }
965
966    (current, idx_type)
967}
968
969/// Get identity element for reduce operation.
970pub fn reduce_identity(op: ReduceOp, dtype: &DType) -> String {
971    let is_vector = matches!(dtype, DType::Vector { .. });
972
973    match op {
974        ReduceOp::Add => {
975            if is_vector {
976                "zeroinitializer".to_string()
977            } else if dtype.is_float() {
978                "0.0".to_string()
979            } else {
980                "0".to_string()
981            }
982        }
983        ReduceOp::Mul => {
984            if is_vector {
985                "zeroinitializer".to_string()
986            } else if dtype.is_float() {
987                "1.0".to_string()
988            } else {
989                "1".to_string()
990            }
991        }
992        ReduceOp::Max => {
993            if is_vector {
994                "zeroinitializer".to_string()
995            } else if dtype.is_float() {
996                "-0x7FF0000000000000".to_string()
997            } else if dtype.is_signed() {
998                i64::MIN.to_string()
999            } else {
1000                "0".to_string()
1001            }
1002        }
1003        ReduceOp::Min => {
1004            if is_vector {
1005                "zeroinitializer".to_string() // TODO: proper +inf splat
1006            } else if dtype.is_float() {
1007                "0x7FF0000000000000".to_string() // +inf
1008            } else if dtype.is_signed() {
1009                i64::MAX.to_string()
1010            } else {
1011                u64::MAX.to_string()
1012            }
1013        }
1014    }
1015}