cubecl_core/frontend/container/
shared_memory.rs

1use std::marker::PhantomData;
2
3use crate::{
4    self as cubecl,
5    prelude::{Lined, LinedExpand},
6    unexpanded,
7};
8use cubecl_ir::{Instruction, Operation, VariableKind};
9use cubecl_macros::{cube, intrinsic};
10
11use crate::{
12    frontend::{CubePrimitive, CubeType, ExpandElementTyped, IntoMut, indexation::Index},
13    ir::{Scope, Type},
14    prelude::{
15        Line, List, ListExpand, ListMut, ListMutExpand, index, index_assign, index_unchecked,
16    },
17};
18
19type SharedMemoryExpand<T> = ExpandElementTyped<SharedMemory<T>>;
20
21#[derive(Clone, Copy)]
22pub struct SharedMemory<T: CubeType> {
23    _val: PhantomData<T>,
24}
25
26impl<T: CubePrimitive> IntoMut for ExpandElementTyped<SharedMemory<T>> {
27    fn into_mut(self, _scope: &mut Scope) -> Self {
28        self
29    }
30}
31
32impl<T: CubePrimitive> CubeType for SharedMemory<T> {
33    type ExpandType = ExpandElementTyped<SharedMemory<T>>;
34}
35
36impl<T: CubePrimitive + Clone> SharedMemory<T> {
37    pub fn new<S: Index>(_size: S) -> Self {
38        SharedMemory { _val: PhantomData }
39    }
40
41    pub fn new_lined<S: Index>(_size: S, _vectorization_factor: u32) -> SharedMemory<Line<T>> {
42        SharedMemory { _val: PhantomData }
43    }
44
45    #[allow(clippy::len_without_is_empty)]
46    pub fn len(&self) -> u32 {
47        unexpanded!()
48    }
49
50    pub fn buffer_len(&self) -> u32 {
51        unexpanded!()
52    }
53
54    pub fn __expand_new_lined(
55        scope: &mut Scope,
56        size: ExpandElementTyped<u32>,
57        line_size: u32,
58    ) -> <SharedMemory<Line<T>> as CubeType>::ExpandType {
59        let size = size
60            .constant()
61            .expect("Shared memory need constant initialization value")
62            .as_u32();
63        let var = scope.create_shared(Type::new(T::as_type(scope)).line(line_size), size, None);
64        ExpandElementTyped::new(var)
65    }
66
67    pub fn vectorized<S: Index>(_size: S, _vectorization_factor: u32) -> Self {
68        SharedMemory { _val: PhantomData }
69    }
70
71    pub fn __expand_vectorized(
72        scope: &mut Scope,
73        size: ExpandElementTyped<u32>,
74        line_size: u32,
75    ) -> <Self as CubeType>::ExpandType {
76        let size = size
77            .constant()
78            .expect("Shared memory need constant initialization value")
79            .as_u32();
80        let var = scope.create_shared(Type::new(T::as_type(scope)).line(line_size), size, None);
81        ExpandElementTyped::new(var)
82    }
83
84    pub fn __expand_new(
85        scope: &mut Scope,
86        size: ExpandElementTyped<u32>,
87    ) -> <Self as CubeType>::ExpandType {
88        let size = size
89            .constant()
90            .expect("Shared memory need constant initialization value")
91            .as_u32();
92        let var = scope.create_shared(Type::new(T::as_type(scope)), size, None);
93        ExpandElementTyped::new(var)
94    }
95
96    pub fn __expand_len(
97        scope: &mut Scope,
98        this: ExpandElementTyped<Self>,
99    ) -> ExpandElementTyped<u32> {
100        this.__expand_len_method(scope)
101    }
102
103    pub fn __expand_buffer_len(
104        scope: &mut Scope,
105        this: ExpandElementTyped<Self>,
106    ) -> ExpandElementTyped<u32> {
107        this.__expand_buffer_len_method(scope)
108    }
109}
110
111impl<T: CubePrimitive> ExpandElementTyped<SharedMemory<T>> {
112    pub fn __expand_len_method(self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
113        len_static(&self)
114    }
115
116    pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
117        self.__expand_len_method(scope)
118    }
119}
120
121#[cube]
122impl<T: CubePrimitive + Clone> SharedMemory<T> {
123    #[allow(unused_variables)]
124    pub fn new_aligned(
125        #[comptime] size: u32,
126        #[comptime] line_size: u32,
127        #[comptime] alignment: u32,
128    ) -> SharedMemory<Line<T>> {
129        intrinsic!(|scope| {
130            let var = scope.create_shared(
131                Type::new(T::as_type(scope)).line(line_size),
132                size,
133                Some(alignment),
134            );
135            ExpandElementTyped::new(var)
136        })
137    }
138
139    /// Frees the shared memory for reuse, if possible on the target runtime.
140    ///
141    /// # Safety
142    /// *Must* be used in uniform control flow
143    /// *Must not* have any dangling references to this shared memory
144    pub unsafe fn free(self) {
145        intrinsic!(|scope| { scope.register(Instruction::no_out(Operation::Free(*self.expand))) })
146    }
147}
148
149fn len_static<T: CubePrimitive>(
150    shared: &ExpandElementTyped<SharedMemory<T>>,
151) -> ExpandElementTyped<u32> {
152    let VariableKind::SharedMemory { length, .. } = shared.expand.kind else {
153        unreachable!("Kind of shared memory is always shared memory")
154    };
155    length.into()
156}
157
158/// Module that contains the implementation details of the index functions.
159mod indexation {
160    use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
161
162    use crate::ir::Instruction;
163
164    use super::*;
165
166    type SharedMemoryExpand<E> = ExpandElementTyped<SharedMemory<E>>;
167
168    #[cube]
169    impl<E: CubePrimitive> SharedMemory<E> {
170        /// Perform an unchecked index into the array
171        ///
172        /// # Safety
173        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
174        /// always in bounds
175        #[allow(unused_variables)]
176        pub unsafe fn index_unchecked(&self, i: u32) -> &E {
177            intrinsic!(|scope| {
178                let out = scope.create_local(self.expand.ty);
179                scope.register(Instruction::new(
180                    Operator::UncheckedIndex(IndexOperator {
181                        list: *self.expand,
182                        index: i.expand.consume(),
183                        line_size: 0,
184                        unroll_factor: 1,
185                    }),
186                    *out,
187                ));
188                out.into()
189            })
190        }
191
192        /// Perform an unchecked index assignment into the array
193        ///
194        /// # Safety
195        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
196        /// always in bounds
197        #[allow(unused_variables)]
198        pub unsafe fn index_assign_unchecked(&mut self, i: u32, value: E) {
199            intrinsic!(|scope| {
200                scope.register(Instruction::new(
201                    Operator::UncheckedIndexAssign(IndexAssignOperator {
202                        index: i.expand.consume(),
203                        value: value.expand.consume(),
204                        line_size: 0,
205                        unroll_factor: 1,
206                    }),
207                    *self.expand,
208                ));
209            })
210        }
211    }
212}
213
214impl<T: CubePrimitive> List<T> for SharedMemory<T> {
215    fn __expand_read(
216        scope: &mut Scope,
217        this: ExpandElementTyped<SharedMemory<T>>,
218        idx: ExpandElementTyped<u32>,
219    ) -> ExpandElementTyped<T> {
220        index::expand(scope, this, idx)
221    }
222}
223
224impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<SharedMemory<T>> {
225    fn __expand_read_method(
226        &self,
227        scope: &mut Scope,
228        idx: ExpandElementTyped<u32>,
229    ) -> ExpandElementTyped<T> {
230        index::expand(scope, self.clone(), idx)
231    }
232    fn __expand_read_unchecked_method(
233        &self,
234        scope: &mut Scope,
235        idx: ExpandElementTyped<u32>,
236    ) -> ExpandElementTyped<T> {
237        index_unchecked::expand(scope, self.clone(), idx)
238    }
239
240    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
241        Self::__expand_len_method(self.clone(), scope)
242    }
243}
244
245impl<T: CubePrimitive> Lined for SharedMemory<T> {}
246impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<SharedMemory<T>> {
247    fn line_size(&self) -> u32 {
248        self.expand.ty.line_size()
249    }
250}
251
252impl<T: CubePrimitive> ListMut<T> for SharedMemory<T> {
253    fn __expand_write(
254        scope: &mut Scope,
255        this: ExpandElementTyped<SharedMemory<T>>,
256        idx: ExpandElementTyped<u32>,
257        value: ExpandElementTyped<T>,
258    ) {
259        index_assign::expand(scope, this, idx, value);
260    }
261}
262
263impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<SharedMemory<T>> {
264    fn __expand_write_method(
265        &self,
266        scope: &mut Scope,
267        idx: ExpandElementTyped<u32>,
268        value: ExpandElementTyped<T>,
269    ) {
270        index_assign::expand(scope, self.clone(), idx, value);
271    }
272}