cubecl_std/tensor/layout/
chain.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::tensor::layout::{Layout, LayoutExpand};
5
6#[derive(CubeType)]
8pub struct Chain<L0: Layout, L1: Layout<SourceCoordinates = L0::Coordinates>> {
9 l0: L0,
10 l1: L1,
11}
12
13#[cube]
14impl<L0: Layout, L1: Layout<SourceCoordinates = L0::Coordinates>> Chain<L0, L1> {
15 pub fn new(l0: L0, l1: L1) -> Self {
16 Chain::<L0, L1> { l0, l1 }
17 }
18}
19
20#[cube]
21impl<L0: Layout, L1: Layout<SourceCoordinates = L0::Coordinates>> Layout for Chain<L0, L1> {
22 type Coordinates = L1::Coordinates;
23 type SourceCoordinates = L0::SourceCoordinates;
24
25 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
26 let pos = self.l1.to_source_pos(pos);
27 self.l0.to_source_pos(pos)
28 }
29
30 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
31 let (pos, l1_in_bounds) = self.l1.to_source_pos_checked(pos);
32 self.l0.is_in_bounds(pos) && l1_in_bounds
33 }
34
35 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
36 let (pos, l1_in_bounds) = self.l1.to_source_pos_checked(pos);
37 let (pos, l0_in_bounds) = self.l0.to_source_pos_checked(pos);
38 (pos, l0_in_bounds && l1_in_bounds)
39 }
40
41 fn shape(&self) -> Self::Coordinates {
42 self.l1.shape()
43 }
44}
45
46pub use launch::*;
47mod launch {
48 use core::marker::PhantomData;
49
50 use super::*;
51
52 pub struct ChainLaunch<
53 'a,
54 L0: Layout + LaunchArg,
55 L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg,
56 R: Runtime,
57 > {
58 _phantom_runtime: PhantomData<R>,
59 _phantom_a: PhantomData<&'a ()>,
60 l0: L0::RuntimeArg<'a, R>,
61 l1: L1::RuntimeArg<'a, R>,
62 }
63 impl<
64 'a,
65 L0: Layout + LaunchArg,
66 L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg,
67 R: Runtime,
68 > ChainLaunch<'a, L0, L1, R>
69 {
70 pub fn new(l0: L0::RuntimeArg<'a, R>, l1: L1::RuntimeArg<'a, R>) -> Self {
71 Self {
72 _phantom_runtime: PhantomData,
73 _phantom_a: PhantomData,
74 l0,
75 l1,
76 }
77 }
78 }
79 impl<
80 'a,
81 L0: Layout + LaunchArg,
82 L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg,
83 R: Runtime,
84 > ArgSettings<R> for ChainLaunch<'a, L0, L1, R>
85 {
86 fn register(&self, launcher: &mut cubecl::prelude::KernelLauncher<R>) {
87 self.l0.register(launcher);
88 self.l1.register(launcher);
89 }
90 }
91
92 pub struct ChainCompilationArg<
93 L0: Layout + LaunchArg,
94 L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg,
95 > {
96 l0: L0::CompilationArg,
97 l1: L1::CompilationArg,
98 }
99 impl<L0: Layout + LaunchArg, L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg> Clone
100 for ChainCompilationArg<L0, L1>
101 {
102 fn clone(&self) -> Self {
103 Self {
104 l0: self.l0.clone(),
105 l1: self.l1.clone(),
106 }
107 }
108 }
109 impl<L0: Layout + LaunchArg, L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg>
110 CompilationArg for ChainCompilationArg<L0, L1>
111 {
112 }
113
114 impl<L0: Layout + LaunchArg, L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg>
115 core::hash::Hash for ChainCompilationArg<L0, L1>
116 {
117 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
118 self.l0.hash(state);
119 self.l1.hash(state);
120 }
121 }
122 impl<L0: Layout + LaunchArg, L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg>
123 core::cmp::PartialEq for ChainCompilationArg<L0, L1>
124 {
125 fn eq(&self, other: &Self) -> bool {
126 self.l0.eq(&other.l0) && self.l1.eq(&other.l1)
127 }
128 }
129 impl<L0: Layout + LaunchArg, L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg>
130 core::fmt::Debug for ChainCompilationArg<L0, L1>
131 {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 f.write_str(stringify!(Chain))?;
134 f.write_str("{")?;
135 f.write_fmt(format_args!("{}: {:?},", stringify!(l0), &self.l0))?;
136 f.write_fmt(format_args!("{}: {:?},", stringify!(l1), &self.l1))?;
137 f.write_str("}")?;
138 Ok(())
139 }
140 }
141 impl<L0: Layout + LaunchArg, L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg>
142 core::cmp::Eq for ChainCompilationArg<L0, L1>
143 {
144 }
145
146 impl<L0: Layout + LaunchArg, L1: Layout<SourceCoordinates = L0::Coordinates> + LaunchArg>
147 LaunchArg for Chain<L0, L1>
148 {
149 type RuntimeArg<'a, R: Runtime> = ChainLaunch<'a, L0, L1, R>;
150 type CompilationArg = ChainCompilationArg<L0, L1>;
151 fn compilation_arg<'a, R: Runtime>(
152 runtime_arg: &Self::RuntimeArg<'a, R>,
153 ) -> Self::CompilationArg {
154 ChainCompilationArg {
155 l0: L0::compilation_arg::<R>(&runtime_arg.l0),
156 l1: L1::compilation_arg::<R>(&runtime_arg.l1),
157 }
158 }
159 fn expand(
160 arg: &Self::CompilationArg,
161 builder: &mut KernelBuilder,
162 ) -> <Self as CubeType>::ExpandType {
163 ChainExpand {
164 l0: L0::expand(&arg.l0, builder),
165 l1: L1::expand(&arg.l1, builder),
166 }
167 }
168 fn expand_output(
169 arg: &Self::CompilationArg,
170 builder: &mut KernelBuilder,
171 ) -> <Self as CubeType>::ExpandType {
172 ChainExpand {
173 l0: L0::expand_output(&arg.l0, builder),
174 l1: L1::expand_output(&arg.l1, builder),
175 }
176 }
177 }
178}