cubecl_std/tensor/view/operations/
array.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, io::read_masked, prelude::barrier::BarrierExpand};
3
4use crate::{
5    CubeOption, CubeOptionExpand,
6    tensor::{
7        ViewOperations, ViewOperationsExpand, ViewOperationsMut, ViewOperationsMutExpand,
8        layout::Coords1d,
9    },
10};
11
12macro_rules! impl_operations_1d {
13    ($ty: ty, $expand: ty) => {
14        impl<T: CubePrimitive> ViewOperations<T, Coords1d> for $ty {}
15        impl<T: CubePrimitive> ViewOperationsExpand<T, Coords1d> for $expand {
16            fn __expand_read_method(
17                &self,
18                scope: &mut Scope,
19                pos: ExpandElementTyped<u32>,
20            ) -> <T>::ExpandType {
21                <Self as ListExpand<T>>::__expand_read_method(&self, scope, pos)
22            }
23
24            fn __expand_read_checked_method(
25                &self,
26                scope: &mut Scope,
27                pos: ExpandElementTyped<u32>,
28            ) -> <T>::ExpandType {
29                let len = self.clone().__expand_buffer_len_method(scope);
30                let in_bounds = lt::expand(scope, pos.clone(), len);
31                let slice = self.clone().__expand_to_slice_method(scope);
32                let zero = T::__expand_cast_from(scope, 0.into());
33                read_masked::expand::<T>(scope, in_bounds, slice, pos, zero)
34            }
35
36            fn __expand_read_masked_method(
37                &self,
38                scope: &mut Scope,
39                pos: ExpandElementTyped<u32>,
40                mask_value: <T>::ExpandType,
41            ) -> <T>::ExpandType {
42                let len = self.clone().__expand_buffer_len_method(scope);
43                let in_bounds = lt::expand(scope, pos.clone(), len);
44                let slice = self.clone().__expand_to_slice_method(scope);
45                read_masked::expand::<T>(scope, in_bounds, slice, pos, mask_value)
46            }
47
48            fn __expand_read_unchecked_method(
49                &self,
50                scope: &mut Scope,
51                pos: ExpandElementTyped<u32>,
52            ) -> <T>::ExpandType {
53                <Self as ListExpand<T>>::__expand_read_unchecked_method(self, scope, pos)
54            }
55
56            fn __expand_to_linear_slice_method(
57                &self,
58                scope: &mut Scope,
59                pos: ExpandElementTyped<u32>,
60                end: ExpandElementTyped<u32>,
61            ) -> SliceExpand<T, ReadOnly> {
62                let end = add::expand(scope, end, 1u32.into());
64                let start = Min::__expand_min(scope, pos, end.clone());
67                <Self as SliceOperatorExpand<T>>::__expand_slice_method(self, scope, start, end)
68            }
69
70            fn __expand_as_tensor_map_method(
71                &self,
72                scope: &mut Scope,
73            ) -> CubeOptionExpand<TensorMap<T>> {
74                CubeOption::__expand_new_None(scope)
75            }
76
77            fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
78                self.clone().__expand_buffer_len_method(scope)
79            }
80
81            fn __expand_is_in_bounds_method(
82                &self,
83                scope: &mut Scope,
84                pos: ExpandElementTyped<u32>,
85            ) -> ExpandElementTyped<bool> {
86                let len = self.clone().__expand_buffer_len_method(scope);
87                lt::expand(scope, pos, len)
88            }
89
90            fn __expand_tensor_map_load_method(
91                &self,
92                _scope: &mut Scope,
93                _barrier: BarrierExpand,
94                _shared_memory: SliceExpand<T, ReadWrite>,
95                _pos: ExpandElementTyped<u32>,
96            ) {
97                unimplemented!("Not a tensor map");
98            }
99        }
100
101        impl<T: CubePrimitive> ViewOperationsMut<T, Coords1d> for $ty {}
102        impl<T: CubePrimitive> ViewOperationsMutExpand<T, Coords1d> for $expand {
103            fn __expand_write_method(
104                &self,
105                scope: &mut Scope,
106                pos: ExpandElementTyped<u32>,
107                value: <T>::ExpandType,
108            ) {
109                <Self as ListMutExpand<T>>::__expand_write_method(&self, scope, pos, value)
110            }
111
112            fn __expand_write_checked_method(
113                &self,
114                scope: &mut Scope,
115                pos: ExpandElementTyped<u32>,
116                value: <T>::ExpandType,
117            ) {
118                let len = self.clone().__expand_buffer_len_method(scope);
119                let in_bounds = lt::expand(scope, pos.clone(), len);
120                if_expand(scope, in_bounds.into(), |scope| {
121                    <Self as ListMutExpand<T>>::__expand_write_method(&self, scope, pos, value)
122                })
123            }
124
125            fn __expand_to_linear_slice_mut_method(
126                &self,
127                scope: &mut Scope,
128                pos: ExpandElementTyped<u32>,
129                end: ExpandElementTyped<u32>,
130            ) -> SliceExpand<T, ReadWrite> {
131                let end = add::expand(scope, end, 1u32.into());
133                let start = Min::__expand_min(scope, pos, end.clone());
136                <Self as SliceMutOperatorExpand<T>>::__expand_slice_mut_method(
137                    self, scope, start, end,
138                )
139            }
140
141            fn __expand_tensor_map_store_method(
142                &self,
143                _scope: &mut Scope,
144                _shared_memory: SliceExpand<T, ReadOnly>,
145                _pos: <Coords1d as CubeType>::ExpandType,
146            ) {
147                unimplemented!("Not a tensor map");
148            }
149        }
150    };
151}
152
153impl_operations_1d!(Array<T>, ExpandElementTyped<Array<T>>);
154impl_operations_1d!(Tensor<T>, ExpandElementTyped<Tensor<T>>);
155impl_operations_1d!(SharedMemory<T>, ExpandElementTyped<SharedMemory<T>>);