Skip to main content

baracuda_kernels/image/
roi_pool.rs

1//! `roi_pool` FW plan — Category T.
2//!
3//! Max-pool variant of [`crate::image::RoiAlignPlan`]. Each output
4//! cell is the max value over the (integer-rounded) RoI bin in the
5//! input plane. The kernel emits an `argmax` buffer (i32 linear
6//! plane-relative index per output cell; `-1` for empty bins) that
7//! the BW reads to route gradient.
8//!
9//! Trailblazer dtype coverage: `f32, f64`.
10
11use core::ffi::c_void;
12use core::marker::PhantomData;
13
14use baracuda_cutlass::{Error, Result};
15use baracuda_driver::Stream;
16use baracuda_kernels_types::{
17    ArchSku, BackendKind, Element, ElementKind, ImageKind, KernelSku, MathPrecision, OpCategory,
18    PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
19};
20
21use super::map_status;
22
23/// Descriptor for `roi_pool`.
24#[derive(Copy, Clone, Debug)]
25pub struct RoiPoolDescriptor {
26    /// Batch.
27    pub n: i32,
28    /// Channels.
29    pub c: i32,
30    /// Input height.
31    pub h: i32,
32    /// Input width.
33    pub w: i32,
34    /// Number of RoIs.
35    pub num_rois: i32,
36    /// Output pooled height per RoI.
37    pub pooled_h: i32,
38    /// Output pooled width per RoI.
39    pub pooled_w: i32,
40    /// RoI coord scale.
41    pub spatial_scale: f32,
42    /// Value element type.
43    pub element: ElementKind,
44}
45
46/// Args bundle for `roi_pool`.
47pub struct RoiPoolArgs<'a, T: Element> {
48    /// Input `[N, C, H, W]`.
49    pub input: TensorRef<'a, T, 4>,
50    /// RoIs `[num_rois, 5]`.
51    pub rois: TensorRef<'a, T, 2>,
52    /// Output `[num_rois, C, pooled_h, pooled_w]`.
53    pub output: TensorMut<'a, T, 4>,
54    /// Argmax (i32 plane-relative index per output cell, or -1 for
55    /// empty bin) `[num_rois, C, pooled_h, pooled_w]`. Required for BW.
56    pub argmax: TensorMut<'a, i32, 4>,
57}
58
59/// `roi_pool` plan.
60///
61/// Max-pool variant of [`crate::RoiAlignPlan`]: each output cell is
62/// the max over the integer-rounded RoI bin in the input plane
63/// (torchvision `roi_pool`). FW emits an `argmax` buffer the BW
64/// reads to route gradients.
65///
66/// **When to use**: forward RoIPool (legacy detection nets). For
67/// bilinear-sampled extraction prefer
68/// [`RoiAlignPlan`](crate::RoiAlignPlan). Pair with
69/// [`RoiPoolBackwardPlan`](crate::RoiPoolBackwardPlan).
70///
71/// **Dtypes**: `{f32, f64}`.
72///
73/// **Shape limits**: rank-4 NCHW input; RoIs `[num_rois, 5]`;
74/// outputs (`output` + `argmax`) `[num_rois, C, pooled_h, pooled_w]`.
75///
76/// **Workspace**: none, but the caller supplies the `argmax` tensor
77/// as part of [`RoiPoolArgs`] (BW requires it).
78///
79/// **Precision guarantee**: deterministic, bit-stable.
80pub struct RoiPoolPlan<T: Element> {
81    desc: RoiPoolDescriptor,
82    sku: KernelSku,
83    _marker: PhantomData<T>,
84}
85
86impl<T: Element> RoiPoolPlan<T> {
87    /// Pick a kernel.
88    pub fn select(
89        _stream: &Stream,
90        desc: &RoiPoolDescriptor,
91        _pref: PlanPreference,
92    ) -> Result<Self> {
93        if desc.element != T::KIND {
94            return Err(Error::Unsupported(
95                "baracuda-kernels::RoiPoolPlan: descriptor element != T",
96            ));
97        }
98        if desc.n < 0
99            || desc.c < 0
100            || desc.h < 0
101            || desc.w < 0
102            || desc.num_rois < 0
103            || desc.pooled_h < 0
104            || desc.pooled_w < 0
105        {
106            return Err(Error::InvalidProblem(
107                "baracuda-kernels::RoiPoolPlan: extents must be non-negative",
108            ));
109        }
110        if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
111            return Err(Error::Unsupported(
112                "baracuda-kernels::RoiPoolPlan: only `f32`, `f64` wired",
113            ));
114        }
115        let precision_guarantee = PrecisionGuarantee {
116            math_precision: if T::KIND == ElementKind::F64 {
117                MathPrecision::F64
118            } else {
119                MathPrecision::F32
120            },
121            accumulator: T::KIND,
122            bit_stable_on_same_hardware: true,
123            deterministic: true,
124        };
125        let sku = KernelSku {
126            category: OpCategory::Image,
127            op: ImageKind::RoiPool as u16,
128            element: T::KIND,
129            aux_element: Some(ElementKind::I32),
130            layout: None,
131            epilogue: None,
132            arch: ArchSku::Sm80,
133            backend: BackendKind::Bespoke,
134            precision_guarantee,
135        };
136        Ok(Self {
137            desc: *desc,
138            sku,
139            _marker: PhantomData,
140        })
141    }
142
143    /// Validate args.
144    pub fn can_implement(&self, args: &RoiPoolArgs<'_, T>) -> Result<()> {
145        if args.input.shape != [self.desc.n, self.desc.c, self.desc.h, self.desc.w] {
146            return Err(Error::InvalidProblem(
147                "baracuda-kernels::RoiPoolPlan: input shape mismatch",
148            ));
149        }
150        if args.rois.shape != [self.desc.num_rois, 5] {
151            return Err(Error::InvalidProblem(
152                "baracuda-kernels::RoiPoolPlan: rois must be [num_rois, 5]",
153            ));
154        }
155        let out_shape = [self.desc.num_rois, self.desc.c, self.desc.pooled_h, self.desc.pooled_w];
156        if args.output.shape != out_shape || args.argmax.shape != out_shape {
157            return Err(Error::InvalidProblem(
158                "baracuda-kernels::RoiPoolPlan: output / argmax shape mismatch",
159            ));
160        }
161        Ok(())
162    }
163
164    /// Workspace (zero).
165    #[inline]
166    pub fn workspace_size(&self) -> usize {
167        0
168    }
169
170    /// Identity.
171    #[inline]
172    pub fn sku(&self) -> KernelSku {
173        self.sku
174    }
175
176    /// Numerical guarantees.
177    #[inline]
178    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
179        self.sku.precision_guarantee
180    }
181
182    /// Launch.
183    pub fn run(
184        &self,
185        stream: &Stream,
186        _workspace: Workspace<'_>,
187        args: RoiPoolArgs<'_, T>,
188    ) -> Result<()> {
189        self.can_implement(&args)?;
190        if args.output.numel() == 0 {
191            return Ok(());
192        }
193        let input_ptr = args.input.data.as_raw().0 as *const c_void;
194        let rois_ptr = args.rois.data.as_raw().0 as *const c_void;
195        let out_ptr = args.output.data.as_raw().0 as *mut c_void;
196        let arg_ptr = args.argmax.data.as_raw().0 as *mut c_void;
197        let stream_ptr = stream.as_raw() as *mut c_void;
198        let status = match T::KIND {
199            ElementKind::F32 => unsafe {
200                baracuda_kernels_sys::baracuda_kernels_roi_pool_f32_run(
201                    self.desc.n, self.desc.c, self.desc.h, self.desc.w,
202                    self.desc.num_rois, self.desc.pooled_h, self.desc.pooled_w,
203                    self.desc.spatial_scale,
204                    input_ptr, rois_ptr, out_ptr, arg_ptr,
205                    core::ptr::null_mut(), 0, stream_ptr,
206                )
207            },
208            ElementKind::F64 => unsafe {
209                baracuda_kernels_sys::baracuda_kernels_roi_pool_f64_run(
210                    self.desc.n, self.desc.c, self.desc.h, self.desc.w,
211                    self.desc.num_rois, self.desc.pooled_h, self.desc.pooled_w,
212                    self.desc.spatial_scale,
213                    input_ptr, rois_ptr, out_ptr, arg_ptr,
214                    core::ptr::null_mut(), 0, stream_ptr,
215                )
216            },
217            _ => {
218                return Err(Error::Unsupported(
219                    "baracuda-kernels::RoiPoolPlan::run reached unimplemented dtype",
220                ));
221            }
222        };
223        map_status(status)
224    }
225}