Skip to main content

cubecl_std/tensor/layout/
as_dyn.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, unexpanded};
3use variadics_please::all_tuples;
4
5use crate::tensor::{
6    launch::{BufferArg, ViewLayoutLaunchArg},
7    layout::*,
8};
9
10/// Coordinates that can be converted to a dynamic sequence of signed coordinates.
11/// Can be used to convert any set of coordinates to a comptime-sized sequence for use with TMA.
12#[cube]
13pub trait IntoDyn: Coordinates + LaunchArg {
14    fn into_dyn(self) -> Sequence<i32> {
15        unexpanded!()
16    }
17}
18
19macro_rules! impl_tuple {
20    ($(($T: ident, $t: ident)),*) => {
21        impl<$($T: Coordinates + CubePrimitive + LaunchArg),*> IntoDyn for ($($T),*) {}
22
23        impl<$($T: Coordinates + CubePrimitive + LaunchArg),*> IntoDynExpand for ($(NativeExpand<$T>),*) {
24            fn __expand_into_dyn_method(self, scope: &mut Scope) -> SequenceExpand<i32> {
25                let mut seq = Sequence::__expand_new(scope);
26                let ($($t),*) = self;
27                let ($($t),*) = ($(i32::__expand_cast_from(scope, $t)),*);
28                $(seq.__expand_push_method(scope, $t);)*
29                seq
30            }
31        }
32    };
33}
34
35all_tuples!(impl_tuple, 2, 12, T, t);
36
37#[cube]
38impl IntoDyn for Sequence<i32> {
39    fn into_dyn(self) -> Sequence<i32> {
40        self
41    }
42}
43
44#[cube]
45impl IntoDyn for Sequence<u32> {
46    fn into_dyn(self) -> Sequence<i32> {
47        let mut seq = Sequence::new();
48        for x in self {
49            seq.push(i32::cast_from(x));
50        }
51        seq
52    }
53}
54
55#[derive(CubeType)]
56pub struct IntoDynLayout<L: Layout<SourceCoordinates: IntoDyn> + ViewLayoutLaunchArg> {
57    layout: L,
58}
59
60impl<L: Layout<SourceCoordinates: IntoDyn> + ViewLayoutLaunchArg> ViewLayoutLaunchArg
61    for IntoDynLayout<L>
62{
63    type RuntimeArg<R: Runtime> = L::RuntimeArg<R>;
64    type CompilationArg = L::CompilationArg;
65
66    fn register<R: Runtime, B: BufferArg>(
67        arg: Self::RuntimeArg<R>,
68        buffer: &B,
69        ty: Type,
70        launcher: &mut KernelLauncher<R>,
71    ) -> Self::CompilationArg {
72        L::register::<R, B>(arg, buffer, ty, launcher)
73    }
74    fn expand(
75        arg: &Self::CompilationArg,
76        ty: Type,
77        builder: &mut KernelBuilder,
78    ) -> <Self as CubeType>::ExpandType {
79        IntoDynLayoutExpand {
80            layout: L::expand(arg, ty, builder),
81        }
82    }
83    fn expand_output(
84        arg: &Self::CompilationArg,
85        ty: Type,
86        builder: &mut KernelBuilder,
87    ) -> <Self as CubeType>::ExpandType {
88        IntoDynLayoutExpand {
89            layout: L::expand_output(arg, ty, builder),
90        }
91    }
92}
93
94#[derive(CubeType)]
95pub struct IntoDyn2Layout<
96    L: Layout<SourceCoordinates = (P, O)> + ViewLayoutLaunchArg,
97    P: IntoDyn,
98    O: IntoDyn,
99> {
100    layout: L,
101}
102
103impl<L: Layout<SourceCoordinates = (P, O)> + ViewLayoutLaunchArg, P: IntoDyn, O: IntoDyn>
104    ViewLayoutLaunchArg for IntoDyn2Layout<L, P, O>
105{
106    type RuntimeArg<R: Runtime> = L::RuntimeArg<R>;
107    type CompilationArg = L::CompilationArg;
108
109    fn register<R: Runtime, B: BufferArg>(
110        arg: Self::RuntimeArg<R>,
111        buffer: &B,
112        ty: Type,
113        launcher: &mut KernelLauncher<R>,
114    ) -> Self::CompilationArg {
115        L::register::<R, B>(arg, buffer, ty, launcher)
116    }
117    fn expand(
118        arg: &Self::CompilationArg,
119        ty: Type,
120        builder: &mut KernelBuilder,
121    ) -> <Self as CubeType>::ExpandType {
122        IntoDyn2LayoutExpand {
123            layout: L::expand(arg, ty, builder),
124        }
125    }
126    fn expand_output(
127        arg: &Self::CompilationArg,
128        ty: Type,
129        builder: &mut KernelBuilder,
130    ) -> <Self as CubeType>::ExpandType {
131        IntoDyn2LayoutExpand {
132            layout: L::expand_output(arg, ty, builder),
133        }
134    }
135}
136
137impl<L: Layout<SourceCoordinates: IntoDyn> + ViewLayoutLaunchArg> IntoDynLayout<L> {
138    pub fn new(layout: L) -> Self {
139        IntoDynLayout { layout }
140    }
141}
142
143impl<
144    L: Layout<SourceCoordinates = (P, O)> + ViewLayoutLaunchArg,
145    P: IntoDyn,
146    O: IntoDyn + ViewLayoutLaunchArg,
147> IntoDyn2Layout<L, P, O>
148{
149    pub fn new(layout: L) -> Self {
150        IntoDyn2Layout { layout }
151    }
152}
153
154#[cube]
155impl<L: Layout<SourceCoordinates: IntoDyn> + ViewLayoutLaunchArg> Layout for IntoDynLayout<L> {
156    type Coordinates = L::Coordinates;
157    type SourceCoordinates = Sequence<i32>;
158
159    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
160        let pos = self.layout.to_source_pos(pos);
161        pos.into_dyn()
162    }
163
164    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
165        self.layout.is_in_bounds(pos)
166    }
167
168    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
169        let (pos, in_bounds) = self.layout.to_source_pos_checked(pos);
170        (pos.into_dyn(), in_bounds)
171    }
172
173    fn shape(&self) -> Self::Coordinates {
174        self.layout.shape()
175    }
176}
177
178#[cube]
179impl<
180    L: Layout<SourceCoordinates = (P, O)> + ViewLayoutLaunchArg,
181    P: IntoDyn,
182    O: IntoDyn + ViewLayoutLaunchArg,
183> Layout for IntoDyn2Layout<L, P, O>
184{
185    type Coordinates = L::Coordinates;
186    type SourceCoordinates = (Sequence<i32>, Sequence<i32>);
187
188    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
189        let pos = self.layout.to_source_pos(pos);
190        (pos.0.into_dyn(), pos.1.into_dyn())
191    }
192
193    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
194        self.layout.is_in_bounds(pos)
195    }
196
197    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
198        let (pos, in_bounds) = self.layout.to_source_pos_checked(pos);
199        ((pos.0.into_dyn(), pos.1.into_dyn()), in_bounds)
200    }
201
202    fn shape(&self) -> Self::Coordinates {
203        self.layout.shape()
204    }
205}