cubecl_reduce/
args.rs

1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, unexpanded};
5use cubecl_std::tensor::r#virtual::{
6    ReadWrite, VirtualTensor, VirtualTensorOperations, VirtualTensorOperationsExpand,
7};
8
9pub trait ReduceDType {
10    type In: Numeric;
11    type Out: Numeric;
12}
13
14impl<In: Numeric, Out: Numeric> ReduceDType for (In, Out) {
15    type In = In;
16    type Out = Out;
17}
18
19#[cube]
20#[allow(dead_code)]
21pub trait ReduceArgs: Send + Sync + 'static + Clone {
22    type Input<E: Numeric>: LaunchArg + CubeType;
23
24    type Output<E: Numeric>: LaunchArg + CubeType;
25
26    type State<P: ReduceDType>: CubeType;
27
28    fn init_state<P: ReduceDType>(
29        input: &Self::Input<P::In>,
30        output: &mut Self::Output<P::Out>,
31    ) -> Self::State<P>;
32
33    fn read_input<P: ReduceDType>(state: &Self::State<P>, index: u32) -> Line<P::In>;
34    fn read_output<P: ReduceDType>(state: &Self::State<P>, index: u32) -> Line<P::Out>;
35
36    fn write_output<P: ReduceDType>(state: &mut Self::State<P>, index: u32, value: Line<P::Out>);
37
38    fn len_input<P: ReduceDType>(state: &Self::State<P>) -> u32;
39    fn len_output<P: ReduceDType>(state: &Self::State<P>) -> u32;
40
41    fn buffer_len_input<P: ReduceDType>(state: &Self::State<P>) -> u32;
42    fn buffer_len_output<P: ReduceDType>(state: &Self::State<P>) -> u32;
43
44    fn rank_input<P: ReduceDType>(state: &Self::State<P>) -> u32;
45    fn rank_output<P: ReduceDType>(state: &Self::State<P>) -> u32;
46
47    fn shape_input<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32;
48    fn shape_output<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32;
49
50    fn stride_input<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32;
51    fn stride_output<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32;
52}
53
54#[cube]
55pub fn init_tensors<RA: ReduceArgs, In: Numeric, Out: Numeric>(
56    input: &RA::Input<In>,
57    output: &mut RA::Output<Out>,
58) -> (VirtualTensor<In>, VirtualTensor<Out, ReadWrite>) {
59    let mut state = RA::init_state::<(In, Out)>(input, output);
60
61    let input = TensorArg::new_input(&state);
62    let mut output = TensorArg::new_output(&mut state);
63
64    let input = VirtualTensor::<In>::new::<TensorArg<(In, Out), RA, Input>>(&input);
65    let output =
66        VirtualTensor::<Out, ReadWrite>::new::<TensorArg<(In, Out), RA, Output>>(&mut output);
67
68    (input, output)
69}
70
71#[derive(Clone)]
72pub struct TensorArgs;
73
74#[cube]
75impl ReduceArgs for TensorArgs {
76    type Input<EG: Numeric> = Tensor<Line<EG>>;
77    type Output<EG: Numeric> = Tensor<Line<EG>>;
78    type State<P: ReduceDType> = (*const Tensor<Line<P::In>>, *mut Tensor<Line<P::Out>>);
79
80    fn init_state<P: ReduceDType>(
81        input: &Self::Input<P::In>,
82        output: &mut Self::Output<P::Out>,
83    ) -> Self::State<P> {
84        (input, output)
85    }
86
87    fn read_input<P: ReduceDType>(state: &Self::State<P>, index: u32) -> Line<P::In> {
88        unsafe { (*state.0)[index] }
89    }
90
91    fn read_output<P: ReduceDType>(state: &Self::State<P>, index: u32) -> Line<P::Out> {
92        unsafe { (*state.1)[index] }
93    }
94
95    fn write_output<P: ReduceDType>(state: &mut Self::State<P>, index: u32, value: Line<P::Out>) {
96        unsafe { (*state.1)[index] = value }
97    }
98
99    fn buffer_len_input<P: ReduceDType>(state: &Self::State<P>) -> u32 {
100        unsafe { (*state.0).buffer_len() }
101    }
102
103    fn buffer_len_output<P: ReduceDType>(state: &Self::State<P>) -> u32 {
104        unsafe { (*state.1).buffer_len() }
105    }
106
107    fn len_input<P: ReduceDType>(state: &Self::State<P>) -> u32 {
108        unsafe { (*state.0).len() }
109    }
110
111    fn len_output<P: ReduceDType>(state: &Self::State<P>) -> u32 {
112        unsafe { (*state.1).len() }
113    }
114    fn rank_input<P: ReduceDType>(state: &Self::State<P>) -> u32 {
115        unsafe { (*state.0).rank() }
116    }
117
118    fn rank_output<P: ReduceDType>(state: &Self::State<P>) -> u32 {
119        unsafe { (*state.1).rank() }
120    }
121
122    fn shape_input<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32 {
123        unsafe { (*state.0).shape(dim) }
124    }
125
126    fn shape_output<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32 {
127        unsafe { (*state.1).shape(dim) }
128    }
129
130    fn stride_input<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32 {
131        unsafe { (*state.0).stride(dim) }
132    }
133
134    fn stride_output<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32 {
135        unsafe { (*state.1).stride(dim) }
136    }
137}
138
139pub struct Input;
140pub struct Output;
141
142pub struct TensorArg<P: ReduceDType, RA: ReduceArgs, Tag> {
143    _state: *mut RA::State<P>,
144    tag: PhantomData<Tag>,
145}
146
147pub struct TensorArgExpand<P: ReduceDType, RA: ReduceArgs, Tag> {
148    state: <RA::State<P> as CubeType>::ExpandType,
149    tag: PhantomData<Tag>,
150}
151
152impl<P: ReduceDType, RA: ReduceArgs> TensorArg<P, RA, Input> {
153    pub fn new_input(_state: &RA::State<P>) -> Self {
154        unexpanded!()
155    }
156    pub fn __expand_new_input(
157        _scope: &mut Scope,
158        state: <RA::State<P> as CubeType>::ExpandType,
159    ) -> TensorArgExpand<P, RA, Input> {
160        TensorArgExpand {
161            state,
162            tag: PhantomData,
163        }
164    }
165}
166
167impl<P: ReduceDType, RA: ReduceArgs> TensorArg<P, RA, Output> {
168    pub fn new_output(_state: &mut RA::State<P>) -> Self {
169        unexpanded!()
170    }
171    pub fn __expand_new_output(
172        _scope: &mut Scope,
173        state: <RA::State<P> as CubeType>::ExpandType,
174    ) -> TensorArgExpand<P, RA, Output> {
175        TensorArgExpand {
176            state,
177            tag: PhantomData,
178        }
179    }
180}
181
182impl<P: ReduceDType, RA: ReduceArgs> VirtualTensorOperations<P::Out> for TensorArg<P, RA, Output> {}
183impl<P: ReduceDType, RA: ReduceArgs> VirtualTensorOperations<P::In> for TensorArg<P, RA, Input> {}
184
185impl<P: ReduceDType, RA: ReduceArgs> VirtualTensorOperationsExpand<P::In>
186    for TensorArgExpand<P, RA, Input>
187{
188    fn __expand_read_method(
189        &self,
190        scope: &mut Scope,
191        index: ExpandElementTyped<u32>,
192    ) -> ExpandElementTyped<Line<P::In>> {
193        RA::__expand_read_input(scope, self.state.clone(), index)
194    }
195
196    fn __expand_write_method(
197        &self,
198        _scope: &mut Scope,
199        _index: ExpandElementTyped<u32>,
200        _value: ExpandElementTyped<Line<P::In>>,
201    ) {
202        unreachable!("Can't write to input")
203    }
204
205    fn __expand_shape_method(
206        &self,
207        scope: &mut Scope,
208        axis: ExpandElementTyped<u32>,
209    ) -> ExpandElementTyped<u32> {
210        RA::__expand_shape_input(scope, self.state.clone(), axis)
211    }
212
213    fn __expand_stride_method(
214        &self,
215        scope: &mut Scope,
216        axis: ExpandElementTyped<u32>,
217    ) -> ExpandElementTyped<u32> {
218        RA::__expand_stride_input(scope, self.state.clone(), axis)
219    }
220
221    fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
222        RA::__expand_rank_input(scope, self.state.clone())
223    }
224    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
225        RA::__expand_len_input(scope, self.state.clone())
226    }
227    fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
228        RA::__expand_buffer_len_input(scope, self.state.clone())
229    }
230
231    fn __expand_read_window_method(
232        &self,
233        _context: &mut Scope,
234        _start: ExpandElementTyped<u32>,
235        _end: ExpandElementTyped<u32>,
236    ) -> SliceExpand<Line<P::In>, ReadOnly> {
237        panic!("Unsupported")
238    }
239
240    fn __expand_as_tensor_map_method(
241        &self,
242        _scope: &mut Scope,
243    ) -> ExpandElementTyped<TensorMap<P::In>> {
244        todo!()
245    }
246}
247
248impl<P: ReduceDType, RA: ReduceArgs> VirtualTensorOperationsExpand<P::Out>
249    for TensorArgExpand<P, RA, Output>
250{
251    fn __expand_read_method(
252        &self,
253        scope: &mut Scope,
254        index: ExpandElementTyped<u32>,
255    ) -> ExpandElementTyped<Line<P::Out>> {
256        RA::__expand_read_output(scope, self.state.clone(), index)
257    }
258
259    fn __expand_write_method(
260        &self,
261        scope: &mut Scope,
262        index: ExpandElementTyped<u32>,
263        value: ExpandElementTyped<Line<P::Out>>,
264    ) {
265        RA::__expand_write_output(scope, self.state.clone(), index, value)
266    }
267
268    fn __expand_shape_method(
269        &self,
270        scope: &mut Scope,
271        axis: ExpandElementTyped<u32>,
272    ) -> ExpandElementTyped<u32> {
273        RA::__expand_shape_output(scope, self.state.clone(), axis)
274    }
275
276    fn __expand_stride_method(
277        &self,
278        scope: &mut Scope,
279        axis: ExpandElementTyped<u32>,
280    ) -> ExpandElementTyped<u32> {
281        RA::__expand_stride_output(scope, self.state.clone(), axis)
282    }
283
284    fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
285        RA::__expand_rank_output(scope, self.state.clone())
286    }
287
288    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
289        RA::__expand_len_output(scope, self.state.clone())
290    }
291    fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
292        RA::__expand_buffer_len_output(scope, self.state.clone())
293    }
294
295    fn __expand_read_window_method(
296        &self,
297        _context: &mut Scope,
298        _start: ExpandElementTyped<u32>,
299        _end: ExpandElementTyped<u32>,
300    ) -> SliceExpand<Line<P::Out>, ReadOnly> {
301        panic!("Unsupported")
302    }
303
304    fn __expand_as_tensor_map_method(
305        &self,
306        _scope: &mut Scope,
307    ) -> ExpandElementTyped<TensorMap<P::Out>> {
308        todo!()
309    }
310}
311
312mod __tensor_arg {
313    use super::*;
314
315    impl<P: ReduceDType, RA: ReduceArgs, Tag> CubeType for TensorArg<P, RA, Tag> {
316        type ExpandType = TensorArgExpand<P, RA, Tag>;
317    }
318
319    impl<P: ReduceDType, RA: ReduceArgs, Tag> IntoMut for TensorArgExpand<P, RA, Tag> {
320        fn into_mut(self, _scope: &mut Scope) -> Self {
321            self
322        }
323    }
324
325    impl<P: ReduceDType, RA: ReduceArgs, Tag> CubeDebug for TensorArgExpand<P, RA, Tag> {}
326    impl<P: ReduceDType, RA: ReduceArgs, Tag> Clone for TensorArgExpand<P, RA, Tag> {
327        fn clone(&self) -> Self {
328            Self {
329                state: self.state.clone(),
330                tag: self.tag,
331            }
332        }
333    }
334}