cubecl_std/tensor/view/operations/
array.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, io::read_masked, prelude::barrier::BarrierExpand};
3
4use crate::tensor::{
5    ViewOperations, ViewOperationsExpand, ViewOperationsMut, ViewOperationsMutExpand,
6    layout::Coords1d,
7};
8
9macro_rules! impl_operations_1d {
10    ($ty: ty, $expand: ty) => {
11        impl<T: CubePrimitive> ViewOperations<T, Coords1d> for $ty {}
12        impl<T: CubePrimitive> ViewOperationsExpand<T, Coords1d> for $expand {
13            fn __expand_read_method(
14                &self,
15                scope: &mut Scope,
16                pos: ExpandElementTyped<u32>,
17            ) -> <T>::ExpandType {
18                <Self as ListExpand<T>>::__expand_read_method(&self, scope, pos)
19            }
20
21            fn __expand_read_checked_method(
22                &self,
23                scope: &mut Scope,
24                pos: ExpandElementTyped<u32>,
25            ) -> <T>::ExpandType {
26                let len = self.clone().__expand_buffer_len_method(scope);
27                let in_bounds = lt::expand(scope, pos.clone(), len);
28                let slice = self.clone().__expand_to_slice_method(scope);
29                let zero = T::__expand_cast_from(scope, 0.into());
30                read_masked::expand::<T>(scope, in_bounds, slice, pos, zero)
31            }
32
33            fn __expand_read_masked_method(
34                &self,
35                scope: &mut Scope,
36                pos: ExpandElementTyped<u32>,
37                mask_value: <T>::ExpandType,
38            ) -> <T>::ExpandType {
39                let len = self.clone().__expand_buffer_len_method(scope);
40                let in_bounds = lt::expand(scope, pos.clone(), len);
41                let slice = self.clone().__expand_to_slice_method(scope);
42                read_masked::expand::<T>(scope, in_bounds, slice, pos, mask_value)
43            }
44
45            fn __expand_read_unchecked_method(
46                &self,
47                scope: &mut Scope,
48                pos: ExpandElementTyped<u32>,
49            ) -> <T>::ExpandType {
50                <Self as ListExpand<T>>::__expand_read_unchecked_method(self, scope, pos)
51            }
52
53            fn __expand_to_linear_slice_method(
54                &self,
55                scope: &mut Scope,
56                pos: ExpandElementTyped<u32>,
57                end: ExpandElementTyped<u32>,
58            ) -> SliceExpand<T, ReadOnly> {
59                // Convert to exclusive end
60                let end = add::expand(scope, end, 1u32.into());
61                // Handling for shapes that are 0 in at least one dim, ensures the slice is not
62                // negative length.
63                let start = Min::__expand_min(scope, pos, end.clone());
64                <Self as SliceOperatorExpand<T>>::__expand_slice_method(self, scope, start, end)
65            }
66
67            fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
68                self.clone().__expand_buffer_len_method(scope)
69            }
70
71            fn __expand_is_in_bounds_method(
72                &self,
73                scope: &mut Scope,
74                pos: ExpandElementTyped<u32>,
75            ) -> ExpandElementTyped<bool> {
76                let len = self.clone().__expand_buffer_len_method(scope);
77                lt::expand(scope, pos, len)
78            }
79
80            fn __expand_tensor_map_load_method(
81                &self,
82                _scope: &mut Scope,
83                _barrier: BarrierExpand,
84                _shared_memory: SliceExpand<T, ReadWrite>,
85                _pos: ExpandElementTyped<u32>,
86            ) {
87                unimplemented!("Not a tensor map");
88            }
89        }
90
91        impl<T: CubePrimitive> ViewOperationsMut<T, Coords1d> for $ty {}
92        impl<T: CubePrimitive> ViewOperationsMutExpand<T, Coords1d> for $expand {
93            fn __expand_write_method(
94                &self,
95                scope: &mut Scope,
96                pos: ExpandElementTyped<u32>,
97                value: <T>::ExpandType,
98            ) {
99                <Self as ListMutExpand<T>>::__expand_write_method(&self, scope, pos, value)
100            }
101
102            fn __expand_write_checked_method(
103                &self,
104                scope: &mut Scope,
105                pos: ExpandElementTyped<u32>,
106                value: <T>::ExpandType,
107            ) {
108                let len = self.clone().__expand_buffer_len_method(scope);
109                let in_bounds = lt::expand(scope, pos.clone(), len);
110                if_expand(scope, in_bounds.into(), |scope| {
111                    <Self as ListMutExpand<T>>::__expand_write_method(&self, scope, pos, value)
112                })
113            }
114
115            fn __expand_to_linear_slice_mut_method(
116                &self,
117                scope: &mut Scope,
118                pos: ExpandElementTyped<u32>,
119                end: ExpandElementTyped<u32>,
120            ) -> SliceExpand<T, ReadWrite> {
121                // Convert to exclusive end
122                let end = add::expand(scope, end, 1u32.into());
123                // Handling for shapes that are 0 in at least one dim, ensures the slice is not
124                // negative length.
125                let start = Min::__expand_min(scope, pos, end.clone());
126                <Self as SliceMutOperatorExpand<T>>::__expand_slice_mut_method(
127                    self, scope, start, end,
128                )
129            }
130
131            fn __expand_tensor_map_store_method(
132                &self,
133                _scope: &mut Scope,
134                _shared_memory: SliceExpand<T, ReadOnly>,
135                _pos: <Coords1d as CubeType>::ExpandType,
136            ) {
137                unimplemented!("Not a tensor map");
138            }
139        }
140    };
141}
142
143impl_operations_1d!(Array<T>, ExpandElementTyped<Array<T>>);
144impl_operations_1d!(Tensor<T>, ExpandElementTyped<Tensor<T>>);
145impl_operations_1d!(SharedMemory<T>, ExpandElementTyped<SharedMemory<T>>);