cubecl_std/tensor/view/operations/
array.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, io::read_masked, prelude::barrier::BarrierExpand};
3
4use crate::{
5 CubeOption, CubeOptionExpand,
6 tensor::{
7 ViewOperations, ViewOperationsExpand, ViewOperationsMut, ViewOperationsMutExpand,
8 layout::Coords1d,
9 },
10};
11
12macro_rules! impl_operations_1d {
13 ($ty: ty, $expand: ty) => {
14 impl<T: CubePrimitive> ViewOperations<T, Coords1d> for $ty {}
15 impl<T: CubePrimitive> ViewOperationsExpand<T, Coords1d> for $expand {
16 fn __expand_read_method(
17 &self,
18 scope: &mut Scope,
19 pos: ExpandElementTyped<u32>,
20 ) -> <T>::ExpandType {
21 <Self as ListExpand<T>>::__expand_read_method(&self, scope, pos)
22 }
23
24 fn __expand_read_checked_method(
25 &self,
26 scope: &mut Scope,
27 pos: ExpandElementTyped<u32>,
28 ) -> <T>::ExpandType {
29 let len = self.clone().__expand_buffer_len_method(scope);
30 let in_bounds = lt::expand(scope, pos.clone(), len);
31 let slice = self.clone().__expand_to_slice_method(scope);
32 let zero = T::__expand_cast_from(scope, 0.into());
33 read_masked::expand::<T>(scope, in_bounds, slice, pos, zero)
34 }
35
36 fn __expand_read_masked_method(
37 &self,
38 scope: &mut Scope,
39 pos: ExpandElementTyped<u32>,
40 mask_value: <T>::ExpandType,
41 ) -> <T>::ExpandType {
42 let len = self.clone().__expand_buffer_len_method(scope);
43 let in_bounds = lt::expand(scope, pos.clone(), len);
44 let slice = self.clone().__expand_to_slice_method(scope);
45 read_masked::expand::<T>(scope, in_bounds, slice, pos, mask_value)
46 }
47
48 fn __expand_read_unchecked_method(
49 &self,
50 scope: &mut Scope,
51 pos: ExpandElementTyped<u32>,
52 ) -> <T>::ExpandType {
53 <Self as ListExpand<T>>::__expand_read_unchecked_method(self, scope, pos)
54 }
55
56 fn __expand_to_linear_slice_method(
57 &self,
58 scope: &mut Scope,
59 pos: ExpandElementTyped<u32>,
60 end: ExpandElementTyped<u32>,
61 ) -> SliceExpand<T, ReadOnly> {
62 let end = add::expand(scope, end, 1u32.into());
64 let start = Min::__expand_min(scope, pos, end.clone());
67 <Self as SliceOperatorExpand<T>>::__expand_slice_method(self, scope, start, end)
68 }
69
70 fn __expand_as_tensor_map_method(
71 &self,
72 scope: &mut Scope,
73 ) -> CubeOptionExpand<TensorMap<T>> {
74 CubeOption::__expand_new_None(scope)
75 }
76
77 fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
78 self.clone().__expand_buffer_len_method(scope)
79 }
80
81 fn __expand_is_in_bounds_method(
82 &self,
83 scope: &mut Scope,
84 pos: ExpandElementTyped<u32>,
85 ) -> ExpandElementTyped<bool> {
86 let len = self.clone().__expand_buffer_len_method(scope);
87 lt::expand(scope, pos, len)
88 }
89
90 fn __expand_tensor_map_load_method(
91 &self,
92 _scope: &mut Scope,
93 _barrier: BarrierExpand,
94 _shared_memory: SliceExpand<T, ReadWrite>,
95 _pos: ExpandElementTyped<u32>,
96 ) {
97 unimplemented!("Not a tensor map");
98 }
99 }
100
101 impl<T: CubePrimitive> ViewOperationsMut<T, Coords1d> for $ty {}
102 impl<T: CubePrimitive> ViewOperationsMutExpand<T, Coords1d> for $expand {
103 fn __expand_write_method(
104 &self,
105 scope: &mut Scope,
106 pos: ExpandElementTyped<u32>,
107 value: <T>::ExpandType,
108 ) {
109 <Self as ListMutExpand<T>>::__expand_write_method(&self, scope, pos, value)
110 }
111
112 fn __expand_write_checked_method(
113 &self,
114 scope: &mut Scope,
115 pos: ExpandElementTyped<u32>,
116 value: <T>::ExpandType,
117 ) {
118 let len = self.clone().__expand_buffer_len_method(scope);
119 let in_bounds = lt::expand(scope, pos.clone(), len);
120 if_expand(scope, in_bounds.into(), |scope| {
121 <Self as ListMutExpand<T>>::__expand_write_method(&self, scope, pos, value)
122 })
123 }
124
125 fn __expand_to_linear_slice_mut_method(
126 &self,
127 scope: &mut Scope,
128 pos: ExpandElementTyped<u32>,
129 end: ExpandElementTyped<u32>,
130 ) -> SliceExpand<T, ReadWrite> {
131 let end = add::expand(scope, end, 1u32.into());
133 let start = Min::__expand_min(scope, pos, end.clone());
136 <Self as SliceMutOperatorExpand<T>>::__expand_slice_mut_method(
137 self, scope, start, end,
138 )
139 }
140
141 fn __expand_tensor_map_store_method(
142 &self,
143 _scope: &mut Scope,
144 _shared_memory: SliceExpand<T, ReadOnly>,
145 _pos: <Coords1d as CubeType>::ExpandType,
146 ) {
147 unimplemented!("Not a tensor map");
148 }
149 }
150 };
151}
152
153impl_operations_1d!(Array<T>, ExpandElementTyped<Array<T>>);
154impl_operations_1d!(Tensor<T>, ExpandElementTyped<Tensor<T>>);
155impl_operations_1d!(SharedMemory<T>, ExpandElementTyped<SharedMemory<T>>);