Skip to main content

atomr_accel_cutlass/
conv.rs

1//! Implicit-GEMM convolution requests.
2//!
3//! CUTLASS exposes three implicit-GEMM kernels per layout:
4//! `Conv2dFprop` (forward), `Conv2dDgrad` (gradient w.r.t. input),
5//! `Conv2dWgrad` (gradient w.r.t. filter). We mirror that surface
6//! here. Layout selection (NHWC / NCHW) and dtype propagate through
7//! the same template render pipeline as GEMM.
8
9use core::marker::PhantomData;
10
11use crate::dtype::{CutlassDtype, GemmSupported, SmArch};
12use crate::kernels;
13use crate::plan_cache::PlanKey;
14
15/// Tensor layout for the convolution. CUTLASS's implicit-GEMM kernels
16/// are NHWC-first; NCHW is a translated fallback.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum ConvLayout {
19    Nhwc,
20    Nchw,
21}
22
23impl ConvLayout {
24    pub fn cutlass_layout(self) -> &'static str {
25        match self {
26            ConvLayout::Nhwc => "cutlass::layout::TensorNHWC",
27            ConvLayout::Nchw => "cutlass::layout::TensorNCHW",
28        }
29    }
30
31    pub fn short_name(self) -> &'static str {
32        match self {
33            ConvLayout::Nhwc => "nhwc",
34            ConvLayout::Nchw => "nchw",
35        }
36    }
37}
38
39/// `(N, H, W, C)` × `(R, S)` × stride / pad / dilation. `(K, P, Q)` is
40/// derived inside the template.
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
42pub struct ConvShape {
43    pub n: u32,
44    pub h: u32,
45    pub w: u32,
46    pub c: u32,
47    pub k: u32,
48    pub r: u32,
49    pub s: u32,
50    pub pad_h: u32,
51    pub pad_w: u32,
52    pub stride_h: u32,
53    pub stride_w: u32,
54    pub dil_h: u32,
55    pub dil_w: u32,
56}
57
58impl ConvShape {
59    /// Convenience builder: stride / pad / dilation default to 1 / 0 / 1.
60    pub fn nhwc(n: u32, h: u32, w: u32, c: u32, k: u32, r: u32, s: u32) -> Self {
61        Self {
62            n,
63            h,
64            w,
65            c,
66            k,
67            r,
68            s,
69            pad_h: 0,
70            pad_w: 0,
71            stride_h: 1,
72            stride_w: 1,
73            dil_h: 1,
74            dil_w: 1,
75        }
76    }
77}
78
79/// Discriminator for which convolution gradient we're emitting.
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
81pub(crate) enum ConvKind {
82    Fprop,
83    Dgrad,
84    Wgrad,
85}
86
87impl ConvKind {
88    pub(crate) fn short_name(self) -> &'static str {
89        match self {
90            ConvKind::Fprop => "fprop",
91            ConvKind::Dgrad => "dgrad",
92            ConvKind::Wgrad => "wgrad",
93        }
94    }
95
96    pub(crate) fn cutlass_kernel(self) -> &'static str {
97        match self {
98            ConvKind::Fprop => "cutlass::conv::device::ImplicitGemmConvolution",
99            ConvKind::Dgrad => "cutlass::conv::device::ImplicitGemmConvolutionDgrad",
100            ConvKind::Wgrad => "cutlass::conv::device::ImplicitGemmConvolutionWgrad",
101        }
102    }
103}
104
105macro_rules! conv_request {
106    ($name:ident, $kind:expr) => {
107        #[derive(Debug, Clone)]
108        pub struct $name<T: GemmSupported> {
109            pub shape: ConvShape,
110            pub layout: ConvLayout,
111            pub accum_dtype: CutlassDtype,
112            pub output_dtype: CutlassDtype,
113            pub arch: SmArch,
114            _t: PhantomData<fn() -> T>,
115        }
116
117        impl<T: GemmSupported> $name<T> {
118            pub fn new(shape: ConvShape, arch: SmArch) -> Self {
119                Self {
120                    shape,
121                    layout: ConvLayout::Nhwc,
122                    accum_dtype: CutlassDtype::F32,
123                    output_dtype: T::DTYPE,
124                    arch,
125                    _t: PhantomData,
126                }
127            }
128
129            pub fn with_layout(mut self, layout: ConvLayout) -> Self {
130                self.layout = layout;
131                self
132            }
133
134            pub fn with_accum_dtype(mut self, dt: CutlassDtype) -> Self {
135                self.accum_dtype = dt;
136                self
137            }
138
139            pub fn with_output_dtype(mut self, dt: CutlassDtype) -> Self {
140                self.output_dtype = dt;
141                self
142            }
143
144            pub fn plan_key(&self) -> PlanKey {
145                PlanKey::conv::<T>(
146                    $kind,
147                    self.shape,
148                    self.layout,
149                    self.accum_dtype,
150                    self.output_dtype,
151                    self.arch,
152                )
153            }
154
155            pub fn render_cu(&self) -> (String, String) {
156                kernels::render_conv::<T>(
157                    $kind,
158                    self.shape,
159                    self.layout,
160                    self.accum_dtype,
161                    self.output_dtype,
162                    self.arch,
163                )
164            }
165        }
166    };
167}
168
169conv_request!(ConvFwdRequest, ConvKind::Fprop);
170conv_request!(ConvDgradRequest, ConvKind::Dgrad);
171conv_request!(ConvWgradRequest, ConvKind::Wgrad);
172
173/// Erased dispatch surface for any convolution gradient.
174pub trait CutlassConvDispatch: Send + 'static {
175    fn plan_key(&self) -> PlanKey;
176    fn render_cu(&self) -> (String, String);
177    fn dtype(&self) -> CutlassDtype;
178    fn arch(&self) -> SmArch;
179    fn shape(&self) -> ConvShape;
180    fn kind_name(&self) -> &'static str;
181}
182
183macro_rules! impl_dispatch {
184    ($name:ident, $kind:expr) => {
185        impl<T: GemmSupported> CutlassConvDispatch for $name<T> {
186            fn plan_key(&self) -> PlanKey {
187                $name::plan_key(self)
188            }
189
190            fn render_cu(&self) -> (String, String) {
191                $name::render_cu(self)
192            }
193
194            fn dtype(&self) -> CutlassDtype {
195                T::DTYPE
196            }
197
198            fn arch(&self) -> SmArch {
199                self.arch
200            }
201
202            fn shape(&self) -> ConvShape {
203                self.shape
204            }
205
206            fn kind_name(&self) -> &'static str {
207                $kind.short_name()
208            }
209        }
210    };
211}
212
213impl_dispatch!(ConvFwdRequest, ConvKind::Fprop);
214impl_dispatch!(ConvDgradRequest, ConvKind::Dgrad);
215impl_dispatch!(ConvWgradRequest, ConvKind::Wgrad);
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use crate::dtype::F16;
221
222    #[test]
223    fn conv_fwd_dgrad_wgrad_round_trip() {
224        let shape = ConvShape::nhwc(8, 56, 56, 64, 128, 3, 3);
225
226        let fwd = ConvFwdRequest::<F16>::new(shape, SmArch::Sm80);
227        let dgrad = ConvDgradRequest::<F16>::new(shape, SmArch::Sm80);
228        let wgrad = ConvWgradRequest::<F16>::new(shape, SmArch::Sm80);
229
230        // All three keys distinct
231        let kf = fwd.plan_key();
232        let kd = dgrad.plan_key();
233        let kw = wgrad.plan_key();
234        assert_ne!(kf, kd);
235        assert_ne!(kd, kw);
236        assert_ne!(kf, kw);
237
238        let (src_f, name_f) = fwd.render_cu();
239        assert!(name_f.contains("fprop"));
240        assert!(src_f.contains("ImplicitGemmConvolution"));
241
242        let (_, name_d) = dgrad.render_cu();
243        assert!(name_d.contains("dgrad"));
244
245        let (_, name_w) = wgrad.render_cu();
246        assert!(name_w.contains("wgrad"));
247
248        // dispatch trait
249        assert_eq!(fwd.kind_name(), "fprop");
250        assert_eq!(dgrad.kind_name(), "dgrad");
251        assert_eq!(wgrad.kind_name(), "wgrad");
252
253        // Layout swap changes the key.
254        let fwd_nchw =
255            ConvFwdRequest::<F16>::new(shape, SmArch::Sm80).with_layout(ConvLayout::Nchw);
256        assert_ne!(fwd.plan_key(), fwd_nchw.plan_key());
257    }
258}