1use alloc::sync::Arc;
2use core::marker::PhantomData;
3use cubecl::prelude::{CubeType, Scope, *};
4use cubecl_core::{self as cubecl, unexpanded};
5use std::ops::{Deref, DerefMut};
6
7use crate::{
8 CubeOption,
9 tensor::{
10 ViewExpand,
11 layout::{
12 Coordinates, Coords1d, Layout, VirtualLayout, VirtualLayoutExpand, simple::SimpleLayout,
13 },
14 view::View,
15 },
16};
17
18#[derive(Clone)]
20pub struct VirtualTensor<E: Numeric, IO = ReadOnly> {
21 _e: PhantomData<E>,
23 _p: PhantomData<IO>,
24}
25
26impl<E: Numeric, IO: Clone> Copy for VirtualTensor<E, IO> {}
27
28#[derive(Clone)]
30pub struct VirtualTensorExpand<E: Numeric, IO> {
31 state: Arc<dyn VirtualTensorOperationsExpand<E>>,
32 _p: PhantomData<IO>,
33}
34
35impl<E: Numeric, IO: Clone> List<Line<E>> for VirtualTensor<E, IO> {
36 fn __expand_read(
37 scope: &mut Scope,
38 this: VirtualTensorExpand<E, IO>,
39 index: <usize as CubeType>::ExpandType,
40 ) -> <Line<E> as CubeType>::ExpandType {
41 this.__expand_read_method(scope, index)
42 }
43}
44
45impl<T: Numeric, IO: Clone> Deref for VirtualTensor<T, IO> {
46 type Target = [Line<T>];
47
48 fn deref(&self) -> &Self::Target {
49 unexpanded!()
50 }
51}
52
53impl<T: Numeric> DerefMut for VirtualTensor<T, ReadWrite> {
54 fn deref_mut(&mut self) -> &mut Self::Target {
55 unexpanded!()
56 }
57}
58
59impl<E: Numeric, IO: Clone> ListExpand<Line<E>> for VirtualTensorExpand<E, IO> {
60 fn __expand_read_method(
61 &self,
62 scope: &mut Scope,
63 index: <usize as CubeType>::ExpandType,
64 ) -> <Line<E> as CubeType>::ExpandType {
65 self.state.clone().__expand_read_method(scope, index)
66 }
67
68 fn __expand_read_unchecked_method(
69 &self,
70 _scope: &mut Scope,
71 _index: ExpandElementTyped<usize>,
72 ) -> <Line<E> as CubeType>::ExpandType {
73 todo!("VirtualTensor don't support read unchecked yet");
74 }
75
76 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
77 self.state.clone().__expand_len_method(scope)
78 }
79}
80
81impl<E: Numeric, IO: Clone> Lined for VirtualTensor<E, IO> {}
82impl<E: Numeric, IO: Clone> LinedExpand for VirtualTensorExpand<E, IO> {
83 fn line_size(&self) -> LineSize {
84 self.state.clone().line_size()
85 }
86}
87
88impl<E: Numeric, IO: Clone> SliceOperator<Line<E>> for VirtualTensor<E, IO> {}
89impl<E: Numeric, IO: Clone> SliceOperatorExpand<Line<E>> for VirtualTensorExpand<E, IO> {
90 fn __expand_slice_method(
91 &self,
92 scope: &mut Scope,
93 start: ExpandElementTyped<usize>,
94 end: ExpandElementTyped<usize>,
95 ) -> SliceExpand<Line<E>, ReadOnly> {
96 self.state
97 .clone()
98 .__expand_read_window_method(scope, start, end)
99 }
100
101 fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<Line<E>, ReadOnly> {
102 let end = self.clone().__expand_buffer_len_method(scope);
103 self.state
104 .clone()
105 .__expand_read_window_method(scope, 0.into(), end)
106 }
107}
108
109#[allow(unused, clippy::all)]
110impl<E: Numeric, IO: Clone> VirtualTensor<E, IO> {
111 pub fn as_tensor_map(&self) -> CubeOption<TensorMap<E, Tiled>> {
112 unexpanded!()
113 }
114 pub fn as_slice(&self, start: usize, end: usize) -> Slice<Line<E>> {
115 unexpanded!();
116 }
117 pub fn shape(&self, axis: usize) -> usize {
119 unexpanded!();
120 }
121 pub fn stride(&self, axis: usize) -> usize {
123 unexpanded!();
124 }
125 pub fn rank(&self) -> usize {
127 unexpanded!();
128 }
129
130 pub fn buffer_len(&self) -> usize {
131 unexpanded!();
132 }
133
134 pub fn __expand_as_tensor_map(
135 context: &mut Scope,
136 this: <Self as CubeType>::ExpandType,
137 ) -> <CubeOption<TensorMap<E, Tiled>> as CubeType>::ExpandType {
138 this.__expand_as_tensor_map_method(context)
139 }
140 pub fn __expand_as_slice(
141 context: &mut Scope,
142 this: <Self as CubeType>::ExpandType,
143 start: <usize as CubeType>::ExpandType,
144 end: <usize as CubeType>::ExpandType,
145 ) -> <Slice<Line<E>> as CubeType>::ExpandType {
146 this.__expand_as_slice_method(context, start, end)
147 }
148 pub fn __expand_shape(
149 scope: &mut Scope,
150 this: <Self as CubeType>::ExpandType,
151 axis: <usize as CubeType>::ExpandType,
152 ) -> <usize as CubeType>::ExpandType {
153 this.__expand_shape_method(scope, axis)
154 }
155 pub fn __expand_stride(
156 scope: &mut Scope,
157 this: <Self as CubeType>::ExpandType,
158 axis: <usize as CubeType>::ExpandType,
159 ) -> <usize as CubeType>::ExpandType {
160 this.__expand_stride_method(scope, axis)
161 }
162 pub fn __expand_rank(
163 scope: &mut Scope,
164 this: <Self as CubeType>::ExpandType,
165 ) -> <usize as CubeType>::ExpandType {
166 this.__expand_rank_method(scope)
167 }
168 pub fn __expand_buffer_len(
169 scope: &mut Scope,
170 this: <Self as CubeType>::ExpandType,
171 ) -> <usize as CubeType>::ExpandType {
172 this.__expand_buffer_len_method(scope)
173 }
174}
175
176#[allow(unused, clippy::all)]
177impl<E: Numeric, IO: Clone> VirtualTensorExpand<E, IO> {
178 pub fn __expand_as_tensor_map_method(
179 self,
180 context: &mut Scope,
181 ) -> <CubeOption<TensorMap<E, Tiled>> as CubeType>::ExpandType {
182 self.state.clone().__expand_as_tensor_map_method(context)
183 }
184
185 pub fn __expand_as_slice_method(
186 self,
187 context: &mut Scope,
188 start: <usize as CubeType>::ExpandType,
189 end: <usize as CubeType>::ExpandType,
190 ) -> <Slice<Line<E>> as CubeType>::ExpandType {
191 self.state
192 .clone()
193 .__expand_read_window_method(context, start, end)
194 }
195
196 pub fn __expand_shape_method(
197 self,
198 scope: &mut Scope,
199 axis: <usize as CubeType>::ExpandType,
200 ) -> <usize as CubeType>::ExpandType {
201 let _arg_0 = axis;
202 self.state
203 .clone()
204 .__expand_shape_method(scope, _arg_0.into())
205 }
206
207 pub fn __expand_stride_method(
208 self,
209 scope: &mut Scope,
210 axis: <usize as CubeType>::ExpandType,
211 ) -> <usize as CubeType>::ExpandType {
212 let _arg_0 = axis;
213 self.state
214 .clone()
215 .__expand_stride_method(scope, _arg_0.into())
216 }
217
218 pub fn __expand_rank_method(self, scope: &mut Scope) -> <usize as CubeType>::ExpandType {
219 self.state.clone().__expand_rank_method(scope)
220 }
221
222 pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> <usize as CubeType>::ExpandType {
223 self.state.clone().__expand_buffer_len_method(scope)
224 }
225
226 pub fn __expand_read(
227 scope: &mut Scope,
228 this: Self,
229 index: <usize as CubeType>::ExpandType,
230 ) -> <Line<E> as CubeType>::ExpandType {
231 VirtualTensor::<E, IO>::__expand_read(scope, this, index)
232 }
233
234 pub fn __expand_shape(
235 scope: &mut Scope,
236 this: Self,
237 axis: <usize as CubeType>::ExpandType,
238 ) -> <usize as CubeType>::ExpandType {
239 VirtualTensor::<E, IO>::__expand_shape(scope, this, axis)
240 }
241
242 pub fn __expand_stride(
243 scope: &mut Scope,
244 this: Self,
245 axis: <usize as CubeType>::ExpandType,
246 ) -> <usize as CubeType>::ExpandType {
247 VirtualTensor::<E, IO>::__expand_stride(scope, this, axis)
248 }
249
250 pub fn __expand_rank(scope: &mut Scope, this: Self) -> <usize as CubeType>::ExpandType {
251 VirtualTensor::<E, IO>::__expand_rank(scope, this)
252 }
253}
254
255impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
256 pub fn view<C: Coordinates + 'static>(
259 &self,
260 layout: impl Into<VirtualLayout<C, Coords1d>>,
261 ) -> View<Line<E>, C, ReadOnly> {
262 View::new::<VirtualTensor<E, IO>, Coords1d>(self, layout)
263 }
264}
265
266#[cube]
267impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
268 pub fn as_view(&self) -> View<Line<E>, usize, ReadOnly> {
270 let line_size = self.line_size();
271 View::new::<VirtualTensor<E, IO>, usize>(
272 self,
273 SimpleLayout::new(self.len() * line_size, line_size),
274 )
275 }
276}
277
278impl<E: Numeric, IO: Clone + 'static> VirtualTensorExpand<E, IO> {
279 pub fn __expand_view_method<C: Coordinates + 'static>(
282 &self,
283 scope: &mut Scope,
284 layout: VirtualLayoutExpand<C, Coords1d>,
285 ) -> ViewExpand<Line<E>, C, ReadOnly> {
286 View::__expand_new::<VirtualTensor<E, IO>, Coords1d>(scope, self.clone(), layout)
287 }
288}
289
290impl<E: Numeric> VirtualTensor<E, ReadWrite> {
291 #[doc = " Create a mutable conceptual view over this tensor, allowing for multi-dimensional indexing"]
292 #[doc = " with custom layouts"]
293 pub fn view_mut<C: Coordinates + 'static>(
294 &self,
295 layout: impl Layout<Coordinates = C, SourceCoordinates = Coords1d> + 'static,
296 ) -> View<Line<E>, C, ReadWrite> {
297 let mut this: VirtualTensor<E, ReadWrite> = *self;
298 View::new_mut::<VirtualTensor<E, ReadWrite>, Coords1d>(&mut this, layout)
299 }
300 pub fn __expand_view_mut<C: Coordinates + 'static>(
301 scope: &mut Scope,
302 this: VirtualTensorExpand<E, ReadWrite>,
303 layout: VirtualLayoutExpand<C, Coords1d>,
304 ) -> ViewExpand<Line<E>, C, ReadWrite> {
305 this.__expand_view_mut_method::<C>(scope, layout)
306 }
307}
308impl<E: Numeric> VirtualTensorExpand<E, ReadWrite> {
309 pub fn __expand_view_mut_method<C: Coordinates + 'static>(
310 self,
311 scope: &mut Scope,
312 layout: VirtualLayoutExpand<C, Coords1d>,
313 ) -> ViewExpand<Line<E>, C, ReadWrite> {
314 View::__expand_new_mut::<VirtualTensor<E, ReadWrite>, Coords1d>(scope, self, layout)
315 }
316}
317
318#[cube]
319impl<E: Numeric> VirtualTensor<E, ReadWrite> {
320 pub fn as_view_mut(&mut self) -> View<Line<E>, usize, ReadWrite> {
322 let line_size = self.line_size();
323 View::new_mut::<VirtualTensor<E, ReadWrite>, usize>(
324 self,
325 SimpleLayout::new(self.len() * line_size, line_size),
326 )
327 }
328}
329
330#[cube]
331impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
332 pub fn coordinate(&self, index: usize, dim: usize) -> usize {
333 let num_strides = index / self.stride(dim);
334 num_strides % self.shape(dim)
335 }
336}
337
338impl<E: Numeric> ListMut<Line<E>> for VirtualTensor<E, ReadWrite> {
339 fn __expand_write(
340 scope: &mut Scope,
341 this: VirtualTensorExpand<E, ReadWrite>,
342 index: <usize as CubeType>::ExpandType,
343 value: <Line<E> as CubeType>::ExpandType,
344 ) -> <() as CubeType>::ExpandType {
345 this.__expand_write_method(scope, index, value)
346 }
347}
348
349impl<E: Numeric> ListMutExpand<Line<E>> for VirtualTensorExpand<E, ReadWrite> {
350 fn __expand_write_method(
351 &self,
352 scope: &mut Scope,
353 index: <usize as CubeType>::ExpandType,
354 value: <Line<E> as CubeType>::ExpandType,
355 ) -> <() as CubeType>::ExpandType {
356 self.state
357 .clone()
358 .__expand_write_method(scope, index, value)
359 }
360}
361
362impl<E: Numeric> SliceMutOperator<Line<E>> for VirtualTensor<E, ReadWrite> {}
363impl<E: Numeric> SliceMutOperatorExpand<Line<E>> for VirtualTensorExpand<E, ReadWrite> {
364 #[allow(unused_variables)]
365 fn __expand_slice_mut_method(
366 &self,
367 scope: &mut Scope,
368 start: ExpandElementTyped<usize>,
369 end: ExpandElementTyped<usize>,
370 ) -> SliceExpand<Line<E>, cubecl_core::prelude::ReadWrite> {
371 todo!("VirtualTensor don't support slice mut yet");
372 }
373
374 #[allow(unused_variables)]
375 fn __expand_to_slice_mut_method(
376 &self,
377 scope: &mut Scope,
378 ) -> SliceExpand<Line<E>, cubecl_core::prelude::ReadWrite> {
379 todo!("VirtualTensor don't support slice mut yet");
380 }
381}
382
383impl<E: Numeric> VirtualTensor<E, ReadOnly> {
384 pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &V) -> Self {
386 unexpanded!()
387 }
388
389 pub fn __expand_new<V: VirtualTensorOperations<E> + 'static>(
391 _scope: &mut Scope,
392 v: V::ExpandType,
393 ) -> VirtualTensorExpand<E, ReadOnly> {
394 VirtualTensorExpand {
395 state: Arc::new(v),
396 _p: PhantomData,
397 }
398 }
399}
400
401impl<E: Numeric> VirtualTensor<E, ReadWrite> {
402 pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &mut V) -> Self {
404 unexpanded!()
405 }
406
407 pub fn __expand_new<V: VirtualTensorOperations<E> + 'static>(
409 _scope: &mut Scope,
410 v: V::ExpandType,
411 ) -> VirtualTensorExpand<E, ReadWrite> {
412 VirtualTensorExpand {
413 state: Arc::new(v),
414 _p: PhantomData,
415 }
416 }
417}
418
419#[cube(self_type = "ref", expand_base_traits = "LinedExpand")]
429pub trait VirtualTensorOperations<E: Numeric>: Lined {
430 fn as_tensor_map(&self) -> CubeOption<TensorMap<E, Tiled>> {
431 unexpanded!()
432 }
433 fn read(&self, _index: usize) -> Line<E> {
435 unexpanded!()
436 }
437 fn read_window(&self, _start: usize, _end: usize) -> Slice<Line<E>, ReadOnly> {
438 unexpanded!()
439 }
440 fn write(&self, _index: usize, _value: Line<E>) {
442 unexpanded!()
443 }
444 fn shape(&self, _axis: usize) -> usize {
446 unexpanded!()
447 }
448 fn stride(&self, _axis: usize) -> usize {
450 unexpanded!()
451 }
452 fn rank(&self) -> usize {
454 unexpanded!()
455 }
456 fn len(&self) -> usize {
457 unexpanded!()
458 }
459 fn buffer_len(&self) -> usize {
460 unexpanded!()
461 }
462}
463
464mod __cube_type {
466 use super::*;
467
468 impl<E: Numeric, IO: Clone> CubeType for VirtualTensor<E, IO> {
469 type ExpandType = VirtualTensorExpand<E, IO>;
470 }
471
472 impl<E: Numeric, IO> IntoMut for VirtualTensorExpand<E, IO> {
473 fn into_mut(self, _scope: &mut Scope) -> Self {
474 self
475 }
476 }
477
478 impl<E: Numeric, IO> CubeDebug for VirtualTensorExpand<E, IO> {}
479}
480
481mod __tensor {
483 use crate::CubeOptionExpand;
484
485 use super::*;
486
487 impl<E: Numeric> VirtualTensorOperations<E> for Tensor<Line<E>> {}
488 impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<Tensor<Line<E>>> {
489 fn __expand_read_method(
490 &self,
491 scope: &mut Scope,
492 index: ExpandElementTyped<usize>,
493 ) -> ExpandElementTyped<Line<E>> {
494 self.clone().__expand_index_unchecked_method(scope, index)
495 }
496 fn __expand_read_window_method(
497 &self,
498 context: &mut Scope,
499 start: ExpandElementTyped<usize>,
500 end: ExpandElementTyped<usize>,
501 ) -> SliceExpand<Line<E>, ReadOnly> {
502 self.clone().__expand_slice_method(context, start, end)
503 }
504
505 fn __expand_write_method(
506 &self,
507 scope: &mut Scope,
508 index: ExpandElementTyped<usize>,
509 value: ExpandElementTyped<Line<E>>,
510 ) {
511 self.clone()
512 .__expand_index_assign_unchecked_method(scope, index, value)
513 }
514
515 fn __expand_shape_method(
516 &self,
517 scope: &mut Scope,
518 axis: ExpandElementTyped<usize>,
519 ) -> ExpandElementTyped<usize> {
520 self.clone().__expand_shape_method(scope, axis)
521 }
522
523 fn __expand_stride_method(
524 &self,
525 scope: &mut Scope,
526 axis: ExpandElementTyped<usize>,
527 ) -> ExpandElementTyped<usize> {
528 self.clone().__expand_stride_method(scope, axis)
529 }
530
531 fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
532 self.clone().__expand_rank_method(scope)
533 }
534 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
535 self.clone().__expand_len_method(scope)
536 }
537 fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
538 self.clone().__expand_buffer_len_method(scope)
539 }
540
541 fn __expand_as_tensor_map_method(
542 &self,
543 scope: &mut Scope,
544 ) -> CubeOptionExpand<TensorMap<E, Tiled>> {
545 CubeOption::__expand_new_None(scope)
546 }
547 }
548}
549
550mod __tensor_map {
552 use crate::CubeOptionExpand;
553
554 use super::*;
555
556 impl<E: Numeric> VirtualTensorOperations<E> for TensorMap<E, Tiled> {}
557 impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<TensorMap<E, Tiled>> {
558 fn __expand_read_method(
559 &self,
560 _scope: &mut Scope,
561 _index: ExpandElementTyped<usize>,
562 ) -> ExpandElementTyped<Line<E>> {
563 todo!()
564 }
565 fn __expand_read_window_method(
566 &self,
567 _context: &mut Scope,
568 _start: ExpandElementTyped<usize>,
569 _end: ExpandElementTyped<usize>,
570 ) -> SliceExpand<Line<E>, ReadOnly> {
571 todo!()
572 }
573
574 fn __expand_write_method(
575 &self,
576 _scope: &mut Scope,
577 _index: ExpandElementTyped<usize>,
578 _value: ExpandElementTyped<Line<E>>,
579 ) {
580 todo!()
581 }
582
583 fn __expand_shape_method(
584 &self,
585 _scope: &mut Scope,
586 _axis: ExpandElementTyped<usize>,
587 ) -> ExpandElementTyped<usize> {
588 todo!()
589 }
590
591 fn __expand_stride_method(
592 &self,
593 _scope: &mut Scope,
594 _axis: ExpandElementTyped<usize>,
595 ) -> ExpandElementTyped<usize> {
596 todo!()
597 }
598
599 fn __expand_rank_method(&self, _scope: &mut Scope) -> ExpandElementTyped<usize> {
600 todo!()
601 }
602 fn __expand_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<usize> {
603 todo!()
604 }
605 fn __expand_buffer_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<usize> {
606 todo!()
607 }
608
609 fn __expand_as_tensor_map_method(
610 &self,
611 scope: &mut Scope,
612 ) -> CubeOptionExpand<TensorMap<E, Tiled>> {
613 CubeOption::__expand_new_Some(scope, self.clone())
614 }
615 }
616}