cubecl_core/ir/
allocator.rs

1use std::{
2    cell::RefCell,
3    collections::HashMap,
4    rc::Rc,
5    sync::atomic::{AtomicU32, Ordering},
6};
7
8use crate::prelude::ExpandElement;
9
10use super::{Item, Matrix, Variable, VariableKind};
11
12/// An allocator for local variables of a kernel.
13///
14/// A local variable is unique to a unit. That is, each unit have their own copy of a local variable.
15/// There are three types of local variables based on their capabilities.
16///     - An immutable local variable is obtained by calling [Allocator::create_local].
17///     - A mutable local variable is obtained by calling [Allocator::create_local_mut]. The allocator will reuse
18///       previously defined mutable variables if possible.
19///     - A restricted mutable local variable is obtained by calling [Allocator::create_local_restricted]. This a is
20///       mutable variable that cannot be reused. This is mostly used for loop indices.
21///
22/// # Performance tips
23///
24/// In order, prefer immutable local variables, then mutable, then restricted.
25///
26/// To enable many compiler optimizations, it is prefered to use the [static single-assignment] strategy for immutable variables.
27/// That is, each variable must be declared and used exactly once.
28///
29/// [static single-assignment](https://en.wikipedia.org/wiki/Static_single-assignment_form)
30#[derive(Clone, Debug, Default)]
31pub struct Allocator {
32    local_mut_pool: Rc<RefCell<HashMap<Item, Vec<ExpandElement>>>>,
33    next_id: Rc<AtomicU32>,
34}
35
36impl PartialEq for Allocator {
37    fn eq(&self, other: &Self) -> bool {
38        Rc::ptr_eq(&self.local_mut_pool, &other.local_mut_pool)
39            && Rc::ptr_eq(&self.next_id, &other.next_id)
40    }
41}
42
43impl Allocator {
44    /// Create a new immutable local variable of type specified by `item`.
45    pub fn create_local(&self, item: Item) -> ExpandElement {
46        let id = self.new_local_index();
47        let local = VariableKind::LocalConst { id };
48        ExpandElement::Plain(Variable::new(local, item))
49    }
50
51    /// Create a new mutable local variable of type specified by `item`.
52    /// Try to reuse a previously defined but unused mutable variable if possible.
53    /// Else, this define a new variable.
54    pub fn create_local_mut(&self, item: Item) -> ExpandElement {
55        if item.elem.is_atomic() {
56            self.create_local_restricted(item)
57        } else {
58            self.reuse_local_mut(item)
59                .unwrap_or_else(|| ExpandElement::Managed(self.add_local_mut(item)))
60        }
61    }
62
63    /// Create a new mutable restricted local variable of type specified by `item`.
64    pub fn create_local_restricted(&self, item: Item) -> ExpandElement {
65        let id = self.new_local_index();
66        let local = VariableKind::LocalMut { id };
67        ExpandElement::Plain(Variable::new(local, item))
68    }
69
70    pub fn create_local_array(&self, item: Item, array_size: u32) -> ExpandElement {
71        let id = self.new_local_index();
72        let local_array = Variable::new(
73            VariableKind::LocalArray {
74                id,
75                length: array_size,
76            },
77            item,
78        );
79        ExpandElement::Plain(local_array)
80    }
81
82    /// Create a slice variable
83    pub fn create_slice(&self, item: Item) -> ExpandElement {
84        let id = self.new_local_index();
85        let variable = Variable::new(VariableKind::Slice { id }, item);
86        ExpandElement::Plain(variable)
87    }
88
89    /// Create a matrix variable
90    pub fn create_matrix(&self, matrix: Matrix) -> ExpandElement {
91        let id = self.new_local_index();
92        let variable = Variable::new(
93            VariableKind::Matrix { id, mat: matrix },
94            Item::new(matrix.elem),
95        );
96        ExpandElement::Plain(variable)
97    }
98
99    // Try to return a reusable mutable variable for the given `item` or `None` otherwise.
100    fn reuse_local_mut(&self, item: Item) -> Option<ExpandElement> {
101        // Among the candidates, take a variable if it's only referenced by the pool.
102        // Arbitrarily takes the first it finds in reversed order.
103        self.local_mut_pool.borrow().get(&item).and_then(|vars| {
104            vars.iter()
105                .rev()
106                .find(|var| matches!(var, ExpandElement::Managed(v) if Rc::strong_count(v) == 1))
107                .cloned()
108        })
109    }
110
111    /// Add a new variable to the pool with type specified by `item` for the given `scope`.
112    pub fn add_local_mut(&self, item: Item) -> Rc<Variable> {
113        let id = self.new_local_index();
114        let local = Variable::new(VariableKind::LocalMut { id }, item);
115        let var = Rc::new(local);
116        let expand = ExpandElement::Managed(var.clone());
117        let mut pool = self.local_mut_pool.borrow_mut();
118        let variables = pool.entry(item).or_default();
119        variables.push(expand);
120        var
121    }
122
123    pub fn new_local_index(&self) -> u32 {
124        self.next_id.fetch_add(1, Ordering::Release)
125    }
126}