1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, unexpanded};
5use cubecl_std::tensor::r#virtual::{
6 ReadWrite, VirtualTensor, VirtualTensorOperations, VirtualTensorOperationsExpand,
7};
8
9pub trait ReduceDType {
10 type In: Numeric;
11 type Out: Numeric;
12}
13
14impl<In: Numeric, Out: Numeric> ReduceDType for (In, Out) {
15 type In = In;
16 type Out = Out;
17}
18
19#[cube]
20#[allow(dead_code)]
21pub trait ReduceArgs: Send + Sync + 'static + Clone {
22 type Input<E: Numeric>: LaunchArg + CubeType;
23
24 type Output<E: Numeric>: LaunchArg + CubeType;
25
26 type State<P: ReduceDType>: CubeType;
27
28 fn init_state<P: ReduceDType>(
29 input: &Self::Input<P::In>,
30 output: &mut Self::Output<P::Out>,
31 ) -> Self::State<P>;
32
33 fn read_input<P: ReduceDType>(state: &Self::State<P>, index: u32) -> Line<P::In>;
34 fn read_output<P: ReduceDType>(state: &Self::State<P>, index: u32) -> Line<P::Out>;
35
36 fn write_output<P: ReduceDType>(state: &mut Self::State<P>, index: u32, value: Line<P::Out>);
37
38 fn len_input<P: ReduceDType>(state: &Self::State<P>) -> u32;
39 fn len_output<P: ReduceDType>(state: &Self::State<P>) -> u32;
40
41 fn buffer_len_input<P: ReduceDType>(state: &Self::State<P>) -> u32;
42 fn buffer_len_output<P: ReduceDType>(state: &Self::State<P>) -> u32;
43
44 fn rank_input<P: ReduceDType>(state: &Self::State<P>) -> u32;
45 fn rank_output<P: ReduceDType>(state: &Self::State<P>) -> u32;
46
47 fn shape_input<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32;
48 fn shape_output<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32;
49
50 fn stride_input<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32;
51 fn stride_output<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32;
52}
53
54#[cube]
55pub fn init_tensors<RA: ReduceArgs, In: Numeric, Out: Numeric>(
56 input: &RA::Input<In>,
57 output: &mut RA::Output<Out>,
58) -> (VirtualTensor<In>, VirtualTensor<Out, ReadWrite>) {
59 let mut state = RA::init_state::<(In, Out)>(input, output);
60
61 let input = TensorArg::new_input(&state);
62 let mut output = TensorArg::new_output(&mut state);
63
64 let input = VirtualTensor::<In>::new::<TensorArg<(In, Out), RA, Input>>(&input);
65 let output =
66 VirtualTensor::<Out, ReadWrite>::new::<TensorArg<(In, Out), RA, Output>>(&mut output);
67
68 (input, output)
69}
70
71#[derive(Clone)]
72pub struct TensorArgs;
73
74#[cube]
75impl ReduceArgs for TensorArgs {
76 type Input<EG: Numeric> = Tensor<Line<EG>>;
77 type Output<EG: Numeric> = Tensor<Line<EG>>;
78 type State<P: ReduceDType> = (*const Tensor<Line<P::In>>, *mut Tensor<Line<P::Out>>);
79
80 fn init_state<P: ReduceDType>(
81 input: &Self::Input<P::In>,
82 output: &mut Self::Output<P::Out>,
83 ) -> Self::State<P> {
84 (input, output)
85 }
86
87 fn read_input<P: ReduceDType>(state: &Self::State<P>, index: u32) -> Line<P::In> {
88 unsafe { (*state.0)[index] }
89 }
90
91 fn read_output<P: ReduceDType>(state: &Self::State<P>, index: u32) -> Line<P::Out> {
92 unsafe { (*state.1)[index] }
93 }
94
95 fn write_output<P: ReduceDType>(state: &mut Self::State<P>, index: u32, value: Line<P::Out>) {
96 unsafe { (*state.1)[index] = value }
97 }
98
99 fn buffer_len_input<P: ReduceDType>(state: &Self::State<P>) -> u32 {
100 unsafe { (*state.0).buffer_len() }
101 }
102
103 fn buffer_len_output<P: ReduceDType>(state: &Self::State<P>) -> u32 {
104 unsafe { (*state.1).buffer_len() }
105 }
106
107 fn len_input<P: ReduceDType>(state: &Self::State<P>) -> u32 {
108 unsafe { (*state.0).len() }
109 }
110
111 fn len_output<P: ReduceDType>(state: &Self::State<P>) -> u32 {
112 unsafe { (*state.1).len() }
113 }
114 fn rank_input<P: ReduceDType>(state: &Self::State<P>) -> u32 {
115 unsafe { (*state.0).rank() }
116 }
117
118 fn rank_output<P: ReduceDType>(state: &Self::State<P>) -> u32 {
119 unsafe { (*state.1).rank() }
120 }
121
122 fn shape_input<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32 {
123 unsafe { (*state.0).shape(dim) }
124 }
125
126 fn shape_output<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32 {
127 unsafe { (*state.1).shape(dim) }
128 }
129
130 fn stride_input<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32 {
131 unsafe { (*state.0).stride(dim) }
132 }
133
134 fn stride_output<P: ReduceDType>(state: &Self::State<P>, dim: u32) -> u32 {
135 unsafe { (*state.1).stride(dim) }
136 }
137}
138
139pub struct Input;
140pub struct Output;
141
142pub struct TensorArg<P: ReduceDType, RA: ReduceArgs, Tag> {
143 _state: *mut RA::State<P>,
144 tag: PhantomData<Tag>,
145}
146
147pub struct TensorArgExpand<P: ReduceDType, RA: ReduceArgs, Tag> {
148 state: <RA::State<P> as CubeType>::ExpandType,
149 tag: PhantomData<Tag>,
150}
151
152impl<P: ReduceDType, RA: ReduceArgs> TensorArg<P, RA, Input> {
153 pub fn new_input(_state: &RA::State<P>) -> Self {
154 unexpanded!()
155 }
156 pub fn __expand_new_input(
157 _scope: &mut Scope,
158 state: <RA::State<P> as CubeType>::ExpandType,
159 ) -> TensorArgExpand<P, RA, Input> {
160 TensorArgExpand {
161 state,
162 tag: PhantomData,
163 }
164 }
165}
166
167impl<P: ReduceDType, RA: ReduceArgs> TensorArg<P, RA, Output> {
168 pub fn new_output(_state: &mut RA::State<P>) -> Self {
169 unexpanded!()
170 }
171 pub fn __expand_new_output(
172 _scope: &mut Scope,
173 state: <RA::State<P> as CubeType>::ExpandType,
174 ) -> TensorArgExpand<P, RA, Output> {
175 TensorArgExpand {
176 state,
177 tag: PhantomData,
178 }
179 }
180}
181
182impl<P: ReduceDType, RA: ReduceArgs> VirtualTensorOperations<P::Out> for TensorArg<P, RA, Output> {}
183impl<P: ReduceDType, RA: ReduceArgs> VirtualTensorOperations<P::In> for TensorArg<P, RA, Input> {}
184
185impl<P: ReduceDType, RA: ReduceArgs> VirtualTensorOperationsExpand<P::In>
186 for TensorArgExpand<P, RA, Input>
187{
188 fn __expand_read_method(
189 &self,
190 scope: &mut Scope,
191 index: ExpandElementTyped<u32>,
192 ) -> ExpandElementTyped<Line<P::In>> {
193 RA::__expand_read_input(scope, self.state.clone(), index)
194 }
195
196 fn __expand_write_method(
197 &self,
198 _scope: &mut Scope,
199 _index: ExpandElementTyped<u32>,
200 _value: ExpandElementTyped<Line<P::In>>,
201 ) {
202 unreachable!("Can't write to input")
203 }
204
205 fn __expand_shape_method(
206 &self,
207 scope: &mut Scope,
208 axis: ExpandElementTyped<u32>,
209 ) -> ExpandElementTyped<u32> {
210 RA::__expand_shape_input(scope, self.state.clone(), axis)
211 }
212
213 fn __expand_stride_method(
214 &self,
215 scope: &mut Scope,
216 axis: ExpandElementTyped<u32>,
217 ) -> ExpandElementTyped<u32> {
218 RA::__expand_stride_input(scope, self.state.clone(), axis)
219 }
220
221 fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
222 RA::__expand_rank_input(scope, self.state.clone())
223 }
224 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
225 RA::__expand_len_input(scope, self.state.clone())
226 }
227 fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
228 RA::__expand_buffer_len_input(scope, self.state.clone())
229 }
230
231 fn __expand_read_window_method(
232 &self,
233 _context: &mut Scope,
234 _start: ExpandElementTyped<u32>,
235 _end: ExpandElementTyped<u32>,
236 ) -> SliceExpand<Line<P::In>, ReadOnly> {
237 panic!("Unsupported")
238 }
239
240 fn __expand_as_tensor_map_method(
241 &self,
242 _scope: &mut Scope,
243 ) -> ExpandElementTyped<TensorMap<P::In>> {
244 todo!()
245 }
246}
247
248impl<P: ReduceDType, RA: ReduceArgs> VirtualTensorOperationsExpand<P::Out>
249 for TensorArgExpand<P, RA, Output>
250{
251 fn __expand_read_method(
252 &self,
253 scope: &mut Scope,
254 index: ExpandElementTyped<u32>,
255 ) -> ExpandElementTyped<Line<P::Out>> {
256 RA::__expand_read_output(scope, self.state.clone(), index)
257 }
258
259 fn __expand_write_method(
260 &self,
261 scope: &mut Scope,
262 index: ExpandElementTyped<u32>,
263 value: ExpandElementTyped<Line<P::Out>>,
264 ) {
265 RA::__expand_write_output(scope, self.state.clone(), index, value)
266 }
267
268 fn __expand_shape_method(
269 &self,
270 scope: &mut Scope,
271 axis: ExpandElementTyped<u32>,
272 ) -> ExpandElementTyped<u32> {
273 RA::__expand_shape_output(scope, self.state.clone(), axis)
274 }
275
276 fn __expand_stride_method(
277 &self,
278 scope: &mut Scope,
279 axis: ExpandElementTyped<u32>,
280 ) -> ExpandElementTyped<u32> {
281 RA::__expand_stride_output(scope, self.state.clone(), axis)
282 }
283
284 fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
285 RA::__expand_rank_output(scope, self.state.clone())
286 }
287
288 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
289 RA::__expand_len_output(scope, self.state.clone())
290 }
291 fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
292 RA::__expand_buffer_len_output(scope, self.state.clone())
293 }
294
295 fn __expand_read_window_method(
296 &self,
297 _context: &mut Scope,
298 _start: ExpandElementTyped<u32>,
299 _end: ExpandElementTyped<u32>,
300 ) -> SliceExpand<Line<P::Out>, ReadOnly> {
301 panic!("Unsupported")
302 }
303
304 fn __expand_as_tensor_map_method(
305 &self,
306 _scope: &mut Scope,
307 ) -> ExpandElementTyped<TensorMap<P::Out>> {
308 todo!()
309 }
310}
311
312mod __tensor_arg {
313 use super::*;
314
315 impl<P: ReduceDType, RA: ReduceArgs, Tag> CubeType for TensorArg<P, RA, Tag> {
316 type ExpandType = TensorArgExpand<P, RA, Tag>;
317 }
318
319 impl<P: ReduceDType, RA: ReduceArgs, Tag> IntoMut for TensorArgExpand<P, RA, Tag> {
320 fn into_mut(self, _scope: &mut Scope) -> Self {
321 self
322 }
323 }
324
325 impl<P: ReduceDType, RA: ReduceArgs, Tag> CubeDebug for TensorArgExpand<P, RA, Tag> {}
326 impl<P: ReduceDType, RA: ReduceArgs, Tag> Clone for TensorArgExpand<P, RA, Tag> {
327 fn clone(&self) -> Self {
328 Self {
329 state: self.state.clone(),
330 tag: self.tag,
331 }
332 }
333 }
334}