cubecl_reduce/
args.rs

1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, unexpanded};
5use cubecl_std::{
6    CubeOption, CubeOptionExpand,
7    tensor::r#virtual::{VirtualTensor, VirtualTensorOperations, VirtualTensorOperationsExpand},
8};
9
10pub trait ReduceDType {
11    type In: Numeric;
12    type Out: Numeric;
13}
14
15impl<In: Numeric, Out: Numeric> ReduceDType for (In, Out) {
16    type In = In;
17    type Out = Out;
18}
19
20#[cube]
21#[allow(dead_code)]
22pub trait ReduceArgs: Send + Sync + 'static + Clone {
23    type Input<E: Numeric>: LaunchArg + CubeType;
24
25    type Output<E: Numeric>: LaunchArg + CubeType;
26
27    type State<P: ReduceDType>: CubeType;
28
29    fn init_state<P: ReduceDType>(
30        input: &Self::Input<P::In>,
31        output: &mut Self::Output<P::Out>,
32    ) -> Self::State<P>;
33
34    fn read_input<P: ReduceDType>(state: &Self::State<P>, index: u32) -> Line<P::In>;
35    fn read_output<P: ReduceDType>(state: &Self::State<P>, index: u32) -> Line<P::Out>;
36
37    fn write_output<P: ReduceDType>(state: &mut Self::State<P>, index: u32, value: Line<P::Out>);
38
39    fn len_input<P: ReduceDType>(state: &Self::State<P>) -> u32;
40    fn len_output<P: ReduceDType>(state: &Self::State<P>) -> u32;
41
42    fn buffer_len_input<P: ReduceDType>(state: &Self::State<P>) -> u32;
43    fn buffer_len_output<P: ReduceDType>(state: &Self::State<P>) -> u32;
44
45    fn rank_input<P: ReduceDType>(state: &Self::State<P>) -> u32;
46    fn rank_output<P: ReduceDType>(state: &Self::State<P>) -> u32;
47
48    fn shape_input<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32;
49    fn shape_output<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32;
50
51    fn stride_input<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32;
52    fn stride_output<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32;
53
54    fn line_size_input<P: ReduceDType>(state: &Self::State<P>) -> comptime_type!(u32);
55    fn line_size_output<P: ReduceDType>(state: &Self::State<P>) -> comptime_type!(u32);
56}
57
58#[cube]
59pub fn init_tensors<RA: ReduceArgs, In: Numeric, Out: Numeric>(
60    input: &RA::Input<In>,
61    output: &mut RA::Output<Out>,
62) -> (VirtualTensor<In>, VirtualTensor<Out, ReadWrite>) {
63    let mut state = RA::init_state::<(In, Out)>(input, output);
64
65    let input = TensorArg::new_input(&state);
66    let mut output = TensorArg::new_output(&mut state);
67
68    let input = VirtualTensor::<In>::new::<TensorArg<(In, Out), RA, Input>>(&input);
69    let output =
70        VirtualTensor::<Out, ReadWrite>::new::<TensorArg<(In, Out), RA, Output>>(&mut output);
71
72    (input, output)
73}
74
75#[derive(Clone)]
76pub struct TensorArgs;
77
78#[cube]
79impl ReduceArgs for TensorArgs {
80    type Input<EG: Numeric> = Tensor<Line<EG>>;
81    type Output<EG: Numeric> = Tensor<Line<EG>>;
82    type State<P: ReduceDType> = (*const Tensor<Line<P::In>>, *mut Tensor<Line<P::Out>>);
83
84    fn init_state<P: ReduceDType>(
85        input: &Self::Input<P::In>,
86        output: &mut Self::Output<P::Out>,
87    ) -> Self::State<P> {
88        (input, output)
89    }
90
91    fn read_input<P: ReduceDType>(state: &Self::State<P>, index: u32) -> Line<P::In> {
92        unsafe { (*state.0)[index] }
93    }
94
95    fn read_output<P: ReduceDType>(state: &Self::State<P>, index: u32) -> Line<P::Out> {
96        unsafe { (*state.1)[index] }
97    }
98
99    fn write_output<P: ReduceDType>(state: &mut Self::State<P>, index: u32, value: Line<P::Out>) {
100        unsafe { (*state.1)[index] = value }
101    }
102
103    fn buffer_len_input<P: ReduceDType>(state: &Self::State<P>) -> u32 {
104        unsafe { (*state.0).buffer_len() }
105    }
106
107    fn buffer_len_output<P: ReduceDType>(state: &Self::State<P>) -> u32 {
108        unsafe { (*state.1).buffer_len() }
109    }
110
111    fn len_input<P: ReduceDType>(state: &Self::State<P>) -> u32 {
112        unsafe { (*state.0).len() }
113    }
114
115    fn len_output<P: ReduceDType>(state: &Self::State<P>) -> u32 {
116        unsafe { (*state.1).len() }
117    }
118    fn rank_input<P: ReduceDType>(state: &Self::State<P>) -> u32 {
119        unsafe { (*state.0).rank() }
120    }
121
122    fn rank_output<P: ReduceDType>(state: &Self::State<P>) -> u32 {
123        unsafe { (*state.1).rank() }
124    }
125
126    fn shape_input<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32 {
127        unsafe { (*state.0).shape(dim) }
128    }
129
130    fn shape_output<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32 {
131        unsafe { (*state.1).shape(dim) }
132    }
133
134    fn stride_input<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32 {
135        unsafe { (*state.0).stride(dim) }
136    }
137
138    fn stride_output<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32 {
139        unsafe { (*state.1).stride(dim) }
140    }
141
142    fn line_size_input<P: ReduceDType>(state: &Self::State<P>) -> comptime_type!(u32) {
143        unsafe { (*state.0).line_size() }
144    }
145
146    fn line_size_output<P: ReduceDType>(state: &Self::State<P>) -> comptime_type!(u32) {
147        unsafe { (*state.1).line_size() }
148    }
149}
150
151pub struct Input;
152pub struct Output;
153
154pub struct TensorArg<P: ReduceDType, RA: ReduceArgs, Tag> {
155    _state: *mut RA::State<P>,
156    tag: PhantomData<Tag>,
157}
158
159pub struct TensorArgExpand<P: ReduceDType, RA: ReduceArgs, Tag> {
160    state: <RA::State<P> as CubeType>::ExpandType,
161    tag: PhantomData<Tag>,
162}
163
164impl<P: ReduceDType, RA: ReduceArgs> TensorArg<P, RA, Input> {
165    pub fn new_input(_state: &RA::State<P>) -> Self {
166        unexpanded!()
167    }
168    pub fn __expand_new_input(
169        _scope: &mut Scope,
170        state: <RA::State<P> as CubeType>::ExpandType,
171    ) -> TensorArgExpand<P, RA, Input> {
172        TensorArgExpand {
173            state,
174            tag: PhantomData,
175        }
176    }
177}
178
179impl<P: ReduceDType, RA: ReduceArgs> TensorArg<P, RA, Output> {
180    pub fn new_output(_state: &mut RA::State<P>) -> Self {
181        unexpanded!()
182    }
183    pub fn __expand_new_output(
184        _scope: &mut Scope,
185        state: <RA::State<P> as CubeType>::ExpandType,
186    ) -> TensorArgExpand<P, RA, Output> {
187        TensorArgExpand {
188            state,
189            tag: PhantomData,
190        }
191    }
192}
193
194impl<P: ReduceDType, RA: ReduceArgs> VirtualTensorOperations<P::Out> for TensorArg<P, RA, Output> {}
195impl<P: ReduceDType, RA: ReduceArgs> VirtualTensorOperations<P::In> for TensorArg<P, RA, Input> {}
196
197impl<P: ReduceDType, RA: ReduceArgs> VirtualTensorOperationsExpand<P::In>
198    for TensorArgExpand<P, RA, Input>
199{
200    fn __expand_read_method(
201        &self,
202        scope: &mut Scope,
203        index: ExpandElementTyped<u32>,
204    ) -> ExpandElementTyped<Line<P::In>> {
205        RA::__expand_read_input(scope, self.state.clone(), index)
206    }
207
208    fn __expand_write_method(
209        &self,
210        _scope: &mut Scope,
211        _index: ExpandElementTyped<u32>,
212        _value: ExpandElementTyped<Line<P::In>>,
213    ) {
214        unreachable!("Can't write to input")
215    }
216
217    fn __expand_shape_method(
218        &self,
219        scope: &mut Scope,
220        axis: ExpandElementTyped<u32>,
221    ) -> ExpandElementTyped<u32> {
222        RA::__expand_shape_input(scope, self.state.clone(), axis)
223    }
224
225    fn __expand_stride_method(
226        &self,
227        scope: &mut Scope,
228        axis: ExpandElementTyped<u32>,
229    ) -> ExpandElementTyped<u32> {
230        RA::__expand_stride_input(scope, self.state.clone(), axis)
231    }
232
233    fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
234        RA::__expand_rank_input(scope, self.state.clone())
235    }
236    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
237        RA::__expand_len_input(scope, self.state.clone())
238    }
239    fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
240        RA::__expand_buffer_len_input(scope, self.state.clone())
241    }
242
243    fn __expand_read_window_method(
244        &self,
245        _context: &mut Scope,
246        _start: ExpandElementTyped<u32>,
247        _end: ExpandElementTyped<u32>,
248    ) -> SliceExpand<Line<P::In>, ReadOnly> {
249        panic!("Unsupported")
250    }
251
252    fn __expand_as_tensor_map_method(
253        &self,
254        scope: &mut Scope,
255    ) -> CubeOptionExpand<TensorMap<P::In>> {
256        CubeOption::__expand_new_None(scope)
257    }
258}
259
260impl<P: ReduceDType, RA: ReduceArgs> Lined for TensorArg<P, RA, Input> {}
261impl<P: ReduceDType, RA: ReduceArgs> LinedExpand for TensorArgExpand<P, RA, Input> {
262    fn line_size(&self) -> u32 {
263        let mut scope = Scope::root(false);
264        RA::__expand_line_size_input(&mut scope, self.state.clone())
265    }
266}
267
268impl<P: ReduceDType, RA: ReduceArgs> VirtualTensorOperationsExpand<P::Out>
269    for TensorArgExpand<P, RA, Output>
270{
271    fn __expand_read_method(
272        &self,
273        scope: &mut Scope,
274        index: ExpandElementTyped<u32>,
275    ) -> ExpandElementTyped<Line<P::Out>> {
276        RA::__expand_read_output(scope, self.state.clone(), index)
277    }
278
279    fn __expand_write_method(
280        &self,
281        scope: &mut Scope,
282        index: ExpandElementTyped<u32>,
283        value: ExpandElementTyped<Line<P::Out>>,
284    ) {
285        RA::__expand_write_output(scope, self.state.clone(), index, value)
286    }
287
288    fn __expand_shape_method(
289        &self,
290        scope: &mut Scope,
291        axis: ExpandElementTyped<u32>,
292    ) -> ExpandElementTyped<u32> {
293        RA::__expand_shape_output(scope, self.state.clone(), axis)
294    }
295
296    fn __expand_stride_method(
297        &self,
298        scope: &mut Scope,
299        axis: ExpandElementTyped<u32>,
300    ) -> ExpandElementTyped<u32> {
301        RA::__expand_stride_output(scope, self.state.clone(), axis)
302    }
303
304    fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
305        RA::__expand_rank_output(scope, self.state.clone())
306    }
307
308    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
309        RA::__expand_len_output(scope, self.state.clone())
310    }
311    fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
312        RA::__expand_buffer_len_output(scope, self.state.clone())
313    }
314
315    fn __expand_read_window_method(
316        &self,
317        _context: &mut Scope,
318        _start: ExpandElementTyped<u32>,
319        _end: ExpandElementTyped<u32>,
320    ) -> SliceExpand<Line<P::Out>, ReadOnly> {
321        panic!("Unsupported")
322    }
323
324    fn __expand_as_tensor_map_method(
325        &self,
326        scope: &mut Scope,
327    ) -> CubeOptionExpand<TensorMap<P::Out>> {
328        CubeOption::__expand_new_None(scope)
329    }
330}
331
332impl<P: ReduceDType, RA: ReduceArgs> Lined for TensorArg<P, RA, Output> {}
333impl<P: ReduceDType, RA: ReduceArgs> LinedExpand for TensorArgExpand<P, RA, Output> {
334    fn line_size(&self) -> u32 {
335        let mut scope = Scope::root(false);
336        RA::__expand_line_size_output(&mut scope, self.state.clone())
337    }
338}
339
340mod __tensor_arg {
341    use super::*;
342
343    impl<P: ReduceDType, RA: ReduceArgs, Tag> CubeType for TensorArg<P, RA, Tag> {
344        type ExpandType = TensorArgExpand<P, RA, Tag>;
345    }
346
347    impl<P: ReduceDType, RA: ReduceArgs, Tag> IntoMut for TensorArgExpand<P, RA, Tag> {
348        fn into_mut(self, _scope: &mut Scope) -> Self {
349            self
350        }
351    }
352
353    impl<P: ReduceDType, RA: ReduceArgs, Tag> CubeDebug for TensorArgExpand<P, RA, Tag> {}
354    impl<P: ReduceDType, RA: ReduceArgs, Tag> Clone for TensorArgExpand<P, RA, Tag> {
355        fn clone(&self) -> Self {
356            Self {
357                state: self.state.clone(),
358                tag: self.tag,
359            }
360        }
361    }
362}