1use alloc::sync::Arc;
2use core::marker::PhantomData;
3use cubecl::prelude::{CubeType, Scope, *};
4use cubecl_core::{self as cubecl, unexpanded};
5
6use crate::{
7 CubeOption,
8 tensor::{
9 ViewExpand,
10 layout::{
11 Coordinates, Coords1d, Layout, VirtualLayout, VirtualLayoutExpand, simple::SimpleLayout,
12 },
13 view::View,
14 },
15};
16
17#[derive(Clone)]
19pub struct VirtualTensor<E: Numeric, IO = ReadOnly> {
20 _e: PhantomData<E>,
22 _p: PhantomData<IO>,
23}
24
25impl<E: Numeric, IO: Clone> Copy for VirtualTensor<E, IO> {}
26
27#[derive(Clone)]
29pub struct VirtualTensorExpand<E: Numeric, IO> {
30 state: Arc<dyn VirtualTensorOperationsExpand<E>>,
31 _p: PhantomData<IO>,
32}
33
34impl<E: Numeric, IO: Clone> List<Line<E>> for VirtualTensor<E, IO> {
35 fn __expand_read(
36 scope: &mut Scope,
37 this: VirtualTensorExpand<E, IO>,
38 index: <u32 as CubeType>::ExpandType,
39 ) -> <Line<E> as CubeType>::ExpandType {
40 this.__expand_read_method(scope, index)
41 }
42}
43
44impl<E: Numeric, IO: Clone> ListExpand<Line<E>> for VirtualTensorExpand<E, IO> {
45 fn __expand_read_method(
46 &self,
47 scope: &mut Scope,
48 index: <u32 as CubeType>::ExpandType,
49 ) -> <Line<E> as CubeType>::ExpandType {
50 self.state.clone().__expand_read_method(scope, index)
51 }
52
53 fn __expand_read_unchecked_method(
54 &self,
55 _scope: &mut Scope,
56 _index: ExpandElementTyped<u32>,
57 ) -> <Line<E> as CubeType>::ExpandType {
58 todo!("VirtualTensor don't support read unchecked yet");
59 }
60
61 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
62 self.state.clone().__expand_len_method(scope)
63 }
64}
65
66impl<E: Numeric, IO: Clone> Lined for VirtualTensor<E, IO> {}
67impl<E: Numeric, IO: Clone> LinedExpand for VirtualTensorExpand<E, IO> {
68 fn line_size(&self) -> u32 {
69 self.state.clone().line_size()
70 }
71}
72
73impl<E: Numeric, IO: Clone> SliceOperator<Line<E>> for VirtualTensor<E, IO> {}
74impl<E: Numeric, IO: Clone> SliceOperatorExpand<Line<E>> for VirtualTensorExpand<E, IO> {
75 fn __expand_slice_method(
76 &self,
77 scope: &mut Scope,
78 start: ExpandElementTyped<u32>,
79 end: ExpandElementTyped<u32>,
80 ) -> SliceExpand<Line<E>, ReadOnly> {
81 self.state
82 .clone()
83 .__expand_read_window_method(scope, start, end)
84 }
85
86 fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<Line<E>, ReadOnly> {
87 let end = self.clone().__expand_buffer_len_method(scope);
88 self.state
89 .clone()
90 .__expand_read_window_method(scope, 0.into(), end)
91 }
92}
93
94#[allow(unused, clippy::all)]
95impl<E: Numeric, IO: Clone> VirtualTensor<E, IO> {
96 pub fn as_tensor_map(&self) -> CubeOption<TensorMap<E>> {
97 unexpanded!()
98 }
99 pub fn as_slice(&self, start: u32, end: u32) -> Slice<Line<E>> {
100 unexpanded!();
101 }
102 pub fn shape(&self, axis: u32) -> u32 {
104 unexpanded!();
105 }
106 pub fn stride(&self, axis: u32) -> u32 {
108 unexpanded!();
109 }
110 pub fn rank(&self) -> u32 {
112 unexpanded!();
113 }
114
115 pub fn buffer_len(&self) -> u32 {
116 unexpanded!();
117 }
118
119 pub fn __expand_as_tensor_map(
120 context: &mut Scope,
121 this: <Self as CubeType>::ExpandType,
122 ) -> <CubeOption<TensorMap<E>> as CubeType>::ExpandType {
123 this.__expand_as_tensor_map_method(context)
124 }
125 pub fn __expand_as_slice(
126 context: &mut Scope,
127 this: <Self as CubeType>::ExpandType,
128 start: <u32 as CubeType>::ExpandType,
129 end: <u32 as CubeType>::ExpandType,
130 ) -> <Slice<Line<E>> as CubeType>::ExpandType {
131 this.__expand_as_slice_method(context, start, end)
132 }
133 pub fn __expand_shape(
134 scope: &mut Scope,
135 this: <Self as CubeType>::ExpandType,
136 axis: <u32 as CubeType>::ExpandType,
137 ) -> <u32 as CubeType>::ExpandType {
138 this.__expand_shape_method(scope, axis)
139 }
140 pub fn __expand_stride(
141 scope: &mut Scope,
142 this: <Self as CubeType>::ExpandType,
143 axis: <u32 as CubeType>::ExpandType,
144 ) -> <u32 as CubeType>::ExpandType {
145 this.__expand_stride_method(scope, axis)
146 }
147 pub fn __expand_rank(
148 scope: &mut Scope,
149 this: <Self as CubeType>::ExpandType,
150 ) -> <u32 as CubeType>::ExpandType {
151 this.__expand_rank_method(scope)
152 }
153 pub fn __expand_buffer_len(
154 scope: &mut Scope,
155 this: <Self as CubeType>::ExpandType,
156 ) -> <u32 as CubeType>::ExpandType {
157 this.__expand_buffer_len_method(scope)
158 }
159}
160
161#[allow(unused, clippy::all)]
162impl<E: Numeric, IO: Clone> VirtualTensorExpand<E, IO> {
163 pub fn __expand_as_tensor_map_method(
164 self,
165 context: &mut Scope,
166 ) -> <CubeOption<TensorMap<E>> as CubeType>::ExpandType {
167 self.state.clone().__expand_as_tensor_map_method(context)
168 }
169
170 pub fn __expand_as_slice_method(
171 self,
172 context: &mut Scope,
173 start: <u32 as CubeType>::ExpandType,
174 end: <u32 as CubeType>::ExpandType,
175 ) -> <Slice<Line<E>> as CubeType>::ExpandType {
176 self.state
177 .clone()
178 .__expand_read_window_method(context, start, end)
179 }
180
181 pub fn __expand_shape_method(
182 self,
183 scope: &mut Scope,
184 axis: <u32 as CubeType>::ExpandType,
185 ) -> <u32 as CubeType>::ExpandType {
186 let _arg_0 = axis;
187 self.state
188 .clone()
189 .__expand_shape_method(scope, _arg_0.into())
190 }
191
192 pub fn __expand_stride_method(
193 self,
194 scope: &mut Scope,
195 axis: <u32 as CubeType>::ExpandType,
196 ) -> <u32 as CubeType>::ExpandType {
197 let _arg_0 = axis;
198 self.state
199 .clone()
200 .__expand_stride_method(scope, _arg_0.into())
201 }
202
203 pub fn __expand_rank_method(self, scope: &mut Scope) -> <u32 as CubeType>::ExpandType {
204 self.state.clone().__expand_rank_method(scope)
205 }
206
207 pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> <u32 as CubeType>::ExpandType {
208 self.state.clone().__expand_buffer_len_method(scope)
209 }
210
211 pub fn __expand_read(
212 scope: &mut Scope,
213 this: Self,
214 index: <u32 as CubeType>::ExpandType,
215 ) -> <Line<E> as CubeType>::ExpandType {
216 VirtualTensor::<E, IO>::__expand_read(scope, this, index)
217 }
218
219 pub fn __expand_shape(
220 scope: &mut Scope,
221 this: Self,
222 axis: <u32 as CubeType>::ExpandType,
223 ) -> <u32 as CubeType>::ExpandType {
224 VirtualTensor::<E, IO>::__expand_shape(scope, this, axis)
225 }
226
227 pub fn __expand_stride(
228 scope: &mut Scope,
229 this: Self,
230 axis: <u32 as CubeType>::ExpandType,
231 ) -> <u32 as CubeType>::ExpandType {
232 VirtualTensor::<E, IO>::__expand_stride(scope, this, axis)
233 }
234
235 pub fn __expand_rank(scope: &mut Scope, this: Self) -> <u32 as CubeType>::ExpandType {
236 VirtualTensor::<E, IO>::__expand_rank(scope, this)
237 }
238}
239
240impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
241 pub fn view<C: Coordinates + 'static>(
244 &self,
245 layout: impl Into<VirtualLayout<C, Coords1d>>,
246 ) -> View<Line<E>, C, ReadOnly> {
247 View::new::<VirtualTensor<E, IO>, Coords1d>(self, layout)
248 }
249}
250
251#[cube]
252impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
253 pub fn as_view(&self) -> View<Line<E>, u32, ReadOnly> {
255 let line_size = self.line_size();
256 View::new::<VirtualTensor<E, IO>, u32>(
257 self,
258 SimpleLayout::new(self.len() * line_size, line_size),
259 )
260 }
261}
262
263impl<E: Numeric, IO: Clone + 'static> VirtualTensorExpand<E, IO> {
264 pub fn __expand_view_method<C: Coordinates + 'static>(
267 &self,
268 scope: &mut Scope,
269 layout: VirtualLayoutExpand<C, Coords1d>,
270 ) -> ViewExpand<Line<E>, C, ReadOnly> {
271 View::__expand_new::<VirtualTensor<E, IO>, Coords1d>(scope, self.clone(), layout)
272 }
273}
274
275impl<E: Numeric> VirtualTensor<E, ReadWrite> {
276 #[doc = " Create a mutable conceptual view over this tensor, allowing for multi-dimensional indexing"]
277 #[doc = " with custom layouts"]
278 pub fn view_mut<C: Coordinates + 'static>(
279 &self,
280 layout: impl Layout<Coordinates = C, SourceCoordinates = Coords1d> + 'static,
281 ) -> View<Line<E>, C, ReadWrite> {
282 let mut this: VirtualTensor<E, ReadWrite> = *self;
283 View::new_mut::<VirtualTensor<E, ReadWrite>, Coords1d>(&mut this, layout)
284 }
285 pub fn __expand_view_mut<C: Coordinates + 'static>(
286 scope: &mut Scope,
287 this: VirtualTensorExpand<E, ReadWrite>,
288 layout: VirtualLayoutExpand<C, Coords1d>,
289 ) -> ViewExpand<Line<E>, C, ReadWrite> {
290 this.__expand_view_mut_method::<C>(scope, layout)
291 }
292}
293impl<E: Numeric> VirtualTensorExpand<E, ReadWrite> {
294 pub fn __expand_view_mut_method<C: Coordinates + 'static>(
295 self,
296 scope: &mut Scope,
297 layout: VirtualLayoutExpand<C, Coords1d>,
298 ) -> ViewExpand<Line<E>, C, ReadWrite> {
299 View::__expand_new_mut::<VirtualTensor<E, ReadWrite>, Coords1d>(scope, self, layout)
300 }
301}
302
303#[cube]
304impl<E: Numeric> VirtualTensor<E, ReadWrite> {
305 pub fn as_view_mut(&mut self) -> View<Line<E>, u32, ReadWrite> {
307 let line_size = self.line_size();
308 View::new_mut::<VirtualTensor<E, ReadWrite>, u32>(
309 self,
310 SimpleLayout::new(self.len() * line_size, line_size),
311 )
312 }
313}
314
315#[cube]
316impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
317 pub fn coordinate(&self, index: u32, dim: u32) -> u32 {
318 let num_strides = index / self.stride(dim);
319 num_strides % self.shape(dim)
320 }
321}
322
323impl<E: Numeric> ListMut<Line<E>> for VirtualTensor<E, ReadWrite> {
324 fn __expand_write(
325 scope: &mut Scope,
326 this: VirtualTensorExpand<E, ReadWrite>,
327 index: <u32 as CubeType>::ExpandType,
328 value: <Line<E> as CubeType>::ExpandType,
329 ) -> <() as CubeType>::ExpandType {
330 this.__expand_write_method(scope, index, value)
331 }
332}
333
334impl<E: Numeric> ListMutExpand<Line<E>> for VirtualTensorExpand<E, ReadWrite> {
335 fn __expand_write_method(
336 &self,
337 scope: &mut Scope,
338 index: <u32 as CubeType>::ExpandType,
339 value: <Line<E> as CubeType>::ExpandType,
340 ) -> <() as CubeType>::ExpandType {
341 self.state
342 .clone()
343 .__expand_write_method(scope, index, value)
344 }
345}
346
347impl<E: Numeric> SliceMutOperator<Line<E>> for VirtualTensor<E, ReadWrite> {}
348impl<E: Numeric> SliceMutOperatorExpand<Line<E>> for VirtualTensorExpand<E, ReadWrite> {
349 #[allow(unused_variables)]
350 fn __expand_slice_mut_method(
351 &self,
352 scope: &mut Scope,
353 start: ExpandElementTyped<u32>,
354 end: ExpandElementTyped<u32>,
355 ) -> SliceExpand<Line<E>, cubecl_core::prelude::ReadWrite> {
356 todo!("VirtualTensor don't support slice mut yet");
357 }
358
359 #[allow(unused_variables)]
360 fn __expand_to_slice_mut_method(
361 &self,
362 scope: &mut Scope,
363 ) -> SliceExpand<Line<E>, cubecl_core::prelude::ReadWrite> {
364 todo!("VirtualTensor don't support slice mut yet");
365 }
366}
367
368impl<E: Numeric> VirtualTensor<E, ReadOnly> {
369 pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &V) -> Self {
371 unexpanded!()
372 }
373
374 pub fn __expand_new<V: VirtualTensorOperations<E> + 'static>(
376 _scope: &mut Scope,
377 v: V::ExpandType,
378 ) -> VirtualTensorExpand<E, ReadOnly> {
379 VirtualTensorExpand {
380 state: Arc::new(v),
381 _p: PhantomData,
382 }
383 }
384}
385
386impl<E: Numeric> VirtualTensor<E, ReadWrite> {
387 pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &mut V) -> Self {
389 unexpanded!()
390 }
391
392 pub fn __expand_new<V: VirtualTensorOperations<E> + 'static>(
394 _scope: &mut Scope,
395 v: V::ExpandType,
396 ) -> VirtualTensorExpand<E, ReadWrite> {
397 VirtualTensorExpand {
398 state: Arc::new(v),
399 _p: PhantomData,
400 }
401 }
402}
403
404#[cube(self_type = "ref", expand_base_traits = "LinedExpand")]
414pub trait VirtualTensorOperations<E: Numeric>: Lined {
415 fn as_tensor_map(&self) -> CubeOption<TensorMap<E>> {
416 unexpanded!()
417 }
418 fn read(&self, _index: u32) -> Line<E> {
420 unexpanded!()
421 }
422 fn read_window(&self, _start: u32, _end: u32) -> Slice<Line<E>, ReadOnly> {
423 unexpanded!()
424 }
425 fn write(&self, _index: u32, _value: Line<E>) {
427 unexpanded!()
428 }
429 fn shape(&self, _axis: u32) -> u32 {
431 unexpanded!()
432 }
433 fn stride(&self, _axis: u32) -> u32 {
435 unexpanded!()
436 }
437 fn rank(&self) -> u32 {
439 unexpanded!()
440 }
441 fn len(&self) -> u32 {
442 unexpanded!()
443 }
444 fn buffer_len(&self) -> u32 {
445 unexpanded!()
446 }
447}
448
449mod __cube_type {
451 use super::*;
452
453 impl<E: Numeric, IO: Clone> CubeType for VirtualTensor<E, IO> {
454 type ExpandType = VirtualTensorExpand<E, IO>;
455 }
456
457 impl<E: Numeric, IO> IntoMut for VirtualTensorExpand<E, IO> {
458 fn into_mut(self, _scope: &mut Scope) -> Self {
459 self
460 }
461 }
462
463 impl<E: Numeric, IO> CubeDebug for VirtualTensorExpand<E, IO> {}
464}
465
466mod __tensor {
468 use crate::CubeOptionExpand;
469
470 use super::*;
471
472 impl<E: Numeric> VirtualTensorOperations<E> for Tensor<Line<E>> {}
473 impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<Tensor<Line<E>>> {
474 fn __expand_read_method(
475 &self,
476 scope: &mut Scope,
477 index: ExpandElementTyped<u32>,
478 ) -> ExpandElementTyped<Line<E>> {
479 self.clone().__expand_index_unchecked_method(scope, index)
480 }
481 fn __expand_read_window_method(
482 &self,
483 context: &mut Scope,
484 start: ExpandElementTyped<u32>,
485 end: ExpandElementTyped<u32>,
486 ) -> SliceExpand<Line<E>, ReadOnly> {
487 self.clone().__expand_slice_method(context, start, end)
488 }
489
490 fn __expand_write_method(
491 &self,
492 scope: &mut Scope,
493 index: ExpandElementTyped<u32>,
494 value: ExpandElementTyped<Line<E>>,
495 ) {
496 self.clone()
497 .__expand_index_assign_unchecked_method(scope, index, value)
498 }
499
500 fn __expand_shape_method(
501 &self,
502 scope: &mut Scope,
503 axis: ExpandElementTyped<u32>,
504 ) -> ExpandElementTyped<u32> {
505 self.clone().__expand_shape_method(scope, axis)
506 }
507
508 fn __expand_stride_method(
509 &self,
510 scope: &mut Scope,
511 axis: ExpandElementTyped<u32>,
512 ) -> ExpandElementTyped<u32> {
513 self.clone().__expand_stride_method(scope, axis)
514 }
515
516 fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
517 self.clone().__expand_rank_method(scope)
518 }
519 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
520 self.clone().__expand_len_method(scope)
521 }
522 fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
523 self.clone().__expand_buffer_len_method(scope)
524 }
525
526 fn __expand_as_tensor_map_method(
527 &self,
528 scope: &mut Scope,
529 ) -> CubeOptionExpand<TensorMap<E>> {
530 CubeOption::__expand_new_None(scope)
531 }
532 }
533}
534
535mod __tensor_map {
537 use crate::CubeOptionExpand;
538
539 use super::*;
540
541 impl<E: Numeric> VirtualTensorOperations<E> for TensorMap<E> {}
542 impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<TensorMap<E>> {
543 fn __expand_read_method(
544 &self,
545 _scope: &mut Scope,
546 _index: ExpandElementTyped<u32>,
547 ) -> ExpandElementTyped<Line<E>> {
548 todo!()
549 }
550 fn __expand_read_window_method(
551 &self,
552 _context: &mut Scope,
553 _start: ExpandElementTyped<u32>,
554 _end: ExpandElementTyped<u32>,
555 ) -> SliceExpand<Line<E>, ReadOnly> {
556 todo!()
557 }
558
559 fn __expand_write_method(
560 &self,
561 _scope: &mut Scope,
562 _index: ExpandElementTyped<u32>,
563 _value: ExpandElementTyped<Line<E>>,
564 ) {
565 todo!()
566 }
567
568 fn __expand_shape_method(
569 &self,
570 _scope: &mut Scope,
571 _axis: ExpandElementTyped<u32>,
572 ) -> ExpandElementTyped<u32> {
573 todo!()
574 }
575
576 fn __expand_stride_method(
577 &self,
578 _scope: &mut Scope,
579 _axis: ExpandElementTyped<u32>,
580 ) -> ExpandElementTyped<u32> {
581 todo!()
582 }
583
584 fn __expand_rank_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
585 todo!()
586 }
587 fn __expand_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
588 todo!()
589 }
590 fn __expand_buffer_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
591 todo!()
592 }
593
594 fn __expand_as_tensor_map_method(
595 &self,
596 scope: &mut Scope,
597 ) -> CubeOptionExpand<TensorMap<E>> {
598 CubeOption::__expand_new_Some(scope, self.clone())
599 }
600 }
601}