cubecl_core/ir/
allocator.rs1use std::{
2 cell::RefCell,
3 collections::HashMap,
4 rc::Rc,
5 sync::atomic::{AtomicU32, Ordering},
6};
7
8use crate::prelude::ExpandElement;
9
10use super::{Item, Matrix, Variable, VariableKind};
11
12#[derive(Clone, Debug, Default)]
31pub struct Allocator {
32 local_mut_pool: Rc<RefCell<HashMap<Item, Vec<ExpandElement>>>>,
33 next_id: Rc<AtomicU32>,
34}
35
36impl PartialEq for Allocator {
37 fn eq(&self, other: &Self) -> bool {
38 Rc::ptr_eq(&self.local_mut_pool, &other.local_mut_pool)
39 && Rc::ptr_eq(&self.next_id, &other.next_id)
40 }
41}
42
43impl Allocator {
44 pub fn create_local(&self, item: Item) -> ExpandElement {
46 let id = self.new_local_index();
47 let local = VariableKind::LocalConst { id };
48 ExpandElement::Plain(Variable::new(local, item))
49 }
50
51 pub fn create_local_mut(&self, item: Item) -> ExpandElement {
55 if item.elem.is_atomic() {
56 self.create_local_restricted(item)
57 } else {
58 self.reuse_local_mut(item)
59 .unwrap_or_else(|| ExpandElement::Managed(self.add_local_mut(item)))
60 }
61 }
62
63 pub fn create_local_restricted(&self, item: Item) -> ExpandElement {
65 let id = self.new_local_index();
66 let local = VariableKind::LocalMut { id };
67 ExpandElement::Plain(Variable::new(local, item))
68 }
69
70 pub fn create_local_array(&self, item: Item, array_size: u32) -> ExpandElement {
71 let id = self.new_local_index();
72 let local_array = Variable::new(
73 VariableKind::LocalArray {
74 id,
75 length: array_size,
76 },
77 item,
78 );
79 ExpandElement::Plain(local_array)
80 }
81
82 pub fn create_slice(&self, item: Item) -> ExpandElement {
84 let id = self.new_local_index();
85 let variable = Variable::new(VariableKind::Slice { id }, item);
86 ExpandElement::Plain(variable)
87 }
88
89 pub fn create_matrix(&self, matrix: Matrix) -> ExpandElement {
91 let id = self.new_local_index();
92 let variable = Variable::new(
93 VariableKind::Matrix { id, mat: matrix },
94 Item::new(matrix.elem),
95 );
96 ExpandElement::Plain(variable)
97 }
98
99 fn reuse_local_mut(&self, item: Item) -> Option<ExpandElement> {
101 self.local_mut_pool.borrow().get(&item).and_then(|vars| {
104 vars.iter()
105 .rev()
106 .find(|var| matches!(var, ExpandElement::Managed(v) if Rc::strong_count(v) == 1))
107 .cloned()
108 })
109 }
110
111 pub fn add_local_mut(&self, item: Item) -> Rc<Variable> {
113 let id = self.new_local_index();
114 let local = Variable::new(VariableKind::LocalMut { id }, item);
115 let var = Rc::new(local);
116 let expand = ExpandElement::Managed(var.clone());
117 let mut pool = self.local_mut_pool.borrow_mut();
118 let variables = pool.entry(item).or_default();
119 variables.push(expand);
120 var
121 }
122
123 pub fn new_local_index(&self) -> u32 {
124 self.next_id.fetch_add(1, Ordering::Release)
125 }
126}