pico_detect/detect/
clusterize.rs1use crate::geometry::{intersection_over_union, Square, Target};
2use crate::traits::Region;
3
4use super::detection::Detection;
5
6use nalgebra::Point2;
7
8#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct Clusterizer {
11 pub intersection_threshold: f32,
12 pub score_threshold: f32,
13}
14
15impl Clusterizer {
16 #[inline]
18 pub fn intersection_threshold(self, value: f32) -> Self {
19 Self {
20 intersection_threshold: value,
21 ..self
22 }
23 }
24
25 #[inline]
27 pub fn score_threshold(self, value: f32) -> Self {
28 Self {
29 score_threshold: value,
30 ..self
31 }
32 }
33
34 #[inline]
36 pub fn clusterize(&self, data: &mut [Detection<Square>], dest: &mut Vec<Detection<Target>>) {
37 clusterize(
38 data,
39 self.intersection_threshold,
40 self.score_threshold,
41 dest,
42 );
43 }
44}
45
46impl Default for Clusterizer {
47 #[inline]
50 fn default() -> Self {
51 Self {
52 intersection_threshold: 0.7,
53 score_threshold: 0.0,
54 }
55 }
56}
57
58#[inline]
67pub fn clusterize<R: Region + Copy>(
68 data: &mut [Detection<R>],
69 intersection_threshold: f32,
70 score_threshold: f32,
71 dest: &mut Vec<Detection<Target>>,
72) {
73 data.sort_by(|a, b| b.partial_cmp(a).unwrap());
74
75 let mut assignments = vec![false; data.len()];
76
77 for (i, det1) in data.iter().enumerate() {
78 if assignments[i] {
79 continue;
80 } else {
81 assignments[i] = true;
82 }
83
84 let mut point = det1.region.top_left();
85 let mut size = det1.region.width();
86
87 let mut score = det1.score;
88 let mut count: usize = 1;
89
90 for (det2, j) in data[(i + 1)..].iter().zip((i + 1)..) {
91 if let Some(value) = intersection_over_union(det1.region, det2.region) {
92 if value > intersection_threshold {
93 assignments[j] = true;
94
95 point += det2.region.top_left().coords;
96 size += det2.region.width();
97
98 score += det2.score * value;
99 count += 1;
100 }
101 }
102 }
103
104 if score > score_threshold {
105 let scale = (count as f32).recip();
106
107 let size = (size as f32) * scale;
108
109 let mut point: Point2<f32> = point.cast();
110
111 point.coords.scale_mut(scale);
112 point.coords.add_scalar_mut(size / 2.0);
113
114 dest.push(Detection {
115 region: Target { point, size },
116 score,
117 });
118 }
119 }
120}