Skip to main content

od_bridge/
lib.rs

1//! od-bridge: C ABI bridge for od_opencv, designed for Go CGO integration.
2//!
3//! Provides opaque model handles and flat C structs for detection results.
4//! Each model is independent: create one for plate detection, another for OCR.
5
6use std::ffi::CStr;
7use std::os::raw::c_char;
8use std::ptr;
9use std::slice;
10
11use ndarray::Array3;
12use od_opencv::model_factory::Model;
13use od_opencv::model_trait::ObjectDetector;
14use od_opencv::BBox;
15
16/// Single detection result (flat, no pointers, safe for CGO memcpy).
17#[repr(C)]
18#[derive(Debug, Clone, Copy)]
19pub struct OdDetection {
20    /// Top-left corner X coordinate (pixels).
21    pub bbox_x: i32,
22    /// Top-left corner Y coordinate (pixels).
23    pub bbox_y: i32,
24    /// Bounding box width (pixels).
25    pub bbox_w: i32,
26    /// Bounding box height (pixels).
27    pub bbox_h: i32,
28    /// Predicted class index (zero-based).
29    pub class_id: i32,
30    /// Detection confidence in [0.0, 1.0].
31    pub confidence: f32,
32}
33
34/// Detection results batch. Caller must free via `od_detections_free`.
35#[repr(C)]
36pub struct OdDetections {
37    /// Pointer to the first element of the results array.
38    pub data: *mut OdDetection,
39    /// Number of detections in the array.
40    pub len: i32,
41}
42
43/// Error code returned by all functions.
44#[repr(C)]
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum OdError {
47    /// No error.
48    Ok = 0,
49    /// Null pointer or invalid dimension passed.
50    InvalidArgument = 1,
51    /// ONNX model file could not be loaded.
52    ModelLoadFailed = 2,
53    /// Inference failed at runtime.
54    DetectionFailed = 3,
55    /// RGB pixel buffer could not be converted to Array3.
56    ImageConvertFailed = 4,
57}
58
59/// Backend-specific model variant.
60enum ModelInner {
61    /// ONNX Runtime (CPU, CUDA, or TensorRT execution provider).
62    Ort(od_opencv::backend_ort::ModelUltralyticsOrt),
63    /// Native TensorRT engine.
64    #[cfg(feature = "trt")]
65    Trt(od_opencv::backend_tensorrt::ModelUltralyticsRt),
66    /// Rockchip RKNN NPU.
67    #[cfg(feature = "rknn")]
68    Rknn(od_opencv::backend_rknn::ModelUltralyticsRknn),
69}
70
71/// Opaque model handle. Created by any `od_model_create_*` function.
72pub struct ModelHandle {
73    inner: ModelInner,
74}
75
76impl ModelHandle {
77    /// Run detection on an `ImageBuffer`, dispatching to the active backend.
78    fn detect(
79        &mut self,
80        img: &od_opencv::ImageBuffer,
81        conf: f32,
82        nms: f32,
83    ) -> Result<(Vec<BBox>, Vec<usize>, Vec<f32>), OdError> {
84        match &mut self.inner {
85            ModelInner::Ort(m) => m.detect(img, conf, nms).map_err(|e| {
86                eprintln!("od_model_detect (ort): {e:?}");
87                OdError::DetectionFailed
88            }),
89            #[cfg(feature = "trt")]
90            ModelInner::Trt(m) => m.detect(img, conf, nms).map_err(|e| {
91                eprintln!("od_model_detect (trt): {e:?}");
92                OdError::DetectionFailed
93            }),
94            #[cfg(feature = "rknn")]
95            ModelInner::Rknn(m) => m.detect(img, conf, nms).map_err(|e| {
96                eprintln!("od_model_detect (rknn): {e:?}");
97                OdError::DetectionFailed
98            }),
99        }
100    }
101}
102
103/// Helper: parse a C string pointer into a Rust `&str`.
104/// Returns `None` if the pointer is null or not valid UTF-8.
105unsafe fn parse_cstr(p: *const c_char) -> Option<&'static str> {
106    if p.is_null() {
107        return None;
108    }
109    unsafe { CStr::from_ptr(p) }.to_str().ok()
110}
111
112/// Helper: allocate a `ModelHandle` on the heap and return a raw pointer.
113fn into_handle(inner: ModelInner) -> *mut ModelHandle {
114    Box::into_raw(Box::new(ModelHandle { inner }))
115}
116
117/// Create a model from an ONNX file (ORT backend, CPU).
118///
119/// # Parameters
120/// - `model_path`: null-terminated path to `.onnx` file
121/// - `input_w`, `input_h`: model input dimensions (e.g. 416, 416)
122///
123/// # Returns
124/// Opaque pointer, or null on error.
125///
126/// # Safety
127/// `model_path` must be a valid null-terminated C string.
128#[unsafe(no_mangle)]
129pub unsafe extern "C" fn od_model_create(
130    model_path: *const c_char,
131    input_w: u32,
132    input_h: u32,
133) -> *mut ModelHandle {
134    let Some(path) = (unsafe { parse_cstr(model_path) }) else {
135        return ptr::null_mut();
136    };
137    match Model::ort(path, (input_w, input_h)) {
138        Ok(model) => into_handle(ModelInner::Ort(model)),
139        Err(e) => {
140            eprintln!("od_model_create: {e:?}");
141            ptr::null_mut()
142        }
143    }
144}
145
146/// Create a model from an ONNX file with CUDA execution provider.
147///
148/// # Safety
149/// `model_path` must be a valid null-terminated C string.
150#[cfg(feature = "cuda")]
151#[unsafe(no_mangle)]
152pub unsafe extern "C" fn od_model_create_cuda(
153    model_path: *const c_char,
154    input_w: u32,
155    input_h: u32,
156) -> *mut ModelHandle {
157    let Some(path) = (unsafe { parse_cstr(model_path) }) else {
158        return ptr::null_mut();
159    };
160    match Model::ort_cuda(path, (input_w, input_h)) {
161        Ok(model) => into_handle(ModelInner::Ort(model)),
162        Err(e) => {
163            eprintln!("od_model_create_cuda: {e:?}");
164            ptr::null_mut()
165        }
166    }
167}
168
169/// Create a model from an ONNX file with TensorRT execution provider (via ORT).
170///
171/// # Safety
172/// `model_path` must be a valid null-terminated C string.
173#[cfg(feature = "tensorrt")]
174#[unsafe(no_mangle)]
175pub unsafe extern "C" fn od_model_create_tensorrt(
176    model_path: *const c_char,
177    input_w: u32,
178    input_h: u32,
179) -> *mut ModelHandle {
180    let Some(path) = (unsafe { parse_cstr(model_path) }) else {
181        return ptr::null_mut();
182    };
183    match Model::ort_tensorrt(path, (input_w, input_h)) {
184        Ok(model) => into_handle(ModelInner::Ort(model)),
185        Err(e) => {
186            eprintln!("od_model_create_tensorrt: {e:?}");
187            ptr::null_mut()
188        }
189    }
190}
191
192/// Create a model from a serialized TensorRT engine file (native TensorRT, no ORT).
193///
194/// # Safety
195/// `engine_path` must be a valid null-terminated C string.
196#[cfg(feature = "trt")]
197#[unsafe(no_mangle)]
198pub unsafe extern "C" fn od_model_create_trt(
199    engine_path: *const c_char,
200) -> *mut ModelHandle {
201    let Some(path) = (unsafe { parse_cstr(engine_path) }) else {
202        return ptr::null_mut();
203    };
204    match Model::tensorrt(path) {
205        Ok(model) => into_handle(ModelInner::Trt(model)),
206        Err(e) => {
207            eprintln!("od_model_create_trt: {e:?}");
208            ptr::null_mut()
209        }
210    }
211}
212
213/// Create a model from an RKNN model file (Rockchip NPU).
214///
215/// # Parameters
216/// - `model_path`: null-terminated path to `.rknn` file
217/// - `num_classes`: number of classes the model was trained on
218///
219/// # Safety
220/// `model_path` must be a valid null-terminated C string.
221#[cfg(feature = "rknn")]
222#[unsafe(no_mangle)]
223pub unsafe extern "C" fn od_model_create_rknn(
224    model_path: *const c_char,
225    num_classes: u32,
226) -> *mut ModelHandle {
227    let Some(path) = (unsafe { parse_cstr(model_path) }) else {
228        return ptr::null_mut();
229    };
230    match Model::rknn(path, num_classes as usize) {
231        Ok(model) => into_handle(ModelInner::Rknn(model)),
232        Err(e) => {
233            eprintln!("od_model_create_rknn: {e:?}");
234            ptr::null_mut()
235        }
236    }
237}
238
239/// Free a model handle.
240///
241/// # Safety
242/// `handle` must have been returned by `od_model_create*` and not yet freed.
243#[unsafe(no_mangle)]
244pub unsafe extern "C" fn od_model_free(handle: *mut ModelHandle) {
245    if !handle.is_null() {
246        drop(unsafe { Box::from_raw(handle) });
247    }
248}
249
250/// Run detection on an RGB image.
251///
252/// Works with any backend: the handle dispatches to the correct runtime internally.
253///
254/// # Parameters
255/// - `handle`: model handle from any `od_model_create_*` function
256/// - `pixels_rgb`: pointer to `width * height * 3` bytes (RGB, row-major, HWC)
257/// - `img_w`, `img_h`: image dimensions in pixels
258/// - `conf_threshold`: confidence threshold (e.g. 0.3)
259/// - `nms_threshold`: NMS IoU threshold (e.g. 0.4)
260/// - `out`: pointer to `OdDetections` struct, filled on success
261///
262/// # Returns
263/// `OdError::Ok` on success. On error, `out` is zeroed.
264///
265/// # Safety
266/// - `handle` must be valid.
267/// - `pixels_rgb` must point to at least `img_w * img_h * 3` bytes.
268/// - `out` must be a valid pointer.
269#[unsafe(no_mangle)]
270pub unsafe extern "C" fn od_model_detect(
271    handle: *mut ModelHandle,
272    pixels_rgb: *const u8,
273    img_w: i32,
274    img_h: i32,
275    conf_threshold: f32,
276    nms_threshold: f32,
277    out: *mut OdDetections,
278) -> OdError {
279    if handle.is_null() || pixels_rgb.is_null() || out.is_null() {
280        return OdError::InvalidArgument;
281    }
282    if img_w <= 0 || img_h <= 0 {
283        return OdError::InvalidArgument;
284    }
285
286    let model = unsafe { &mut *handle };
287    let h = img_h as usize;
288    let w = img_w as usize;
289    let n_bytes = h * w * 3;
290
291    let rgb_slice = unsafe { slice::from_raw_parts(pixels_rgb, n_bytes) };
292    let arr = match Array3::from_shape_vec((h, w, 3), rgb_slice.to_vec()) {
293        Ok(a) => a,
294        Err(_) => {
295            unsafe {
296                (*out).data = ptr::null_mut();
297                (*out).len = 0;
298            }
299            return OdError::ImageConvertFailed;
300        }
301    };
302
303    let img_buf = od_opencv::ImageBuffer::from_rgb(arr);
304
305    let (bboxes, class_ids, confidences) = match model.detect(&img_buf, conf_threshold, nms_threshold) {
306        Ok(result) => result,
307        Err(e) => {
308            unsafe {
309                (*out).data = ptr::null_mut();
310                (*out).len = 0;
311            }
312            return e;
313        }
314    };
315
316    let count = bboxes.len();
317    if count == 0 {
318        unsafe {
319            (*out).data = ptr::null_mut();
320            (*out).len = 0;
321        }
322        return OdError::Ok;
323    }
324
325    let mut results: Vec<OdDetection> = Vec::with_capacity(count);
326    for i in 0..count {
327        results.push(OdDetection {
328            bbox_x: bboxes[i].x,
329            bbox_y: bboxes[i].y,
330            bbox_w: bboxes[i].width,
331            bbox_h: bboxes[i].height,
332            class_id: class_ids[i] as i32,
333            confidence: confidences[i],
334        });
335    }
336
337    let mut results = results.into_boxed_slice();
338    unsafe {
339        (*out).data = results.as_mut_ptr();
340        (*out).len = count as i32;
341    }
342    std::mem::forget(results);
343
344    OdError::Ok
345}
346
347/// Free detection results.
348///
349/// # Safety
350/// `detections` must point to a valid `OdDetections` returned by `od_model_detect`.
351#[unsafe(no_mangle)]
352pub unsafe extern "C" fn od_detections_free(detections: *mut OdDetections) {
353    if detections.is_null() {
354        return;
355    }
356    let d = unsafe { &mut *detections };
357    if !d.data.is_null() && d.len > 0 {
358        let _ = unsafe {
359            Vec::from_raw_parts(d.data, d.len as usize, d.len as usize)
360        };
361        d.data = ptr::null_mut();
362        d.len = 0;
363    }
364}