1use std::collections::{HashMap, HashSet};
8use std::sync::Arc;
9
10use morok_dtype::{DType, ScalarDType};
11use morok_ir::{AxisType, BinaryOp, Op, ReduceOp, TernaryOp, UnaryOp, prelude::*};
12
13use super::types::{c_cast, c_const, c_dtype, c_math_fn};
14
15pub struct CContext {
17 names: HashMap<u64, String>,
19 ref_counts: HashMap<u64, usize>,
21 counter: usize,
23 depth: usize,
25 pending_reduces: HashMap<u64, (String, DType)>,
27 scope_escaping: HashSet<u64>,
29 pub hoisted_declarations: Vec<String>,
31}
32
33impl CContext {
34 pub fn new(ref_counts: HashMap<u64, usize>, scope_escaping: HashSet<u64>) -> Self {
35 Self {
36 names: HashMap::new(),
37 ref_counts,
38 counter: 0,
39 depth: 1,
40 pending_reduces: HashMap::new(),
41 scope_escaping,
42 hoisted_declarations: Vec::new(),
43 }
44 }
45
46 pub fn get(&self, uop: &Arc<UOp>) -> &str {
48 self.names
49 .get(&uop.id)
50 .map(|s| s.as_str())
51 .unwrap_or_else(|| panic!("UOp {} ({}) not in C context", uop.id, uop.op().as_ref()))
52 }
53
54 pub fn register(&mut self, id: u64, expr: String) {
56 self.names.insert(id, expr);
57 }
58
59 pub fn should_inline(&self, id: u64) -> bool {
61 self.ref_counts.get(&id).copied().unwrap_or(0) <= 1
62 }
63
64 pub fn next_name(&mut self, prefix: &str) -> String {
66 let name = format!("{}{}", prefix, self.counter);
67 self.counter += 1;
68 name
69 }
70
71 pub fn indent(&self) -> String {
73 " ".repeat(self.depth)
74 }
75
76 pub fn push_indent(&mut self) {
78 self.depth += 1;
79 }
80
81 pub fn pop_indent(&mut self) {
83 self.depth = self.depth.saturating_sub(1);
84 }
85
86 pub fn register_reduce_pending(&mut self, reduce_id: u64, acc_name: String, dtype: DType) {
88 self.pending_reduces.insert(reduce_id, (acc_name, dtype));
89 }
90
91 pub fn take_pending_reduces(&mut self) -> HashMap<u64, (String, DType)> {
93 std::mem::take(&mut self.pending_reduces)
94 }
95
96 pub fn emit_expr(&mut self, uop: &Arc<UOp>, expr: String, prefix: &str, kernel: &mut Vec<String>) -> String {
104 if self.should_inline(uop.id) {
105 self.register(uop.id, expr.clone());
106 expr
107 } else {
108 let name = self.next_name(prefix);
109 let dtype = c_dtype(&uop.dtype());
110 let indent = self.indent();
111 if self.scope_escaping.contains(&uop.id) {
112 self.hoisted_declarations.push(format!(" {dtype} {name};"));
114 kernel.push(format!("{indent}{name} = {expr};"));
115 } else {
116 kernel.push(format!("{indent}{dtype} {name} = {expr};"));
117 }
118 self.register(uop.id, name.clone());
119 name
120 }
121 }
122}
123
124pub fn render_uop(uop: &Arc<UOp>, ctx: &mut CContext, kernel: &mut Vec<String>) -> Option<()> {
128 match uop.op() {
129 Op::Const(_)
131 | Op::VConst { .. }
132 | Op::Param { device: None, .. }
133 | Op::DefineLocal(_)
134 | Op::DefineVar { .. }
135 | Op::Noop
136 | Op::Sink { .. }
137 | Op::Group { .. }
138 | Op::Buffer { .. }
139 | Op::Unique(_)
140 | Op::Device(_)
141 | Op::Kernel { .. }
142 | Op::Barrier { .. } => None,
143
144 Op::DefineReg { .. } => {
145 let (base_dtype, alloc_size) = match uop.dtype() {
149 DType::Ptr { base, size, .. } => (base.as_ref().clone(), size.unwrap_or(1)),
150 other => (other, 1),
151 };
152 let name = ctx.next_name("reg");
153 let indent = ctx.indent();
154 kernel.push(format!("{indent}{} {name}[{alloc_size}];", c_dtype(&base_dtype)));
155 ctx.register(uop.id, name);
156 Some(())
157 }
158
159 Op::Index { buffer, indices, .. } => {
160 let buf = ctx.get(buffer).to_string();
161
162 if indices.is_empty() {
163 ctx.register(uop.id, buf);
165 } else {
166 let idx = if indices.len() > 1 {
168 render_linearize_multi_index_c(indices, ctx)
169 } else {
170 ctx.get(&indices[0]).to_string()
171 };
172 let expr = format!("{buf} + {idx}");
173 ctx.emit_expr(uop, expr, "idx", kernel);
174 }
175 Some(())
176 }
177
178 Op::PointerIndex { ptr, offset } => {
179 let ptr_val = ctx.get(ptr).to_string();
180 let off_val = ctx.get(offset).to_string();
181 let expr = format!("{ptr_val} + {off_val}");
182 ctx.emit_expr(uop, expr, "pidx", kernel);
183 Some(())
184 }
185
186 Op::Load { index, alt, .. } => {
187 let idx = ctx.get(index).to_string();
188 let load_dtype = uop.dtype();
189 let gate_expr = if let Op::Index { gate: Some(gate_uop), .. } = index.op() {
192 Some(ctx.get(gate_uop).to_string())
193 } else {
194 None
195 };
196 let deref_expr = if load_dtype.vcount() > 1 {
197 let cast_type = c_dtype(&load_dtype);
198 format!("*(({cast_type}*)({idx}))")
199 } else {
200 format!("*({idx})")
201 };
202 let expr = if let Some(gate) = gate_expr {
203 let alt_expr = if let Some(alt_uop) = alt {
205 ctx.get(alt_uop).to_string()
206 } else {
207 c_const(&morok_ir::types::ConstValue::zero(load_dtype.base()), &load_dtype)
208 };
209 format!("({gate} ? {deref_expr} : {alt_expr})")
210 } else {
211 deref_expr
212 };
213 ctx.emit_expr(uop, expr, "val", kernel);
214 Some(())
215 }
216
217 Op::Store { index, value, .. } => {
218 let idx = ctx.get(index).to_string();
219 let val = ctx.get(value).to_string();
220 let indent = ctx.indent();
221 let val_dtype = value.dtype();
222 if val_dtype.vcount() > 1 {
225 let cast_type = c_dtype(&val_dtype);
226 kernel.push(format!("{indent}*(({cast_type}*)({idx})) = {val};"));
227 } else {
228 kernel.push(format!("{indent}*({idx}) = {val};"));
229 }
230 Some(())
231 }
232
233 Op::Binary(op, lhs, rhs) => {
234 let l = ctx.get(lhs).to_string();
235 let r = ctx.get(rhs).to_string();
236 let expr = render_binary(*op, &l, &r, &lhs.dtype());
237 ctx.emit_expr(uop, expr, "alu", kernel);
238 Some(())
239 }
240
241 Op::Unary(op, src) => {
242 let s = ctx.get(src).to_string();
243 let expr = render_unary(*op, &s, &src.dtype());
244 ctx.emit_expr(uop, expr, "alu", kernel);
245 Some(())
246 }
247
248 Op::Ternary(TernaryOp::Where, cond, t, f) => {
249 let c = ctx.get(cond).to_string();
250 let tv = ctx.get(t).to_string();
251 let fv = ctx.get(f).to_string();
252 let expr = format!("({c} ? {tv} : {fv})");
253 ctx.emit_expr(uop, expr, "alu", kernel);
254 Some(())
255 }
256
257 Op::Ternary(TernaryOp::MulAcc, a, b, c) => {
258 let av = ctx.get(a).to_string();
259 let bv = ctx.get(b).to_string();
260 let cv = ctx.get(c).to_string();
261 let expr = if a.dtype().is_float() {
262 format!("{}({av}, {bv}, {cv})", c_math_fn("__builtin_fma", &a.dtype()))
263 } else {
264 format!("(({av} * {bv}) + {cv})")
265 };
266 ctx.emit_expr(uop, expr, "alu", kernel);
267 Some(())
268 }
269
270 Op::Cast { src, dtype } => {
271 let s = ctx.get(src).to_string();
272
273 if matches!(src.op(), Op::Index { .. }) && matches!(dtype, DType::Ptr { .. }) {
275 ctx.register(uop.id, s);
276 return Some(());
277 }
278
279 let expr = if dtype.vcount() > 1 && !matches!(dtype, DType::Ptr { .. }) {
282 format!("__builtin_convertvector({s}, {})", c_dtype(dtype))
283 } else {
284 c_cast(&s, &src.dtype(), dtype)
285 };
286 ctx.emit_expr(uop, expr, "cast", kernel);
287 Some(())
288 }
289
290 Op::BitCast { src, dtype } => {
291 let s = ctx.get(src).to_string();
292 let from_type = c_dtype(&src.dtype());
293 let to_type = c_dtype(dtype);
294 if from_type == to_type {
295 ctx.register(uop.id, s);
296 } else {
297 let expr = format!("__builtin_bit_cast({to_type}, ({from_type})({s}))");
298 ctx.emit_expr(uop, expr, "cast", kernel);
299 }
300 Some(())
301 }
302
303 Op::Range { end, axis_id, axis_type, .. } => {
304 if matches!(axis_type, AxisType::Thread) {
305 return None;
306 }
307 let end_val = ctx.get(end).to_string();
308 let id = axis_id.value();
309 let range_dtype = c_dtype(&uop.dtype());
310 let var_name = format!("ridx{id}");
311 let indent = ctx.indent();
312 kernel.push(format!("{indent}for ({range_dtype} {var_name} = 0; {var_name} < {end_val}; {var_name}++) {{"));
313 ctx.register(uop.id, var_name);
314 ctx.push_indent();
315 Some(())
316 }
317
318 Op::End { ranges, .. } => {
319 for range in ranges.iter() {
320 if let Op::Range { axis_type, .. } = range.op() {
321 if matches!(axis_type, AxisType::Thread) {
322 continue;
323 }
324 ctx.pop_indent();
325 let indent = ctx.indent();
326 kernel.push(format!("{indent}}}"));
327 }
328 }
329
330 let pending = ctx.take_pending_reduces();
334 for (reduce_id, (acc_name, _dtype)) in pending {
335 ctx.register(reduce_id, acc_name);
338 }
339 Some(())
340 }
341
342 Op::Reduce { src, ranges, reduce_op } => {
343 let src_val = ctx.get(src).to_string();
344 let dtype = &uop.dtype();
345
346 if ranges.is_empty() {
347 ctx.register(uop.id, src_val);
349 } else {
350 let acc_name = ctx.get(uop).to_string();
352 let indent = ctx.indent();
353
354 let acc_expr = render_reduce_accumulate(*reduce_op, &acc_name, &src_val, dtype);
355 kernel.push(format!("{indent}{acc_expr}"));
356
357 ctx.register_reduce_pending(uop.id, acc_name, dtype.clone());
359 }
360 Some(())
361 }
362
363 Op::Gep { vector, indices } => {
364 let vec = ctx.get(vector).to_string();
365 if indices.len() == 1 {
366 let expr = format!("({vec})[{}]", indices[0]);
368 ctx.emit_expr(uop, expr, "gep", kernel);
369 } else {
370 let out_dtype = c_dtype(&uop.dtype());
372 let elements: Vec<String> = indices.iter().map(|&i| format!("({vec})[{i}]")).collect();
373 let expr = format!("({out_dtype}){{{}}}", elements.join(", "));
374 ctx.emit_expr(uop, expr, "gep", kernel);
375 }
376 Some(())
377 }
378
379 Op::Vectorize { elements } => {
380 let vals: Vec<String> = elements.iter().map(|e| ctx.get(e).to_string()).collect();
381 if matches!(uop.dtype(), DType::Ptr { .. }) {
382 ctx.emit_expr(uop, vals[0].clone(), "vec", kernel);
385 } else {
386 let out_dtype = c_dtype(&uop.dtype());
387 let expr = format!("({out_dtype}){{{}}}", vals.join(", "));
388 ctx.emit_expr(uop, expr, "vec", kernel);
389 }
390 Some(())
391 }
392
393 Op::Cat { sources } => {
394 render_cat(uop, sources, ctx, kernel);
395 Some(())
396 }
397
398 Op::PtrCat { .. } => {
399 panic!(
400 "PtrCat must be eliminated before codegen (devectorize should distribute it into scalar loads/stores)"
401 );
402 }
403
404 Op::Wmma { a, b, c, metadata } => {
405 let a_val = ctx.get(a).to_string();
406 let b_val = ctx.get(b).to_string();
407 let c_val = ctx.get(c).to_string();
408 let expr = format!("__{name}({a_val}, {b_val}, {c_val})", name = metadata.name);
409 ctx.emit_expr(uop, expr, "wmma", kernel);
410 Some(())
411 }
412
413 Op::Contract { src, .. } | Op::Unroll { src, .. } | Op::Detach { src } => {
414 let s = ctx.get(src).to_string();
415 ctx.register(uop.id, s);
416 None
417 }
418
419 Op::After { passthrough, .. } => {
420 assert!(
421 !matches!(passthrough.op(), Op::Group { .. }),
422 "BUG: AFTER passthrough is GROUP (id={}). AFTER tree:\n{}",
423 passthrough.id,
424 uop.tree()
425 );
426 let s = ctx.get(passthrough).to_string();
427 ctx.register(uop.id, s);
428 None
429 }
430
431 Op::Bind { var, value } => {
432 let v = ctx.get(value).to_string();
433 ctx.register(var.id, v);
434 None
435 }
436
437 Op::If { condition, .. } => {
438 let cond = ctx.get(condition).to_string();
439 let indent = ctx.indent();
440 kernel.push(format!("{indent}if ({cond}) {{"));
441 ctx.push_indent();
442 Some(())
443 }
444
445 Op::EndIf { .. } => {
446 ctx.pop_indent();
447 let indent = ctx.indent();
448 kernel.push(format!("{indent}}}"));
449 Some(())
450 }
451
452 _ => {
453 let indent = ctx.indent();
454 kernel.push(format!("{indent}/* UNSUPPORTED: {:?} */", uop.op().as_ref()));
455 None
456 }
457 }
458}
459
460fn render_linearize_multi_index_c(indices: &[Arc<UOp>], ctx: &CContext) -> String {
464 use morok_schedule::passes::linearize_index::{compute_row_major_strides, extract_index_dimension};
465
466 let dims: Vec<i64> = indices
467 .iter()
468 .map(|idx| extract_index_dimension(idx).expect("multi-index dimension must be resolvable at codegen"))
469 .collect();
470 let strides = compute_row_major_strides(&dims);
471
472 let mut terms: Vec<String> = Vec::new();
473 for (idx_uop, &stride) in indices.iter().zip(strides.iter()) {
474 if stride == 0 {
475 continue;
476 }
477 let idx_val = ctx.get(idx_uop);
478 if stride == 1 {
479 terms.push(idx_val.to_string());
480 } else {
481 terms.push(format!("({idx_val} * {stride})"));
482 }
483 }
484
485 if terms.is_empty() { "0".to_string() } else { format!("({})", terms.join(" + ")) }
486}
487
488fn render_binary(op: BinaryOp, l: &str, r: &str, dtype: &DType) -> String {
490 match op {
491 BinaryOp::Add => format!("({l} + {r})"),
492 BinaryOp::Sub => format!("({l} - {r})"),
493 BinaryOp::Mul => format!("({l} * {r})"),
494 BinaryOp::Fdiv => format!("({l} / {r})"),
495 BinaryOp::Idiv => format!("({l} / {r})"),
496 BinaryOp::Mod => {
497 if dtype.is_float() {
498 format!("{}({l}, {r})", c_math_fn("__builtin_fmod", dtype))
499 } else {
500 format!("({l} % {r})")
501 }
502 }
503 BinaryOp::Max => {
504 if dtype.is_float() {
505 format!("{}({l}, {r})", c_math_fn("__builtin_fmax", dtype))
506 } else {
507 format!("({l} > {r} ? {l} : {r})")
508 }
509 }
510 BinaryOp::Lt => format!("({l} < {r})"),
511 BinaryOp::Le => format!("({l} <= {r})"),
512 BinaryOp::Gt => format!("({l} > {r})"),
513 BinaryOp::Ge => format!("({l} >= {r})"),
514 BinaryOp::Eq => format!("({l} == {r})"),
515 BinaryOp::Ne => format!("({l} != {r})"),
516 BinaryOp::And => format!("({l} & {r})"),
517 BinaryOp::Or => format!("({l} | {r})"),
518 BinaryOp::Xor => format!("({l} ^ {r})"),
519 BinaryOp::Shl => format!("({l} << {r})"),
520 BinaryOp::Shr => format!("({l} >> {r})"),
521 BinaryOp::Pow => {
522 if dtype.is_float() {
523 format!("{}({l}, {r})", c_math_fn("__builtin_pow", dtype))
524 } else {
525 format!("(({})__builtin_pow((double){l}, (double){r}))", c_dtype(&DType::Scalar(dtype.base())))
527 }
528 }
529 BinaryOp::Threefry => format!("({l} ^ {r})"),
530 }
531}
532
533fn render_unary(op: UnaryOp, s: &str, dtype: &DType) -> String {
535 match op {
536 UnaryOp::Neg => {
537 format!("(-{s})")
538 }
539 UnaryOp::Not => {
540 if dtype.is_bool() {
541 format!("(!{s})")
542 } else {
543 format!("(~{s})")
544 }
545 }
546 UnaryOp::Abs => {
547 if dtype.is_float() {
548 format!("{}({s})", c_math_fn("__builtin_fabs", dtype))
549 } else {
550 format!("({s} < 0 ? -{s} : {s})")
551 }
552 }
553 UnaryOp::Sqrt => format!("{}({s})", c_math_fn("__builtin_sqrt", dtype)),
554 UnaryOp::Rsqrt => {
555 let one = if matches!(dtype.base(), ScalarDType::Float64) { "1.0" } else { "1.0f" };
556 format!("({one} / {}({s}))", c_math_fn("__builtin_sqrt", dtype))
557 }
558 UnaryOp::Reciprocal => {
559 let one = if matches!(dtype.base(), ScalarDType::Float64) { "1.0" } else { "1.0f" };
560 format!("({one} / {s})")
561 }
562 UnaryOp::Exp => format!("{}({s})", c_math_fn("__builtin_exp", dtype)),
563 UnaryOp::Exp2 => format!("{}({s})", c_math_fn("__builtin_exp2", dtype)),
564 UnaryOp::Log => format!("{}({s})", c_math_fn("__builtin_log", dtype)),
565 UnaryOp::Log2 => format!("{}({s})", c_math_fn("__builtin_log2", dtype)),
566 UnaryOp::Sin => format!("{}({s})", c_math_fn("__builtin_sin", dtype)),
567 UnaryOp::Cos => format!("{}({s})", c_math_fn("__builtin_cos", dtype)),
568 UnaryOp::Tan => format!("{}({s})", c_math_fn("__builtin_tan", dtype)),
569 UnaryOp::Floor => format!("{}({s})", c_math_fn("__builtin_floor", dtype)),
570 UnaryOp::Ceil => format!("{}({s})", c_math_fn("__builtin_ceil", dtype)),
571 UnaryOp::Trunc => format!("{}({s})", c_math_fn("__builtin_trunc", dtype)),
572 UnaryOp::Round => format!("{}({s})", c_math_fn("__builtin_rint", dtype)),
573 UnaryOp::Erf => format!("{}({s})", c_math_fn("__builtin_erf", dtype)),
574 UnaryOp::Sign => {
575 if dtype.is_float() {
576 let zero = if matches!(dtype.base(), ScalarDType::Float64) { "0.0" } else { "0.0f" };
577 format!("(({s} > {zero}) - ({s} < {zero}))")
578 } else {
579 format!("(({s} > 0) - ({s} < 0))")
580 }
581 }
582 UnaryOp::Square => format!("({s} * {s})"),
583 }
584}
585
586fn render_reduce_accumulate(op: ReduceOp, acc: &str, val: &str, dtype: &DType) -> String {
588 match op {
589 ReduceOp::Add => format!("{acc} += {val};"),
590 ReduceOp::Mul => format!("{acc} *= {val};"),
591 ReduceOp::Max => {
592 if dtype.is_float() {
593 format!("{acc} = {}({acc}, {val});", c_math_fn("__builtin_fmax", dtype))
594 } else {
595 format!("{acc} = ({acc} > {val} ? {acc} : {val});")
596 }
597 }
598 ReduceOp::Min => {
599 if dtype.is_float() {
600 format!("{acc} = {}({acc}, {val});", c_math_fn("__builtin_fmin", dtype))
601 } else {
602 format!("{acc} = ({acc} < {val} ? {acc} : {val});")
603 }
604 }
605 }
606}
607
608fn render_cat(uop: &Arc<UOp>, sources: &[Arc<UOp>], ctx: &mut CContext, kernel: &mut Vec<String>) {
610 let out_dtype = c_dtype(&uop.dtype());
611 let mut elements = Vec::new();
612
613 for src in sources {
614 let src_val = ctx.get(src).to_string();
615 let src_vcount = src.dtype().vcount();
616 if src_vcount == 1 {
617 elements.push(src_val);
618 } else {
619 for i in 0..src_vcount {
620 elements.push(format!("{src_val}[{i}]"));
621 }
622 }
623 }
624
625 let expr = format!("({out_dtype}){{{}}}", elements.join(", "));
626 ctx.emit_expr(uop, expr, "cat", kernel);
627}
628
629pub fn count_references(nodes: &[Arc<UOp>]) -> HashMap<u64, usize> {
632 let mut counts: HashMap<u64, usize> = HashMap::new();
633 for node in nodes {
634 for child in node.op().children() {
635 *counts.entry(child.id).or_insert(0) += 1;
636 }
637 }
638 counts
639}