use crate::track::notify::{ChangeNotifier, NoopNotifier};
use crate::Errors;
use anyhow::Result;
use itertools::Itertools;
use std::collections::HashMap;
use std::fmt::Debug;
use std::mem::take;
use ultraviolet::f32x8;
pub mod notify;
pub mod store;
pub mod utils;
pub mod voting;
#[derive(Debug, Clone)]
pub struct ObservationMetricResult<M>
where
M: ObservationAttributes,
{
pub from: u64,
pub to: u64,
pub attribute_metric: Option<M::MetricObject>,
pub feature_distance: Option<f32>,
}
impl<M> ObservationMetricResult<M>
where
M: ObservationAttributes,
{
pub fn new(
from: u64,
to: u64,
attribute_metric: Option<M::MetricObject>,
feature_distance: Option<f32>,
) -> Self {
Self {
from,
to,
attribute_metric,
feature_distance,
}
}
}
pub type Observation = Vec<f32x8>;
const FEATURE_LANES_SIZE: usize = 8;
#[derive(Default, Clone)]
pub struct ObservationSpec<T>(pub Option<T>, pub Option<Observation>)
where
T: Default + Send + Sync + Clone + 'static + PartialOrd;
pub type ObservationsDb<T> = HashMap<u64, Vec<ObservationSpec<T>>>;
pub trait ObservationAttributes: Default + Send + Sync + Clone + PartialOrd + 'static {
type MetricObject: Debug + Default + Send + Sync + Clone + PartialOrd + 'static;
fn calculate_metric_object(l: &Option<Self>, r: &Option<Self>) -> Option<Self::MetricObject>;
}
pub type MetricOutput<T> = Option<(Option<T>, Option<f32>)>;
pub trait ObservationMetric<TA, FA: ObservationAttributes>:
Default + Send + Sync + Clone + 'static
{
fn metric(
feature_class: u64,
left_attrs: &TA,
right_attrs: &TA,
left_observation: &ObservationSpec<FA>,
right_observation: &ObservationSpec<FA>,
) -> MetricOutput<FA::MetricObject>;
fn optimize(
&mut self,
feature_class: &u64,
merge_history: &[u64],
attributes: &mut TA,
observations: &mut Vec<ObservationSpec<FA>>,
prev_length: usize,
is_merge: bool,
) -> Result<()>;
fn postprocess_distances(
&self,
unfiltered: Vec<ObservationMetricResult<FA>>,
) -> Vec<ObservationMetricResult<FA>> {
unfiltered
}
}
#[derive(Clone, Debug)]
pub enum TrackStatus {
Ready,
Pending,
Wasted,
}
pub trait TrackAttributes<TA, FA: ObservationAttributes>:
Default + Send + Sync + Clone + 'static
{
type Update: TrackAttributesUpdate<TA>;
fn compatible(&self, other: &TA) -> bool;
fn merge(&mut self, other: &TA) -> Result<()>;
fn baked(&self, observations: &ObservationsDb<FA>) -> Result<TrackStatus>;
}
pub trait TrackAttributesUpdate<TA>: Clone + Send + Sync + 'static {
fn apply(&self, attrs: &mut TA) -> Result<()>;
}
#[derive(Default, Clone)]
pub struct Track<TA, M, FA, N = NoopNotifier>
where
FA: ObservationAttributes,
N: ChangeNotifier,
TA: TrackAttributes<TA, FA>,
M: ObservationMetric<TA, FA>,
{
attributes: TA,
track_id: u64,
observations: ObservationsDb<FA>,
metric: M,
merge_history: Vec<u64>,
notifier: N,
}
impl<TA, M, FA, N> Track<TA, M, FA, N>
where
FA: ObservationAttributes,
N: ChangeNotifier,
TA: TrackAttributes<TA, FA>,
M: ObservationMetric<TA, FA>,
{
pub fn new(
track_id: u64,
metric: Option<M>,
attributes: Option<TA>,
notifier: Option<N>,
) -> Self {
let mut v = Self {
notifier: if let Some(notifier) = notifier {
notifier
} else {
N::default()
},
attributes: if let Some(attributes) = attributes {
attributes
} else {
TA::default()
},
track_id,
observations: Default::default(),
metric: if let Some(m) = metric {
m
} else {
M::default()
},
merge_history: vec![track_id],
};
v.notifier.send(track_id);
v
}
pub fn get_track_id(&self) -> u64 {
self.track_id
}
pub fn set_track_id(&mut self, track_id: u64) -> u64 {
let old = self.track_id;
self.track_id = track_id;
old
}
pub fn get_attributes(&self) -> &TA {
&self.attributes
}
pub fn get_observations(&self, feature_class: u64) -> Option<&Vec<ObservationSpec<FA>>> {
self.observations.get(&feature_class)
}
pub fn get_attributes_mut(&mut self) -> &mut TA {
&mut self.attributes
}
pub fn get_merge_history(&self) -> &Vec<u64> {
&self.merge_history
}
pub fn get_feature_classes(&self) -> Vec<u64> {
self.observations.keys().cloned().collect()
}
fn update_attributes(&mut self, update: TA::Update) -> Result<()> {
update.apply(&mut self.attributes)
}
pub fn add_observation(
&mut self,
feature_class: u64,
feature_attributes: Option<FA>,
feature: Option<Observation>,
track_attributes_update: Option<TA::Update>,
) -> Result<()> {
let last_attributes = self.attributes.clone();
let last_observations = self.observations.clone();
let last_metric = self.metric.clone();
if let Some(track_attributes_update) = track_attributes_update {
let res = self.update_attributes(track_attributes_update);
if res.is_err() {
self.attributes = last_attributes;
res?;
unreachable!();
}
}
if feature.is_none() && feature_attributes.is_none() {
return Ok(());
}
match self.observations.get_mut(&feature_class) {
None => {
self.observations.insert(
feature_class,
vec![ObservationSpec(feature_attributes, feature)],
);
}
Some(observations) => {
observations.push(ObservationSpec(feature_attributes, feature));
}
}
let observations = self.observations.get_mut(&feature_class).unwrap();
let prev_length = observations.len() - 1;
let res = self.metric.optimize(
&feature_class,
&self.merge_history,
&mut self.attributes,
observations,
prev_length,
false,
);
if res.is_err() {
self.attributes = last_attributes;
self.observations = last_observations;
self.metric = last_metric;
res?;
unreachable!();
}
self.notifier.send(self.track_id);
Ok(())
}
pub fn merge(&mut self, other: &Self, classes: &[u64], merge_history: bool) -> Result<()> {
let last_attributes = self.attributes.clone();
let res = self.attributes.merge(&other.attributes);
if res.is_err() {
self.attributes = last_attributes;
res?;
unreachable!();
}
let last_observations = self.observations.clone();
let last_metric = self.metric.clone();
for cls in classes {
let dest = self.observations.get_mut(cls);
let src = other.observations.get(cls);
let prev_length = match (dest, src) {
(Some(dest_observations), Some(src_observations)) => {
let prev_length = dest_observations.len();
dest_observations.extend(src_observations.iter().cloned());
Some(prev_length)
}
(None, Some(src_observations)) => {
self.observations.insert(*cls, src_observations.clone());
Some(0)
}
(Some(dest_observations), None) => {
let prev_length = dest_observations.len();
Some(prev_length)
}
_ => None,
};
let merge_history = if merge_history {
vec![self.merge_history.clone(), other.merge_history.clone()]
.into_iter()
.flatten()
.collect::<Vec<_>>()
} else {
take(&mut self.merge_history)
};
if let Some(prev_length) = prev_length {
let res = self.metric.optimize(
cls,
&merge_history,
&mut self.attributes,
self.observations.get_mut(cls).unwrap(),
prev_length,
true,
);
if res.is_err() {
self.attributes = last_attributes;
self.observations = last_observations;
self.metric = last_metric;
res?;
unreachable!();
}
self.merge_history = merge_history;
}
}
self.notifier.send(self.track_id);
Ok(())
}
pub fn distances(
&self,
other: &Self,
feature_class: u64,
) -> Result<Vec<ObservationMetricResult<FA>>> {
if !self.attributes.compatible(&other.attributes) {
Err(Errors::IncompatibleAttributes.into())
} else {
match (
self.observations.get(&feature_class),
other.observations.get(&feature_class),
) {
(Some(left), Some(right)) => Ok(left
.iter()
.cartesian_product(right.iter())
.flat_map(|(l, r)| {
let (attribute_metric, feature_distance) = M::metric(
feature_class,
self.get_attributes(),
other.get_attributes(),
l,
r,
)?;
Some(ObservationMetricResult {
from: self.track_id,
to: other.track_id,
attribute_metric,
feature_distance,
})
})
.collect()),
_ => Err(Errors::ObservationForClassNotFound(
self.track_id,
other.track_id,
feature_class,
)
.into()),
}
}
}
}
#[cfg(test)]
mod tests {
use crate::distance::euclidean;
use crate::test_stuff::current_time_sec;
use crate::track::utils::{feature_attributes_sort_dec, FromVec};
use crate::track::{
MetricOutput, Observation, ObservationAttributes, ObservationMetric, ObservationSpec,
ObservationsDb, Track, TrackAttributes, TrackAttributesUpdate, TrackStatus,
};
use crate::EPS;
use anyhow::Result;
#[derive(Default, Clone)]
pub struct DefaultAttrs;
#[derive(Default, Clone)]
pub struct DefaultAttrUpdates;
impl TrackAttributesUpdate<DefaultAttrs> for DefaultAttrUpdates {
fn apply(&self, _attrs: &mut DefaultAttrs) -> Result<()> {
Ok(())
}
}
impl TrackAttributes<DefaultAttrs, f32> for DefaultAttrs {
type Update = DefaultAttrUpdates;
fn compatible(&self, _other: &DefaultAttrs) -> bool {
true
}
fn merge(&mut self, _other: &DefaultAttrs) -> Result<()> {
Ok(())
}
fn baked(&self, _observations: &ObservationsDb<f32>) -> Result<TrackStatus> {
Ok(TrackStatus::Pending)
}
}
#[derive(Default, Clone)]
struct DefaultMetric;
impl ObservationMetric<DefaultAttrs, f32> for DefaultMetric {
fn metric(
_feature_class: u64,
_attrs1: &DefaultAttrs,
_attrs2: &DefaultAttrs,
e1: &ObservationSpec<f32>,
e2: &ObservationSpec<f32>,
) -> MetricOutput<f32> {
Some((
f32::calculate_metric_object(&e1.0, &e2.0),
match (e1.1.as_ref(), e2.1.as_ref()) {
(Some(x), Some(y)) => Some(euclidean(x, y)),
_ => None,
},
))
}
fn optimize(
&mut self,
_feature_class: &u64,
_merge_history: &[u64],
_attributes: &mut DefaultAttrs,
features: &mut Vec<ObservationSpec<f32>>,
_prev_length: usize,
_is_merge: bool,
) -> Result<()> {
features.sort_by(feature_attributes_sort_dec);
features.truncate(20);
Ok(())
}
}
#[test]
fn init() {
let t1: Track<DefaultAttrs, DefaultMetric, f32> = Track::new(3, None, None, None);
assert_eq!(t1.get_track_id(), 3);
}
#[test]
fn track_distances() -> Result<()> {
let mut t1: Track<DefaultAttrs, DefaultMetric, f32> = Track::default();
t1.add_observation(
0,
Some(0.3),
Some(Observation::from_vec(vec![1f32, 0.0, 0.0])),
None,
)?;
let mut t2 = Track::default();
t2.add_observation(
0,
Some(0.3),
Some(Observation::from_vec(vec![0f32, 1.0f32, 0.0])),
None,
)?;
let dists = t1.distances(&t1, 0);
let dists = dists.unwrap();
assert_eq!(dists.len(), 1);
assert!(*dists[0].feature_distance.as_ref().unwrap() < EPS);
let dists = t1.distances(&t2, 0);
let dists = dists.unwrap();
assert_eq!(dists.len(), 1);
assert!((*dists[0].feature_distance.as_ref().unwrap() - 2.0_f32.sqrt()).abs() < EPS);
t2.add_observation(
0,
Some(0.2),
Some(Observation::from_vec(vec![1f32, 1.0f32, 0.0])),
None,
)?;
assert_eq!(t2.observations.get(&0).unwrap().len(), 2);
let dists = t1.distances(&t2, 0);
let dists = dists.unwrap();
assert_eq!(dists.len(), 2);
assert!((*dists[0].feature_distance.as_ref().unwrap() - 2.0_f32.sqrt()).abs() < EPS);
assert!((*dists[1].feature_distance.as_ref().unwrap() - 1.0).abs() < EPS);
Ok(())
}
#[test]
fn merge_same() -> Result<()> {
let mut t1: Track<DefaultAttrs, DefaultMetric, f32> = Track::default();
t1.add_observation(
0,
Some(0.3),
Some(Observation::from_vec(vec![1f32, 0.0, 0.0])),
None,
)?;
let mut t2 = Track::default();
t2.add_observation(
0,
Some(0.3),
Some(Observation::from_vec(vec![0f32, 1.0f32, 0.0])),
None,
)?;
let r = t1.merge(&t2, &vec![0], true);
assert!(r.is_ok());
assert_eq!(t1.observations.get(&0).unwrap().len(), 2);
Ok(())
}
#[test]
fn merge_other_feature_class() -> Result<()> {
let mut t1: Track<DefaultAttrs, DefaultMetric, f32> = Track::default();
t1.add_observation(
0,
Some(0.3),
Some(Observation::from_vec(vec![1f32, 0.0, 0.0])),
None,
)?;
let mut t2 = Track::default();
t2.add_observation(
1,
Some(0.3),
Some(Observation::from_vec(vec![0f32, 1.0f32, 0.0])),
None,
)?;
let r = t1.merge(&t2, &vec![1], true);
assert!(r.is_ok());
assert_eq!(t1.observations.get(&0).unwrap().len(), 1);
assert_eq!(t1.observations.get(&1).unwrap().len(), 1);
Ok(())
}
#[test]
fn attribute_compatible_match() -> Result<()> {
#[derive(Default, Debug, Clone)]
pub struct TimeAttrs {
start_time: u64,
end_time: u64,
}
#[derive(Default, Clone)]
pub struct TimeAttrUpdates {
time: u64,
}
impl TrackAttributesUpdate<TimeAttrs> for TimeAttrUpdates {
fn apply(&self, attrs: &mut TimeAttrs) -> Result<()> {
attrs.end_time = self.time;
if attrs.start_time == 0 {
attrs.start_time = self.time;
}
Ok(())
}
}
impl TrackAttributes<TimeAttrs, f32> for TimeAttrs {
type Update = TimeAttrUpdates;
fn compatible(&self, other: &TimeAttrs) -> bool {
self.end_time <= other.start_time
}
fn merge(&mut self, other: &TimeAttrs) -> Result<()> {
self.start_time = self.start_time.min(other.start_time);
self.end_time = self.end_time.max(other.end_time);
Ok(())
}
fn baked(&self, _observations: &ObservationsDb<f32>) -> Result<TrackStatus> {
if current_time_sec() - self.end_time > 30 {
Ok(TrackStatus::Ready)
} else {
Ok(TrackStatus::Pending)
}
}
}
#[derive(Default, Clone)]
struct TimeMetric;
impl ObservationMetric<TimeAttrs, f32> for TimeMetric {
fn metric(
_feature_class: u64,
_attrs1: &TimeAttrs,
_attrs2: &TimeAttrs,
e1: &ObservationSpec<f32>,
e2: &ObservationSpec<f32>,
) -> MetricOutput<f32> {
Some((
f32::calculate_metric_object(&e1.0, &e2.0),
match (e1.1.as_ref(), e2.1.as_ref()) {
(Some(x), Some(y)) => Some(euclidean(x, y)),
_ => None,
},
))
}
fn optimize(
&mut self,
_feature_class: &u64,
_merge_history: &[u64],
_attributes: &mut TimeAttrs,
features: &mut Vec<ObservationSpec<f32>>,
_prev_length: usize,
_is_merge: bool,
) -> Result<()> {
features.sort_by(feature_attributes_sort_dec);
features.truncate(20);
Ok(())
}
}
let mut t1: Track<TimeAttrs, TimeMetric, f32> = Track::default();
t1.track_id = 1;
t1.add_observation(
0,
Some(0.3),
Some(Observation::from_vec(vec![1f32, 0.0, 0.0])),
Some(TimeAttrUpdates { time: 2 }),
)?;
let mut t2 = Track::default();
t2.add_observation(
0,
Some(0.3),
Some(Observation::from_vec(vec![0f32, 1.0f32, 0.0])),
Some(TimeAttrUpdates { time: 3 }),
)?;
t2.track_id = 2;
let dists = t1.distances(&t2, 0);
let dists = dists.unwrap();
assert_eq!(dists.len(), 1);
assert!((*dists[0].feature_distance.as_ref().unwrap() - 2.0_f32.sqrt()).abs() < EPS);
assert_eq!(dists[0].to, 2);
let mut t3 = Track::default();
t3.add_observation(
0,
Some(0.3),
Some(Observation::from_vec(vec![0f32, 1.0f32, 0.0])),
Some(TimeAttrUpdates { time: 1 }),
)?;
let dists = t1.distances(&t3, 0);
assert!(dists.is_err());
Ok(())
}
#[test]
fn get_classes() -> Result<()> {
let mut t1: Track<DefaultAttrs, DefaultMetric, f32> = Track::default();
t1.add_observation(
0,
Some(0.3),
Some(Observation::from_vec(vec![1f32, 0.0, 0.0])),
None,
)?;
t1.add_observation(
1,
Some(0.3),
Some(Observation::from_vec(vec![0f32, 1.0f32, 0.0])),
None,
)?;
let mut classes = t1.get_feature_classes();
classes.sort();
assert_eq!(classes, vec![0, 1]);
Ok(())
}
#[test]
fn attr_metric_update_recover() {
use thiserror::Error;
#[derive(Error, Debug)]
enum TestError {
#[error("Update Error")]
UpdateError,
#[error("MergeError")]
MergeError,
#[error("OptimizeError")]
OptimizeError,
}
#[derive(Default, Clone, PartialEq, Debug)]
pub struct DefaultAttrs {
pub count: u32,
}
#[derive(Default, Clone)]
pub struct DefaultAttrUpdates {
ignore: bool,
}
impl TrackAttributesUpdate<DefaultAttrs> for DefaultAttrUpdates {
fn apply(&self, attrs: &mut DefaultAttrs) -> Result<()> {
if !self.ignore {
attrs.count += 1;
if attrs.count > 1 {
Err(TestError::UpdateError.into())
} else {
Ok(())
}
} else {
Ok(())
}
}
}
impl TrackAttributes<DefaultAttrs, f32> for DefaultAttrs {
type Update = DefaultAttrUpdates;
fn compatible(&self, _other: &DefaultAttrs) -> bool {
true
}
fn merge(&mut self, _other: &DefaultAttrs) -> Result<()> {
Err(TestError::MergeError.into())
}
fn baked(&self, _observations: &ObservationsDb<f32>) -> Result<TrackStatus> {
Ok(TrackStatus::Pending)
}
}
#[derive(Default, Clone)]
struct DefaultMetric;
impl ObservationMetric<DefaultAttrs, f32> for DefaultMetric {
fn metric(
_feature_class: u64,
_attrs1: &DefaultAttrs,
_attrs2: &DefaultAttrs,
e1: &ObservationSpec<f32>,
e2: &ObservationSpec<f32>,
) -> MetricOutput<f32> {
Some((
f32::calculate_metric_object(&e1.0, &e2.0),
match (e1.1.as_ref(), e2.1.as_ref()) {
(Some(x), Some(y)) => Some(euclidean(x, y)),
_ => None,
},
))
}
fn optimize(
&mut self,
_feature_class: &u64,
_merge_history: &[u64],
_attributes: &mut DefaultAttrs,
_features: &mut Vec<ObservationSpec<f32>>,
prev_length: usize,
_is_merge: bool,
) -> Result<()> {
if prev_length == 1 {
Err(TestError::OptimizeError.into())
} else {
Ok(())
}
}
}
let mut t1: Track<DefaultAttrs, DefaultMetric, f32> = Track::default();
assert_eq!(t1.attributes, DefaultAttrs { count: 0 });
let res = t1.add_observation(
0,
Some(0.3),
Some(Observation::from_vec(vec![1f32, 0.0, 0.0])),
Some(DefaultAttrUpdates { ignore: false }),
);
assert!(res.is_ok());
assert_eq!(t1.attributes, DefaultAttrs { count: 1 });
let res = t1.add_observation(
0,
Some(0.3),
Some(Observation::from_vec(vec![1f32, 0.0, 0.0])),
Some(DefaultAttrUpdates { ignore: true }),
);
assert!(res.is_err());
if let Err(e) = res {
match e.root_cause().downcast_ref::<TestError>().unwrap() {
TestError::UpdateError | TestError::MergeError => {
unreachable!();
}
TestError::OptimizeError => {}
}
} else {
unreachable!();
}
assert_eq!(t1.attributes, DefaultAttrs { count: 1 });
let mut t2: Track<DefaultAttrs, DefaultMetric, f32> = Track::default();
assert_eq!(t2.attributes, DefaultAttrs { count: 0 });
let res = t2.add_observation(
0,
Some(0.3),
Some(Observation::from_vec(vec![1f32, 0.0, 0.0])),
Some(DefaultAttrUpdates { ignore: false }),
);
assert!(res.is_ok());
assert_eq!(t2.attributes, DefaultAttrs { count: 1 });
let res = t1.merge(&t2, &vec![0], true);
if let Err(e) = res {
match e.root_cause().downcast_ref::<TestError>().unwrap() {
TestError::UpdateError | TestError::OptimizeError => {
unreachable!();
}
TestError::MergeError => {}
}
} else {
unreachable!();
}
assert_eq!(t1.attributes, DefaultAttrs { count: 1 });
}
#[test]
fn merge_history() -> Result<()> {
let mut t1: Track<DefaultAttrs, DefaultMetric, f32> = Track::new(0, None, None, None);
let mut t2 = Track::new(1, None, None, None);
t1.add_observation(
0,
Some(0.3),
Some(Observation::from_vec(vec![1f32, 0.0, 0.0])),
None,
)?;
t2.add_observation(
0,
Some(0.3),
Some(Observation::from_vec(vec![0f32, 1.0f32, 0.0])),
None,
)?;
let mut track_with_merge_history = t1.clone();
let _r = track_with_merge_history.merge(&t2, &vec![0], true);
assert_eq!(track_with_merge_history.merge_history, vec![0, 1]);
let _r = t1.merge(&t2, &vec![0], false);
assert_eq!(t1.merge_history, vec![0]);
Ok(())
}
#[test]
fn unit_track() {
#[derive(Default, Clone)]
pub struct UnitAttrs;
#[derive(Default, Clone)]
pub struct UnitAttrUpdates;
impl TrackAttributesUpdate<UnitAttrs> for UnitAttrUpdates {
fn apply(&self, _attrs: &mut UnitAttrs) -> Result<()> {
Ok(())
}
}
impl TrackAttributes<UnitAttrs, ()> for UnitAttrs {
type Update = UnitAttrUpdates;
fn compatible(&self, _other: &UnitAttrs) -> bool {
true
}
fn merge(&mut self, _other: &UnitAttrs) -> Result<()> {
Ok(())
}
fn baked(&self, _observations: &ObservationsDb<()>) -> Result<TrackStatus> {
Ok(TrackStatus::Pending)
}
}
#[derive(Default, Clone)]
struct UnitMetric;
impl ObservationMetric<UnitAttrs, ()> for UnitMetric {
fn metric(
_feature_class: u64,
_attrs1: &UnitAttrs,
_attrs2: &UnitAttrs,
e1: &ObservationSpec<()>,
e2: &ObservationSpec<()>,
) -> MetricOutput<()> {
Some((
None,
match (e1.1.as_ref(), e2.1.as_ref()) {
(Some(x), Some(y)) => Some(euclidean(x, y)),
_ => None,
},
))
}
fn optimize(
&mut self,
_feature_class: &u64,
_merge_history: &[u64],
_attributes: &mut UnitAttrs,
features: &mut Vec<ObservationSpec<()>>,
_prev_length: usize,
_is_merge: bool,
) -> Result<()> {
features.sort_by(feature_attributes_sort_dec);
features.truncate(20);
Ok(())
}
}
let _t1: Track<UnitAttrs, UnitMetric, ()> = Track::new(0, None, None, None);
}
}