cubecl_std/tensor/layout/
chain.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::tensor::layout::{Layout, LayoutExpand};
5
6/// Chain of layouts, can be used to launch with multiple layouts
7#[derive(CubeType)]
8pub struct Chain<L0: Layout, L1: Layout<SourceCoordinates = L0::Coordinates>> {
9    l0: L0,
10    l1: L1,
11}
12
13#[cube]
14impl<L0: Layout, L1: Layout<SourceCoordinates = L0::Coordinates>> Chain<L0, L1> {
15    pub fn new(l0: L0, l1: L1) -> Self {
16        Chain::<L0, L1> { l0, l1 }
17    }
18}
19
20#[cube]
21impl<L0: Layout, L1: Layout<SourceCoordinates = L0::Coordinates>> Layout for Chain<L0, L1> {
22    type Coordinates = L1::Coordinates;
23    type SourceCoordinates = L0::SourceCoordinates;
24
25    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
26        let pos = self.l1.to_source_pos(pos);
27        self.l0.to_source_pos(pos)
28    }
29
30    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
31        let (pos, l1_in_bounds) = self.l1.to_source_pos_checked(pos);
32        self.l0.is_in_bounds(pos) && l1_in_bounds
33    }
34
35    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
36        let (pos, l1_in_bounds) = self.l1.to_source_pos_checked(pos);
37        let (pos, l0_in_bounds) = self.l0.to_source_pos_checked(pos);
38        (pos, l0_in_bounds && l1_in_bounds)
39    }
40
41    fn shape(&self) -> Self::Coordinates {
42        self.l1.shape()
43    }
44}
45
46pub use launch::*;
47mod launch {
48    use core::marker::PhantomData;
49
50    use super::*;
51
52    pub struct ChainLaunch<
53        'a,
54        L0: Layout + LaunchArg,
55        L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg,
56        R: Runtime,
57    > {
58        _phantom_runtime: PhantomData<R>,
59        _phantom_a: PhantomData<&'a ()>,
60        l0: L0::RuntimeArg<'a, R>,
61        l1: L1::RuntimeArg<'a, R>,
62    }
63    impl<
64        'a,
65        L0: Layout + LaunchArg,
66        L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg,
67        R: Runtime,
68    > ChainLaunch<'a, L0, L1, R>
69    {
70        pub fn new(l0: L0::RuntimeArg<'a, R>, l1: L1::RuntimeArg<'a, R>) -> Self {
71            Self {
72                _phantom_runtime: PhantomData,
73                _phantom_a: PhantomData,
74                l0,
75                l1,
76            }
77        }
78    }
79    impl<
80        'a,
81        L0: Layout + LaunchArg,
82        L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg,
83        R: Runtime,
84    > ArgSettings<R> for ChainLaunch<'a, L0, L1, R>
85    {
86        fn register(&self, launcher: &mut cubecl::prelude::KernelLauncher<R>) {
87            self.l0.register(launcher);
88            self.l1.register(launcher);
89        }
90    }
91
92    pub struct ChainCompilationArg<
93        L0: Layout + LaunchArg,
94        L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg,
95    > {
96        l0: L0::CompilationArg,
97        l1: L1::CompilationArg,
98    }
99    impl<L0: Layout + LaunchArg, L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg> Clone
100        for ChainCompilationArg<L0, L1>
101    {
102        fn clone(&self) -> Self {
103            Self {
104                l0: self.l0.clone(),
105                l1: self.l1.clone(),
106            }
107        }
108    }
109    impl<L0: Layout + LaunchArg, L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg>
110        CompilationArg for ChainCompilationArg<L0, L1>
111    {
112    }
113
114    impl<L0: Layout + LaunchArg, L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg>
115        core::hash::Hash for ChainCompilationArg<L0, L1>
116    {
117        fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
118            self.l0.hash(state);
119            self.l1.hash(state);
120        }
121    }
122    impl<L0: Layout + LaunchArg, L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg>
123        core::cmp::PartialEq for ChainCompilationArg<L0, L1>
124    {
125        fn eq(&self, other: &Self) -> bool {
126            self.l0.eq(&other.l0) && self.l1.eq(&other.l1)
127        }
128    }
129    impl<L0: Layout + LaunchArg, L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg>
130        core::fmt::Debug for ChainCompilationArg<L0, L1>
131    {
132        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133            f.write_str(stringify!(Chain))?;
134            f.write_str("{")?;
135            f.write_fmt(format_args!("{}: {:?},", stringify!(l0), &self.l0))?;
136            f.write_fmt(format_args!("{}: {:?},", stringify!(l1), &self.l1))?;
137            f.write_str("}")?;
138            Ok(())
139        }
140    }
141    impl<L0: Layout + LaunchArg, L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg>
142        core::cmp::Eq for ChainCompilationArg<L0, L1>
143    {
144    }
145
146    impl<L0: Layout + LaunchArg, L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg>
147        LaunchArg for Chain<L0, L1>
148    {
149        type RuntimeArg<'a, R: Runtime> = ChainLaunch<'a, L0, L1, R>;
150        type CompilationArg = ChainCompilationArg<L0, L1>;
151        fn compilation_arg<'a, R: Runtime>(
152            runtime_arg: &Self::RuntimeArg<'a, R>,
153        ) -> Self::CompilationArg {
154            ChainCompilationArg {
155                l0: L0::compilation_arg::<R>(&runtime_arg.l0),
156                l1: L1::compilation_arg::<R>(&runtime_arg.l1),
157            }
158        }
159        fn expand(
160            arg: &Self::CompilationArg,
161            builder: &mut KernelBuilder,
162        ) -> <Self as CubeType>::ExpandType {
163            ChainExpand {
164                l0: L0::expand(&arg.l0, builder),
165                l1: L1::expand(&arg.l1, builder),
166            }
167        }
168        fn expand_output(
169            arg: &Self::CompilationArg,
170            builder: &mut KernelBuilder,
171        ) -> <Self as CubeType>::ExpandType {
172            ChainExpand {
173                l0: L0::expand_output(&arg.l0, builder),
174                l1: L1::expand_output(&arg.l1, builder),
175            }
176        }
177    }
178}