1use crate::track::notify::{ChangeNotifier, NoopNotifier};
2use crate::Errors;
3use anyhow::Result;
4use itertools::Itertools;
5use std::collections::HashMap;
6use std::fmt::Debug;
7use std::marker::PhantomData;
8use std::mem::take;
9use ultraviolet::f32x8;
10
11pub mod builder;
12pub mod notify;
13pub mod store;
14pub mod utils;
15pub mod voting;
16
17#[derive(Debug, Clone)]
20pub struct ObservationMetricOk<OA>
21where
22 OA: ObservationAttributes,
23{
24 pub from: u64,
26 pub to: u64,
28 pub attribute_metric: Option<OA::MetricObject>,
30 pub feature_distance: Option<f32>,
32}
33
34impl<OA> ObservationMetricOk<OA>
35where
36 OA: ObservationAttributes,
37{
38 pub fn new(
39 from: u64,
40 to: u64,
41 attribute_metric: Option<OA::MetricObject>,
42 feature_distance: Option<f32>,
43 ) -> Self {
44 Self {
45 from,
46 to,
47 attribute_metric,
48 feature_distance,
49 }
50 }
51}
52
53pub type Feature = Vec<f32x8>;
56
57const FEATURE_LANES_SIZE: usize = 8;
59
60#[derive(Default, Clone)]
67pub struct Observation<T>(pub(crate) Option<T>, pub(crate) Option<Feature>)
68where
69 T: Send + Sync + Clone + 'static;
70
71impl<T> Observation<T>
72where
73 T: Send + Sync + Clone + 'static,
74{
75 pub fn new(attrs: Option<T>, feature: Option<Feature>) -> Self {
76 Self(attrs, feature)
77 }
78
79 pub fn attr(&self) -> &Option<T> {
82 &self.0
83 }
84
85 pub fn attr_mut(&mut self) -> &mut Option<T> {
88 &mut self.0
89 }
90
91 pub fn feature(&self) -> &Option<Feature> {
94 &self.1
95 }
96
97 pub fn feature_mut(&mut self) -> &mut Option<Feature> {
100 &mut self.1
101 }
102}
103
104pub type ObservationsDb<T> = HashMap<u64, Vec<Observation<T>>>;
109
110pub trait ObservationAttributes: Send + Sync + Clone + 'static {
113 type MetricObject: Debug + Send + Sync + Clone + 'static;
114 fn calculate_metric_object(l: &Option<&Self>, r: &Option<&Self>) -> Option<Self::MetricObject>;
115}
116
117pub type MetricOutput<T> = Option<(Option<T>, Option<f32>)>;
126
127pub struct MetricQuery<'a, TA, OA: ObservationAttributes> {
135 pub feature_class: u64,
137 pub candidate_attrs: &'a TA,
139 pub candidate_observation: &'a Observation<OA>,
141 pub track_attrs: &'a TA,
143 pub track_observation: &'a Observation<OA>,
145}
146
147pub trait ObservationMetric<TA, OA: ObservationAttributes>: Send + Sync + Clone + 'static {
153 fn metric(&self, mq: &'_ MetricQuery<'_, TA, OA>) -> MetricOutput<OA::MetricObject>;
159
160 fn optimize(
177 &mut self,
178 feature_class: u64,
179 merge_history: &[u64],
180 attributes: &mut TA,
181 observations: &mut Vec<Observation<OA>>,
182 prev_length: usize,
183 is_merge: bool,
184 ) -> Result<()>;
185
186 fn postprocess_distances(
192 &self,
193 unfiltered: Vec<ObservationMetricOk<OA>>,
194 ) -> Vec<ObservationMetricOk<OA>> {
195 unfiltered
196 }
197}
198
199#[derive(Clone, Debug)]
204pub enum TrackStatus {
205 Ready,
207 Pending,
209 Wasted,
211}
212
213pub trait LookupRequest<TA, OA>: Send + Sync + Clone + 'static
216where
217 TA: TrackAttributes<TA, OA>,
218 OA: ObservationAttributes,
219{
220 fn lookup(
221 &self,
222 _attributes: &TA,
223 _observations: &ObservationsDb<OA>,
224 _merge_history: &[u64],
225 ) -> bool {
226 false
227 }
228}
229
230pub struct NoopLookup<TA, OA, const RES: bool = false>
237where
238 TA: TrackAttributes<TA, OA>,
239 OA: ObservationAttributes,
240{
241 _ta: PhantomData<TA>,
242 _oa: PhantomData<OA>,
243}
244
245impl<TA, OA, const RES: bool> Clone for NoopLookup<TA, OA, RES>
246where
247 TA: TrackAttributes<TA, OA>,
248 OA: ObservationAttributes,
249{
250 fn clone(&self) -> Self {
251 NoopLookup {
252 _ta: PhantomData,
253 _oa: PhantomData,
254 }
255 }
256}
257
258impl<TA, OA, const RES: bool> Default for NoopLookup<TA, OA, RES>
259where
260 TA: TrackAttributes<TA, OA>,
261 OA: ObservationAttributes,
262{
263 fn default() -> Self {
264 NoopLookup {
265 _ta: PhantomData,
266 _oa: PhantomData,
267 }
268 }
269}
270
271impl<TA, OA, const RES: bool> LookupRequest<TA, OA> for NoopLookup<TA, OA, RES>
272where
273 TA: TrackAttributes<TA, OA>,
274 OA: ObservationAttributes,
275{
276 fn lookup(
277 &self,
278 _attributes: &TA,
279 _observations: &ObservationsDb<OA>,
280 _merge_history: &[u64],
281 ) -> bool {
282 RES
283 }
284}
285
286pub trait TrackAttributes<TA: TrackAttributes<TA, OA>, OA: ObservationAttributes>:
292 Send + Sync + Clone + 'static
293{
294 type Update: TrackAttributesUpdate<TA>;
295 type Lookup: LookupRequest<TA, OA>;
296 fn compatible(&self, other: &TA) -> bool;
305
306 fn merge(&mut self, other: &TA) -> Result<()>;
313
314 fn baked(&self, observations: &ObservationsDb<OA>) -> Result<TrackStatus>;
323}
324
325pub trait TrackAttributesUpdate<TA>: Clone + Send + Sync + 'static {
330 fn apply(&self, attrs: &mut TA) -> Result<()>;
333}
334
335#[derive(Default, Clone)]
344pub struct Track<TA, M, OA, N = NoopNotifier>
345where
346 TA: TrackAttributes<TA, OA>,
347 M: ObservationMetric<TA, OA>,
348 OA: ObservationAttributes,
349 N: ChangeNotifier,
350{
351 attributes: TA,
352 track_id: u64,
353 observations: ObservationsDb<OA>,
354 metric: M,
355 merge_history: Vec<u64>,
356 notifier: N,
357}
358
359impl<TA, M, OA, N> Track<TA, M, OA, N>
362where
363 TA: TrackAttributes<TA, OA>,
364 M: ObservationMetric<TA, OA>,
365 OA: ObservationAttributes,
366 N: ChangeNotifier,
367{
368 pub fn new(track_id: u64, metric: M, attributes: TA, notifier: N) -> Self {
373 let mut v = Self {
374 notifier,
375 attributes,
376 track_id,
377 metric,
378 observations: ObservationsDb::default(),
379 merge_history: vec![track_id],
380 };
381 v.notifier.send(track_id);
382 v
383 }
384
385 pub fn get_track_id(&self) -> u64 {
388 self.track_id
389 }
390
391 pub fn set_track_id(&mut self, track_id: u64) -> u64 {
394 let old = self.track_id;
395 self.track_id = track_id;
396 old
397 }
398
399 pub fn get_attributes(&self) -> &TA {
402 &self.attributes
403 }
404
405 pub fn get_observations(&self, feature_class: u64) -> Option<&Vec<Observation<OA>>> {
406 self.observations.get(&feature_class)
407 }
408
409 pub fn get_mut_observations(
410 &mut self,
411 feature_class: u64,
412 ) -> Option<&mut Vec<Observation<OA>>> {
413 self.observations.get_mut(&feature_class)
414 }
415
416 pub fn get_merge_history(&self) -> &Vec<u64> {
419 &self.merge_history
420 }
421
422 pub fn get_feature_classes(&self) -> Vec<u64> {
425 self.observations.keys().cloned().collect()
426 }
427
428 fn update_attributes(&mut self, update: &TA::Update) -> Result<()> {
429 update.apply(&mut self.attributes)
430 }
431
432 pub fn add_observation(
448 &mut self,
449 feature_class: u64,
450 feature_attributes: Option<OA>,
451 feature: Option<Feature>,
452 track_attributes_update: Option<TA::Update>,
453 ) -> Result<()> {
454 let last_attributes = self.attributes.clone();
455 let last_observations = self.observations.clone();
456 let last_metric = self.metric.clone();
457
458 if let Some(track_attributes_update) = &track_attributes_update {
459 let res = self.update_attributes(track_attributes_update);
460 if res.is_err() {
461 self.attributes = last_attributes;
462 res?;
463 unreachable!();
464 }
465 }
466
467 if feature.is_none() && feature_attributes.is_none() {
468 self.notifier.send(self.track_id);
469 return Ok(());
470 }
471
472 match self.observations.get_mut(&feature_class) {
473 None => {
474 self.observations.insert(
475 feature_class,
476 vec![Observation(feature_attributes, feature)],
477 );
478 }
479 Some(observations) => {
480 observations.push(Observation(feature_attributes, feature));
481 }
482 }
483 let observations = self.observations.get_mut(&feature_class).unwrap();
484 let prev_length = observations.len() - 1;
485
486 let res = self.metric.optimize(
487 feature_class,
488 &self.merge_history,
489 &mut self.attributes,
490 observations,
491 prev_length,
492 false,
493 );
494 if res.is_err() {
495 self.attributes = last_attributes;
496 self.observations = last_observations;
497 self.metric = last_metric;
498 res?;
499 unreachable!();
500 }
501 self.notifier.send(self.track_id);
502 Ok(())
503 }
504
505 pub fn merge(&mut self, other: &Self, classes: &[u64], merge_history: bool) -> Result<()> {
523 let last_attributes = self.attributes.clone();
524 let res = self.attributes.merge(&other.attributes);
525 if res.is_err() {
526 self.attributes = last_attributes;
527 res?;
528 unreachable!();
529 }
530
531 let last_observations = self.observations.clone();
532 let last_metric = self.metric.clone();
533
534 for cls in classes {
535 let dest = self.observations.get_mut(cls);
536 let src = other.observations.get(cls);
537 let prev_length = match (dest, src) {
538 (Some(dest_observations), Some(src_observations)) => {
539 let prev_length = dest_observations.len();
540 dest_observations.extend(src_observations.iter().cloned());
541 Some(prev_length)
542 }
543 (None, Some(src_observations)) => {
544 self.observations.insert(*cls, src_observations.clone());
545 Some(0)
546 }
547
548 (Some(dest_observations), None) => {
549 let prev_length = dest_observations.len();
550 Some(prev_length)
551 }
552
553 _ => None,
554 };
555 let merge_history = if merge_history {
556 self.merge_history
557 .iter()
558 .chain(other.merge_history.iter())
559 .cloned()
560 .collect::<Vec<_>>()
561 } else {
562 take(&mut self.merge_history)
563 };
564
565 if let Some(prev_length) = prev_length {
566 let res = self.metric.optimize(
567 *cls,
568 &merge_history,
569 &mut self.attributes,
570 self.observations.get_mut(cls).unwrap(),
571 prev_length,
572 true,
573 );
574
575 if res.is_err() {
576 self.attributes = last_attributes;
577 self.observations = last_observations;
578 self.metric = last_metric;
579 res?;
580 unreachable!();
581 }
582 self.merge_history = merge_history;
583 }
584 }
585
586 self.notifier.send(self.track_id);
587 Ok(())
588 }
589
590 pub fn distances(
605 &self,
606 other: &Self,
607 feature_class: u64,
608 ) -> Result<Vec<ObservationMetricOk<OA>>> {
609 if !self.attributes.compatible(&other.attributes) {
610 Err(Errors::IncompatibleAttributes.into())
611 } else {
612 match (
613 self.observations.get(&feature_class),
614 other.observations.get(&feature_class),
615 ) {
616 (Some(left), Some(right)) => Ok(left
617 .iter()
618 .cartesian_product(right.iter())
619 .flat_map(|(l, r)| {
620 let mq = MetricQuery {
621 feature_class,
622 candidate_attrs: self.get_attributes(),
623 candidate_observation: l,
624 track_attrs: other.get_attributes(),
625 track_observation: r,
626 };
627
628 let (attribute_metric, feature_distance) = self.metric.metric(
630 &mq, )?;
636 Some(ObservationMetricOk {
637 from: self.track_id,
638 to: other.track_id,
639 attribute_metric,
640 feature_distance,
641 })
642 })
643 .collect()),
644 _ => Err(Errors::ObservationForClassNotFound(
645 self.track_id,
646 other.track_id,
647 feature_class,
648 )
649 .into()),
650 }
651 }
652 }
653
654 pub fn lookup(&self, query: &TA::Lookup) -> bool {
655 query.lookup(&self.attributes, &self.observations, &self.merge_history)
656 }
657}
658
659#[cfg(test)]
660mod tests {
661 use crate::distance::euclidean;
662 use crate::examples::current_time_sec;
663 use crate::prelude::{NoopNotifier, TrackBuilder};
664 use crate::track::utils::{feature_attributes_sort_dec, FromVec};
665 use crate::track::{
666 Feature, LookupRequest, MetricOutput, MetricQuery, NoopLookup, Observation,
667 ObservationAttributes, ObservationMetric, ObservationsDb, Track, TrackAttributes,
668 TrackAttributesUpdate, TrackStatus,
669 };
670 use crate::EPS;
671 use anyhow::Result;
672
673 #[derive(Clone)]
674 pub struct DefaultAttrs;
675
676 #[derive(Clone)]
677 pub struct DefaultAttrUpdates;
678
679 impl TrackAttributesUpdate<DefaultAttrs> for DefaultAttrUpdates {
680 fn apply(&self, _attrs: &mut DefaultAttrs) -> Result<()> {
681 Ok(())
682 }
683 }
684
685 impl TrackAttributes<DefaultAttrs, f32> for DefaultAttrs {
686 type Update = DefaultAttrUpdates;
687 type Lookup = NoopLookup<DefaultAttrs, f32>;
688
689 fn compatible(&self, _other: &DefaultAttrs) -> bool {
690 true
691 }
692
693 fn merge(&mut self, _other: &DefaultAttrs) -> Result<()> {
694 Ok(())
695 }
696
697 fn baked(&self, _observations: &ObservationsDb<f32>) -> Result<TrackStatus> {
698 Ok(TrackStatus::Pending)
699 }
700 }
701
702 #[derive(Clone)]
703 struct DefaultMetric;
704 impl ObservationMetric<DefaultAttrs, f32> for DefaultMetric {
705 fn metric(&self, mq: &MetricQuery<'_, DefaultAttrs, f32>) -> MetricOutput<f32> {
706 let (e1, e2) = (mq.candidate_observation, mq.track_observation);
707 Some((
708 f32::calculate_metric_object(&e1.attr().as_ref(), &e2.attr().as_ref()),
709 match (e1.feature().as_ref(), e2.feature().as_ref()) {
710 (Some(x), Some(y)) => Some(euclidean(x, y)),
711 _ => None,
712 },
713 ))
714 }
715
716 fn optimize(
717 &mut self,
718 _feature_class: u64,
719 _merge_history: &[u64],
720 _attributes: &mut DefaultAttrs,
721 features: &mut Vec<Observation<f32>>,
722 _prev_length: usize,
723 _is_merge: bool,
724 ) -> Result<()> {
725 features.sort_by(feature_attributes_sort_dec);
726 features.truncate(20);
727 Ok(())
728 }
729 }
730
731 #[test]
732 fn init() {
733 let t1 = Track::new(3, DefaultMetric, DefaultAttrs, NoopNotifier);
734 assert_eq!(t1.get_track_id(), 3);
735 }
736
737 #[test]
738 fn track_distances() -> Result<()> {
739 let mut t1 = Track::new(1, DefaultMetric, DefaultAttrs, NoopNotifier);
740 t1.add_observation(
741 0,
742 Some(0.3),
743 Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
744 None,
745 )?;
746
747 let mut t2 = Track::new(2, DefaultMetric, DefaultAttrs, NoopNotifier);
748 t2.add_observation(
749 0,
750 Some(0.3),
751 Some(Feature::from_vec(vec![0f32, 1.0f32, 0.0])),
752 None,
753 )?;
754
755 let dists = t1.distances(&t1, 0);
756 let dists = dists.unwrap();
757 assert_eq!(dists.len(), 1);
758 assert!(*dists[0].feature_distance.as_ref().unwrap() < EPS);
759
760 let dists = t1.distances(&t2, 0);
761 let dists = dists.unwrap();
762 assert_eq!(dists.len(), 1);
763 assert!((*dists[0].feature_distance.as_ref().unwrap() - 2.0_f32.sqrt()).abs() < EPS);
764
765 t2.add_observation(
766 0,
767 Some(0.2),
768 Some(Feature::from_vec(vec![1f32, 1.0f32, 0.0])),
769 None,
770 )?;
771
772 assert_eq!(t2.observations.get(&0).unwrap().len(), 2);
773
774 let dists = t1.distances(&t2, 0);
775 let dists = dists.unwrap();
776 assert_eq!(dists.len(), 2);
777 assert!((*dists[0].feature_distance.as_ref().unwrap() - 2.0_f32.sqrt()).abs() < EPS);
778 assert!((*dists[1].feature_distance.as_ref().unwrap() - 1.0).abs() < EPS);
779 Ok(())
780 }
781
782 #[test]
783 fn merge_same() -> Result<()> {
784 let mut t1 = Track::new(1, DefaultMetric, DefaultAttrs, NoopNotifier);
785 t1.add_observation(
786 0,
787 Some(0.3),
788 Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
789 None,
790 )?;
791
792 let mut t2 = Track::new(2, DefaultMetric, DefaultAttrs, NoopNotifier);
793 t2.add_observation(
794 0,
795 Some(0.3),
796 Some(Feature::from_vec(vec![0f32, 1.0f32, 0.0])),
797 None,
798 )?;
799 let r = t1.merge(&t2, &[0], true);
800 assert!(r.is_ok());
801 assert_eq!(t1.observations.get(&0).unwrap().len(), 2);
802 Ok(())
803 }
804
805 #[test]
806 fn merge_other_feature_class() -> Result<()> {
807 let mut t1 = Track::new(1, DefaultMetric, DefaultAttrs, NoopNotifier);
808 t1.add_observation(
809 0,
810 Some(0.3),
811 Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
812 None,
813 )?;
814
815 let mut t2 = Track::new(2, DefaultMetric, DefaultAttrs, NoopNotifier);
816 t2.add_observation(
817 1,
818 Some(0.3),
819 Some(Feature::from_vec(vec![0f32, 1.0f32, 0.0])),
820 None,
821 )?;
822 let r = t1.merge(&t2, &[1], true);
823 assert!(r.is_ok());
824 assert_eq!(t1.observations.get(&0).unwrap().len(), 1);
825 assert_eq!(t1.observations.get(&1).unwrap().len(), 1);
826 Ok(())
827 }
828
829 #[test]
830 fn attribute_compatible_match() -> Result<()> {
831 #[derive(Default, Debug, Clone)]
832 pub struct TimeAttrs {
833 start_time: u64,
834 end_time: u64,
835 }
836
837 #[derive(Default, Clone)]
838 pub struct TimeAttrUpdates {
839 time: u64,
840 }
841
842 impl TrackAttributesUpdate<TimeAttrs> for TimeAttrUpdates {
843 fn apply(&self, attrs: &mut TimeAttrs) -> Result<()> {
844 attrs.end_time = self.time;
845 if attrs.start_time == 0 {
846 attrs.start_time = self.time;
847 }
848 Ok(())
849 }
850 }
851
852 impl TrackAttributes<TimeAttrs, f32> for TimeAttrs {
853 type Update = TimeAttrUpdates;
854 type Lookup = NoopLookup<TimeAttrs, f32>;
855
856 fn compatible(&self, other: &TimeAttrs) -> bool {
857 self.end_time <= other.start_time
858 }
859
860 fn merge(&mut self, other: &TimeAttrs) -> Result<()> {
861 self.start_time = self.start_time.min(other.start_time);
862 self.end_time = self.end_time.max(other.end_time);
863 Ok(())
864 }
865
866 fn baked(&self, _observations: &ObservationsDb<f32>) -> Result<TrackStatus> {
867 if current_time_sec() - self.end_time > 30 {
868 Ok(TrackStatus::Ready)
869 } else {
870 Ok(TrackStatus::Pending)
871 }
872 }
873 }
874
875 #[derive(Default, Clone)]
876 struct TimeMetric;
877 impl ObservationMetric<TimeAttrs, f32> for TimeMetric {
878 fn metric(&self, mq: &MetricQuery<'_, TimeAttrs, f32>) -> MetricOutput<f32> {
879 let (e1, e2) = (mq.candidate_observation, mq.track_observation);
880 Some((
881 f32::calculate_metric_object(&e1.attr().as_ref(), &e2.attr().as_ref()),
882 match (e1.feature().as_ref(), e2.feature().as_ref()) {
883 (Some(x), Some(y)) => Some(euclidean(x, y)),
884 _ => None,
885 },
886 ))
887 }
888
889 fn optimize(
890 &mut self,
891 _feature_class: u64,
892 _merge_history: &[u64],
893 _attributes: &mut TimeAttrs,
894 features: &mut Vec<Observation<f32>>,
895 _prev_length: usize,
896 _is_merge: bool,
897 ) -> Result<()> {
898 features.sort_by(feature_attributes_sort_dec);
899 features.truncate(20);
900 Ok(())
901 }
902 }
903
904 let mut t1 = Track::new(1, TimeMetric::default(), TimeAttrs::default(), NoopNotifier);
905 t1.add_observation(
906 0,
907 Some(0.3),
908 Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
909 Some(TimeAttrUpdates { time: 2 }),
910 )?;
911
912 let mut t2 = Track::new(2, TimeMetric::default(), TimeAttrs::default(), NoopNotifier);
913 t2.add_observation(
914 0,
915 Some(0.3),
916 Some(Feature::from_vec(vec![0f32, 1.0f32, 0.0])),
917 Some(TimeAttrUpdates { time: 3 }),
918 )?;
919
920 let dists = t1.distances(&t2, 0);
921 let dists = dists.unwrap();
922 assert_eq!(dists.len(), 1);
923 assert!((*dists[0].feature_distance.as_ref().unwrap() - 2.0_f32.sqrt()).abs() < EPS);
924 assert_eq!(dists[0].to, 2);
925
926 let mut t3 = Track::new(3, TimeMetric::default(), TimeAttrs::default(), NoopNotifier);
927 t3.add_observation(
928 0,
929 Some(0.3),
930 Some(Feature::from_vec(vec![0f32, 1.0f32, 0.0])),
931 Some(TimeAttrUpdates { time: 1 }),
932 )?;
933
934 let dists = t1.distances(&t3, 0);
935 assert!(dists.is_err());
936 Ok(())
937 }
938
939 #[test]
940 fn get_classes() -> Result<()> {
941 let mut t1 = Track::new(1, DefaultMetric, DefaultAttrs, NoopNotifier);
942 t1.add_observation(
943 0,
944 Some(0.3),
945 Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
946 None,
947 )?;
948
949 t1.add_observation(
950 1,
951 Some(0.3),
952 Some(Feature::from_vec(vec![0f32, 1.0f32, 0.0])),
953 None,
954 )?;
955 let mut classes = t1.get_feature_classes();
956 classes.sort();
957
958 assert_eq!(classes, vec![0, 1]);
959
960 Ok(())
961 }
962
963 #[test]
964 fn attr_metric_update_recover() {
965 use thiserror::Error;
966
967 #[derive(Error, Debug)]
968 enum TestError {
969 #[error("Update Error")]
970 Update,
971 #[error("Unable to Merge")]
972 Merge,
973 #[error("Unable to Optimize")]
974 Optimize,
975 }
976
977 #[derive(Default, Clone, PartialEq, Eq, Debug)]
978 pub struct LocalAttrs {
979 pub count: u32,
980 }
981
982 #[derive(Clone)]
983 pub struct LocalAttrsUpdates {
984 ignore: bool,
985 }
986
987 impl TrackAttributesUpdate<LocalAttrs> for LocalAttrsUpdates {
988 fn apply(&self, attrs: &mut LocalAttrs) -> Result<()> {
989 if !self.ignore {
990 attrs.count += 1;
991 if attrs.count > 1 {
992 Err(TestError::Update.into())
993 } else {
994 Ok(())
995 }
996 } else {
997 Ok(())
998 }
999 }
1000 }
1001
1002 impl TrackAttributes<LocalAttrs, f32> for LocalAttrs {
1003 type Update = LocalAttrsUpdates;
1004 type Lookup = NoopLookup<LocalAttrs, f32>;
1005
1006 fn compatible(&self, _other: &LocalAttrs) -> bool {
1007 true
1008 }
1009
1010 fn merge(&mut self, _other: &LocalAttrs) -> Result<()> {
1011 Err(TestError::Merge.into())
1012 }
1013
1014 fn baked(&self, _observations: &ObservationsDb<f32>) -> Result<TrackStatus> {
1015 Ok(TrackStatus::Pending)
1016 }
1017 }
1018
1019 #[derive(Clone)]
1020 struct LocalMetric;
1021 impl ObservationMetric<LocalAttrs, f32> for LocalMetric {
1022 fn metric(&self, mq: &MetricQuery<LocalAttrs, f32>) -> MetricOutput<f32> {
1023 let (e1, e2) = (mq.candidate_observation, mq.track_observation);
1024 Some((
1025 f32::calculate_metric_object(&e1.attr().as_ref(), &e2.attr().as_ref()),
1026 match (e1.feature().as_ref(), e2.feature().as_ref()) {
1027 (Some(x), Some(y)) => Some(euclidean(x, y)),
1028 _ => None,
1029 },
1030 ))
1031 }
1032
1033 fn optimize(
1034 &mut self,
1035 _feature_class: u64,
1036 _merge_history: &[u64],
1037 _attributes: &mut LocalAttrs,
1038 _features: &mut Vec<Observation<f32>>,
1039 prev_length: usize,
1040 _is_merge: bool,
1041 ) -> Result<()> {
1042 if prev_length == 1 {
1043 Err(TestError::Optimize.into())
1044 } else {
1045 Ok(())
1046 }
1047 }
1048 }
1049
1050 let mut t1 = Track::new(1, LocalMetric, LocalAttrs::default(), NoopNotifier);
1051 assert_eq!(t1.attributes, LocalAttrs { count: 0 });
1052 let res = t1.add_observation(
1053 0,
1054 Some(0.3),
1055 Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
1056 Some(LocalAttrsUpdates { ignore: false }),
1057 );
1058 assert!(res.is_ok());
1059 assert_eq!(t1.attributes, LocalAttrs { count: 1 });
1060
1061 let res = t1.add_observation(
1062 0,
1063 Some(0.3),
1064 Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
1065 Some(LocalAttrsUpdates { ignore: true }),
1066 );
1067 assert!(res.is_err());
1068 if let Err(e) = res {
1069 match e.root_cause().downcast_ref::<TestError>().unwrap() {
1070 TestError::Update | TestError::Merge => {
1071 unreachable!();
1072 }
1073 TestError::Optimize => {}
1074 }
1075 } else {
1076 unreachable!();
1077 }
1078
1079 assert_eq!(t1.attributes, LocalAttrs { count: 1 });
1080
1081 let mut t2 = Track::new(2, LocalMetric, LocalAttrs::default(), NoopNotifier);
1082 assert_eq!(t2.attributes, LocalAttrs { count: 0 });
1083 let res = t2.add_observation(
1084 0,
1085 Some(0.3),
1086 Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
1087 Some(LocalAttrsUpdates { ignore: false }),
1088 );
1089 assert!(res.is_ok());
1090 assert_eq!(t2.attributes, LocalAttrs { count: 1 });
1091
1092 let res = t1.merge(&t2, &[0], true);
1093 if let Err(e) = res {
1094 match e.root_cause().downcast_ref::<TestError>().unwrap() {
1095 TestError::Update | TestError::Optimize => {
1096 unreachable!();
1097 }
1098 TestError::Merge => {}
1099 }
1100 } else {
1101 unreachable!();
1102 }
1103 assert_eq!(t1.attributes, LocalAttrs { count: 1 });
1104 }
1105
1106 #[test]
1107 fn merge_history() -> Result<()> {
1108 let mut t1 = Track::new(0, DefaultMetric, DefaultAttrs, NoopNotifier);
1109 let mut t2 = Track::new(1, DefaultMetric, DefaultAttrs, NoopNotifier);
1110
1111 t1.add_observation(
1112 0,
1113 Some(0.3),
1114 Some(Feature::from_vec(vec![1f32, 0.0, 0.0])),
1115 None,
1116 )?;
1117
1118 t2.add_observation(
1119 0,
1120 Some(0.3),
1121 Some(Feature::from_vec(vec![0f32, 1.0f32, 0.0])),
1122 None,
1123 )?;
1124
1125 let mut track_with_merge_history = t1.clone();
1126 let _r = track_with_merge_history.merge(&t2, &[0], true);
1127 assert_eq!(track_with_merge_history.merge_history, vec![0, 1]);
1128
1129 let _r = t1.merge(&t2, &[0], false);
1130 assert_eq!(t1.merge_history, vec![0]);
1131
1132 Ok(())
1133 }
1134
1135 #[test]
1136 fn unit_track() {
1137 #[derive(Clone)]
1138 pub struct UnitAttrs;
1139
1140 #[derive(Clone)]
1141 pub struct UnitAttrUpdates;
1142
1143 impl TrackAttributesUpdate<UnitAttrs> for UnitAttrUpdates {
1144 fn apply(&self, _attrs: &mut UnitAttrs) -> Result<()> {
1145 Ok(())
1146 }
1147 }
1148
1149 impl TrackAttributes<UnitAttrs, ()> for UnitAttrs {
1150 type Update = UnitAttrUpdates;
1151 type Lookup = NoopLookup<UnitAttrs, ()>;
1152
1153 fn compatible(&self, _other: &UnitAttrs) -> bool {
1154 true
1155 }
1156
1157 fn merge(&mut self, _other: &UnitAttrs) -> Result<()> {
1158 Ok(())
1159 }
1160
1161 fn baked(&self, _observations: &ObservationsDb<()>) -> Result<TrackStatus> {
1162 Ok(TrackStatus::Pending)
1163 }
1164 }
1165
1166 #[derive(Clone)]
1167 struct UnitMetric;
1168 impl ObservationMetric<UnitAttrs, ()> for UnitMetric {
1169 fn metric(&self, mq: &MetricQuery<UnitAttrs, ()>) -> MetricOutput<()> {
1170 let (e1, e2) = (mq.candidate_observation, mq.track_observation);
1171 Some((
1172 None,
1173 match (e1.1.as_ref(), e2.1.as_ref()) {
1174 (Some(x), Some(y)) => Some(euclidean(x, y)),
1175 _ => None,
1176 },
1177 ))
1178 }
1179
1180 fn optimize(
1181 &mut self,
1182 _feature_class: u64,
1183 _merge_history: &[u64],
1184 _attributes: &mut UnitAttrs,
1185 features: &mut Vec<Observation<()>>,
1186 _prev_length: usize,
1187 _is_merge: bool,
1188 ) -> Result<()> {
1189 features.sort_by(feature_attributes_sort_dec);
1190 features.truncate(20);
1191 Ok(())
1192 }
1193 }
1194
1195 let _t1 = Track::new(1, UnitMetric, UnitAttrs, NoopNotifier);
1196 }
1197
1198 #[test]
1199 fn lookup() {
1200 #[derive(Default, Clone)]
1201 struct Lookup;
1202 impl LookupRequest<LookupAttrs, f32> for Lookup {
1203 fn lookup(
1204 &self,
1205 _attributes: &LookupAttrs,
1206 _observations: &ObservationsDb<f32>,
1207 _merge_history: &[u64],
1208 ) -> bool {
1209 true
1210 }
1211 }
1212
1213 #[derive(Clone, Default)]
1214 struct LookupAttrs;
1215
1216 #[derive(Clone)]
1217 pub struct LookupAttributeUpdate;
1218
1219 impl TrackAttributesUpdate<LookupAttrs> for LookupAttributeUpdate {
1220 fn apply(&self, _attrs: &mut LookupAttrs) -> Result<()> {
1221 Ok(())
1222 }
1223 }
1224
1225 impl TrackAttributes<LookupAttrs, f32> for LookupAttrs {
1226 type Update = LookupAttributeUpdate;
1227 type Lookup = Lookup;
1228
1229 fn compatible(&self, _other: &LookupAttrs) -> bool {
1230 true
1231 }
1232
1233 fn merge(&mut self, _other: &LookupAttrs) -> Result<()> {
1234 Ok(())
1235 }
1236
1237 fn baked(&self, _observations: &ObservationsDb<f32>) -> Result<TrackStatus> {
1238 Ok(TrackStatus::Ready)
1239 }
1240 }
1241
1242 #[derive(Clone)]
1243 pub struct LookupMetric;
1244
1245 impl ObservationMetric<LookupAttrs, f32> for LookupMetric {
1246 fn metric(&self, mq: &MetricQuery<LookupAttrs, f32>) -> MetricOutput<f32> {
1247 let (e1, e2) = (mq.candidate_observation, mq.track_observation);
1248 Some((
1249 f32::calculate_metric_object(&e1.attr().as_ref(), &e2.attr().as_ref()),
1250 match (e1.feature().as_ref(), e2.feature().as_ref()) {
1251 (Some(x), Some(y)) => Some(euclidean(x, y)),
1252 _ => None,
1253 },
1254 ))
1255 }
1256
1257 fn optimize(
1258 &mut self,
1259 _feature_class: u64,
1260 _merge_history: &[u64],
1261 _attrs: &mut LookupAttrs,
1262 _features: &mut Vec<Observation<f32>>,
1263 _prev_length: usize,
1264 _is_merge: bool,
1265 ) -> Result<()> {
1266 Ok(())
1267 }
1268 }
1269
1270 let t: Track<LookupAttrs, LookupMetric, f32> = TrackBuilder::default()
1271 .metric(LookupMetric)
1272 .attributes(LookupAttrs)
1273 .notifier(NoopNotifier)
1274 .build()
1275 .unwrap();
1276 assert!(t.lookup(&Lookup));
1277 }
1278}