baracuda_kernels/image/interpolate.rs
1//! `interpolate` FW plan — Category T trailblazer.
2//!
3//! Spatial resample of an NCHW input via bilinear interpolation.
4//! Trailblazer mode: `Bilinear2d`. Other modes are reserved on
5//! [`InterpolateMode`] and return `Unsupported`.
6//!
7//! Output shape: `[N, C, OH, OW]` from input `[N, C, IH, IW]`. The
8//! coordinate mapping (per PyTorch ATen `UpSample.h`):
9//!
10//! - `align_corners=false` (PyTorch new-code default):
11//! `scale_h = scale_h.unwrap_or(IH/OH)^-1`, `src_y = (oh + 0.5) * scale_h - 0.5`
12//! - `align_corners=true` (PyTorch `nn.Upsample(align_corners=True)`):
13//! `scale_h = scale_h.unwrap_or((IH-1)/(OH-1))^-1`, `src_y = oh * scale_h`
14//!
15//! `scale_h` / `scale_w` (when `Some`) are interpreted as PyTorch-style
16//! SCALE values (output_size / input_size); the kernel uses `1/scale`
17//! per output coordinate.
18//!
19//! Out-of-range samples are clamped to the input boundary (matches
20//! PyTorch).
21//!
22//! Dtype coverage (Phase 21): `f32, f64, f16, bf16`. Half-precision
23//! paths cast at load, accumulate in `f32`, cast at store.
24//!
25//! # Phase 21 breaking change
26//!
27//! [`InterpolateDescriptor`] gained `align_corners`, `scale_h`, and
28//! `scale_w` fields. Pre-Phase-21 callers constructing the struct must
29//! supply the new fields. The underlying FFI also took on three new
30//! params (`align_corners: i32`, `scale_h_factor: f64`,
31//! `scale_w_factor: f64`) — see `baracuda-kernels-sys` rustdoc.
32
33use core::ffi::c_void;
34use core::marker::PhantomData;
35
36use baracuda_cutlass::{Error, Result};
37use baracuda_driver::Stream;
38use baracuda_kernels_types::{
39 ArchSku, BackendKind, Element, ElementKind, ImageKind, KernelSku, MathPrecision, OpCategory,
40 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
41};
42
43use super::map_status;
44
45/// Interpolation mode for [`InterpolatePlan`]. Only [`Self::Bilinear2d`]
46/// is wired today; the other variants return `Unsupported`.
47///
48/// `#[non_exhaustive]` — additional interpolation modes (cubic
49/// spline, lanczos, mitchell-netravali, …) may land in future
50/// vision-domain phases. Match arms must include a `_ =>` catch-all.
51#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
52#[non_exhaustive]
53pub enum InterpolateMode {
54 /// 2-D bilinear interpolation. Trailblazer.
55 Bilinear2d,
56 /// 2-D nearest-neighbor — reserved.
57 Nearest2d,
58 /// 2-D bicubic — reserved.
59 Bicubic2d,
60 /// 3-D trilinear — reserved.
61 Trilinear3d,
62 /// 1-D linear — reserved.
63 Linear1d,
64 /// 2-D area (adaptive average) — reserved.
65 Area2d,
66}
67
68/// Descriptor for an `interpolate` op.
69///
70/// `#[non_exhaustive]` (Phase 32) — Phase 21 added `align_corners` /
71/// `scale_h` / `scale_w`; future modes may add more fields. Use
72/// [`Self::new`] + the `with_*` setters from downstream code.
73#[derive(Copy, Clone, Debug)]
74#[non_exhaustive]
75pub struct InterpolateDescriptor {
76 /// Batch.
77 pub n: i32,
78 /// Channels.
79 pub c: i32,
80 /// Input height.
81 pub ih: i32,
82 /// Input width.
83 pub iw: i32,
84 /// Output height.
85 pub oh: i32,
86 /// Output width.
87 pub ow: i32,
88 /// Interpolation mode.
89 pub mode: InterpolateMode,
90 /// Value element type. Must match `T::KIND`.
91 pub element: ElementKind,
92 /// Coordinate alignment mode. `false` matches PyTorch
93 /// `F.interpolate` new-code default; `true` matches
94 /// `nn.Upsample(align_corners=True)`.
95 pub align_corners: bool,
96 /// Per-axis SCALE override for height (output_size / input_size).
97 /// `None` derives the scale from `(ih, oh)`; `Some(s)` overrides
98 /// and the kernel uses `1.0 / s` per output coordinate. Matches
99 /// PyTorch's `scale_factor` semantics.
100 pub scale_h: Option<f64>,
101 /// Per-axis SCALE override for width (output_size / input_size).
102 /// `None` derives the scale from `(iw, ow)`; `Some(s)` overrides
103 /// and the kernel uses `1.0 / s` per output coordinate. Matches
104 /// PyTorch's `scale_factor` semantics.
105 pub scale_w: Option<f64>,
106}
107
108impl InterpolateDescriptor {
109 /// Build a descriptor with `align_corners = false` (PyTorch
110 /// `F.interpolate` new-code default) and `scale_h / scale_w = None`
111 /// (derive scale from `(ih, oh)` / `(iw, ow)`). Chain with the
112 /// `with_*` setters to override.
113 #[allow(clippy::too_many_arguments)]
114 pub fn new(
115 n: i32,
116 c: i32,
117 ih: i32,
118 iw: i32,
119 oh: i32,
120 ow: i32,
121 mode: InterpolateMode,
122 element: ElementKind,
123 ) -> Self {
124 Self {
125 n,
126 c,
127 ih,
128 iw,
129 oh,
130 ow,
131 mode,
132 element,
133 align_corners: false,
134 scale_h: None,
135 scale_w: None,
136 }
137 }
138
139 /// Override `align_corners`. Default `false`.
140 #[inline]
141 pub fn with_align_corners(mut self, align_corners: bool) -> Self {
142 self.align_corners = align_corners;
143 self
144 }
145
146 /// Override the per-axis SCALE for height. Default `None` (derive
147 /// from `(ih, oh)`).
148 #[inline]
149 pub fn with_scale_h(mut self, scale_h: Option<f64>) -> Self {
150 self.scale_h = scale_h;
151 self
152 }
153
154 /// Override the per-axis SCALE for width. Default `None` (derive
155 /// from `(iw, ow)`).
156 #[inline]
157 pub fn with_scale_w(mut self, scale_w: Option<f64>) -> Self {
158 self.scale_w = scale_w;
159 self
160 }
161}
162
163/// Args bundle for an `interpolate` launch.
164pub struct InterpolateArgs<'a, T: Element> {
165 /// Input `[N, C, IH, IW]`. NCHW row-major contiguous.
166 pub input: TensorRef<'a, T, 4>,
167 /// Output `[N, C, OH, OW]`. NCHW row-major contiguous.
168 pub output: TensorMut<'a, T, 4>,
169}
170
171/// `interpolate` plan.
172///
173/// Spatial resample of an NCHW input. PyTorch `F.interpolate`.
174/// Coordinate mapping: `src = (dst + 0.5) * (src_size / dst_size) - 0.5`
175/// (`align_corners=false`); corner samples clamp to the input
176/// boundary.
177///
178/// **When to use**: forward 2-D bilinear resample. Pair with
179/// [`InterpolateBackwardPlan`](crate::InterpolateBackwardPlan) for
180/// autograd.
181///
182/// **Dtypes**: `{f32, f64, f16, bf16}`.
183///
184/// **Shape limits**: rank-4 NCHW input `[N, C, IH, IW]`; output
185/// `[N, C, OH, OW]`; all extents non-negative.
186///
187/// **Modes**: only `Bilinear2d` is wired in the trailblazer.
188/// `Nearest2d` / `Bicubic2d` / `Trilinear3d` / `Linear1d` / `Area2d`
189/// are reserved on the enum and return `Unsupported`.
190///
191/// **Workspace**: none.
192///
193/// **Precision guarantee**: deterministic, bit-stable on identical
194/// hardware. No atomics on FW.
195pub struct InterpolatePlan<T: Element> {
196 desc: InterpolateDescriptor,
197 sku: KernelSku,
198 _marker: PhantomData<T>,
199}
200
201impl<T: Element> InterpolatePlan<T> {
202 /// Pick a kernel for `desc`.
203 pub fn select(
204 _stream: &Stream,
205 desc: &InterpolateDescriptor,
206 _pref: PlanPreference,
207 ) -> Result<Self> {
208 if desc.element != T::KIND {
209 return Err(Error::Unsupported(
210 "baracuda-kernels::InterpolatePlan: descriptor element != type parameter T",
211 ));
212 }
213 if !matches!(desc.mode, InterpolateMode::Bilinear2d) {
214 return Err(Error::Unsupported(
215 "baracuda-kernels::InterpolatePlan: only Bilinear2d wired in trailblazer",
216 ));
217 }
218 if desc.n < 0 || desc.c < 0 || desc.ih < 0 || desc.iw < 0 || desc.oh < 0 || desc.ow < 0 {
219 return Err(Error::InvalidProblem(
220 "baracuda-kernels::InterpolatePlan: all extents must be non-negative",
221 ));
222 }
223 if !matches!(
224 T::KIND,
225 ElementKind::F32 | ElementKind::F64 | ElementKind::F16 | ElementKind::Bf16
226 ) {
227 return Err(Error::Unsupported(
228 "baracuda-kernels::InterpolatePlan: only `f32`, `f64`, `f16`, `bf16` wired",
229 ));
230 }
231 // Validate scale factors (positive, finite) when present.
232 if let Some(s) = desc.scale_h {
233 if !s.is_finite() || s <= 0.0 {
234 return Err(Error::InvalidProblem(
235 "baracuda-kernels::InterpolatePlan: scale_h must be positive and finite",
236 ));
237 }
238 }
239 if let Some(s) = desc.scale_w {
240 if !s.is_finite() || s <= 0.0 {
241 return Err(Error::InvalidProblem(
242 "baracuda-kernels::InterpolatePlan: scale_w must be positive and finite",
243 ));
244 }
245 }
246 let precision_guarantee = PrecisionGuarantee {
247 math_precision: if T::KIND == ElementKind::F64 {
248 MathPrecision::F64
249 } else {
250 MathPrecision::F32
251 },
252 // Half-precision paths accumulate in f32, then cast.
253 accumulator: if matches!(T::KIND, ElementKind::F16 | ElementKind::Bf16) {
254 ElementKind::F32
255 } else {
256 T::KIND
257 },
258 bit_stable_on_same_hardware: true,
259 deterministic: true,
260 };
261 let sku = KernelSku {
262 category: OpCategory::Image,
263 op: ImageKind::InterpolateBilinear2d as u16,
264 element: T::KIND,
265 aux_element: None,
266 layout: None,
267 epilogue: None,
268 arch: ArchSku::Sm80,
269 backend: BackendKind::Bespoke,
270 precision_guarantee,
271 };
272 Ok(Self {
273 desc: *desc,
274 sku,
275 _marker: PhantomData,
276 })
277 }
278
279 /// Validate args.
280 pub fn can_implement(&self, args: &InterpolateArgs<'_, T>) -> Result<()> {
281 if args.input.shape != [self.desc.n, self.desc.c, self.desc.ih, self.desc.iw] {
282 return Err(Error::InvalidProblem(
283 "baracuda-kernels::InterpolatePlan: input shape mismatch",
284 ));
285 }
286 if args.output.shape != [self.desc.n, self.desc.c, self.desc.oh, self.desc.ow] {
287 return Err(Error::InvalidProblem(
288 "baracuda-kernels::InterpolatePlan: output shape mismatch",
289 ));
290 }
291 let in_numel = args.input.numel();
292 let out_numel = args.output.numel();
293 if (args.input.data.len() as i64) < in_numel {
294 return Err(Error::BufferTooSmall {
295 needed: in_numel as usize,
296 got: args.input.data.len(),
297 });
298 }
299 if (args.output.data.len() as i64) < out_numel {
300 return Err(Error::BufferTooSmall {
301 needed: out_numel as usize,
302 got: args.output.data.len(),
303 });
304 }
305 Ok(())
306 }
307
308 /// Workspace size (zero).
309 #[inline]
310 pub fn workspace_size(&self) -> usize {
311 0
312 }
313
314 /// Identity of the kernel this plan picked.
315 #[inline]
316 pub fn sku(&self) -> KernelSku {
317 self.sku
318 }
319
320 /// Numerical guarantees for this plan's kernel.
321 #[inline]
322 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
323 self.sku.precision_guarantee
324 }
325
326 /// Launch.
327 pub fn run(
328 &self,
329 stream: &Stream,
330 _workspace: Workspace<'_>,
331 args: InterpolateArgs<'_, T>,
332 ) -> Result<()> {
333 self.can_implement(&args)?;
334 if args.output.numel() == 0 {
335 return Ok(());
336 }
337 let input_ptr = args.input.data.as_raw().0 as *const c_void;
338 let output_ptr = args.output.data.as_raw().0 as *mut c_void;
339 let stream_ptr = stream.as_raw() as *mut c_void;
340 let ac: i32 = if self.desc.align_corners { 1 } else { 0 };
341 // Sentinel: 0.0 = "derive from sizes" on the C side.
342 let sh: f64 = self.desc.scale_h.unwrap_or(0.0);
343 let sw: f64 = self.desc.scale_w.unwrap_or(0.0);
344 let status = match T::KIND {
345 ElementKind::F32 => unsafe {
346 baracuda_kernels_sys::baracuda_kernels_interpolate_bilinear_2d_f32_run(
347 self.desc.n, self.desc.c, self.desc.ih, self.desc.iw,
348 self.desc.oh, self.desc.ow,
349 input_ptr, output_ptr,
350 core::ptr::null_mut(), 0,
351 ac, sh, sw,
352 stream_ptr,
353 )
354 },
355 ElementKind::F64 => unsafe {
356 baracuda_kernels_sys::baracuda_kernels_interpolate_bilinear_2d_f64_run(
357 self.desc.n, self.desc.c, self.desc.ih, self.desc.iw,
358 self.desc.oh, self.desc.ow,
359 input_ptr, output_ptr,
360 core::ptr::null_mut(), 0,
361 ac, sh, sw,
362 stream_ptr,
363 )
364 },
365 ElementKind::F16 => unsafe {
366 baracuda_kernels_sys::baracuda_kernels_interpolate_bilinear_2d_f16_run(
367 self.desc.n, self.desc.c, self.desc.ih, self.desc.iw,
368 self.desc.oh, self.desc.ow,
369 input_ptr, output_ptr,
370 core::ptr::null_mut(), 0,
371 ac, sh, sw,
372 stream_ptr,
373 )
374 },
375 ElementKind::Bf16 => unsafe {
376 baracuda_kernels_sys::baracuda_kernels_interpolate_bilinear_2d_bf16_run(
377 self.desc.n, self.desc.c, self.desc.ih, self.desc.iw,
378 self.desc.oh, self.desc.ow,
379 input_ptr, output_ptr,
380 core::ptr::null_mut(), 0,
381 ac, sh, sw,
382 stream_ptr,
383 )
384 },
385 _ => {
386 return Err(Error::Unsupported(
387 "baracuda-kernels::InterpolatePlan::run reached an unimplemented dtype",
388 ));
389 }
390 };
391 map_status(status)
392 }
393}