Skip to main content

morok_codegen/llvm/common/
ctx.rs

1//! Render context for LLVM IR text generation.
2//!
3//! Maps UOp IDs to LLVM variable names and manages naming.
4//! Shared between CPU and GPU backends.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use morok_ir::{ConstValue, Op, prelude::*};
10
11use super::types::{lconst, ldt};
12
13/// Pending reduce load info.
14pub struct PendingReduce {
15    pub acc_ptr: String,
16    pub dtype: String,
17}
18
19/// Maps UOp ID → LLVM variable name.
20pub struct RenderContext {
21    names: HashMap<u64, String>,
22    range_values: HashMap<usize, String>,
23    counter: usize,
24    /// Pending reduce final loads: reduce_id -> (acc_ptr, dtype)
25    pending_reduces: HashMap<u64, PendingReduce>,
26}
27
28impl RenderContext {
29    pub fn new() -> Self {
30        Self { names: HashMap::new(), range_values: HashMap::new(), counter: 0, pending_reduces: HashMap::new() }
31    }
32
33    /// Get or create variable name for UOp.
34    ///
35    /// For constants, returns literal value.
36    /// For definitions, returns argument name.
37    /// For other ops, returns a generated variable name.
38    pub fn name(&mut self, uop: &Arc<UOp>) -> String {
39        if let Some(name) = self.names.get(&uop.id) {
40            return name.clone();
41        }
42
43        let name = match uop.op() {
44            Op::Const(cv) => lconst(&cv.0, &uop.dtype()),
45            Op::VConst { values } => self.render_vconst(values, uop),
46            Op::DefineGlobal(id) => format!("%data{id}"),
47            Op::DefineLocal(id) => format!("%local{id}"),
48            Op::DefineVar { name, .. } => format!("%{name}"),
49            Op::DefineReg { .. } => {
50                let n = format!("%reg{}", self.counter);
51                self.counter += 1;
52                n
53            }
54            Op::Range { axis_id, .. } => {
55                // Range variables are named by axis_id
56                format!("%r{}", axis_id.value())
57            }
58            _ => {
59                let n = format!("%v{}", self.counter);
60                self.counter += 1;
61                n
62            }
63        };
64
65        self.names.insert(uop.id, name.clone());
66        name
67    }
68
69    /// Render a vector constant.
70    fn render_vconst(&self, values: &[ConstValue], uop: &Arc<UOp>) -> String {
71        let scalar_type = ldt(&uop.dtype().scalar_dtype());
72
73        // Format as LLVM vector constant: <type val, type val, ...>
74        let elements: Vec<String> = values
75            .iter()
76            .map(|v| {
77                let val = lconst(v, &uop.dtype());
78                format!("{scalar_type} {val}")
79            })
80            .collect();
81
82        format!("<{}>", elements.join(", "))
83    }
84
85    /// Get existing name (panics if not found).
86    pub fn get(&self, uop: &Arc<UOp>) -> &str {
87        self.names
88            .get(&uop.id)
89            .map(|s| s.as_str())
90            .unwrap_or_else(|| panic!("UOp {} ({:?}) not in context", uop.id, uop.op()))
91    }
92
93    /// Try to get existing name.
94    pub fn try_get(&self, uop: &Arc<UOp>) -> Option<&str> {
95        self.names.get(&uop.id).map(|s| s.as_str())
96    }
97
98    /// Check if a UOp is already registered.
99    pub fn contains(&self, id: u64) -> bool {
100        self.names.contains_key(&id)
101    }
102
103    /// Alias one ID to another's name.
104    pub fn alias(&mut self, id: u64, name: String) {
105        self.names.insert(id, name);
106    }
107
108    /// Pre-register a name for a UOp ID.
109    pub fn register(&mut self, id: u64, name: String) {
110        self.names.insert(id, name);
111    }
112
113    /// Get current variable counter.
114    pub fn counter(&self) -> usize {
115        self.counter
116    }
117
118    /// Register a range value by axis_id.
119    pub fn register_range(&mut self, axis_id: usize, name: String) {
120        self.range_values.insert(axis_id, name);
121    }
122
123    /// Get a range value by axis_id.
124    pub fn get_range(&self, axis_id: usize) -> Option<&str> {
125        self.range_values.get(&axis_id).map(|s| s.as_str())
126    }
127
128    /// Register a pending reduce final load.
129    pub fn register_reduce_pending(&mut self, reduce_id: u64, acc_ptr: String, dtype: String) {
130        self.pending_reduces.insert(reduce_id, PendingReduce { acc_ptr, dtype });
131    }
132
133    /// Take all pending reduces (empties map).
134    pub fn take_pending_reduces(&mut self) -> HashMap<u64, PendingReduce> {
135        std::mem::take(&mut self.pending_reduces)
136    }
137
138    /// Check if there are pending reduces.
139    pub fn has_pending_reduces(&self) -> bool {
140        !self.pending_reduces.is_empty()
141    }
142}
143
144impl Default for RenderContext {
145    fn default() -> Self {
146        Self::new()
147    }
148}