cubecl_std/tensor/layout/
as_dyn.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, unexpanded};
3use variadics_please::all_tuples;
4
5use crate::tensor::layout::*;
6
7/// Coordinates that can be converted to a dynamic sequence of signed coordinates.
8/// Can be used to convert any set of coordinates to a comptime-sized sequence for use with TMA.
9#[cube]
10pub trait IntoDyn: Coordinates + LaunchArg {
11    fn into_dyn(self) -> Sequence<i32> {
12        unexpanded!()
13    }
14}
15
16macro_rules! impl_tuple {
17    ($(($T: ident, $t: ident)),*) => {
18        impl<$($T: Coordinates + CubePrimitive + LaunchArg),*> IntoDyn for ($($T),*) {}
19
20        impl<$($T: Coordinates + CubePrimitive + LaunchArg),*> IntoDynExpand for ($(ExpandElementTyped<$T>),*) {
21            fn __expand_into_dyn_method(self, scope: &mut Scope) -> SequenceExpand<i32> {
22                let mut seq = Sequence::__expand_new(scope);
23                let ($($t),*) = self;
24                let ($($t),*) = ($(i32::__expand_cast_from(scope, $t)),*);
25                $(seq.__expand_push_method(scope, $t);)*
26                seq
27            }
28        }
29    };
30}
31
32all_tuples!(impl_tuple, 2, 12, T, t);
33
34#[cube]
35impl IntoDyn for Sequence<i32> {
36    fn into_dyn(self) -> Sequence<i32> {
37        self
38    }
39}
40
41#[cube]
42impl IntoDyn for Sequence<u32> {
43    fn into_dyn(self) -> Sequence<i32> {
44        let mut seq = Sequence::new();
45        for x in self {
46            seq.push(i32::cast_from(x));
47        }
48        seq
49    }
50}
51
52#[derive(CubeType, CubeLaunch)]
53pub struct IntoDynLayout<L: Layout<SourceCoordinates: IntoDyn> + LaunchArg> {
54    layout: L,
55}
56
57#[derive(CubeType, CubeLaunch)]
58pub struct IntoDyn2Layout<L: Layout<SourceCoordinates = (P, O)> + LaunchArg, P: IntoDyn, O: IntoDyn>
59{
60    layout: L,
61}
62
63impl<L: Layout<SourceCoordinates: IntoDyn> + LaunchArg> IntoDynLayout<L> {
64    pub fn new(layout: L) -> Self {
65        IntoDynLayout { layout }
66    }
67}
68
69impl<L: Layout<SourceCoordinates = (P, O)> + LaunchArg, P: IntoDyn, O: IntoDyn + LaunchArg>
70    IntoDyn2Layout<L, P, O>
71{
72    pub fn new(layout: L) -> Self {
73        IntoDyn2Layout { layout }
74    }
75}
76
77#[cube]
78impl<L: Layout<SourceCoordinates: IntoDyn> + LaunchArg> Layout for IntoDynLayout<L> {
79    type Coordinates = L::Coordinates;
80    type SourceCoordinates = Sequence<i32>;
81
82    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
83        let pos = self.layout.to_source_pos(pos);
84        pos.into_dyn()
85    }
86
87    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
88        self.layout.is_in_bounds(pos)
89    }
90
91    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
92        let (pos, in_bounds) = self.layout.to_source_pos_checked(pos);
93        (pos.into_dyn(), in_bounds)
94    }
95
96    fn shape(&self) -> Self::Coordinates {
97        self.layout.shape()
98    }
99}
100
101#[cube]
102impl<L: Layout<SourceCoordinates = (P, O)> + LaunchArg, P: IntoDyn, O: IntoDyn + LaunchArg> Layout
103    for IntoDyn2Layout<L, P, O>
104{
105    type Coordinates = L::Coordinates;
106    type SourceCoordinates = (Sequence<i32>, Sequence<i32>);
107
108    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
109        let pos = self.layout.to_source_pos(pos);
110        (pos.0.into_dyn(), pos.1.into_dyn())
111    }
112
113    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
114        self.layout.is_in_bounds(pos)
115    }
116
117    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
118        let (pos, in_bounds) = self.layout.to_source_pos_checked(pos);
119        ((pos.0.into_dyn(), pos.1.into_dyn()), in_bounds)
120    }
121
122    fn shape(&self) -> Self::Coordinates {
123        self.layout.shape()
124    }
125}