Skip to main content

baracuda_kernels/image/
interpolate.rs

1//! `interpolate` FW plan — Category T trailblazer.
2//!
3//! Spatial resample of an NCHW input via bilinear interpolation.
4//! Trailblazer mode: `Bilinear2d`. Other modes are reserved on
5//! [`InterpolateMode`] and return `Unsupported`.
6//!
7//! Output shape: `[N, C, OH, OW]` from input `[N, C, IH, IW]`. The
8//! coordinate mapping (per PyTorch ATen `UpSample.h`):
9//!
10//! - `align_corners=false` (PyTorch new-code default):
11//!   `scale_h = scale_h.unwrap_or(IH/OH)^-1`, `src_y = (oh + 0.5) * scale_h - 0.5`
12//! - `align_corners=true` (PyTorch `nn.Upsample(align_corners=True)`):
13//!   `scale_h = scale_h.unwrap_or((IH-1)/(OH-1))^-1`, `src_y = oh * scale_h`
14//!
15//! `scale_h` / `scale_w` (when `Some`) are interpreted as PyTorch-style
16//! SCALE values (output_size / input_size); the kernel uses `1/scale`
17//! per output coordinate.
18//!
19//! Out-of-range samples are clamped to the input boundary (matches
20//! PyTorch).
21//!
22//! Dtype coverage (Phase 21): `f32, f64, f16, bf16`. Half-precision
23//! paths cast at load, accumulate in `f32`, cast at store.
24//!
25//! # Phase 21 breaking change
26//!
27//! [`InterpolateDescriptor`] gained `align_corners`, `scale_h`, and
28//! `scale_w` fields. Pre-Phase-21 callers constructing the struct must
29//! supply the new fields. The underlying FFI also took on three new
30//! params (`align_corners: i32`, `scale_h_factor: f64`,
31//! `scale_w_factor: f64`) — see `baracuda-kernels-sys` rustdoc.
32
33use core::ffi::c_void;
34use core::marker::PhantomData;
35
36use baracuda_cutlass::{Error, Result};
37use baracuda_driver::Stream;
38use baracuda_kernels_types::{
39    ArchSku, BackendKind, Element, ElementKind, ImageKind, KernelSku, MathPrecision, OpCategory,
40    PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
41};
42
43use super::map_status;
44
45/// Interpolation mode for [`InterpolatePlan`]. Only [`Self::Bilinear2d`]
46/// is wired today; the other variants return `Unsupported`.
47///
48/// `#[non_exhaustive]` — additional interpolation modes (cubic
49/// spline, lanczos, mitchell-netravali, …) may land in future
50/// vision-domain phases. Match arms must include a `_ =>` catch-all.
51#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
52#[non_exhaustive]
53pub enum InterpolateMode {
54    /// 2-D bilinear interpolation. Trailblazer.
55    Bilinear2d,
56    /// 2-D nearest-neighbor — reserved.
57    Nearest2d,
58    /// 2-D bicubic — reserved.
59    Bicubic2d,
60    /// 3-D trilinear — reserved.
61    Trilinear3d,
62    /// 1-D linear — reserved.
63    Linear1d,
64    /// 2-D area (adaptive average) — reserved.
65    Area2d,
66}
67
68/// Descriptor for an `interpolate` op.
69///
70/// `#[non_exhaustive]` (Phase 32) — Phase 21 added `align_corners` /
71/// `scale_h` / `scale_w`; future modes may add more fields. Use
72/// [`Self::new`] + the `with_*` setters from downstream code.
73#[derive(Copy, Clone, Debug)]
74#[non_exhaustive]
75pub struct InterpolateDescriptor {
76    /// Batch.
77    pub n: i32,
78    /// Channels.
79    pub c: i32,
80    /// Input height.
81    pub ih: i32,
82    /// Input width.
83    pub iw: i32,
84    /// Output height.
85    pub oh: i32,
86    /// Output width.
87    pub ow: i32,
88    /// Interpolation mode.
89    pub mode: InterpolateMode,
90    /// Value element type. Must match `T::KIND`.
91    pub element: ElementKind,
92    /// Coordinate alignment mode. `false` matches PyTorch
93    /// `F.interpolate` new-code default; `true` matches
94    /// `nn.Upsample(align_corners=True)`.
95    pub align_corners: bool,
96    /// Per-axis SCALE override for height (output_size / input_size).
97    /// `None` derives the scale from `(ih, oh)`; `Some(s)` overrides
98    /// and the kernel uses `1.0 / s` per output coordinate. Matches
99    /// PyTorch's `scale_factor` semantics.
100    pub scale_h: Option<f64>,
101    /// Per-axis SCALE override for width (output_size / input_size).
102    /// `None` derives the scale from `(iw, ow)`; `Some(s)` overrides
103    /// and the kernel uses `1.0 / s` per output coordinate. Matches
104    /// PyTorch's `scale_factor` semantics.
105    pub scale_w: Option<f64>,
106}
107
108impl InterpolateDescriptor {
109    /// Build a descriptor with `align_corners = false` (PyTorch
110    /// `F.interpolate` new-code default) and `scale_h / scale_w = None`
111    /// (derive scale from `(ih, oh)` / `(iw, ow)`). Chain with the
112    /// `with_*` setters to override.
113    #[allow(clippy::too_many_arguments)]
114    pub fn new(
115        n: i32,
116        c: i32,
117        ih: i32,
118        iw: i32,
119        oh: i32,
120        ow: i32,
121        mode: InterpolateMode,
122        element: ElementKind,
123    ) -> Self {
124        Self {
125            n,
126            c,
127            ih,
128            iw,
129            oh,
130            ow,
131            mode,
132            element,
133            align_corners: false,
134            scale_h: None,
135            scale_w: None,
136        }
137    }
138
139    /// Override `align_corners`. Default `false`.
140    #[inline]
141    pub fn with_align_corners(mut self, align_corners: bool) -> Self {
142        self.align_corners = align_corners;
143        self
144    }
145
146    /// Override the per-axis SCALE for height. Default `None` (derive
147    /// from `(ih, oh)`).
148    #[inline]
149    pub fn with_scale_h(mut self, scale_h: Option<f64>) -> Self {
150        self.scale_h = scale_h;
151        self
152    }
153
154    /// Override the per-axis SCALE for width. Default `None` (derive
155    /// from `(iw, ow)`).
156    #[inline]
157    pub fn with_scale_w(mut self, scale_w: Option<f64>) -> Self {
158        self.scale_w = scale_w;
159        self
160    }
161}
162
163/// Args bundle for an `interpolate` launch.
164pub struct InterpolateArgs<'a, T: Element> {
165    /// Input `[N, C, IH, IW]`. NCHW row-major contiguous.
166    pub input: TensorRef<'a, T, 4>,
167    /// Output `[N, C, OH, OW]`. NCHW row-major contiguous.
168    pub output: TensorMut<'a, T, 4>,
169}
170
171/// `interpolate` plan.
172///
173/// Spatial resample of an NCHW input. PyTorch `F.interpolate`.
174/// Coordinate mapping: `src = (dst + 0.5) * (src_size / dst_size) - 0.5`
175/// (`align_corners=false`); corner samples clamp to the input
176/// boundary.
177///
178/// **When to use**: forward 2-D bilinear resample. Pair with
179/// [`InterpolateBackwardPlan`](crate::InterpolateBackwardPlan) for
180/// autograd.
181///
182/// **Dtypes**: `{f32, f64, f16, bf16}`.
183///
184/// **Shape limits**: rank-4 NCHW input `[N, C, IH, IW]`; output
185/// `[N, C, OH, OW]`; all extents non-negative.
186///
187/// **Modes**: only `Bilinear2d` is wired in the trailblazer.
188/// `Nearest2d` / `Bicubic2d` / `Trilinear3d` / `Linear1d` / `Area2d`
189/// are reserved on the enum and return `Unsupported`.
190///
191/// **Workspace**: none.
192///
193/// **Precision guarantee**: deterministic, bit-stable on identical
194/// hardware. No atomics on FW.
195pub struct InterpolatePlan<T: Element> {
196    desc: InterpolateDescriptor,
197    sku: KernelSku,
198    _marker: PhantomData<T>,
199}
200
201impl<T: Element> InterpolatePlan<T> {
202    /// Pick a kernel for `desc`.
203    pub fn select(
204        _stream: &Stream,
205        desc: &InterpolateDescriptor,
206        _pref: PlanPreference,
207    ) -> Result<Self> {
208        if desc.element != T::KIND {
209            return Err(Error::Unsupported(
210                "baracuda-kernels::InterpolatePlan: descriptor element != type parameter T",
211            ));
212        }
213        if !matches!(desc.mode, InterpolateMode::Bilinear2d) {
214            return Err(Error::Unsupported(
215                "baracuda-kernels::InterpolatePlan: only Bilinear2d wired in trailblazer",
216            ));
217        }
218        if desc.n < 0 || desc.c < 0 || desc.ih < 0 || desc.iw < 0 || desc.oh < 0 || desc.ow < 0 {
219            return Err(Error::InvalidProblem(
220                "baracuda-kernels::InterpolatePlan: all extents must be non-negative",
221            ));
222        }
223        if !matches!(
224            T::KIND,
225            ElementKind::F32 | ElementKind::F64 | ElementKind::F16 | ElementKind::Bf16
226        ) {
227            return Err(Error::Unsupported(
228                "baracuda-kernels::InterpolatePlan: only `f32`, `f64`, `f16`, `bf16` wired",
229            ));
230        }
231        // Validate scale factors (positive, finite) when present.
232        if let Some(s) = desc.scale_h {
233            if !s.is_finite() || s <= 0.0 {
234                return Err(Error::InvalidProblem(
235                    "baracuda-kernels::InterpolatePlan: scale_h must be positive and finite",
236                ));
237            }
238        }
239        if let Some(s) = desc.scale_w {
240            if !s.is_finite() || s <= 0.0 {
241                return Err(Error::InvalidProblem(
242                    "baracuda-kernels::InterpolatePlan: scale_w must be positive and finite",
243                ));
244            }
245        }
246        let precision_guarantee = PrecisionGuarantee {
247            math_precision: if T::KIND == ElementKind::F64 {
248                MathPrecision::F64
249            } else {
250                MathPrecision::F32
251            },
252            // Half-precision paths accumulate in f32, then cast.
253            accumulator: if matches!(T::KIND, ElementKind::F16 | ElementKind::Bf16) {
254                ElementKind::F32
255            } else {
256                T::KIND
257            },
258            bit_stable_on_same_hardware: true,
259            deterministic: true,
260        };
261        let sku = KernelSku {
262            category: OpCategory::Image,
263            op: ImageKind::InterpolateBilinear2d as u16,
264            element: T::KIND,
265            aux_element: None,
266            layout: None,
267            epilogue: None,
268            arch: ArchSku::Sm80,
269            backend: BackendKind::Bespoke,
270            precision_guarantee,
271        };
272        Ok(Self {
273            desc: *desc,
274            sku,
275            _marker: PhantomData,
276        })
277    }
278
279    /// Validate args.
280    pub fn can_implement(&self, args: &InterpolateArgs<'_, T>) -> Result<()> {
281        if args.input.shape != [self.desc.n, self.desc.c, self.desc.ih, self.desc.iw] {
282            return Err(Error::InvalidProblem(
283                "baracuda-kernels::InterpolatePlan: input shape mismatch",
284            ));
285        }
286        if args.output.shape != [self.desc.n, self.desc.c, self.desc.oh, self.desc.ow] {
287            return Err(Error::InvalidProblem(
288                "baracuda-kernels::InterpolatePlan: output shape mismatch",
289            ));
290        }
291        let in_numel = args.input.numel();
292        let out_numel = args.output.numel();
293        if (args.input.data.len() as i64) < in_numel {
294            return Err(Error::BufferTooSmall {
295                needed: in_numel as usize,
296                got: args.input.data.len(),
297            });
298        }
299        if (args.output.data.len() as i64) < out_numel {
300            return Err(Error::BufferTooSmall {
301                needed: out_numel as usize,
302                got: args.output.data.len(),
303            });
304        }
305        Ok(())
306    }
307
308    /// Workspace size (zero).
309    #[inline]
310    pub fn workspace_size(&self) -> usize {
311        0
312    }
313
314    /// Identity of the kernel this plan picked.
315    #[inline]
316    pub fn sku(&self) -> KernelSku {
317        self.sku
318    }
319
320    /// Numerical guarantees for this plan's kernel.
321    #[inline]
322    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
323        self.sku.precision_guarantee
324    }
325
326    /// Launch.
327    pub fn run(
328        &self,
329        stream: &Stream,
330        _workspace: Workspace<'_>,
331        args: InterpolateArgs<'_, T>,
332    ) -> Result<()> {
333        self.can_implement(&args)?;
334        if args.output.numel() == 0 {
335            return Ok(());
336        }
337        let input_ptr = args.input.data.as_raw().0 as *const c_void;
338        let output_ptr = args.output.data.as_raw().0 as *mut c_void;
339        let stream_ptr = stream.as_raw() as *mut c_void;
340        let ac: i32 = if self.desc.align_corners { 1 } else { 0 };
341        // Sentinel: 0.0 = "derive from sizes" on the C side.
342        let sh: f64 = self.desc.scale_h.unwrap_or(0.0);
343        let sw: f64 = self.desc.scale_w.unwrap_or(0.0);
344        let status = match T::KIND {
345            ElementKind::F32 => unsafe {
346                baracuda_kernels_sys::baracuda_kernels_interpolate_bilinear_2d_f32_run(
347                    self.desc.n, self.desc.c, self.desc.ih, self.desc.iw,
348                    self.desc.oh, self.desc.ow,
349                    input_ptr, output_ptr,
350                    core::ptr::null_mut(), 0,
351                    ac, sh, sw,
352                    stream_ptr,
353                )
354            },
355            ElementKind::F64 => unsafe {
356                baracuda_kernels_sys::baracuda_kernels_interpolate_bilinear_2d_f64_run(
357                    self.desc.n, self.desc.c, self.desc.ih, self.desc.iw,
358                    self.desc.oh, self.desc.ow,
359                    input_ptr, output_ptr,
360                    core::ptr::null_mut(), 0,
361                    ac, sh, sw,
362                    stream_ptr,
363                )
364            },
365            ElementKind::F16 => unsafe {
366                baracuda_kernels_sys::baracuda_kernels_interpolate_bilinear_2d_f16_run(
367                    self.desc.n, self.desc.c, self.desc.ih, self.desc.iw,
368                    self.desc.oh, self.desc.ow,
369                    input_ptr, output_ptr,
370                    core::ptr::null_mut(), 0,
371                    ac, sh, sw,
372                    stream_ptr,
373                )
374            },
375            ElementKind::Bf16 => unsafe {
376                baracuda_kernels_sys::baracuda_kernels_interpolate_bilinear_2d_bf16_run(
377                    self.desc.n, self.desc.c, self.desc.ih, self.desc.iw,
378                    self.desc.oh, self.desc.ow,
379                    input_ptr, output_ptr,
380                    core::ptr::null_mut(), 0,
381                    ac, sh, sw,
382                    stream_ptr,
383                )
384            },
385            _ => {
386                return Err(Error::Unsupported(
387                    "baracuda-kernels::InterpolatePlan::run reached an unimplemented dtype",
388                ));
389            }
390        };
391        map_status(status)
392    }
393}