1use std::collections::HashMap;
4
5use yscv_tensor::Tensor;
6
7use crate::TrackError;
8
9pub trait ReIdExtractor: Send + Sync {
12 fn extract(&self, crop: &Tensor) -> Result<Vec<f32>, TrackError>;
14 fn dim(&self) -> usize;
16}
17
18pub struct ColorHistogramReId {
21 bins: usize,
22}
23
24impl ColorHistogramReId {
25 pub fn new(bins: usize) -> Self {
26 Self { bins }
27 }
28}
29
30impl ReIdExtractor for ColorHistogramReId {
31 fn extract(&self, crop: &Tensor) -> Result<Vec<f32>, TrackError> {
32 let shape = crop.shape();
34 let c = *shape.last().unwrap_or(&3);
35 let data = crop.data();
36 let pixels = data.len() / c;
37 let mut hist = vec![0.0f32; self.bins * c];
38 for px in 0..pixels {
39 for ch in 0..c {
40 let val = data[px * c + ch].clamp(0.0, 1.0);
41 let bin = ((val * self.bins as f32) as usize).min(self.bins - 1);
42 hist[ch * self.bins + bin] += 1.0;
43 }
44 }
45 let norm: f32 = hist.iter().map(|v| v * v).sum::<f32>().sqrt();
47 if norm > 0.0 {
48 hist.iter_mut().for_each(|v| *v /= norm);
49 }
50 Ok(hist)
51 }
52
53 fn dim(&self) -> usize {
54 self.bins * 3
55 }
56}
57
58pub struct ReIdGallery {
61 features: HashMap<u64, Vec<Vec<f32>>>,
62 max_features: usize,
63}
64
65impl ReIdGallery {
66 pub fn new(max_features: usize) -> Self {
68 Self {
69 features: HashMap::new(),
70 max_features,
71 }
72 }
73
74 pub fn update(&mut self, track_id: u64, feature: Vec<f32>) {
77 let entry = self.features.entry(track_id).or_default();
78 entry.push(feature);
79 if entry.len() > self.max_features {
80 entry.remove(0);
81 }
82 }
83
84 pub fn remove(&mut self, track_id: u64) {
86 self.features.remove(&track_id);
87 }
88
89 pub fn min_cosine_distance(&self, track_id: u64, feature: &[f32]) -> f32 {
93 match self.features.get(&track_id) {
94 Some(gallery) if !gallery.is_empty() => gallery
95 .iter()
96 .map(|g| cosine_distance(feature, g))
97 .fold(f32::INFINITY, f32::min),
98 _ => 1.0,
99 }
100 }
101
102 pub fn cost_matrix(&self, track_ids: &[u64], features: &[Vec<f32>]) -> Vec<Vec<f32>> {
106 track_ids
107 .iter()
108 .map(|&tid| {
109 features
110 .iter()
111 .map(|f| self.min_cosine_distance(tid, f))
112 .collect()
113 })
114 .collect()
115 }
116}
117
118fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
120 let mut dot = 0.0_f32;
121 let mut norm_a = 0.0_f32;
122 let mut norm_b = 0.0_f32;
123 for (&ai, &bi) in a.iter().zip(b.iter()) {
124 dot += ai * bi;
125 norm_a += ai * ai;
126 norm_b += bi * bi;
127 }
128 let denom = norm_a.sqrt() * norm_b.sqrt();
129 if denom < 1e-12 {
130 return 1.0;
131 }
132 1.0 - (dot / denom)
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn color_histogram_reid_basic() {
141 let extractor = ColorHistogramReId::new(8);
142 assert_eq!(extractor.dim(), 24); let data = vec![
146 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.0, 0.1, 0.2, ];
151 let crop = Tensor::from_vec(vec![2, 2, 3], data).unwrap();
152 let embedding = extractor.extract(&crop).unwrap();
153
154 assert_eq!(embedding.len(), extractor.dim());
156
157 let norm: f32 = embedding.iter().map(|v| v * v).sum::<f32>().sqrt();
159 assert!((norm - 1.0).abs() < 1e-5, "Expected unit norm, got {norm}");
160 }
161
162 #[test]
163 fn reid_gallery_update_and_distance() {
164 let mut gallery = ReIdGallery::new(10);
165
166 assert!((gallery.min_cosine_distance(1, &[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-6);
168
169 gallery.update(1, vec![1.0, 0.0, 0.0]);
171
172 let dist = gallery.min_cosine_distance(1, &[1.0, 0.0, 0.0]);
174 assert!(
175 dist < 1e-5,
176 "Expected ~0 distance for identical vectors, got {dist}"
177 );
178
179 let dist = gallery.min_cosine_distance(1, &[0.0, 1.0, 0.0]);
181 assert!(
182 (dist - 1.0).abs() < 1e-5,
183 "Expected ~1 distance for orthogonal vectors, got {dist}"
184 );
185
186 gallery.update(1, vec![0.0, 1.0, 0.0]);
188 let dist = gallery.min_cosine_distance(1, &[0.0, 1.0, 0.0]);
189 assert!(
190 dist < 1e-5,
191 "Expected ~0 after adding matching feature, got {dist}"
192 );
193
194 gallery.remove(1);
196 assert!((gallery.min_cosine_distance(1, &[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-6);
197 }
198
199 #[test]
200 fn reid_gallery_cost_matrix() {
201 let mut gallery = ReIdGallery::new(10);
202
203 gallery.update(10, vec![1.0, 0.0, 0.0]);
204 gallery.update(20, vec![0.0, 1.0, 0.0]);
205
206 let track_ids = vec![10, 20];
207 let features = vec![
208 vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0], vec![0.0, 0.0, 1.0], ];
212
213 let matrix = gallery.cost_matrix(&track_ids, &features);
214
215 assert_eq!(matrix.len(), 2);
217 assert_eq!(matrix[0].len(), 3);
218 assert_eq!(matrix[1].len(), 3);
219
220 assert!(matrix[0][0] < 1e-5);
222 assert!((matrix[0][1] - 1.0).abs() < 1e-5);
224 assert!((matrix[1][0] - 1.0).abs() < 1e-5);
226 assert!(matrix[1][1] < 1e-5);
228 }
229}