Skip to main content

od_bridge/
lib.rs

1//! od-bridge: C ABI bridge for od_opencv, designed for Go CGO integration.
2//!
3//! Provides opaque model handles and flat C structs for detection results.
4//! Each model is independent: create one for plate detection, another for OCR.
5
6use std::ffi::CStr;
7use std::os::raw::c_char;
8use std::ptr;
9use std::slice;
10
11use ndarray::Array3;
12use od_opencv::model_factory::Model;
13use od_opencv::model_trait::ObjectDetector;
14use od_opencv::BBox;
15use od_opencv::face_pipeline::FacePipeline;
16
17/// Single detection result (flat, no pointers, safe for CGO memcpy).
18#[repr(C)]
19#[derive(Debug, Clone, Copy)]
20pub struct OdDetection {
21    /// Top-left corner X coordinate (pixels).
22    pub bbox_x: i32,
23    /// Top-left corner Y coordinate (pixels).
24    pub bbox_y: i32,
25    /// Bounding box width (pixels).
26    pub bbox_w: i32,
27    /// Bounding box height (pixels).
28    pub bbox_h: i32,
29    /// Predicted class index (zero-based).
30    pub class_id: i32,
31    /// Detection confidence in [0.0, 1.0].
32    pub confidence: f32,
33}
34
35/// Detection results batch. Caller must free via `od_detections_free`.
36#[repr(C)]
37pub struct OdDetections {
38    /// Pointer to the first element of the results array.
39    pub data: *mut OdDetection,
40    /// Number of detections in the array.
41    pub len: i32,
42}
43
44/// Error code returned by all functions.
45#[repr(C)]
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum OdError {
48    /// No error.
49    Ok = 0,
50    /// Null pointer or invalid dimension passed.
51    InvalidArgument = 1,
52    /// ONNX model file could not be loaded.
53    ModelLoadFailed = 2,
54    /// Inference failed at runtime.
55    DetectionFailed = 3,
56    /// RGB pixel buffer could not be converted to Array3.
57    ImageConvertFailed = 4,
58}
59
60/// Backend-specific model variant.
61enum ModelInner {
62    /// ONNX Runtime (CPU, CUDA, or TensorRT execution provider).
63    Ort(od_opencv::backend_ort::ModelUltralyticsOrt),
64    /// Native TensorRT engine.
65    #[cfg(feature = "trt")]
66    Trt(od_opencv::backend_tensorrt::ModelUltralyticsRt),
67    /// Rockchip RKNN NPU.
68    #[cfg(feature = "rknn")]
69    Rknn(od_opencv::backend_rknn::ModelUltralyticsRknn),
70}
71
72/// Opaque model handle. Created by any `od_model_create_*` function.
73pub struct ModelHandle {
74    inner: ModelInner,
75}
76
77impl ModelHandle {
78    /// Run detection on an `ImageBuffer`, dispatching to the active backend.
79    fn detect(
80        &mut self,
81        img: &od_opencv::ImageBuffer,
82        conf: f32,
83        nms: f32,
84    ) -> Result<(Vec<BBox>, Vec<usize>, Vec<f32>), OdError> {
85        match &mut self.inner {
86            ModelInner::Ort(m) => m.detect(img, conf, nms).map_err(|e| {
87                eprintln!("od_model_detect (ort): {e:?}");
88                OdError::DetectionFailed
89            }),
90            #[cfg(feature = "trt")]
91            ModelInner::Trt(m) => m.detect(img, conf, nms).map_err(|e| {
92                eprintln!("od_model_detect (trt): {e:?}");
93                OdError::DetectionFailed
94            }),
95            #[cfg(feature = "rknn")]
96            ModelInner::Rknn(m) => m.detect(img, conf, nms).map_err(|e| {
97                eprintln!("od_model_detect (rknn): {e:?}");
98                OdError::DetectionFailed
99            }),
100        }
101    }
102}
103
104/// Helper: parse a C string pointer into a Rust `&str`.
105/// Returns `None` if the pointer is null or not valid UTF-8.
106unsafe fn parse_cstr(p: *const c_char) -> Option<&'static str> {
107    if p.is_null() {
108        return None;
109    }
110    unsafe { CStr::from_ptr(p) }.to_str().ok()
111}
112
113/// Helper: allocate a `ModelHandle` on the heap and return a raw pointer.
114fn into_handle(inner: ModelInner) -> *mut ModelHandle {
115    Box::into_raw(Box::new(ModelHandle { inner }))
116}
117
118/// Create a model from an ONNX file (ORT backend, CPU).
119///
120/// # Parameters
121/// - `model_path`: null-terminated path to `.onnx` file
122/// - `input_w`, `input_h`: model input dimensions (e.g. 416, 416)
123///
124/// # Returns
125/// Opaque pointer, or null on error.
126///
127/// # Safety
128/// `model_path` must be a valid null-terminated C string.
129#[unsafe(no_mangle)]
130pub unsafe extern "C" fn od_model_create(
131    model_path: *const c_char,
132    input_w: u32,
133    input_h: u32,
134) -> *mut ModelHandle {
135    let Some(path) = (unsafe { parse_cstr(model_path) }) else {
136        return ptr::null_mut();
137    };
138    match Model::ort(path, (input_w, input_h)) {
139        Ok(model) => into_handle(ModelInner::Ort(model)),
140        Err(e) => {
141            eprintln!("od_model_create: {e:?}");
142            ptr::null_mut()
143        }
144    }
145}
146
147/// Create a model from an ONNX file with CUDA execution provider.
148///
149/// # Safety
150/// `model_path` must be a valid null-terminated C string.
151#[cfg(feature = "cuda")]
152#[unsafe(no_mangle)]
153pub unsafe extern "C" fn od_model_create_cuda(
154    model_path: *const c_char,
155    input_w: u32,
156    input_h: u32,
157) -> *mut ModelHandle {
158    let Some(path) = (unsafe { parse_cstr(model_path) }) else {
159        return ptr::null_mut();
160    };
161    match Model::ort_cuda(path, (input_w, input_h)) {
162        Ok(model) => into_handle(ModelInner::Ort(model)),
163        Err(e) => {
164            eprintln!("od_model_create_cuda: {e:?}");
165            ptr::null_mut()
166        }
167    }
168}
169
170/// Create a model from an ONNX file with TensorRT execution provider (via ORT).
171///
172/// # Safety
173/// `model_path` must be a valid null-terminated C string.
174#[cfg(feature = "tensorrt")]
175#[unsafe(no_mangle)]
176pub unsafe extern "C" fn od_model_create_tensorrt(
177    model_path: *const c_char,
178    input_w: u32,
179    input_h: u32,
180) -> *mut ModelHandle {
181    let Some(path) = (unsafe { parse_cstr(model_path) }) else {
182        return ptr::null_mut();
183    };
184    match Model::ort_tensorrt(path, (input_w, input_h)) {
185        Ok(model) => into_handle(ModelInner::Ort(model)),
186        Err(e) => {
187            eprintln!("od_model_create_tensorrt: {e:?}");
188            ptr::null_mut()
189        }
190    }
191}
192
193/// Create a model from a serialized TensorRT engine file (native TensorRT, no ORT).
194///
195/// # Safety
196/// `engine_path` must be a valid null-terminated C string.
197#[cfg(feature = "trt")]
198#[unsafe(no_mangle)]
199pub unsafe extern "C" fn od_model_create_trt(
200    engine_path: *const c_char,
201) -> *mut ModelHandle {
202    let Some(path) = (unsafe { parse_cstr(engine_path) }) else {
203        return ptr::null_mut();
204    };
205    match Model::tensorrt(path) {
206        Ok(model) => into_handle(ModelInner::Trt(model)),
207        Err(e) => {
208            eprintln!("od_model_create_trt: {e:?}");
209            ptr::null_mut()
210        }
211    }
212}
213
214/// Create a model from an RKNN model file (Rockchip NPU).
215///
216/// # Parameters
217/// - `model_path`: null-terminated path to `.rknn` file
218/// - `num_classes`: number of classes the model was trained on
219///
220/// # Safety
221/// `model_path` must be a valid null-terminated C string.
222#[cfg(feature = "rknn")]
223#[unsafe(no_mangle)]
224pub unsafe extern "C" fn od_model_create_rknn(
225    model_path: *const c_char,
226    num_classes: u32,
227) -> *mut ModelHandle {
228    let Some(path) = (unsafe { parse_cstr(model_path) }) else {
229        return ptr::null_mut();
230    };
231    match Model::rknn(path, num_classes as usize) {
232        Ok(model) => into_handle(ModelInner::Rknn(model)),
233        Err(e) => {
234            eprintln!("od_model_create_rknn: {e:?}");
235            ptr::null_mut()
236        }
237    }
238}
239
240/// Free a model handle.
241///
242/// # Safety
243/// `handle` must have been returned by `od_model_create*` and not yet freed.
244#[unsafe(no_mangle)]
245pub unsafe extern "C" fn od_model_free(handle: *mut ModelHandle) {
246    if !handle.is_null() {
247        drop(unsafe { Box::from_raw(handle) });
248    }
249}
250
251/// Run detection on an RGB image.
252///
253/// Works with any backend: the handle dispatches to the correct runtime internally.
254///
255/// # Parameters
256/// - `handle`: model handle from any `od_model_create_*` function
257/// - `pixels_rgb`: pointer to `width * height * 3` bytes (RGB, row-major, HWC)
258/// - `img_w`, `img_h`: image dimensions in pixels
259/// - `conf_threshold`: confidence threshold (e.g. 0.3)
260/// - `nms_threshold`: NMS IoU threshold (e.g. 0.4)
261/// - `out`: pointer to `OdDetections` struct, filled on success
262///
263/// # Returns
264/// `OdError::Ok` on success. On error, `out` is zeroed.
265///
266/// # Safety
267/// - `handle` must be valid.
268/// - `pixels_rgb` must point to at least `img_w * img_h * 3` bytes.
269/// - `out` must be a valid pointer.
270#[unsafe(no_mangle)]
271pub unsafe extern "C" fn od_model_detect(
272    handle: *mut ModelHandle,
273    pixels_rgb: *const u8,
274    img_w: i32,
275    img_h: i32,
276    conf_threshold: f32,
277    nms_threshold: f32,
278    out: *mut OdDetections,
279) -> OdError {
280    if handle.is_null() || pixels_rgb.is_null() || out.is_null() {
281        return OdError::InvalidArgument;
282    }
283    if img_w <= 0 || img_h <= 0 {
284        return OdError::InvalidArgument;
285    }
286
287    let model = unsafe { &mut *handle };
288    let h = img_h as usize;
289    let w = img_w as usize;
290    let n_bytes = h * w * 3;
291
292    let rgb_slice = unsafe { slice::from_raw_parts(pixels_rgb, n_bytes) };
293    let arr = match Array3::from_shape_vec((h, w, 3), rgb_slice.to_vec()) {
294        Ok(a) => a,
295        Err(_) => {
296            unsafe {
297                (*out).data = ptr::null_mut();
298                (*out).len = 0;
299            }
300            return OdError::ImageConvertFailed;
301        }
302    };
303
304    let img_buf = od_opencv::ImageBuffer::from_rgb(arr);
305
306    let (bboxes, class_ids, confidences) = match model.detect(&img_buf, conf_threshold, nms_threshold) {
307        Ok(result) => result,
308        Err(e) => {
309            unsafe {
310                (*out).data = ptr::null_mut();
311                (*out).len = 0;
312            }
313            return e;
314        }
315    };
316
317    let count = bboxes.len();
318    if count == 0 {
319        unsafe {
320            (*out).data = ptr::null_mut();
321            (*out).len = 0;
322        }
323        return OdError::Ok;
324    }
325
326    let mut results: Vec<OdDetection> = Vec::with_capacity(count);
327    for i in 0..count {
328        results.push(OdDetection {
329            bbox_x: bboxes[i].x,
330            bbox_y: bboxes[i].y,
331            bbox_w: bboxes[i].width,
332            bbox_h: bboxes[i].height,
333            class_id: class_ids[i] as i32,
334            confidence: confidences[i],
335        });
336    }
337
338    let mut results = results.into_boxed_slice();
339    unsafe {
340        (*out).data = results.as_mut_ptr();
341        (*out).len = count as i32;
342    }
343    std::mem::forget(results);
344
345    OdError::Ok
346}
347
348/// Free detection results.
349///
350/// # Safety
351/// `detections` must point to a valid `OdDetections` returned by `od_model_detect`.
352#[unsafe(no_mangle)]
353pub unsafe extern "C" fn od_detections_free(detections: *mut OdDetections) {
354    if detections.is_null() {
355        return;
356    }
357    let d = unsafe { &mut *detections };
358    if !d.data.is_null() && d.len > 0 {
359        let _ = unsafe {
360            Vec::from_raw_parts(d.data, d.len as usize, d.len as usize)
361        };
362        d.data = ptr::null_mut();
363        d.len = 0;
364    }
365}
366
367/// Single face detection + recognition result (flat, safe for CGO memcpy).
368#[repr(C)]
369#[derive(Debug, Clone, Copy)]
370pub struct FaceDetectionResult {
371    /// Bounding box top-left X (pixels).
372    pub bbox_x: f32,
373    /// Bounding box top-left Y (pixels).
374    pub bbox_y: f32,
375    /// Bounding box width (pixels).
376    pub bbox_w: f32,
377    /// Bounding box height (pixels).
378    pub bbox_h: f32,
379    /// Detection confidence in [0.0, 1.0].
380    pub confidence: f32,
381    /// 5 facial landmarks: [x0,y0, x1,y1, ..., x4,y4] (10 floats).
382    pub landmarks: [f32; 10],
383    /// 512-dimensional L2-normalized embedding.
384    pub embedding: [f32; 512],
385}
386
387/// Face detection results batch. Caller must free via `face_pipeline_results_free`.
388#[repr(C)]
389pub struct FaceDetectionResults {
390    /// Pointer to the first element of the results array.
391    pub data: *mut FaceDetectionResult,
392    /// Number of face detections in the array.
393    pub len: i32,
394}
395
396/// Opaque face pipeline handle. Created by `face_pipeline_create*`.
397pub struct FacePipelineHandle {
398    inner: FacePipeline,
399}
400
401/// Create a face pipeline (YuNet detector + ArcFace recognizer, ORT CPU).
402///
403/// # Parameters
404/// - `detector_path`: null-terminated path to YuNet `.onnx` file
405/// - `recognizer_path`: null-terminated path to ArcFace `.onnx` file (e.g. `w600k_mbf.onnx`)
406///
407/// # Returns
408/// Opaque pointer, or null on error.
409///
410/// # Safety
411/// Both paths must be valid null-terminated C strings.
412#[unsafe(no_mangle)]
413pub unsafe extern "C" fn face_pipeline_create(
414    detector_path: *const c_char,
415    recognizer_path: *const c_char,
416) -> *mut FacePipelineHandle {
417    let Some(det_path) = (unsafe { parse_cstr(detector_path) }) else {
418        return ptr::null_mut();
419    };
420    let Some(rec_path) = (unsafe { parse_cstr(recognizer_path) }) else {
421        return ptr::null_mut();
422    };
423    match FacePipeline::new(det_path, rec_path) {
424        Ok(pipeline) => Box::into_raw(Box::new(FacePipelineHandle { inner: pipeline })),
425        Err(e) => {
426            eprintln!("face_pipeline_create: {e:?}");
427            ptr::null_mut()
428        }
429    }
430}
431
432/// Create a face pipeline with CUDA acceleration.
433///
434/// # Safety
435/// Both paths must be valid null-terminated C strings.
436#[cfg(feature = "cuda")]
437#[unsafe(no_mangle)]
438pub unsafe extern "C" fn face_pipeline_create_cuda(
439    detector_path: *const c_char,
440    recognizer_path: *const c_char,
441) -> *mut FacePipelineHandle {
442    let Some(det_path) = (unsafe { parse_cstr(detector_path) }) else {
443        return ptr::null_mut();
444    };
445    let Some(rec_path) = (unsafe { parse_cstr(recognizer_path) }) else {
446        return ptr::null_mut();
447    };
448    match FacePipeline::new_cuda(det_path, rec_path) {
449        Ok(pipeline) => Box::into_raw(Box::new(FacePipelineHandle { inner: pipeline })),
450        Err(e) => {
451            eprintln!("face_pipeline_create_cuda: {e:?}");
452            ptr::null_mut()
453        }
454    }
455}
456
457/// Create a face pipeline with TensorRT acceleration (via ORT).
458///
459/// # Safety
460/// Both paths must be valid null-terminated C strings.
461#[cfg(feature = "tensorrt")]
462#[unsafe(no_mangle)]
463pub unsafe extern "C" fn face_pipeline_create_tensorrt(
464    detector_path: *const c_char,
465    recognizer_path: *const c_char,
466) -> *mut FacePipelineHandle {
467    let Some(det_path) = (unsafe { parse_cstr(detector_path) }) else {
468        return ptr::null_mut();
469    };
470    let Some(rec_path) = (unsafe { parse_cstr(recognizer_path) }) else {
471        return ptr::null_mut();
472    };
473    match FacePipeline::new_tensorrt(det_path, rec_path) {
474        Ok(pipeline) => Box::into_raw(Box::new(FacePipelineHandle { inner: pipeline })),
475        Err(e) => {
476            eprintln!("face_pipeline_create_tensorrt: {e:?}");
477            ptr::null_mut()
478        }
479    }
480}
481
482/// Returns the expected aligned face size (square side, read from the ONNX model).
483///
484/// E.g. 112 for MobileFaceNet (w600k_mbf.onnx).
485/// Go-side should call this instead of hardcoding a constant.
486///
487/// # Safety
488/// `handle` must be a valid pointer returned by `face_pipeline_create*`.
489#[unsafe(no_mangle)]
490pub unsafe extern "C" fn face_pipeline_aligned_size(
491    handle: *const FacePipelineHandle,
492) -> u32 {
493    if handle.is_null() {
494        return 0;
495    }
496    unsafe { &*handle }.inner.aligned_size()
497}
498
499/// Run face detection + recognition on an RGB image.
500///
501/// # Parameters
502/// - `handle`: face pipeline handle
503/// - `pixels_rgb`: pointer to `width * height * 3` bytes (RGB, row-major, HWC)
504/// - `img_w`, `img_h`: image dimensions in pixels
505/// - `conf_threshold`: detection confidence threshold (e.g. 0.7)
506/// - `nms_threshold`: NMS IoU threshold (e.g. 0.3)
507/// - `out`: pointer to `FaceDetectionResults` struct, filled on success
508///
509/// # Returns
510/// `OdError::Ok` on success. On error, `out` is zeroed.
511///
512/// # Safety
513/// - `handle` must be valid.
514/// - `pixels_rgb` must point to at least `img_w * img_h * 3` bytes.
515/// - `out` must be a valid pointer.
516#[unsafe(no_mangle)]
517pub unsafe extern "C" fn face_pipeline_process(
518    handle: *mut FacePipelineHandle,
519    pixels_rgb: *const u8,
520    img_w: i32,
521    img_h: i32,
522    conf_threshold: f32,
523    nms_threshold: f32,
524    out: *mut FaceDetectionResults,
525) -> OdError {
526    if handle.is_null() || pixels_rgb.is_null() || out.is_null() {
527        return OdError::InvalidArgument;
528    }
529    if img_w <= 0 || img_h <= 0 {
530        return OdError::InvalidArgument;
531    }
532
533    let pipeline = unsafe { &mut *handle };
534    let h = img_h as usize;
535    let w = img_w as usize;
536    let n_bytes = h * w * 3;
537
538    let rgb_slice = unsafe { slice::from_raw_parts(pixels_rgb, n_bytes) };
539    let arr = match Array3::from_shape_vec((h, w, 3), rgb_slice.to_vec()) {
540        Ok(a) => a,
541        Err(_) => {
542            unsafe {
543                (*out).data = ptr::null_mut();
544                (*out).len = 0;
545            }
546            return OdError::ImageConvertFailed;
547        }
548    };
549
550    let img_buf = od_opencv::ImageBuffer::from_rgb(arr);
551
552    let faces = match pipeline.inner.process(&img_buf, conf_threshold, nms_threshold) {
553        Ok(f) => f,
554        Err(e) => {
555            eprintln!("face_pipeline_process: {e:?}");
556            unsafe {
557                (*out).data = ptr::null_mut();
558                (*out).len = 0;
559            }
560            return OdError::DetectionFailed;
561        }
562    };
563
564    let count = faces.len();
565    if count == 0 {
566        unsafe {
567            (*out).data = ptr::null_mut();
568            (*out).len = 0;
569        }
570        return OdError::Ok;
571    }
572
573    let mut results: Vec<FaceDetectionResult> = Vec::with_capacity(count);
574    for face in &faces {
575        let mut landmarks = [0.0f32; 10];
576        for k in 0..5 {
577            landmarks[k * 2] = face.landmarks[k][0];
578            landmarks[k * 2 + 1] = face.landmarks[k][1];
579        }
580        results.push(FaceDetectionResult {
581            bbox_x: face.x,
582            bbox_y: face.y,
583            bbox_w: face.width,
584            bbox_h: face.height,
585            confidence: face.confidence,
586            landmarks,
587            embedding: face.embedding,
588        });
589    }
590
591    let mut results = results.into_boxed_slice();
592    unsafe {
593        (*out).data = results.as_mut_ptr();
594        (*out).len = count as i32;
595    }
596    std::mem::forget(results);
597
598    OdError::Ok
599}
600
601/// Extract embedding from a pre-aligned face image.
602///
603/// The image must be aligned to the size returned by `face_pipeline_aligned_size()`
604/// (typically 112x112).
605///
606/// # Parameters
607/// - `handle`: face pipeline handle
608/// - `pixels_rgb`: pointer to aligned face RGB data (size x size x 3 bytes)
609/// - `size`: aligned face size (e.g. 112)
610/// - `out_embedding`: pointer to caller-allocated `[f32; 512]` buffer
611///
612/// # Returns
613/// `OdError::Ok` on success.
614///
615/// # Safety
616/// - `handle` must be valid.
617/// - `pixels_rgb` must point to at least `size * size * 3` bytes.
618/// - `out_embedding` must point to at least 512 f32 elements.
619#[unsafe(no_mangle)]
620pub unsafe extern "C" fn face_pipeline_embed(
621    handle: *mut FacePipelineHandle,
622    pixels_rgb: *const u8,
623    size: i32,
624    out_embedding: *mut f32,
625) -> OdError {
626    if handle.is_null() || pixels_rgb.is_null() || out_embedding.is_null() {
627        return OdError::InvalidArgument;
628    }
629    if size <= 0 {
630        return OdError::InvalidArgument;
631    }
632
633    let pipeline = unsafe { &mut *handle };
634    let s = size as usize;
635    let n_bytes = s * s * 3;
636
637    let rgb_slice = unsafe { slice::from_raw_parts(pixels_rgb, n_bytes) };
638    let arr = match Array3::from_shape_vec((s, s, 3), rgb_slice.to_vec()) {
639        Ok(a) => a,
640        Err(_) => return OdError::ImageConvertFailed,
641    };
642
643    let img_buf = od_opencv::ImageBuffer::from_rgb(arr);
644
645    match pipeline.inner.embed(&img_buf) {
646        Ok(embedding) => {
647            unsafe {
648                ptr::copy_nonoverlapping(embedding.as_ptr(), out_embedding, 512);
649            }
650            OdError::Ok
651        }
652        Err(e) => {
653            eprintln!("face_pipeline_embed: {e:?}");
654            OdError::DetectionFailed
655        }
656    }
657}
658
659/// Free a face pipeline handle.
660///
661/// # Safety
662/// `handle` must have been returned by `face_pipeline_create*` and not yet freed.
663#[unsafe(no_mangle)]
664pub unsafe extern "C" fn face_pipeline_destroy(handle: *mut FacePipelineHandle) {
665    if !handle.is_null() {
666        drop(unsafe { Box::from_raw(handle) });
667    }
668}
669
670/// Free face detection results.
671///
672/// # Safety
673/// `results` must point to a valid `FaceDetectionResults` returned by `face_pipeline_process`.
674#[unsafe(no_mangle)]
675pub unsafe extern "C" fn face_pipeline_results_free(results: *mut FaceDetectionResults) {
676    if results.is_null() {
677        return;
678    }
679    let r = unsafe { &mut *results };
680    if !r.data.is_null() && r.len > 0 {
681        let _ = unsafe {
682            Vec::from_raw_parts(r.data, r.len as usize, r.len as usize)
683        };
684        r.data = ptr::null_mut();
685        r.len = 0;
686    }
687}