pub(crate) mod average_precision;
pub(crate) mod bpref;
pub(crate) mod f1;
pub(crate) mod hits;
pub(crate) mod ndcg;
pub(crate) mod precision;
pub(crate) mod r_precision;
pub(crate) mod recall;
pub(crate) mod reciprocal_rank;
pub(crate) mod success;
use std::collections::BTreeMap;
use std::fmt::Display;
use std::str::FromStr;
use regex::Regex;
use crate::errors::ElinorError;
use crate::PredRelStore;
use crate::TrueRelStore;
use crate::TrueScore;
pub(crate) const RELEVANT_LEVEL: TrueScore = 1;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Metric {
Hits {
k: usize,
},
Success {
k: usize,
},
Precision {
k: usize,
},
Recall {
k: usize,
},
F1 {
k: usize,
},
RPrecision,
AP {
k: usize,
},
RR {
k: usize,
},
Bpref,
DCG {
k: usize,
},
NDCG {
k: usize,
},
DCGBurges {
k: usize,
},
NDCGBurges {
k: usize,
},
}
impl Display for Metric {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Hits { k } => {
write!(f, "{}", format_metric("hits", *k))
}
Self::Success { k } => {
write!(f, "{}", format_metric("success", *k))
}
Self::Precision { k } => {
write!(f, "{}", format_metric("precision", *k))
}
Self::Recall { k } => {
write!(f, "{}", format_metric("recall", *k))
}
Self::F1 { k } => {
write!(f, "{}", format_metric("f1", *k))
}
Self::RPrecision => {
write!(f, "r_precision")
}
Self::AP { k } => {
write!(f, "{}", format_metric("ap", *k))
}
Self::RR { k } => {
write!(f, "{}", format_metric("rr", *k))
}
Self::Bpref => {
write!(f, "bpref")
}
Self::DCG { k } => {
write!(f, "{}", format_metric("dcg", *k))
}
Self::NDCG { k } => {
write!(f, "{}", format_metric("ndcg", *k))
}
Self::DCGBurges { k } => {
write!(f, "{}", format_metric("dcg_burges", *k))
}
Self::NDCGBurges { k } => {
write!(f, "{}", format_metric("ndcg_burges", *k))
}
}
}
}
fn format_metric(name: &str, k: usize) -> String {
if k == 0 {
name.to_string()
} else {
format!("{name}@{k}")
}
}
impl FromStr for Metric {
type Err = ElinorError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let re = Regex::new(r"^(?<metric>[a-z1-9_]+)(@(?<k>\d+))?$").unwrap();
let caps = re
.captures(s)
.ok_or_else(|| ElinorError::InvalidFormat(format!("Unsupported metric: {s}")))?;
let name = caps.name("metric").unwrap().as_str();
let k = caps
.name("k")
.map(|m| m.as_str().parse::<usize>())
.transpose()
.map_err(|_| ElinorError::InvalidFormat(s.to_string()))?
.unwrap_or(0);
match name {
"hits" => Ok(Self::Hits { k }),
"success" => Ok(Self::Success { k }),
"precision" => Ok(Self::Precision { k }),
"recall" => Ok(Self::Recall { k }),
"f1" => Ok(Self::F1 { k }),
"r_precision" => Ok(Self::RPrecision),
"ap" => Ok(Self::AP { k }),
"rr" => Ok(Self::RR { k }),
"bpref" => Ok(Self::Bpref),
"dcg" => Ok(Self::DCG { k }),
"ndcg" => Ok(Self::NDCG { k }),
"dcg_burges" => Ok(Self::DCGBurges { k }),
"ndcg_burges" => Ok(Self::NDCGBurges { k }),
_ => Err(ElinorError::InvalidFormat(s.to_string())),
}
}
}
pub fn compute_metric<K>(
true_rels: &TrueRelStore<K>,
pred_rels: &PredRelStore<K>,
metric: Metric,
) -> Result<BTreeMap<K, f64>, ElinorError>
where
K: Clone + Eq + Ord + std::fmt::Display,
{
for query_id in pred_rels.query_ids() {
if true_rels.get_map(query_id).is_none() {
return Err(ElinorError::MissingEntry(format!(
"The set of queries in true_rels must be a subset of that in pred_rels, but {} is missing",
query_id
)));
}
}
let mut results = BTreeMap::new();
for query_id in pred_rels.query_ids() {
let sorted_preds = pred_rels.get_sorted(query_id).unwrap();
let trues = true_rels.get_map(query_id).unwrap();
let score = match metric {
Metric::Hits { k } => hits::compute_hits(trues, sorted_preds, k, RELEVANT_LEVEL),
Metric::Success { k } => {
success::compute_success(trues, sorted_preds, k, RELEVANT_LEVEL)
}
Metric::Precision { k } => {
precision::compute_precision(trues, sorted_preds, k, RELEVANT_LEVEL)
}
Metric::Recall { k } => recall::compute_recall(trues, sorted_preds, k, RELEVANT_LEVEL),
Metric::F1 { k } => f1::compute_f1(trues, sorted_preds, k, RELEVANT_LEVEL),
Metric::RPrecision => {
r_precision::compute_r_precision(trues, sorted_preds, RELEVANT_LEVEL)
}
Metric::AP { k } => {
average_precision::compute_average_precision(trues, sorted_preds, k, RELEVANT_LEVEL)
}
Metric::RR { k } => {
reciprocal_rank::compute_reciprocal_rank(trues, sorted_preds, k, RELEVANT_LEVEL)
}
Metric::Bpref => bpref::compute_bpref(trues, sorted_preds, RELEVANT_LEVEL),
Metric::DCG { k } => {
ndcg::compute_dcg(trues, sorted_preds, k, ndcg::DcgWeighting::Jarvelin)
}
Metric::NDCG { k } => {
let sorted_trues = true_rels.get_sorted(query_id).unwrap();
ndcg::compute_ndcg(
trues,
sorted_trues,
sorted_preds,
k,
ndcg::DcgWeighting::Jarvelin,
)
}
Metric::DCGBurges { k } => {
ndcg::compute_dcg(trues, sorted_preds, k, ndcg::DcgWeighting::Burges)
}
Metric::NDCGBurges { k } => {
let sorted_trues = true_rels.get_sorted(query_id).unwrap();
ndcg::compute_ndcg(
trues,
sorted_trues,
sorted_preds,
k,
ndcg::DcgWeighting::Burges,
)
}
};
results.insert(query_id.clone(), score);
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Record;
use approx::assert_relative_eq;
use maplit::btreemap;
use rstest::*;
const LOG_2_2: f64 = 1.0;
const LOG_2_3: f64 = 1.584962500721156;
const LOG_2_4: f64 = 2.0;
fn compare_hashmaps(a: &BTreeMap<char, f64>, b: &BTreeMap<char, f64>) {
assert_eq!(a.len(), b.len());
for (k, v) in a.iter() {
assert_relative_eq!(v, b.get(k).unwrap());
}
}
#[rstest]
#[case::hits_k_0(Metric::Hits { k: 0 }, btreemap! { 'A' => 2.0 })]
#[case::hits_k_1(Metric::Hits { k: 1 }, btreemap! { 'A' => 1.0 })]
#[case::hits_k_2(Metric::Hits { k: 2 }, btreemap! { 'A' => 1.0 })]
#[case::hits_k_3(Metric::Hits { k: 3 }, btreemap! { 'A' => 2.0 })]
#[case::hits_k_4(Metric::Hits { k: 4 }, btreemap! { 'A' => 2.0 })]
#[case::hits_k_5(Metric::Hits { k: 5 }, btreemap! { 'A' => 2.0 })]
#[case::success_k_0(Metric::Success { k: 0 }, btreemap! { 'A' => 1.0 })]
#[case::success_k_1(Metric::Success { k: 1 }, btreemap! { 'A' => 1.0 })]
#[case::success_k_2(Metric::Success { k: 2 }, btreemap! { 'A' => 1.0 })]
#[case::success_k_3(Metric::Success { k: 3 }, btreemap! { 'A' => 1.0 })]
#[case::success_k_4(Metric::Success { k: 4 }, btreemap! { 'A' => 1.0 })]
#[case::success_k_5(Metric::Success { k: 5 }, btreemap! { 'A' => 1.0 })]
#[case::precision_k_0(Metric::Precision { k: 0 }, btreemap! { 'A' => 2.0 / 4.0 })]
#[case::precision_k_1(Metric::Precision { k: 1 }, btreemap! { 'A' => 1.0 / 1.0 })]
#[case::precision_k_2(Metric::Precision { k: 2 }, btreemap! { 'A' => 1.0 / 2.0 })]
#[case::precision_k_3(Metric::Precision { k: 3 }, btreemap! { 'A' => 2.0 / 3.0 })]
#[case::precision_k_4(Metric::Precision { k: 4 }, btreemap! { 'A' => 2.0 / 4.0 })]
#[case::precision_k_5(Metric::Precision { k: 5 }, btreemap! { 'A' => 2.0 / 5.0 })]
#[case::recall_k_0(Metric::Recall { k: 0 }, btreemap! { 'A' => 2.0 / 2.0 })]
#[case::recall_k_1(Metric::Recall { k: 1 }, btreemap! { 'A' => 1.0 / 2.0 })]
#[case::recall_k_2(Metric::Recall { k: 2 }, btreemap! { 'A' => 1.0 / 2.0 })]
#[case::recall_k_3(Metric::Recall { k: 3 }, btreemap! { 'A' => 2.0 / 2.0 })]
#[case::recall_k_4(Metric::Recall { k: 4 }, btreemap! { 'A' => 2.0 / 2.0 })]
#[case::recall_k_5(Metric::Recall { k: 5 }, btreemap! { 'A' => 2.0 / 2.0 })]
#[case::f1_k_0(Metric::F1 { k: 0 }, btreemap! { 'A' => 2.0 * (2.0 / 4.0) * (2.0 / 2.0) / ((2.0 / 4.0) + (2.0 / 2.0)) })]
#[case::f1_k_1(Metric::F1 { k: 1 }, btreemap! { 'A' => 2.0 * (1.0 / 1.0) * (1.0 / 2.0) / ((1.0 / 1.0) + (1.0 / 2.0)) })]
#[case::f1_k_2(Metric::F1 { k: 2 }, btreemap! { 'A' => 2.0 * (1.0 / 2.0) * (1.0 / 2.0) / ((1.0 / 2.0) + (1.0 / 2.0)) })]
#[case::f1_k_3(Metric::F1 { k: 3 }, btreemap! { 'A' => 2.0 * (2.0 / 3.0) * (2.0 / 2.0) / ((2.0 / 3.0) + (2.0 / 2.0)) })]
#[case::f1_k_4(Metric::F1 { k: 4 }, btreemap! { 'A' => 2.0 * (2.0 / 4.0) * (2.0 / 2.0) / ((2.0 / 4.0) + (2.0 / 2.0)) })]
#[case::f1_k_5(Metric::F1 { k: 5 }, btreemap! { 'A' => 2.0 * (2.0 / 5.0) * (2.0 / 2.0) / ((2.0 / 5.0) + (2.0 / 2.0)) })]
#[case::r_precision(Metric::RPrecision, btreemap! { 'A' => 1.0 / 2.0 })]
#[case::average_precision_k_0(Metric::AP { k: 0 }, btreemap! { 'A' => ((1.0 / 1.0) + (2.0 / 3.0)) / 2.0 })]
#[case::average_precision_k_1(Metric::AP { k: 1 }, btreemap! { 'A' => (1.0 / 1.0) / 2.0 })]
#[case::average_precision_k_2(Metric::AP { k: 2 }, btreemap! { 'A' => (1.0 / 1.0) / 2.0 })]
#[case::average_precision_k_3(Metric::AP { k: 3 }, btreemap! { 'A' => ((1.0 / 1.0) + (2.0 / 3.0)) / 2.0 })]
#[case::average_precision_k_4(Metric::AP { k: 4 }, btreemap! { 'A' => ((1.0 / 1.0) + (2.0 / 3.0)) / 2.0 })]
#[case::average_precision_k_5(Metric::AP { k: 5 }, btreemap! { 'A' => ((1.0 / 1.0) + (2.0 / 3.0)) / 2.0 })]
#[case::reciprocal_rank_k_0(Metric::RR { k: 0 }, btreemap! { 'A' => 1.0 / 1.0 })]
#[case::reciprocal_rank_k_1(Metric::RR { k: 1 }, btreemap! { 'A' => 1.0 / 1.0 })]
#[case::reciprocal_rank_k_2(Metric::RR { k: 2 }, btreemap! { 'A' => 1.0 / 1.0 })]
#[case::reciprocal_rank_k_3(Metric::RR { k: 3 }, btreemap! { 'A' => 1.0 / 1.0 })]
#[case::reciprocal_rank_k_4(Metric::RR { k: 4 }, btreemap! { 'A' => 1.0 / 1.0 })]
#[case::reciprocal_rank_k_5(Metric::RR { k: 5 }, btreemap! { 'A' => 1.0 / 1.0 })]
#[case::bpref(Metric::Bpref, btreemap! { 'A' => (1.0 + (1.0 - 1.0 / 1.0)) / 2.0 })]
#[case::dcg_k_0_jarvelin(Metric::DCG { k: 0 }, btreemap! { 'A' => 1.0 / LOG_2_2 + 2.0 / LOG_2_4 })]
#[case::dcg_k_1_jarvelin(Metric::DCG { k: 1 }, btreemap! { 'A' => 1.0 / LOG_2_2 })]
#[case::dcg_k_2_jarvelin(Metric::DCG { k: 2 }, btreemap! { 'A' => 1.0 / LOG_2_2 })]
#[case::dcg_k_3_jarvelin(Metric::DCG { k: 3 }, btreemap! { 'A' => 1.0 / LOG_2_2 + 2.0 / LOG_2_4 })]
#[case::dcg_k_4_jarvelin(Metric::DCG { k: 4 }, btreemap! { 'A' => 1.0 / LOG_2_2 + 2.0 / LOG_2_4 })]
#[case::dcg_k_5_jarvelin(Metric::DCG { k: 5 }, btreemap! { 'A' => 1.0 / LOG_2_2 + 2.0 / LOG_2_4 })]
#[case::ndcg_k_0_jarvelin(Metric::NDCG { k: 0 }, btreemap! { 'A' => (1.0 / LOG_2_2 + 2.0 / LOG_2_4) / (2.0 / LOG_2_2 + 1.0 / LOG_2_3) })]
#[case::ndcg_k_1_jarvelin(Metric::NDCG { k: 1 }, btreemap! { 'A' => (1.0 / LOG_2_2) / (2.0 / LOG_2_2) })]
#[case::ndcg_k_2_jarvelin(Metric::NDCG { k: 2 }, btreemap! { 'A' => (1.0 / LOG_2_2) / (2.0 / LOG_2_2 + 1.0 / LOG_2_3) })]
#[case::ndcg_k_3_jarvelin(Metric::NDCG { k: 3 }, btreemap! { 'A' => (1.0 / LOG_2_2 + 2.0 / LOG_2_4) / (2.0 / LOG_2_2 + 1.0 / LOG_2_3) })]
#[case::ndcg_k_4_jarvelin(Metric::NDCG { k: 4 }, btreemap! { 'A' => (1.0 / LOG_2_2 + 2.0 / LOG_2_4) / (2.0 / LOG_2_2 + 1.0 / LOG_2_3) })]
#[case::ndcg_k_5_jarvelin(Metric::NDCG { k: 5 }, btreemap! { 'A' => (1.0 / LOG_2_2 + 2.0 / LOG_2_4) / (2.0 / LOG_2_2 + 1.0 / LOG_2_3) })]
#[case::dcg_k_0_burges(Metric::DCGBurges { k: 0 }, btreemap! { 'A' => 1.0 / LOG_2_2 + 3.0 / LOG_2_4 })]
#[case::dcg_k_1_burges(Metric::DCGBurges { k: 1 }, btreemap! { 'A' => 1.0 / LOG_2_2 })]
#[case::dcg_k_2_burges(Metric::DCGBurges { k: 2 }, btreemap! { 'A' => 1.0 / LOG_2_2 })]
#[case::dcg_k_3_burges(Metric::DCGBurges { k: 3 }, btreemap! { 'A' => 1.0 / LOG_2_2 + 3.0 / LOG_2_4 })]
#[case::dcg_k_4_burges(Metric::DCGBurges { k: 4 }, btreemap! { 'A' => 1.0 / LOG_2_2 + 3.0 / LOG_2_4 })]
#[case::dcg_k_5_burges(Metric::DCGBurges { k: 5 }, btreemap! { 'A' => 1.0 / LOG_2_2 + 3.0 / LOG_2_4 })]
#[case::ndcg_k_0_burges(Metric::NDCGBurges { k: 0 }, btreemap! { 'A' => (1.0 / LOG_2_2 + 3.0 / LOG_2_4) / (3.0 / LOG_2_2 + 1.0 / LOG_2_3) })]
#[case::ndcg_k_1_burges(Metric::NDCGBurges { k: 1 }, btreemap! { 'A' => (1.0 / LOG_2_2) / (3.0 / LOG_2_2) })]
#[case::ndcg_k_2_burges(Metric::NDCGBurges { k: 2 }, btreemap! { 'A' => (1.0 / LOG_2_2) / (3.0 / LOG_2_2 + 1.0 / LOG_2_3) })]
#[case::ndcg_k_3_burges(Metric::NDCGBurges { k: 3 }, btreemap! { 'A' => (1.0 / LOG_2_2 + 3.0 / LOG_2_4) / (3.0 / LOG_2_2 + 1.0 / LOG_2_3) })]
#[case::ndcg_k_4_burges(Metric::NDCGBurges { k: 4 }, btreemap! { 'A' => (1.0 / LOG_2_2 + 3.0 / LOG_2_4) / (3.0 / LOG_2_2 + 1.0 / LOG_2_3) })]
#[case::ndcg_k_5_burges(Metric::NDCGBurges { k: 5 }, btreemap! { 'A' => (1.0 / LOG_2_2 + 3.0 / LOG_2_4) / (3.0 / LOG_2_2 + 1.0 / LOG_2_3) })]
fn test_compute_metric(#[case] metric: Metric, #[case] expected: BTreeMap<char, f64>) {
let true_rels = TrueRelStore::from_records([
Record {
query_id: 'A',
doc_id: 'X',
score: 1,
},
Record {
query_id: 'A',
doc_id: 'Y',
score: 0,
},
Record {
query_id: 'A',
doc_id: 'Z',
score: 2,
},
])
.unwrap();
let pred_rels = PredRelStore::from_records([
Record {
query_id: 'A',
doc_id: 'X',
score: 0.5.into(),
},
Record {
query_id: 'A',
doc_id: 'Y',
score: 0.4.into(),
},
Record {
query_id: 'A',
doc_id: 'Z',
score: 0.3.into(),
},
Record {
query_id: 'A',
doc_id: 'W',
score: 0.2.into(),
},
])
.unwrap();
let results = compute_metric(&true_rels, &pred_rels, metric).unwrap();
compare_hashmaps(&results, &expected);
}
#[rstest]
#[case::hits("hits", Metric::Hits { k: 0 })]
#[case::hits_k0("hits@0", Metric::Hits { k: 0 })]
#[case::hits_k1("hits@1", Metric::Hits { k: 1 })]
#[case::hits_k100("hits@100", Metric::Hits { k: 100 })]
#[case::success("success", Metric::Success { k: 0 })]
#[case::success_k0("success@0", Metric::Success { k: 0 })]
#[case::success_k1("success@1", Metric::Success { k: 1 })]
#[case::success_k100("success@100", Metric::Success { k: 100 })]
#[case::precision("precision", Metric::Precision { k: 0 })]
#[case::precision_k0("precision@0", Metric::Precision { k: 0 })]
#[case::precision_k1("precision@1", Metric::Precision { k: 1 })]
#[case::precision_k100("precision@100", Metric::Precision { k: 100 })]
#[case::recall("recall", Metric::Recall { k: 0 })]
#[case::recall_k0("recall@0", Metric::Recall { k: 0 })]
#[case::recall_k1("recall@1", Metric::Recall { k: 1 })]
#[case::recall_k100("recall@100", Metric::Recall { k: 100 })]
#[case::f1("f1", Metric::F1 { k: 0 })]
#[case::f1_k0("f1@0", Metric::F1 { k: 0 })]
#[case::f1_k1("f1@1", Metric::F1 { k: 1 })]
#[case::f1_k100("f1@100", Metric::F1 { k: 100 })]
#[case::r_precision("r_precision", Metric::RPrecision)]
#[case::average_precision("ap", Metric::AP { k: 0 })]
#[case::average_precision_k0("ap@0", Metric::AP { k: 0 })]
#[case::average_precision_k1("ap@1", Metric::AP { k: 1 })]
#[case::average_precision_k100("ap@100", Metric::AP { k: 100 })]
#[case::reciprocal_rank("rr", Metric::RR { k: 0 })]
#[case::reciprocal_rank_k0("rr@0", Metric::RR { k: 0 })]
#[case::reciprocal_rank_k1("rr@1", Metric::RR { k: 1 })]
#[case::reciprocal_rank_k100("rr@100", Metric::RR { k: 100 })]
#[case::bpref("bpref", Metric::Bpref)]
#[case::dcg("dcg", Metric::DCG { k: 0 })]
#[case::dcg_k0("dcg@0", Metric::DCG { k: 0 })]
#[case::dcg_k1("dcg@1", Metric::DCG { k: 1 })]
#[case::dcg_k100("dcg@100", Metric::DCG { k: 100 })]
#[case::ndcg("ndcg", Metric::NDCG { k: 0 })]
#[case::ndcg_k0("ndcg@0", Metric::NDCG { k: 0 })]
#[case::ndcg_k1("ndcg@1", Metric::NDCG { k: 1 })]
#[case::ndcg_k100("ndcg@100", Metric::NDCG { k: 100 })]
#[case::dcg_burges("dcg_burges", Metric::DCGBurges { k: 0 })]
#[case::dcg_burges_k0("dcg_burges@0", Metric::DCGBurges { k: 0 })]
#[case::dcg_burges_k1("dcg_burges@1", Metric::DCGBurges { k: 1 })]
#[case::dcg_burges_k100("dcg_burges@100", Metric::DCGBurges { k: 100 })]
#[case::ndcg_burges("ndcg_burges", Metric::NDCGBurges { k: 0 })]
#[case::ndcg_burges_k0("ndcg_burges@0", Metric::NDCGBurges { k: 0 })]
#[case::ndcg_burges_k1("ndcg_burges@1", Metric::NDCGBurges { k: 1 })]
#[case::ndcg_burges_k100("ndcg_burges@100", Metric::NDCGBurges { k: 100 })]
fn test_metric_from_str(#[case] input: &str, #[case] expected: Metric) {
let metric = Metric::from_str(input).unwrap();
assert_eq!(metric, expected);
}
}