1use core::ffi::c_void;
12use core::marker::PhantomData;
13
14use baracuda_cutlass::{Error, Result};
15use baracuda_driver::Stream;
16use baracuda_kernels_types::{
17 ArchSku, BackendKind, Element, ElementKind, ImageKind, KernelSku, MathPrecision, OpCategory,
18 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
19};
20
21use super::map_status;
22
23#[derive(Copy, Clone, Debug)]
25pub struct RoiPoolDescriptor {
26 pub n: i32,
28 pub c: i32,
30 pub h: i32,
32 pub w: i32,
34 pub num_rois: i32,
36 pub pooled_h: i32,
38 pub pooled_w: i32,
40 pub spatial_scale: f32,
42 pub element: ElementKind,
44}
45
46pub struct RoiPoolArgs<'a, T: Element> {
48 pub input: TensorRef<'a, T, 4>,
50 pub rois: TensorRef<'a, T, 2>,
52 pub output: TensorMut<'a, T, 4>,
54 pub argmax: TensorMut<'a, i32, 4>,
57}
58
59pub struct RoiPoolPlan<T: Element> {
81 desc: RoiPoolDescriptor,
82 sku: KernelSku,
83 _marker: PhantomData<T>,
84}
85
86impl<T: Element> RoiPoolPlan<T> {
87 pub fn select(
89 _stream: &Stream,
90 desc: &RoiPoolDescriptor,
91 _pref: PlanPreference,
92 ) -> Result<Self> {
93 if desc.element != T::KIND {
94 return Err(Error::Unsupported(
95 "baracuda-kernels::RoiPoolPlan: descriptor element != T",
96 ));
97 }
98 if desc.n < 0
99 || desc.c < 0
100 || desc.h < 0
101 || desc.w < 0
102 || desc.num_rois < 0
103 || desc.pooled_h < 0
104 || desc.pooled_w < 0
105 {
106 return Err(Error::InvalidProblem(
107 "baracuda-kernels::RoiPoolPlan: extents must be non-negative",
108 ));
109 }
110 if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
111 return Err(Error::Unsupported(
112 "baracuda-kernels::RoiPoolPlan: only `f32`, `f64` wired",
113 ));
114 }
115 let precision_guarantee = PrecisionGuarantee {
116 math_precision: if T::KIND == ElementKind::F64 {
117 MathPrecision::F64
118 } else {
119 MathPrecision::F32
120 },
121 accumulator: T::KIND,
122 bit_stable_on_same_hardware: true,
123 deterministic: true,
124 };
125 let sku = KernelSku {
126 category: OpCategory::Image,
127 op: ImageKind::RoiPool as u16,
128 element: T::KIND,
129 aux_element: Some(ElementKind::I32),
130 layout: None,
131 epilogue: None,
132 arch: ArchSku::Sm80,
133 backend: BackendKind::Bespoke,
134 precision_guarantee,
135 };
136 Ok(Self {
137 desc: *desc,
138 sku,
139 _marker: PhantomData,
140 })
141 }
142
143 pub fn can_implement(&self, args: &RoiPoolArgs<'_, T>) -> Result<()> {
145 if args.input.shape != [self.desc.n, self.desc.c, self.desc.h, self.desc.w] {
146 return Err(Error::InvalidProblem(
147 "baracuda-kernels::RoiPoolPlan: input shape mismatch",
148 ));
149 }
150 if args.rois.shape != [self.desc.num_rois, 5] {
151 return Err(Error::InvalidProblem(
152 "baracuda-kernels::RoiPoolPlan: rois must be [num_rois, 5]",
153 ));
154 }
155 let out_shape = [self.desc.num_rois, self.desc.c, self.desc.pooled_h, self.desc.pooled_w];
156 if args.output.shape != out_shape || args.argmax.shape != out_shape {
157 return Err(Error::InvalidProblem(
158 "baracuda-kernels::RoiPoolPlan: output / argmax shape mismatch",
159 ));
160 }
161 Ok(())
162 }
163
164 #[inline]
166 pub fn workspace_size(&self) -> usize {
167 0
168 }
169
170 #[inline]
172 pub fn sku(&self) -> KernelSku {
173 self.sku
174 }
175
176 #[inline]
178 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
179 self.sku.precision_guarantee
180 }
181
182 pub fn run(
184 &self,
185 stream: &Stream,
186 _workspace: Workspace<'_>,
187 args: RoiPoolArgs<'_, T>,
188 ) -> Result<()> {
189 self.can_implement(&args)?;
190 if args.output.numel() == 0 {
191 return Ok(());
192 }
193 let input_ptr = args.input.data.as_raw().0 as *const c_void;
194 let rois_ptr = args.rois.data.as_raw().0 as *const c_void;
195 let out_ptr = args.output.data.as_raw().0 as *mut c_void;
196 let arg_ptr = args.argmax.data.as_raw().0 as *mut c_void;
197 let stream_ptr = stream.as_raw() as *mut c_void;
198 let status = match T::KIND {
199 ElementKind::F32 => unsafe {
200 baracuda_kernels_sys::baracuda_kernels_roi_pool_f32_run(
201 self.desc.n, self.desc.c, self.desc.h, self.desc.w,
202 self.desc.num_rois, self.desc.pooled_h, self.desc.pooled_w,
203 self.desc.spatial_scale,
204 input_ptr, rois_ptr, out_ptr, arg_ptr,
205 core::ptr::null_mut(), 0, stream_ptr,
206 )
207 },
208 ElementKind::F64 => unsafe {
209 baracuda_kernels_sys::baracuda_kernels_roi_pool_f64_run(
210 self.desc.n, self.desc.c, self.desc.h, self.desc.w,
211 self.desc.num_rois, self.desc.pooled_h, self.desc.pooled_w,
212 self.desc.spatial_scale,
213 input_ptr, rois_ptr, out_ptr, arg_ptr,
214 core::ptr::null_mut(), 0, stream_ptr,
215 )
216 },
217 _ => {
218 return Err(Error::Unsupported(
219 "baracuda-kernels::RoiPoolPlan::run reached unimplemented dtype",
220 ));
221 }
222 };
223 map_status(status)
224 }
225}