1use std::{marker::PhantomData, sync::Arc};
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, intrinsic, ir::Scope, unexpanded};
5
6use crate::tensor::layout::{Coordinates, Layout, LayoutExpand};
7
8#[derive(Clone)]
11pub struct VirtualLayout<C: Coordinates, S: Coordinates> {
12 _coords: PhantomData<(C, S)>,
13}
14
15impl<C: Coordinates, S: Coordinates> Copy for VirtualLayout<C, S> {}
16unsafe impl<C: Coordinates, S: Coordinates> Send for VirtualLayout<C, S> {}
17unsafe impl<C: Coordinates, S: Coordinates> Sync for VirtualLayout<C, S> {}
18
19#[derive(Clone)]
20pub struct VirtualLayoutExpand<C: Coordinates, S: Coordinates> {
21 pub(crate) state: Arc<dyn VirtualLayoutOperationsExpand<C, S>>,
22}
23
24#[cube]
25impl<C: Coordinates, S: Coordinates> VirtualLayout<C, S> {
26 #[allow(unused)]
28 pub fn to_source_pos(&self, pos: C) -> S {
29 intrinsic!(|scope| { self.state.__expand_to_source_pos_method(scope, pos) })
30 }
31
32 #[allow(unused)]
34 pub fn to_source_pos_checked(&self, pos: C) -> (S, bool) {
35 intrinsic!(|scope| { self.state.__expand_to_source_pos_checked_method(scope, pos) })
36 }
37
38 pub fn shape(&self) -> C {
40 intrinsic!(|scope| { self.state.__expand_shape_method(scope) })
41 }
42
43 #[allow(unused)]
45 pub fn is_in_bounds(&self, pos: C) -> bool {
46 intrinsic!(|scope| { self.state.__expand_is_in_bounds_method(scope, pos) })
47 }
48}
49
50impl<C: Coordinates, S: Coordinates> VirtualLayout<C, S> {
51 pub fn new<L: Layout<Coordinates = C, SourceCoordinates = S>>(
53 _layout: L,
54 ) -> VirtualLayout<C, S> {
55 unexpanded!()
56 }
57
58 pub fn __expand_new<L: Layout<Coordinates = C, SourceCoordinates = S> + 'static>(
60 _scope: &mut Scope,
61 layout: L::ExpandType,
62 ) -> VirtualLayoutExpand<C, S> {
63 VirtualLayoutExpand::new::<L::ExpandType>(layout)
64 }
65}
66
67impl<C: Coordinates, S: Coordinates> VirtualLayoutExpand<C, S> {
68 pub fn new<L: VirtualLayoutOperationsExpand<C, S> + 'static>(
70 layout: L,
71 ) -> VirtualLayoutExpand<C, S> {
72 VirtualLayoutExpand::<C, S> {
73 state: Arc::new(layout),
74 }
75 }
76}
77
78impl<C: Coordinates, S: Coordinates> CubeType for VirtualLayout<C, S> {
79 type ExpandType = VirtualLayoutExpand<C, S>;
80}
81
82impl<C: Coordinates, S: Coordinates> IntoMut for VirtualLayoutExpand<C, S> {
83 fn into_mut(self, _scope: &mut Scope) -> Self {
84 self
85 }
86}
87
88impl<C: Coordinates, S: Coordinates> CubeDebug for VirtualLayoutExpand<C, S> {}
89
90mod private {
92 pub trait Sealed {}
93}
94pub trait VirtualLayoutOperationsExpand<C: CubeType, S: CubeType>: private::Sealed {
95 fn __expand_to_source_pos_method(
96 &self,
97 scope: &mut Scope,
98 pos: <C as CubeType>::ExpandType,
99 ) -> <S as CubeType>::ExpandType;
100 fn __expand_to_source_pos_checked_method(
101 &self,
102 scope: &mut Scope,
103 pos: <C as CubeType>::ExpandType,
104 ) -> <(S, bool) as CubeType>::ExpandType;
105 fn __expand_shape_method(&self, scope: &mut Scope) -> <C as CubeType>::ExpandType;
106 fn __expand_is_in_bounds_method(
107 &self,
108 scope: &mut Scope,
109 pos: <C as CubeType>::ExpandType,
110 ) -> ExpandElementTyped<bool>;
111}
112
113impl<L: LayoutExpand> private::Sealed for L {}
114impl<L: LayoutExpand> VirtualLayoutOperationsExpand<L::Coordinates, L::SourceCoordinates> for L {
115 fn __expand_to_source_pos_method(
116 &self,
117 scope: &mut Scope,
118 pos: <L::Coordinates as CubeType>::ExpandType,
119 ) -> <L::SourceCoordinates as CubeType>::ExpandType {
120 <L as LayoutExpand>::__expand_to_source_pos_method(self.clone(), scope, pos)
121 }
122
123 fn __expand_to_source_pos_checked_method(
124 &self,
125 scope: &mut Scope,
126 pos: <L::Coordinates as CubeType>::ExpandType,
127 ) -> <(L::SourceCoordinates, bool) as CubeType>::ExpandType {
128 <L as LayoutExpand>::__expand_to_source_pos_checked_method(self.clone(), scope, pos)
129 }
130
131 fn __expand_shape_method(&self, scope: &mut Scope) -> <L::Coordinates as CubeType>::ExpandType {
132 <L as LayoutExpand>::__expand_shape_method(self.clone(), scope)
133 }
134
135 fn __expand_is_in_bounds_method(
136 &self,
137 scope: &mut Scope,
138 pos: <L::Coordinates as CubeType>::ExpandType,
139 ) -> ExpandElementTyped<bool> {
140 <L as LayoutExpand>::__expand_is_in_bounds_method(self.clone(), scope, pos)
141 }
142}
143
144impl<C: Coordinates, S: Coordinates, L: VirtualLayoutOperationsExpand<C, S> + 'static> From<L>
145 for VirtualLayoutExpand<C, S>
146{
147 fn from(value: L) -> Self {
148 VirtualLayoutExpand::new(value)
149 }
150}
151
152impl<L: Layout + 'static> From<L> for VirtualLayout<L::Coordinates, L::SourceCoordinates> {
153 fn from(_value: L) -> Self {
154 VirtualLayout {
155 _coords: PhantomData,
156 }
157 }
158}
159
160mod launch {
161 use core::hash::BuildHasher;
162 use cubecl_core::format::DebugRaw;
163 use spin::Mutex;
164
165 use super::*;
166
167 type ExpandFn<C, S> =
168 Arc<Mutex<dyn FnMut(&mut KernelBuilder) -> VirtualLayoutExpand<C, S> + Send>>;
169
170 pub struct VirtualLayoutLaunch<'a, C: Coordinates, S: Coordinates, R: Runtime> {
171 _phantom_runtime: core::marker::PhantomData<R>,
172 _phantom_a: core::marker::PhantomData<&'a ()>,
173 inner: Arc<dyn ArgSettings<R> + 'a>,
174 hashed_arg: VirtualLayoutCompilationArg<C, S>,
175 }
176
177 impl<'a, C: Coordinates, S: Coordinates, R: cubecl::prelude::Runtime>
178 VirtualLayoutLaunch<'a, C, S, R>
179 {
180 pub fn new<L: Layout<Coordinates = C, SourceCoordinates = S> + LaunchArg>(
181 layout: L::RuntimeArg<'a, R>,
182 ) -> Self {
183 let comp_arg = L::compilation_arg(&layout);
184 let comp_arg_2 = comp_arg.clone();
185 let expand = move |builder: &mut KernelBuilder| {
186 let expand = L::expand(&comp_arg_2, builder);
187 VirtualLayoutExpand::new(expand)
188 };
189 let comp_arg_2 = comp_arg.clone();
190 let expand_out = move |builder: &mut KernelBuilder| {
191 let expand = L::expand_output(&comp_arg_2, builder);
192 VirtualLayoutExpand::new(expand)
193 };
194 let hashed_arg = VirtualLayoutCompilationArg::new::<L::CompilationArg>(
195 &comp_arg,
196 Arc::new(Mutex::new(expand)),
197 Arc::new(Mutex::new(expand_out)),
198 );
199
200 Self {
201 _phantom_runtime: PhantomData,
202 _phantom_a: PhantomData,
203 inner: Arc::new(layout),
204 hashed_arg,
205 }
206 }
207 }
208 impl<'a, C: Coordinates, S: Coordinates, R: cubecl::prelude::Runtime> ArgSettings<R>
209 for VirtualLayoutLaunch<'a, C, S, R>
210 {
211 fn register(&self, launcher: &mut cubecl::prelude::KernelLauncher<R>) {
212 self.inner.register(launcher);
213 }
214 }
215
216 #[derive(Clone)]
217 pub struct VirtualLayoutCompilationArg<C: Coordinates, S: Coordinates> {
218 type_name: String,
219 debug_string: String,
220 debug_string_pretty: String,
221 hash: u64,
222 expand: ExpandFn<C, S>,
223 expand_output: ExpandFn<C, S>,
224 }
225
226 impl<C: Coordinates, S: Coordinates> VirtualLayoutCompilationArg<C, S> {
227 pub fn new<L: CompilationArg>(
228 arg: &L,
229 expand: ExpandFn<C, S>,
230 expand_output: ExpandFn<C, S>,
231 ) -> Self {
232 let state = foldhash::fast::FixedState::default();
235 let hash = state.hash_one(arg);
236 Self {
237 type_name: core::any::type_name::<L>().to_string(),
238 debug_string: format!("{arg:?}"),
239 debug_string_pretty: format!("{arg:#?}"),
240 hash,
241 expand,
242 expand_output,
243 }
244 }
245 }
246
247 impl<C: Coordinates, S: Coordinates> PartialEq for VirtualLayoutCompilationArg<C, S> {
248 fn eq(&self, other: &Self) -> bool {
249 self.type_name == other.type_name && self.hash == other.hash
250 }
251 }
252 impl<C: Coordinates, S: Coordinates> Eq for VirtualLayoutCompilationArg<C, S> {}
253
254 impl<C: Coordinates + 'static, S: Coordinates + 'static> CompilationArg
255 for VirtualLayoutCompilationArg<C, S>
256 {
257 }
258
259 impl<C: Coordinates, S: Coordinates> core::hash::Hash for VirtualLayoutCompilationArg<C, S> {
260 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
261 self.type_name.hash(state);
262 self.hash.hash(state);
263 }
264 }
265
266 impl<C: Coordinates, S: Coordinates> core::fmt::Debug for VirtualLayoutCompilationArg<C, S> {
267 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268 if f.alternate() {
270 f.debug_struct(stringify!(VirtualLayout))
271 .field("type", &DebugRaw(&self.type_name))
272 .field("value", &DebugRaw(&self.debug_string_pretty))
273 .finish()
274 } else {
275 f.debug_struct(stringify!(VirtualLayout))
276 .field("type", &DebugRaw(&self.type_name))
277 .field("value", &DebugRaw(&self.debug_string))
278 .finish()
279 }
280 }
281 }
282
283 impl<C: Coordinates + 'static, S: Coordinates + 'static> LaunchArg for VirtualLayout<C, S> {
284 type RuntimeArg<'a, R: Runtime> = VirtualLayoutLaunch<'a, C, S, R>;
285 type CompilationArg = VirtualLayoutCompilationArg<C, S>;
286
287 fn compilation_arg<'a, R: Runtime>(
288 runtime_arg: &Self::RuntimeArg<'a, R>,
289 ) -> Self::CompilationArg {
290 runtime_arg.hashed_arg.clone()
291 }
292 fn expand(
293 arg: &Self::CompilationArg,
294 builder: &mut KernelBuilder,
295 ) -> <Self as CubeType>::ExpandType {
296 let mut expand = arg.expand.as_ref().lock();
297 expand(builder)
298 }
299 fn expand_output(
300 arg: &Self::CompilationArg,
301 builder: &mut KernelBuilder,
302 ) -> <Self as CubeType>::ExpandType {
303 let mut expand = arg.expand_output.as_ref().lock();
304 expand(builder)
305 }
306 }
307}
308
309pub use launch::*;