Skip to main content

cubecl_core/frontend/container/
shared_memory.rs

1use core::marker::PhantomData;
2use core::ops::{Deref, DerefMut};
3
4use crate::{
5    self as cubecl,
6    prelude::{Vectorized, VectorizedExpand},
7    unexpanded,
8};
9use cubecl_ir::{Marker, VariableKind, VectorSize};
10use cubecl_macros::{cube, intrinsic};
11
12use crate::{
13    frontend::{CubePrimitive, CubeType, IntoMut, NativeExpand},
14    ir::Scope,
15    prelude::*,
16};
17
18pub type SharedMemoryExpand<T> = NativeExpand<SharedMemory<T>>;
19pub type SharedExpand<T> = NativeExpand<Shared<T>>;
20
21#[derive(Clone, Copy)]
22pub struct Shared<E: CubePrimitive> {
23    _val: PhantomData<E>,
24}
25
26#[derive(Clone, Copy)]
27pub struct SharedMemory<E: CubePrimitive> {
28    _val: PhantomData<E>,
29}
30
31impl<T: CubePrimitive> IntoMut for NativeExpand<SharedMemory<T>> {
32    fn into_mut(self, _scope: &mut Scope) -> Self {
33        self
34    }
35}
36
37impl<T: CubePrimitive> CubeType for SharedMemory<T> {
38    type ExpandType = NativeExpand<SharedMemory<T>>;
39}
40
41impl<T: CubePrimitive> IntoMut for NativeExpand<Shared<T>> {
42    fn into_mut(self, _scope: &mut Scope) -> Self {
43        self
44    }
45}
46
47impl<T: CubePrimitive> CubeType for Shared<T> {
48    type ExpandType = NativeExpand<Shared<T>>;
49}
50
51#[cube]
52impl<T: CubePrimitive + Clone> SharedMemory<T> {
53    #[allow(unused_variables)]
54    pub fn new(#[comptime] size: usize) -> Self {
55        intrinsic!(|scope| {
56            scope
57                .create_shared_array(T::as_type(scope), size, None)
58                .into()
59        })
60    }
61
62    #[allow(clippy::len_without_is_empty)]
63    pub fn len(&self) -> usize {
64        intrinsic!(|_| len_static(&self))
65    }
66
67    pub fn buffer_len(&self) -> usize {
68        self.len()
69    }
70}
71
72#[cube]
73impl<T: CubePrimitive> Shared<T> {
74    pub fn new() -> Self {
75        intrinsic!(|scope| {
76            let var = scope.create_shared(T::as_type(scope));
77            NativeExpand::new(var)
78        })
79    }
80}
81
82pub trait AsRefExpand<T: CubeType> {
83    /// Converts this type into a shared reference of the (usually inferred) input type.
84    fn __expand_as_ref_method(self, scope: &mut Scope) -> T::ExpandType;
85}
86impl<T: CubePrimitive> AsRefExpand<T> for NativeExpand<T> {
87    fn __expand_as_ref_method(self, _scope: &mut Scope) -> NativeExpand<T> {
88        self
89    }
90}
91pub trait AsMutExpand<T: CubeType> {
92    /// Converts this type into a shared reference of the (usually inferred) input type.
93    fn __expand_as_mut_method(self, scope: &mut Scope) -> T::ExpandType;
94}
95impl<T: CubePrimitive> AsMutExpand<T> for NativeExpand<T> {
96    fn __expand_as_mut_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
97        self
98    }
99}
100
101/// Type inference won't allow things like assign to work normally, so we need to manually call
102/// `as_ref` or `as_mut` for those. Things like barrier ops should take `AsRef` so the conversion
103/// is automatic.
104impl<T: CubePrimitive> AsRef<T> for Shared<T> {
105    fn as_ref(&self) -> &T {
106        unexpanded!()
107    }
108}
109impl<T: CubePrimitive> AsRefExpand<T> for SharedExpand<T> {
110    fn __expand_as_ref_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
111        self.expand.into()
112    }
113}
114
115impl<T: CubePrimitive> AsMut<T> for Shared<T> {
116    fn as_mut(&mut self) -> &mut T {
117        unexpanded!()
118    }
119}
120impl<T: CubePrimitive> AsMutExpand<T> for SharedExpand<T> {
121    fn __expand_as_mut_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
122        self.expand.into()
123    }
124}
125
126impl<T: CubePrimitive> Default for Shared<T> {
127    fn default() -> Self {
128        Self::new()
129    }
130}
131impl<T: CubePrimitive> Shared<T> {
132    pub fn __expand_default(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
133        Self::__expand_new(scope)
134    }
135}
136
137#[cube]
138impl<T: CubePrimitive + Clone> SharedMemory<T> {
139    #[allow(unused_variables)]
140    pub fn new_aligned(#[comptime] size: usize, #[comptime] alignment: usize) -> SharedMemory<T> {
141        intrinsic!(|scope| {
142            let var = scope.create_shared_array(T::as_type(scope), size, Some(alignment));
143            NativeExpand::new(var)
144        })
145    }
146
147    /// Frees the shared memory for reuse, if possible on the target runtime.
148    ///
149    /// # Safety
150    /// *Must* be used in uniform control flow
151    /// *Must not* have any dangling references to this shared memory
152    pub unsafe fn free(self) {
153        intrinsic!(|scope| { scope.register(Marker::Free(*self.expand)) })
154    }
155}
156
157fn len_static<T: CubePrimitive>(shared: &NativeExpand<SharedMemory<T>>) -> NativeExpand<usize> {
158    let VariableKind::SharedArray { length, .. } = shared.expand.kind else {
159        unreachable!("Kind of shared memory is always shared memory")
160    };
161    length.into()
162}
163
164/// Module that contains the implementation details of the index functions.
165mod indexation {
166    use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
167
168    use crate::ir::Instruction;
169
170    use super::*;
171
172    type SharedMemoryExpand<E> = NativeExpand<SharedMemory<E>>;
173
174    #[cube]
175    impl<E: CubePrimitive> SharedMemory<E> {
176        /// Perform an unchecked index into the array
177        ///
178        /// # Safety
179        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
180        /// always in bounds
181        #[allow(unused_variables)]
182        pub unsafe fn index_unchecked(&self, i: usize) -> &E {
183            intrinsic!(|scope| {
184                let out = scope.create_local(self.expand.ty);
185                scope.register(Instruction::new(
186                    Operator::UncheckedIndex(IndexOperator {
187                        list: *self.expand,
188                        index: i.expand.consume(),
189                        vector_size: 0,
190                        unroll_factor: 1,
191                    }),
192                    *out,
193                ));
194                out.into()
195            })
196        }
197
198        /// Perform an unchecked index assignment into the array
199        ///
200        /// # Safety
201        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
202        /// always in bounds
203        #[allow(unused_variables)]
204        pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E) {
205            intrinsic!(|scope| {
206                scope.register(Instruction::new(
207                    Operator::UncheckedIndexAssign(IndexAssignOperator {
208                        index: i.expand.consume(),
209                        value: value.expand.consume(),
210                        vector_size: 0,
211                        unroll_factor: 1,
212                    }),
213                    *self.expand,
214                ));
215            })
216        }
217    }
218}
219
220impl<T: CubePrimitive> List<T> for SharedMemory<T> {
221    fn __expand_read(
222        scope: &mut Scope,
223        this: NativeExpand<SharedMemory<T>>,
224        idx: NativeExpand<usize>,
225    ) -> NativeExpand<T> {
226        index::expand(scope, this, idx)
227    }
228}
229
230impl<T: CubePrimitive> Deref for SharedMemory<T> {
231    type Target = [T];
232
233    fn deref(&self) -> &Self::Target {
234        unexpanded!()
235    }
236}
237
238impl<T: CubePrimitive> DerefMut for SharedMemory<T> {
239    fn deref_mut(&mut self) -> &mut Self::Target {
240        unexpanded!()
241    }
242}
243
244impl<T: CubePrimitive> ListExpand<T> for NativeExpand<SharedMemory<T>> {
245    fn __expand_read_method(&self, scope: &mut Scope, idx: NativeExpand<usize>) -> NativeExpand<T> {
246        index::expand(scope, self.clone(), idx)
247    }
248    fn __expand_read_unchecked_method(
249        &self,
250        scope: &mut Scope,
251        idx: NativeExpand<usize>,
252    ) -> NativeExpand<T> {
253        index_unchecked::expand(scope, self.clone(), idx)
254    }
255
256    fn __expand_len_method(&self, scope: &mut Scope) -> NativeExpand<usize> {
257        Self::__expand_len_method(self.clone(), scope)
258    }
259}
260
261impl<T: CubePrimitive> Vectorized for SharedMemory<T> {}
262impl<T: CubePrimitive> VectorizedExpand for NativeExpand<SharedMemory<T>> {
263    fn vector_size(&self) -> VectorSize {
264        self.expand.ty.vector_size()
265    }
266}
267
268impl<T: CubePrimitive> ListMut<T> for SharedMemory<T> {
269    fn __expand_write(
270        scope: &mut Scope,
271        this: NativeExpand<SharedMemory<T>>,
272        idx: NativeExpand<usize>,
273        value: NativeExpand<T>,
274    ) {
275        index_assign::expand(scope, this, idx, value);
276    }
277}
278
279impl<T: CubePrimitive> ListMutExpand<T> for NativeExpand<SharedMemory<T>> {
280    fn __expand_write_method(
281        &self,
282        scope: &mut Scope,
283        idx: NativeExpand<usize>,
284        value: NativeExpand<T>,
285    ) {
286        index_assign::expand(scope, self.clone(), idx, value);
287    }
288}