Skip to main content

morok_codegen/llvm/common/
ctx.rs

1//! Render context for LLVM IR text generation.
2//!
3//! Maps UOp IDs to LLVM variable names and manages naming.
4//! Shared between CPU and GPU backends.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use morok_ir::{ConstValue, Op, prelude::*};
10
11use super::types::{lconst, ldt};
12
13/// Pending reduce load info.
14pub struct PendingReduce {
15    pub acc_ptr: String,
16    pub dtype: String,
17}
18
19/// Maps UOp ID → LLVM variable name.
20pub struct RenderContext {
21    names: HashMap<u64, String>,
22    range_values: HashMap<usize, String>,
23    counter: usize,
24    /// Pending reduce final loads: reduce_id -> (acc_ptr, dtype)
25    pending_reduces: HashMap<u64, PendingReduce>,
26    /// Stack of currently open RANGE axis_ids (for correct END footer ordering).
27    /// Pushed on RANGE emission, popped on END emission.
28    range_stack: Vec<usize>,
29}
30
31impl RenderContext {
32    pub fn new() -> Self {
33        Self {
34            names: HashMap::new(),
35            range_values: HashMap::new(),
36            counter: 0,
37            pending_reduces: HashMap::new(),
38            range_stack: Vec::new(),
39        }
40    }
41
42    /// Get or create variable name for UOp.
43    ///
44    /// For constants, returns literal value.
45    /// For definitions, returns argument name.
46    /// For other ops, returns a generated variable name.
47    pub fn name(&mut self, uop: &Arc<UOp>) -> String {
48        if let Some(name) = self.names.get(&uop.id) {
49            return name.clone();
50        }
51
52        let name = match uop.op() {
53            Op::Const(cv) => lconst(&cv.0, &uop.dtype()),
54            Op::VConst { values } => self.render_vconst(values, uop),
55            Op::Param { slot, device: None, .. } => format!("%data{slot}"),
56            Op::DefineLocal(id) => format!("%local{id}"),
57            Op::DefineVar { name, .. } => format!("%{name}"),
58            Op::DefineReg { .. } => {
59                let n = format!("%reg{}", self.counter);
60                self.counter += 1;
61                n
62            }
63            Op::Range { axis_id, .. } => {
64                // Range variables are named by axis_id
65                format!("%r{}", axis_id.value())
66            }
67            _ => {
68                let n = format!("%v{}", self.counter);
69                self.counter += 1;
70                n
71            }
72        };
73
74        self.names.insert(uop.id, name.clone());
75        name
76    }
77
78    /// Render a vector constant.
79    fn render_vconst(&self, values: &[ConstValue], uop: &Arc<UOp>) -> String {
80        let scalar_type = ldt(&uop.dtype().scalar_dtype());
81
82        // Format as LLVM vector constant: <type val, type val, ...>
83        let elements: Vec<String> = values
84            .iter()
85            .map(|v| {
86                let val = lconst(v, &uop.dtype());
87                format!("{scalar_type} {val}")
88            })
89            .collect();
90
91        format!("<{}>", elements.join(", "))
92    }
93
94    /// Get existing name (panics if not found).
95    pub fn get(&self, uop: &Arc<UOp>) -> &str {
96        self.names
97            .get(&uop.id)
98            .map(|s| s.as_str())
99            .unwrap_or_else(|| panic!("UOp {} ({:?}) not in context", uop.id, uop.op()))
100    }
101
102    /// Try to get existing name.
103    pub fn try_get(&self, uop: &Arc<UOp>) -> Option<&str> {
104        self.names.get(&uop.id).map(|s| s.as_str())
105    }
106
107    /// Check if a UOp is already registered.
108    pub fn contains(&self, id: u64) -> bool {
109        self.names.contains_key(&id)
110    }
111
112    /// Alias one ID to another's name.
113    pub fn alias(&mut self, id: u64, name: String) {
114        self.names.insert(id, name);
115    }
116
117    /// Pre-register a name for a UOp ID.
118    pub fn register(&mut self, id: u64, name: String) {
119        self.names.insert(id, name);
120    }
121
122    /// Get current variable counter.
123    pub fn counter(&self) -> usize {
124        self.counter
125    }
126
127    /// Register a range value by axis_id.
128    pub fn register_range(&mut self, axis_id: usize, name: String) {
129        self.range_values.insert(axis_id, name);
130    }
131
132    /// Get a range value by axis_id.
133    pub fn get_range(&self, axis_id: usize) -> Option<&str> {
134        self.range_values.get(&axis_id).map(|s| s.as_str())
135    }
136
137    /// Push a range axis_id onto the open-range stack (called during RANGE codegen).
138    pub fn push_range(&mut self, axis_id: usize) {
139        self.range_stack.push(axis_id);
140    }
141
142    /// Pop the innermost open range axis_id (called during END codegen).
143    pub fn pop_range(&mut self) -> Option<usize> {
144        self.range_stack.pop()
145    }
146
147    /// Register a pending reduce final load.
148    pub fn register_reduce_pending(&mut self, reduce_id: u64, acc_ptr: String, dtype: String) {
149        self.pending_reduces.insert(reduce_id, PendingReduce { acc_ptr, dtype });
150    }
151
152    /// Take all pending reduces (empties map).
153    pub fn take_pending_reduces(&mut self) -> HashMap<u64, PendingReduce> {
154        std::mem::take(&mut self.pending_reduces)
155    }
156
157    /// Check if there are pending reduces.
158    pub fn has_pending_reduces(&self) -> bool {
159        !self.pending_reduces.is_empty()
160    }
161}
162
163impl Default for RenderContext {
164    fn default() -> Self {
165        Self::new()
166    }
167}