cubecl_std/tensor/view/operations/
array.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, io::read_masked, prelude::barrier::BarrierExpand};
3
4use crate::tensor::{
5 ViewOperations, ViewOperationsExpand, ViewOperationsMut, ViewOperationsMutExpand,
6 layout::Coords1d,
7};
8
9macro_rules! impl_operations_1d {
10 ($ty: ty, $expand: ty) => {
11 impl<T: CubePrimitive> ViewOperations<T, Coords1d> for $ty {}
12 impl<T: CubePrimitive> ViewOperationsExpand<T, Coords1d> for $expand {
13 fn __expand_read_method(
14 &self,
15 scope: &mut Scope,
16 pos: ExpandElementTyped<u32>,
17 ) -> <T>::ExpandType {
18 <Self as ListExpand<T>>::__expand_read_method(&self, scope, pos)
19 }
20
21 fn __expand_read_checked_method(
22 &self,
23 scope: &mut Scope,
24 pos: ExpandElementTyped<u32>,
25 ) -> <T>::ExpandType {
26 let len = self.clone().__expand_buffer_len_method(scope);
27 let in_bounds = lt::expand(scope, pos.clone(), len);
28 let slice = self.clone().__expand_to_slice_method(scope);
29 let zero = T::__expand_cast_from(scope, 0.into());
30 read_masked::expand::<T>(scope, in_bounds, slice, pos, zero)
31 }
32
33 fn __expand_read_masked_method(
34 &self,
35 scope: &mut Scope,
36 pos: ExpandElementTyped<u32>,
37 mask_value: <T>::ExpandType,
38 ) -> <T>::ExpandType {
39 let len = self.clone().__expand_buffer_len_method(scope);
40 let in_bounds = lt::expand(scope, pos.clone(), len);
41 let slice = self.clone().__expand_to_slice_method(scope);
42 read_masked::expand::<T>(scope, in_bounds, slice, pos, mask_value)
43 }
44
45 fn __expand_read_unchecked_method(
46 &self,
47 scope: &mut Scope,
48 pos: ExpandElementTyped<u32>,
49 ) -> <T>::ExpandType {
50 <Self as ListExpand<T>>::__expand_read_unchecked_method(self, scope, pos)
51 }
52
53 fn __expand_to_linear_slice_method(
54 &self,
55 scope: &mut Scope,
56 pos: ExpandElementTyped<u32>,
57 end: ExpandElementTyped<u32>,
58 ) -> SliceExpand<T, ReadOnly> {
59 let end = add::expand(scope, end, 1u32.into());
61 let start = Min::__expand_min(scope, pos, end.clone());
64 <Self as SliceOperatorExpand<T>>::__expand_slice_method(self, scope, start, end)
65 }
66
67 fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
68 self.clone().__expand_buffer_len_method(scope)
69 }
70
71 fn __expand_is_in_bounds_method(
72 &self,
73 scope: &mut Scope,
74 pos: ExpandElementTyped<u32>,
75 ) -> ExpandElementTyped<bool> {
76 let len = self.clone().__expand_buffer_len_method(scope);
77 lt::expand(scope, pos, len)
78 }
79
80 fn __expand_tensor_map_load_method(
81 &self,
82 _scope: &mut Scope,
83 _barrier: BarrierExpand,
84 _shared_memory: SliceExpand<T, ReadWrite>,
85 _pos: ExpandElementTyped<u32>,
86 ) {
87 unimplemented!("Not a tensor map");
88 }
89 }
90
91 impl<T: CubePrimitive> ViewOperationsMut<T, Coords1d> for $ty {}
92 impl<T: CubePrimitive> ViewOperationsMutExpand<T, Coords1d> for $expand {
93 fn __expand_write_method(
94 &self,
95 scope: &mut Scope,
96 pos: ExpandElementTyped<u32>,
97 value: <T>::ExpandType,
98 ) {
99 <Self as ListMutExpand<T>>::__expand_write_method(&self, scope, pos, value)
100 }
101
102 fn __expand_write_checked_method(
103 &self,
104 scope: &mut Scope,
105 pos: ExpandElementTyped<u32>,
106 value: <T>::ExpandType,
107 ) {
108 let len = self.clone().__expand_buffer_len_method(scope);
109 let in_bounds = lt::expand(scope, pos.clone(), len);
110 if_expand(scope, in_bounds.into(), |scope| {
111 <Self as ListMutExpand<T>>::__expand_write_method(&self, scope, pos, value)
112 })
113 }
114
115 fn __expand_to_linear_slice_mut_method(
116 &self,
117 scope: &mut Scope,
118 pos: ExpandElementTyped<u32>,
119 end: ExpandElementTyped<u32>,
120 ) -> SliceExpand<T, ReadWrite> {
121 let end = add::expand(scope, end, 1u32.into());
123 let start = Min::__expand_min(scope, pos, end.clone());
126 <Self as SliceMutOperatorExpand<T>>::__expand_slice_mut_method(
127 self, scope, start, end,
128 )
129 }
130
131 fn __expand_tensor_map_store_method(
132 &self,
133 _scope: &mut Scope,
134 _shared_memory: SliceExpand<T, ReadOnly>,
135 _pos: <Coords1d as CubeType>::ExpandType,
136 ) {
137 unimplemented!("Not a tensor map");
138 }
139 }
140 };
141}
142
143impl_operations_1d!(Array<T>, ExpandElementTyped<Array<T>>);
144impl_operations_1d!(Tensor<T>, ExpandElementTyped<Tensor<T>>);
145impl_operations_1d!(SharedMemory<T>, ExpandElementTyped<SharedMemory<T>>);