1use std::sync::Arc;
7
8use morok_dtype::DType;
9use morok_ir::{AxisType, BinaryOp, Op, ReduceOp, TernaryOp, UnaryOp, prelude::*};
10
11use crate::llvm::common::{RenderContext, lcast, ldt};
12
13fn maybe_extract_scalar_ptr(
20 dst: &str,
21 idx: &str,
22 idx_type: &str,
23 dtype: &DType,
24 kernel: &mut Vec<String>,
25) -> (String, String) {
26 if matches!(dtype, DType::Ptr { vcount, .. } if *vcount > 1) {
27 let extract = format!("{dst}.ptr");
28 kernel.push(format!(" {extract} = extractelement {idx_type} {idx}, i32 0"));
29 (extract, "ptr".to_string())
30 } else {
31 (idx.to_string(), idx_type.to_string())
32 }
33}
34
35pub fn render_uop(uop: &Arc<UOp>, ctx: &mut RenderContext, kernel: &mut Vec<String>) -> Option<()> {
39 let dst = ctx.name(uop);
40
41 match uop.op() {
42 Op::Const(_)
43 | Op::VConst { .. }
44 | Op::DefineGlobal(_)
45 | Op::DefineVar { .. }
46 | Op::Noop
47 | Op::Sink { .. }
48 | Op::Group { .. }
49 | Op::Buffer { .. }
50 | Op::Unique(_)
51 | Op::Device(_)
52 | Op::Kernel { .. }
53 | Op::Barrier { .. } => None,
54
55 Op::DefineLocal(_) | Op::DefineReg { .. } => {
56 let (base_dtype, alloc_size) = match uop.dtype() {
60 DType::Ptr { base, size, .. } => (base.as_ref().clone(), size.unwrap_or(1)),
61 other => (other, 1),
62 };
63 let base = ldt(&base_dtype);
64 let align = if matches!(uop.op(), Op::DefineLocal(_)) { ", align 16" } else { "" };
66 kernel.push(format!(" {dst} = alloca [{alloc_size} x {base}]{align}"));
67 Some(())
68 }
69
70 Op::Index { buffer, indices, .. } => {
71 let buf = ctx.get(buffer);
72 let buf_type = ldt(&buffer.dtype());
73
74 if indices.is_empty() {
75 kernel.push(format!(" {dst} = bitcast {buf_type} {buf} to {}", ldt(&uop.dtype())));
76 } else {
77 let (final_idx, final_idx_type) = if indices.len() > 1 {
79 render_linearize_multi_index(&dst, indices, ctx, kernel)
80 } else {
81 (ctx.get(&indices[0]).to_string(), ldt(&indices[0].dtype()))
82 };
83
84 let elem_type = match uop.dtype() {
85 morok_dtype::DType::Ptr { ref base, .. } => ldt(base),
86 other => ldt(&other),
87 };
88
89 kernel.push(format!(
93 " {dst} = getelementptr inbounds {elem_type}, {buf_type} {buf}, {final_idx_type} {final_idx}"
94 ));
95 }
96 Some(())
97 }
98
99 Op::PointerIndex { ptr, offset } => {
100 let ptr_val = ctx.get(ptr);
101 let off_val = ctx.get(offset);
102 let elem_type = ldt(&uop.dtype());
103 let ptr_type = ldt(&ptr.dtype());
104 let off_type = ldt(&offset.dtype());
105
106 kernel.push(format!(
107 " {dst} = getelementptr inbounds {elem_type}, {ptr_type} {ptr_val}, {off_type} {off_val}"
108 ));
109 Some(())
110 }
111
112 Op::Load { index, alt, .. } => {
113 let idx = ctx.get(index);
114 let dtype = ldt(&uop.dtype());
115 let idx_type = ldt(&index.dtype());
116
117 let (idx, idx_type) = maybe_extract_scalar_ptr(&dst, idx, &idx_type, &index.dtype(), kernel);
118
119 let actual_index = match index.op() {
126 Op::Cast { src, .. } => src,
127 _ => index,
128 };
129 let gate_info = if let Op::Index { gate: Some(gate_uop), .. } = actual_index.op() {
130 let alt_uop = alt.as_ref().expect(
131 "gated LOAD without alt value — pipeline bug: \
132 line_rewrite_cleanups should ensure alt is present for gated loads",
133 );
134 Some((ctx.get(gate_uop).to_string(), ctx.get(alt_uop).to_string()))
135 } else {
136 None
137 };
138
139 if let Some((gate, alt_val)) = gate_info {
140 let label_base = &dst[1..]; let entry_label = format!("{label_base}_entry");
142 let load_label = format!("{label_base}_load");
143 let exit_label = format!("{label_base}_exit");
144 let load_val = format!("{dst}_yes");
145
146 kernel.push(format!(" br label %{entry_label}"));
147 kernel.push(format!("{entry_label}:"));
148 kernel.push(format!(" br i1 {gate}, label %{load_label}, label %{exit_label}"));
149 kernel.push(format!("{load_label}:"));
150 kernel.push(format!(" {load_val} = load {dtype}, {idx_type} {idx}"));
151 kernel.push(format!(" br label %{exit_label}"));
152 kernel.push(format!("{exit_label}:"));
153 kernel.push(format!(" {dst} = phi {dtype} [{load_val}, %{load_label}], [{alt_val}, %{entry_label}]"));
154 } else {
155 kernel.push(format!(" {dst} = load {dtype}, {idx_type} {idx}"));
156 }
157 Some(())
158 }
159
160 Op::Store { index, value, .. } => {
161 let idx = ctx.get(index);
162 let val = ctx.get(value);
163 let val_type = ldt(&value.dtype());
164 let idx_type = ldt(&index.dtype());
165
166 let (idx, idx_type) = maybe_extract_scalar_ptr(&dst, idx, &idx_type, &index.dtype(), kernel);
167
168 kernel.push(format!(" store {val_type} {val}, {idx_type} {idx}"));
169 Some(())
170 }
171
172 Op::Binary(op, lhs, rhs) => {
173 let l = ctx.get(lhs);
174 let r = ctx.get(rhs);
175 let ltype = ldt(&lhs.dtype());
176 let rtype = ldt(&rhs.dtype());
177
178 if ltype != rtype {
180 tracing::error!(
181 uop_id = uop.id,
182 uop_dtype = ?uop.dtype(),
183 op = ?op,
184 lhs_id = lhs.id,
185 rhs_id = rhs.id,
186 lhs_dtype = ?lhs.dtype(),
187 rhs_dtype = ?rhs.dtype(),
188 lhs_op = ?lhs.op().as_ref(),
189 rhs_op = ?rhs.op().as_ref(),
190 "Binary op type mismatch - lhs and rhs have different dtypes"
191 );
192 }
193
194 if matches!(op, BinaryOp::Max) {
195 render_binary_max(&dst, lhs, l, r, <ype, kernel);
196 } else if matches!(op, BinaryOp::Pow) {
197 render_binary_pow(&dst, lhs, l, r, <ype, kernel);
198 } else {
199 let instr = binary_instr(*op, &lhs.dtype());
200 kernel.push(format!(" {dst} = {instr} {ltype} {l}, {r}"));
201 }
202 Some(())
203 }
204
205 Op::Unary(op, src) => {
206 let s = ctx.get(src);
207 let stype = ldt(&src.dtype());
208
209 match op {
210 UnaryOp::Neg => {
211 if src.dtype().is_float() {
212 kernel.push(format!(" {dst} = fneg {stype} {s}"));
213 } else {
214 kernel.push(format!(" {dst} = sub {stype} 0, {s}"));
215 }
216 }
217 UnaryOp::Not => {
218 let all_ones = if src.dtype().is_bool() { "1".to_string() } else { "-1".to_string() };
219 kernel.push(format!(" {dst} = xor {stype} {s}, {all_ones}"));
220 }
221 UnaryOp::Floor | UnaryOp::Ceil | UnaryOp::Trunc | UnaryOp::Round if !src.dtype().is_float() => {
222 kernel.push(format!(" {dst} = bitcast {stype} {s} to {stype}"));
225 }
226 UnaryOp::Sqrt
227 | UnaryOp::Exp
228 | UnaryOp::Exp2
229 | UnaryOp::Log
230 | UnaryOp::Log2
231 | UnaryOp::Sin
232 | UnaryOp::Cos
233 | UnaryOp::Floor
234 | UnaryOp::Ceil
235 | UnaryOp::Trunc
236 | UnaryOp::Round => {
237 let intrinsic = unary_instr(*op, &src.dtype()).unwrap();
238 render_intrinsic(&dst, intrinsic, &[(&stype, s)], &stype, kernel);
239 }
240 UnaryOp::Abs => {
241 if src.dtype().is_float() {
242 render_intrinsic(&dst, "fabs", &[(&stype, s)], &stype, kernel);
243 } else {
244 render_intrinsic(&dst, "abs", &[(&stype, s), ("i1", "1")], &stype, kernel);
245 }
246 }
247 UnaryOp::Rsqrt => {
248 let sqrt_dst = format!("{dst}.sqrt");
249 render_intrinsic(&sqrt_dst, "sqrt", &[(&stype, s)], &stype, kernel);
250 kernel.push(format!(" {dst} = fdiv nsz arcp contract afn {stype} 1.0, {sqrt_dst}"));
251 }
252 UnaryOp::Reciprocal => {
253 kernel.push(format!(" {dst} = fdiv nsz arcp contract afn {stype} 1.0, {s}"));
254 }
255 UnaryOp::Tan => {
256 let sin_dst = format!("{dst}.sin");
257 let cos_dst = format!("{dst}.cos");
258 render_intrinsic(&sin_dst, "sin", &[(&stype, s)], &stype, kernel);
259 render_intrinsic(&cos_dst, "cos", &[(&stype, s)], &stype, kernel);
260 kernel.push(format!(" {dst} = fdiv nsz arcp contract afn {stype} {sin_dst}, {cos_dst}"));
261 }
262 UnaryOp::Sign => {
263 if src.dtype().is_float() {
264 let gt_zero = format!("{dst}.gt");
265 let lt_zero = format!("{dst}.lt");
266 let gt_ext = format!("{dst}.gt_ext");
267 let lt_ext = format!("{dst}.lt_ext");
268 kernel.push(format!(" {gt_zero} = fcmp nsz arcp contract afn ogt {stype} {s}, 0.0"));
269 kernel.push(format!(" {lt_zero} = fcmp nsz arcp contract afn olt {stype} {s}, 0.0"));
270 kernel.push(format!(" {gt_ext} = uitofp i1 {gt_zero} to {stype}"));
271 kernel.push(format!(" {lt_ext} = uitofp i1 {lt_zero} to {stype}"));
272 kernel.push(format!(" {dst} = fsub nsz arcp contract afn {stype} {gt_ext}, {lt_ext}"));
273 } else {
274 let is_signed = src.dtype().is_signed();
275 let cmp = if is_signed { "sgt" } else { "ugt" };
276 let cmp_lt = if is_signed { "slt" } else { "icmp eq" };
277 let gt_zero = format!("{dst}.gt");
278 let lt_zero = format!("{dst}.lt");
279 let gt_ext = format!("{dst}.gt_ext");
280 let lt_ext = format!("{dst}.lt_ext");
281 kernel.push(format!(" {gt_zero} = icmp {cmp} {stype} {s}, 0"));
282 if is_signed {
283 kernel.push(format!(" {lt_zero} = icmp {cmp_lt} {stype} {s}, 0"));
284 } else {
285 kernel.push(format!(" {lt_zero} = icmp eq {stype} {s}, 0"));
286 kernel.push(format!(" {lt_zero} = xor i1 {lt_zero}, 1"));
287 kernel.push(format!(" {lt_zero} = and i1 {lt_zero}, 0"));
288 }
289 kernel.push(format!(" {gt_ext} = zext i1 {gt_zero} to {stype}"));
290 kernel.push(format!(" {lt_ext} = zext i1 {lt_zero} to {stype}"));
291 kernel.push(format!(" {dst} = sub {stype} {gt_ext}, {lt_ext}"));
292 }
293 }
294 UnaryOp::Erf => {
295 render_intrinsic(&dst, "erf", &[(&stype, s)], &stype, kernel);
296 }
297 UnaryOp::Square => {
298 if src.dtype().is_float() {
299 kernel.push(format!(" {dst} = fmul nsz arcp contract afn {stype} {s}, {s}"));
300 } else {
301 kernel.push(format!(" {dst} = mul {stype} {s}, {s}"));
302 }
303 }
304 }
305 Some(())
306 }
307
308 Op::Ternary(TernaryOp::Where, cond, t, f) => {
309 let c = ctx.get(cond);
310 let tv = ctx.get(t);
311 let fv = ctx.get(f);
312 kernel.push(format!(
313 " {dst} = select {} {c}, {} {tv}, {} {fv}",
314 ldt(&cond.dtype()),
315 ldt(&t.dtype()),
316 ldt(&f.dtype())
317 ));
318 Some(())
319 }
320
321 Op::Ternary(TernaryOp::MulAcc, a, b, c) => {
322 let av = ctx.get(a);
323 let bv = ctx.get(b);
324 let cv = ctx.get(c);
325 let dtype = ldt(&a.dtype());
326
327 if a.dtype().is_float() {
328 render_intrinsic(&dst, "fmuladd", &[(&dtype, av), (&dtype, bv), (&dtype, cv)], &dtype, kernel);
329 } else {
330 let mul_dst = format!("{dst}.mul");
331 kernel.push(format!(" {mul_dst} = mul {dtype} {av}, {bv}"));
332 kernel.push(format!(" {dst} = add {dtype} {mul_dst}, {cv}"));
333 }
334 Some(())
335 }
336
337 Op::Cast { src, dtype } => {
338 let s = ctx.get(src);
339
340 let is_index_src = matches!(src.op(), Op::Index { .. });
343 let src_llvm_type = if is_index_src { "ptr".to_string() } else { ldt(&src.dtype()) };
344 let dst_llvm_type = ldt(dtype);
345
346 if is_index_src && matches!(dtype, DType::Ptr { .. }) {
350 kernel.push(format!(" {dst} = bitcast ptr {s} to ptr"));
352 return Some(());
353 }
354
355 if src_llvm_type == dst_llvm_type {
356 kernel.push(format!(" {dst} = bitcast {src_llvm_type} {s} to {dst_llvm_type}"));
357 } else {
358 let cast_instr = lcast(&src.dtype(), dtype);
359 kernel.push(format!(" {dst} = {cast_instr} {src_llvm_type} {s} to {dst_llvm_type}"));
360 }
361 Some(())
362 }
363
364 Op::BitCast { src, dtype } => {
365 let s = ctx.get(src);
366 kernel.push(format!(" {dst} = bitcast {} {s} to {}", ldt(&src.dtype()), ldt(dtype)));
367 Some(())
368 }
369
370 Op::Range { end, axis_id, .. } => {
371 let end_val = ctx.get(end);
372 let id = axis_id.value();
373 let dtype = ldt(&uop.dtype());
374
375 kernel.push(format!(" br label %loop_entry_{id}"));
376 kernel.push(format!("loop_entry_{id}:"));
377 kernel.push(format!(" br label %loop_latch_{id}"));
378 kernel.push(format!("loop_latch_{id}:"));
379 kernel.push(format!(" {dst} = phi {dtype} [ 0, %loop_entry_{id} ], [ {dst}phi, %loop_footer_{id} ]"));
380 kernel.push(format!(" {dst}phi = add {dtype} {dst}, 1"));
381 kernel.push(format!(" {dst}cmp = icmp ult {dtype} {dst}, {end_val}"));
382 kernel.push(format!(" br i1 {dst}cmp, label %loop_body_{id}, label %loop_exit_{id}"));
383 kernel.push(format!("loop_body_{id}:"));
384 Some(())
385 }
386
387 Op::End { ranges, .. } => {
388 for range in ranges.iter() {
389 if let Op::Range { axis_id, axis_type, .. } = range.op() {
390 if matches!(axis_type, AxisType::Thread) {
391 continue;
392 }
393 let id = axis_id.value();
394 kernel.push(format!(" br label %loop_footer_{id}"));
395 kernel.push(format!("loop_footer_{id}:"));
396 kernel.push(format!(" br label %loop_latch_{id}"));
397 kernel.push(format!("loop_exit_{id}:"));
398 }
399 }
400
401 let pending = ctx.take_pending_reduces();
402 for (reduce_id, info) in pending {
403 let result_name = format!("%reduce_{reduce_id}.final");
404 kernel.push(format!(" {result_name} = load {}, ptr {}", info.dtype, info.acc_ptr));
405 ctx.register(reduce_id, result_name);
406 }
407 Some(())
408 }
409
410 Op::Reduce { src, ranges, reduce_op } => {
411 let src_val = ctx.get(src);
412 let dtype = ldt(&uop.dtype());
413
414 if ranges.is_empty() {
415 kernel.push(format!(" {dst} = bitcast {dtype} {src_val} to {dtype}"));
416 } else {
417 let acc_ptr = format!("%reduce_{}", uop.id);
418 let acc_load = format!("{acc_ptr}.load");
419 let acc_new = format!("{acc_ptr}.new");
420 let instr = reduce_instr(*reduce_op, &uop.dtype());
421
422 kernel.push(format!(" {acc_load} = load {dtype}, ptr {acc_ptr}"));
423
424 if matches!(reduce_op, ReduceOp::Max | ReduceOp::Min) {
425 render_reduce_minmax(&acc_new, *reduce_op, &uop.dtype(), &acc_load, src_val, &dtype, kernel);
426 } else {
427 kernel.push(format!(" {acc_new} = {instr} {dtype} {acc_load}, {src_val}"));
428 }
429
430 kernel.push(format!(" store {dtype} {acc_new}, ptr {acc_ptr}"));
431 ctx.register_reduce_pending(uop.id, acc_ptr.clone(), dtype.clone());
432 }
433 Some(())
434 }
435
436 Op::Gep { vector, indices } => {
437 let vec = ctx.get(vector);
438 let vec_type = ldt(&vector.dtype());
439 let out_type = ldt(&uop.dtype());
440
441 if indices.len() == 1 {
442 kernel.push(format!(" {dst} = extractelement {vec_type} {vec}, i32 {}", indices[0]));
443 } else {
444 render_multi_gep(&dst, vec, &vector.dtype(), indices, &out_type, kernel);
445 }
446 Some(())
447 }
448
449 Op::Vectorize { elements } => {
450 render_vectorize(&dst, elements, ctx, kernel);
451 Some(())
452 }
453
454 Op::Cat { sources } => {
455 render_cat(&dst, sources, ctx, kernel);
456 Some(())
457 }
458
459 Op::PtrCat { sources } => {
460 render_ptrcat(&dst, sources, ctx, kernel);
461 Some(())
462 }
463
464 Op::Contract { src, .. } | Op::Unroll { src, .. } | Op::Detach { src } => {
465 let s = ctx.get(src);
466 ctx.alias(uop.id, s.to_string());
467 None
468 }
469
470 Op::After { passthrough, .. } => {
471 #[cfg(debug_assertions)]
472 if matches!(passthrough.op(), Op::Range { .. }) {
473 panic!("AFTER passthrough is Range (id={}), this violates Tinygrad semantics", passthrough.id);
474 }
475 let s = ctx.get(passthrough);
476 ctx.alias(uop.id, s.to_string());
477 None
478 }
479
480 Op::Bind { var, value } => {
481 let v = ctx.get(value);
482 ctx.alias(var.id, v.to_string());
483 None
484 }
485
486 Op::If { condition, .. } => {
487 let cond = ctx.get(condition);
488 let if_id = uop.id;
489 kernel.push(format!(" br i1 {cond}, label %if_then_{if_id}, label %if_end_{if_id}"));
490 kernel.push(format!("if_then_{if_id}:"));
491 Some(())
492 }
493
494 Op::EndIf { if_op } => {
495 let if_id = if_op.id;
496 kernel.push(format!(" br label %if_end_{if_id}"));
497 kernel.push(format!("if_end_{if_id}:"));
498 Some(())
499 }
500
501 op if op.is_movement() => {
502 panic!(
503 "movement op {:?} (id={}) reached LLVM codegen — \
504 should have been eliminated during rangeify. \
505 This indicates a bug in remove_movement_op or apply_bufferize_transform.",
506 std::mem::discriminant(op),
507 uop.id,
508 );
509 }
510
511 _ => {
512 kernel.push(format!("; UNSUPPORTED: {:?}", uop.op()));
513 None
514 }
515 }
516}
517
518fn binary_instr(op: BinaryOp, dtype: &DType) -> &'static str {
519 assert!(
520 !matches!(dtype.base(), morok_dtype::ScalarDType::Index),
521 "Index dtype reached LLVM codegen binary_instr({op:?}, {dtype:?}) — \
522 pm_lower_index_dtype should have lowered it to i32/i64"
523 );
524 let is_float = dtype.is_float();
525 let is_signed = dtype.is_signed();
526
527 match op {
528 BinaryOp::Add => {
529 if is_float {
530 "fadd nsz arcp contract afn"
531 } else if is_signed {
532 "add nsw"
533 } else {
534 "add"
535 }
536 }
537 BinaryOp::Mul => {
538 if is_float {
539 "fmul nsz arcp contract afn"
540 } else {
541 "mul"
542 }
543 }
544 BinaryOp::Sub => {
545 if is_float {
546 "fsub nsz arcp contract afn"
547 } else {
548 "sub"
549 }
550 }
551 BinaryOp::Fdiv => "fdiv nsz arcp contract afn",
552 BinaryOp::Idiv => {
553 if is_signed {
554 "sdiv"
555 } else {
556 "udiv"
557 }
558 }
559 BinaryOp::Mod => {
560 if is_float {
561 "frem nsz arcp contract afn"
562 } else if is_signed {
563 "srem"
564 } else {
565 "urem"
566 }
567 }
568 BinaryOp::Max => {
569 if is_float {
570 "maxnum"
571 } else if is_signed {
572 "smax"
573 } else {
574 "umax"
575 }
576 }
577 BinaryOp::Lt => {
578 if is_float {
579 "fcmp nsz arcp contract afn ult"
580 } else if is_signed {
581 "icmp slt"
582 } else {
583 "icmp ult"
584 }
585 }
586 BinaryOp::Le => {
587 if is_float {
588 "fcmp nsz arcp contract afn ule"
589 } else if is_signed {
590 "icmp sle"
591 } else {
592 "icmp ule"
593 }
594 }
595 BinaryOp::Gt => {
596 if is_float {
597 "fcmp nsz arcp contract afn ugt"
598 } else if is_signed {
599 "icmp sgt"
600 } else {
601 "icmp ugt"
602 }
603 }
604 BinaryOp::Ge => {
605 if is_float {
606 "fcmp nsz arcp contract afn uge"
607 } else if is_signed {
608 "icmp sge"
609 } else {
610 "icmp uge"
611 }
612 }
613 BinaryOp::Eq => {
614 if is_float {
615 "fcmp nsz arcp contract afn oeq"
616 } else {
617 "icmp eq"
618 }
619 }
620 BinaryOp::Ne => {
621 if is_float {
622 "fcmp nsz arcp contract afn une"
623 } else {
624 "icmp ne"
625 }
626 }
627 BinaryOp::And => "and",
628 BinaryOp::Or => "or",
629 BinaryOp::Xor => "xor",
630 BinaryOp::Shl => "shl",
631 BinaryOp::Shr => {
632 if is_signed {
633 "ashr"
634 } else {
635 "lshr"
636 }
637 }
638 BinaryOp::Pow => "pow",
639 BinaryOp::Threefry => "xor",
640 }
641}
642
643fn unary_instr(op: UnaryOp, dtype: &DType) -> Option<&'static str> {
644 let is_float = dtype.is_float();
645
646 match op {
647 UnaryOp::Neg => Some(if is_float { "fneg" } else { "sub" }),
648 UnaryOp::Not => Some("xor"),
649 UnaryOp::Sqrt => Some("sqrt"),
650 UnaryOp::Rsqrt => None,
651 UnaryOp::Exp => Some("exp"),
652 UnaryOp::Exp2 => Some("exp2"),
653 UnaryOp::Log => Some("log"),
654 UnaryOp::Log2 => Some("log2"),
655 UnaryOp::Sin => Some("sin"),
656 UnaryOp::Cos => Some("cos"),
657 UnaryOp::Abs => Some(if is_float { "fabs" } else { "abs" }),
658 UnaryOp::Floor => Some("floor"),
659 UnaryOp::Ceil => Some("ceil"),
660 UnaryOp::Trunc => Some("trunc"),
661 UnaryOp::Round => Some("rint"),
662 UnaryOp::Reciprocal => None,
663 UnaryOp::Tan => None,
664 UnaryOp::Sign => None,
665 UnaryOp::Erf => None,
666 UnaryOp::Square => None,
667 }
668}
669
670fn reduce_instr(op: ReduceOp, dtype: &DType) -> &'static str {
671 let is_float = dtype.is_float();
672 let is_signed = dtype.is_signed();
673
674 match op {
675 ReduceOp::Add => {
676 if is_float {
677 "fadd nsz arcp contract afn"
678 } else {
679 "add"
680 }
681 }
682 ReduceOp::Mul => {
683 if is_float {
684 "fmul nsz arcp contract afn"
685 } else {
686 "mul"
687 }
688 }
689 ReduceOp::Max => {
690 if is_float {
691 "maxnum"
692 } else if is_signed {
693 "smax"
694 } else {
695 "umax"
696 }
697 }
698 ReduceOp::Min => {
699 if is_float {
700 "minnum"
701 } else if is_signed {
702 "smin"
703 } else {
704 "umin"
705 }
706 }
707 }
708}
709
710fn mangle_type(llvm_type: &str) -> String {
711 match llvm_type {
712 "float" => "f32".to_string(),
713 "double" => "f64".to_string(),
714 "half" => "f16".to_string(),
715 "i8" => "i8".to_string(),
716 "i16" => "i16".to_string(),
717 "i32" => "i32".to_string(),
718 "i64" => "i64".to_string(),
719 _ if llvm_type.starts_with('<') && llvm_type.ends_with('>') => {
720 let inner = &llvm_type[1..llvm_type.len() - 1];
721 let parts: Vec<&str> = inner.split(" x ").collect();
722 if parts.len() == 2 {
723 let count = parts[0].trim();
724 let base = mangle_type(parts[1].trim());
725 format!("v{count}{base}")
726 } else {
727 llvm_type.to_string()
728 }
729 }
730 _ => llvm_type.to_string(),
731 }
732}
733
734fn render_intrinsic(dst: &str, name: &str, args: &[(&str, &str)], ret_type: &str, kernel: &mut Vec<String>) {
735 let args_str: String = args.iter().map(|(ty, val)| format!("{ty} {val}")).collect::<Vec<_>>().join(", ");
736 let mangled = mangle_type(ret_type);
737 kernel.push(format!(" {dst} = call {ret_type} @llvm.{name}.{mangled}({args_str})"));
738}
739
740fn render_binary_max(dst: &str, lhs: &Arc<UOp>, l: &str, r: &str, ltype: &str, kernel: &mut Vec<String>) {
741 if lhs.dtype().is_float() {
742 render_intrinsic(dst, "maxnum", &[(ltype, l), (ltype, r)], ltype, kernel);
743 } else {
744 let is_signed = lhs.dtype().is_signed();
745 let cmp = if is_signed { "sgt" } else { "ugt" };
746 let cmp_dst = format!("{dst}.cmp");
747 kernel.push(format!(" {cmp_dst} = icmp {cmp} {ltype} {l}, {r}"));
748 kernel.push(format!(" {dst} = select i1 {cmp_dst}, {ltype} {l}, {ltype} {r}"));
749 }
750}
751
752fn render_binary_pow(dst: &str, lhs: &Arc<UOp>, l: &str, r: &str, ltype: &str, kernel: &mut Vec<String>) {
753 if lhs.dtype().is_float() {
754 render_intrinsic(dst, "pow", &[(ltype, l), (ltype, r)], ltype, kernel);
755 } else {
756 let l_float = format!("{dst}.lf");
757 let r_float = format!("{dst}.rf");
758 let pow_float = format!("{dst}.pf");
759 kernel.push(format!(" {l_float} = sitofp {ltype} {l} to double"));
760 kernel.push(format!(" {r_float} = sitofp {ltype} {r} to double"));
761 render_intrinsic(&pow_float, "pow", &[("double", &l_float), ("double", &r_float)], "double", kernel);
762 kernel.push(format!(" {dst} = fptosi double {pow_float} to {ltype}"));
763 }
764}
765
766fn render_reduce_minmax(
767 dst: &str,
768 op: ReduceOp,
769 dtype: &DType,
770 acc: &str,
771 val: &str,
772 ltype: &str,
773 kernel: &mut Vec<String>,
774) {
775 if dtype.is_float() {
776 let intrinsic = match op {
777 ReduceOp::Max => "maxnum",
778 ReduceOp::Min => "minnum",
779 _ => unreachable!(),
780 };
781 render_intrinsic(dst, intrinsic, &[(ltype, acc), (ltype, val)], ltype, kernel);
782 } else {
783 let is_signed = dtype.is_signed();
784 let cmp = match op {
785 ReduceOp::Max => {
786 if is_signed {
787 "sgt"
788 } else {
789 "ugt"
790 }
791 }
792 ReduceOp::Min => {
793 if is_signed {
794 "slt"
795 } else {
796 "ult"
797 }
798 }
799 _ => unreachable!(),
800 };
801 let cmp_dst = format!("{dst}.cmp");
802 kernel.push(format!(" {cmp_dst} = icmp {cmp} {ltype} {acc}, {val}"));
803 kernel.push(format!(" {dst} = select i1 {cmp_dst}, {ltype} {acc}, {ltype} {val}"));
804 }
805}
806
807fn render_multi_gep(
808 dst: &str,
809 vec: &str,
810 vec_dtype: &DType,
811 indices: &[usize],
812 out_type: &str,
813 kernel: &mut Vec<String>,
814) {
815 let vec_type = ldt(vec_dtype);
816
817 let elem_dtype = match vec_dtype {
818 DType::Ptr { base, addrspace, size, .. } => {
819 DType::Ptr { base: base.clone(), addrspace: *addrspace, size: *size, vcount: 1 }
820 }
821 DType::Vector { scalar, .. } => DType::Scalar(*scalar),
822 _ => DType::Scalar(vec_dtype.base()),
823 };
824 let elem_type = ldt(&elem_dtype);
825
826 for (i, &idx) in indices.iter().enumerate() {
827 let elem = format!("{dst}.e{i}");
828 kernel.push(format!(" {elem} = extractelement {vec_type} {vec}, i32 {idx}"));
829 }
830
831 if indices.len() == 1 {
832 kernel.push(format!(" {dst} = bitcast {elem_type} {dst}.e0 to {out_type}"));
833 } else {
834 let count = indices.len();
835 kernel.push(format!(" {dst}.undef = undef <{count} x {elem_type}>"));
836 let mut prev = format!("{dst}.undef");
837 for i in 0..count {
838 let next = if i == count - 1 { dst.to_string() } else { format!("{dst}.v{i}") };
839 kernel.push(format!(
840 " {next} = insertelement <{count} x {elem_type}> {prev}, {elem_type} {dst}.e{i}, i32 {i}"
841 ));
842 prev = next;
843 }
844 }
845}
846
847fn render_vectorize(dst: &str, elements: &[Arc<UOp>], ctx: &RenderContext, kernel: &mut Vec<String>) {
848 if elements.is_empty() {
849 return;
850 }
851
852 let scalar_type = ldt(&elements[0].dtype());
853 let count = elements.len();
854 let vec_type = format!("<{count} x {scalar_type}>");
855
856 let mut prev = "undef".to_string();
857 for (i, elem) in elements.iter().enumerate() {
858 let val = ctx.get(elem);
859 let next = if i == count - 1 { dst.to_string() } else { format!("{dst}.v{i}") };
860 kernel.push(format!(" {next} = insertelement {vec_type} {prev}, {scalar_type} {val}, i32 {i}"));
861 prev = next;
862 }
863}
864
865fn render_cat(dst: &str, sources: &[Arc<UOp>], ctx: &RenderContext, kernel: &mut Vec<String>) {
866 if sources.is_empty() {
867 return;
868 }
869
870 let total_count: usize = sources.iter().map(|s| s.dtype().vcount()).sum();
871 let scalar_type = ldt(&sources[0].dtype().scalar_dtype());
872 let out_type = format!("<{total_count} x {scalar_type}>");
873
874 let mut out_idx = 0;
875 let mut prev = "undef".to_string();
876
877 for src in sources.iter() {
878 let src_val = ctx.get(src);
879 let src_count = src.dtype().vcount();
880
881 if src_count == 1 {
882 let next = if out_idx == total_count - 1 { dst.to_string() } else { format!("{dst}.c{out_idx}") };
883 kernel.push(format!(" {next} = insertelement {out_type} {prev}, {scalar_type} {src_val}, i32 {out_idx}"));
884 prev = next;
885 out_idx += 1;
886 } else {
887 let src_type = ldt(&src.dtype());
888 for i in 0..src_count {
889 let elem = format!("{dst}.e{out_idx}");
890 kernel.push(format!(" {elem} = extractelement {src_type} {src_val}, i32 {i}"));
891
892 let next = if out_idx == total_count - 1 { dst.to_string() } else { format!("{dst}.c{out_idx}") };
893 kernel.push(format!(" {next} = insertelement {out_type} {prev}, {scalar_type} {elem}, i32 {out_idx}"));
894 prev = next;
895 out_idx += 1;
896 }
897 }
898 }
899}
900
901fn render_ptrcat(dst: &str, sources: &[Arc<UOp>], ctx: &RenderContext, kernel: &mut Vec<String>) {
902 if sources.is_empty() {
903 return;
904 }
905
906 let count = sources.len();
907 let ptr_type = ldt(&sources[0].dtype());
908 let vec_type = format!("<{count} x {ptr_type}>");
909
910 let mut prev = "undef".to_string();
911 for (i, src) in sources.iter().enumerate() {
912 let val = ctx.get(src);
913 let next = if i == count - 1 { dst.to_string() } else { format!("{dst}.p{i}") };
914 kernel.push(format!(" {next} = insertelement {vec_type} {prev}, {ptr_type} {val}, i32 {i}"));
915 prev = next;
916 }
917}
918
919fn render_linearize_multi_index(
924 dst: &str,
925 indices: &[Arc<UOp>],
926 ctx: &RenderContext,
927 kernel: &mut Vec<String>,
928) -> (String, String) {
929 use morok_schedule::passes::linearize_index::{compute_row_major_strides, extract_index_dimension};
930
931 let dims: Vec<i64> = indices
933 .iter()
934 .map(|idx| extract_index_dimension(idx).expect("multi-index dimension must be resolvable at codegen"))
935 .collect();
936 let strides = compute_row_major_strides(&dims);
937 let idx_type = ldt(&indices[0].dtype());
938
939 let mut current = String::new();
940 for (i, (idx_uop, &stride)) in indices.iter().zip(strides.iter()).enumerate() {
941 if stride == 0 {
942 continue;
943 }
944 let idx_val = ctx.get(idx_uop);
945 let term = if stride == 1 {
946 idx_val.to_string()
947 } else {
948 let mul_name = format!("{dst}.lin_mul{i}");
949 kernel.push(format!(" {mul_name} = mul {idx_type} {idx_val}, {stride}"));
950 mul_name
951 };
952
953 if current.is_empty() {
954 current = term;
955 } else {
956 let add_name = format!("{dst}.lin_add{i}");
957 kernel.push(format!(" {add_name} = add {idx_type} {current}, {term}"));
958 current = add_name;
959 }
960 }
961
962 if current.is_empty() {
963 current = "0".to_string();
964 }
965
966 (current, idx_type)
967}
968
969pub fn reduce_identity(op: ReduceOp, dtype: &DType) -> String {
971 let is_vector = matches!(dtype, DType::Vector { .. });
972
973 match op {
974 ReduceOp::Add => {
975 if is_vector {
976 "zeroinitializer".to_string()
977 } else if dtype.is_float() {
978 "0.0".to_string()
979 } else {
980 "0".to_string()
981 }
982 }
983 ReduceOp::Mul => {
984 if is_vector {
985 "zeroinitializer".to_string()
986 } else if dtype.is_float() {
987 "1.0".to_string()
988 } else {
989 "1".to_string()
990 }
991 }
992 ReduceOp::Max => {
993 if is_vector {
994 "zeroinitializer".to_string()
995 } else if dtype.is_float() {
996 "-0x7FF0000000000000".to_string()
997 } else if dtype.is_signed() {
998 i64::MIN.to_string()
999 } else {
1000 "0".to_string()
1001 }
1002 }
1003 ReduceOp::Min => {
1004 if is_vector {
1005 "zeroinitializer".to_string() } else if dtype.is_float() {
1007 "0x7FF0000000000000".to_string() } else if dtype.is_signed() {
1009 i64::MAX.to_string()
1010 } else {
1011 u64::MAX.to_string()
1012 }
1013 }
1014 }
1015}