cubecl_core/frontend/container/
shared_memory.rs

1use std::{marker::PhantomData, num::NonZero};
2
3use crate as cubecl;
4use cubecl_macros::{cube, intrinsic};
5
6use crate::{
7    frontend::{CubePrimitive, CubeType, ExpandElementTyped, IntoMut, indexation::Index},
8    ir::{Item, Scope},
9    prelude::{
10        Line, List, ListExpand, ListMut, ListMutExpand, index, index_assign, index_unchecked,
11    },
12};
13
14#[derive(Clone, Copy)]
15pub struct SharedMemory<T: CubeType> {
16    _val: PhantomData<T>,
17}
18
19impl<T: CubePrimitive> IntoMut for ExpandElementTyped<SharedMemory<T>> {
20    fn into_mut(self, _scope: &mut Scope) -> Self {
21        self
22    }
23}
24
25impl<T: CubePrimitive> CubeType for SharedMemory<T> {
26    type ExpandType = ExpandElementTyped<SharedMemory<T>>;
27}
28
29impl<T: CubePrimitive + Clone> SharedMemory<T> {
30    pub fn new<S: Index>(_size: S) -> Self {
31        SharedMemory { _val: PhantomData }
32    }
33
34    pub fn new_lined<S: Index>(_size: S, _vectorization_factor: u32) -> SharedMemory<Line<T>> {
35        SharedMemory { _val: PhantomData }
36    }
37
38    pub fn __expand_new_lined(
39        scope: &mut Scope,
40        size: ExpandElementTyped<u32>,
41        vectorization_factor: u32,
42    ) -> <SharedMemory<Line<T>> as CubeType>::ExpandType {
43        let size = size
44            .constant()
45            .expect("Shared memory need constant initialization value")
46            .as_u32();
47        let var = scope.create_shared(
48            Item::vectorized(T::as_elem(scope), NonZero::new(vectorization_factor as u8)),
49            size,
50            None,
51        );
52        ExpandElementTyped::new(var)
53    }
54
55    pub fn vectorized<S: Index>(_size: S, _vectorization_factor: u32) -> Self {
56        SharedMemory { _val: PhantomData }
57    }
58
59    pub fn __expand_vectorized(
60        scope: &mut Scope,
61        size: ExpandElementTyped<u32>,
62        vectorization_factor: u32,
63    ) -> <Self as CubeType>::ExpandType {
64        let size = size
65            .constant()
66            .expect("Shared memory need constant initialization value")
67            .as_u32();
68        let var = scope.create_shared(
69            Item::vectorized(T::as_elem(scope), NonZero::new(vectorization_factor as u8)),
70            size,
71            None,
72        );
73        ExpandElementTyped::new(var)
74    }
75
76    pub fn __expand_new(
77        scope: &mut Scope,
78        size: ExpandElementTyped<u32>,
79    ) -> <Self as CubeType>::ExpandType {
80        let size = size
81            .constant()
82            .expect("Shared memory need constant initialization value")
83            .as_u32();
84        let var = scope.create_shared(Item::new(T::as_elem(scope)), size, None);
85        ExpandElementTyped::new(var)
86    }
87}
88
89#[cube]
90impl<T: CubePrimitive + Clone> SharedMemory<T> {
91    #[allow(unused_variables)]
92    pub fn new_aligned(
93        #[comptime] size: u32,
94        #[comptime] vectorization_factor: u32,
95        #[comptime] alignment: u32,
96    ) -> SharedMemory<Line<T>> {
97        intrinsic!(|scope| {
98            let var = scope.create_shared(
99                Item::vectorized(T::as_elem(scope), NonZero::new(vectorization_factor as u8)),
100                size,
101                Some(alignment),
102            );
103            ExpandElementTyped::new(var)
104        })
105    }
106}
107
108/// Module that contains the implementation details of the index functions.
109mod indexation {
110    use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
111
112    use crate::ir::Instruction;
113
114    use super::*;
115
116    type SharedMemoryExpand<E> = ExpandElementTyped<SharedMemory<E>>;
117
118    #[cube]
119    impl<E: CubePrimitive> SharedMemory<E> {
120        /// Perform an unchecked index into the array
121        ///
122        /// # Safety
123        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
124        /// always in bounds
125        #[allow(unused_variables)]
126        pub unsafe fn index_unchecked(&self, i: u32) -> &E {
127            intrinsic!(|scope| {
128                let out = scope.create_local(self.expand.item);
129                scope.register(Instruction::new(
130                    Operator::UncheckedIndex(IndexOperator {
131                        list: *self.expand,
132                        index: i.expand.consume(),
133                        line_size: 0,
134                    }),
135                    *out,
136                ));
137                out.into()
138            })
139        }
140
141        /// Perform an unchecked index assignment into the array
142        ///
143        /// # Safety
144        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
145        /// always in bounds
146        #[allow(unused_variables)]
147        pub unsafe fn index_assign_unchecked(&mut self, i: u32, value: E) {
148            intrinsic!(|scope| {
149                scope.register(Instruction::new(
150                    Operator::UncheckedIndexAssign(IndexAssignOperator {
151                        index: i.expand.consume(),
152                        value: value.expand.consume(),
153                        line_size: 0,
154                    }),
155                    *self.expand,
156                ));
157            })
158        }
159    }
160}
161
162impl<T: CubePrimitive> List<T> for SharedMemory<T> {
163    fn __expand_read(
164        scope: &mut Scope,
165        this: ExpandElementTyped<SharedMemory<T>>,
166        idx: ExpandElementTyped<u32>,
167    ) -> ExpandElementTyped<T> {
168        index::expand(scope, this, idx)
169    }
170}
171
172impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<SharedMemory<T>> {
173    fn __expand_read_method(
174        &self,
175        scope: &mut Scope,
176        idx: ExpandElementTyped<u32>,
177    ) -> ExpandElementTyped<T> {
178        index::expand(scope, self.clone(), idx)
179    }
180    fn __expand_read_unchecked_method(
181        &self,
182        scope: &mut Scope,
183        idx: ExpandElementTyped<u32>,
184    ) -> ExpandElementTyped<T> {
185        index_unchecked::expand(scope, self.clone(), idx)
186    }
187}
188
189impl<T: CubePrimitive> ListMut<T> for SharedMemory<T> {
190    fn __expand_write(
191        scope: &mut Scope,
192        this: ExpandElementTyped<SharedMemory<T>>,
193        idx: ExpandElementTyped<u32>,
194        value: ExpandElementTyped<T>,
195    ) {
196        index_assign::expand(scope, this, idx, value);
197    }
198}
199
200impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<SharedMemory<T>> {
201    fn __expand_write_method(
202        &self,
203        scope: &mut Scope,
204        idx: ExpandElementTyped<u32>,
205        value: ExpandElementTyped<T>,
206    ) {
207        index_assign::expand(scope, self.clone(), idx, value);
208    }
209}