pico_detect/detect/
clusterize.rs

1use crate::geometry::{intersection_over_union, Square, Target};
2use crate::traits::Region;
3
4use super::detection::Detection;
5
6use nalgebra::Point2;
7
8/// Clustering parameters for object detection results.
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct Clusterizer {
11    pub intersection_threshold: f32,
12    pub score_threshold: f32,
13}
14
15impl Clusterizer {
16    /// Set the intersection threshold for clustering.
17    #[inline]
18    pub fn intersection_threshold(self, value: f32) -> Self {
19        Self {
20            intersection_threshold: value,
21            ..self
22        }
23    }
24
25    /// Set the score threshold for clustering.
26    #[inline]
27    pub fn score_threshold(self, value: f32) -> Self {
28        Self {
29            score_threshold: value,
30            ..self
31        }
32    }
33
34    /// Run clustering on the provided detection data.
35    #[inline]
36    pub fn clusterize(&self, data: &mut [Detection<Square>], dest: &mut Vec<Detection<Target>>) {
37        clusterize(
38            data,
39            self.intersection_threshold,
40            self.score_threshold,
41            dest,
42        );
43    }
44}
45
46impl Default for Clusterizer {
47    /// Create a default clusterizer with intersection threshold of 0.7
48    /// and score threshold of 0.0.
49    #[inline]
50    fn default() -> Self {
51        Self {
52            intersection_threshold: 0.7,
53            score_threshold: 0.0,
54        }
55    }
56}
57
58/// Clusterize detection results based on intersection and score thresholds.
59///
60/// ### Arguments
61///
62/// * `data` -- mutable slice of detection data to clusterize;
63/// * `intersection_threshold` -- threshold for intersection over union;
64/// * `score_threshold` -- threshold for detection score;
65/// * `dest` -- destination vector to store clustered detections.
66#[inline]
67pub fn clusterize<R: Region + Copy>(
68    data: &mut [Detection<R>],
69    intersection_threshold: f32,
70    score_threshold: f32,
71    dest: &mut Vec<Detection<Target>>,
72) {
73    data.sort_by(|a, b| b.partial_cmp(a).unwrap());
74
75    let mut assignments = vec![false; data.len()];
76
77    for (i, det1) in data.iter().enumerate() {
78        if assignments[i] {
79            continue;
80        } else {
81            assignments[i] = true;
82        }
83
84        let mut point = det1.region.top_left();
85        let mut size = det1.region.width();
86
87        let mut score = det1.score;
88        let mut count: usize = 1;
89
90        for (det2, j) in data[(i + 1)..].iter().zip((i + 1)..) {
91            if let Some(value) = intersection_over_union(det1.region, det2.region) {
92                if value > intersection_threshold {
93                    assignments[j] = true;
94
95                    point += det2.region.top_left().coords;
96                    size += det2.region.width();
97
98                    score += det2.score * value;
99                    count += 1;
100                }
101            }
102        }
103
104        if score > score_threshold {
105            let scale = (count as f32).recip();
106
107            let size = (size as f32) * scale;
108
109            let mut point: Point2<f32> = point.cast();
110
111            point.coords.scale_mut(scale);
112            point.coords.add_scalar_mut(size / 2.0);
113
114            dest.push(Detection {
115                region: Target { point, size },
116                score,
117            });
118        }
119    }
120}