1use std::collections::HashMap;
8use std::sync::Arc;
9
10use morok_dtype::{DType, ScalarDType};
11use morok_ir::{AxisType, BinaryOp, Op, ReduceOp, TernaryOp, UnaryOp, prelude::*};
12
13use super::types::{c_cast, c_const, c_dtype, c_math_fn};
14
15pub struct CContext {
17 names: HashMap<u64, String>,
19 ref_counts: HashMap<u64, usize>,
21 counter: usize,
23 depth: usize,
25 pending_reduces: HashMap<u64, (String, DType)>,
27}
28
29impl CContext {
30 pub fn new(ref_counts: HashMap<u64, usize>) -> Self {
31 Self { names: HashMap::new(), ref_counts, counter: 0, depth: 1, pending_reduces: HashMap::new() }
32 }
33
34 pub fn get(&self, uop: &Arc<UOp>) -> &str {
36 self.names
37 .get(&uop.id)
38 .map(|s| s.as_str())
39 .unwrap_or_else(|| panic!("UOp {} ({:?}) not in C context", uop.id, uop.op()))
40 }
41
42 pub fn register(&mut self, id: u64, expr: String) {
44 self.names.insert(id, expr);
45 }
46
47 pub fn should_inline(&self, id: u64) -> bool {
49 self.ref_counts.get(&id).copied().unwrap_or(0) <= 1
50 }
51
52 pub fn next_name(&mut self, prefix: &str) -> String {
54 let name = format!("{}{}", prefix, self.counter);
55 self.counter += 1;
56 name
57 }
58
59 pub fn indent(&self) -> String {
61 " ".repeat(self.depth)
62 }
63
64 pub fn push_indent(&mut self) {
66 self.depth += 1;
67 }
68
69 pub fn pop_indent(&mut self) {
71 self.depth = self.depth.saturating_sub(1);
72 }
73
74 pub fn register_reduce_pending(&mut self, reduce_id: u64, acc_name: String, dtype: DType) {
76 self.pending_reduces.insert(reduce_id, (acc_name, dtype));
77 }
78
79 pub fn take_pending_reduces(&mut self) -> HashMap<u64, (String, DType)> {
81 std::mem::take(&mut self.pending_reduces)
82 }
83
84 pub fn emit_expr(&mut self, uop: &Arc<UOp>, expr: String, prefix: &str, kernel: &mut Vec<String>) -> String {
87 if self.should_inline(uop.id) {
88 self.register(uop.id, expr.clone());
89 expr
90 } else {
91 let name = self.next_name(prefix);
92 let dtype = c_dtype(&uop.dtype());
93 let indent = self.indent();
94 kernel.push(format!("{indent}{dtype} {name} = {expr};"));
95 self.register(uop.id, name.clone());
96 name
97 }
98 }
99}
100
101pub fn render_uop(uop: &Arc<UOp>, ctx: &mut CContext, kernel: &mut Vec<String>) -> Option<()> {
105 match uop.op() {
106 Op::Const(_)
108 | Op::VConst { .. }
109 | Op::DefineGlobal(_)
110 | Op::DefineLocal(_)
111 | Op::DefineVar { .. }
112 | Op::Noop
113 | Op::Sink { .. }
114 | Op::Group { .. }
115 | Op::Buffer { .. }
116 | Op::Unique(_)
117 | Op::Device(_)
118 | Op::Kernel { .. }
119 | Op::Barrier { .. } => None,
120
121 Op::DefineReg { .. } => {
122 let (base_dtype, alloc_size) = match uop.dtype() {
126 DType::Ptr { base, size, .. } => (base.as_ref().clone(), size.unwrap_or(1)),
127 other => (other, 1),
128 };
129 let name = ctx.next_name("reg");
130 let indent = ctx.indent();
131 kernel.push(format!("{indent}{} {name}[{alloc_size}];", c_dtype(&base_dtype)));
132 ctx.register(uop.id, name);
133 Some(())
134 }
135
136 Op::Index { buffer, indices, .. } => {
137 let buf = ctx.get(buffer).to_string();
138
139 if indices.is_empty() {
140 ctx.register(uop.id, buf);
142 } else {
143 let idx = if indices.len() > 1 {
145 render_linearize_multi_index_c(indices, ctx)
146 } else {
147 ctx.get(&indices[0]).to_string()
148 };
149 let expr = format!("{buf} + {idx}");
150 ctx.emit_expr(uop, expr, "idx", kernel);
151 }
152 Some(())
153 }
154
155 Op::PointerIndex { ptr, offset } => {
156 let ptr_val = ctx.get(ptr).to_string();
157 let off_val = ctx.get(offset).to_string();
158 let expr = format!("{ptr_val} + {off_val}");
159 ctx.emit_expr(uop, expr, "pidx", kernel);
160 Some(())
161 }
162
163 Op::Load { index, .. } => {
164 let idx = ctx.get(index).to_string();
165 let load_dtype = uop.dtype();
166 let gate_expr = if let Op::Index { gate: Some(gate_uop), .. } = index.op() {
169 Some(ctx.get(gate_uop).to_string())
170 } else {
171 None
172 };
173 let deref_expr = if load_dtype.vcount() > 1 {
174 let cast_type = c_dtype(&load_dtype);
175 format!("*(({cast_type}*)({idx}))")
176 } else {
177 format!("*({idx})")
178 };
179 let expr = if let Some(gate) = gate_expr {
180 let zero = c_const(&morok_ir::types::ConstValue::zero(load_dtype.base()), &load_dtype);
181 format!("({gate} ? {deref_expr} : {zero})")
182 } else {
183 deref_expr
184 };
185 ctx.emit_expr(uop, expr, "val", kernel);
186 Some(())
187 }
188
189 Op::Store { index, value, .. } => {
190 let idx = ctx.get(index).to_string();
191 let val = ctx.get(value).to_string();
192 let indent = ctx.indent();
193 let val_dtype = value.dtype();
194 if val_dtype.vcount() > 1 {
197 let cast_type = c_dtype(&val_dtype);
198 kernel.push(format!("{indent}*(({cast_type}*)({idx})) = {val};"));
199 } else {
200 kernel.push(format!("{indent}*({idx}) = {val};"));
201 }
202 Some(())
203 }
204
205 Op::Binary(op, lhs, rhs) => {
206 let l = ctx.get(lhs).to_string();
207 let r = ctx.get(rhs).to_string();
208 let expr = render_binary(*op, &l, &r, &lhs.dtype());
209 ctx.emit_expr(uop, expr, "alu", kernel);
210 Some(())
211 }
212
213 Op::Unary(op, src) => {
214 let s = ctx.get(src).to_string();
215 let expr = render_unary(*op, &s, &src.dtype());
216 ctx.emit_expr(uop, expr, "alu", kernel);
217 Some(())
218 }
219
220 Op::Ternary(TernaryOp::Where, cond, t, f) => {
221 let c = ctx.get(cond).to_string();
222 let tv = ctx.get(t).to_string();
223 let fv = ctx.get(f).to_string();
224 let expr = format!("({c} ? {tv} : {fv})");
225 ctx.emit_expr(uop, expr, "alu", kernel);
226 Some(())
227 }
228
229 Op::Ternary(TernaryOp::MulAcc, a, b, c) => {
230 let av = ctx.get(a).to_string();
231 let bv = ctx.get(b).to_string();
232 let cv = ctx.get(c).to_string();
233 let expr = if a.dtype().is_float() {
234 format!("{}({av}, {bv}, {cv})", c_math_fn("__builtin_fma", &a.dtype()))
235 } else {
236 format!("(({av} * {bv}) + {cv})")
237 };
238 ctx.emit_expr(uop, expr, "alu", kernel);
239 Some(())
240 }
241
242 Op::Cast { src, dtype } => {
243 let s = ctx.get(src).to_string();
244
245 if matches!(src.op(), Op::Index { .. }) && matches!(dtype, DType::Ptr { .. }) {
247 ctx.register(uop.id, s);
248 return Some(());
249 }
250
251 let expr = if dtype.vcount() > 1 && !matches!(dtype, DType::Ptr { .. }) {
254 format!("__builtin_convertvector({s}, {})", c_dtype(dtype))
255 } else {
256 c_cast(&s, &src.dtype(), dtype)
257 };
258 ctx.emit_expr(uop, expr, "cast", kernel);
259 Some(())
260 }
261
262 Op::BitCast { src, dtype } => {
263 let s = ctx.get(src).to_string();
264 let from_type = c_dtype(&src.dtype());
265 let to_type = c_dtype(dtype);
266 if from_type == to_type {
267 ctx.register(uop.id, s);
268 } else {
269 let expr = format!("__builtin_bit_cast({to_type}, ({from_type})({s}))");
270 ctx.emit_expr(uop, expr, "cast", kernel);
271 }
272 Some(())
273 }
274
275 Op::Range { end, axis_id, axis_type, .. } => {
276 if matches!(axis_type, AxisType::Thread) {
277 return None;
278 }
279 let end_val = ctx.get(end).to_string();
280 let id = axis_id.value();
281 let range_dtype = c_dtype(&uop.dtype());
282 let var_name = format!("ridx{id}");
283 let indent = ctx.indent();
284 kernel.push(format!("{indent}for ({range_dtype} {var_name} = 0; {var_name} < {end_val}; {var_name}++) {{"));
285 ctx.register(uop.id, var_name);
286 ctx.push_indent();
287 Some(())
288 }
289
290 Op::End { ranges, .. } => {
291 for range in ranges.iter() {
292 if let Op::Range { axis_type, .. } = range.op() {
293 if matches!(axis_type, AxisType::Thread) {
294 continue;
295 }
296 ctx.pop_indent();
297 let indent = ctx.indent();
298 kernel.push(format!("{indent}}}"));
299 }
300 }
301
302 let pending = ctx.take_pending_reduces();
306 for (reduce_id, (acc_name, _dtype)) in pending {
307 ctx.register(reduce_id, acc_name);
310 }
311 Some(())
312 }
313
314 Op::Reduce { src, ranges, reduce_op } => {
315 let src_val = ctx.get(src).to_string();
316 let dtype = &uop.dtype();
317
318 if ranges.is_empty() {
319 ctx.register(uop.id, src_val);
321 } else {
322 let acc_name = ctx.get(uop).to_string();
324 let indent = ctx.indent();
325
326 let acc_expr = render_reduce_accumulate(*reduce_op, &acc_name, &src_val, dtype);
327 kernel.push(format!("{indent}{acc_expr}"));
328
329 ctx.register_reduce_pending(uop.id, acc_name, dtype.clone());
331 }
332 Some(())
333 }
334
335 Op::Gep { vector, indices } => {
336 let vec = ctx.get(vector).to_string();
337 if indices.len() == 1 {
338 let expr = format!("({vec})[{}]", indices[0]);
340 ctx.emit_expr(uop, expr, "gep", kernel);
341 } else {
342 let out_dtype = c_dtype(&uop.dtype());
344 let elements: Vec<String> = indices.iter().map(|&i| format!("({vec})[{i}]")).collect();
345 let expr = format!("({out_dtype}){{{}}}", elements.join(", "));
346 ctx.emit_expr(uop, expr, "gep", kernel);
347 }
348 Some(())
349 }
350
351 Op::Vectorize { elements } => {
352 let vals: Vec<String> = elements.iter().map(|e| ctx.get(e).to_string()).collect();
353 if matches!(uop.dtype(), DType::Ptr { .. }) {
354 ctx.emit_expr(uop, vals[0].clone(), "vec", kernel);
357 } else {
358 let out_dtype = c_dtype(&uop.dtype());
359 let expr = format!("({out_dtype}){{{}}}", vals.join(", "));
360 ctx.emit_expr(uop, expr, "vec", kernel);
361 }
362 Some(())
363 }
364
365 Op::Cat { sources } => {
366 render_cat(uop, sources, ctx, kernel);
367 Some(())
368 }
369
370 Op::PtrCat { sources } => {
371 render_cat(uop, sources, ctx, kernel);
373 Some(())
374 }
375
376 Op::Wmma { a, b, c, metadata } => {
377 let a_val = ctx.get(a).to_string();
378 let b_val = ctx.get(b).to_string();
379 let c_val = ctx.get(c).to_string();
380 let expr = format!("__{name}({a_val}, {b_val}, {c_val})", name = metadata.name);
381 ctx.emit_expr(uop, expr, "wmma", kernel);
382 Some(())
383 }
384
385 Op::Contract { src, .. } | Op::Unroll { src, .. } | Op::Detach { src } => {
386 let s = ctx.get(src).to_string();
387 ctx.register(uop.id, s);
388 None
389 }
390
391 Op::After { passthrough, .. } => {
392 let s = ctx.get(passthrough).to_string();
393 ctx.register(uop.id, s);
394 None
395 }
396
397 Op::Bind { var, value } => {
398 let v = ctx.get(value).to_string();
399 ctx.register(var.id, v);
400 None
401 }
402
403 Op::If { condition, .. } => {
404 let cond = ctx.get(condition).to_string();
405 let indent = ctx.indent();
406 kernel.push(format!("{indent}if ({cond}) {{"));
407 ctx.push_indent();
408 Some(())
409 }
410
411 Op::EndIf { .. } => {
412 ctx.pop_indent();
413 let indent = ctx.indent();
414 kernel.push(format!("{indent}}}"));
415 Some(())
416 }
417
418 _ => {
419 let indent = ctx.indent();
420 kernel.push(format!("{indent}/* UNSUPPORTED: {:?} */", uop.op().as_ref()));
421 None
422 }
423 }
424}
425
426fn render_linearize_multi_index_c(indices: &[Arc<UOp>], ctx: &CContext) -> String {
430 use morok_schedule::passes::linearize_index::{compute_row_major_strides, extract_index_dimension};
431
432 let dims: Vec<i64> = indices
433 .iter()
434 .map(|idx| extract_index_dimension(idx).expect("multi-index dimension must be resolvable at codegen"))
435 .collect();
436 let strides = compute_row_major_strides(&dims);
437
438 let mut terms: Vec<String> = Vec::new();
439 for (idx_uop, &stride) in indices.iter().zip(strides.iter()) {
440 if stride == 0 {
441 continue;
442 }
443 let idx_val = ctx.get(idx_uop);
444 if stride == 1 {
445 terms.push(idx_val.to_string());
446 } else {
447 terms.push(format!("({idx_val} * {stride})"));
448 }
449 }
450
451 if terms.is_empty() { "0".to_string() } else { format!("({})", terms.join(" + ")) }
452}
453
454fn render_binary(op: BinaryOp, l: &str, r: &str, dtype: &DType) -> String {
456 match op {
457 BinaryOp::Add => format!("({l} + {r})"),
458 BinaryOp::Sub => format!("({l} - {r})"),
459 BinaryOp::Mul => format!("({l} * {r})"),
460 BinaryOp::Fdiv => format!("({l} / {r})"),
461 BinaryOp::Idiv => format!("({l} / {r})"),
462 BinaryOp::Mod => {
463 if dtype.is_float() {
464 format!("{}({l}, {r})", c_math_fn("__builtin_fmod", dtype))
465 } else {
466 format!("({l} % {r})")
467 }
468 }
469 BinaryOp::Max => {
470 if dtype.is_float() {
471 format!("{}({l}, {r})", c_math_fn("__builtin_fmax", dtype))
472 } else {
473 format!("({l} > {r} ? {l} : {r})")
474 }
475 }
476 BinaryOp::Lt => format!("({l} < {r})"),
477 BinaryOp::Le => format!("({l} <= {r})"),
478 BinaryOp::Gt => format!("({l} > {r})"),
479 BinaryOp::Ge => format!("({l} >= {r})"),
480 BinaryOp::Eq => format!("({l} == {r})"),
481 BinaryOp::Ne => format!("({l} != {r})"),
482 BinaryOp::And => format!("({l} & {r})"),
483 BinaryOp::Or => format!("({l} | {r})"),
484 BinaryOp::Xor => format!("({l} ^ {r})"),
485 BinaryOp::Shl => format!("({l} << {r})"),
486 BinaryOp::Shr => format!("({l} >> {r})"),
487 BinaryOp::Pow => {
488 if dtype.is_float() {
489 format!("{}({l}, {r})", c_math_fn("__builtin_pow", dtype))
490 } else {
491 format!("(({})__builtin_pow((double){l}, (double){r}))", c_dtype(&DType::Scalar(dtype.base())))
493 }
494 }
495 BinaryOp::Threefry => format!("({l} ^ {r})"),
496 }
497}
498
499fn render_unary(op: UnaryOp, s: &str, dtype: &DType) -> String {
501 match op {
502 UnaryOp::Neg => {
503 format!("(-{s})")
504 }
505 UnaryOp::Not => {
506 if dtype.is_bool() {
507 format!("(!{s})")
508 } else {
509 format!("(~{s})")
510 }
511 }
512 UnaryOp::Abs => {
513 if dtype.is_float() {
514 format!("{}({s})", c_math_fn("__builtin_fabs", dtype))
515 } else {
516 format!("({s} < 0 ? -{s} : {s})")
517 }
518 }
519 UnaryOp::Sqrt => format!("{}({s})", c_math_fn("__builtin_sqrt", dtype)),
520 UnaryOp::Rsqrt => {
521 let one = if matches!(dtype.base(), ScalarDType::Float64) { "1.0" } else { "1.0f" };
522 format!("({one} / {}({s}))", c_math_fn("__builtin_sqrt", dtype))
523 }
524 UnaryOp::Reciprocal => {
525 let one = if matches!(dtype.base(), ScalarDType::Float64) { "1.0" } else { "1.0f" };
526 format!("({one} / {s})")
527 }
528 UnaryOp::Exp => format!("{}({s})", c_math_fn("__builtin_exp", dtype)),
529 UnaryOp::Exp2 => format!("{}({s})", c_math_fn("__builtin_exp2", dtype)),
530 UnaryOp::Log => format!("{}({s})", c_math_fn("__builtin_log", dtype)),
531 UnaryOp::Log2 => format!("{}({s})", c_math_fn("__builtin_log2", dtype)),
532 UnaryOp::Sin => format!("{}({s})", c_math_fn("__builtin_sin", dtype)),
533 UnaryOp::Cos => format!("{}({s})", c_math_fn("__builtin_cos", dtype)),
534 UnaryOp::Tan => format!("{}({s})", c_math_fn("__builtin_tan", dtype)),
535 UnaryOp::Floor => format!("{}({s})", c_math_fn("__builtin_floor", dtype)),
536 UnaryOp::Ceil => format!("{}({s})", c_math_fn("__builtin_ceil", dtype)),
537 UnaryOp::Trunc => format!("{}({s})", c_math_fn("__builtin_trunc", dtype)),
538 UnaryOp::Round => format!("{}({s})", c_math_fn("__builtin_rint", dtype)),
539 UnaryOp::Erf => format!("{}({s})", c_math_fn("__builtin_erf", dtype)),
540 UnaryOp::Sign => {
541 if dtype.is_float() {
542 let zero = if matches!(dtype.base(), ScalarDType::Float64) { "0.0" } else { "0.0f" };
543 format!("(({s} > {zero}) - ({s} < {zero}))")
544 } else {
545 format!("(({s} > 0) - ({s} < 0))")
546 }
547 }
548 UnaryOp::Square => format!("({s} * {s})"),
549 }
550}
551
552fn render_reduce_accumulate(op: ReduceOp, acc: &str, val: &str, dtype: &DType) -> String {
554 match op {
555 ReduceOp::Add => format!("{acc} += {val};"),
556 ReduceOp::Mul => format!("{acc} *= {val};"),
557 ReduceOp::Max => {
558 if dtype.is_float() {
559 format!("{acc} = {}({acc}, {val});", c_math_fn("__builtin_fmax", dtype))
560 } else {
561 format!("{acc} = ({acc} > {val} ? {acc} : {val});")
562 }
563 }
564 ReduceOp::Min => {
565 if dtype.is_float() {
566 format!("{acc} = {}({acc}, {val});", c_math_fn("__builtin_fmin", dtype))
567 } else {
568 format!("{acc} = ({acc} < {val} ? {acc} : {val});")
569 }
570 }
571 }
572}
573
574fn render_cat(uop: &Arc<UOp>, sources: &[Arc<UOp>], ctx: &mut CContext, kernel: &mut Vec<String>) {
576 let out_dtype = c_dtype(&uop.dtype());
577 let mut elements = Vec::new();
578
579 for src in sources {
580 let src_val = ctx.get(src).to_string();
581 let src_vcount = src.dtype().vcount();
582 if src_vcount == 1 {
583 elements.push(src_val);
584 } else {
585 for i in 0..src_vcount {
586 elements.push(format!("{src_val}[{i}]"));
587 }
588 }
589 }
590
591 let expr = format!("({out_dtype}){{{}}}", elements.join(", "));
592 ctx.emit_expr(uop, expr, "cat", kernel);
593}
594
595pub fn count_references(nodes: &[Arc<UOp>]) -> HashMap<u64, usize> {
598 let mut counts: HashMap<u64, usize> = HashMap::new();
599 for node in nodes {
600 for child in node.op().children() {
601 *counts.entry(child.id).or_insert(0) += 1;
602 }
603 }
604 counts
605}