cubecl_std/tensor/layout/
as_dyn.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, unexpanded};
3use variadics_please::all_tuples;
4
5use crate::tensor::layout::*;
6
7/// Coordinates that can be converted to a dynamic sequence of signed coordinates.
8/// Can be used to convert any set of coordinates to a comptime-sized sequence for use with TMA.
9#[cube]
10pub trait IntoDyn: Coordinates + LaunchArg {
11    fn into_dyn(self) -> Sequence<i32> {
12        unexpanded!()
13    }
14}
15
16macro_rules! as_ty {
17    ($T: ident, $dummy: ident) => {
18        $T
19    };
20}
21
22macro_rules! impl_tuple {
23    ($ty: ident, $($t: ident),*) => {
24        impl IntoDyn for ($(as_ty!($ty, $t)),*) {}
25
26        impl IntoDynExpand for ($(ExpandElementTyped<as_ty!($ty, $t)>),*) {
27            fn __expand_into_dyn_method(self, scope: &mut Scope) -> SequenceExpand<i32> {
28                let mut seq = Sequence::__expand_new(scope);
29                let ($($t),*) = self;
30                let ($($t),*) = ($(i32::__expand_cast_from(scope, $t)),*);
31                $(seq.__expand_push_method(scope, $t);)*
32                seq
33            }
34        }
35    };
36}
37
38macro_rules! impl_tuples {
39    ($($t: ident),*) => {
40        impl_tuple!(u32, $($t),*);
41        impl_tuple!(i32, $($t),*);
42    };
43}
44
45all_tuples!(impl_tuples, 2, 12, t);
46
47#[cube]
48impl IntoDyn for Sequence<i32> {
49    fn into_dyn(self) -> Sequence<i32> {
50        self
51    }
52}
53
54#[cube]
55impl IntoDyn for Sequence<u32> {
56    fn into_dyn(self) -> Sequence<i32> {
57        let mut seq = Sequence::new();
58        for x in self {
59            seq.push(i32::cast_from(x));
60        }
61        seq
62    }
63}
64
65#[derive(CubeType, CubeLaunch)]
66pub struct IntoDynLayout<L: Layout<SourceCoordinates: IntoDyn> + LaunchArg> {
67    layout: L,
68}
69
70#[derive(CubeType, CubeLaunch)]
71pub struct IntoDyn2Layout<L: Layout<SourceCoordinates = (P, O)> + LaunchArg, P: IntoDyn, O: IntoDyn>
72{
73    layout: L,
74}
75
76impl<L: Layout<SourceCoordinates: IntoDyn> + LaunchArg> IntoDynLayout<L> {
77    pub fn new(layout: L) -> Self {
78        IntoDynLayout { layout }
79    }
80}
81
82impl<L: Layout<SourceCoordinates = (P, O)> + LaunchArg, P: IntoDyn, O: IntoDyn + LaunchArg>
83    IntoDyn2Layout<L, P, O>
84{
85    pub fn new(layout: L) -> Self {
86        IntoDyn2Layout { layout }
87    }
88}
89
90#[cube]
91impl<L: Layout<SourceCoordinates: IntoDyn> + LaunchArg> Layout for IntoDynLayout<L> {
92    type Coordinates = L::Coordinates;
93    type SourceCoordinates = Sequence<i32>;
94
95    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
96        let pos = self.layout.to_source_pos(pos);
97        pos.into_dyn()
98    }
99
100    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
101        self.layout.is_in_bounds(pos)
102    }
103
104    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
105        let (pos, in_bounds) = self.layout.to_source_pos_checked(pos);
106        (pos.into_dyn(), in_bounds)
107    }
108
109    fn shape(&self) -> Self::Coordinates {
110        self.layout.shape()
111    }
112}
113
114#[cube]
115impl<L: Layout<SourceCoordinates = (P, O)> + LaunchArg, P: IntoDyn, O: IntoDyn + LaunchArg> Layout
116    for IntoDyn2Layout<L, P, O>
117{
118    type Coordinates = L::Coordinates;
119    type SourceCoordinates = (Sequence<i32>, Sequence<i32>);
120
121    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
122        let pos = self.layout.to_source_pos(pos);
123        (pos.0.into_dyn(), pos.1.into_dyn())
124    }
125
126    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
127        self.layout.is_in_bounds(pos)
128    }
129
130    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
131        let (pos, in_bounds) = self.layout.to_source_pos_checked(pos);
132        ((pos.0.into_dyn(), pos.1.into_dyn()), in_bounds)
133    }
134
135    fn shape(&self) -> Self::Coordinates {
136        self.layout.shape()
137    }
138}