1use core::ffi::c_void;
11use core::marker::PhantomData;
12
13use baracuda_cutlass::{Error, Result};
14use baracuda_driver::Stream;
15use baracuda_kernels_types::{
16 ArchSku, BackendKind, Element, ElementKind, ImageKind, KernelSku, MathPrecision, OpCategory,
17 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
18};
19
20use super::map_status;
21
22#[derive(Copy, Clone, Debug)]
24pub struct GridSampleDescriptor {
25 pub n: i32,
27 pub c: i32,
29 pub ih: i32,
31 pub iw: i32,
33 pub oh: i32,
35 pub ow: i32,
37 pub element: ElementKind,
39}
40
41pub struct GridSampleArgs<'a, T: Element> {
43 pub input: TensorRef<'a, T, 4>,
45 pub grid: TensorRef<'a, T, 4>,
47 pub output: TensorMut<'a, T, 4>,
49}
50
51pub struct GridSamplePlan<T: Element> {
70 desc: GridSampleDescriptor,
71 sku: KernelSku,
72 _marker: PhantomData<T>,
73}
74
75impl<T: Element> GridSamplePlan<T> {
76 pub fn select(
78 _stream: &Stream,
79 desc: &GridSampleDescriptor,
80 _pref: PlanPreference,
81 ) -> Result<Self> {
82 if desc.element != T::KIND {
83 return Err(Error::Unsupported(
84 "baracuda-kernels::GridSamplePlan: descriptor element != T",
85 ));
86 }
87 if desc.n < 0 || desc.c < 0 || desc.ih < 0 || desc.iw < 0 || desc.oh < 0 || desc.ow < 0 {
88 return Err(Error::InvalidProblem(
89 "baracuda-kernels::GridSamplePlan: all extents must be non-negative",
90 ));
91 }
92 if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
93 return Err(Error::Unsupported(
94 "baracuda-kernels::GridSamplePlan: only `f32`, `f64` wired",
95 ));
96 }
97 let precision_guarantee = PrecisionGuarantee {
98 math_precision: if T::KIND == ElementKind::F64 {
99 MathPrecision::F64
100 } else {
101 MathPrecision::F32
102 },
103 accumulator: T::KIND,
104 bit_stable_on_same_hardware: true,
105 deterministic: true,
106 };
107 let sku = KernelSku {
108 category: OpCategory::Image,
109 op: ImageKind::GridSample2d as u16,
110 element: T::KIND,
111 aux_element: None,
112 layout: None,
113 epilogue: None,
114 arch: ArchSku::Sm80,
115 backend: BackendKind::Bespoke,
116 precision_guarantee,
117 };
118 Ok(Self {
119 desc: *desc,
120 sku,
121 _marker: PhantomData,
122 })
123 }
124
125 pub fn can_implement(&self, args: &GridSampleArgs<'_, T>) -> Result<()> {
127 if args.input.shape != [self.desc.n, self.desc.c, self.desc.ih, self.desc.iw] {
128 return Err(Error::InvalidProblem(
129 "baracuda-kernels::GridSamplePlan: input shape mismatch",
130 ));
131 }
132 if args.grid.shape != [self.desc.n, self.desc.oh, self.desc.ow, 2] {
133 return Err(Error::InvalidProblem(
134 "baracuda-kernels::GridSamplePlan: grid shape must be [N, OH, OW, 2]",
135 ));
136 }
137 if args.output.shape != [self.desc.n, self.desc.c, self.desc.oh, self.desc.ow] {
138 return Err(Error::InvalidProblem(
139 "baracuda-kernels::GridSamplePlan: output shape mismatch",
140 ));
141 }
142 Ok(())
143 }
144
145 #[inline]
147 pub fn workspace_size(&self) -> usize {
148 0
149 }
150
151 #[inline]
153 pub fn sku(&self) -> KernelSku {
154 self.sku
155 }
156
157 #[inline]
159 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
160 self.sku.precision_guarantee
161 }
162
163 pub fn run(
165 &self,
166 stream: &Stream,
167 _workspace: Workspace<'_>,
168 args: GridSampleArgs<'_, T>,
169 ) -> Result<()> {
170 self.can_implement(&args)?;
171 if args.output.numel() == 0 {
172 return Ok(());
173 }
174 let input_ptr = args.input.data.as_raw().0 as *const c_void;
175 let grid_ptr = args.grid.data.as_raw().0 as *const c_void;
176 let out_ptr = args.output.data.as_raw().0 as *mut c_void;
177 let stream_ptr = stream.as_raw() as *mut c_void;
178 let status = match T::KIND {
179 ElementKind::F32 => unsafe {
180 baracuda_kernels_sys::baracuda_kernels_grid_sample_2d_f32_run(
181 self.desc.n, self.desc.c, self.desc.ih, self.desc.iw,
182 self.desc.oh, self.desc.ow,
183 input_ptr, grid_ptr, out_ptr,
184 core::ptr::null_mut(), 0, stream_ptr,
185 )
186 },
187 ElementKind::F64 => unsafe {
188 baracuda_kernels_sys::baracuda_kernels_grid_sample_2d_f64_run(
189 self.desc.n, self.desc.c, self.desc.ih, self.desc.iw,
190 self.desc.oh, self.desc.ow,
191 input_ptr, grid_ptr, out_ptr,
192 core::ptr::null_mut(), 0, stream_ptr,
193 )
194 },
195 _ => {
196 return Err(Error::Unsupported(
197 "baracuda-kernels::GridSamplePlan::run reached unimplemented dtype",
198 ));
199 }
200 };
201 map_status(status)
202 }
203}