Skip to main content

cubecl_std/tensor/layout/
virtual.rs

1use std::{marker::PhantomData, sync::Arc};
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, intrinsic, ir::Scope, unexpanded};
5
6use crate::tensor::layout::{Coordinates, Layout, LayoutExpand};
7
8/// A virtual layout, to carry a layout without the need for generic parameters everywhere.
9/// `C` represents the coordinate space of the underlying layout.
10#[derive(Clone)]
11pub struct VirtualLayout<C: Coordinates, S: Coordinates> {
12    _coords: PhantomData<(C, S)>,
13}
14
15impl<C: Coordinates, S: Coordinates> Copy for VirtualLayout<C, S> {}
16unsafe impl<C: Coordinates, S: Coordinates> Send for VirtualLayout<C, S> {}
17unsafe impl<C: Coordinates, S: Coordinates> Sync for VirtualLayout<C, S> {}
18
19#[derive(Clone)]
20pub struct VirtualLayoutExpand<C: Coordinates, S: Coordinates> {
21    pub(crate) state: Arc<dyn VirtualLayoutOperationsExpand<C, S>>,
22}
23
24#[cube]
25impl<C: Coordinates, S: Coordinates> VirtualLayout<C, S> {
26    /// Virtual version of [`Layout::to_source_pos`]
27    #[allow(unused)]
28    pub fn to_source_pos(&self, pos: C) -> S {
29        intrinsic!(|scope| { self.state.__expand_to_source_pos_method(scope, pos) })
30    }
31
32    /// Virtual version of [`Layout::to_source_pos_checked`]
33    #[allow(unused)]
34    pub fn to_source_pos_checked(&self, pos: C) -> (S, bool) {
35        intrinsic!(|scope| { self.state.__expand_to_source_pos_checked_method(scope, pos) })
36    }
37
38    /// Virtual version of [`Layout::shape`]
39    pub fn shape(&self) -> C {
40        intrinsic!(|scope| { self.state.__expand_shape_method(scope) })
41    }
42
43    /// Virtual version of [`Layout::is_in_bounds`]
44    #[allow(unused)]
45    pub fn is_in_bounds(&self, pos: C) -> bool {
46        intrinsic!(|scope| { self.state.__expand_is_in_bounds_method(scope, pos) })
47    }
48}
49
50impl<C: Coordinates, S: Coordinates> VirtualLayout<C, S> {
51    /// Create a new virtual layout from a concrete one
52    pub fn new<L: Layout<Coordinates = C, SourceCoordinates = S>>(
53        _layout: L,
54    ) -> VirtualLayout<C, S> {
55        unexpanded!()
56    }
57
58    /// Expand function of [`VirtualLayout::`__`expand_new`]
59    pub fn __expand_new<L: Layout<Coordinates = C, SourceCoordinates = S> + 'static>(
60        _scope: &mut Scope,
61        layout: L::ExpandType,
62    ) -> VirtualLayoutExpand<C, S> {
63        VirtualLayoutExpand::new::<L::ExpandType>(layout)
64    }
65}
66
67impl<C: Coordinates, S: Coordinates> VirtualLayoutExpand<C, S> {
68    /// Create a new virtual layout from a concrete one
69    pub fn new<L: VirtualLayoutOperationsExpand<C, S> + 'static>(
70        layout: L,
71    ) -> VirtualLayoutExpand<C, S> {
72        VirtualLayoutExpand::<C, S> {
73            state: Arc::new(layout),
74        }
75    }
76}
77
78impl<C: Coordinates, S: Coordinates> CubeType for VirtualLayout<C, S> {
79    type ExpandType = VirtualLayoutExpand<C, S>;
80}
81
82impl<C: Coordinates, S: Coordinates> IntoMut for VirtualLayoutExpand<C, S> {
83    fn into_mut(self, _scope: &mut Scope) -> Self {
84        self
85    }
86}
87
88impl<C: Coordinates, S: Coordinates> CubeDebug for VirtualLayoutExpand<C, S> {}
89
90// We need to seal the trait to allow us to blanket implement `From<L>` below
91mod private {
92    pub trait Sealed {}
93}
94pub trait VirtualLayoutOperationsExpand<C: CubeType, S: CubeType>: private::Sealed {
95    fn __expand_to_source_pos_method(
96        &self,
97        scope: &mut Scope,
98        pos: <C as CubeType>::ExpandType,
99    ) -> <S as CubeType>::ExpandType;
100    fn __expand_to_source_pos_checked_method(
101        &self,
102        scope: &mut Scope,
103        pos: <C as CubeType>::ExpandType,
104    ) -> <(S, bool) as CubeType>::ExpandType;
105    fn __expand_shape_method(&self, scope: &mut Scope) -> <C as CubeType>::ExpandType;
106    fn __expand_is_in_bounds_method(
107        &self,
108        scope: &mut Scope,
109        pos: <C as CubeType>::ExpandType,
110    ) -> ExpandElementTyped<bool>;
111}
112
113impl<L: LayoutExpand> private::Sealed for L {}
114impl<L: LayoutExpand> VirtualLayoutOperationsExpand<L::Coordinates, L::SourceCoordinates> for L {
115    fn __expand_to_source_pos_method(
116        &self,
117        scope: &mut Scope,
118        pos: <L::Coordinates as CubeType>::ExpandType,
119    ) -> <L::SourceCoordinates as CubeType>::ExpandType {
120        <L as LayoutExpand>::__expand_to_source_pos_method(self.clone(), scope, pos)
121    }
122
123    fn __expand_to_source_pos_checked_method(
124        &self,
125        scope: &mut Scope,
126        pos: <L::Coordinates as CubeType>::ExpandType,
127    ) -> <(L::SourceCoordinates, bool) as CubeType>::ExpandType {
128        <L as LayoutExpand>::__expand_to_source_pos_checked_method(self.clone(), scope, pos)
129    }
130
131    fn __expand_shape_method(&self, scope: &mut Scope) -> <L::Coordinates as CubeType>::ExpandType {
132        <L as LayoutExpand>::__expand_shape_method(self.clone(), scope)
133    }
134
135    fn __expand_is_in_bounds_method(
136        &self,
137        scope: &mut Scope,
138        pos: <L::Coordinates as CubeType>::ExpandType,
139    ) -> ExpandElementTyped<bool> {
140        <L as LayoutExpand>::__expand_is_in_bounds_method(self.clone(), scope, pos)
141    }
142}
143
144impl<C: Coordinates, S: Coordinates, L: VirtualLayoutOperationsExpand<C, S> + 'static> From<L>
145    for VirtualLayoutExpand<C, S>
146{
147    fn from(value: L) -> Self {
148        VirtualLayoutExpand::new(value)
149    }
150}
151
152impl<L: Layout + 'static> From<L> for VirtualLayout<L::Coordinates, L::SourceCoordinates> {
153    fn from(_value: L) -> Self {
154        VirtualLayout {
155            _coords: PhantomData,
156        }
157    }
158}
159
160mod launch {
161    use cubecl_core::{
162        format::DebugRaw,
163        hash::{StableHash, StableHasher},
164    };
165    use spin::Mutex;
166
167    use super::*;
168
169    type ExpandFn<C, S> =
170        Arc<Mutex<dyn FnMut(&mut KernelBuilder) -> VirtualLayoutExpand<C, S> + Send>>;
171
172    pub struct VirtualLayoutLaunch<'a, C: Coordinates, S: Coordinates, R: Runtime> {
173        _phantom_runtime: core::marker::PhantomData<R>,
174        _phantom_a: core::marker::PhantomData<&'a ()>,
175        inner: Arc<dyn ArgSettings<R> + 'a>,
176        hashed_arg: VirtualLayoutCompilationArg<C, S>,
177    }
178
179    impl<'a, C: Coordinates, S: Coordinates, R: cubecl::prelude::Runtime>
180        VirtualLayoutLaunch<'a, C, S, R>
181    {
182        pub fn new<L: Layout<Coordinates = C, SourceCoordinates = S> + LaunchArg>(
183            layout: L::RuntimeArg<'a, R>,
184        ) -> Self {
185            let comp_arg = L::compilation_arg(&layout);
186            let comp_arg_2 = comp_arg.clone();
187            let expand = move |builder: &mut KernelBuilder| {
188                let expand = L::expand(&comp_arg_2, builder);
189                VirtualLayoutExpand::new(expand)
190            };
191            let comp_arg_2 = comp_arg.clone();
192            let expand_out = move |builder: &mut KernelBuilder| {
193                let expand = L::expand_output(&comp_arg_2, builder);
194                VirtualLayoutExpand::new(expand)
195            };
196            let hashed_arg = VirtualLayoutCompilationArg::new::<L::CompilationArg>(
197                &comp_arg,
198                Arc::new(Mutex::new(expand)),
199                Arc::new(Mutex::new(expand_out)),
200            );
201
202            Self {
203                _phantom_runtime: PhantomData,
204                _phantom_a: PhantomData,
205                inner: Arc::new(layout),
206                hashed_arg,
207            }
208        }
209    }
210    impl<'a, C: Coordinates, S: Coordinates, R: cubecl::prelude::Runtime> ArgSettings<R>
211        for VirtualLayoutLaunch<'a, C, S, R>
212    {
213        fn register(&self, launcher: &mut cubecl::prelude::KernelLauncher<R>) {
214            self.inner.register(launcher);
215        }
216    }
217
218    #[derive(Clone)]
219    pub struct VirtualLayoutCompilationArg<C: Coordinates, S: Coordinates> {
220        type_name: String,
221        debug_string: String,
222        debug_string_pretty: String,
223        hash: StableHash,
224        expand: ExpandFn<C, S>,
225        expand_output: ExpandFn<C, S>,
226    }
227
228    impl<C: Coordinates, S: Coordinates> VirtualLayoutCompilationArg<C, S> {
229        pub fn new<L: CompilationArg>(
230            arg: &L,
231            expand: ExpandFn<C, S>,
232            expand_output: ExpandFn<C, S>,
233        ) -> Self {
234            // Hash ahead of time so we don't need to store the actual data, which would be far
235            // more complex
236            let hash = StableHasher::hash_one(arg);
237            Self {
238                type_name: core::any::type_name::<L>().to_string(),
239                debug_string: format!("{arg:?}"),
240                debug_string_pretty: format!("{arg:#?}"),
241                hash,
242                expand,
243                expand_output,
244            }
245        }
246    }
247
248    impl<C: Coordinates, S: Coordinates> PartialEq for VirtualLayoutCompilationArg<C, S> {
249        fn eq(&self, other: &Self) -> bool {
250            self.type_name == other.type_name && self.hash == other.hash
251        }
252    }
253    impl<C: Coordinates, S: Coordinates> Eq for VirtualLayoutCompilationArg<C, S> {}
254
255    impl<C: Coordinates + 'static, S: Coordinates + 'static> CompilationArg
256        for VirtualLayoutCompilationArg<C, S>
257    {
258    }
259
260    impl<C: Coordinates, S: Coordinates> core::hash::Hash for VirtualLayoutCompilationArg<C, S> {
261        fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
262            self.type_name.hash(state);
263            self.hash.hash(state);
264        }
265    }
266
267    impl<C: Coordinates, S: Coordinates> core::fmt::Debug for VirtualLayoutCompilationArg<C, S> {
268        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269            // `alternate` means `{:#?}`, or pretty printing
270            if f.alternate() {
271                f.debug_struct(stringify!(VirtualLayout))
272                    .field("type", &DebugRaw(&self.type_name))
273                    .field("value", &DebugRaw(&self.debug_string_pretty))
274                    .finish()
275            } else {
276                f.debug_struct(stringify!(VirtualLayout))
277                    .field("type", &DebugRaw(&self.type_name))
278                    .field("value", &DebugRaw(&self.debug_string))
279                    .finish()
280            }
281        }
282    }
283
284    impl<C: Coordinates + 'static, S: Coordinates + 'static> LaunchArg for VirtualLayout<C, S> {
285        type RuntimeArg<'a, R: Runtime> = VirtualLayoutLaunch<'a, C, S, R>;
286        type CompilationArg = VirtualLayoutCompilationArg<C, S>;
287
288        fn compilation_arg<'a, R: Runtime>(
289            runtime_arg: &Self::RuntimeArg<'a, R>,
290        ) -> Self::CompilationArg {
291            runtime_arg.hashed_arg.clone()
292        }
293        fn expand(
294            arg: &Self::CompilationArg,
295            builder: &mut KernelBuilder,
296        ) -> <Self as CubeType>::ExpandType {
297            let mut expand = arg.expand.as_ref().lock();
298            expand(builder)
299        }
300        fn expand_output(
301            arg: &Self::CompilationArg,
302            builder: &mut KernelBuilder,
303        ) -> <Self as CubeType>::ExpandType {
304            let mut expand = arg.expand_output.as_ref().lock();
305            expand(builder)
306        }
307    }
308}
309
310pub use launch::*;