Skip to main content

cubecl_std/tensor/layout/
chain.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::tensor::layout::{Layout, LayoutExpand};
5
6/// Chain of layouts, can be used to launch with multiple layouts
7#[derive(CubeType)]
8pub struct Chain<L0: Layout, L1: Layout<SourceCoordinates = L0::Coordinates>> {
9    l0: L0,
10    l1: L1,
11}
12
13#[cube]
14impl<L0: Layout, L1: Layout<SourceCoordinates = L0::Coordinates>> Chain<L0, L1> {
15    pub fn new(l0: L0, l1: L1) -> Self {
16        Chain::<L0, L1> { l0, l1 }
17    }
18}
19
20#[cube]
21impl<L0: Layout, L1: Layout<SourceCoordinates = L0::Coordinates>> Layout for Chain<L0, L1> {
22    type Coordinates = L1::Coordinates;
23    type SourceCoordinates = L0::SourceCoordinates;
24
25    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
26        let pos = self.l1.to_source_pos(pos);
27        self.l0.to_source_pos(pos)
28    }
29
30    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
31        let (pos, l1_in_bounds) = self.l1.to_source_pos_checked(pos);
32        self.l0.is_in_bounds(pos) && l1_in_bounds
33    }
34
35    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
36        let (pos, l1_in_bounds) = self.l1.to_source_pos_checked(pos);
37        let (pos, l0_in_bounds) = self.l0.to_source_pos_checked(pos);
38        (pos, l0_in_bounds && l1_in_bounds)
39    }
40
41    fn shape(&self) -> Self::Coordinates {
42        self.l1.shape()
43    }
44}
45
46pub use launch::*;
47mod launch {
48    use core::marker::PhantomData;
49
50    use crate::tensor::launch::{BufferArg, ViewLayoutLaunchArg};
51
52    use super::*;
53
54    pub struct ChainLaunch<
55        L0: Layout + ViewLayoutLaunchArg,
56        L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
57        R: Runtime,
58    > {
59        _phantom_runtime: PhantomData<R>,
60        l0: L0::RuntimeArg<R>,
61        l1: L1::RuntimeArg<R>,
62    }
63    impl<
64        L0: Layout + ViewLayoutLaunchArg,
65        L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
66        R: Runtime,
67    > ChainLaunch<L0, L1, R>
68    {
69        pub fn new(l0: L0::RuntimeArg<R>, l1: L1::RuntimeArg<R>) -> Self {
70            Self {
71                _phantom_runtime: PhantomData,
72                l0,
73                l1,
74            }
75        }
76    }
77
78    pub struct ChainCompilationArg<
79        L0: Layout + ViewLayoutLaunchArg,
80        L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
81    > {
82        l0: L0::CompilationArg,
83        l1: L1::CompilationArg,
84    }
85    impl<
86        L0: Layout + ViewLayoutLaunchArg,
87        L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
88    > Clone for ChainCompilationArg<L0, L1>
89    {
90        fn clone(&self) -> Self {
91            Self {
92                l0: self.l0.clone(),
93                l1: self.l1.clone(),
94            }
95        }
96    }
97
98    impl<
99        L0: Layout + ViewLayoutLaunchArg,
100        L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
101    > core::hash::Hash for ChainCompilationArg<L0, L1>
102    {
103        fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
104            self.l0.hash(state);
105            self.l1.hash(state);
106        }
107    }
108    impl<
109        L0: Layout + ViewLayoutLaunchArg,
110        L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
111    > core::cmp::PartialEq for ChainCompilationArg<L0, L1>
112    {
113        fn eq(&self, other: &Self) -> bool {
114            self.l0.eq(&other.l0) && self.l1.eq(&other.l1)
115        }
116    }
117    impl<
118        L0: Layout + ViewLayoutLaunchArg,
119        L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
120    > core::fmt::Debug for ChainCompilationArg<L0, L1>
121    {
122        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123            f.debug_struct(stringify!(Chain))
124                .field(stringify!(l0), &self.l0)
125                .field(stringify!(l1), &self.l1)
126                .finish()
127        }
128    }
129    impl<
130        L0: Layout + ViewLayoutLaunchArg,
131        L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
132    > core::cmp::Eq for ChainCompilationArg<L0, L1>
133    {
134    }
135
136    impl<
137        L0: Layout + ViewLayoutLaunchArg,
138        L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
139    > ViewLayoutLaunchArg for Chain<L0, L1>
140    {
141        type RuntimeArg<R: Runtime> = ChainLaunch<L0, L1, R>;
142        type CompilationArg = ChainCompilationArg<L0, L1>;
143
144        fn register<R: Runtime, B: BufferArg>(
145            arg: Self::RuntimeArg<R>,
146            buffer: &B,
147            ty: Type,
148            launcher: &mut KernelLauncher<R>,
149        ) -> Self::CompilationArg {
150            ChainCompilationArg {
151                l0: L0::register(arg.l0, buffer, ty, launcher),
152                l1: L1::register(arg.l1, buffer, ty, launcher),
153            }
154        }
155        fn expand(
156            arg: &Self::CompilationArg,
157            ty: Type,
158            builder: &mut KernelBuilder,
159        ) -> <Self as CubeType>::ExpandType {
160            ChainExpand {
161                l0: L0::expand(&arg.l0, ty, builder),
162                l1: L1::expand(&arg.l1, ty, builder),
163            }
164        }
165        fn expand_output(
166            arg: &Self::CompilationArg,
167            ty: Type,
168            builder: &mut KernelBuilder,
169        ) -> <Self as CubeType>::ExpandType {
170            ChainExpand {
171                l0: L0::expand_output(&arg.l0, ty, builder),
172                l1: L1::expand_output(&arg.l1, ty, builder),
173            }
174        }
175    }
176}