morok_codegen/llvm/common/
ctx.rs1use std::collections::HashMap;
7use std::sync::Arc;
8
9use morok_ir::{ConstValue, Op, prelude::*};
10
11use super::types::{lconst, ldt};
12
13pub struct PendingReduce {
15 pub acc_ptr: String,
16 pub dtype: String,
17}
18
19pub struct RenderContext {
21 names: HashMap<u64, String>,
22 range_values: HashMap<usize, String>,
23 counter: usize,
24 pending_reduces: HashMap<u64, PendingReduce>,
26 range_stack: Vec<usize>,
29}
30
31impl RenderContext {
32 pub fn new() -> Self {
33 Self {
34 names: HashMap::new(),
35 range_values: HashMap::new(),
36 counter: 0,
37 pending_reduces: HashMap::new(),
38 range_stack: Vec::new(),
39 }
40 }
41
42 pub fn name(&mut self, uop: &Arc<UOp>) -> String {
48 if let Some(name) = self.names.get(&uop.id) {
49 return name.clone();
50 }
51
52 let name = match uop.op() {
53 Op::Const(cv) => lconst(&cv.0, &uop.dtype()),
54 Op::VConst { values } => self.render_vconst(values, uop),
55 Op::Param { slot, device: None, .. } => format!("%data{slot}"),
56 Op::DefineLocal(id) => format!("%local{id}"),
57 Op::DefineVar { name, .. } => format!("%{name}"),
58 Op::DefineReg { .. } => {
59 let n = format!("%reg{}", self.counter);
60 self.counter += 1;
61 n
62 }
63 Op::Range { axis_id, .. } => {
64 format!("%r{}", axis_id.value())
66 }
67 _ => {
68 let n = format!("%v{}", self.counter);
69 self.counter += 1;
70 n
71 }
72 };
73
74 self.names.insert(uop.id, name.clone());
75 name
76 }
77
78 fn render_vconst(&self, values: &[ConstValue], uop: &Arc<UOp>) -> String {
80 let scalar_type = ldt(&uop.dtype().scalar_dtype());
81
82 let elements: Vec<String> = values
84 .iter()
85 .map(|v| {
86 let val = lconst(v, &uop.dtype());
87 format!("{scalar_type} {val}")
88 })
89 .collect();
90
91 format!("<{}>", elements.join(", "))
92 }
93
94 pub fn get(&self, uop: &Arc<UOp>) -> &str {
96 self.names
97 .get(&uop.id)
98 .map(|s| s.as_str())
99 .unwrap_or_else(|| panic!("UOp {} ({:?}) not in context", uop.id, uop.op()))
100 }
101
102 pub fn try_get(&self, uop: &Arc<UOp>) -> Option<&str> {
104 self.names.get(&uop.id).map(|s| s.as_str())
105 }
106
107 pub fn contains(&self, id: u64) -> bool {
109 self.names.contains_key(&id)
110 }
111
112 pub fn alias(&mut self, id: u64, name: String) {
114 self.names.insert(id, name);
115 }
116
117 pub fn register(&mut self, id: u64, name: String) {
119 self.names.insert(id, name);
120 }
121
122 pub fn counter(&self) -> usize {
124 self.counter
125 }
126
127 pub fn register_range(&mut self, axis_id: usize, name: String) {
129 self.range_values.insert(axis_id, name);
130 }
131
132 pub fn get_range(&self, axis_id: usize) -> Option<&str> {
134 self.range_values.get(&axis_id).map(|s| s.as_str())
135 }
136
137 pub fn push_range(&mut self, axis_id: usize) {
139 self.range_stack.push(axis_id);
140 }
141
142 pub fn pop_range(&mut self) -> Option<usize> {
144 self.range_stack.pop()
145 }
146
147 pub fn register_reduce_pending(&mut self, reduce_id: u64, acc_ptr: String, dtype: String) {
149 self.pending_reduces.insert(reduce_id, PendingReduce { acc_ptr, dtype });
150 }
151
152 pub fn take_pending_reduces(&mut self) -> HashMap<u64, PendingReduce> {
154 std::mem::take(&mut self.pending_reduces)
155 }
156
157 pub fn has_pending_reduces(&self) -> bool {
159 !self.pending_reduces.is_empty()
160 }
161}
162
163impl Default for RenderContext {
164 fn default() -> Self {
165 Self::new()
166 }
167}