1use core::ffi::c_void;
13use core::marker::PhantomData;
14
15use baracuda_cutlass::{Error, Result};
16use baracuda_driver::Stream;
17use baracuda_kernels_types::{
18 ArchSku, BackendKind, Element, ElementKind, ImageKind, KernelSku, MathPrecision, OpCategory,
19 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
20};
21
22use super::map_status;
23
24#[derive(Copy, Clone, Debug)]
26pub struct NmsDescriptor {
27 pub num_boxes: i32,
29 pub iou_threshold: f32,
32 pub element: ElementKind,
34}
35
36pub struct NmsArgs<'a, T: Element> {
38 pub boxes: TensorRef<'a, T, 2>,
41 pub keep_mask: TensorMut<'a, u8, 1>,
43 pub count: TensorMut<'a, i32, 1>,
45}
46
47pub struct NmsPlan<T: Element> {
66 desc: NmsDescriptor,
67 sku: KernelSku,
68 _marker: PhantomData<T>,
69}
70
71impl<T: Element> NmsPlan<T> {
72 pub fn select(
74 _stream: &Stream,
75 desc: &NmsDescriptor,
76 _pref: PlanPreference,
77 ) -> Result<Self> {
78 if desc.element != T::KIND {
79 return Err(Error::Unsupported(
80 "baracuda-kernels::NmsPlan: descriptor element != T",
81 ));
82 }
83 if desc.num_boxes < 0 {
84 return Err(Error::InvalidProblem(
85 "baracuda-kernels::NmsPlan: num_boxes must be non-negative",
86 ));
87 }
88 if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
89 return Err(Error::Unsupported(
90 "baracuda-kernels::NmsPlan: only `f32`, `f64` wired",
91 ));
92 }
93 let precision_guarantee = PrecisionGuarantee {
94 math_precision: if T::KIND == ElementKind::F64 {
95 MathPrecision::F64
96 } else {
97 MathPrecision::F32
98 },
99 accumulator: T::KIND,
100 bit_stable_on_same_hardware: true,
101 deterministic: true,
102 };
103 let sku = KernelSku {
104 category: OpCategory::Image,
105 op: ImageKind::Nms as u16,
106 element: T::KIND,
107 aux_element: Some(ElementKind::U8),
108 layout: None,
109 epilogue: None,
110 arch: ArchSku::Sm80,
111 backend: BackendKind::Bespoke,
112 precision_guarantee,
113 };
114 Ok(Self {
115 desc: *desc,
116 sku,
117 _marker: PhantomData,
118 })
119 }
120
121 pub fn can_implement(&self, args: &NmsArgs<'_, T>) -> Result<()> {
123 if args.boxes.shape != [self.desc.num_boxes, 4] {
124 return Err(Error::InvalidProblem(
125 "baracuda-kernels::NmsPlan: boxes must be [num_boxes, 4]",
126 ));
127 }
128 if args.keep_mask.shape != [self.desc.num_boxes] {
129 return Err(Error::InvalidProblem(
130 "baracuda-kernels::NmsPlan: keep_mask must be [num_boxes]",
131 ));
132 }
133 if args.count.shape != [1] {
134 return Err(Error::InvalidProblem(
135 "baracuda-kernels::NmsPlan: count must be [1]",
136 ));
137 }
138 Ok(())
139 }
140
141 #[inline]
144 pub fn workspace_size(&self) -> usize {
145 0
146 }
147
148 #[inline]
150 pub fn sku(&self) -> KernelSku {
151 self.sku
152 }
153
154 #[inline]
156 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
157 self.sku.precision_guarantee
158 }
159
160 pub fn run(
162 &self,
163 stream: &Stream,
164 _workspace: Workspace<'_>,
165 args: NmsArgs<'_, T>,
166 ) -> Result<()> {
167 self.can_implement(&args)?;
168 let boxes_ptr = args.boxes.data.as_raw().0 as *const c_void;
169 let mask_ptr = args.keep_mask.data.as_raw().0 as *mut c_void;
170 let count_ptr = args.count.data.as_raw().0 as *mut c_void;
171 let stream_ptr = stream.as_raw() as *mut c_void;
172 let status = match T::KIND {
173 ElementKind::F32 => unsafe {
174 baracuda_kernels_sys::baracuda_kernels_nms_f32_run(
175 self.desc.num_boxes,
176 self.desc.iou_threshold,
177 boxes_ptr, mask_ptr, count_ptr,
178 core::ptr::null_mut(), 0, stream_ptr,
179 )
180 },
181 ElementKind::F64 => unsafe {
182 baracuda_kernels_sys::baracuda_kernels_nms_f64_run(
183 self.desc.num_boxes,
184 self.desc.iou_threshold,
185 boxes_ptr, mask_ptr, count_ptr,
186 core::ptr::null_mut(), 0, stream_ptr,
187 )
188 },
189 _ => {
190 return Err(Error::Unsupported(
191 "baracuda-kernels::NmsPlan::run reached unimplemented dtype",
192 ));
193 }
194 };
195 map_status(status)
196 }
197}