use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Metric {
Euclidean,
Cosine,
}
#[derive(Debug, Clone)]
pub struct NearestNeighborDistanceMetric {
metric: Metric,
matching_threshold: f32,
budget: Option<usize>,
samples: HashMap<u64, Vec<Vec<f32>>>,
}
impl NearestNeighborDistanceMetric {
pub fn new(metric: Metric, matching_threshold: f32, budget: Option<usize>) -> Self {
Self {
metric,
matching_threshold,
budget,
samples: HashMap::new(),
}
}
pub fn partial_fit(&mut self, features: &[(u64, Vec<f32>)], active_targets: &[u64]) {
for (track_id, feature) in features {
let sample_list = self.samples.entry(*track_id).or_default();
sample_list.push(feature.clone());
if let Some(b) = self.budget {
if sample_list.len() > b {
let remove_count = sample_list.len() - b;
sample_list.drain(0..remove_count);
}
}
}
self.samples.retain(|k, _| active_targets.contains(k));
}
pub fn distance(&self, features: &[Vec<f32>], targets: &[u64]) -> Vec<Vec<f32>> {
let mut cost_matrix = vec![vec![0.0; features.len()]; targets.len()];
for (i, track_id) in targets.iter().enumerate() {
let sample_list = match self.samples.get(track_id) {
Some(s) => s,
None => {
for cell in cost_matrix[i].iter_mut() {
*cell = f32::MAX;
}
continue;
}
};
for (j, feature) in features.iter().enumerate() {
cost_matrix[i][j] = self.compute_min_distance(sample_list, feature);
}
}
cost_matrix
}
fn compute_min_distance(&self, samples: &[Vec<f32>], feature: &[f32]) -> f32 {
let mut min_dist = f32::MAX;
for sample in samples {
let dist = match self.metric {
Metric::Euclidean => euclidean_distance(sample, feature),
Metric::Cosine => cosine_distance(sample, feature),
};
if dist < min_dist {
min_dist = dist;
}
}
min_dist
}
pub fn matching_threshold(&self) -> f32 {
self.matching_threshold
}
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
let mut sum = 0.0;
for (x, y) in a.iter().zip(b.iter()) {
sum += (x - y).powi(2);
}
sum.sqrt()
}
fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
let cosine_sim = if norm_a > 1e-6 && norm_b > 1e-6 {
dot / (norm_a * norm_b)
} else {
0.0
};
(1.0 - cosine_sim).max(0.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_euclidean() {
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
assert!((euclidean_distance(&a, &b) - 5.0).abs() < 1e-5);
}
#[test]
fn test_cosine() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!((cosine_distance(&a, &b) - 1.0).abs() < 1e-5);
let c = vec![1.0, 0.0];
assert!((cosine_distance(&a, &c)).abs() < 1e-5);
}
#[test]
fn test_cosine_parallel() {
let a = vec![1.0, 1.0];
let b = vec![2.0, 2.0]; assert!(cosine_distance(&a, &b) < 0.01);
}
#[test]
fn test_cosine_opposite() {
let a = vec![1.0, 0.0];
let b = vec![-1.0, 0.0];
assert!((cosine_distance(&a, &b) - 2.0).abs() < 0.01);
}
#[test]
fn test_cosine_zero_norm() {
let a = vec![0.0, 0.0];
let b = vec![1.0, 1.0];
assert!((cosine_distance(&a, &b) - 1.0).abs() < 0.01);
}
#[test]
fn test_metric_budget() {
let mut metric = NearestNeighborDistanceMetric::new(Metric::Euclidean, 0.5, Some(2));
metric.partial_fit(&[(1, vec![1.0]), (1, vec![2.0]), (1, vec![3.0])], &[1]);
let samples = metric.samples.get(&1).unwrap();
assert_eq!(samples.len(), 2);
assert_eq!(samples[0], vec![2.0]);
assert_eq!(samples[1], vec![3.0]);
}
#[test]
fn test_metric_no_budget() {
let mut metric = NearestNeighborDistanceMetric::new(Metric::Cosine, 0.3, None);
metric.partial_fit(
&[
(1, vec![1.0, 0.0]),
(1, vec![0.9, 0.1]),
(1, vec![0.8, 0.2]),
(1, vec![0.7, 0.3]),
],
&[1],
);
let samples = metric.samples.get(&1).unwrap();
assert_eq!(samples.len(), 4); }
#[test]
fn test_metric_inactive_removal() {
let mut metric = NearestNeighborDistanceMetric::new(Metric::Euclidean, 0.5, Some(10));
metric.partial_fit(&[(1, vec![1.0]), (2, vec![2.0])], &[1, 2]);
assert!(metric.samples.contains_key(&1));
assert!(metric.samples.contains_key(&2));
metric.partial_fit(&[(1, vec![1.5])], &[1]);
assert!(metric.samples.contains_key(&1));
assert!(!metric.samples.contains_key(&2));
}
#[test]
fn test_distance_matrix() {
let mut metric = NearestNeighborDistanceMetric::new(Metric::Euclidean, 0.5, Some(10));
metric.partial_fit(&[(1, vec![0.0, 0.0])], &[1]);
let features = vec![vec![0.0, 0.0], vec![3.0, 4.0]]; let cost_matrix = metric.distance(&features, &[1]);
assert_eq!(cost_matrix.len(), 1); assert_eq!(cost_matrix[0].len(), 2); assert!(cost_matrix[0][0] < 0.01); assert!((cost_matrix[0][1] - 5.0).abs() < 0.01); }
#[test]
fn test_distance_no_samples() {
let metric = NearestNeighborDistanceMetric::new(Metric::Euclidean, 0.5, Some(10));
let features = vec![vec![0.0, 0.0]];
let cost_matrix = metric.distance(&features, &[1]);
assert_eq!(cost_matrix.len(), 1);
assert_eq!(cost_matrix[0][0], f32::MAX);
}
#[test]
fn test_matching_threshold() {
let metric = NearestNeighborDistanceMetric::new(Metric::Cosine, 0.25, Some(10));
assert!((metric.matching_threshold() - 0.25).abs() < 1e-6);
}
#[test]
fn test_min_distance_multiple_samples() {
let mut metric = NearestNeighborDistanceMetric::new(Metric::Euclidean, 0.5, Some(10));
metric.partial_fit(
&[
(1, vec![0.0, 0.0]),
(1, vec![10.0, 0.0]),
(1, vec![5.0, 0.0]),
],
&[1],
);
let features = vec![vec![1.0, 0.0]];
let cost_matrix = metric.distance(&features, &[1]);
assert!((cost_matrix[0][0] - 1.0).abs() < 0.01);
}
}