cubecl-ir 0.10.0-pre.4

Intermediate representation for CubeCL
Documentation
use alloc::{rc::Rc, vec::Vec};
use core::cell::RefCell;

use hashbrown::HashMap;
use portable_atomic::{AtomicU32, Ordering};

use crate::SemanticType;

use super::{Matrix, Type, Variable, VariableKind};

/// An allocator for local variables of a kernel.
///
/// A local variable is unique to a unit. That is, each unit have their own copy of a local variable.
/// There are three types of local variables based on their capabilities.
///     - An immutable local variable is obtained by calling [`Allocator::create_local`].
///     - A mutable local variable is obtained by calling [`Allocator::create_local_mut`]. The allocator will reuse
///       previously defined mutable variables if possible.
///     - A restricted mutable local variable is obtained by calling [`Allocator::create_local_restricted`]. This a is
///       mutable variable that cannot be reused. This is mostly used for loop indices.
///
/// # Performance tips
///
/// In order, prefer immutable local variables, then mutable, then restricted.
///
/// To enable many compiler optimizations, it is preferred to use the [static single-assignment] strategy for immutable variables.
/// That is, each variable must be declared and used exactly once.
///
/// [static single-assignment](https://en.wikipedia.org/wiki/Static_single-assignment_form)
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Clone, Debug, Default, TypeHash)]
pub struct Allocator {
    #[cfg_attr(feature = "serde", serde(skip))]
    local_mut_pool: Rc<RefCell<HashMap<Type, Vec<ManagedVariable>>>>,
    next_id: Rc<AtomicU32>,
}

impl PartialEq for Allocator {
    fn eq(&self, other: &Self) -> bool {
        Rc::ptr_eq(&self.local_mut_pool, &other.local_mut_pool)
            && Rc::ptr_eq(&self.next_id, &other.next_id)
    }
}
impl Eq for Allocator {}

impl Allocator {
    /// Create a new immutable local variable of type specified by `item`.
    pub fn create_local(&self, item: Type) -> ManagedVariable {
        let id = self.new_local_index();
        let local = VariableKind::LocalConst { id };
        ManagedVariable::Plain(Variable::new(local, item))
    }

    /// Create a new mutable local variable of type specified by `item`.
    /// Try to reuse a previously defined but unused mutable variable if possible.
    /// Else, this define a new variable.
    pub fn create_local_mut(&self, item: Type) -> ManagedVariable {
        if item.is_atomic() {
            self.create_local_restricted(item)
        } else {
            self.reuse_local_mut(item)
                .unwrap_or_else(|| ManagedVariable::Managed(self.add_local_mut(item)))
        }
    }

    /// Create a new mutable restricted local variable of type specified by `item`.
    pub fn create_local_restricted(&self, item: Type) -> ManagedVariable {
        let id = self.new_local_index();
        let local = VariableKind::LocalMut { id };
        ManagedVariable::Plain(Variable::new(local, item))
    }

    pub fn create_local_array(&self, item: Type, array_size: usize) -> ManagedVariable {
        let id = self.new_local_index();
        let local_array = Variable::new(
            VariableKind::LocalArray {
                id,
                length: array_size,
                unroll_factor: 1,
            },
            item,
        );
        ManagedVariable::Plain(local_array)
    }

    /// Create a matrix variable
    pub fn create_matrix(&self, matrix: Matrix) -> ManagedVariable {
        let id = self.new_local_index();
        let variable = Variable::new(
            VariableKind::Matrix { id, mat: matrix },
            Type::new(matrix.storage),
        );
        ManagedVariable::Plain(variable)
    }

    pub fn create_pipeline(&self, num_stages: u8) -> ManagedVariable {
        let id = self.new_local_index();
        let variable = Variable::new(
            VariableKind::Pipeline { id, num_stages },
            SemanticType::Pipeline.into(),
        );
        ManagedVariable::Plain(variable)
    }

