use alloc::{rc::Rc, vec::Vec};
use core::cell::RefCell;
use hashbrown::HashMap;
use portable_atomic::{AtomicU32, Ordering};
use crate::SemanticType;
use super::{Matrix, Type, Variable, VariableKind};
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Clone, Debug, Default, TypeHash)]
pub struct Allocator {
#[cfg_attr(feature = "serde", serde(skip))]
local_mut_pool: Rc<RefCell<HashMap<Type, Vec<ManagedVariable>>>>,
next_id: Rc<AtomicU32>,
}
impl PartialEq for Allocator {
fn eq(&self, other: &Self) -> bool {
Rc::ptr_eq(&self.local_mut_pool, &other.local_mut_pool)
&& Rc::ptr_eq(&self.next_id, &other.next_id)
}
}
impl Eq for Allocator {}
impl Allocator {
pub fn create_local(&self, item: Type) -> ManagedVariable {
let id = self.new_local_index();
let local = VariableKind::LocalConst { id };
ManagedVariable::Plain(Variable::new(local, item))
}
pub fn create_local_mut(&self, item: Type) -> ManagedVariable {
if item.is_atomic() {
self.create_local_restricted(item)
} else {
self.reuse_local_mut(item)
.unwrap_or_else(|| ManagedVariable::Managed(self.add_local_mut(item)))
}
}
pub fn create_local_restricted(&self, item: Type) -> ManagedVariable {
let id = self.new_local_index();
let local = VariableKind::LocalMut { id };
ManagedVariable::Plain(Variable::new(local, item))
}
pub fn create_local_array(&self, item: Type, array_size: usize) -> ManagedVariable {
let id = self.new_local_index();
let local_array = Variable::new(
VariableKind::LocalArray {
id,
length: array_size,
unroll_factor: 1,
},
item,
);
ManagedVariable::Plain(local_array)
}
pub fn create_matrix(&self, matrix: Matrix) -> ManagedVariable {
let id = self.new_local_index();
let variable = Variable::new(
VariableKind::Matrix { id, mat: matrix },
Type::new(matrix.storage),
);
ManagedVariable::Plain(variable)
}
pub fn create_pipeline(&self, num_stages: u8) -> ManagedVariable {
let id = self.new_local_index();
let variable = Variable::new(
VariableKind::Pipeline { id, num_stages },
SemanticType::Pipeline.into(),
);
ManagedVariable::Plain(variable)
}
pub fn reuse_local_mut(&self, item: Type) -> Option<ManagedVariable> {
self.local_mut_pool.borrow().get(&item).and_then(|vars| {
vars.iter()
.rev()
.find(|var| matches!(var, ManagedVariable::Managed(v) if Rc::strong_count(v) == 1))
.cloned()
})
}
pub fn add_local_mut(&self, item: Type) -> Rc<Variable> {
let id = self.new_local_index();
let local = Variable::new(VariableKind::LocalMut { id }, item);
let var = Rc::new(local);
let expand = ManagedVariable::Managed(var.clone());
let mut pool = self.local_mut_pool.borrow_mut();
let variables = pool.entry(item).or_default();
variables.push(expand);
var
}
pub fn new_local_index(&self) -> u32 {
self.next_id.fetch_add(1, Ordering::Release)
}
pub fn take_variables(&self) -> Vec<Variable> {
self.local_mut_pool
.borrow_mut()
.drain()
.flat_map(|it| it.1)
.map(|it| *it)
.collect()
}
}
use cubecl_macros_internal::TypeHash;
pub use expand_element::*;
mod expand_element {
use cubecl_common::{e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
use half::{bf16, f16};
use super::*;
#[derive(Clone, Debug, TypeHash)]
pub enum ManagedVariable {
Managed(Rc<Variable>),
Plain(Variable),
}
impl core::ops::Deref for ManagedVariable {
type Target = Variable;
fn deref(&self) -> &Self::Target {
match self {
ManagedVariable::Managed(var) => var.as_ref(),
ManagedVariable::Plain(var) => var,
}
}
}
impl From<ManagedVariable> for Variable {
fn from(value: ManagedVariable) -> Self {
match value {
ManagedVariable::Managed(var) => *var,
ManagedVariable::Plain(var) => var,
}
}
}
impl ManagedVariable {
pub fn can_mut(&self) -> bool {
match self {
ManagedVariable::Managed(var) => {
if let VariableKind::LocalMut { .. } = var.as_ref().kind {
Rc::strong_count(var) <= 2
} else {
false
}
}
ManagedVariable::Plain(_) => false,
}
}
pub fn consume(self) -> Variable {
*self
}
}
macro_rules! impl_into_expand_element {
($type:ty) => {
impl From<$type> for ManagedVariable {
fn from(value: $type) -> Self {
ManagedVariable::Plain(Variable::from(value))
}
}
};
}
impl_into_expand_element!(u8);
impl_into_expand_element!(u16);
impl_into_expand_element!(u32);
impl_into_expand_element!(u64);
impl_into_expand_element!(usize);
impl_into_expand_element!(isize);
impl_into_expand_element!(bool);
impl_into_expand_element!(e2m1);
impl_into_expand_element!(e2m1x2);
impl_into_expand_element!(e2m3);
impl_into_expand_element!(e3m2);
impl_into_expand_element!(e4m3);
impl_into_expand_element!(e5m2);
impl_into_expand_element!(ue8m0);
impl_into_expand_element!(flex32);
impl_into_expand_element!(f16);
impl_into_expand_element!(bf16);
impl_into_expand_element!(tf32);
impl_into_expand_element!(f32);
impl_into_expand_element!(i8);
impl_into_expand_element!(i16);
impl_into_expand_element!(i32);
impl_into_expand_element!(i64);
}