cubecl_core/frontend/container/
shared_memory.rs

1use std::{marker::PhantomData, num::NonZero};
2
3use crate::{
4    frontend::{CubePrimitive, CubeType, ExpandElementTyped, Init, indexation::Index},
5    ir::{Item, Scope},
6    prelude::{Line, List, ListExpand, ListMut, ListMutExpand, index, index_assign},
7};
8
9#[derive(Clone, Copy)]
10pub struct SharedMemory<T: CubeType> {
11    _val: PhantomData<T>,
12}
13
14impl<T: CubePrimitive> Init for ExpandElementTyped<SharedMemory<T>> {
15    fn init(self, _scope: &mut Scope) -> Self {
16        self
17    }
18}
19
20impl<T: CubePrimitive> CubeType for SharedMemory<T> {
21    type ExpandType = ExpandElementTyped<SharedMemory<T>>;
22}
23
24impl<T: CubePrimitive + Clone> SharedMemory<T> {
25    pub fn new<S: Index>(_size: S) -> Self {
26        SharedMemory { _val: PhantomData }
27    }
28
29    pub fn new_lined<S: Index>(_size: S, _vectorization_factor: u32) -> SharedMemory<Line<T>> {
30        SharedMemory { _val: PhantomData }
31    }
32
33    pub fn new_aligned<S: Index>(
34        _size: S,
35        _vectorization_factor: u32,
36        _alignment: u32,
37    ) -> SharedMemory<Line<T>> {
38        SharedMemory { _val: PhantomData }
39    }
40
41    pub fn __expand_new_lined(
42        scope: &mut Scope,
43        size: ExpandElementTyped<u32>,
44        vectorization_factor: u32,
45    ) -> <SharedMemory<Line<T>> as CubeType>::ExpandType {
46        let size = size
47            .constant()
48            .expect("Shared memory need constant initialization value")
49            .as_u32();
50        let var = scope.create_shared(
51            Item::vectorized(T::as_elem(scope), NonZero::new(vectorization_factor as u8)),
52            size,
53            None,
54        );
55        ExpandElementTyped::new(var)
56    }
57
58    pub fn __expand_new_aligned(
59        scope: &mut Scope,
60        size: ExpandElementTyped<u32>,
61        vectorization_factor: u32,
62        alignment: u32,
63    ) -> <SharedMemory<Line<T>> as CubeType>::ExpandType {
64        let size = size
65            .constant()
66            .expect("Shared memory need constant initialization value")
67            .as_u32();
68        let var = scope.create_shared(
69            Item::vectorized(T::as_elem(scope), NonZero::new(vectorization_factor as u8)),
70            size,
71            Some(alignment),
72        );
73        ExpandElementTyped::new(var)
74    }
75
76    pub fn vectorized<S: Index>(_size: S, _vectorization_factor: u32) -> Self {
77        SharedMemory { _val: PhantomData }
78    }
79
80    pub fn __expand_vectorized(
81        scope: &mut Scope,
82        size: ExpandElementTyped<u32>,
83        vectorization_factor: u32,
84    ) -> <Self as CubeType>::ExpandType {
85        let size = size
86            .constant()
87            .expect("Shared memory need constant initialization value")
88            .as_u32();
89        let var = scope.create_shared(
90            Item::vectorized(T::as_elem(scope), NonZero::new(vectorization_factor as u8)),
91            size,
92            None,
93        );
94        ExpandElementTyped::new(var)
95    }
96
97    pub fn __expand_new(
98        scope: &mut Scope,
99        size: ExpandElementTyped<u32>,
100    ) -> <Self as CubeType>::ExpandType {
101        let size = size
102            .constant()
103            .expect("Shared memory need constant initialization value")
104            .as_u32();
105        let var = scope.create_shared(Item::new(T::as_elem(scope)), size, None);
106        ExpandElementTyped::new(var)
107    }
108}
109
110/// Module that contains the implementation details of the index functions.
111mod indexation {
112    use cubecl_ir::Operator;
113
114    use crate::{
115        ir::{BinaryOperator, Instruction},
116        prelude::{CubeIndex, CubeIndexMut},
117        unexpanded,
118    };
119
120    use super::*;
121
122    impl<E: CubePrimitive> SharedMemory<E> {
123        /// Perform an unchecked index into the array
124        ///
125        /// # Safety
126        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
127        /// always in bounds
128        pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
129        where
130            Self: CubeIndex<I>,
131        {
132            unexpanded!()
133        }
134
135        /// Perform an unchecked index assignment into the array
136        ///
137        /// # Safety
138        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
139        /// always in bounds
140        pub unsafe fn index_assign_unchecked<I: Index>(&mut self, _i: I, _value: E)
141        where
142            Self: CubeIndexMut<I>,
143        {
144            unexpanded!()
145        }
146    }
147
148    impl<E: CubePrimitive> ExpandElementTyped<SharedMemory<E>> {
149        pub fn __expand_index_unchecked_method(
150            self,
151            scope: &mut Scope,
152            i: ExpandElementTyped<u32>,
153        ) -> ExpandElementTyped<E> {
154            let out = scope.create_local(self.expand.item);
155            scope.register(Instruction::new(
156                Operator::UncheckedIndex(BinaryOperator {
157                    lhs: *self.expand,
158                    rhs: i.expand.consume(),
159                }),
160                *out,
161            ));
162            out.into()
163        }
164
165        pub fn __expand_index_assign_unchecked_method(
166            self,
167            scope: &mut Scope,
168            i: ExpandElementTyped<u32>,
169            value: ExpandElementTyped<E>,
170        ) {
171            scope.register(Instruction::new(
172                Operator::UncheckedIndexAssign(BinaryOperator {
173                    lhs: i.expand.consume(),
174                    rhs: value.expand.consume(),
175                }),
176                *self.expand,
177            ));
178        }
179    }
180}
181
182impl<T: CubePrimitive> List<T> for SharedMemory<T> {
183    fn __expand_read(
184        scope: &mut Scope,
185        this: ExpandElementTyped<SharedMemory<T>>,
186        idx: ExpandElementTyped<u32>,
187    ) -> ExpandElementTyped<T> {
188        index::expand(scope, this, idx)
189    }
190}
191
192impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<SharedMemory<T>> {
193    fn __expand_read_method(
194        self,
195        scope: &mut Scope,
196        idx: ExpandElementTyped<u32>,
197    ) -> ExpandElementTyped<T> {
198        index::expand(scope, self, idx)
199    }
200}
201
202impl<T: CubePrimitive> ListMut<T> for SharedMemory<T> {
203    fn __expand_write(
204        scope: &mut Scope,
205        this: ExpandElementTyped<SharedMemory<T>>,
206        idx: ExpandElementTyped<u32>,
207        value: ExpandElementTyped<T>,
208    ) {
209        index_assign::expand(scope, this, idx, value);
210    }
211}
212
213impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<SharedMemory<T>> {
214    fn __expand_write_method(
215        self,
216        scope: &mut Scope,
217        idx: ExpandElementTyped<u32>,
218        value: ExpandElementTyped<T>,
219    ) {
220        index_assign::expand(scope, self, idx, value);
221    }
222}