1use std::sync::Arc;
7
8use morok_dtype::DType;
9use morok_ir::{AxisType, BinaryOp, Op, ReduceOp, TernaryOp, UnaryOp, prelude::*};
10
11use crate::llvm::common::{RenderContext, lcast, ldt};
12
13fn maybe_extract_scalar_ptr(
20 dst: &str,
21 idx: &str,
22 idx_type: &str,
23 dtype: &DType,
24 kernel: &mut Vec<String>,
25) -> (String, String) {
26 if matches!(dtype, DType::Ptr { vcount, .. } if *vcount > 1) {
27 let extract = format!("{dst}.ptr");
28 kernel.push(format!(" {extract} = extractelement {idx_type} {idx}, i32 0"));
29 (extract, "ptr".to_string())
30 } else {
31 (idx.to_string(), idx_type.to_string())
32 }
33}
34
35pub fn render_uop(uop: &Arc<UOp>, ctx: &mut RenderContext, kernel: &mut Vec<String>) -> Option<()> {
39 let dst = ctx.name(uop);
40
41 match uop.op() {
42 Op::Const(_)
43 | Op::VConst { .. }
44 | Op::Param { device: None, .. }
45 | Op::DefineVar { .. }
46 | Op::Noop
47 | Op::Sink { .. }
48 | Op::Group { .. }
49 | Op::Buffer { .. }
50 | Op::Unique(_)
51 | Op::Device(_)
52 | Op::Kernel { .. }
53 | Op::Barrier { .. } => None,
54
55 Op::DefineLocal(_) | Op::DefineReg { .. } => {
56 let (base_dtype, alloc_size) = match uop.dtype() {
60 DType::Ptr { base, size, .. } => (base.as_ref().clone(), size.unwrap_or(1)),
61 other => (other, 1),
62 };
63 let base = ldt(&base_dtype);
64 let align = if matches!(uop.op(), Op::DefineLocal(_)) { ", align 16" } else { "" };
66 kernel.push(format!(" {dst} = alloca [{alloc_size} x {base}]{align}"));
67 Some(())
68 }
69
70 Op::Index { buffer, indices, .. } => {
71 let buf = ctx.get(buffer);
72 let buf_type = ldt(&buffer.dtype());
73
74 if indices.is_empty() {
75 kernel.push(format!(" {dst} = bitcast {buf_type} {buf} to {}", ldt(&uop.dtype())));
76 } else {
77 let (final_idx, final_idx_type) = if indices.len() > 1 {
79 render_linearize_multi_index(&dst, indices, ctx, kernel)
80 } else {
81 (ctx.get(&indices[0]).to_string(), ldt(&indices[0].dtype()))
82 };
83
84 let elem_type = match uop.dtype() {
85 morok_dtype::DType::Ptr { ref base, .. } => ldt(base),
86 other => ldt(&other),
87 };
88
89 kernel.push(format!(
93 " {dst} = getelementptr inbounds {elem_type}, {buf_type} {buf}, {final_idx_type} {final_idx}"
94 ));
95 }
96 Some(())
97 }
98
99 Op::PointerIndex { ptr, offset } => {
100 let ptr_val = ctx.get(ptr);
101 let off_val = ctx.get(offset);
102 let elem_type = ldt(&uop.dtype());
103 let ptr_type = ldt(&ptr.dtype());
104 let off_type = ldt(&offset.dtype());
105
106 kernel.push(format!(
107 " {dst} = getelementptr inbounds {elem_type}, {ptr_type} {ptr_val}, {off_type} {off_val}"
108 ));
109 Some(())
110 }
111
112 Op::Load { index, alt, .. } => {
113 let idx = ctx.get(index);
114 let dtype = ldt(&uop.dtype());
115 let idx_type = ldt(&index.dtype());
116
117 let (idx, idx_type) = maybe_extract_scalar_ptr(&dst, idx, &idx_type, &index.dtype(), kernel);
118
119 let actual_index = match index.op() {
126 Op::Cast { src, .. } => src,
127 _ => index,
128 };
129 let gate_info = if let Op::Index { gate: Some(gate_uop), .. } = actual_index.op() {
130 let alt_uop = alt.as_ref().expect(
131 "gated LOAD without alt value — pipeline bug: \
132 line_rewrite_cleanups should ensure alt is present for gated loads",
133 );
134 Some((ctx.get(gate_uop).to_string(), ctx.get(alt_uop).to_string()))
135 } else {
136 None
137 };
138
139 if let Some((gate, alt_val)) = gate_info {
140 let label_base = &dst[1..]; let entry_label = format!("{label_base}_entry");
142 let load_label = format!("{label_base}_load");
143 let exit_label = format!("{label_base}_exit");
144 let load_val = format!("{dst}_yes");
145
146 kernel.push(format!(" br label %{entry_label}"));
147 kernel.push(format!("{entry_label}:"));
148 kernel.push(format!(" br i1 {gate}, label %{load_label}, label %{exit_label}"));
149 kernel.push(format!("{load_label}:"));
150 kernel.push(format!(" {load_val} = load {dtype}, {idx_type} {idx}"));
151 kernel.push(format!(" br label %{exit_label}"));
152 kernel.push(format!("{exit_label}:"));
153 kernel.push(format!(" {dst} = phi {dtype} [{load_val}, %{load_label}], [{alt_val}, %{entry_label}]"));
154 } else {
155 kernel.push(format!(" {dst} = load {dtype}, {idx_type} {idx}"));
156 }
157 Some(())
158 }
159
160 Op::Store { index, value, .. } => {
161 let idx = ctx.get(index);
162 let val = ctx.get(value);
163 let val_type = ldt(&value.dtype());
164 let idx_type = ldt(&index.dtype());
165
166 let (idx, idx_type) = maybe_extract_scalar_ptr(&dst, idx, &idx_type, &index.dtype(), kernel);
167
168 kernel.push(format!(" store {val_type} {val}, {idx_type} {idx}"));
169 Some(())
170 }
171
172 Op::Binary(op, lhs, rhs) => {
173 let l = ctx.get(lhs);
174 let r = ctx.get(rhs);
175 let ltype = ldt(&lhs.dtype());
176 let rtype = ldt(&rhs.dtype());
177
178 if ltype != rtype {
180 tracing::error!(
181 uop_id = uop.id,
182 uop_dtype = ?uop.dtype(),
183 op = ?op,
184 lhs_id = lhs.id,
185 rhs_id = rhs.id,
186 lhs_dtype = ?lhs.dtype(),
187 rhs_dtype = ?rhs.dtype(),
188 lhs_op = ?lhs.op().as_ref(),
189 rhs_op = ?rhs.op().as_ref(),
190 "Binary op type mismatch - lhs and rhs have different dtypes"
191 );
192 }
193
194 if matches!(op, BinaryOp::Max) {
195 render_binary_max(&dst, lhs, l, r, <ype, kernel);
196 } else if matches!(op, BinaryOp::Pow) {
197 render_binary_pow(&dst, lhs, l, r, <ype, kernel);
198 } else {
199 let instr = binary_instr(*op, &lhs.dtype());
200 kernel.push(format!(" {dst} = {instr} {ltype} {l}, {r}"));
201 }
202 Some(())
203 }
204
205 Op::Unary(op, src) => {
206 let s = ctx.get(src);
207 let stype = ldt(&src.dtype());
208
209 match op {
210 UnaryOp::Neg => {
211 if src.dtype().is_float() {
212 kernel.push(format!(" {dst} = fneg {stype} {s}"));
213 } else {
214 kernel.push(format!(" {dst} = sub {stype} 0, {s}"));
215 }
216 }
217 UnaryOp::Not => {
218 let all_ones = if src.dtype().is_bool() { "1".to_string() } else { "-1".to_string() };
219 kernel.push(format!(" {dst} = xor {stype} {s}, {all_ones}"));
220 }
221 UnaryOp::Floor | UnaryOp::Ceil | UnaryOp::Trunc | UnaryOp::Round if !src.dtype().is_float() => {
222 kernel.push(format!(" {dst} = bitcast {stype} {s} to {stype}"));
225 }
226 UnaryOp::Sqrt
227 | UnaryOp::Exp
228 | UnaryOp::Exp2
229 | UnaryOp::Log
230 | UnaryOp::Log2
231 | UnaryOp::Sin
232 | UnaryOp::Cos
233 | UnaryOp::Floor
234 | UnaryOp::Ceil
235 | UnaryOp::Trunc
236 | UnaryOp::Round => {
237 let intrinsic = unary_instr(*op, &src.dtype()).unwrap();
238 render_intrinsic(&dst, intrinsic, &[(&stype, s)], &stype, kernel);
239 }
240 UnaryOp::Abs => {
241 if src.dtype().is_float() {
242 render_intrinsic(&dst, "fabs", &[(&stype, s)], &stype, kernel);
243 } else {
244 render_intrinsic(&dst, "abs", &[(&stype, s), ("i1", "1")], &stype, kernel);
245 }
246 }
247 UnaryOp::Rsqrt => {
248 let sqrt_dst = format!("{dst}.sqrt");
249 render_intrinsic(&sqrt_dst, "sqrt", &[(&stype, s)], &stype, kernel);
250 kernel.push(format!(" {dst} = fdiv nsz arcp contract afn {stype} 1.0, {sqrt_dst}"));
251 }
252 UnaryOp::Reciprocal => {
253 kernel.push(format!(" {dst} = fdiv nsz arcp contract afn {stype} 1.0, {s}"));
254 }
255 UnaryOp::Tan => {
256 let sin_dst = format!("{dst}.sin");
257 let cos_dst = format!("{dst}.cos");
258 render_intrinsic(&sin_dst, "sin", &[(&stype, s)], &stype, kernel);
259 render_intrinsic(&cos_dst, "cos", &[(&stype, s)], &stype, kernel);
260 kernel.push(format!(" {dst} = fdiv nsz arcp contract afn {stype} {sin_dst}, {cos_dst}"));
261 }
262 UnaryOp::Sign => {
263 if src.dtype().is_float() {
264 let gt_zero = format!("{dst}.gt");
265 let lt_zero = format!("{dst}.lt");
266 let gt_ext = format!("{dst}.gt_ext");
267 let lt_ext = format!("{dst}.lt_ext");
268 kernel.push(format!(" {gt_zero} = fcmp nsz arcp contract afn ogt {stype} {s}, 0.0"));
269 kernel.push(format!(" {lt_zero} = fcmp nsz arcp contract afn olt {stype} {s}, 0.0"));
270 kernel.push(format!(" {gt_ext} = uitofp i1 {gt_zero} to {stype}"));
271 kernel.push(format!(" {lt_ext} = uitofp i1 {lt_zero} to {stype}"));
272 kernel.push(format!(" {dst} = fsub nsz arcp contract afn {stype} {gt_ext}, {lt_ext}"));
273 } else {
274 let is_signed = src.dtype().is_signed();
275 let cmp = if is_signed { "sgt" } else { "ugt" };
276 let cmp_lt = if is_signed { "slt" } else { "icmp eq" };
277 let gt_zero = format!("{dst}.gt");
278 let lt_zero = format!("{dst}.lt");
279 let gt_ext = format!("{dst}.gt_ext");
280 let lt_ext = format!("{dst}.lt_ext");
281 kernel.push(format!(" {gt_zero} = icmp {cmp} {stype} {s}, 0"));
282 if is_signed {
283 kernel.push(format!(" {lt_zero} = icmp {cmp_lt} {stype} {s}, 0"));
284 } else {
285 kernel.push(format!(" {lt_zero} = icmp eq {stype} {s}, 0"));
286 kernel.push(format!(" {lt_zero} = xor i1 {lt_zero}, 1"));
287 kernel.push(format!(" {lt_zero} = and i1 {lt_zero}, 0"));
288 }
289 kernel.push(format!(" {gt_ext} = zext i1 {gt_zero} to {stype}"));
290 kernel.push(format!(" {lt_ext} = zext i1 {lt_zero} to {stype}"));
291 kernel.push(format!(" {dst} = sub {stype} {gt_ext}, {lt_ext}"));
292 }
293 }
294 UnaryOp::Erf => {
295 render_intrinsic(&dst, "erf", &[(&stype, s)], &stype, kernel);
296 }
297 UnaryOp::Square => {
298 if src.dtype().is_float() {
299 kernel.push(format!(" {dst} = fmul nsz arcp contract afn {stype} {s}, {s}"));
300 } else {
301 kernel.push(format!(" {dst} = mul {stype} {s}, {s}"));
302 }
303 }
304 }
305 Some(())
306 }
307
308 Op::Ternary(TernaryOp::Where, cond, t, f) => {
309 let c = ctx.get(cond);
310 let tv = ctx.get(t);
311 let fv = ctx.get(f);
312 kernel.push(format!(
313 " {dst} = select {} {c}, {} {tv}, {} {fv}",
314 ldt(&cond.dtype()),
315 ldt(&t.dtype()),
316 ldt(&f.dtype())
317 ));
318 Some(())
319 }
320
321 Op::Ternary(TernaryOp::MulAcc, a, b, c) => {
322 let av = ctx.get(a);
323 let bv = ctx.get(b);
324 let cv = ctx.get(c);
325 let dtype = ldt(&a.dtype());
326
327 if a.dtype().is_float() {
328 render_intrinsic(&dst, "fmuladd", &[(&dtype, av), (&dtype, bv), (&dtype, cv)], &dtype, kernel);
329 } else {
330 let mul_dst = format!("{dst}.mul");
331 kernel.push(format!(" {mul_dst} = mul {dtype} {av}, {bv}"));
332 kernel.push(format!(" {dst} = add {dtype} {mul_dst}, {cv}"));
333 }
334 Some(())
335 }
336
337 Op::Cast { src, dtype } => {
338 let s = ctx.get(src);
339
340 let is_index_src = matches!(src.op(), Op::Index { .. });
343 let src_llvm_type = if is_index_src { "ptr".to_string() } else { ldt(&src.dtype()) };
344 let dst_llvm_type = ldt(dtype);
345
346 if is_index_src && matches!(dtype, DType::Ptr { .. }) {
350 kernel.push(format!(" {dst} = bitcast ptr {s} to ptr"));
352 return Some(());
353 }
354
355 if dtype.is_bool() && !src.dtype().is_bool() {
356 let cmp = if src.dtype().is_float() { "fcmp nsz arcp contract afn une" } else { "icmp ne" };
359 kernel.push(format!(" {dst} = {cmp} {src_llvm_type} {s}, zeroinitializer"));
360 } else if src_llvm_type == dst_llvm_type {
361 kernel.push(format!(" {dst} = bitcast {src_llvm_type} {s} to {dst_llvm_type}"));
362 } else {
363 let cast_instr = lcast(&src.dtype(), dtype);
364 kernel.push(format!(" {dst} = {cast_instr} {src_llvm_type} {s} to {dst_llvm_type}"));
365 }
366 Some(())
367 }
368
369 Op::BitCast { src, dtype } => {
370 let s = ctx.get(src);
371 kernel.push(format!(" {dst} = bitcast {} {s} to {}", ldt(&src.dtype()), ldt(dtype)));
372 Some(())
373 }
374
375 Op::Range { axis_id, end, .. } => {
376 let id = axis_id.value();
377 let dtype = ldt(&uop.dtype());
378 let end_val = ctx.get(end).to_string();
379
380 ctx.push_range(id);
382
383 kernel.push(format!(" br label %loop_entry_{id}"));
388 kernel.push(format!("loop_entry_{id}:"));
389 kernel.push(format!(" br label %loop_latch_{id}"));
390 kernel.push(format!("loop_latch_{id}:"));
391 kernel.push(format!(" {dst} = phi {dtype} [ 0, %loop_entry_{id} ], [ {dst}phi, %loop_footer_{id} ]"));
392 kernel.push(format!(" {dst}phi = add {dtype} {dst}, 1"));
393 kernel.push(format!(" {dst}cmp = icmp ult {dtype} {dst}, {end_val}"));
394 kernel.push(format!(" br i1 {dst}cmp, label %loop_body_{id}, label %loop_exit_{id}"));
395 kernel.push(format!("loop_body_{id}:"));
396 Some(())
397 }
398
399 Op::End { ranges, .. } => {
400 let range_count = ranges
404 .iter()
405 .filter(|r| matches!(r.op(), Op::Range { axis_type, .. } if !matches!(axis_type, AxisType::Thread)))
406 .count();
407 for _ in 0..range_count {
408 if let Some(id) = ctx.pop_range() {
409 kernel.push(format!(" br label %loop_footer_{id}"));
413 kernel.push(format!("loop_footer_{id}:"));
414 kernel.push(format!(" br label %loop_latch_{id}"));
415 kernel.push(format!("loop_exit_{id}:"));
416 }
417 }
418
419 let pending = ctx.take_pending_reduces();
420 for (reduce_id, info) in pending {
421 let result_name = format!("%reduce_{reduce_id}.final");
422 kernel.push(format!(" {result_name} = load {}, ptr {}", info.dtype, info.acc_ptr));
423 ctx.register(reduce_id, result_name);
424 }
425 Some(())
426 }
427
428 Op::Reduce { src, ranges, reduce_op } => {
429 let src_val = ctx.get(src);
430 let dtype = ldt(&uop.dtype());
431
432 if ranges.is_empty() {
433 kernel.push(format!(" {dst} = bitcast {dtype} {src_val} to {dtype}"));
434 } else {
435 let acc_ptr = format!("%reduce_{}", uop.id);
436 let acc_load = format!("{acc_ptr}.load");
437 let acc_new = format!("{acc_ptr}.new");
438 let instr = reduce_instr(*reduce_op, &uop.dtype());
439
440 kernel.push(format!(" {acc_load} = load {dtype}, ptr {acc_ptr}"));
441
442 if matches!(reduce_op, ReduceOp::Max | ReduceOp::Min) {
443 render_reduce_minmax(&acc_new, *reduce_op, &uop.dtype(), &acc_load, src_val, &dtype, kernel);
444 } else {
445 kernel.push(format!(" {acc_new} = {instr} {dtype} {acc_load}, {src_val}"));
446 }
447
448 kernel.push(format!(" store {dtype} {acc_new}, ptr {acc_ptr}"));
449 ctx.register_reduce_pending(uop.id, acc_ptr.clone(), dtype.clone());
450 }
451 Some(())
452 }
453
454 Op::Gep { vector, indices } => {
455 let vec = ctx.get(vector);
456 let vec_type = ldt(&vector.dtype());
457 let out_type = ldt(&uop.dtype());
458
459 if indices.len() == 1 {
460 kernel.push(format!(" {dst} = extractelement {vec_type} {vec}, i32 {}", indices[0]));
461 } else {
462 render_multi_gep(&dst, vec, &vector.dtype(), indices, &out_type, kernel);
463 }
464 Some(())
465 }
466
467 Op::Vectorize { elements } => {
468 render_vectorize(&dst, elements, ctx, kernel);
469 Some(())
470 }
471
472 Op::Cat { sources } => {
473 render_cat(&dst, sources, ctx, kernel);
474 Some(())
475 }
476
477 Op::PtrCat { .. } => {
478 panic!(
479 "PtrCat must be eliminated before codegen (devectorize should distribute it into scalar loads/stores)"
480 );
481 }
482
483 Op::Contract { src, .. } | Op::Unroll { src, .. } | Op::Detach { src } => {
484 let s = ctx.get(src);
485 ctx.alias(uop.id, s.to_string());
486 None
487 }
488
489 Op::After { passthrough, .. } => {
490 #[cfg(debug_assertions)]
491 if matches!(passthrough.op(), Op::Range { .. }) {
492 panic!("AFTER passthrough is Range (id={}), this violates Tinygrad semantics", passthrough.id);
493 }
494 let s = ctx.get(passthrough);
495 ctx.alias(uop.id, s.to_string());
496 None
497 }
498
499 Op::Bind { var, value } => {
500 let v = ctx.get(value);
501 ctx.alias(var.id, v.to_string());
502 None
503 }
504
505 Op::If { condition, .. } => {
506 let cond = ctx.get(condition);
507 let if_id = uop.id;
508 kernel.push(format!(" br i1 {cond}, label %if_then_{if_id}, label %if_end_{if_id}"));
509 kernel.push(format!("if_then_{if_id}:"));
510 Some(())
511 }
512
513 Op::EndIf { if_op } => {
514 let if_id = if_op.id;
515 kernel.push(format!(" br label %if_end_{if_id}"));
516 kernel.push(format!("if_end_{if_id}:"));
517 Some(())
518 }
519
520 op if op.is_movement() => {
521 panic!(
522 "movement op {:?} (id={}) reached LLVM codegen — \
523 should have been eliminated during rangeify. \
524 This indicates a bug in remove_movement_op or apply_bufferize_transform.",
525 std::mem::discriminant(op),
526 uop.id,
527 );
528 }
529
530 _ => {
531 kernel.push(format!("; UNSUPPORTED: {:?}", uop.op()));
532 None
533 }
534 }
535}
536
537fn binary_instr(op: BinaryOp, dtype: &DType) -> &'static str {
538 assert!(
539 !matches!(dtype.base(), morok_dtype::ScalarDType::Index),
540 "Index dtype reached LLVM codegen binary_instr({op:?}, {dtype:?}) — \
541 pm_lower_index_dtype should have lowered it to i32/i64"
542 );
543 let is_float = dtype.is_float();
544 let is_signed = dtype.is_signed();
545
546 match op {
547 BinaryOp::Add => {
548 if is_float {
549 "fadd nsz arcp contract afn"
550 } else if is_signed {
551 "add nsw"
552 } else {
553 "add"
554 }
555 }
556 BinaryOp::Mul => {
557 if is_float {
558 "fmul nsz arcp contract afn"
559 } else {
560 "mul"
561 }
562 }
563 BinaryOp::Sub => {
564 if is_float {
565 "fsub nsz arcp contract afn"
566 } else {
567 "sub"
568 }
569 }
570 BinaryOp::Fdiv => "fdiv nsz arcp contract afn",
571 BinaryOp::Idiv => {
572 if is_signed {
573 "sdiv"
574 } else {
575 "udiv"
576 }
577 }
578 BinaryOp::Mod => {
579 if is_float {
580 "frem nsz arcp contract afn"
581 } else if is_signed {
582 "srem"
583 } else {
584 "urem"
585 }
586 }
587 BinaryOp::Max => {
588 if is_float {
589 "maxnum"
590 } else if is_signed {
591 "smax"
592 } else {
593 "umax"
594 }
595 }
596 BinaryOp::Lt => {
597 if is_float {
598 "fcmp nsz arcp contract afn ult"
599 } else if is_signed {
600 "icmp slt"
601 } else {
602 "icmp ult"
603 }
604 }
605 BinaryOp::Le => {
606 if is_float {
607 "fcmp nsz arcp contract afn ule"
608 } else if is_signed {
609 "icmp sle"
610 } else {
611 "icmp ule"
612 }
613 }
614 BinaryOp::Gt => {
615 if is_float {
616 "fcmp nsz arcp contract afn ugt"
617 } else if is_signed {
618 "icmp sgt"
619 } else {
620 "icmp ugt"
621 }
622 }
623 BinaryOp::Ge => {
624 if is_float {
625 "fcmp nsz arcp contract afn uge"
626 } else if is_signed {
627 "icmp sge"
628 } else {
629 "icmp uge"
630 }
631 }
632 BinaryOp::Eq => {
633 if is_float {
634 "fcmp nsz arcp contract afn oeq"
635 } else {
636 "icmp eq"
637 }
638 }
639 BinaryOp::Ne => {
640 if is_float {
641 "fcmp nsz arcp contract afn une"
642 } else {
643 "icmp ne"
644 }
645 }
646 BinaryOp::And => "and",
647 BinaryOp::Or => "or",
648 BinaryOp::Xor => "xor",
649 BinaryOp::Shl => "shl",
650 BinaryOp::Shr => {
651 if is_signed {
652 "ashr"
653 } else {
654 "lshr"
655 }
656 }
657 BinaryOp::Pow => "pow",
658 BinaryOp::Threefry => "xor",
659 }
660}
661
662fn unary_instr(op: UnaryOp, dtype: &DType) -> Option<&'static str> {
663 let is_float = dtype.is_float();
664
665 match op {
666 UnaryOp::Neg => Some(if is_float { "fneg" } else { "sub" }),
667 UnaryOp::Not => Some("xor"),
668 UnaryOp::Sqrt => Some("sqrt"),
669 UnaryOp::Rsqrt => None,
670 UnaryOp::Exp => Some("exp"),
671 UnaryOp::Exp2 => Some("exp2"),
672 UnaryOp::Log => Some("log"),
673 UnaryOp::Log2 => Some("log2"),
674 UnaryOp::Sin => Some("sin"),
675 UnaryOp::Cos => Some("cos"),
676 UnaryOp::Abs => Some(if is_float { "fabs" } else { "abs" }),
677 UnaryOp::Floor => Some("floor"),
678 UnaryOp::Ceil => Some("ceil"),
679 UnaryOp::Trunc => Some("trunc"),
680 UnaryOp::Round => Some("rint"),
681 UnaryOp::Reciprocal => None,
682 UnaryOp::Tan => None,
683 UnaryOp::Sign => None,
684 UnaryOp::Erf => None,
685 UnaryOp::Square => None,
686 }
687}
688
689fn reduce_instr(op: ReduceOp, dtype: &DType) -> &'static str {
690 let is_float = dtype.is_float();
691 let is_signed = dtype.is_signed();
692
693 match op {
694 ReduceOp::Add => {
695 if is_float {
696 "fadd nsz arcp contract afn"
697 } else {
698 "add"
699 }
700 }
701 ReduceOp::Mul => {
702 if is_float {
703 "fmul nsz arcp contract afn"
704 } else {
705 "mul"
706 }
707 }
708 ReduceOp::Max => {
709 if is_float {
710 "maxnum"
711 } else if is_signed {
712 "smax"
713 } else {
714 "umax"
715 }
716 }
717 ReduceOp::Min => {
718 if is_float {
719 "minnum"
720 } else if is_signed {
721 "smin"
722 } else {
723 "umin"
724 }
725 }
726 }
727}
728
729fn mangle_type(llvm_type: &str) -> String {
730 match llvm_type {
731 "float" => "f32".to_string(),
732 "double" => "f64".to_string(),
733 "half" => "f16".to_string(),
734 "i8" => "i8".to_string(),
735 "i16" => "i16".to_string(),
736 "i32" => "i32".to_string(),
737 "i64" => "i64".to_string(),
738 _ if llvm_type.starts_with('<') && llvm_type.ends_with('>') => {
739 let inner = &llvm_type[1..llvm_type.len() - 1];
740 let parts: Vec<&str> = inner.split(" x ").collect();
741 if parts.len() == 2 {
742 let count = parts[0].trim();
743 let base = mangle_type(parts[1].trim());
744 format!("v{count}{base}")
745 } else {
746 llvm_type.to_string()
747 }
748 }
749 _ => llvm_type.to_string(),
750 }
751}
752
753fn render_intrinsic(dst: &str, name: &str, args: &[(&str, &str)], ret_type: &str, kernel: &mut Vec<String>) {
754 let args_str: String = args.iter().map(|(ty, val)| format!("{ty} {val}")).collect::<Vec<_>>().join(", ");
755 let mangled = mangle_type(ret_type);
756 kernel.push(format!(" {dst} = call {ret_type} @llvm.{name}.{mangled}({args_str})"));
757}
758
759fn render_binary_max(dst: &str, lhs: &Arc<UOp>, l: &str, r: &str, ltype: &str, kernel: &mut Vec<String>) {
760 if lhs.dtype().is_float() {
761 render_intrinsic(dst, "maxnum", &[(ltype, l), (ltype, r)], ltype, kernel);
762 } else {
763 let is_signed = lhs.dtype().is_signed();
764 let cmp = if is_signed { "sgt" } else { "ugt" };
765 let cmp_dst = format!("{dst}.cmp");
766 kernel.push(format!(" {cmp_dst} = icmp {cmp} {ltype} {l}, {r}"));
767 kernel.push(format!(" {dst} = select i1 {cmp_dst}, {ltype} {l}, {ltype} {r}"));
768 }
769}
770
771fn render_binary_pow(dst: &str, lhs: &Arc<UOp>, l: &str, r: &str, ltype: &str, kernel: &mut Vec<String>) {
772 if lhs.dtype().is_float() {
773 render_intrinsic(dst, "pow", &[(ltype, l), (ltype, r)], ltype, kernel);
774 } else {
775 let l_float = format!("{dst}.lf");
776 let r_float = format!("{dst}.rf");
777 let pow_float = format!("{dst}.pf");
778 kernel.push(format!(" {l_float} = sitofp {ltype} {l} to double"));
779 kernel.push(format!(" {r_float} = sitofp {ltype} {r} to double"));
780 render_intrinsic(&pow_float, "pow", &[("double", &l_float), ("double", &r_float)], "double", kernel);
781 kernel.push(format!(" {dst} = fptosi double {pow_float} to {ltype}"));
782 }
783}
784
785fn render_reduce_minmax(
786 dst: &str,
787 op: ReduceOp,
788 dtype: &DType,
789 acc: &str,
790 val: &str,
791 ltype: &str,
792 kernel: &mut Vec<String>,
793) {
794 if dtype.is_float() {
795 let intrinsic = match op {
796 ReduceOp::Max => "maxnum",
797 ReduceOp::Min => "minnum",
798 _ => unreachable!(),
799 };
800 render_intrinsic(dst, intrinsic, &[(ltype, acc), (ltype, val)], ltype, kernel);
801 } else {
802 let is_signed = dtype.is_signed();
803 let cmp = match op {
804 ReduceOp::Max => {
805 if is_signed {
806 "sgt"
807 } else {
808 "ugt"
809 }
810 }
811 ReduceOp::Min => {
812 if is_signed {
813 "slt"
814 } else {
815 "ult"
816 }
817 }
818 _ => unreachable!(),
819 };
820 let cmp_dst = format!("{dst}.cmp");
821 kernel.push(format!(" {cmp_dst} = icmp {cmp} {ltype} {acc}, {val}"));
822 kernel.push(format!(" {dst} = select i1 {cmp_dst}, {ltype} {acc}, {ltype} {val}"));
823 }
824}
825
826fn render_multi_gep(
827 dst: &str,
828 vec: &str,
829 vec_dtype: &DType,
830 indices: &[usize],
831 out_type: &str,
832 kernel: &mut Vec<String>,
833) {
834 let vec_type = ldt(vec_dtype);
835
836 let elem_dtype = match vec_dtype {
837 DType::Ptr { base, addrspace, size, .. } => {
838 DType::Ptr { base: base.clone(), addrspace: *addrspace, size: *size, vcount: 1 }
839 }
840 DType::Vector { scalar, .. } => DType::Scalar(*scalar),
841 _ => DType::Scalar(vec_dtype.base()),
842 };
843 let elem_type = ldt(&elem_dtype);
844
845 for (i, &idx) in indices.iter().enumerate() {
846 let elem = format!("{dst}.e{i}");
847 kernel.push(format!(" {elem} = extractelement {vec_type} {vec}, i32 {idx}"));
848 }
849
850 if indices.len() == 1 {
851 kernel.push(format!(" {dst} = bitcast {elem_type} {dst}.e0 to {out_type}"));
852 } else {
853 let count = indices.len();
854 kernel.push(format!(" {dst}.undef = undef <{count} x {elem_type}>"));
855 let mut prev = format!("{dst}.undef");
856 for i in 0..count {
857 let next = if i == count - 1 { dst.to_string() } else { format!("{dst}.v{i}") };
858 kernel.push(format!(
859 " {next} = insertelement <{count} x {elem_type}> {prev}, {elem_type} {dst}.e{i}, i32 {i}"
860 ));
861 prev = next;
862 }
863 }
864}
865
866fn render_vectorize(dst: &str, elements: &[Arc<UOp>], ctx: &RenderContext, kernel: &mut Vec<String>) {
867 if elements.is_empty() {
868 return;
869 }
870
871 let scalar_type = ldt(&elements[0].dtype());
872 let count = elements.len();
873 let vec_type = format!("<{count} x {scalar_type}>");
874
875 let mut prev = "undef".to_string();
876 for (i, elem) in elements.iter().enumerate() {
877 let val = ctx.get(elem);
878 let next = if i == count - 1 { dst.to_string() } else { format!("{dst}.v{i}") };
879 kernel.push(format!(" {next} = insertelement {vec_type} {prev}, {scalar_type} {val}, i32 {i}"));
880 prev = next;
881 }
882}
883
884fn render_cat(dst: &str, sources: &[Arc<UOp>], ctx: &RenderContext, kernel: &mut Vec<String>) {
885 if sources.is_empty() {
886 return;
887 }
888
889 let total_count: usize = sources.iter().map(|s| s.dtype().vcount()).sum();
890 let scalar_type = ldt(&sources[0].dtype().scalar_dtype());
891 let out_type = format!("<{total_count} x {scalar_type}>");
892
893 let mut out_idx = 0;
894 let mut prev = "undef".to_string();
895
896 for src in sources.iter() {
897 let src_val = ctx.get(src);
898 let src_count = src.dtype().vcount();
899
900 if src_count == 1 {
901 let next = if out_idx == total_count - 1 { dst.to_string() } else { format!("{dst}.c{out_idx}") };
902 kernel.push(format!(" {next} = insertelement {out_type} {prev}, {scalar_type} {src_val}, i32 {out_idx}"));
903 prev = next;
904 out_idx += 1;
905 } else {
906 let src_type = ldt(&src.dtype());
907 for i in 0..src_count {
908 let elem = format!("{dst}.e{out_idx}");
909 kernel.push(format!(" {elem} = extractelement {src_type} {src_val}, i32 {i}"));
910
911 let next = if out_idx == total_count - 1 { dst.to_string() } else { format!("{dst}.c{out_idx}") };
912 kernel.push(format!(" {next} = insertelement {out_type} {prev}, {scalar_type} {elem}, i32 {out_idx}"));
913 prev = next;
914 out_idx += 1;
915 }
916 }
917 }
918}
919
920fn render_linearize_multi_index(
925 dst: &str,
926 indices: &[Arc<UOp>],
927 ctx: &RenderContext,
928 kernel: &mut Vec<String>,
929) -> (String, String) {
930 use morok_schedule::passes::linearize_index::{compute_row_major_strides, extract_index_dimension};
931
932 let dims: Vec<i64> = indices
934 .iter()
935 .map(|idx| extract_index_dimension(idx).expect("multi-index dimension must be resolvable at codegen"))
936 .collect();
937 let strides = compute_row_major_strides(&dims);
938 let idx_type = ldt(&indices[0].dtype());
939
940 let mut current = String::new();
941 for (i, (idx_uop, &stride)) in indices.iter().zip(strides.iter()).enumerate() {
942 if stride == 0 {
943 continue;
944 }
945 let idx_val = ctx.get(idx_uop);
946 let term = if stride == 1 {
947 idx_val.to_string()
948 } else {
949 let mul_name = format!("{dst}.lin_mul{i}");
950 kernel.push(format!(" {mul_name} = mul {idx_type} {idx_val}, {stride}"));
951 mul_name
952 };
953
954 if current.is_empty() {
955 current = term;
956 } else {
957 let add_name = format!("{dst}.lin_add{i}");
958 kernel.push(format!(" {add_name} = add {idx_type} {current}, {term}"));
959 current = add_name;
960 }
961 }
962
963 if current.is_empty() {
964 current = "0".to_string();
965 }
966
967 (current, idx_type)
968}
969
970pub fn reduce_identity(op: ReduceOp, dtype: &DType) -> String {
972 let is_vector = matches!(dtype, DType::Vector { .. });
973
974 match op {
975 ReduceOp::Add => {
976 if is_vector {
977 "zeroinitializer".to_string()
978 } else if dtype.is_float() {
979 "0.0".to_string()
980 } else {
981 "0".to_string()
982 }
983 }
984 ReduceOp::Mul => {
985 if is_vector {
986 "zeroinitializer".to_string()
987 } else if dtype.is_float() {
988 "1.0".to_string()
989 } else {
990 "1".to_string()
991 }
992 }
993 ReduceOp::Max => {
994 if is_vector {
995 "zeroinitializer".to_string()
996 } else if dtype.is_float() {
997 "-0x7FF0000000000000".to_string()
998 } else if dtype.is_signed() {
999 i64::MIN.to_string()
1000 } else {
1001 "0".to_string()
1002 }
1003 }
1004 ReduceOp::Min => {
1005 if is_vector {
1006 "zeroinitializer".to_string() } else if dtype.is_float() {
1008 "0x7FF0000000000000".to_string() } else if dtype.is_signed() {
1010 i64::MAX.to_string()
1011 } else {
1012 u64::MAX.to_string()
1013 }
1014 }
1015 }
1016}