cecile_supercool_tracker/
track.rs

1use crate::track::notify::{ChangeNotifier, NoopNotifier};
2use crate::Errors;
3use anyhow::Result;
4use itertools::Itertools;
5use std::collections::HashMap;
6use std::fmt::Debug;
7use std::marker::PhantomData;
8use std::mem::take;
9use ultraviolet::f32x8;
10
11pub mod builder;
12pub mod notify;
13pub mod store;
14pub mod utils;
15pub mod voting;
16
17/// Return type for distance between the current track's and other track observation pair
18///
19#[derive(Debug, Clone)]
20pub struct ObservationMetricOk<OA>
21where
22    OA: ObservationAttributes,
23{
24    /// source track ID
25    pub from: u64,
26    /// compared track ID
27    pub to: u64,
28    /// custom feature attribute metric object calculated for pairwise feature attributes
29    pub attribute_metric: Option<OA::MetricObject>,
30    /// distance calculated for pairwise feature vectors
31    pub feature_distance: Option<f32>,
32}
33
34impl<OA> ObservationMetricOk<OA>
35where
36    OA: ObservationAttributes,
37{
38    pub fn new(
39        from: u64,
40        to: u64,
41        attribute_metric: Option<OA::MetricObject>,
42        feature_distance: Option<f32>,
43    ) -> Self {
44        Self {
45            from,
46            to,
47            attribute_metric,
48            feature_distance,
49        }
50    }
51}
52
53/// Internal feature vector representation.
54///
55pub type Feature = Vec<f32x8>;
56
57/// Number of SIMD lanes used to store observation parts internally
58const FEATURE_LANES_SIZE: usize = 8;
59
60/// Observation specification.
61///
62/// It is a tuple struct of optional observation attributes (T) and optional feature vector itself.
63/// Observations are collected from the real world and placed into tracks. Later the observations are used
64/// to calculate the distances between tracks to make merging.
65///
66#[derive(Default, Clone)]
67pub struct Observation<T>(pub(crate) Option<T>, pub(crate) Option<Feature>)
68where
69    T: Send + Sync + Clone + 'static;
70
71impl<T> Observation<T>
72where
73    T: Send + Sync + Clone + 'static,
74{
75    pub fn new(attrs: Option<T>, feature: Option<Feature>) -> Self {
76        Self(attrs, feature)
77    }
78
79    /// Access to observation attributes
80    ///
81    pub fn attr(&self) -> &Option<T> {
82        &self.0
83    }
84
85    /// Access to observation attributes for modification purposes
86    ///
87    pub fn attr_mut(&mut self) -> &mut Option<T> {
88        &mut self.0
89    }
90
91    /// Access to observation feature
92    ///
93    pub fn feature(&self) -> &Option<Feature> {
94        &self.1
95    }
96
97    /// Access to observation feature for modification purposes
98    ///
99    pub fn feature_mut(&mut self) -> &mut Option<Feature> {
100        &mut self.1
101    }
102}
103
104/// HashTable that accumulates observations within the track.
105///
106/// The key is the feature class the value is the vector of observations collected.
107///
108pub type ObservationsDb<T> = HashMap<u64, Vec<Observation<T>>>;
109
110/// Custom observation attributes object that is the part of the observation together with the feature vector.
111///
112pub trait ObservationAttributes: Send + Sync + Clone + 'static {
113    type MetricObject: Debug + Send + Sync + Clone + 'static;
114    fn calculate_metric_object(l: &Option<&Self>, r: &Option<&Self>) -> Option<Self::MetricObject>;
115}
116
117/// Output result type used by metric when pairwise metric is calculated
118///
119/// `None` - no metric for that pair - the result will be dropped (optimization technique)
120/// `Some(Option<X>, Option<Y>)` - metric is calculated, values are inside.
121///  where
122///   * `Option<X>` is the metric object computed for observation attributes;
123///   * `Option<Y>` is the distance computed for feature vectors of the observation.
124///
125pub type MetricOutput<T> = Option<(Option<T>, Option<f32>)>;
126
127/// Query object that is a parameter of the ``ObservationMetric::metric` method.
128///
129/// The query is used to make pairwise comparison of observations for two tracks.
130/// There is a
131///  * `candidate` track - the one, that is selected as a comparison subject
132///  * `track` track - the one, that is iterated over those kept in the store
133///
134pub struct MetricQuery<'a, TA, OA: ObservationAttributes> {
135    /// * `feature_class` - class of currently used feature
136    pub feature_class: u64,
137    /// * `candidate_attrs` - candidate track attributes
138    pub candidate_attrs: &'a TA,
139    /// * `candidate_observation` - candidate track observation
140    pub candidate_observation: &'a Observation<OA>,
141    /// * `track_attrs` - track attributes
142    pub track_attrs: &'a TA,
143    /// * `track_observation` - track observation
144    pub track_observation: &'a Observation<OA>,
145}
146
147/// The trait that implements the methods for observations comparison, optimization and filtering.
148///
149/// This is the one of the most important elements of the track. It defines how track distances are
150/// computed, how track observations are compacted and transformed upon merging.
151///
152pub trait ObservationMetric<TA, OA: ObservationAttributes>: Send + Sync + Clone + 'static {
153    /// calculates the distance between two features.
154    ///
155    /// # Parameters
156    /// * `mq` - query to calculate metric
157    ///
158    fn metric(&self, mq: &'_ MetricQuery<'_, TA, OA>) -> MetricOutput<OA::MetricObject>;
159
160    /// the method is used every time, when a new observation is added to the feature storage as well as when
161    /// two tracks are merged.
162    ///
163    /// # Arguments
164    ///
165    /// * `feature_class` - the feature class
166    /// * `merge_history` - the vector of track identifiers collected upon every merge
167    /// * `attributes` - mutable track attributes that can be updated or read during optimization
168    /// * `observations` - observations to optimize
169    /// * `prev_length` - previous length of observations (before the current observation was added or merge occurred)
170    /// * `is_merge` - true, when the op is for track merging, false when the observation is added to the track
171    ///
172    /// # Returns
173    /// * `Ok(())` if the optimization is successful
174    /// * `Err(e)` if the optimization failed
175    ///
176    fn optimize(
177        &mut self,
178        feature_class: u64,
179        merge_history: &[u64],
180        attributes: &mut TA,
181        observations: &mut Vec<Observation<OA>>,
182        prev_length: usize,
183        is_merge: bool,
184    ) -> Result<()>;
185
186    /// The postprocessing is run just before the executor returns calculated distances.
187    ///
188    /// The postprocessing is aimed to remove non-viable, invalid distances that can be skipped
189    /// to improve the performance or the quality of further track voting process.
190    ///
191    fn postprocess_distances(
192        &self,
193        unfiltered: Vec<ObservationMetricOk<OA>>,
194    ) -> Vec<ObservationMetricOk<OA>> {
195        unfiltered
196    }
197}
198
199/// Enum which specifies the status of feature tracks in storage. When the feature tracks are collected,
200/// eventually the track must be complete so it can be used for
201/// database search and later merge operations.
202///
203#[derive(Clone, Debug)]
204pub enum TrackStatus {
205    /// The track is ready and can be used to find similar tracks for merge.
206    Ready,
207    /// The track is not ready and still being collected.
208    Pending,
209    /// The track is invalid because somehow became incorrect or outdated along the way.
210    Wasted,
211}
212
213/// The trait that must be implemented by a search query object to run searches over the store
214///
215pub trait LookupRequest<TA, OA>: Send + Sync + Clone + 'static
216where
217    TA: TrackAttributes<TA, OA>,
218    OA: ObservationAttributes,
219{
220    fn lookup(
221        &self,
222        _attributes: &TA,
223        _observations: &ObservationsDb<OA>,
224        _merge_history: &[u64],
225    ) -> bool {
226        false
227    }
228}
229
230/// Do nothing lookup implementation that can be put anywhere lookup is required.
231///
232/// It is compatible with all TA, OA. Const parameter defines what lookup returns:
233/// * `false` - all lookup elements are ignored
234/// * `true` - all lookup elements are returned
235///
236pub struct NoopLookup<TA, OA, const RES: bool = false>
237where
238    TA: TrackAttributes<TA, OA>,
239    OA: ObservationAttributes,
240{
241    _ta: PhantomData<TA>,
242    _oa: PhantomData<OA>,
243}
244
245impl<TA, OA, const RES: bool> Clone for NoopLookup<TA, OA, RES>
246where
247    TA: TrackAttributes<TA, OA>,
248    OA: ObservationAttributes,
249{
250    fn clone(&self) -> Self {
251        NoopLookup {
252            _ta: PhantomData,
253            _oa: PhantomData,
254        }
255    }
256}
257
258impl<TA, OA, const RES: bool> Default for NoopLookup<TA, OA, RES>
259where
260    TA: TrackAttributes<TA, OA>,
261    OA: ObservationAttributes,
262{
263    fn default() -> Self {
264        NoopLookup {
265            _ta: PhantomData,
266            _oa: PhantomData,
267        }
268    }
269}
270
271impl<TA, OA, const RES: bool> LookupRequest<TA, OA> for NoopLookup<TA, OA, RES>
272where
273    TA: TrackAttributes<TA, OA>,
274    OA: ObservationAttributes,
275{
276    fn lookup(
277        &self,
278        _attributes: &TA,
279        _observations: &ObservationsDb<OA>,
280        _merge_history: &[u64],
281    ) -> bool {
282        RES
283    }
284}
285
286/// The trait represents user defined Track Attributes. It is used to define custom attributes that
287/// fit a domain field where tracking implemented.
288///
289/// When the user implements track attributes they has to implement this trait to create a valid attributes object.
290///
291pub trait TrackAttributes<TA: TrackAttributes<TA, OA>, OA: ObservationAttributes>:
292    Send + Sync + Clone + 'static
293{
294    type Update: TrackAttributesUpdate<TA>;
295    type Lookup: LookupRequest<TA, OA>;
296    /// The method is used to evaluate attributes of two tracks to determine whether tracks are compatible
297    /// for distance calculation. When the attributes are compatible, the method returns `true`.
298    ///
299    /// E.g.
300    ///     Let's imagine the case when the track includes the attributes for track begin and end timestamps.
301    ///     The tracks are compatible their timeframes don't intersect between each other. The method `compatible`
302    ///     can decide that.
303    ///
304    fn compatible(&self, other: &TA) -> bool;
305
306    /// When the tracks are merged, their attributes are merged as well. The method defines the approach to merge attributes.
307    ///
308    /// E.g.
309    ///     Let's imagine the case when the track includes the attributes for track begin and end timestamps.
310    ///     Merge operation may look like `[b1; e1] + [b2; e2] -> [min(b1, b2); max(e1, e2)]`.
311    ///
312    fn merge(&mut self, other: &TA) -> Result<()>;
313
314    /// The method is used by storage to determine when track is ready/not ready/wasted. Look at [TrackStatus](TrackStatus).
315    ///
316    /// It uses attribute information collected across the track config.toml and features information.
317    ///
318    /// E.g.
319    ///     track is ready when
320    ///          `now - end_timestamp > 30s` (no features collected during the last 30 seconds).
321    ///
322    fn baked(&self, observations: &ObservationsDb<OA>) -> Result<TrackStatus>;
323}
324
325/// The attribute update information that is sent with new features to the track is represented by the trait.
326///
327/// The trait must be implemented for update struct for specific attributes struct implementation.
328///
329pub trait TrackAttributesUpdate<TA>: Clone + Send + Sync + 'static {
330    /// Method is used to update track attributes from update structure.
331    ///
332    fn apply(&self, attrs: &mut TA) -> Result<()>;
333}
334
335/// Represents track of observations - it's a core concept of the library.
336///
337/// The track is created for specific attributes(A), Metric(M) and AttributeUpdate(U).
338/// * Attributes hold track meta information specific for certain domain.
339/// * Metric defines how to compare track features and optimize features when tracks are
340///   merged or collected
341/// * AttributeUpdate specifies how attributes are update from external sources.
342///
343#[derive(Default, Clone)]
344pub struct Track<TA, M, OA, N = NoopNotifier>
345where
346    TA: TrackAttributes<TA, OA>,
347    M: ObservationMetric<TA, OA>,
348    OA: ObservationAttributes,
349    N: ChangeNotifier,
350{
351    attributes: TA,
352    track_id: u64,
353    observations: ObservationsDb<OA>,
354    metric: M,
355    merge_history: Vec<u64>,
356    notifier: N,
357}
358
359/// One and only parametrized track implementation.
360///
361impl<TA, M, OA, N> Track<TA, M, OA, N>
362where
363    TA: TrackAttributes<TA, OA>,
364    M: ObservationMetric<TA, OA>,
365    OA: ObservationAttributes,
366    N: ChangeNotifier,
367{
368    /// Creates a new track with id `track_id` with `metric` initializer object and `attributes` initializer object.
369    ///
370    /// The `metric` and `attributes` are optional, if `None` is specified, then `Default` initializer is used.
371    ///
372    pub fn new(track_id: u64, metric: M, attributes: TA, notifier: N) -> Self {
373        let mut v = Self {
374            notifier,
375            attributes,
376            track_id,
377            metric,
378            observations: ObservationsDb::default(),
379            merge_history: vec![track_id],
380        };
381        v.notifier.send(track_id);
382        v
383    }
384
385    /// Returns track_id.
386    ///
387    pub fn get_track_id(&self) -> u64 {
388        self.track_id
389    }
390
391    /// Sets track_id.
392    ///
393    pub fn set_track_id(&mut self, track_id: u64) -> u64 {
394        let old = self.track_id;
395        self.track_id = track_id;
396        old
397    }
398
399    /// Returns current track attributes.
400    ///
401    pub fn get_attributes(&self) -> &TA {
402        &self.attributes
403    }
404
405    pub fn get_observations(&self, feature_class: u64) -> Option<&Vec<Observation<OA>>> {
406        self.observations.get(&feature_class)
407    }
408
409    pub fn get_mut_observations(
410        &mut self,
411        feature_class: u64,
412    ) -> Option<&mut Vec<Observation<OA>>> {
413        self.observations.get_mut(&feature_class)
414    }
415
416    /// Returns the current track merge history for the track
417    ///
418    pub fn get_merge_history(&self) -> &Vec<u64> {
419        &self.merge_history
420    }
421
422    /// Returns all classes present
423    ///
424    pub fn get_feature_classes(&self) -> Vec<u64> {
425        self.observations.keys().cloned().collect()
426    }
427
428    fn update_attributes(&mut self, update: &TA::Update) -> Result<()> {
429        update.apply(&mut self.attributes)
430    }
431
432    /// Adds new observation to track.
433    ///
434    /// When the method is called, the track attributes are updated according to `update` argument, and the feature
435    /// is placed into features for a specified feature class.
436    ///
437    /// # Arguments
438    /// * `feature_class` - class of observation
439    /// * `feature_attributes` - quality of the feature (confidence, or another parameter that defines how the observation is valuable across the observations).
440    /// * `feature` - observation to add to the track for specified `feature_class`.
441    /// * `track_attributes_update` - attribute update message
442    ///
443    /// # Returns
444    /// Returns `Result<()>` where `Ok(())` if attributes are updated without errors AND observation is added AND observations optimized without errors.
445    ///
446    ///
447    pub fn add_observation(
448        &mut self,
449        feature_class: u64,
450        feature_attributes: Option<OA>,
451        feature: Option<Feature>,
452        track_attributes_update: Option<TA::Update>,
453    ) -> Result<()> {
454        let last_attributes = self.attributes.clone();
455        let last_observations = self.observations.clone();
456        let last_metric = self.metric.clone();
457
458        if let Some(track_attributes_update) = &track_attributes_update {
459            let res = self.update_attributes(track_attributes_update);
460            if res.is_err() {
461                self.attributes = last_attributes;
462                res?;
463                unreachable!();
464            }
465        }
466
467        if feature.is_none() && feature_attributes.is_none() {
468            self.notifier.send(self.track_id);
469            return Ok(());
470        }
471
472        match self.observations.get_mut(&feature_class) {
473            None => {
474                self.observations.insert(
475                    feature_class,
476                    vec![Observation(feature_attributes, feature)],
477                );
478            }
479            Some(observations) => {
480                observations.push(Observation(feature_attributes, feature));
481            }
482        }
483        let observations = self.observations.get_mut(&feature_class).unwrap();
484        let prev_length = observations.len() - 1;
485
486        let res = self.metric.optimize(
487            feature_class,
488            &self.merge_history,
489            &mut self.attributes,
490            observations,
491            prev_length,
492            false,
493        );
494        if res.is_err() {
495            self.attributes = last_attributes;
496            self.observations = last_observations;
497            self.metric = last_metric;
498            res?;
499            unreachable!();
500        }
501        self.notifier.send(self.track_id);
502        Ok(())
503    }
504
505    /// Merges vector into current track across specified feature classes.
506    ///
507    /// The merge works across specified feature classes:
508    /// * step 1: attributes are merged
509    /// * step 2.0: features are merged for classes
510    /// * step 2.1: features are optimized for every class
511    ///
512    /// If feature class doesn't exist any of tracks it's skipped, otherwise:
513    ///
514    /// * both: `{S[class]} U {OTHER[class]}`
515    /// * self: `{S[class]}`
516    /// * other: `{OTHER[class]}`
517    ///
518    /// # Parameters
519    /// * `other` - track to merge into self
520    /// * `merge_history` - defines add merged track id into self merge history or not
521    ///
522    pub fn merge(&mut self, other: &Self, classes: &[u64], merge_history: bool) -> Result<()> {
523        let last_attributes = self.attributes.clone();
524        let res = self.attributes.merge(&other.attributes);
525        if res.is_err() {
526            self.attributes = last_attributes;
527            res?;
528            unreachable!();
529        }
530
531        let last_observations = self.observations.clone();
532        let last_metric = self.metric.clone();
533
534        for cls in classes {
535            let dest = self.observations.get_mut(cls);
536            let src = other.observations.get(cls);
537            let prev_length = match (dest, src) {
538                (Some(dest_observations), Some(src_observations)) => {
539                    let prev_length = dest_observations.len();
540                    dest_observations.extend(src_observations.iter().cloned());
541                    Some(prev_length)
542                }
543                (None, Some(src_observations)) => {
544                    self.observations.insert(*cls, src_observations.clone());
545                    Some(0)
546                }
547
548                (Some(dest_observations), None) => {
549                    let prev_length = dest_observations.len();
550                    Some(prev_length)
551                }
552
553                _ => None,
554            };
555            let merge_history = if merge_history {
556                self.merge_history
557                    .iter()
558                    .chain(other.merge_history.iter())
559                    .cloned()
560                    .collect::<Vec<_>>()
561            } else {
562                take(&mut self.merge_history)
563            };
564
565            if let Some(prev_length) = prev_length {
566                let res = self.metric.optimize(
567                    *cls,
568                    &merge_history,
569                    &mut self.attributes,
570                    self.observations.get_mut(cls).unwrap(),
571                    prev_length,
572                    true,
573                );
574
575                if res.is_err() {
576                    self.attributes = last_attributes;
577                    self.observations = last_observations;
578                    self.metric = last_metric;
579                    res?;
580                    unreachable!();
581                }
582                self.merge_history = merge_history;
583            }
584        }
585
586        self.notifier.send(self.track_id);
587        Ok(())
588    }
589
590    /// Calculates distances between all features for two tracks for a class.
591    ///
592    /// First it calculates cartesian product `S X O` and calculates the distance for every pair.
593    ///
594    /// Before it calculates the distance, it checks that attributes are compatible. If no,
595    /// [`Err(Errors::IncompatibleAttributes)`](Errors::IncompatibleAttributes) returned. Otherwise,
596    /// the vector of distances returned that holds `(other.track_id, Result<f32>)` pairs. `Track_id` is
597    /// the same for all results and used in higher level operations. `Result<f32>` is `Ok(f32)` when
598    /// the distance calculated by `Metric` well, `Err(e)` when `Metric` is unable to calculate the distance.
599    ///
600    /// # Parameters
601    /// * `other` - track to find distances to
602    /// * `feature_class` - what feature class to use to calculate distances
603    /// * `filter` - defines either results are filtered by distance before the output or not
604    pub fn distances(
605        &self,
606        other: &Self,
607        feature_class: u64,
608    ) -> Result<Vec<ObservationMetricOk<OA>>> {
609        if !self.attributes.compatible(&other.attributes) {
610            Err(Errors::IncompatibleAttributes.into())
611        } else {
612            match (
613                self.observations.get(&feature_class),
614                other.observations.get(&feature_class),
615            ) {
616                (Some(left), Some(right)) => Ok(left
617                    .iter()
618                    .cartesian_product(right.iter())
619                    .flat_map(|(l, r)| {
620                        let mq = MetricQuery {
621                            feature_class,
622                            candidate_attrs: self.get_attributes(),
623                            candidate_observation: l,
624                            track_attrs: other.get_attributes(),
625                            track_observation: r,
626                        };
627
628                        // let (attribute_metric, feature_distance) = self.metric.new_metric(&mq)?;
629                        let (attribute_metric, feature_distance) = self.metric.metric(
630                            &mq, // feature_class,
631                                // self.get_attributes(),
632                                // other.get_attributes(),
633                                // l,
634                                // r,
635                        )?;
636                        Some(ObservationMetricOk {
637                            from: self.track_id,
638                            to: other.track_id,
639                            attribute_metric,
640                            feature_distance,
641                        })
642                    })
643                    .collect()),
644                _ => Err(Errors::ObservationForClassNotFound(
645                    self.track_id,
646                    other.track_id,
647                    feature_class,
648                )
649                .into()),
650            }
651        }
652    }
653
654    pub fn lookup(&self, query: &TA::Lookup) -> bool {
655        query.lookup(&self.attributes, &self.observations, &self.merge_history)
656    }
657}
658
659#[cfg(test)]
660mod tests {
661    use crate::distance::euclidean;
662    use crate::examples::current_time_sec;
663    use crate::prelude::{NoopNotifier, TrackBuilder};
664    use crate::track::utils::{feature_attributes_sort_dec, FromVec};
665    use crate::track::{
666        Feature, LookupRequest, MetricOutput, MetricQuery, NoopLookup, Observation,
667        ObservationAttributes, ObservationMetric, ObservationsDb, Track, TrackAttributes,
668        TrackAttributesUpdate, TrackStatus,
669    };
670    use crate::EPS;
671    use anyhow::Result;
672
673    #[derive(Clone)]
674    pub struct DefaultAttrs;
675
676    #[derive(Clone)]
677    pub struct DefaultAttrUpdates;
678
679    impl TrackAttributesUpdate<DefaultAttrs> for DefaultAttrUpdates {
680        fn apply(&self, _attrs: &mut DefaultAttrs) -> Result<()> {
681            Ok(())
682        }
683    }
684
685    impl TrackAttributes<DefaultAttrs, f32> for DefaultAttrs {
686        type Update = DefaultAttrUpdates;
687        type Lookup = NoopLookup<DefaultAttrs, f32>;
688
689        fn compatible(&self, _other: &DefaultAttrs) -> bool {
690            true
691        }
692
693        fn merge(&mut self, _other: &DefaultAttrs) -> Result<()> {
694            Ok(())
695        }
696
697        fn baked(&self, _observations: &ObservationsDb<f32>) -> Result<TrackStatus> {
698            Ok(TrackStatus::Pending)
699        }
700    }
701
702    #[derive(Clone)]
703    struct DefaultMetric;
704    impl ObservationMetric<DefaultAttrs, f32> for DefaultMetric {
705        fn metric(&self, mq: &MetricQuery<'_, DefaultAttrs, f32>) -> MetricOutput<f32> {
706            let (e1, e2) = (mq.candidate_observation, mq.track_observation);
707            Some((
708                f32::calculate_metric_object(&e1.attr().as_ref(), &e2.attr().as_ref()),
709                match (e1.feature().as_ref(), e2.feature().as_ref()) {
710                    (Some(x), Some(y)) => Some(euclidean(x, y)),
711                    _ => None,
712                },
713            ))
714        }
715
716        fn optimize(
717            &mut self,
718            _feature_class: u64,
719            _merge_history: &[u64],
720            _attributes: &mut DefaultAttrs,
721            features: &mut Vec<Observation<f32>>,
722            _prev_length: usize,
723            _is_merge: bool,
724        ) -> Result<()> {
725            features.sort_by(feature_attributes_sort_dec);
726            features.truncate(20);
727            Ok(())
728        }
729    }
730
731    #[test]
732    fn init() {
733        let t1 = Track::new(3, DefaultMetric, DefaultAttrs, NoopNotifier);
734        assert_eq!(t1.get_track_id(), 3);
735    }
736
737    #[test]
738    fn track_distances() -> Result<()> {
739        let mut t1 = Track::new(1, DefaultMetric, DefaultAttrs, NoopNotifier);
740        t1.add_observation(
741            0,
742            Some(0.3),
743            Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
744            None,
745        )?;
746
747        let mut t2 = Track::new(2, DefaultMetric, DefaultAttrs, NoopNotifier);
748        t2.add_observation(
749            0,
750            Some(0.3),
751            Some(Feature::from_vec(vec![0f32, 1.0f32, 0.0])),
752            None,
753        )?;
754
755        let dists = t1.distances(&t1, 0);
756        let dists = dists.unwrap();
757        assert_eq!(dists.len(), 1);
758        assert!(*dists[0].feature_distance.as_ref().unwrap() < EPS);
759
760        let dists = t1.distances(&t2, 0);
761        let dists = dists.unwrap();
762        assert_eq!(dists.len(), 1);
763        assert!((*dists[0].feature_distance.as_ref().unwrap() - 2.0_f32.sqrt()).abs() < EPS);
764
765        t2.add_observation(
766            0,
767            Some(0.2),
768            Some(Feature::from_vec(vec![1f32, 1.0f32, 0.0])),
769            None,
770        )?;
771
772        assert_eq!(t2.observations.get(&0).unwrap().len(), 2);
773
774        let dists = t1.distances(&t2, 0);
775        let dists = dists.unwrap();
776        assert_eq!(dists.len(), 2);
777        assert!((*dists[0].feature_distance.as_ref().unwrap() - 2.0_f32.sqrt()).abs() < EPS);
778        assert!((*dists[1].feature_distance.as_ref().unwrap() - 1.0).abs() < EPS);
779        Ok(())
780    }
781
782    #[test]
783    fn merge_same() -> Result<()> {
784        let mut t1 = Track::new(1, DefaultMetric, DefaultAttrs, NoopNotifier);
785        t1.add_observation(
786            0,
787            Some(0.3),
788            Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
789            None,
790        )?;
791
792        let mut t2 = Track::new(2, DefaultMetric, DefaultAttrs, NoopNotifier);
793        t2.add_observation(
794            0,
795            Some(0.3),
796            Some(Feature::from_vec(vec![0f32, 1.0f32, 0.0])),
797            None,
798        )?;
799        let r = t1.merge(&t2, &[0], true);
800        assert!(r.is_ok());
801        assert_eq!(t1.observations.get(&0).unwrap().len(), 2);
802        Ok(())
803    }
804
805    #[test]
806    fn merge_other_feature_class() -> Result<()> {
807        let mut t1 = Track::new(1, DefaultMetric, DefaultAttrs, NoopNotifier);
808        t1.add_observation(
809            0,
810            Some(0.3),
811            Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
812            None,
813        )?;
814
815        let mut t2 = Track::new(2, DefaultMetric, DefaultAttrs, NoopNotifier);
816        t2.add_observation(
817            1,
818            Some(0.3),
819            Some(Feature::from_vec(vec![0f32, 1.0f32, 0.0])),
820            None,
821        )?;
822        let r = t1.merge(&t2, &[1], true);
823        assert!(r.is_ok());
824        assert_eq!(t1.observations.get(&0).unwrap().len(), 1);
825        assert_eq!(t1.observations.get(&1).unwrap().len(), 1);
826        Ok(())
827    }
828
829    #[test]
830    fn attribute_compatible_match() -> Result<()> {
831        #[derive(Default, Debug, Clone)]
832        pub struct TimeAttrs {
833            start_time: u64,
834            end_time: u64,
835        }
836
837        #[derive(Default, Clone)]
838        pub struct TimeAttrUpdates {
839            time: u64,
840        }
841
842        impl TrackAttributesUpdate<TimeAttrs> for TimeAttrUpdates {
843            fn apply(&self, attrs: &mut TimeAttrs) -> Result<()> {
844                attrs.end_time = self.time;
845                if attrs.start_time == 0 {
846                    attrs.start_time = self.time;
847                }
848                Ok(())
849            }
850        }
851
852        impl TrackAttributes<TimeAttrs, f32> for TimeAttrs {
853            type Update = TimeAttrUpdates;
854            type Lookup = NoopLookup<TimeAttrs, f32>;
855
856            fn compatible(&self, other: &TimeAttrs) -> bool {
857                self.end_time <= other.start_time
858            }
859
860            fn merge(&mut self, other: &TimeAttrs) -> Result<()> {
861                self.start_time = self.start_time.min(other.start_time);
862                self.end_time = self.end_time.max(other.end_time);
863                Ok(())
864            }
865
866            fn baked(&self, _observations: &ObservationsDb<f32>) -> Result<TrackStatus> {
867                if current_time_sec() - self.end_time > 30 {
868                    Ok(TrackStatus::Ready)
869                } else {
870                    Ok(TrackStatus::Pending)
871                }
872            }
873        }
874
875        #[derive(Default, Clone)]
876        struct TimeMetric;
877        impl ObservationMetric<TimeAttrs, f32> for TimeMetric {
878            fn metric(&self, mq: &MetricQuery<'_, TimeAttrs, f32>) -> MetricOutput<f32> {
879                let (e1, e2) = (mq.candidate_observation, mq.track_observation);
880                Some((
881                    f32::calculate_metric_object(&e1.attr().as_ref(), &e2.attr().as_ref()),
882                    match (e1.feature().as_ref(), e2.feature().as_ref()) {
883                        (Some(x), Some(y)) => Some(euclidean(x, y)),
884                        _ => None,
885                    },
886                ))
887            }
888
889            fn optimize(
890                &mut self,
891                _feature_class: u64,
892                _merge_history: &[u64],
893                _attributes: &mut TimeAttrs,
894                features: &mut Vec<Observation<f32>>,
895                _prev_length: usize,
896                _is_merge: bool,
897            ) -> Result<()> {
898                features.sort_by(feature_attributes_sort_dec);
899                features.truncate(20);
900                Ok(())
901            }
902        }
903
904        let mut t1 = Track::new(1, TimeMetric::default(), TimeAttrs::default(), NoopNotifier);
905        t1.add_observation(
906            0,
907            Some(0.3),
908            Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
909            Some(TimeAttrUpdates { time: 2 }),
910        )?;
911
912        let mut t2 = Track::new(2, TimeMetric::default(), TimeAttrs::default(), NoopNotifier);
913        t2.add_observation(
914            0,
915            Some(0.3),
916            Some(Feature::from_vec(vec![0f32, 1.0f32, 0.0])),
917            Some(TimeAttrUpdates { time: 3 }),
918        )?;
919
920        let dists = t1.distances(&t2, 0);
921        let dists = dists.unwrap();
922        assert_eq!(dists.len(), 1);
923        assert!((*dists[0].feature_distance.as_ref().unwrap() - 2.0_f32.sqrt()).abs() < EPS);
924        assert_eq!(dists[0].to, 2);
925
926        let mut t3 = Track::new(3, TimeMetric::default(), TimeAttrs::default(), NoopNotifier);
927        t3.add_observation(
928            0,
929            Some(0.3),
930            Some(Feature::from_vec(vec![0f32, 1.0f32, 0.0])),
931            Some(TimeAttrUpdates { time: 1 }),
932        )?;
933
934        let dists = t1.distances(&t3, 0);
935        assert!(dists.is_err());
936        Ok(())
937    }
938
939    #[test]
940    fn get_classes() -> Result<()> {
941        let mut t1 = Track::new(1, DefaultMetric, DefaultAttrs, NoopNotifier);
942        t1.add_observation(
943            0,
944            Some(0.3),
945            Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
946            None,
947        )?;
948
949        t1.add_observation(
950            1,
951            Some(0.3),
952            Some(Feature::from_vec(vec![0f32, 1.0f32, 0.0])),
953            None,
954        )?;
955        let mut classes = t1.get_feature_classes();
956        classes.sort();
957
958        assert_eq!(classes, vec![0, 1]);
959
960        Ok(())
961    }
962
963    #[test]
964    fn attr_metric_update_recover() {
965        use thiserror::Error;
966
967        #[derive(Error, Debug)]
968        enum TestError {
969            #[error("Update Error")]
970            Update,
971            #[error("Unable to Merge")]
972            Merge,
973            #[error("Unable to Optimize")]
974            Optimize,
975        }
976
977        #[derive(Default, Clone, PartialEq, Eq, Debug)]
978        pub struct LocalAttrs {
979            pub count: u32,
980        }
981
982        #[derive(Clone)]
983        pub struct LocalAttrsUpdates {
984            ignore: bool,
985        }
986
987        impl TrackAttributesUpdate<LocalAttrs> for LocalAttrsUpdates {
988            fn apply(&self, attrs: &mut LocalAttrs) -> Result<()> {
989                if !self.ignore {
990                    attrs.count += 1;
991                    if attrs.count > 1 {
992                        Err(TestError::Update.into())
993                    } else {
994                        Ok(())
995                    }
996                } else {
997                    Ok(())
998                }
999            }
1000        }
1001
1002        impl TrackAttributes<LocalAttrs, f32> for LocalAttrs {
1003            type Update = LocalAttrsUpdates;
1004            type Lookup = NoopLookup<LocalAttrs, f32>;
1005
1006            fn compatible(&self, _other: &LocalAttrs) -> bool {
1007                true
1008            }
1009
1010            fn merge(&mut self, _other: &LocalAttrs) -> Result<()> {
1011                Err(TestError::Merge.into())
1012            }
1013
1014            fn baked(&self, _observations: &ObservationsDb<f32>) -> Result<TrackStatus> {
1015                Ok(TrackStatus::Pending)
1016            }
1017        }
1018
1019        #[derive(Clone)]
1020        struct LocalMetric;
1021        impl ObservationMetric<LocalAttrs, f32> for LocalMetric {
1022            fn metric(&self, mq: &MetricQuery<LocalAttrs, f32>) -> MetricOutput<f32> {
1023                let (e1, e2) = (mq.candidate_observation, mq.track_observation);
1024                Some((
1025                    f32::calculate_metric_object(&e1.attr().as_ref(), &e2.attr().as_ref()),
1026                    match (e1.feature().as_ref(), e2.feature().as_ref()) {
1027                        (Some(x), Some(y)) => Some(euclidean(x, y)),
1028                        _ => None,
1029                    },
1030                ))
1031            }
1032
1033            fn optimize(
1034                &mut self,
1035                _feature_class: u64,
1036                _merge_history: &[u64],
1037                _attributes: &mut LocalAttrs,
1038                _features: &mut Vec<Observation<f32>>,
1039                prev_length: usize,
1040                _is_merge: bool,
1041            ) -> Result<()> {
1042                if prev_length == 1 {
1043                    Err(TestError::Optimize.into())
1044                } else {
1045                    Ok(())
1046                }
1047            }
1048        }
1049
1050        let mut t1 = Track::new(1, LocalMetric, LocalAttrs::default(), NoopNotifier);
1051        assert_eq!(t1.attributes, LocalAttrs { count: 0 });
1052        let res = t1.add_observation(
1053            0,
1054            Some(0.3),
1055            Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
1056            Some(LocalAttrsUpdates { ignore: false }),
1057        );
1058        assert!(res.is_ok());
1059        assert_eq!(t1.attributes, LocalAttrs { count: 1 });
1060
1061        let res = t1.add_observation(
1062            0,
1063            Some(0.3),
1064            Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
1065            Some(LocalAttrsUpdates { ignore: true }),
1066        );
1067        assert!(res.is_err());
1068        if let Err(e) = res {
1069            match e.root_cause().downcast_ref::<TestError>().unwrap() {
1070                TestError::Update | TestError::Merge => {
1071                    unreachable!();
1072                }
1073                TestError::Optimize => {}
1074            }
1075        } else {
1076            unreachable!();
1077        }
1078
1079        assert_eq!(t1.attributes, LocalAttrs { count: 1 });
1080
1081        let mut t2 = Track::new(2, LocalMetric, LocalAttrs::default(), NoopNotifier);
1082        assert_eq!(t2.attributes, LocalAttrs { count: 0 });
1083        let res = t2.add_observation(
1084            0,
1085            Some(0.3),
1086            Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
1087            Some(LocalAttrsUpdates { ignore: false }),
1088        );
1089        assert!(res.is_ok());
1090        assert_eq!(t2.attributes, LocalAttrs { count: 1 });
1091
1092        let res = t1.merge(&t2, &[0], true);
1093        if let Err(e) = res {
1094            match e.root_cause().downcast_ref::<TestError>().unwrap() {
1095                TestError::Update | TestError::Optimize => {
1096                    unreachable!();
1097                }
1098                TestError::Merge => {}
1099            }
1100        } else {
1101            unreachable!();
1102        }
1103        assert_eq!(t1.attributes, LocalAttrs { count: 1 });
1104    }
1105
1106    #[test]
1107    fn merge_history() -> Result<()> {
1108        let mut t1 = Track::new(0, DefaultMetric, DefaultAttrs, NoopNotifier);
1109        let mut t2 = Track::new(1, DefaultMetric, DefaultAttrs, NoopNotifier);
1110
1111        t1.add_observation(
1112            0,
1113            Some(0.3),
1114            Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
1115            None,
1116        )?;
1117
1118        t2.add_observation(
1119            0,
1120            Some(0.3),
1121            Some(Feature::from_vec(vec![0f32, 1.0f32, 0.0])),
1122            None,
1123        )?;
1124
1125        let mut track_with_merge_history = t1.clone();
1126        let _r = track_with_merge_history.merge(&t2, &[0], true);
1127        assert_eq!(track_with_merge_history.merge_history, vec![0, 1]);
1128
1129        let _r = t1.merge(&t2, &[0], false);
1130        assert_eq!(t1.merge_history, vec![0]);
1131
1132        Ok(())
1133    }
1134
1135    #[test]
1136    fn unit_track() {
1137        #[derive(Clone)]
1138        pub struct UnitAttrs;
1139
1140        #[derive(Clone)]
1141        pub struct UnitAttrUpdates;
1142
1143        impl TrackAttributesUpdate<UnitAttrs> for UnitAttrUpdates {
1144            fn apply(&self, _attrs: &mut UnitAttrs) -> Result<()> {
1145                Ok(())
1146            }
1147        }
1148
1149        impl TrackAttributes<UnitAttrs, ()> for UnitAttrs {
1150            type Update = UnitAttrUpdates;
1151            type Lookup = NoopLookup<UnitAttrs, ()>;
1152
1153            fn compatible(&self, _other: &UnitAttrs) -> bool {
1154                true
1155            }
1156
1157            fn merge(&mut self, _other: &UnitAttrs) -> Result<()> {
1158                Ok(())
1159            }
1160
1161            fn baked(&self, _observations: &ObservationsDb<()>) -> Result<TrackStatus> {
1162                Ok(TrackStatus::Pending)
1163            }
1164        }
1165
1166        #[derive(Clone)]
1167        struct UnitMetric;
1168        impl ObservationMetric<UnitAttrs, ()> for UnitMetric {
1169            fn metric(&self, mq: &MetricQuery<UnitAttrs, ()>) -> MetricOutput<()> {
1170                let (e1, e2) = (mq.candidate_observation, mq.track_observation);
1171                Some((
1172                    None,
1173                    match (e1.1.as_ref(), e2.1.as_ref()) {
1174                        (Some(x), Some(y)) => Some(euclidean(x, y)),
1175                        _ => None,
1176                    },
1177                ))
1178            }
1179
1180            fn optimize(
1181                &mut self,
1182                _feature_class: u64,
1183                _merge_history: &[u64],
1184                _attributes: &mut UnitAttrs,
1185                features: &mut Vec<Observation<()>>,
1186                _prev_length: usize,
1187                _is_merge: bool,
1188            ) -> Result<()> {
1189                features.sort_by(feature_attributes_sort_dec);
1190                features.truncate(20);
1191                Ok(())
1192            }
1193        }
1194
1195        let _t1 = Track::new(1, UnitMetric, UnitAttrs, NoopNotifier);
1196    }
1197
1198    #[test]
1199    fn lookup() {
1200        #[derive(Default, Clone)]
1201        struct Lookup;
1202        impl LookupRequest<LookupAttrs, f32> for Lookup {
1203            fn lookup(
1204                &self,
1205                _attributes: &LookupAttrs,
1206                _observations: &ObservationsDb<f32>,
1207                _merge_history: &[u64],
1208            ) -> bool {
1209                true
1210            }
1211        }
1212
1213        #[derive(Clone, Default)]
1214        struct LookupAttrs;
1215
1216        #[derive(Clone)]
1217        pub struct LookupAttributeUpdate;
1218
1219        impl TrackAttributesUpdate<LookupAttrs> for LookupAttributeUpdate {
1220            fn apply(&self, _attrs: &mut LookupAttrs) -> Result<()> {
1221                Ok(())
1222            }
1223        }
1224
1225        impl TrackAttributes<LookupAttrs, f32> for LookupAttrs {
1226            type Update = LookupAttributeUpdate;
1227            type Lookup = Lookup;
1228
1229            fn compatible(&self, _other: &LookupAttrs) -> bool {
1230                true
1231            }
1232
1233            fn merge(&mut self, _other: &LookupAttrs) -> Result<()> {
1234                Ok(())
1235            }
1236
1237            fn baked(&self, _observations: &ObservationsDb<f32>) -> Result<TrackStatus> {
1238                Ok(TrackStatus::Ready)
1239            }
1240        }
1241
1242        #[derive(Clone)]
1243        pub struct LookupMetric;
1244
1245        impl ObservationMetric<LookupAttrs, f32> for LookupMetric {
1246            fn metric(&self, mq: &MetricQuery<LookupAttrs, f32>) -> MetricOutput<f32> {
1247                let (e1, e2) = (mq.candidate_observation, mq.track_observation);
1248                Some((
1249                    f32::calculate_metric_object(&e1.attr().as_ref(), &e2.attr().as_ref()),
1250                    match (e1.feature().as_ref(), e2.feature().as_ref()) {
1251                        (Some(x), Some(y)) => Some(euclidean(x, y)),
1252                        _ => None,
1253                    },
1254                ))
1255            }
1256
1257            fn optimize(
1258                &mut self,
1259                _feature_class: u64,
1260                _merge_history: &[u64],
1261                _attrs: &mut LookupAttrs,
1262                _features: &mut Vec<Observation<f32>>,
1263                _prev_length: usize,
1264                _is_merge: bool,
1265            ) -> Result<()> {
1266                Ok(())
1267            }
1268        }
1269
1270        let t: Track<LookupAttrs, LookupMetric, f32> = TrackBuilder::default()
1271            .metric(LookupMetric)
1272            .attributes(LookupAttrs)
1273            .notifier(NoopNotifier)
1274            .build()
1275            .unwrap();
1276        assert!(t.lookup(&Lookup));
1277    }
1278}