cubecl_std/tensor/layout/
chain.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::tensor::layout::{Layout, LayoutExpand};
5
6#[derive(CubeType)]
8pub struct Chain<L0: Layout, L1: Layout<SourceCoordinates = L0::Coordinates>> {
9 l0: L0,
10 l1: L1,
11}
12
13#[cube]
14impl<L0: Layout, L1: Layout<SourceCoordinates = L0::Coordinates>> Chain<L0, L1> {
15 pub fn new(l0: L0, l1: L1) -> Self {
16 Chain::<L0, L1> { l0, l1 }
17 }
18}
19
20#[cube]
21impl<L0: Layout, L1: Layout<SourceCoordinates = L0::Coordinates>> Layout for Chain<L0, L1> {
22 type Coordinates = L1::Coordinates;
23 type SourceCoordinates = L0::SourceCoordinates;
24
25 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
26 let pos = self.l1.to_source_pos(pos);
27 self.l0.to_source_pos(pos)
28 }
29
30 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
31 let (pos, l1_in_bounds) = self.l1.to_source_pos_checked(pos);
32 self.l0.is_in_bounds(pos) && l1_in_bounds
33 }
34
35 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
36 let (pos, l1_in_bounds) = self.l1.to_source_pos_checked(pos);
37 let (pos, l0_in_bounds) = self.l0.to_source_pos_checked(pos);
38 (pos, l0_in_bounds && l1_in_bounds)
39 }
40
41 fn shape(&self) -> Self::Coordinates {
42 self.l1.shape()
43 }
44}
45
46pub use launch::*;
47mod launch {
48 use core::marker::PhantomData;
49
50 use crate::tensor::launch::{BufferArg, ViewLayoutLaunchArg};
51
52 use super::*;
53
54 pub struct ChainLaunch<
55 L0: Layout + ViewLayoutLaunchArg,
56 L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
57 R: Runtime,
58 > {
59 _phantom_runtime: PhantomData<R>,
60 l0: L0::RuntimeArg<R>,
61 l1: L1::RuntimeArg<R>,
62 }
63 impl<
64 L0: Layout + ViewLayoutLaunchArg,
65 L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
66 R: Runtime,
67 > ChainLaunch<L0, L1, R>
68 {
69 pub fn new(l0: L0::RuntimeArg<R>, l1: L1::RuntimeArg<R>) -> Self {
70 Self {
71 _phantom_runtime: PhantomData,
72 l0,
73 l1,
74 }
75 }
76 }
77
78 pub struct ChainCompilationArg<
79 L0: Layout + ViewLayoutLaunchArg,
80 L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
81 > {
82 l0: L0::CompilationArg,
83 l1: L1::CompilationArg,
84 }
85 impl<
86 L0: Layout + ViewLayoutLaunchArg,
87 L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
88 > Clone for ChainCompilationArg<L0, L1>
89 {
90 fn clone(&self) -> Self {
91 Self {
92 l0: self.l0.clone(),
93 l1: self.l1.clone(),
94 }
95 }
96 }
97
98 impl<
99 L0: Layout + ViewLayoutLaunchArg,
100 L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
101 > core::hash::Hash for ChainCompilationArg<L0, L1>
102 {
103 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
104 self.l0.hash(state);
105 self.l1.hash(state);
106 }
107 }
108 impl<
109 L0: Layout + ViewLayoutLaunchArg,
110 L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
111 > core::cmp::PartialEq for ChainCompilationArg<L0, L1>
112 {
113 fn eq(&self, other: &Self) -> bool {
114 self.l0.eq(&other.l0) && self.l1.eq(&other.l1)
115 }
116 }
117 impl<
118 L0: Layout + ViewLayoutLaunchArg,
119 L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
120 > core::fmt::Debug for ChainCompilationArg<L0, L1>
121 {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 f.debug_struct(stringify!(Chain))
124 .field(stringify!(l0), &self.l0)
125 .field(stringify!(l1), &self.l1)
126 .finish()
127 }
128 }
129 impl<
130 L0: Layout + ViewLayoutLaunchArg,
131 L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
132 > core::cmp::Eq for ChainCompilationArg<L0, L1>
133 {
134 }
135
136 impl<
137 L0: Layout + ViewLayoutLaunchArg,
138 L1: Layout<SourceCoordinates = L0::Coordinates> + ViewLayoutLaunchArg,
139 > ViewLayoutLaunchArg for Chain<L0, L1>
140 {
141 type RuntimeArg<R: Runtime> = ChainLaunch<L0, L1, R>;
142 type CompilationArg = ChainCompilationArg<L0, L1>;
143
144 fn register<R: Runtime, B: BufferArg>(
145 arg: Self::RuntimeArg<R>,
146 buffer: &B,
147 ty: Type,
148 launcher: &mut KernelLauncher<R>,
149 ) -> Self::CompilationArg {
150 ChainCompilationArg {
151 l0: L0::register(arg.l0, buffer, ty, launcher),
152 l1: L1::register(arg.l1, buffer, ty, launcher),
153 }
154 }
155 fn expand(
156 arg: &Self::CompilationArg,
157 ty: Type,
158 builder: &mut KernelBuilder,
159 ) -> <Self as CubeType>::ExpandType {
160 ChainExpand {
161 l0: L0::expand(&arg.l0, ty, builder),
162 l1: L1::expand(&arg.l1, ty, builder),
163 }
164 }
165 fn expand_output(
166 arg: &Self::CompilationArg,
167 ty: Type,
168 builder: &mut KernelBuilder,
169 ) -> <Self as CubeType>::ExpandType {
170 ChainExpand {
171 l0: L0::expand_output(&arg.l0, ty, builder),
172 l1: L1::expand_output(&arg.l1, ty, builder),
173 }
174 }
175 }
176}