1use alloc::sync::Arc;
2use core::marker::PhantomData;
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, unexpanded};
5use std::ops::{Deref, DerefMut};
6
7use crate::tensor::{
8 ViewExpand,
9 layout::{
10 Coordinates, Coords1d, Layout, VirtualLayout, VirtualLayoutExpand, simple::SimpleLayout,
11 },
12 view::View,
13};
14
15#[derive(Clone)]
17pub struct VirtualTensor<E: Numeric, IO = ReadOnly> {
18 _e: PhantomData<E>,
20 _p: PhantomData<IO>,
21}
22
23impl<E: Numeric, IO: Clone> Copy for VirtualTensor<E, IO> {}
24
25#[derive(Clone)]
27pub struct VirtualTensorExpand<E: Numeric, IO> {
28 state: Arc<dyn VirtualTensorOperationsExpand<E>>,
29 _p: PhantomData<IO>,
30}
31
32impl<E: Numeric, IO: Clone> List<Line<E>> for VirtualTensor<E, IO> {
33 fn __expand_read(
34 scope: &mut Scope,
35 this: VirtualTensorExpand<E, IO>,
36 index: <usize as CubeType>::ExpandType,
37 ) -> <Line<E> as CubeType>::ExpandType {
38 this.__expand_read_method(scope, index)
39 }
40}
41
42impl<T: Numeric, IO: Clone> Deref for VirtualTensor<T, IO> {
43 type Target = [Line<T>];
44
45 fn deref(&self) -> &Self::Target {
46 unexpanded!()
47 }
48}
49
50impl<T: Numeric> DerefMut for VirtualTensor<T, ReadWrite> {
51 fn deref_mut(&mut self) -> &mut Self::Target {
52 unexpanded!()
53 }
54}
55
56impl<E: Numeric, IO: Clone> ListExpand<Line<E>> for VirtualTensorExpand<E, IO> {
57 fn __expand_read_method(
58 &self,
59 scope: &mut Scope,
60 index: <usize as CubeType>::ExpandType,
61 ) -> <Line<E> as CubeType>::ExpandType {
62 self.state.clone().__expand_read_method(scope, index)
63 }
64
65 fn __expand_read_unchecked_method(
66 &self,
67 _scope: &mut Scope,
68 _index: ExpandElementTyped<usize>,
69 ) -> <Line<E> as CubeType>::ExpandType {
70 todo!("VirtualTensor don't support read unchecked yet");
71 }
72
73 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
74 self.state.clone().__expand_len_method(scope)
75 }
76}
77
78impl<E: Numeric, IO: Clone> Lined for VirtualTensor<E, IO> {}
79impl<E: Numeric, IO: Clone> LinedExpand for VirtualTensorExpand<E, IO> {
80 fn line_size(&self) -> LineSize {
81 self.state.clone().line_size()
82 }
83}
84
85impl<E: Numeric, IO: Clone> SliceOperator<Line<E>> for VirtualTensor<E, IO> {}
86impl<E: Numeric, IO: Clone> SliceOperatorExpand<Line<E>> for VirtualTensorExpand<E, IO> {
87 fn __expand_slice_method(
88 &self,
89 scope: &mut Scope,
90 start: ExpandElementTyped<usize>,
91 end: ExpandElementTyped<usize>,
92 ) -> SliceExpand<Line<E>, ReadOnly> {
93 self.state
94 .clone()
95 .__expand_read_window_method(scope, start, end)
96 }
97
98 fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<Line<E>, ReadOnly> {
99 let end = self.clone().__expand_buffer_len_method(scope);
100 self.state
101 .clone()
102 .__expand_read_window_method(scope, 0.into(), end)
103 }
104}
105
106#[allow(unused, clippy::all)]
107impl<E: Numeric, IO: Clone> VirtualTensor<E, IO> {
108 pub fn as_tensor_map(&self) -> Option<TensorMap<E, Tiled>> {
109 unexpanded!()
110 }
111 pub fn as_slice(&self, start: usize, end: usize) -> Slice<Line<E>> {
112 unexpanded!();
113 }
114 pub fn shape(&self, axis: usize) -> usize {
116 unexpanded!();
117 }
118 pub fn stride(&self, axis: usize) -> usize {
120 unexpanded!();
121 }
122 pub fn rank(&self) -> usize {
124 unexpanded!();
125 }
126
127 pub fn buffer_len(&self) -> usize {
128 unexpanded!();
129 }
130
131 pub fn __expand_as_tensor_map(
132 context: &mut Scope,
133 this: <Self as CubeType>::ExpandType,
134 ) -> <Option<TensorMap<E, Tiled>> as CubeType>::ExpandType {
135 this.__expand_as_tensor_map_method(context)
136 }
137 pub fn __expand_as_slice(
138 context: &mut Scope,
139 this: <Self as CubeType>::ExpandType,
140 start: <usize as CubeType>::ExpandType,
141 end: <usize as CubeType>::ExpandType,
142 ) -> <Slice<Line<E>> as CubeType>::ExpandType {
143 this.__expand_as_slice_method(context, start, end)
144 }
145 pub fn __expand_shape(
146 scope: &mut Scope,
147 this: <Self as CubeType>::ExpandType,
148 axis: <usize as CubeType>::ExpandType,
149 ) -> <usize as CubeType>::ExpandType {
150 this.__expand_shape_method(scope, axis)
151 }
152 pub fn __expand_stride(
153 scope: &mut Scope,
154 this: <Self as CubeType>::ExpandType,
155 axis: <usize as CubeType>::ExpandType,
156 ) -> <usize as CubeType>::ExpandType {
157 this.__expand_stride_method(scope, axis)
158 }
159 pub fn __expand_rank(
160 scope: &mut Scope,
161 this: <Self as CubeType>::ExpandType,
162 ) -> <usize as CubeType>::ExpandType {
163 this.__expand_rank_method(scope)
164 }
165 pub fn __expand_buffer_len(
166 scope: &mut Scope,
167 this: <Self as CubeType>::ExpandType,
168 ) -> <usize as CubeType>::ExpandType {
169 this.__expand_buffer_len_method(scope)
170 }
171}
172
173#[allow(unused, clippy::all)]
174impl<E: Numeric, IO: Clone> VirtualTensorExpand<E, IO> {
175 pub fn __expand_as_tensor_map_method(
176 self,
177 context: &mut Scope,
178 ) -> <Option<TensorMap<E, Tiled>> as CubeType>::ExpandType {
179 self.state.clone().__expand_as_tensor_map_method(context)
180 }
181
182 pub fn __expand_as_slice_method(
183 self,
184 context: &mut Scope,
185 start: <usize as CubeType>::ExpandType,
186 end: <usize as CubeType>::ExpandType,
187 ) -> <Slice<Line<E>> as CubeType>::ExpandType {
188 self.state
189 .clone()
190 .__expand_read_window_method(context, start, end)
191 }
192
193 pub fn __expand_shape_method(
194 self,
195 scope: &mut Scope,
196 axis: <usize as CubeType>::ExpandType,
197 ) -> <usize as CubeType>::ExpandType {
198 let _arg_0 = axis;
199 self.state
200 .clone()
201 .__expand_shape_method(scope, _arg_0.into())
202 }
203
204 pub fn __expand_stride_method(
205 self,
206 scope: &mut Scope,
207 axis: <usize as CubeType>::ExpandType,
208 ) -> <usize as CubeType>::ExpandType {
209 let _arg_0 = axis;
210 self.state
211 .clone()
212 .__expand_stride_method(scope, _arg_0.into())
213 }
214
215 pub fn __expand_rank_method(self, scope: &mut Scope) -> <usize as CubeType>::ExpandType {
216 self.state.clone().__expand_rank_method(scope)
217 }
218
219 pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> <usize as CubeType>::ExpandType {
220 self.state.clone().__expand_buffer_len_method(scope)
221 }
222
223 pub fn __expand_read(
224 scope: &mut Scope,
225 this: Self,
226 index: <usize as CubeType>::ExpandType,
227 ) -> <Line<E> as CubeType>::ExpandType {
228 VirtualTensor::<E, IO>::__expand_read(scope, this, index)
229 }
230
231 pub fn __expand_shape(
232 scope: &mut Scope,
233 this: Self,
234 axis: <usize as CubeType>::ExpandType,
235 ) -> <usize as CubeType>::ExpandType {
236 VirtualTensor::<E, IO>::__expand_shape(scope, this, axis)
237 }
238
239 pub fn __expand_stride(
240 scope: &mut Scope,
241 this: Self,
242 axis: <usize as CubeType>::ExpandType,
243 ) -> <usize as CubeType>::ExpandType {
244 VirtualTensor::<E, IO>::__expand_stride(scope, this, axis)
245 }
246
247 pub fn __expand_rank(scope: &mut Scope, this: Self) -> <usize as CubeType>::ExpandType {
248 VirtualTensor::<E, IO>::__expand_rank(scope, this)
249 }
250}
251
252impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
253 pub fn view<C: Coordinates + 'static>(
256 &self,
257 layout: impl Into<VirtualLayout<C, Coords1d>>,
258 ) -> View<Line<E>, C, ReadOnly> {
259 View::new::<VirtualTensor<E, IO>, Coords1d>(self, layout)
260 }
261}
262
263#[cube]
264impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
265 pub fn as_view(&self) -> View<Line<E>, usize, ReadOnly> {
267 let line_size = self.line_size();
268 View::new::<VirtualTensor<E, IO>, usize>(
269 self,
270 SimpleLayout::new(self.len() * line_size, line_size),
271 )
272 }
273}
274
275impl<E: Numeric, IO: Clone + 'static> VirtualTensorExpand<E, IO> {
276 pub fn __expand_view_method<C: Coordinates + 'static>(
279 &self,
280 scope: &mut Scope,
281 layout: VirtualLayoutExpand<C, Coords1d>,
282 ) -> ViewExpand<Line<E>, C, ReadOnly> {
283 View::__expand_new::<VirtualTensor<E, IO>, Coords1d>(scope, self.clone(), layout)
284 }
285}
286
287impl<E: Numeric> VirtualTensor<E, ReadWrite> {
288 #[doc = " Create a mutable conceptual view over this tensor, allowing for multi-dimensional indexing"]
289 #[doc = " with custom layouts"]
290 pub fn view_mut<C: Coordinates + 'static>(
291 &self,
292 layout: impl Layout<Coordinates = C, SourceCoordinates = Coords1d> + 'static,
293 ) -> View<Line<E>, C, ReadWrite> {
294 let mut this: VirtualTensor<E, ReadWrite> = *self;
295 View::new_mut::<VirtualTensor<E, ReadWrite>, Coords1d>(&mut this, layout)
296 }
297 pub fn __expand_view_mut<C: Coordinates + 'static>(
298 scope: &mut Scope,
299 this: VirtualTensorExpand<E, ReadWrite>,
300 layout: VirtualLayoutExpand<C, Coords1d>,
301 ) -> ViewExpand<Line<E>, C, ReadWrite> {
302 this.__expand_view_mut_method::<C>(scope, layout)
303 }
304}
305impl<E: Numeric> VirtualTensorExpand<E, ReadWrite> {
306 pub fn __expand_view_mut_method<C: Coordinates + 'static>(
307 self,
308 scope: &mut Scope,
309 layout: VirtualLayoutExpand<C, Coords1d>,
310 ) -> ViewExpand<Line<E>, C, ReadWrite> {
311 View::__expand_new_mut::<VirtualTensor<E, ReadWrite>, Coords1d>(scope, self, layout)
312 }
313}
314
315#[cube]
316impl<E: Numeric> VirtualTensor<E, ReadWrite> {
317 pub fn as_view_mut(&mut self) -> View<Line<E>, usize, ReadWrite> {
319 let line_size = self.line_size();
320 View::new_mut::<VirtualTensor<E, ReadWrite>, usize>(
321 self,
322 SimpleLayout::new(self.len() * line_size, line_size),
323 )
324 }
325}
326
327#[cube]
328impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
329 pub fn coordinate(&self, index: usize, dim: usize) -> usize {
330 let num_strides = index / self.stride(dim);
331 num_strides % self.shape(dim)
332 }
333}
334
335impl<E: Numeric> ListMut<Line<E>> for VirtualTensor<E, ReadWrite> {
336 fn __expand_write(
337 scope: &mut Scope,
338 this: VirtualTensorExpand<E, ReadWrite>,
339 index: <usize as CubeType>::ExpandType,
340 value: <Line<E> as CubeType>::ExpandType,
341 ) -> <() as CubeType>::ExpandType {
342 this.__expand_write_method(scope, index, value)
343 }
344}
345
346impl<E: Numeric> ListMutExpand<Line<E>> for VirtualTensorExpand<E, ReadWrite> {
347 fn __expand_write_method(
348 &self,
349 scope: &mut Scope,
350 index: <usize as CubeType>::ExpandType,
351 value: <Line<E> as CubeType>::ExpandType,
352 ) -> <() as CubeType>::ExpandType {
353 self.state
354 .clone()
355 .__expand_write_method(scope, index, value)
356 }
357}
358
359impl<E: Numeric> SliceMutOperator<Line<E>> for VirtualTensor<E, ReadWrite> {}
360impl<E: Numeric> SliceMutOperatorExpand<Line<E>> for VirtualTensorExpand<E, ReadWrite> {
361 #[allow(unused_variables)]
362 fn __expand_slice_mut_method(
363 &self,
364 scope: &mut Scope,
365 start: ExpandElementTyped<usize>,
366 end: ExpandElementTyped<usize>,
367 ) -> SliceExpand<Line<E>, cubecl_core::prelude::ReadWrite> {
368 todo!("VirtualTensor don't support slice mut yet");
369 }
370
371 #[allow(unused_variables)]
372 fn __expand_to_slice_mut_method(
373 &self,
374 scope: &mut Scope,
375 ) -> SliceExpand<Line<E>, cubecl_core::prelude::ReadWrite> {
376 todo!("VirtualTensor don't support slice mut yet");
377 }
378}
379
380impl<E: Numeric> VirtualTensor<E, ReadOnly> {
381 pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &V) -> Self {
383 unexpanded!()
384 }
385
386 pub fn __expand_new<V: VirtualTensorOperations<E> + 'static>(
388 _scope: &mut Scope,
389 v: V::ExpandType,
390 ) -> VirtualTensorExpand<E, ReadOnly> {
391 VirtualTensorExpand {
392 state: Arc::new(v),
393 _p: PhantomData,
394 }
395 }
396}
397
398impl<E: Numeric> VirtualTensor<E, ReadWrite> {
399 pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &mut V) -> Self {
401 unexpanded!()
402 }
403
404 pub fn __expand_new<V: VirtualTensorOperations<E> + 'static>(
406 _scope: &mut Scope,
407 v: V::ExpandType,
408 ) -> VirtualTensorExpand<E, ReadWrite> {
409 VirtualTensorExpand {
410 state: Arc::new(v),
411 _p: PhantomData,
412 }
413 }
414}
415
416#[cube(self_type = "ref", expand_base_traits = "LinedExpand")]
426pub trait VirtualTensorOperations<E: Numeric>: Lined {
427 fn as_tensor_map(&self) -> Option<TensorMap<E, Tiled>> {
428 unexpanded!()
429 }
430 fn read(&self, _index: usize) -> Line<E> {
432 unexpanded!()
433 }
434 fn read_window(&self, _start: usize, _end: usize) -> Slice<Line<E>, ReadOnly> {
435 unexpanded!()
436 }
437 fn write(&self, _index: usize, _value: Line<E>) {
439 unexpanded!()
440 }
441 fn shape(&self, _axis: usize) -> usize {
443 unexpanded!()
444 }
445 fn stride(&self, _axis: usize) -> usize {
447 unexpanded!()
448 }
449 fn rank(&self) -> usize {
451 unexpanded!()
452 }
453 fn len(&self) -> usize {
454 unexpanded!()
455 }
456 fn buffer_len(&self) -> usize {
457 unexpanded!()
458 }
459}
460
461mod __cube_type {
463 use super::*;
464
465 impl<E: Numeric, IO: Clone> CubeType for VirtualTensor<E, IO> {
466 type ExpandType = VirtualTensorExpand<E, IO>;
467 }
468
469 impl<E: Numeric, IO> IntoMut for VirtualTensorExpand<E, IO> {
470 fn into_mut(self, _scope: &mut Scope) -> Self {
471 self
472 }
473 }
474
475 impl<E: Numeric, IO> CubeDebug for VirtualTensorExpand<E, IO> {}
476}
477
478mod __tensor {
480 use super::*;
481
482 impl<E: Numeric> VirtualTensorOperations<E> for Tensor<Line<E>> {}
483 impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<Tensor<Line<E>>> {
484 fn __expand_read_method(
485 &self,
486 scope: &mut Scope,
487 index: ExpandElementTyped<usize>,
488 ) -> ExpandElementTyped<Line<E>> {
489 self.clone().__expand_index_unchecked_method(scope, index)
490 }
491 fn __expand_read_window_method(
492 &self,
493 context: &mut Scope,
494 start: ExpandElementTyped<usize>,
495 end: ExpandElementTyped<usize>,
496 ) -> SliceExpand<Line<E>, ReadOnly> {
497 self.clone().__expand_slice_method(context, start, end)
498 }
499
500 fn __expand_write_method(
501 &self,
502 scope: &mut Scope,
503 index: ExpandElementTyped<usize>,
504 value: ExpandElementTyped<Line<E>>,
505 ) {
506 self.clone()
507 .__expand_index_assign_unchecked_method(scope, index, value)
508 }
509
510 fn __expand_shape_method(
511 &self,
512 scope: &mut Scope,
513 axis: ExpandElementTyped<usize>,
514 ) -> ExpandElementTyped<usize> {
515 self.clone().__expand_shape_method(scope, axis)
516 }
517
518 fn __expand_stride_method(
519 &self,
520 scope: &mut Scope,
521 axis: ExpandElementTyped<usize>,
522 ) -> ExpandElementTyped<usize> {
523 self.clone().__expand_stride_method(scope, axis)
524 }
525
526 fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
527 self.clone().__expand_rank_method(scope)
528 }
529 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
530 self.clone().__expand_len_method(scope)
531 }
532 fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
533 self.clone().__expand_buffer_len_method(scope)
534 }
535
536 fn __expand_as_tensor_map_method(
537 &self,
538 scope: &mut Scope,
539 ) -> OptionExpand<TensorMap<E, Tiled>> {
540 Option::__expand_new_None(scope)
541 }
542 }
543}
544
545mod __tensor_map {
547 use super::*;
548
549 impl<E: Numeric> VirtualTensorOperations<E> for TensorMap<E, Tiled> {}
550 impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<TensorMap<E, Tiled>> {
551 fn __expand_read_method(
552 &self,
553 _scope: &mut Scope,
554 _index: ExpandElementTyped<usize>,
555 ) -> ExpandElementTyped<Line<E>> {
556 todo!()
557 }
558 fn __expand_read_window_method(
559 &self,
560 _context: &mut Scope,
561 _start: ExpandElementTyped<usize>,
562 _end: ExpandElementTyped<usize>,
563 ) -> SliceExpand<Line<E>, ReadOnly> {
564 todo!()
565 }
566
567 fn __expand_write_method(
568 &self,
569 _scope: &mut Scope,
570 _index: ExpandElementTyped<usize>,
571 _value: ExpandElementTyped<Line<E>>,
572 ) {
573 todo!()
574 }
575
576 fn __expand_shape_method(
577 &self,
578 _scope: &mut Scope,
579 _axis: ExpandElementTyped<usize>,
580 ) -> ExpandElementTyped<usize> {
581 todo!()
582 }
583
584 fn __expand_stride_method(
585 &self,
586 _scope: &mut Scope,
587 _axis: ExpandElementTyped<usize>,
588 ) -> ExpandElementTyped<usize> {
589 todo!()
590 }
591
592 fn __expand_rank_method(&self, _scope: &mut Scope) -> ExpandElementTyped<usize> {
593 todo!()
594 }
595 fn __expand_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<usize> {
596 todo!()
597 }
598 fn __expand_buffer_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<usize> {
599 todo!()
600 }
601
602 fn __expand_as_tensor_map_method(
603 &self,
604 scope: &mut Scope,
605 ) -> OptionExpand<TensorMap<E, Tiled>> {
606 Option::__expand_new_Some(scope, self.clone())
607 }
608 }
609}