Skip to main content

oximedia_ml/
postprocess.rs

1//! Output tensor post-processing utilities.
2//!
3//! These are small, self-contained helpers that operate on plain
4//! `&[f32]` slices. They keep the pipeline layer free of backend-
5//! specific tensor plumbing.
6//!
7//! A small geometric type — [`BoundingBox`] — also lives here so that
8//! the NMS / IoU helpers are always available regardless of which
9//! pipeline features are enabled.
10//!
11//! ## Helper overview
12//!
13//! | Helper                  | Use case                                             |
14//! |-------------------------|------------------------------------------------------|
15//! | [`softmax`]             | Turn classifier logits into a probability vector.    |
16//! | [`argmax`]              | Top-1 class index (errors on empty).                 |
17//! | [`top_k`]               | Top-k `(index, score)` pairs, descending.            |
18//! | [`sigmoid`]             | Scalar logistic sigmoid.                             |
19//! | [`sigmoid_slice`]       | Element-wise sigmoid for multi-label outputs.        |
20//! | [`iou`]                 | Pairwise IoU of two [`BoundingBox`]es.               |
21//! | [`nms`]                 | Greedy IoU-based Non-Maximum Suppression.            |
22//! | [`l2_normalize`]        | In-place L2 unit-normalisation (safe on zero norm).  |
23//! | [`cosine_similarity`]   | Cosine similarity; zero on mismatched / empty input. |
24//!
25//! ## Example
26//!
27//! ```
28//! use oximedia_ml::postprocess::{argmax, softmax, top_k};
29//!
30//! # fn main() -> oximedia_ml::MlResult<()> {
31//! let logits = [0.1_f32, 5.0, 0.3, 0.2];
32//! let probs = softmax(&logits);
33//! assert_eq!(argmax(&probs)?, 1);
34//!
35//! let ranked = top_k(&probs, 2)?;
36//! assert_eq!(ranked[0].0, 1); // best class
37//! # Ok(())
38//! # }
39//! ```
40
41use crate::error::{MlError, MlResult};
42
43/// Axis-aligned bounding box in corner form.
44///
45/// Coordinates are in the same space as the detector input (typically
46/// normalised 0..=1 or pixel-space 0..=W/H). The type is deliberately
47/// side-effect-free — semantic interpretation (pixel vs normalised)
48/// is left to the caller.
49///
50/// # Examples
51///
52/// ```
53/// use oximedia_ml::BoundingBox;
54///
55/// let b = BoundingBox::from_xywh_center(10.0, 20.0, 4.0, 8.0);
56/// assert_eq!(b.width(), 4.0);
57/// assert_eq!(b.area(), 32.0);
58/// ```
59#[derive(Clone, Copy, Debug, PartialEq)]
60pub struct BoundingBox {
61    /// Top-left X coordinate.
62    pub x0: f32,
63    /// Top-left Y coordinate.
64    pub y0: f32,
65    /// Bottom-right X coordinate.
66    pub x1: f32,
67    /// Bottom-right Y coordinate.
68    pub y1: f32,
69}
70
71impl BoundingBox {
72    /// Construct a new bounding box from corner coordinates.
73    #[must_use]
74    pub const fn new(x0: f32, y0: f32, x1: f32, y1: f32) -> Self {
75        Self { x0, y0, x1, y1 }
76    }
77
78    /// Width of the box clamped to `>= 0`.
79    #[must_use]
80    pub fn width(&self) -> f32 {
81        (self.x1 - self.x0).max(0.0)
82    }
83
84    /// Height of the box clamped to `>= 0`.
85    #[must_use]
86    pub fn height(&self) -> f32 {
87        (self.y1 - self.y0).max(0.0)
88    }
89
90    /// Area of the box (0 for degenerate / negative-extent boxes).
91    #[must_use]
92    pub fn area(&self) -> f32 {
93        self.width() * self.height()
94    }
95
96    /// Build a [`BoundingBox`] from YOLO-style centre form (`cx, cy, w, h`).
97    #[must_use]
98    pub fn from_xywh_center(cx: f32, cy: f32, w: f32, h: f32) -> Self {
99        let half_w = w * 0.5;
100        let half_h = h * 0.5;
101        Self {
102            x0: cx - half_w,
103            y0: cy - half_h,
104            x1: cx + half_w,
105            y1: cy + half_h,
106        }
107    }
108}
109
110/// Apply the softmax function along the slice.
111///
112/// Uses the max-shift trick for numerical stability. If every entry is
113/// `-∞` the fallback distribution is uniform. Returns an empty vector
114/// when the input is empty.
115///
116/// # Examples
117///
118/// ```
119/// use oximedia_ml::postprocess::softmax;
120///
121/// let probs = softmax(&[1.0, 2.0, 3.0]);
122/// let sum: f32 = probs.iter().sum();
123/// assert!((sum - 1.0).abs() < 1e-5);
124/// ```
125#[must_use]
126pub fn softmax(logits: &[f32]) -> Vec<f32> {
127    if logits.is_empty() {
128        return Vec::new();
129    }
130    let mut max = f32::NEG_INFINITY;
131    for &v in logits {
132        if v > max {
133            max = v;
134        }
135    }
136    let mut exps: Vec<f32> = logits.iter().map(|&v| (v - max).exp()).collect();
137    let sum: f32 = exps.iter().sum();
138    if sum == 0.0 {
139        // Degenerate case (all -inf); fall back to uniform.
140        let n = exps.len() as f32;
141        for e in &mut exps {
142            *e = 1.0 / n;
143        }
144    } else {
145        for e in &mut exps {
146            *e /= sum;
147        }
148    }
149    exps
150}
151
152/// Return the index of the largest value in `scores`.
153///
154/// # Errors
155///
156/// Returns [`MlError::Postprocess`] if `scores` is empty.
157///
158/// # Examples
159///
160/// ```
161/// use oximedia_ml::postprocess::argmax;
162///
163/// # fn main() -> oximedia_ml::MlResult<()> {
164/// assert_eq!(argmax(&[0.1, 0.4, 0.2])?, 1);
165/// # Ok(())
166/// # }
167/// ```
168pub fn argmax(scores: &[f32]) -> MlResult<usize> {
169    if scores.is_empty() {
170        return Err(MlError::postprocess("argmax on empty slice"));
171    }
172    let mut best = 0usize;
173    let mut best_v = scores[0];
174    for (i, &v) in scores.iter().enumerate().skip(1) {
175        if v > best_v {
176            best = i;
177            best_v = v;
178        }
179    }
180    Ok(best)
181}
182
183/// Return the top-`k` `(index, score)` pairs, sorted by descending score.
184///
185/// When `k == 0` an empty `Vec` is returned (no error). When `k` exceeds
186/// `scores.len()` the result is simply truncated to the input length.
187///
188/// # Errors
189///
190/// Returns [`MlError::Postprocess`] if `scores` is empty.
191///
192/// # Examples
193///
194/// ```
195/// use oximedia_ml::postprocess::top_k;
196///
197/// # fn main() -> oximedia_ml::MlResult<()> {
198/// let ranked = top_k(&[0.1, 0.5, 0.3, 0.7, 0.2], 3)?;
199/// assert_eq!(ranked[0].0, 3);
200/// assert_eq!(ranked[1].0, 1);
201/// # Ok(())
202/// # }
203/// ```
204pub fn top_k(scores: &[f32], k: usize) -> MlResult<Vec<(usize, f32)>> {
205    if scores.is_empty() {
206        return Err(MlError::postprocess("top_k on empty slice"));
207    }
208    if k == 0 {
209        return Ok(Vec::new());
210    }
211    let mut indexed: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect();
212    indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
213    indexed.truncate(k);
214    Ok(indexed)
215}
216
217/// Apply the logistic sigmoid to a single value.
218#[must_use]
219pub fn sigmoid(v: f32) -> f32 {
220    1.0 / (1.0 + (-v).exp())
221}
222
223/// Apply sigmoid element-wise to a slice.
224#[must_use]
225pub fn sigmoid_slice(values: &[f32]) -> Vec<f32> {
226    values.iter().copied().map(sigmoid).collect()
227}
228
229/// Intersection-over-Union for two bounding boxes in corner form.
230///
231/// Returns `0.0` if either box has zero or negative area, or if the
232/// boxes are disjoint.
233#[must_use]
234pub fn iou(a: &BoundingBox, b: &BoundingBox) -> f32 {
235    let ix0 = a.x0.max(b.x0);
236    let iy0 = a.y0.max(b.y0);
237    let ix1 = a.x1.min(b.x1);
238    let iy1 = a.y1.min(b.y1);
239    let iw = (ix1 - ix0).max(0.0);
240    let ih = (iy1 - iy0).max(0.0);
241    let inter = iw * ih;
242    if inter <= 0.0 {
243        return 0.0;
244    }
245    let area_a = a.area();
246    let area_b = b.area();
247    let union = area_a + area_b - inter;
248    if union <= 0.0 {
249        return 0.0;
250    }
251    (inter / union).clamp(0.0, 1.0)
252}
253
254/// Greedy Non-Maximum Suppression (NMS) over `(boxes, scores)`.
255///
256/// * `boxes` and `scores` must have equal length; otherwise an empty
257///   `Vec` is returned.
258/// * Boxes are processed in descending score order.
259/// * Any box whose IoU with an already-kept box exceeds
260///   `iou_threshold` is suppressed.
261/// * `iou_threshold` is clamped to `0.0..=1.0`.
262///
263/// Returned indices reference positions in the original `boxes` /
264/// `scores` slices, sorted by descending score.
265///
266/// # Examples
267///
268/// ```
269/// use oximedia_ml::{postprocess::nms, BoundingBox};
270///
271/// let a = BoundingBox::new(0.0, 0.0, 10.0, 10.0);
272/// let b = BoundingBox::new(1.0, 1.0, 11.0, 11.0);
273/// let c = BoundingBox::new(50.0, 50.0, 60.0, 60.0);
274/// let kept = nms(&[a, b, c], &[0.9_f32, 0.8, 0.7], 0.5);
275/// // The overlapping box is suppressed; `c` is far away so it survives.
276/// assert_eq!(kept, vec![0, 2]);
277/// ```
278#[must_use]
279pub fn nms(boxes: &[BoundingBox], scores: &[f32], iou_threshold: f32) -> Vec<usize> {
280    if boxes.len() != scores.len() || boxes.is_empty() {
281        return Vec::new();
282    }
283    let threshold = iou_threshold.clamp(0.0, 1.0);
284
285    // Sort indices by descending score.
286    let mut order: Vec<usize> = (0..boxes.len()).collect();
287    order.sort_by(|&a, &b| {
288        scores[b]
289            .partial_cmp(&scores[a])
290            .unwrap_or(std::cmp::Ordering::Equal)
291    });
292
293    let mut kept: Vec<usize> = Vec::with_capacity(order.len());
294    for &idx in &order {
295        let cand = &boxes[idx];
296        if cand.area() <= 0.0 {
297            continue;
298        }
299        let mut suppress = false;
300        for &keep_idx in &kept {
301            if iou(cand, &boxes[keep_idx]) > threshold {
302                suppress = true;
303                break;
304            }
305        }
306        if !suppress {
307            kept.push(idx);
308        }
309    }
310    kept
311}
312
313/// In-place L2 normalisation of a float vector.
314///
315/// If the input norm is zero (or non-finite) the slice is left
316/// untouched. Safe to call on any `&mut [f32]`.
317pub fn l2_normalize(v: &mut [f32]) {
318    let norm_sq: f32 = v.iter().map(|x| x * x).sum();
319    if !norm_sq.is_finite() || norm_sq <= 0.0 {
320        return;
321    }
322    let inv = norm_sq.sqrt().recip();
323    for x in v.iter_mut() {
324        *x *= inv;
325    }
326}
327
328/// Cosine similarity for two equal-length slices.
329///
330/// Returns `0.0` if either input is empty, the lengths mismatch, or
331/// either vector has zero L2 norm.
332#[must_use]
333pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
334    if a.len() != b.len() || a.is_empty() {
335        return 0.0;
336    }
337    let mut dot = 0.0_f32;
338    let mut norm_a = 0.0_f32;
339    let mut norm_b = 0.0_f32;
340    for (&x, &y) in a.iter().zip(b.iter()) {
341        dot += x * y;
342        norm_a += x * x;
343        norm_b += y * y;
344    }
345    if norm_a <= 0.0 || norm_b <= 0.0 {
346        return 0.0;
347    }
348    let denom = norm_a.sqrt() * norm_b.sqrt();
349    if denom <= 0.0 {
350        0.0
351    } else {
352        (dot / denom).clamp(-1.0, 1.0)
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn softmax_sums_to_one() {
362        let probs = softmax(&[1.0, 2.0, 3.0]);
363        let sum: f32 = probs.iter().sum();
364        assert!((sum - 1.0).abs() < 1e-5);
365    }
366
367    #[test]
368    fn softmax_empty_is_empty() {
369        assert!(softmax(&[]).is_empty());
370    }
371
372    #[test]
373    fn softmax_largest_input_is_largest_output() {
374        let probs = softmax(&[0.1, 5.0, 0.3, 0.2]);
375        assert!(probs[1] > probs[0]);
376        assert!(probs[1] > probs[2]);
377        assert!(probs[1] > probs[3]);
378    }
379
380    #[test]
381    fn argmax_picks_max() {
382        let idx = argmax(&[0.1, 0.4, 0.2]).expect("ok");
383        assert_eq!(idx, 1);
384    }
385
386    #[test]
387    fn argmax_empty_errors() {
388        let err = argmax(&[]).expect_err("must fail");
389        assert!(matches!(err, MlError::Postprocess(_)));
390    }
391
392    #[test]
393    fn top_k_sorted_descending() {
394        let r = top_k(&[0.1, 0.5, 0.3, 0.7, 0.2], 3).expect("ok");
395        assert_eq!(r.len(), 3);
396        assert_eq!(r[0].0, 3);
397        assert_eq!(r[1].0, 1);
398        assert_eq!(r[2].0, 2);
399    }
400
401    #[test]
402    fn top_k_zero_returns_empty() {
403        let r = top_k(&[1.0, 2.0], 0).expect("ok");
404        assert!(r.is_empty());
405    }
406
407    #[test]
408    fn sigmoid_zero_is_half() {
409        assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
410    }
411
412    #[test]
413    fn sigmoid_slice_matches() {
414        let v = sigmoid_slice(&[-10.0, 0.0, 10.0]);
415        assert!(v[0] < 0.001);
416        assert!((v[1] - 0.5).abs() < 1e-6);
417        assert!(v[2] > 0.999);
418    }
419
420    #[test]
421    fn bbox_xywh_center_round_trip() {
422        let b = BoundingBox::from_xywh_center(10.0, 20.0, 4.0, 8.0);
423        assert!((b.x0 - 8.0).abs() < 1e-5);
424        assert!((b.y0 - 16.0).abs() < 1e-5);
425        assert!((b.x1 - 12.0).abs() < 1e-5);
426        assert!((b.y1 - 24.0).abs() < 1e-5);
427        assert!((b.area() - 32.0).abs() < 1e-5);
428    }
429
430    #[test]
431    fn bbox_negative_extent_has_zero_area() {
432        let b = BoundingBox::new(5.0, 5.0, 2.0, 2.0);
433        assert_eq!(b.width(), 0.0);
434        assert_eq!(b.height(), 0.0);
435        assert_eq!(b.area(), 0.0);
436    }
437
438    #[test]
439    fn iou_identical_boxes_is_one() {
440        let b = BoundingBox::new(0.0, 0.0, 10.0, 10.0);
441        assert!((iou(&b, &b) - 1.0).abs() < 1e-6);
442    }
443
444    #[test]
445    fn iou_zero_area_returns_zero() {
446        let a = BoundingBox::new(0.0, 0.0, 0.0, 0.0);
447        let b = BoundingBox::new(0.0, 0.0, 10.0, 10.0);
448        assert_eq!(iou(&a, &b), 0.0);
449    }
450
451    #[test]
452    fn nms_handles_length_mismatch() {
453        let boxes = vec![BoundingBox::new(0.0, 0.0, 1.0, 1.0)];
454        let scores = vec![0.9_f32, 0.8];
455        assert!(nms(&boxes, &scores, 0.5).is_empty());
456    }
457
458    #[test]
459    fn l2_normalize_unit_vector_idempotent() {
460        let mut v = vec![3.0_f32, 4.0];
461        l2_normalize(&mut v);
462        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
463        assert!((norm - 1.0).abs() < 1e-5);
464        // Re-normalising does nothing.
465        l2_normalize(&mut v);
466        let norm2: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
467        assert!((norm2 - 1.0).abs() < 1e-5);
468    }
469
470    #[test]
471    fn cosine_similarity_orthogonal_zero() {
472        let a = [1.0_f32, 0.0];
473        let b = [0.0_f32, 1.0];
474        assert!(cosine_similarity(&a, &b).abs() < 1e-6);
475    }
476
477    #[test]
478    fn cosine_similarity_length_mismatch_zero() {
479        let a = [1.0_f32, 2.0];
480        let b = [1.0_f32, 2.0, 3.0];
481        assert_eq!(cosine_similarity(&a, &b), 0.0);
482    }
483}