cubecl_std/tensor/layout/
as_dyn.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, unexpanded};
3use variadics_please::all_tuples;
4
5use crate::tensor::{
6 launch::{BufferArg, ViewLayoutLaunchArg},
7 layout::*,
8};
9
10#[cube]
13pub trait IntoDyn: Coordinates + LaunchArg {
14 fn into_dyn(self) -> Sequence<i32> {
15 unexpanded!()
16 }
17}
18
19macro_rules! impl_tuple {
20 ($(($T: ident, $t: ident)),*) => {
21 impl<$($T: Coordinates + CubePrimitive + LaunchArg),*> IntoDyn for ($($T),*) {}
22
23 impl<$($T: Coordinates + CubePrimitive + LaunchArg),*> IntoDynExpand for ($(NativeExpand<$T>),*) {
24 fn __expand_into_dyn_method(self, scope: &mut Scope) -> SequenceExpand<i32> {
25 let mut seq = Sequence::__expand_new(scope);
26 let ($($t),*) = self;
27 let ($($t),*) = ($(i32::__expand_cast_from(scope, $t)),*);
28 $(seq.__expand_push_method(scope, $t);)*
29 seq
30 }
31 }
32 };
33}
34
35all_tuples!(impl_tuple, 2, 12, T, t);
36
37#[cube]
38impl IntoDyn for Sequence<i32> {
39 fn into_dyn(self) -> Sequence<i32> {
40 self
41 }
42}
43
44#[cube]
45impl IntoDyn for Sequence<u32> {
46 fn into_dyn(self) -> Sequence<i32> {
47 let mut seq = Sequence::new();
48 for x in self {
49 seq.push(i32::cast_from(x));
50 }
51 seq
52 }
53}
54
55#[derive(CubeType)]
56pub struct IntoDynLayout<L: Layout<SourceCoordinates: IntoDyn> + ViewLayoutLaunchArg> {
57 layout: L,
58}
59
60impl<L: Layout<SourceCoordinates: IntoDyn> + ViewLayoutLaunchArg> ViewLayoutLaunchArg
61 for IntoDynLayout<L>
62{
63 type RuntimeArg<R: Runtime> = L::RuntimeArg<R>;
64 type CompilationArg = L::CompilationArg;
65
66 fn register<R: Runtime, B: BufferArg>(
67 arg: Self::RuntimeArg<R>,
68 buffer: &B,
69 ty: Type,
70 launcher: &mut KernelLauncher<R>,
71 ) -> Self::CompilationArg {
72 L::register::<R, B>(arg, buffer, ty, launcher)
73 }
74 fn expand(
75 arg: &Self::CompilationArg,
76 ty: Type,
77 builder: &mut KernelBuilder,
78 ) -> <Self as CubeType>::ExpandType {
79 IntoDynLayoutExpand {
80 layout: L::expand(arg, ty, builder),
81 }
82 }
83 fn expand_output(
84 arg: &Self::CompilationArg,
85 ty: Type,
86 builder: &mut KernelBuilder,
87 ) -> <Self as CubeType>::ExpandType {
88 IntoDynLayoutExpand {
89 layout: L::expand_output(arg, ty, builder),
90 }
91 }
92}
93
94#[derive(CubeType)]
95pub struct IntoDyn2Layout<
96 L: Layout<SourceCoordinates = (P, O)> + ViewLayoutLaunchArg,
97 P: IntoDyn,
98 O: IntoDyn,
99> {
100 layout: L,
101}
102
103impl<L: Layout<SourceCoordinates = (P, O)> + ViewLayoutLaunchArg, P: IntoDyn, O: IntoDyn>
104 ViewLayoutLaunchArg for IntoDyn2Layout<L, P, O>
105{
106 type RuntimeArg<R: Runtime> = L::RuntimeArg<R>;
107 type CompilationArg = L::CompilationArg;
108
109 fn register<R: Runtime, B: BufferArg>(
110 arg: Self::RuntimeArg<R>,
111 buffer: &B,
112 ty: Type,
113 launcher: &mut KernelLauncher<R>,
114 ) -> Self::CompilationArg {
115 L::register::<R, B>(arg, buffer, ty, launcher)
116 }
117 fn expand(
118 arg: &Self::CompilationArg,
119 ty: Type,
120 builder: &mut KernelBuilder,
121 ) -> <Self as CubeType>::ExpandType {
122 IntoDyn2LayoutExpand {
123 layout: L::expand(arg, ty, builder),
124 }
125 }
126 fn expand_output(
127 arg: &Self::CompilationArg,
128 ty: Type,
129 builder: &mut KernelBuilder,
130 ) -> <Self as CubeType>::ExpandType {
131 IntoDyn2LayoutExpand {
132 layout: L::expand_output(arg, ty, builder),
133 }
134 }
135}
136
137impl<L: Layout<SourceCoordinates: IntoDyn> + ViewLayoutLaunchArg> IntoDynLayout<L> {
138 pub fn new(layout: L) -> Self {
139 IntoDynLayout { layout }
140 }
141}
142
143impl<
144 L: Layout<SourceCoordinates = (P, O)> + ViewLayoutLaunchArg,
145 P: IntoDyn,
146 O: IntoDyn + ViewLayoutLaunchArg,
147> IntoDyn2Layout<L, P, O>
148{
149 pub fn new(layout: L) -> Self {
150 IntoDyn2Layout { layout }
151 }
152}
153
154#[cube]
155impl<L: Layout<SourceCoordinates: IntoDyn> + ViewLayoutLaunchArg> Layout for IntoDynLayout<L> {
156 type Coordinates = L::Coordinates;
157 type SourceCoordinates = Sequence<i32>;
158
159 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
160 let pos = self.layout.to_source_pos(pos);
161 pos.into_dyn()
162 }
163
164 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
165 self.layout.is_in_bounds(pos)
166 }
167
168 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
169 let (pos, in_bounds) = self.layout.to_source_pos_checked(pos);
170 (pos.into_dyn(), in_bounds)
171 }
172
173 fn shape(&self) -> Self::Coordinates {
174 self.layout.shape()
175 }
176}
177
178#[cube]
179impl<
180 L: Layout<SourceCoordinates = (P, O)> + ViewLayoutLaunchArg,
181 P: IntoDyn,
182 O: IntoDyn + ViewLayoutLaunchArg,
183> Layout for IntoDyn2Layout<L, P, O>
184{
185 type Coordinates = L::Coordinates;
186 type SourceCoordinates = (Sequence<i32>, Sequence<i32>);
187
188 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
189 let pos = self.layout.to_source_pos(pos);
190 (pos.0.into_dyn(), pos.1.into_dyn())
191 }
192
193 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
194 self.layout.is_in_bounds(pos)
195 }
196
197 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
198 let (pos, in_bounds) = self.layout.to_source_pos_checked(pos);
199 ((pos.0.into_dyn(), pos.1.into_dyn()), in_bounds)
200 }
201
202 fn shape(&self) -> Self::Coordinates {
203 self.layout.shape()
204 }
205}