use std::collections::HashMap;
use std::sync::Arc;
use morok_ir::{ConstValue, Op, prelude::*};
use super::types::{lconst, ldt};
pub struct PendingReduce {
pub acc_ptr: String,
pub dtype: String,
}
pub struct RenderContext {
names: HashMap<u64, String>,
range_values: HashMap<usize, String>,
counter: usize,
pending_reduces: HashMap<u64, PendingReduce>,
range_stack: Vec<usize>,
}
impl RenderContext {
pub fn new() -> Self {
Self {
names: HashMap::new(),
range_values: HashMap::new(),
counter: 0,
pending_reduces: HashMap::new(),
range_stack: Vec::new(),
}
}
pub fn name(&mut self, uop: &Arc<UOp>) -> String {
if let Some(name) = self.names.get(&uop.id) {
return name.clone();
}
let name = match uop.op() {
Op::Const(cv) => lconst(&cv.0, &uop.dtype()),
Op::VConst { values } => self.render_vconst(values, uop),
Op::Param { slot, device: None, .. } => format!("%data{slot}"),
Op::DefineLocal(id) => format!("%local{id}"),
Op::DefineVar { name, .. } => format!("%{name}"),
Op::DefineReg { .. } => {
let n = format!("%reg{}", self.counter);
self.counter += 1;
n
}
Op::Range { axis_id, .. } => {
format!("%r{}", axis_id.value())
}
_ => {
let n = format!("%v{}", self.counter);
self.counter += 1;
n
}
};
self.names.insert(uop.id, name.clone());
name
}
fn render_vconst(&self, values: &[ConstValue], uop: &Arc<UOp>) -> String {
let scalar_type = ldt(&uop.dtype().scalar_dtype());
let elements: Vec<String> = values
.iter()
.map(|v| {
let val = lconst(v, &uop.dtype());
format!("{scalar_type} {val}")
})
.collect();
format!("<{}>", elements.join(", "))
}
pub fn get(&self, uop: &Arc<UOp>) -> &str {
self.names
.get(&uop.id)
.map(|s| s.as_str())
.unwrap_or_else(|| panic!("UOp {} ({:?}) not in context", uop.id, uop.op()))
}
pub fn try_get(&self, uop: &Arc<UOp>) -> Option<&str> {
self.names.get(&uop.id).map(|s| s.as_str())
}
pub fn contains(&self, id: u64) -> bool {
self.names.contains_key(&id)
}
pub fn alias(&mut self, id: u64, name: String) {
self.names.insert(id, name);
}
pub fn register(&mut self, id: u64, name: String) {
self.names.insert(id, name);
}
pub fn counter(&self) -> usize {
self.counter
}
pub fn register_range(&mut self, axis_id: usize, name: String) {
self.range_values.insert(axis_id, name);
}
pub fn get_range(&self, axis_id: usize) -> Option<&str> {
self.range_values.get(&axis_id).map(|s| s.as_str())
}
pub fn push_range(&mut self, axis_id: usize) {
self.range_stack.push(axis_id);
}
pub fn pop_range(&mut self) -> Option<usize> {
self.range_stack.pop()
}
pub fn register_reduce_pending(&mut self, reduce_id: u64, acc_ptr: String, dtype: String) {
self.pending_reduces.insert(reduce_id, PendingReduce { acc_ptr, dtype });
}
pub fn take_pending_reduces(&mut self) -> HashMap<u64, PendingReduce> {
std::mem::take(&mut self.pending_reduces)
}
pub fn has_pending_reduces(&self) -> bool {
!self.pending_reduces.is_empty()
}
}
impl Default for RenderContext {
fn default() -> Self {
Self::new()
}
}