morok_codegen/llvm/common/
ctx.rs1use std::collections::HashMap;
7use std::sync::Arc;
8
9use morok_ir::{ConstValue, Op, prelude::*};
10
11use super::types::{lconst, ldt};
12
13pub struct PendingReduce {
15 pub acc_ptr: String,
16 pub dtype: String,
17}
18
19pub struct RenderContext {
21 names: HashMap<u64, String>,
22 range_values: HashMap<usize, String>,
23 counter: usize,
24 pending_reduces: HashMap<u64, PendingReduce>,
26}
27
28impl RenderContext {
29 pub fn new() -> Self {
30 Self { names: HashMap::new(), range_values: HashMap::new(), counter: 0, pending_reduces: HashMap::new() }
31 }
32
33 pub fn name(&mut self, uop: &Arc<UOp>) -> String {
39 if let Some(name) = self.names.get(&uop.id) {
40 return name.clone();
41 }
42
43 let name = match uop.op() {
44 Op::Const(cv) => lconst(&cv.0, &uop.dtype()),
45 Op::VConst { values } => self.render_vconst(values, uop),
46 Op::DefineGlobal(id) => format!("%data{id}"),
47 Op::DefineLocal(id) => format!("%local{id}"),
48 Op::DefineVar { name, .. } => format!("%{name}"),
49 Op::DefineReg { .. } => {
50 let n = format!("%reg{}", self.counter);
51 self.counter += 1;
52 n
53 }
54 Op::Range { axis_id, .. } => {
55 format!("%r{}", axis_id.value())
57 }
58 _ => {
59 let n = format!("%v{}", self.counter);
60 self.counter += 1;
61 n
62 }
63 };
64
65 self.names.insert(uop.id, name.clone());
66 name
67 }
68
69 fn render_vconst(&self, values: &[ConstValue], uop: &Arc<UOp>) -> String {
71 let scalar_type = ldt(&uop.dtype().scalar_dtype());
72
73 let elements: Vec<String> = values
75 .iter()
76 .map(|v| {
77 let val = lconst(v, &uop.dtype());
78 format!("{scalar_type} {val}")
79 })
80 .collect();
81
82 format!("<{}>", elements.join(", "))
83 }
84
85 pub fn get(&self, uop: &Arc<UOp>) -> &str {
87 self.names
88 .get(&uop.id)
89 .map(|s| s.as_str())
90 .unwrap_or_else(|| panic!("UOp {} ({:?}) not in context", uop.id, uop.op()))
91 }
92
93 pub fn try_get(&self, uop: &Arc<UOp>) -> Option<&str> {
95 self.names.get(&uop.id).map(|s| s.as_str())
96 }
97
98 pub fn contains(&self, id: u64) -> bool {
100 self.names.contains_key(&id)
101 }
102
103 pub fn alias(&mut self, id: u64, name: String) {
105 self.names.insert(id, name);
106 }
107
108 pub fn register(&mut self, id: u64, name: String) {
110 self.names.insert(id, name);
111 }
112
113 pub fn counter(&self) -> usize {
115 self.counter
116 }
117
118 pub fn register_range(&mut self, axis_id: usize, name: String) {
120 self.range_values.insert(axis_id, name);
121 }
122
123 pub fn get_range(&self, axis_id: usize) -> Option<&str> {
125 self.range_values.get(&axis_id).map(|s| s.as_str())
126 }
127
128 pub fn register_reduce_pending(&mut self, reduce_id: u64, acc_ptr: String, dtype: String) {
130 self.pending_reduces.insert(reduce_id, PendingReduce { acc_ptr, dtype });
131 }
132
133 pub fn take_pending_reduces(&mut self) -> HashMap<u64, PendingReduce> {
135 std::mem::take(&mut self.pending_reduces)
136 }
137
138 pub fn has_pending_reduces(&self) -> bool {
140 !self.pending_reduces.is_empty()
141 }
142}
143
144impl Default for RenderContext {
145 fn default() -> Self {
146 Self::new()
147 }
148}