1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, unexpanded};
5use cubecl_std::{
6 CubeOption, CubeOptionExpand,
7 tensor::r#virtual::{VirtualTensor, VirtualTensorOperations, VirtualTensorOperationsExpand},
8};
9
10pub trait ReduceDType {
11 type In: Numeric;
12 type Out: Numeric;
13}
14
15impl<In: Numeric, Out: Numeric> ReduceDType for (In, Out) {
16 type In = In;
17 type Out = Out;
18}
19
20#[cube]
21#[allow(dead_code)]
22pub trait ReduceArgs: Send + Sync + 'static + Clone {
23 type Input<E: Numeric>: LaunchArg + CubeType;
24
25 type Output<E: Numeric>: LaunchArg + CubeType;
26
27 type State<P: ReduceDType>: CubeType;
28
29 fn init_state<P: ReduceDType>(
30 input: &Self::Input<P::In>,
31 output: &mut Self::Output<P::Out>,
32 ) -> Self::State<P>;
33
34 fn read_input<P: ReduceDType>(state: &Self::State<P>, index: u32) -> Line<P::In>;
35 fn read_output<P: ReduceDType>(state: &Self::State<P>, index: u32) -> Line<P::Out>;
36
37 fn write_output<P: ReduceDType>(state: &mut Self::State<P>, index: u32, value: Line<P::Out>);
38
39 fn len_input<P: ReduceDType>(state: &Self::State<P>) -> u32;
40 fn len_output<P: ReduceDType>(state: &Self::State<P>) -> u32;
41
42 fn buffer_len_input<P: ReduceDType>(state: &Self::State<P>) -> u32;
43 fn buffer_len_output<P: ReduceDType>(state: &Self::State<P>) -> u32;
44
45 fn rank_input<P: ReduceDType>(state: &Self::State<P>) -> u32;
46 fn rank_output<P: ReduceDType>(state: &Self::State<P>) -> u32;
47
48 fn shape_input<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32;
49 fn shape_output<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32;
50
51 fn stride_input<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32;
52 fn stride_output<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32;
53
54 fn line_size_input<P: ReduceDType>(state: &Self::State<P>) -> comptime_type!(u32);
55 fn line_size_output<P: ReduceDType>(state: &Self::State<P>) -> comptime_type!(u32);
56}
57
58#[cube]
59pub fn init_tensors<RA: ReduceArgs, In: Numeric, Out: Numeric>(
60 input: &RA::Input<In>,
61 output: &mut RA::Output<Out>,
62) -> (VirtualTensor<In>, VirtualTensor<Out, ReadWrite>) {
63 let mut state = RA::init_state::<(In, Out)>(input, output);
64
65 let input = TensorArg::new_input(&state);
66 let mut output = TensorArg::new_output(&mut state);
67
68 let input = VirtualTensor::<In>::new::<TensorArg<(In, Out), RA, Input>>(&input);
69 let output =
70 VirtualTensor::<Out, ReadWrite>::new::<TensorArg<(In, Out), RA, Output>>(&mut output);
71
72 (input, output)
73}
74
75#[derive(Clone)]
76pub struct TensorArgs;
77
78#[cube]
79impl ReduceArgs for TensorArgs {
80 type Input<EG: Numeric> = Tensor<Line<EG>>;
81 type Output<EG: Numeric> = Tensor<Line<EG>>;
82 type State<P: ReduceDType> = (*const Tensor<Line<P::In>>, *mut Tensor<Line<P::Out>>);
83
84 fn init_state<P: ReduceDType>(
85 input: &Self::Input<P::In>,
86 output: &mut Self::Output<P::Out>,
87 ) -> Self::State<P> {
88 (input, output)
89 }
90
91 fn read_input<P: ReduceDType>(state: &Self::State<P>, index: u32) -> Line<P::In> {
92 unsafe { (*state.0)[index] }
93 }
94
95 fn read_output<P: ReduceDType>(state: &Self::State<P>, index: u32) -> Line<P::Out> {
96 unsafe { (*state.1)[index] }
97 }
98
99 fn write_output<P: ReduceDType>(state: &mut Self::State<P>, index: u32, value: Line<P::Out>) {
100 unsafe { (*state.1)[index] = value }
101 }
102
103 fn buffer_len_input<P: ReduceDType>(state: &Self::State<P>) -> u32 {
104 unsafe { (*state.0).buffer_len() }
105 }
106
107 fn buffer_len_output<P: ReduceDType>(state: &Self::State<P>) -> u32 {
108 unsafe { (*state.1).buffer_len() }
109 }
110
111 fn len_input<P: ReduceDType>(state: &Self::State<P>) -> u32 {
112 unsafe { (*state.0).len() }
113 }
114
115 fn len_output<P: ReduceDType>(state: &Self::State<P>) -> u32 {
116 unsafe { (*state.1).len() }
117 }
118 fn rank_input<P: ReduceDType>(state: &Self::State<P>) -> u32 {
119 unsafe { (*state.0).rank() }
120 }
121
122 fn rank_output<P: ReduceDType>(state: &Self::State<P>) -> u32 {
123 unsafe { (*state.1).rank() }
124 }
125
126 fn shape_input<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32 {
127 unsafe { (*state.0).shape(dim) }
128 }
129
130 fn shape_output<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32 {
131 unsafe { (*state.1).shape(dim) }
132 }
133
134 fn stride_input<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32 {
135 unsafe { (*state.0).stride(dim) }
136 }
137
138 fn stride_output<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32 {
139 unsafe { (*state.1).stride(dim) }
140 }
141
142 fn line_size_input<P: ReduceDType>(state: &Self::State<P>) -> comptime_type!(u32) {
143 unsafe { (*state.0).line_size() }
144 }
145
146 fn line_size_output<P: ReduceDType>(state: &Self::State<P>) -> comptime_type!(u32) {
147 unsafe { (*state.1).line_size() }
148 }
149}
150
151pub struct Input;
152pub struct Output;
153
154pub struct TensorArg<P: ReduceDType, RA: ReduceArgs, Tag> {
155 _state: *mut RA::State<P>,
156 tag: PhantomData<Tag>,
157}
158
159pub struct TensorArgExpand<P: ReduceDType, RA: ReduceArgs, Tag> {
160 state: <RA::State<P> as CubeType>::ExpandType,
161 tag: PhantomData<Tag>,
162}
163
164impl<P: ReduceDType, RA: ReduceArgs> TensorArg<P, RA, Input> {
165 pub fn new_input(_state: &RA::State<P>) -> Self {
166 unexpanded!()
167 }
168 pub fn __expand_new_input(
169 _scope: &mut Scope,
170 state: <RA::State<P> as CubeType>::ExpandType,
171 ) -> TensorArgExpand<P, RA, Input> {
172 TensorArgExpand {
173 state,
174 tag: PhantomData,
175 }
176 }
177}
178
179impl<P: ReduceDType, RA: ReduceArgs> TensorArg<P, RA, Output> {
180 pub fn new_output(_state: &mut RA::State<P>) -> Self {
181 unexpanded!()
182 }
183 pub fn __expand_new_output(
184 _scope: &mut Scope,
185 state: <RA::State<P> as CubeType>::ExpandType,
186 ) -> TensorArgExpand<P, RA, Output> {
187 TensorArgExpand {
188 state,
189 tag: PhantomData,
190 }
191 }
192}
193
194impl<P: ReduceDType, RA: ReduceArgs> VirtualTensorOperations<P::Out> for TensorArg<P, RA, Output> {}
195impl<P: ReduceDType, RA: ReduceArgs> VirtualTensorOperations<P::In> for TensorArg<P, RA, Input> {}
196
197impl<P: ReduceDType, RA: ReduceArgs> VirtualTensorOperationsExpand<P::In>
198 for TensorArgExpand<P, RA, Input>
199{
200 fn __expand_read_method(
201 &self,
202 scope: &mut Scope,
203 index: ExpandElementTyped<u32>,
204 ) -> ExpandElementTyped<Line<P::In>> {
205 RA::__expand_read_input(scope, self.state.clone(), index)
206 }
207
208 fn __expand_write_method(
209 &self,
210 _scope: &mut Scope,
211 _index: ExpandElementTyped<u32>,
212 _value: ExpandElementTyped<Line<P::In>>,
213 ) {
214 unreachable!("Can't write to input")
215 }
216
217 fn __expand_shape_method(
218 &self,
219 scope: &mut Scope,
220 axis: ExpandElementTyped<u32>,
221 ) -> ExpandElementTyped<u32> {
222 RA::__expand_shape_input(scope, self.state.clone(), axis)
223 }
224
225 fn __expand_stride_method(
226 &self,
227 scope: &mut Scope,
228 axis: ExpandElementTyped<u32>,
229 ) -> ExpandElementTyped<u32> {
230 RA::__expand_stride_input(scope, self.state.clone(), axis)
231 }
232
233 fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
234 RA::__expand_rank_input(scope, self.state.clone())
235 }
236 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
237 RA::__expand_len_input(scope, self.state.clone())
238 }
239 fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
240 RA::__expand_buffer_len_input(scope, self.state.clone())
241 }
242
243 fn __expand_read_window_method(
244 &self,
245 _context: &mut Scope,
246 _start: ExpandElementTyped<u32>,
247 _end: ExpandElementTyped<u32>,
248 ) -> SliceExpand<Line<P::In>, ReadOnly> {
249 panic!("Unsupported")
250 }
251
252 fn __expand_as_tensor_map_method(
253 &self,
254 scope: &mut Scope,
255 ) -> CubeOptionExpand<TensorMap<P::In>> {
256 CubeOption::__expand_new_None(scope)
257 }
258}
259
260impl<P: ReduceDType, RA: ReduceArgs> Lined for TensorArg<P, RA, Input> {}
261impl<P: ReduceDType, RA: ReduceArgs> LinedExpand for TensorArgExpand<P, RA, Input> {
262 fn line_size(&self) -> u32 {
263 let mut scope = Scope::root(false);
264 RA::__expand_line_size_input(&mut scope, self.state.clone())
265 }
266}
267
268impl<P: ReduceDType, RA: ReduceArgs> VirtualTensorOperationsExpand<P::Out>
269 for TensorArgExpand<P, RA, Output>
270{
271 fn __expand_read_method(
272 &self,
273 scope: &mut Scope,
274 index: ExpandElementTyped<u32>,
275 ) -> ExpandElementTyped<Line<P::Out>> {
276 RA::__expand_read_output(scope, self.state.clone(), index)
277 }
278
279 fn __expand_write_method(
280 &self,
281 scope: &mut Scope,
282 index: ExpandElementTyped<u32>,
283 value: ExpandElementTyped<Line<P::Out>>,
284 ) {
285 RA::__expand_write_output(scope, self.state.clone(), index, value)
286 }
287
288 fn __expand_shape_method(
289 &self,
290 scope: &mut Scope,
291 axis: ExpandElementTyped<u32>,
292 ) -> ExpandElementTyped<u32> {
293 RA::__expand_shape_output(scope, self.state.clone(), axis)
294 }
295
296 fn __expand_stride_method(
297 &self,
298 scope: &mut Scope,
299 axis: ExpandElementTyped<u32>,
300 ) -> ExpandElementTyped<u32> {
301 RA::__expand_stride_output(scope, self.state.clone(), axis)
302 }
303
304 fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
305 RA::__expand_rank_output(scope, self.state.clone())
306 }
307
308 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
309 RA::__expand_len_output(scope, self.state.clone())
310 }
311 fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
312 RA::__expand_buffer_len_output(scope, self.state.clone())
313 }
314
315 fn __expand_read_window_method(
316 &self,
317 _context: &mut Scope,
318 _start: ExpandElementTyped<u32>,
319 _end: ExpandElementTyped<u32>,
320 ) -> SliceExpand<Line<P::Out>, ReadOnly> {
321 panic!("Unsupported")
322 }
323
324 fn __expand_as_tensor_map_method(
325 &self,
326 scope: &mut Scope,
327 ) -> CubeOptionExpand<TensorMap<P::Out>> {
328 CubeOption::__expand_new_None(scope)
329 }
330}
331
332impl<P: ReduceDType, RA: ReduceArgs> Lined for TensorArg<P, RA, Output> {}
333impl<P: ReduceDType, RA: ReduceArgs> LinedExpand for TensorArgExpand<P, RA, Output> {
334 fn line_size(&self) -> u32 {
335 let mut scope = Scope::root(false);
336 RA::__expand_line_size_output(&mut scope, self.state.clone())
337 }
338}
339
340mod __tensor_arg {
341 use super::*;
342
343 impl<P: ReduceDType, RA: ReduceArgs, Tag> CubeType for TensorArg<P, RA, Tag> {
344 type ExpandType = TensorArgExpand<P, RA, Tag>;
345 }
346
347 impl<P: ReduceDType, RA: ReduceArgs, Tag> IntoMut for TensorArgExpand<P, RA, Tag> {
348 fn into_mut(self, _scope: &mut Scope) -> Self {
349 self
350 }
351 }
352
353 impl<P: ReduceDType, RA: ReduceArgs, Tag> CubeDebug for TensorArgExpand<P, RA, Tag> {}
354 impl<P: ReduceDType, RA: ReduceArgs, Tag> Clone for TensorArgExpand<P, RA, Tag> {
355 fn clone(&self) -> Self {
356 Self {
357 state: self.state.clone(),
358 tag: self.tag,
359 }
360 }
361 }
362}