use crate::track::{ObservationAttributes, ObservationMetricOk};
use crate::voting::Voting;
use itertools::Itertools;
use std::collections::HashMap;
use std::marker::PhantomData;
pub struct TopNVoting<OA>
where
OA: ObservationAttributes,
{
topn: usize,
max_distance: f32,
min_votes: usize,
_phony: PhantomData<OA>,
}
impl<OA> TopNVoting<OA>
where
OA: ObservationAttributes,
{
pub fn new(topn: usize, max_distance: f32, min_votes: usize) -> Self {
Self {
topn,
max_distance,
min_votes,
_phony: PhantomData,
}
}
}
#[derive(Default, Debug, PartialEq)]
pub struct TopNVotingElt {
pub query_track: u64,
pub winner_track: u64,
pub weight: f64,
}
impl TopNVotingElt {
pub fn new(query_track: u64, winner_track: u64, weight: f64) -> Self {
Self {
query_track,
winner_track,
weight,
}
}
}
impl<OA> Voting<OA> for TopNVoting<OA>
where
OA: ObservationAttributes,
{
type WinnerObject = TopNVotingElt;
fn winners<T>(&self, distances: T) -> HashMap<u64, Vec<TopNVotingElt>>
where
T: IntoIterator<Item = ObservationMetricOk<OA>>,
{
let mut max_dist = -1.0_f32;
let counts: Vec<_> = distances
.into_iter()
.filter(
|ObservationMetricOk {
from: _,
to: _track,
attribute_metric: _f_attr_dist,
feature_distance: feat_dist,
}| match feat_dist {
Some(e) => {
if max_dist < *e {
max_dist = *e;
}
*e <= self.max_distance
}
_ => false,
},
)
.map(
|ObservationMetricOk {
from: src_track,
to: dest_track,
attribute_metric: _,
feature_distance: dist,
}| { ((src_track, dest_track), dist.unwrap()) },
)
.into_group_map()
.into_iter()
.filter(|(_, count)| count.len() >= self.min_votes)
.map(|((q, w), c)| {
let weight = c.into_iter().map(|d| (max_dist - d) as f64).sum();
TopNVotingElt {
query_track: q,
winner_track: w,
weight,
}
})
.collect::<Vec<_>>();
let mut results: HashMap<u64, Vec<TopNVotingElt>> = HashMap::new();
for c in counts {
let key = c.query_track;
if let Some(val) = results.get_mut(&key) {
val.push(c);
} else {
results.insert(key, vec![c]);
}
}
for counts in results.values_mut() {
counts.sort_by(|l, r| r.weight.partial_cmp(&l.weight).unwrap());
counts.truncate(self.topn);
}
results
}
}
#[cfg(test)]
mod tests {
use crate::track::voting::topn::{TopNVoting, TopNVotingElt, Voting};
use crate::track::ObservationMetricOk;
use std::collections::HashMap;
#[test]
fn default_voting() {
let v: TopNVoting<()> = TopNVoting::new(5, 0.32, 1);
let candidates = v.winners([ObservationMetricOk::new(0, 1, None, Some(0.2))]);
assert_eq!(
candidates,
HashMap::from([(0, vec![TopNVotingElt::new(0, 1, 0.0)])])
);
let candidates = v.winners([
ObservationMetricOk::new(0, 1, None, Some(0.2)),
ObservationMetricOk::new(0, 1, None, Some(0.3)),
]);
assert_eq!(
candidates,
HashMap::from([(0, vec![TopNVotingElt::new(0, 1, 0.10000000894069672)])])
);
let candidates = v.winners([
ObservationMetricOk::new(0, 1, None, Some(0.2)),
ObservationMetricOk::new(0, 1, None, Some(0.4)),
]);
assert_eq!(
candidates,
HashMap::from([(0, vec![TopNVotingElt::new(0, 1, 0.20000000298023224)])])
);
let mut candidates = v.winners([
ObservationMetricOk::new(0, 1, None, Some(0.2)),
ObservationMetricOk::new(0, 2, None, Some(0.2)),
]);
candidates
.get_mut(&0)
.unwrap()
.sort_by(|l, r| l.winner_track.partial_cmp(&r.winner_track).unwrap());
assert_eq!(
candidates,
HashMap::from([(
0,
vec![TopNVotingElt::new(0, 1, 0.0), TopNVotingElt::new(0, 2, 0.0)]
)])
);
let mut candidates = v.winners([
ObservationMetricOk::new(0, 1, None, Some(0.2)),
ObservationMetricOk::new(0, 1, None, Some(0.22)),
ObservationMetricOk::new(0, 2, None, Some(0.21)),
ObservationMetricOk::new(0, 2, None, Some(0.2)),
ObservationMetricOk::new(0, 3, None, Some(0.22)),
ObservationMetricOk::new(0, 3, None, Some(0.2)),
ObservationMetricOk::new(0, 4, None, Some(0.23)),
ObservationMetricOk::new(0, 4, None, Some(0.3)),
ObservationMetricOk::new(0, 5, None, Some(0.24)),
ObservationMetricOk::new(0, 5, None, Some(0.3)),
ObservationMetricOk::new(0, 6, None, Some(0.25)),
ObservationMetricOk::new(0, 6, None, Some(0.5)),
]);
candidates
.get_mut(&0)
.unwrap()
.sort_by(|l, r| l.winner_track.partial_cmp(&r.winner_track).unwrap());
assert_eq!(
candidates,
HashMap::from([(
0,
vec![
TopNVotingElt::new(0, 1, 0.5800000131130219),
TopNVotingElt::new(0, 2, 0.5900000333786011),
TopNVotingElt::new(0, 3, 0.5800000131130219),
TopNVotingElt::new(0, 4, 0.4699999690055847),
TopNVotingElt::new(0, 5, 0.4599999785423279)
]
)])
);
}
#[test]
fn two_query_vecs() {
let v: TopNVoting<f32> = TopNVoting::new(5, 0.32, 1);
let mut candidates = v.winners([
ObservationMetricOk::new(0, 1, None, Some(0.2)),
ObservationMetricOk::new(0, 1, None, Some(0.22)),
ObservationMetricOk::new(0, 2, None, Some(0.21)),
ObservationMetricOk::new(0, 2, None, Some(0.2)),
ObservationMetricOk::new(0, 3, None, Some(0.22)),
ObservationMetricOk::new(0, 3, None, Some(0.2)),
ObservationMetricOk::new(7, 4, None, Some(0.23)),
ObservationMetricOk::new(7, 4, None, Some(0.3)),
ObservationMetricOk::new(7, 5, None, Some(0.24)),
ObservationMetricOk::new(7, 5, None, Some(0.3)),
ObservationMetricOk::new(7, 6, None, Some(0.25)),
ObservationMetricOk::new(7, 6, None, Some(0.5)),
]);
candidates
.get_mut(&0)
.unwrap()
.sort_by(|l, r| l.winner_track.partial_cmp(&r.winner_track).unwrap());
candidates
.get_mut(&7)
.unwrap()
.sort_by(|l, r| l.winner_track.partial_cmp(&r.winner_track).unwrap());
assert_eq!(
candidates,
HashMap::from([
(
0,
vec![
TopNVotingElt::new(0, 1, 0.5800000131130219),
TopNVotingElt::new(0, 2, 0.5900000333786011),
TopNVotingElt::new(0, 3, 0.5800000131130219),
]
),
(
7,
vec![
TopNVotingElt::new(7, 4, 0.4699999690055847),
TopNVotingElt::new(7, 5, 0.4599999785423279),
TopNVotingElt::new(7, 6, 0.250)
]
)
])
);
}
}