use crate::Errors;
use anyhow::Result;
use itertools::Itertools;
use nalgebra::{Dynamic, OMatrix};
use std::cmp::Ordering;
use std::collections::HashMap;
use std::marker::PhantomData;
pub mod store;
pub mod voting;
pub type TrackDistance = (u64, Result<f32>);
pub type Feature = OMatrix<f32, Dynamic, Dynamic>;
pub type FeatureSpec = (f32, Feature);
pub type FeatureObservationsGroups = HashMap<u64, Vec<FeatureSpec>>;
pub trait Metric {
fn distance(feature_class: u64, e1: &FeatureSpec, e2: &FeatureSpec) -> Result<f32>;
fn optimize(
&mut self,
feature_class: &u64,
merge_history: &[u64],
observations: &mut Vec<FeatureSpec>,
prev_length: usize,
) -> Result<()>;
}
#[derive(Clone)]
pub enum TrackBakingStatus {
Ready,
Pending,
Wasted,
}
pub trait AttributeMatch<A> {
fn compatible(&self, other: &A) -> bool;
fn merge(&mut self, other: &A) -> Result<()>;
fn baked(&self, observations: &FeatureObservationsGroups) -> Result<TrackBakingStatus>;
}
pub trait AttributeUpdate<A> {
fn apply(&self, attrs: &mut A) -> Result<()>;
}
pub fn feat_confidence_cmp(e1: &FeatureSpec, e2: &FeatureSpec) -> Ordering {
e2.0.partial_cmp(&e1.0).unwrap()
}
#[derive(Default, Clone)]
pub struct Track<A, M, U>
where
A: Default + AttributeMatch<A> + Send + Sync + Clone,
M: Metric + Default + Send + Sync + Clone,
U: AttributeUpdate<A> + Send + Sync,
{
attributes: A,
track_id: u64,
observations: FeatureObservationsGroups,
metric: M,
phantom_attribute_update: PhantomData<U>,
merge_history: Vec<u64>,
}
impl<A, M, U> Track<A, M, U>
where
A: Default + AttributeMatch<A> + Send + Sync + Clone,
M: Metric + Default + Send + Sync + Clone,
U: AttributeUpdate<A> + Send + Sync,
{
pub fn new(track_id: u64, metric: Option<M>, attributes: Option<A>) -> Self {
Self {
attributes: if let Some(attributes) = attributes {
attributes
} else {
A::default()
},
track_id,
observations: Default::default(),
metric: if let Some(m) = metric {
m
} else {
M::default()
},
phantom_attribute_update: Default::default(),
merge_history: vec![track_id],
}
}
pub fn get_track_id(&self) -> u64 {
self.track_id
}
pub fn get_attributes(&self) -> &A {
&self.attributes
}
pub fn get_feature_classes(&self) -> Vec<u64> {
self.observations.keys().cloned().collect()
}
fn update_attributes(&mut self, update: U) -> Result<()> {
update.apply(&mut self.attributes)
}
pub fn add_observation(
&mut self,
feature_class: u64,
feature_q: f32,
feature: Feature,
attribute_update: U,
) -> Result<()> {
let last_attributes = self.attributes.clone();
let last_observations = self.observations.clone();
let last_metric = self.metric.clone();
let res = self.update_attributes(attribute_update);
if res.is_err() {
self.attributes = last_attributes;
res?;
unreachable!();
}
match self.observations.get_mut(&feature_class) {
None => {
self.observations
.insert(feature_class, vec![(feature_q, feature)]);
}
Some(observations) => {
observations.push((feature_q, feature));
}
}
let observations = self.observations.get_mut(&feature_class).unwrap();
let prev_length = observations.len() - 1;
let res = self.metric.optimize(
&feature_class,
&self.merge_history,
observations,
prev_length,
);
if res.is_err() {
self.attributes = last_attributes;
self.observations = last_observations;
self.metric = last_metric;
res?;
unreachable!();
}
Ok(())
}
pub fn merge(&mut self, other: &Self, classes: &[u64]) -> Result<()> {
let last_attributes = self.attributes.clone();
let res = self.attributes.merge(&other.attributes);
if res.is_err() {
self.attributes = last_attributes;
res?;
unreachable!();
}
let last_observations = self.observations.clone();
let last_metric = self.metric.clone();
for cls in classes {
let dest = self.observations.get_mut(cls);
let src = other.observations.get(cls);
let prev_length = match (dest, src) {
(Some(dest_observations), Some(src_observations)) => {
let prev_length = dest_observations.len();
dest_observations.extend(src_observations.iter().cloned());
Some(prev_length)
}
(None, Some(src_observations)) => {
self.observations.insert(*cls, src_observations.clone());
Some(0)
}
(Some(dest_observations), None) => {
let prev_length = dest_observations.len();
Some(prev_length)
}
_ => None,
};
if let Some(prev_length) = prev_length {
let res = self.metric.optimize(
cls,
&self.merge_history,
self.observations.get_mut(cls).unwrap(),
prev_length,
);
if res.is_err() {
self.attributes = last_attributes;
self.observations = last_observations;
self.metric = last_metric;
res?;
unreachable!();
}
}
}
Ok(())
}
pub fn distances(&self, other: &Self, feature_class: u64) -> Result<Vec<TrackDistance>> {
if !self.attributes.compatible(&other.attributes) {
Err(Errors::IncompatibleAttributes.into())
} else {
match (
self.observations.get(&feature_class),
other.observations.get(&feature_class),
) {
(Some(left), Some(right)) => Ok(left
.iter()
.cartesian_product(right.iter())
.map(|(l, r)| (other.track_id, M::distance(feature_class, l, r)))
.collect()),
_ => Err(Errors::ObservationForClassNotFound(
self.track_id,
other.track_id,
feature_class,
)
.into()),
}
}
}
}
#[cfg(test)]
mod tests {
use crate::distance::euclidean;
use crate::track::{
feat_confidence_cmp, AttributeMatch, AttributeUpdate, Feature, FeatureObservationsGroups,
FeatureSpec, Metric, Track, TrackBakingStatus,
};
use crate::EPS;
use anyhow::Result;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Default, Clone)]
pub struct DefaultAttrs;
#[derive(Default)]
pub struct DefaultAttrUpdates;
impl AttributeUpdate<DefaultAttrs> for DefaultAttrUpdates {
fn apply(&self, _attrs: &mut DefaultAttrs) -> Result<()> {
Ok(())
}
}
impl AttributeMatch<DefaultAttrs> for DefaultAttrs {
fn compatible(&self, _other: &DefaultAttrs) -> bool {
true
}
fn merge(&mut self, _other: &DefaultAttrs) -> Result<()> {
Ok(())
}
fn baked(&self, _observations: &FeatureObservationsGroups) -> Result<TrackBakingStatus> {
Ok(TrackBakingStatus::Pending)
}
}
#[derive(Default, Clone)]
struct DefaultMetric;
impl Metric for DefaultMetric {
fn distance(_feature_class: u64, e1: &FeatureSpec, e2: &FeatureSpec) -> Result<f32> {
Ok(euclidean(&e1.1, &e2.1))
}
fn optimize(
&mut self,
_feature_class: &u64,
_merge_history: &[u64],
features: &mut Vec<FeatureSpec>,
_prev_length: usize,
) -> Result<()> {
features.sort_by(feat_confidence_cmp);
features.truncate(20);
Ok(())
}
}
#[test]
fn init() {
let t1: Track<DefaultAttrs, DefaultMetric, DefaultAttrUpdates> = Track::new(3, None, None);
assert_eq!(t1.get_track_id(), 3);
}
#[test]
fn track_distances() -> Result<()> {
let mut t1: Track<DefaultAttrs, DefaultMetric, DefaultAttrUpdates> = Track::default();
t1.add_observation(
0,
0.3,
Feature::from_vec(1, 3, vec![1f32, 0.0, 0.0]),
DefaultAttrUpdates {},
)?;
let mut t2 = Track::default();
t2.add_observation(
0,
0.3,
Feature::from_vec(1, 3, vec![0f32, 1.0f32, 0.0]),
DefaultAttrUpdates {},
)?;
let dists = t1.distances(&t1, 0);
let dists = dists.unwrap();
assert_eq!(dists.len(), 1);
assert!(*dists[0].1.as_ref().unwrap() < EPS);
let dists = t1.distances(&t2, 0);
let dists = dists.unwrap();
assert_eq!(dists.len(), 1);
assert!((*dists[0].1.as_ref().unwrap() - 2.0_f32.sqrt()).abs() < EPS);
t2.add_observation(
0,
0.2,
Feature::from_vec(1, 3, vec![1f32, 1.0f32, 0.0]),
DefaultAttrUpdates {},
)?;
assert_eq!(t2.observations.get(&0).unwrap().len(), 2);
let dists = t1.distances(&t2, 0);
let dists = dists.unwrap();
assert_eq!(dists.len(), 2);
assert!((*dists[0].1.as_ref().unwrap() - 2.0_f32.sqrt()).abs() < EPS);
assert!((*dists[1].1.as_ref().unwrap() - 1.0).abs() < EPS);
Ok(())
}
#[test]
fn merge_same() -> Result<()> {
let mut t1: Track<DefaultAttrs, DefaultMetric, DefaultAttrUpdates> = Track::default();
t1.add_observation(
0,
0.3,
Feature::from_vec(1, 3, vec![1f32, 0.0, 0.0]),
DefaultAttrUpdates {},
)?;
let mut t2 = Track::default();
t2.add_observation(
0,
0.3,
Feature::from_vec(1, 3, vec![0f32, 1.0f32, 0.0]),
DefaultAttrUpdates {},
)?;
let r = t1.merge(&t2, &vec![0]);
assert!(r.is_ok());
assert_eq!(t1.observations.get(&0).unwrap().len(), 2);
Ok(())
}
#[test]
fn merge_other_feature_class() -> Result<()> {
let mut t1: Track<DefaultAttrs, DefaultMetric, DefaultAttrUpdates> = Track::default();
t1.add_observation(
0,
0.3,
Feature::from_vec(1, 3, vec![1f32, 0.0, 0.0]),
DefaultAttrUpdates {},
)?;
let mut t2 = Track::default();
t2.add_observation(
1,
0.3,
Feature::from_vec(1, 3, vec![0f32, 1.0f32, 0.0]),
DefaultAttrUpdates {},
)?;
let r = t1.merge(&t2, &vec![1]);
assert!(r.is_ok());
assert_eq!(t1.observations.get(&0).unwrap().len(), 1);
assert_eq!(t1.observations.get(&1).unwrap().len(), 1);
Ok(())
}
#[test]
fn attribute_compatible_match() -> Result<()> {
#[derive(Default, Debug, Clone)]
pub struct TimeAttrs {
start_time: u64,
end_time: u64,
}
#[derive(Default)]
pub struct TimeAttrUpdates {
time: u64,
}
impl AttributeUpdate<TimeAttrs> for TimeAttrUpdates {
fn apply(&self, attrs: &mut TimeAttrs) -> Result<()> {
attrs.end_time = self.time;
if attrs.start_time == 0 {
attrs.start_time = self.time;
}
Ok(())
}
}
impl AttributeMatch<TimeAttrs> for TimeAttrs {
fn compatible(&self, other: &TimeAttrs) -> bool {
self.end_time <= other.start_time
}
fn merge(&mut self, other: &TimeAttrs) -> Result<()> {
self.start_time = self.start_time.min(other.start_time);
self.end_time = self.end_time.max(other.end_time);
Ok(())
}
fn baked(
&self,
_observations: &FeatureObservationsGroups,
) -> Result<TrackBakingStatus> {
if SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
- self.end_time
> 30
{
Ok(TrackBakingStatus::Ready)
} else {
Ok(TrackBakingStatus::Pending)
}
}
}
#[derive(Default, Clone)]
struct TimeMetric;
impl Metric for TimeMetric {
fn distance(_feature_class: u64, e1: &FeatureSpec, e2: &FeatureSpec) -> Result<f32> {
Ok(euclidean(&e1.1, &e2.1))
}
fn optimize(
&mut self,
_feature_class: &u64,
_merge_history: &[u64],
features: &mut Vec<FeatureSpec>,
_prev_length: usize,
) -> Result<()> {
features.sort_by(feat_confidence_cmp);
features.truncate(20);
Ok(())
}
}
let mut t1: Track<TimeAttrs, TimeMetric, TimeAttrUpdates> = Track::default();
t1.track_id = 1;
t1.add_observation(
0,
0.3,
Feature::from_vec(1, 3, vec![1f32, 0.0, 0.0]),
TimeAttrUpdates { time: 2 },
)?;
let mut t2 = Track::default();
t2.add_observation(
0,
0.3,
Feature::from_vec(1, 3, vec![0f32, 1.0f32, 0.0]),
TimeAttrUpdates { time: 3 },
)?;
t2.track_id = 2;
let dists = t1.distances(&t2, 0);
let dists = dists.unwrap();
assert_eq!(dists.len(), 1);
assert!((*dists[0].1.as_ref().unwrap() - 2.0_f32.sqrt()).abs() < EPS);
assert_eq!(dists[0].0, 2);
let mut t3 = Track::default();
t3.add_observation(
0,
0.3,
Feature::from_vec(1, 3, vec![0f32, 1.0f32, 0.0]),
TimeAttrUpdates { time: 1 },
)?;
let dists = t1.distances(&t3, 0);
assert!(dists.is_err());
Ok(())
}
#[test]
fn get_classes() -> Result<()> {
let mut t1: Track<DefaultAttrs, DefaultMetric, DefaultAttrUpdates> = Track::default();
t1.add_observation(
0,
0.3,
Feature::from_vec(1, 3, vec![1f32, 0.0, 0.0]),
DefaultAttrUpdates {},
)?;
t1.add_observation(
1,
0.3,
Feature::from_vec(1, 3, vec![0f32, 1.0f32, 0.0]),
DefaultAttrUpdates {},
)?;
let mut classes = t1.get_feature_classes();
classes.sort();
assert_eq!(classes, vec![0, 1]);
Ok(())
}
#[test]
fn attr_metric_update_recover() {
use thiserror::Error;
#[derive(Error, Debug)]
enum TestError {
#[error("Update Error")]
UpdateError,
#[error("MergeError")]
MergeError,
#[error("OptimizeError")]
OptimizeError,
}
#[derive(Default, Clone, PartialEq, Debug)]
pub struct DefaultAttrs {
pub count: u32,
}
#[derive(Default)]
pub struct DefaultAttrUpdates {
ignore: bool,
}
impl AttributeUpdate<DefaultAttrs> for DefaultAttrUpdates {
fn apply(&self, attrs: &mut DefaultAttrs) -> Result<()> {
if !self.ignore {
attrs.count += 1;
if attrs.count > 1 {
Err(TestError::UpdateError.into())
} else {
Ok(())
}
} else {
Ok(())
}
}
}
impl AttributeMatch<DefaultAttrs> for DefaultAttrs {
fn compatible(&self, _other: &DefaultAttrs) -> bool {
true
}
fn merge(&mut self, _other: &DefaultAttrs) -> Result<()> {
Err(TestError::MergeError.into())
}
fn baked(
&self,
_observations: &FeatureObservationsGroups,
) -> Result<TrackBakingStatus> {
Ok(TrackBakingStatus::Pending)
}
}
#[derive(Default, Clone)]
struct DefaultMetric;
impl Metric for DefaultMetric {
fn distance(_feature_class: u64, e1: &FeatureSpec, e2: &FeatureSpec) -> Result<f32> {
Ok(euclidean(&e1.1, &e2.1))
}
fn optimize(
&mut self,
_feature_class: &u64,
_merge_history: &[u64],
_features: &mut Vec<FeatureSpec>,
prev_length: usize,
) -> Result<()> {
if prev_length == 1 {
Err(TestError::OptimizeError.into())
} else {
Ok(())
}
}
}
let mut t1: Track<DefaultAttrs, DefaultMetric, DefaultAttrUpdates> = Track::default();
assert_eq!(t1.attributes, DefaultAttrs { count: 0 });
let res = t1.add_observation(
0,
0.3,
Feature::from_vec(1, 3, vec![1f32, 0.0, 0.0]),
DefaultAttrUpdates { ignore: false },
);
assert!(res.is_ok());
assert_eq!(t1.attributes, DefaultAttrs { count: 1 });
let res = t1.add_observation(
0,
0.3,
Feature::from_vec(1, 3, vec![1f32, 0.0, 0.0]),
DefaultAttrUpdates { ignore: true },
);
assert!(res.is_err());
if let Err(e) = res {
match e.root_cause().downcast_ref::<TestError>().unwrap() {
TestError::UpdateError | TestError::MergeError => {
unreachable!();
}
TestError::OptimizeError => {}
}
} else {
unreachable!();
}
assert_eq!(t1.attributes, DefaultAttrs { count: 1 });
let mut t2: Track<DefaultAttrs, DefaultMetric, DefaultAttrUpdates> = Track::default();
assert_eq!(t2.attributes, DefaultAttrs { count: 0 });
let res = t2.add_observation(
0,
0.3,
Feature::from_vec(1, 3, vec![1f32, 0.0, 0.0]),
DefaultAttrUpdates { ignore: false },
);
assert!(res.is_ok());
assert_eq!(t2.attributes, DefaultAttrs { count: 1 });
let res = t1.merge(&t2, &vec![0]);
if let Err(e) = res {
match e.root_cause().downcast_ref::<TestError>().unwrap() {
TestError::UpdateError | TestError::OptimizeError => {
unreachable!();
}
TestError::MergeError => {}
}
} else {
unreachable!();
}
assert_eq!(t1.attributes, DefaultAttrs { count: 1 });
}
}