Skip to main content

baracuda_kernels/image/
grid_sample_backward.rs

1//! `grid_sample` BW plan — Category T (2-D).
2//!
3//! Adjoint of [`crate::image::GridSamplePlan`]: scatters the upstream
4//! gradient back into the input via the same bilinear weights
5//! (atomicAdd into `dinput`), and accumulates the analytical
6//! coordinate derivatives into `dgrid` (also atomic; one (n, oh, ow)
7//! cell aggregates contributions across C).
8//!
9//! Caller MUST pre-zero both `dinput` and `dgrid`.
10
11use core::ffi::c_void;
12use core::marker::PhantomData;
13
14use baracuda_cutlass::{Error, Result};
15use baracuda_driver::Stream;
16use baracuda_kernels_types::{
17    ArchSku, BackendKind, Element, ElementKind, ImageKind, KernelSku, MathPrecision, OpCategory,
18    PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
19};
20
21use super::map_status;
22
23/// Descriptor for `grid_sample_backward`.
24#[derive(Copy, Clone, Debug)]
25pub struct GridSampleBackwardDescriptor {
26    /// Batch.
27    pub n: i32,
28    /// Channels.
29    pub c: i32,
30    /// Input height.
31    pub ih: i32,
32    /// Input width.
33    pub iw: i32,
34    /// Output height.
35    pub oh: i32,
36    /// Output width.
37    pub ow: i32,
38    /// Value element type.
39    pub element: ElementKind,
40}
41
42/// Args bundle for a `grid_sample_backward` launch.
43pub struct GridSampleBackwardArgs<'a, T: Element> {
44    /// Upstream gradient `[N, C, OH, OW]`.
45    pub dout: TensorRef<'a, T, 4>,
46    /// Saved FW input `[N, C, IH, IW]`.
47    pub input: TensorRef<'a, T, 4>,
48    /// Saved FW grid `[N, OH, OW, 2]`.
49    pub grid: TensorRef<'a, T, 4>,
50    /// Gradient w.r.t. input `[N, C, IH, IW]`. Caller pre-zeros.
51    pub dinput: TensorMut<'a, T, 4>,
52    /// Gradient w.r.t. grid `[N, OH, OW, 2]`. Caller pre-zeros.
53    pub dgrid: TensorMut<'a, T, 4>,
54}
55
56/// `grid_sample_backward` plan.
57///
58/// Adjoint of [`crate::GridSamplePlan`]: scatter `dout` into both
59/// `dinput` (4 bilinear weights via atomicAdd) and `dgrid` (analytic
60/// chain rule through the coordinate mapping).
61///
62/// **When to use**: BW for [`GridSamplePlan`](crate::GridSamplePlan).
63/// Caller retains FW `input` and `grid`.
64///
65/// **Dtypes**: `{f32, f64}`.
66///
67/// **Shape limits**: rank-4 NCHW + grid `[N, OH, OW, 2]`.
68///
69/// **Workspace**: none. Caller MUST zero `dinput` and `dgrid` before
70/// launch.
71///
72/// **Precision guarantee**: **non-deterministic** (atomicAdd into
73/// `dinput`).
74pub struct GridSampleBackwardPlan<T: Element> {
75    desc: GridSampleBackwardDescriptor,
76    sku: KernelSku,
77    _marker: PhantomData<T>,
78}
79
80impl<T: Element> GridSampleBackwardPlan<T> {
81    /// Pick a kernel.
82    pub fn select(
83        _stream: &Stream,
84        desc: &GridSampleBackwardDescriptor,
85        _pref: PlanPreference,
86    ) -> Result<Self> {
87        if desc.element != T::KIND {
88            return Err(Error::Unsupported(
89                "baracuda-kernels::GridSampleBackwardPlan: descriptor element != T",
90            ));
91        }
92        if desc.n < 0 || desc.c < 0 || desc.ih < 0 || desc.iw < 0 || desc.oh < 0 || desc.ow < 0 {
93            return Err(Error::InvalidProblem(
94                "baracuda-kernels::GridSampleBackwardPlan: all extents must be non-negative",
95            ));
96        }
97        if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
98            return Err(Error::Unsupported(
99                "baracuda-kernels::GridSampleBackwardPlan: only `f32`, `f64` wired",
100            ));
101        }
102        let precision_guarantee = PrecisionGuarantee {
103            math_precision: if T::KIND == ElementKind::F64 {
104                MathPrecision::F64
105            } else {
106                MathPrecision::F32
107            },
108            accumulator: T::KIND,
109            bit_stable_on_same_hardware: false,
110            deterministic: false,
111        };
112        let sku = KernelSku {
113            category: OpCategory::Image,
114            op: ImageKind::GridSample2dBackward as u16,
115            element: T::KIND,
116            aux_element: None,
117            layout: None,
118            epilogue: None,
119            arch: ArchSku::Sm80,
120            backend: BackendKind::Bespoke,
121            precision_guarantee,
122        };
123        Ok(Self {
124            desc: *desc,
125            sku,
126            _marker: PhantomData,
127        })
128    }
129
130    /// Validate args.
131    pub fn can_implement(&self, args: &GridSampleBackwardArgs<'_, T>) -> Result<()> {
132        if args.dout.shape != [self.desc.n, self.desc.c, self.desc.oh, self.desc.ow]
133            || args.input.shape != [self.desc.n, self.desc.c, self.desc.ih, self.desc.iw]
134            || args.grid.shape != [self.desc.n, self.desc.oh, self.desc.ow, 2]
135            || args.dinput.shape != [self.desc.n, self.desc.c, self.desc.ih, self.desc.iw]
136            || args.dgrid.shape != [self.desc.n, self.desc.oh, self.desc.ow, 2]
137        {
138            return Err(Error::InvalidProblem(
139                "baracuda-kernels::GridSampleBackwardPlan: operand shape mismatch",
140            ));
141        }
142        Ok(())
143    }
144
145    /// Workspace (zero).
146    #[inline]
147    pub fn workspace_size(&self) -> usize {
148        0
149    }
150
151    /// Identity.
152    #[inline]
153    pub fn sku(&self) -> KernelSku {
154        self.sku
155    }
156
157    /// Numerical guarantees.
158    #[inline]
159    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
160        self.sku.precision_guarantee
161    }
162
163    /// Launch.
164    pub fn run(
165        &self,
166        stream: &Stream,
167        _workspace: Workspace<'_>,
168        args: GridSampleBackwardArgs<'_, T>,
169    ) -> Result<()> {
170        self.can_implement(&args)?;
171        if args.dout.numel() == 0 {
172            return Ok(());
173        }
174        let dout_ptr = args.dout.data.as_raw().0 as *const c_void;
175        let input_ptr = args.input.data.as_raw().0 as *const c_void;
176        let grid_ptr = args.grid.data.as_raw().0 as *const c_void;
177        let din_ptr = args.dinput.data.as_raw().0 as *mut c_void;
178        let dgrid_ptr = args.dgrid.data.as_raw().0 as *mut c_void;
179        let stream_ptr = stream.as_raw() as *mut c_void;
180        let status = match T::KIND {
181            ElementKind::F32 => unsafe {
182                baracuda_kernels_sys::baracuda_kernels_grid_sample_2d_backward_f32_run(
183                    self.desc.n, self.desc.c, self.desc.ih, self.desc.iw,
184                    self.desc.oh, self.desc.ow,
185                    dout_ptr, input_ptr, grid_ptr,
186                    din_ptr, dgrid_ptr,
187                    core::ptr::null_mut(), 0, stream_ptr,
188                )
189            },
190            ElementKind::F64 => unsafe {
191                baracuda_kernels_sys::baracuda_kernels_grid_sample_2d_backward_f64_run(
192                    self.desc.n, self.desc.c, self.desc.ih, self.desc.iw,
193                    self.desc.oh, self.desc.ow,
194                    dout_ptr, input_ptr, grid_ptr,
195                    din_ptr, dgrid_ptr,
196                    core::ptr::null_mut(), 0, stream_ptr,
197                )
198            },
199            _ => {
200                return Err(Error::Unsupported(
201                    "baracuda-kernels::GridSampleBackwardPlan::run reached unimplemented dtype",
202                ));
203            }
204        };
205        map_status(status)
206    }
207}