use alloc::{borrow::Cow, rc::Rc, string::String, string::ToString, vec::Vec};
use core::{any::TypeId, cell::RefCell, fmt::Display};
use enumset::EnumSet;
use hashbrown::{HashMap, HashSet};
use crate::{
BarrierLevel, CubeFnSource, DeviceProperties, FastMath, ManagedVariable, Matrix, Processor,
SemanticType, SourceLoc, StorageType, TargetProperties, TypeHash,
};
use super::{
Allocator, Id, Instruction, Type, Variable, VariableKind, processing::ScopeProcessing,
};
pub type TypeMap = Rc<RefCell<HashMap<TypeId, StorageType>>>;
pub type SizeMap = Rc<RefCell<HashMap<TypeId, usize>>>;
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
#[allow(missing_docs)]
pub struct Scope {
validation_errors: ValidationErrors,
pub depth: u8,
pub instructions: Vec<Instruction>,
pub locals: Vec<Variable>,
matrices: Vec<Variable>,
pipelines: Vec<Variable>,
shared: Vec<Variable>,
pub const_arrays: Vec<(Variable, Vec<Variable>)>,
local_arrays: Vec<Variable>,
index_offset_with_output_layout_position: Vec<usize>,
pub allocator: Allocator,
pub debug: DebugInfo,
#[type_hash(skip)]
#[cfg_attr(feature = "serde", serde(skip))]
pub typemap: TypeMap,
#[type_hash(skip)]
#[cfg_attr(feature = "serde", serde(skip))]
pub sizemap: SizeMap,
pub runtime_properties: Rc<TargetProperties>,
pub modes: Rc<RefCell<InstructionModes>>,
#[cfg_attr(feature = "serde", serde(skip))]
pub properties: Option<Rc<DeviceProperties>>,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
pub struct ValidationErrors {
errors: Rc<RefCell<Vec<String>>>,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
pub struct DebugInfo {
pub enabled: bool,
pub sources: Rc<RefCell<HashSet<CubeFnSource>>>,
pub variable_names: Rc<RefCell<HashMap<Variable, Cow<'static, str>>>>,
pub source_loc: Option<SourceLoc>,
pub entry_loc: Option<SourceLoc>,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, TypeHash)]
pub struct InstructionModes {
pub fp_math_mode: EnumSet<FastMath>,
}
impl core::hash::Hash for Scope {
fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
self.depth.hash(ra_expand_state);
self.instructions.hash(ra_expand_state);
self.locals.hash(ra_expand_state);
self.matrices.hash(ra_expand_state);
self.pipelines.hash(ra_expand_state);
self.shared.hash(ra_expand_state);
self.const_arrays.hash(ra_expand_state);
self.local_arrays.hash(ra_expand_state);
self.index_offset_with_output_layout_position
.hash(ra_expand_state);
}
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash)]
#[allow(missing_docs)]
pub enum ReadingStrategy {
OutputLayout,
Plain,
}
impl Scope {
pub fn device_properties(&mut self, properties: &DeviceProperties) {
self.properties = Some(Rc::new(properties.clone()));
}
pub fn root(debug_enabled: bool) -> Self {
Self {
validation_errors: ValidationErrors {
errors: Rc::new(RefCell::new(Vec::new())),
},
depth: 0,
instructions: Vec::new(),
locals: Vec::new(),
matrices: Vec::new(),
pipelines: Vec::new(),
local_arrays: Vec::new(),
shared: Vec::new(),
const_arrays: Vec::new(),
index_offset_with_output_layout_position: Vec::new(),
allocator: Allocator::default(),
debug: DebugInfo {
enabled: debug_enabled,
sources: Default::default(),
variable_names: Default::default(),
source_loc: None,
entry_loc: None,
},
typemap: Default::default(),
sizemap: Default::default(),
runtime_properties: Rc::new(Default::default()),
modes: Default::default(),
properties: None,
}
}
pub fn with_allocator(mut self, allocator: Allocator) -> Self {
self.allocator = allocator;
self
}
pub fn with_types(mut self, typemap: TypeMap) -> Self {
self.typemap = typemap;
self
}
pub fn create_matrix(&mut self, matrix: Matrix) -> ManagedVariable {
let matrix = self.allocator.create_matrix(matrix);
self.add_matrix(*matrix);
matrix
}
pub fn add_matrix(&mut self, variable: Variable) {
self.matrices.push(variable);
}
pub fn create_pipeline(&mut self, num_stages: u8) -> ManagedVariable {
let pipeline = self.allocator.create_pipeline(num_stages);
self.add_pipeline(*pipeline);
pipeline
}
pub fn create_barrier_token(&mut self, id: Id, level: BarrierLevel) -> ManagedVariable {
let token = Variable::new(
VariableKind::BarrierToken { id, level },
Type::semantic(SemanticType::BarrierToken),
);
ManagedVariable::Plain(token)
}
pub fn add_pipeline(&mut self, variable: Variable) {
self.pipelines.push(variable);
}
pub fn create_local_mut<I: Into<Type>>(&mut self, item: I) -> ManagedVariable {
self.allocator.create_local_mut(item.into())
}
pub fn add_local_mut(&mut self, var: Variable) {
if !self.locals.contains(&var) {
self.locals.push(var);
}
}
pub fn create_local_restricted(&mut self, item: Type) -> ManagedVariable {
self.allocator.create_local_restricted(item)
}
pub fn create_local(&mut self, item: Type) -> ManagedVariable {
self.allocator.create_local(item)
}
pub fn last_local_index(&self) -> Option<&Variable> {
self.locals.last()
}
pub fn register<T: Into<Instruction>>(&mut self, instruction: T) {
let mut inst = instruction.into();
inst.source_loc = self.debug.source_loc.clone();
inst.modes = *self.modes.borrow();
self.instructions.push(inst)
}
pub fn resolve_type<T: 'static>(&self) -> Option<StorageType> {
let map = self.typemap.borrow();
let result = map.get(&TypeId::of::<T>());
result.cloned()
}
pub fn resolve_size<T: 'static>(&self) -> Option<usize> {
let map = self.sizemap.borrow();
let result = map.get(&TypeId::of::<T>());
result.cloned()
}
pub fn register_type<T: 'static>(&mut self, elem: StorageType) {
let mut map = self.typemap.borrow_mut();
map.insert(TypeId::of::<T>(), elem);
}
pub fn register_size<T: 'static>(&mut self, size: usize) {
let mut map = self.sizemap.borrow_mut();
map.insert(TypeId::of::<T>(), size);
}
pub fn child(&mut self) -> Self {
Self {
validation_errors: self.validation_errors.clone(),
depth: self.depth + 1,
instructions: Vec::new(),
locals: Vec::new(),
matrices: Vec::new(),
pipelines: Vec::new(),
shared: Vec::new(),
const_arrays: Vec::new(),
local_arrays: Vec::new(),
index_offset_with_output_layout_position: Vec::new(),
allocator: self.allocator.clone(),
debug: self.debug.clone(),
typemap: self.typemap.clone(),
sizemap: self.sizemap.clone(),
runtime_properties: self.runtime_properties.clone(),
modes: self.modes.clone(),
properties: self.properties.clone(),
}
}
pub fn push_error(&mut self, msg: impl Into<String>) {
self.validation_errors.errors.borrow_mut().push(msg.into());
}
pub fn pop_errors(&mut self) -> Vec<String> {
self.validation_errors.errors.replace_with(|_| Vec::new())
}
pub fn process<'a>(
&mut self,
processors: impl IntoIterator<Item = &'a dyn Processor>,
) -> ScopeProcessing {
let mut variables = core::mem::take(&mut self.locals);
for var in self.matrices.drain(..) {
variables.push(var);
}
let mut instructions = Vec::new();
for inst in self.instructions.drain(..) {
instructions.push(inst);
}
variables.extend(self.allocator.take_variables());
let mut processing = ScopeProcessing {
variables,
instructions,
typemap: self.typemap.clone(),
};
for p in processors {
processing = p.transform(processing, self.allocator.clone());
}
processing.variables.extend(self.allocator.take_variables());
processing
}
pub fn new_local_index(&self) -> u32 {
self.allocator.new_local_index()
}
pub fn create_shared_array<I: Into<Type>>(
&mut self,
item: I,
shared_memory_size: usize,
alignment: Option<usize>,
) -> ManagedVariable {
let item = item.into();
let index = self.new_local_index();
let shared_array = Variable::new(
VariableKind::SharedArray {
id: index,
length: shared_memory_size,
unroll_factor: 1,
alignment,
},
item,
);
self.shared.push(shared_array);
ManagedVariable::Plain(shared_array)
}
pub fn create_shared<I: Into<Type>>(&mut self, item: I) -> ManagedVariable {
let item = item.into();
let index = self.new_local_index();
let shared = Variable::new(VariableKind::Shared { id: index }, item);
self.shared.push(shared);
ManagedVariable::Plain(shared)
}
pub fn create_const_array<I: Into<Type>>(
&mut self,
item: I,
data: Vec<Variable>,
) -> ManagedVariable {
let item = item.into();
let index = self.new_local_index();
let const_array = Variable::new(
VariableKind::ConstantArray {
id: index,
length: data.len(),
unroll_factor: 1,
},
item,
);
self.const_arrays.push((const_array, data));
ManagedVariable::Plain(const_array)
}
pub fn input(&mut self, id: Id, item: Type) -> ManagedVariable {
ManagedVariable::Plain(crate::Variable::new(
VariableKind::GlobalInputArray(id),
item,
))
}
pub fn output(&mut self, id: Id, item: Type) -> ManagedVariable {
let var = crate::Variable::new(VariableKind::GlobalOutputArray(id), item);
ManagedVariable::Plain(var)
}
pub fn scalar(&self, id: Id, storage: StorageType) -> ManagedVariable {
ManagedVariable::Plain(crate::Variable::new(
VariableKind::GlobalScalar(id),
Type::new(storage),
))
}
pub fn create_local_array<I: Into<Type>>(
&mut self,
item: I,
array_size: usize,
) -> ManagedVariable {
let local_array = self.allocator.create_local_array(item.into(), array_size);
self.add_local_array(*local_array);
local_array
}
pub fn add_local_array(&mut self, var: Variable) {
self.local_arrays.push(var);
}
pub fn update_source(&mut self, source: CubeFnSource) {
if self.debug.enabled {
self.debug.sources.borrow_mut().insert(source.clone());
self.debug.source_loc = Some(SourceLoc {
line: source.line,
column: source.column,
source,
});
if self.debug.entry_loc.is_none() {
self.debug.entry_loc = self.debug.source_loc.clone();
}
}
}
pub fn update_span(&mut self, line: u32, col: u32) {
if let Some(loc) = self.debug.source_loc.as_mut() {
loc.line = line;
loc.column = col;
}
}
pub fn update_variable_name(&self, variable: Variable, name: impl Into<Cow<'static, str>>) {
if self.debug.enabled {
self.debug
.variable_names
.borrow_mut()
.insert(variable, name.into());
}
}
}
impl Display for Scope {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
writeln!(f, "{{")?;
for instruction in self.instructions.iter() {
let instruction_str = instruction.to_string();
if !instruction_str.is_empty() {
writeln!(
f,
"{}{}",
" ".repeat(self.depth as usize + 1),
instruction_str,
)?;
}
}
write!(f, "{}}}", " ".repeat(self.depth as usize))?;
Ok(())
}
}