use std::cmp::Ordering;
use std::collections::BinaryHeap;
use comp_cat_rs::effect::io::Io;
use crate::score::ScoredCandidate;
use crate::stage::Stage;
#[derive(Debug, Clone, Copy)]
pub struct MaxItems(usize);
impl MaxItems {
#[must_use]
pub fn new(n: usize) -> Self {
Self(n)
}
#[must_use]
pub fn value(self) -> usize {
self.0
}
}
#[derive(Debug, Clone, Copy)]
pub struct Budget {
max_items: MaxItems,
}
impl Budget {
#[must_use]
pub fn new(max_items: MaxItems) -> Self {
Self { max_items }
}
#[must_use]
pub fn max_items(&self) -> MaxItems {
self.max_items
}
}
pub trait Selector<I, E> {
fn select(
&self,
candidates: Vec<ScoredCandidate<I>>,
budget: &Budget,
) -> Io<E, Vec<ScoredCandidate<I>>>;
}
pub fn selector_stage<S, I, E>(
selector: S,
budget: Budget,
) -> Stage<E, Vec<ScoredCandidate<I>>, Vec<ScoredCandidate<I>>>
where
S: Selector<I, E> + Send + 'static,
I: Send + 'static,
E: Send + 'static,
{
Stage::new(move |candidates| selector.select(candidates, &budget))
}
pub struct TopNSelector;
impl<I: Send + 'static, E: Send + 'static> Selector<I, E> for TopNSelector {
fn select(
&self,
candidates: Vec<ScoredCandidate<I>>,
budget: &Budget,
) -> Io<E, Vec<ScoredCandidate<I>>> {
let n = budget.max_items().value();
Io::pure(top_n_by_score(candidates, n))
}
}
struct ByScore<I>(ScoredCandidate<I>);
impl<I> Eq for ByScore<I> {}
impl<I> PartialEq for ByScore<I> {
fn eq(&self, other: &Self) -> bool {
self.0.score() == other.0.score()
}
}
impl<I> PartialOrd for ByScore<I> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<I> Ord for ByScore<I> {
fn cmp(&self, other: &Self) -> Ordering {
self.0.score().cmp(&other.0.score())
}
}
fn top_n_by_score<I>(candidates: Vec<ScoredCandidate<I>>, n: usize) -> Vec<ScoredCandidate<I>> {
let heap: BinaryHeap<ByScore<I>> = candidates.into_iter().map(ByScore).collect();
heap.into_sorted_vec()
.into_iter()
.rev()
.take(n)
.map(|by| by.0)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::score::Score;
fn score(v: f64) -> Result<Score, &'static str> {
Score::new(v).ok_or("test score was NaN")
}
#[test]
fn top_n_selects_highest_scores() -> Result<(), &'static str> {
let candidates = vec![
ScoredCandidate::new("c", score(1.0)?),
ScoredCandidate::new("a", score(3.0)?),
ScoredCandidate::new("b", score(2.0)?),
];
let result = top_n_by_score(candidates, 2);
assert_eq!(result.len(), 2);
assert_eq!(*result[0].item(), "a");
assert_eq!(*result[1].item(), "b");
Ok(())
}
#[test]
fn top_n_with_n_greater_than_len_returns_all() -> Result<(), &'static str> {
let candidates = vec![
ScoredCandidate::new("a", score(1.0)?),
ScoredCandidate::new("b", score(2.0)?),
];
let result = top_n_by_score(candidates, 10);
assert_eq!(result.len(), 2);
Ok(())
}
#[test]
fn top_n_with_zero_returns_empty() -> Result<(), &'static str> {
let candidates = vec![
ScoredCandidate::new("a", score(1.0)?),
];
let result = top_n_by_score(candidates, 0);
assert!(result.is_empty());
Ok(())
}
#[test]
fn top_n_empty_input() {
let candidates: Vec<ScoredCandidate<&str>> = Vec::new();
let result = top_n_by_score(candidates, 5);
assert!(result.is_empty());
}
}