    // Try to return a reusable mutable variable for the given `item` or `None` otherwise.
    pub fn reuse_local_mut(&self, item: Type) -> Option<ManagedVariable> {
        // Among the candidates, take a variable if it's only referenced by the pool.
        // Arbitrarily takes the first it finds in reversed order.
        self.local_mut_pool.borrow().get(&item).and_then(|vars| {
            vars.iter()
                .rev()
                .find(|var| matches!(var, ManagedVariable::Managed(v) if Rc::strong_count(v) == 1))
                .cloned()
        })
    }

    /// Add a new variable to the pool with type specified by `item` for the given `scope`.
    pub fn add_local_mut(&self, item: Type) -> Rc<Variable> {
        let id = self.new_local_index();
        let local = Variable::new(VariableKind::LocalMut { id }, item);
        let var = Rc::new(local);
        let expand = ManagedVariable::Managed(var.clone());
        let mut pool = self.local_mut_pool.borrow_mut();
        let variables = pool.entry(item).or_default();
        variables.push(expand);
        var
    }

    pub fn new_local_index(&self) -> u32 {
        self.next_id.fetch_add(1, Ordering::Release)
    }

    pub fn take_variables(&self) -> Vec<Variable> {
        self.local_mut_pool
            .borrow_mut()
            .drain()
            .flat_map(|it| it.1)
            .map(|it| *it)
            .collect()
    }
}

use cubecl_macros_internal::TypeHash;
pub use expand_element::*;

mod expand_element {
    use cubecl_common::{e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
    use half::{bf16, f16};

    use super::*;

    /// Reference to a JIT variable
    #[derive(Clone, Debug, TypeHash)]
    pub enum ManagedVariable {
        /// Variable kept in the variable pool.
        Managed(Rc<Variable>),
        /// Variable not kept in the variable pool.
        Plain(Variable),
    }

    impl core::ops::Deref for ManagedVariable {
        type Target = Variable;

        fn deref(&self) -> &Self::Target {
            match self {
                ManagedVariable::Managed(var) => var.as_ref(),
                ManagedVariable::Plain(var) => var,
            }
        }
    }

    impl From<ManagedVariable> for Variable {
        fn from(value: ManagedVariable) -> Self {
            match value {
                ManagedVariable::Managed(var) => *var,
                ManagedVariable::Plain(var) => var,
            }
        }
    }

    impl ManagedVariable {
        /// If the element can be mutated inplace, potentially reusing the register.
        pub fn can_mut(&self) -> bool {
            match self {
                ManagedVariable::Managed(var) => {
                    if let VariableKind::LocalMut { .. } = var.as_ref().kind {
                        Rc::strong_count(var) <= 2
                    } else {
                        false
                    }
                }
                ManagedVariable::Plain(_) => false,
            }
        }

        /// Explicitly consume the element, freeing it for reuse if no other copies exist.
        pub fn consume(self) -> Variable {
            *self
        }
    }

    macro_rules! impl_into_expand_element {
        ($type:ty) => {
            impl From<$type> for ManagedVariable {
                fn from(value: $type) -> Self {
                    ManagedVariable::Plain(Variable::from(value))
                }
            }
        };
    }

    impl_into_expand_element!(u8);
    impl_into_expand_element!(u16);
    impl_into_expand_element!(u32);
    impl_into_expand_element!(u64);
    impl_into_expand_element!(usize);
    impl_into_expand_element!(isize);
    impl_into_expand_element!(bool);
    impl_into_expand_element!(e2m1);
    impl_into_expand_element!(e2m1x2);
    impl_into_expand_element!(e2m3);
    impl_into_expand_element!(e3m2);
    impl_into_expand_element!(e4m3);
    impl_into_expand_element!(e5m2);
    impl_into_expand_element!(ue8m0);
    impl_into_expand_element!(flex32);
    impl_into_expand_element!(f16);
    impl_into_expand_element!(bf16);
    impl_into_expand_element!(tf32);
    impl_into_expand_element!(f32);
    impl_into_expand_element!(i8);
    impl_into_expand_element!(i16);
    impl_into_expand_element!(i32);
    impl_into_expand_element!(i64);
}