1use core::ffi::c_void;
12use core::marker::PhantomData;
13
14use baracuda_cutlass::{Error, Result};
15use baracuda_driver::Stream;
16use baracuda_kernels_types::{
17 ArchSku, BackendKind, Element, ElementKind, ImageKind, KernelSku, MathPrecision, OpCategory,
18 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
19};
20
21use super::map_status;
22
23#[derive(Copy, Clone, Debug)]
25pub struct GridSampleBackwardDescriptor {
26 pub n: i32,
28 pub c: i32,
30 pub ih: i32,
32 pub iw: i32,
34 pub oh: i32,
36 pub ow: i32,
38 pub element: ElementKind,
40}
41
42pub struct GridSampleBackwardArgs<'a, T: Element> {
44 pub dout: TensorRef<'a, T, 4>,
46 pub input: TensorRef<'a, T, 4>,
48 pub grid: TensorRef<'a, T, 4>,
50 pub dinput: TensorMut<'a, T, 4>,
52 pub dgrid: TensorMut<'a, T, 4>,
54}
55
56pub struct GridSampleBackwardPlan<T: Element> {
75 desc: GridSampleBackwardDescriptor,
76 sku: KernelSku,
77 _marker: PhantomData<T>,
78}
79
80impl<T: Element> GridSampleBackwardPlan<T> {
81 pub fn select(
83 _stream: &Stream,
84 desc: &GridSampleBackwardDescriptor,
85 _pref: PlanPreference,
86 ) -> Result<Self> {
87 if desc.element != T::KIND {
88 return Err(Error::Unsupported(
89 "baracuda-kernels::GridSampleBackwardPlan: descriptor element != T",
90 ));
91 }
92 if desc.n < 0 || desc.c < 0 || desc.ih < 0 || desc.iw < 0 || desc.oh < 0 || desc.ow < 0 {
93 return Err(Error::InvalidProblem(
94 "baracuda-kernels::GridSampleBackwardPlan: all extents must be non-negative",
95 ));
96 }
97 if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
98 return Err(Error::Unsupported(
99 "baracuda-kernels::GridSampleBackwardPlan: only `f32`, `f64` wired",
100 ));
101 }
102 let precision_guarantee = PrecisionGuarantee {
103 math_precision: if T::KIND == ElementKind::F64 {
104 MathPrecision::F64
105 } else {
106 MathPrecision::F32
107 },
108 accumulator: T::KIND,
109 bit_stable_on_same_hardware: false,
110 deterministic: false,
111 };
112 let sku = KernelSku {
113 category: OpCategory::Image,
114 op: ImageKind::GridSample2dBackward as u16,
115 element: T::KIND,
116 aux_element: None,
117 layout: None,
118 epilogue: None,
119 arch: ArchSku::Sm80,
120 backend: BackendKind::Bespoke,
121 precision_guarantee,
122 };
123 Ok(Self {
124 desc: *desc,
125 sku,
126 _marker: PhantomData,
127 })
128 }
129
130 pub fn can_implement(&self, args: &GridSampleBackwardArgs<'_, T>) -> Result<()> {
132 if args.dout.shape != [self.desc.n, self.desc.c, self.desc.oh, self.desc.ow]
133 || args.input.shape != [self.desc.n, self.desc.c, self.desc.ih, self.desc.iw]
134 || args.grid.shape != [self.desc.n, self.desc.oh, self.desc.ow, 2]
135 || args.dinput.shape != [self.desc.n, self.desc.c, self.desc.ih, self.desc.iw]
136 || args.dgrid.shape != [self.desc.n, self.desc.oh, self.desc.ow, 2]
137 {
138 return Err(Error::InvalidProblem(
139 "baracuda-kernels::GridSampleBackwardPlan: operand shape mismatch",
140 ));
141 }
142 Ok(())
143 }
144
145 #[inline]
147 pub fn workspace_size(&self) -> usize {
148 0
149 }
150
151 #[inline]
153 pub fn sku(&self) -> KernelSku {
154 self.sku
155 }
156
157 #[inline]
159 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
160 self.sku.precision_guarantee
161 }
162
163 pub fn run(
165 &self,
166 stream: &Stream,
167 _workspace: Workspace<'_>,
168 args: GridSampleBackwardArgs<'_, T>,
169 ) -> Result<()> {
170 self.can_implement(&args)?;
171 if args.dout.numel() == 0 {
172 return Ok(());
173 }
174 let dout_ptr = args.dout.data.as_raw().0 as *const c_void;
175 let input_ptr = args.input.data.as_raw().0 as *const c_void;
176 let grid_ptr = args.grid.data.as_raw().0 as *const c_void;
177 let din_ptr = args.dinput.data.as_raw().0 as *mut c_void;
178 let dgrid_ptr = args.dgrid.data.as_raw().0 as *mut c_void;
179 let stream_ptr = stream.as_raw() as *mut c_void;
180 let status = match T::KIND {
181 ElementKind::F32 => unsafe {
182 baracuda_kernels_sys::baracuda_kernels_grid_sample_2d_backward_f32_run(
183 self.desc.n, self.desc.c, self.desc.ih, self.desc.iw,
184 self.desc.oh, self.desc.ow,
185 dout_ptr, input_ptr, grid_ptr,
186 din_ptr, dgrid_ptr,
187 core::ptr::null_mut(), 0, stream_ptr,
188 )
189 },
190 ElementKind::F64 => unsafe {
191 baracuda_kernels_sys::baracuda_kernels_grid_sample_2d_backward_f64_run(
192 self.desc.n, self.desc.c, self.desc.ih, self.desc.iw,
193 self.desc.oh, self.desc.ow,
194 dout_ptr, input_ptr, grid_ptr,
195 din_ptr, dgrid_ptr,
196 core::ptr::null_mut(), 0, stream_ptr,
197 )
198 },
199 _ => {
200 return Err(Error::Unsupported(
201 "baracuda-kernels::GridSampleBackwardPlan::run reached unimplemented dtype",
202 ));
203 }
204 };
205 map_status(status)
206 }
207}