cubecl_std/tensor/view/operations/
virtual_tensor.rs

1use super::*;
2use crate::tensor::{
3    layout::Coords1d,
4    r#virtual::{VirtualTensor, VirtualTensorExpand},
5};
6use cubecl::prelude::*;
7use cubecl_core::{self as cubecl, io::read_masked, prelude::barrier::BarrierExpand};
8
9impl<T: Numeric, IO: Clone> ViewOperations<Line<T>, Coords1d> for VirtualTensor<T, IO> {}
10impl<T: Numeric, IO: Clone> ViewOperationsExpand<Line<T>, Coords1d> for VirtualTensorExpand<T, IO> {
11    fn __expand_read_method(
12        &self,
13        scope: &mut Scope,
14        pos: ExpandElementTyped<u32>,
15    ) -> <Line<T> as CubeType>::ExpandType {
16        <Self as ListExpand<Line<T>>>::__expand_read_method(self, scope, pos)
17    }
18
19    fn __expand_read_checked_method(
20        &self,
21        scope: &mut Scope,
22        pos: ExpandElementTyped<u32>,
23    ) -> <Line<T> as CubeType>::ExpandType {
24        let zero = Line::__expand_cast_from(scope, 0.into());
25        self.__expand_read_masked_method(scope, pos, zero)
26    }
27
28    fn __expand_read_masked_method(
29        &self,
30        scope: &mut Scope,
31        pos: ExpandElementTyped<u32>,
32        mask_value: <Line<T> as CubeType>::ExpandType,
33    ) -> <Line<T> as CubeType>::ExpandType {
34        let in_bounds = self.__expand_is_in_bounds_method(scope, pos.clone());
35        let slice = self.clone().__expand_to_slice_method(scope);
36        read_masked::expand::<Line<T>>(scope, in_bounds, slice, pos, mask_value)
37    }
38
39    fn __expand_read_unchecked_method(
40        &self,
41        scope: &mut Scope,
42        pos: ExpandElementTyped<u32>,
43    ) -> <Line<T> as CubeType>::ExpandType {
44        <Self as ListExpand<Line<T>>>::__expand_read_unchecked_method(self, scope, pos)
45    }
46
47    fn __expand_to_linear_slice_method(
48        &self,
49        scope: &mut Scope,
50        pos: ExpandElementTyped<u32>,
51        end: ExpandElementTyped<u32>,
52    ) -> SliceExpand<Line<T>, ReadOnly> {
53        // Convert to exclusive end
54        let end = add::expand(scope, end, 1u32.into());
55        // Handling for shapes that are 0 in at least one dim, ensures the slice is not
56        // negative length.
57        let start = Min::__expand_min(scope, pos, end.clone());
58        <Self as SliceOperatorExpand<Line<T>>>::__expand_slice_method(self, scope, start, end)
59    }
60
61    fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
62        self.clone().__expand_buffer_len_method(scope)
63    }
64
65    fn __expand_is_in_bounds_method(
66        &self,
67        scope: &mut Scope,
68        pos: ExpandElementTyped<u32>,
69    ) -> ExpandElementTyped<bool> {
70        let len = self.clone().__expand_buffer_len_method(scope);
71        lt::expand(scope, pos, len)
72    }
73
74    fn __expand_tensor_map_load_method(
75        &self,
76        _scope: &mut Scope,
77        _barrier: BarrierExpand,
78        _shared_memory: SliceExpand<Line<T>, ReadWrite>,
79        _pos: ExpandElementTyped<u32>,
80    ) {
81        unimplemented!("Not a tensor map");
82    }
83}
84
85impl<T: Numeric> ViewOperationsMut<Line<T>, Coords1d> for VirtualTensor<T, ReadWrite> {}
86impl<T: Numeric> ViewOperationsMutExpand<Line<T>, Coords1d> for VirtualTensorExpand<T, ReadWrite> {
87    fn __expand_write_method(
88        &self,
89        scope: &mut Scope,
90        pos: ExpandElementTyped<u32>,
91        value: <Line<T> as CubeType>::ExpandType,
92    ) {
93        <Self as ListMutExpand<Line<T>>>::__expand_write_method(self, scope, pos, value)
94    }
95
96    fn __expand_write_checked_method(
97        &self,
98        scope: &mut Scope,
99        pos: ExpandElementTyped<u32>,
100        value: <Line<T> as CubeType>::ExpandType,
101    ) {
102        let len = self.clone().__expand_buffer_len_method(scope);
103        let in_bounds = lt::expand(scope, pos.clone(), len);
104        if_expand(scope, in_bounds.into(), |scope| {
105            <Self as ListMutExpand<Line<T>>>::__expand_write_method(self, scope, pos, value)
106        })
107    }
108
109    fn __expand_to_linear_slice_mut_method(
110        &self,
111        scope: &mut Scope,
112        pos: ExpandElementTyped<u32>,
113        end: ExpandElementTyped<u32>,
114    ) -> SliceExpand<Line<T>, ReadWrite> {
115        // Convert to exclusive end
116        let end = add::expand(scope, end, 1u32.into());
117        // Handling for shapes that are 0 in at least one dim, ensures the slice is not
118        // negative length.
119        let start = Min::__expand_min(scope, pos, end.clone());
120        <Self as SliceMutOperatorExpand<Line<T>>>::__expand_slice_mut_method(
121            self, scope, start, end,
122        )
123    }
124
125    fn __expand_tensor_map_store_method(
126        &self,
127        _scope: &mut Scope,
128        _shared_memory: SliceExpand<Line<T>, ReadOnly>,
129        _pos: <Coords1d as CubeType>::ExpandType,
130    ) {
131        unimplemented!("Not a tensor map");
132    }
133}