1use core::marker::PhantomData;
10
11use crate::dtype::{CutlassDtype, GemmSupported, SmArch};
12use crate::kernels;
13use crate::plan_cache::PlanKey;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum ConvLayout {
19 Nhwc,
20 Nchw,
21}
22
23impl ConvLayout {
24 pub fn cutlass_layout(self) -> &'static str {
25 match self {
26 ConvLayout::Nhwc => "cutlass::layout::TensorNHWC",
27 ConvLayout::Nchw => "cutlass::layout::TensorNCHW",
28 }
29 }
30
31 pub fn short_name(self) -> &'static str {
32 match self {
33 ConvLayout::Nhwc => "nhwc",
34 ConvLayout::Nchw => "nchw",
35 }
36 }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
42pub struct ConvShape {
43 pub n: u32,
44 pub h: u32,
45 pub w: u32,
46 pub c: u32,
47 pub k: u32,
48 pub r: u32,
49 pub s: u32,
50 pub pad_h: u32,
51 pub pad_w: u32,
52 pub stride_h: u32,
53 pub stride_w: u32,
54 pub dil_h: u32,
55 pub dil_w: u32,
56}
57
58impl ConvShape {
59 pub fn nhwc(n: u32, h: u32, w: u32, c: u32, k: u32, r: u32, s: u32) -> Self {
61 Self {
62 n,
63 h,
64 w,
65 c,
66 k,
67 r,
68 s,
69 pad_h: 0,
70 pad_w: 0,
71 stride_h: 1,
72 stride_w: 1,
73 dil_h: 1,
74 dil_w: 1,
75 }
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
81pub(crate) enum ConvKind {
82 Fprop,
83 Dgrad,
84 Wgrad,
85}
86
87impl ConvKind {
88 pub(crate) fn short_name(self) -> &'static str {
89 match self {
90 ConvKind::Fprop => "fprop",
91 ConvKind::Dgrad => "dgrad",
92 ConvKind::Wgrad => "wgrad",
93 }
94 }
95
96 pub(crate) fn cutlass_kernel(self) -> &'static str {
97 match self {
98 ConvKind::Fprop => "cutlass::conv::device::ImplicitGemmConvolution",
99 ConvKind::Dgrad => "cutlass::conv::device::ImplicitGemmConvolutionDgrad",
100 ConvKind::Wgrad => "cutlass::conv::device::ImplicitGemmConvolutionWgrad",
101 }
102 }
103}
104
105macro_rules! conv_request {
106 ($name:ident, $kind:expr) => {
107 #[derive(Debug, Clone)]
108 pub struct $name<T: GemmSupported> {
109 pub shape: ConvShape,
110 pub layout: ConvLayout,
111 pub accum_dtype: CutlassDtype,
112 pub output_dtype: CutlassDtype,
113 pub arch: SmArch,
114 _t: PhantomData<fn() -> T>,
115 }
116
117 impl<T: GemmSupported> $name<T> {
118 pub fn new(shape: ConvShape, arch: SmArch) -> Self {
119 Self {
120 shape,
121 layout: ConvLayout::Nhwc,
122 accum_dtype: CutlassDtype::F32,
123 output_dtype: T::DTYPE,
124 arch,
125 _t: PhantomData,
126 }
127 }
128
129 pub fn with_layout(mut self, layout: ConvLayout) -> Self {
130 self.layout = layout;
131 self
132 }
133
134 pub fn with_accum_dtype(mut self, dt: CutlassDtype) -> Self {
135 self.accum_dtype = dt;
136 self
137 }
138
139 pub fn with_output_dtype(mut self, dt: CutlassDtype) -> Self {
140 self.output_dtype = dt;
141 self
142 }
143
144 pub fn plan_key(&self) -> PlanKey {
145 PlanKey::conv::<T>(
146 $kind,
147 self.shape,
148 self.layout,
149 self.accum_dtype,
150 self.output_dtype,
151 self.arch,
152 )
153 }
154
155 pub fn render_cu(&self) -> (String, String) {
156 kernels::render_conv::<T>(
157 $kind,
158 self.shape,
159 self.layout,
160 self.accum_dtype,
161 self.output_dtype,
162 self.arch,
163 )
164 }
165 }
166 };
167}
168
169conv_request!(ConvFwdRequest, ConvKind::Fprop);
170conv_request!(ConvDgradRequest, ConvKind::Dgrad);
171conv_request!(ConvWgradRequest, ConvKind::Wgrad);
172
173pub trait CutlassConvDispatch: Send + 'static {
175 fn plan_key(&self) -> PlanKey;
176 fn render_cu(&self) -> (String, String);
177 fn dtype(&self) -> CutlassDtype;
178 fn arch(&self) -> SmArch;
179 fn shape(&self) -> ConvShape;
180 fn kind_name(&self) -> &'static str;
181}
182
183macro_rules! impl_dispatch {
184 ($name:ident, $kind:expr) => {
185 impl<T: GemmSupported> CutlassConvDispatch for $name<T> {
186 fn plan_key(&self) -> PlanKey {
187 $name::plan_key(self)
188 }
189
190 fn render_cu(&self) -> (String, String) {
191 $name::render_cu(self)
192 }
193
194 fn dtype(&self) -> CutlassDtype {
195 T::DTYPE
196 }
197
198 fn arch(&self) -> SmArch {
199 self.arch
200 }
201
202 fn shape(&self) -> ConvShape {
203 self.shape
204 }
205
206 fn kind_name(&self) -> &'static str {
207 $kind.short_name()
208 }
209 }
210 };
211}
212
213impl_dispatch!(ConvFwdRequest, ConvKind::Fprop);
214impl_dispatch!(ConvDgradRequest, ConvKind::Dgrad);
215impl_dispatch!(ConvWgradRequest, ConvKind::Wgrad);
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use crate::dtype::F16;
221
222 #[test]
223 fn conv_fwd_dgrad_wgrad_round_trip() {
224 let shape = ConvShape::nhwc(8, 56, 56, 64, 128, 3, 3);
225
226 let fwd = ConvFwdRequest::<F16>::new(shape, SmArch::Sm80);
227 let dgrad = ConvDgradRequest::<F16>::new(shape, SmArch::Sm80);
228 let wgrad = ConvWgradRequest::<F16>::new(shape, SmArch::Sm80);
229
230 let kf = fwd.plan_key();
232 let kd = dgrad.plan_key();
233 let kw = wgrad.plan_key();
234 assert_ne!(kf, kd);
235 assert_ne!(kd, kw);
236 assert_ne!(kf, kw);
237
238 let (src_f, name_f) = fwd.render_cu();
239 assert!(name_f.contains("fprop"));
240 assert!(src_f.contains("ImplicitGemmConvolution"));
241
242 let (_, name_d) = dgrad.render_cu();
243 assert!(name_d.contains("dgrad"));
244
245 let (_, name_w) = wgrad.render_cu();
246 assert!(name_w.contains("wgrad"));
247
248 assert_eq!(fwd.kind_name(), "fprop");
250 assert_eq!(dgrad.kind_name(), "dgrad");
251 assert_eq!(wgrad.kind_name(), "wgrad");
252
253 let fwd_nchw =
255 ConvFwdRequest::<F16>::new(shape, SmArch::Sm80).with_layout(ConvLayout::Nchw);
256 assert_ne!(fwd.plan_key(), fwd_nchw.plan_key());
257 }
258}