use std::collections::HashMap;
use melior::ir::{BlockRef, Region, Type, Value};
use super::amx::AmxLoopState;
pub struct ScfLoopInfo<'c, 'a> {
pub parent_block: BlockRef<'c, 'a>,
pub region: Region<'c>,
pub range_id: u64,
pub axis_id: usize,
pub range_type: Type<'c>,
pub lb: Value<'c, 'a>,
pub ub: Value<'c, 'a>,
pub step: Value<'c, 'a>,
pub init_values: Vec<Value<'c, 'a>>,
pub result_types: Vec<Type<'c>>,
pub reduce_ids: Vec<u64>,
pub yield_values: Vec<Value<'c, 'a>>,
}
pub struct ScfIfInfo<'c, 'a> {
pub parent_block: BlockRef<'c, 'a>,
pub condition: Value<'c, 'a>,
pub then_region: Region<'c>,
}
pub struct RenderContext<'c, 'a> {
values: HashMap<u64, Value<'c, 'a>>,
scf_loop_stack: Vec<ScfLoopInfo<'c, 'a>>,
scf_if_stack: Vec<(u64, ScfIfInfo<'c, 'a>)>,
current_block: BlockRef<'c, 'a>,
entry_block: BlockRef<'c, 'a>,
amx_loop_state: Option<AmxLoopState<'c>>,
amx_set_emitted: bool,
}
impl<'c, 'a> RenderContext<'c, 'a> {
pub fn new(entry_block: BlockRef<'c, 'a>) -> Self {
Self {
values: HashMap::new(),
scf_loop_stack: Vec::new(),
scf_if_stack: Vec::new(),
current_block: entry_block,
entry_block,
amx_loop_state: None,
amx_set_emitted: false,
}
}
pub fn entry_block(&self) -> BlockRef<'c, 'a> {
self.entry_block
}
pub fn current_block(&self) -> BlockRef<'c, 'a> {
self.current_block
}
pub fn set_current_block(&mut self, block: BlockRef<'c, 'a>) {
self.current_block = block;
}
pub fn register(&mut self, id: u64, value: Value<'c, 'a>) {
self.values.insert(id, value);
}
pub fn get(&self, id: u64) -> Value<'c, 'a> {
self.values.get(&id).copied().unwrap_or_else(|| panic!("UOp {} not in MLIR context", id))
}
pub fn try_get(&self, id: u64) -> Option<Value<'c, 'a>> {
self.values.get(&id).copied()
}
pub fn contains(&self, id: u64) -> bool {
self.values.contains_key(&id)
}
pub fn push_scf_loop(&mut self, info: ScfLoopInfo<'c, 'a>) {
self.scf_loop_stack.push(info);
}
pub fn pop_scf_loop(&mut self) -> ScfLoopInfo<'c, 'a> {
self.scf_loop_stack.pop().expect("scf loop stack underflow")
}
pub fn update_reduce_yield(&mut self, reduce_id: u64, new_value: Value<'c, 'a>) {
for loop_info in self.scf_loop_stack.iter_mut().rev() {
if let Some(idx) = loop_info.reduce_ids.iter().position(|&id| id == reduce_id) {
loop_info.yield_values[idx] = new_value;
return;
}
}
panic!("reduce {} not found in any scf loop on stack", reduce_id);
}
pub fn push_scf_if(&mut self, if_id: u64, info: ScfIfInfo<'c, 'a>) {
self.scf_if_stack.push((if_id, info));
}
pub fn pop_scf_if(&mut self, if_id: u64) -> ScfIfInfo<'c, 'a> {
let idx = self
.scf_if_stack
.iter()
.rposition(|(id, _)| *id == if_id)
.unwrap_or_else(|| panic!("scf.if {} not found on stack", if_id));
self.scf_if_stack.remove(idx).1
}
pub fn set_amx_loop_state(&mut self, state: AmxLoopState<'c>) {
self.amx_loop_state = Some(state);
}
pub fn amx_loop_state(&self) -> Option<&AmxLoopState<'c>> {
self.amx_loop_state.as_ref()
}
pub fn take_amx_loop_state(&mut self) -> Option<AmxLoopState<'c>> {
self.amx_loop_state.take()
}
pub fn amx_set_emitted(&self) -> bool {
self.amx_set_emitted
}
pub fn mark_amx_set_emitted(&mut self) {
self.amx_set_emitted = true;
}
}