cubecl_std/tensor/view/operations/
virtual_tensor.rs1use super::*;
2use crate::tensor::{
3 layout::Coords1d,
4 r#virtual::{VirtualTensor, VirtualTensorExpand},
5};
6use cubecl::prelude::*;
7use cubecl_core::{self as cubecl, io::read_masked, prelude::barrier::BarrierExpand};
8
9impl<T: Numeric, IO: Clone> ViewOperations<Line<T>, Coords1d> for VirtualTensor<T, IO> {}
10impl<T: Numeric, IO: Clone> ViewOperationsExpand<Line<T>, Coords1d> for VirtualTensorExpand<T, IO> {
11 fn __expand_read_method(
12 &self,
13 scope: &mut Scope,
14 pos: ExpandElementTyped<u32>,
15 ) -> <Line<T> as CubeType>::ExpandType {
16 <Self as ListExpand<Line<T>>>::__expand_read_method(self, scope, pos)
17 }
18
19 fn __expand_read_checked_method(
20 &self,
21 scope: &mut Scope,
22 pos: ExpandElementTyped<u32>,
23 ) -> <Line<T> as CubeType>::ExpandType {
24 let zero = Line::__expand_cast_from(scope, 0.into());
25 self.__expand_read_masked_method(scope, pos, zero)
26 }
27
28 fn __expand_read_masked_method(
29 &self,
30 scope: &mut Scope,
31 pos: ExpandElementTyped<u32>,
32 mask_value: <Line<T> as CubeType>::ExpandType,
33 ) -> <Line<T> as CubeType>::ExpandType {
34 let in_bounds = self.__expand_is_in_bounds_method(scope, pos.clone());
35 let slice = self.clone().__expand_to_slice_method(scope);
36 read_masked::expand::<Line<T>>(scope, in_bounds, slice, pos, mask_value)
37 }
38
39 fn __expand_read_unchecked_method(
40 &self,
41 scope: &mut Scope,
42 pos: ExpandElementTyped<u32>,
43 ) -> <Line<T> as CubeType>::ExpandType {
44 <Self as ListExpand<Line<T>>>::__expand_read_unchecked_method(self, scope, pos)
45 }
46
47 fn __expand_to_linear_slice_method(
48 &self,
49 scope: &mut Scope,
50 pos: ExpandElementTyped<u32>,
51 end: ExpandElementTyped<u32>,
52 ) -> SliceExpand<Line<T>, ReadOnly> {
53 let end = add::expand(scope, end, 1u32.into());
55 let start = Min::__expand_min(scope, pos, end.clone());
58 <Self as SliceOperatorExpand<Line<T>>>::__expand_slice_method(self, scope, start, end)
59 }
60
61 fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
62 self.clone().__expand_buffer_len_method(scope)
63 }
64
65 fn __expand_is_in_bounds_method(
66 &self,
67 scope: &mut Scope,
68 pos: ExpandElementTyped<u32>,
69 ) -> ExpandElementTyped<bool> {
70 let len = self.clone().__expand_buffer_len_method(scope);
71 lt::expand(scope, pos, len)
72 }
73
74 fn __expand_tensor_map_load_method(
75 &self,
76 _scope: &mut Scope,
77 _barrier: BarrierExpand,
78 _shared_memory: SliceExpand<Line<T>, ReadWrite>,
79 _pos: ExpandElementTyped<u32>,
80 ) {
81 unimplemented!("Not a tensor map");
82 }
83}
84
85impl<T: Numeric> ViewOperationsMut<Line<T>, Coords1d> for VirtualTensor<T, ReadWrite> {}
86impl<T: Numeric> ViewOperationsMutExpand<Line<T>, Coords1d> for VirtualTensorExpand<T, ReadWrite> {
87 fn __expand_write_method(
88 &self,
89 scope: &mut Scope,
90 pos: ExpandElementTyped<u32>,
91 value: <Line<T> as CubeType>::ExpandType,
92 ) {
93 <Self as ListMutExpand<Line<T>>>::__expand_write_method(self, scope, pos, value)
94 }
95
96 fn __expand_write_checked_method(
97 &self,
98 scope: &mut Scope,
99 pos: ExpandElementTyped<u32>,
100 value: <Line<T> as CubeType>::ExpandType,
101 ) {
102 let len = self.clone().__expand_buffer_len_method(scope);
103 let in_bounds = lt::expand(scope, pos.clone(), len);
104 if_expand(scope, in_bounds.into(), |scope| {
105 <Self as ListMutExpand<Line<T>>>::__expand_write_method(self, scope, pos, value)
106 })
107 }
108
109 fn __expand_to_linear_slice_mut_method(
110 &self,
111 scope: &mut Scope,
112 pos: ExpandElementTyped<u32>,
113 end: ExpandElementTyped<u32>,
114 ) -> SliceExpand<Line<T>, ReadWrite> {
115 let end = add::expand(scope, end, 1u32.into());
117 let start = Min::__expand_min(scope, pos, end.clone());
120 <Self as SliceMutOperatorExpand<Line<T>>>::__expand_slice_mut_method(
121 self, scope, start, end,
122 )
123 }
124
125 fn __expand_tensor_map_store_method(
126 &self,
127 _scope: &mut Scope,
128 _shared_memory: SliceExpand<Line<T>, ReadOnly>,
129 _pos: <Coords1d as CubeType>::ExpandType,
130 ) {
131 unimplemented!("Not a tensor map");
132 }
133}