use comfy_table::Cell;
use comfy_table::Row;
use comfy_table::Table;
use hdrhistogram::Histogram;
use std::fmt::Display;
use std::future::Future;
use std::ops::RangeInclusive;
use std::sync::OnceLock;
#[cfg(feature = "bert")]
use kalosm_language::prelude::Bert;
#[cfg(feature = "bert")]
use kalosm_language::prelude::Embedder;
pub trait Metric<T> {
const RANGE: RangeInclusive<f64> = 0.0..=1.0;
fn distance(&mut self, first: &T, other: &T) -> impl Future<Output = f64> + Send;
}
#[cfg(feature = "bert")]
pub struct BertDistance {
bert: Bert,
}
#[cfg(feature = "bert")]
impl BertDistance {
pub fn new(model: Bert) -> Self {
BertDistance { bert: model }
}
}
#[cfg(feature = "bert")]
impl<S: ToString + Send + Sync> Metric<S> for BertDistance {
async fn distance(&mut self, first: &S, other: &S) -> f64 {
let embeddings = self
.bert
.embed_vec(vec![first.to_string(), other.to_string()])
.await
.expect("Failed to embed text with Bert");
let [first_embedding, other_embedding] = embeddings
.try_into()
.expect("Failed to get two embeddings from the batch of two input texts from Bert");
first_embedding.cosine_similarity(&other_embedding).into()
}
}
pub struct TestCases<I> {
name: String,
tests: Vec<TestCase<I>>,
}
impl<I> Default for TestCases<I> {
#[track_caller]
fn default() -> Self {
Self::new()
}
}
impl<I> TestCases<I> {
#[track_caller]
pub fn new() -> Self {
TestCases {
name: std::panic::Location::caller().to_string(),
tests: Vec::new(),
}
}
pub fn with_name(mut self, name: impl Display) -> Self {
self.name = name.to_string();
self
}
pub fn with_case(mut self, expected: I, actual: I) -> Self {
self.tests.push(TestCase { expected, actual });
self
}
pub fn push_case(&mut self, expected: I, actual: I) {
self.tests.push(TestCase { expected, actual });
}
pub async fn evaluate<M: Metric<I>>(&mut self, metric: &mut M) -> EvaluationResult<'_, I> {
let mut values = Vec::new();
for case in &self.tests {
let TestCase { expected, actual } = case;
let distance = metric.distance(expected, actual).await;
values.push(TestCaseScored {
case,
score: distance,
});
}
values.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
EvaluationResult {
name: self.name.clone(),
histogram: OnceLock::new(),
tests: values,
range: M::RANGE,
}
}
}
#[derive(Clone)]
pub struct EvaluationResult<'a, I> {
name: String,
histogram: OnceLock<Histogram<u64>>,
tests: Vec<TestCaseScored<'a, I>>,
range: RangeInclusive<f64>,
}
impl<I> EvaluationResult<'_, I> {
const SCALE_FACTOR: f64 = 10000.0;
fn histogram_scale_factor(&self) -> f64 {
let min = self.range.start();
let max = self.range.end();
Self::SCALE_FACTOR / (max - min)
}
fn scale_value(&self, value: f64) -> f64 {
let min = self.range.start();
let scale_factor = self.histogram_scale_factor();
(value - min) * scale_factor
}
fn unscale_value(&self, value: f64) -> f64 {
let min = self.range.start();
let scale_factor = self.histogram_scale_factor();
value / scale_factor + min
}
fn histogram(&self) -> &Histogram<u64> {
self.histogram.get_or_init(|| {
let mut histogram = Histogram::<u64>::new(3).unwrap();
for test in &self.tests {
histogram
.record(self.scale_value(test.score) as u64)
.expect("Failed to record score");
}
histogram
})
}
pub fn mean_score(&self) -> f64 {
self.unscale_value(self.histogram().mean())
}
pub fn median_score(&self) -> f64 {
self.unscale_value(self.histogram().value_at_percentile(50.0) as f64)
}
pub fn min_score(&self) -> f64 {
self.unscale_value(self.histogram().min() as f64)
}
pub fn max_score(&self) -> f64 {
self.unscale_value(self.histogram().max() as f64)
}
pub fn quantile_score(&self, quantile: f64) -> f64 {
self.unscale_value(self.histogram().value_at_percentile(quantile * 100.0) as f64)
}
pub fn normalize_score(&self, score: f64) -> f64 {
let min = self.range.start();
let max = self.range.end();
(score - min) / (max - min)
}
pub fn denormalize_score(&self, score: f64) -> f64 {
let min = self.range.start();
let max = self.range.end();
score * (max - min) + min
}
pub fn normalized(self) -> Self {
let mut normalized_values = self.tests;
let min = self.range.start();
let max = self.range.end();
let range = max - min;
for test in &mut normalized_values {
test.score = (test.score - min) / range;
}
EvaluationResult {
name: self.name,
histogram: OnceLock::new(),
tests: normalized_values,
range: 0.0..=1.0,
}
}
}
impl<I: Display> std::fmt::Display for EvaluationResult<'_, I> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let histogram = self.histogram();
let mut statistics = Table::new();
statistics.set_header(vec!["Statistic", "Value"]);
statistics.add_row(vec![
Cell::new("Mean"),
Cell::new(format!("{:.2}", self.mean_score())),
]);
statistics.add_row(vec![
Cell::new("Median"),
Cell::new(format!("{:.2}", self.median_score())),
]);
statistics.add_row(vec![
Cell::new("Min"),
Cell::new(format!("{:.2}", self.min_score())),
]);
statistics.add_row(vec![
Cell::new("Max"),
Cell::new(format!("{:.2}", self.max_score())),
]);
statistics.add_row(vec![
Cell::new("25th Percentile"),
Cell::new(format!("{:.2}", self.quantile_score(0.25))),
]);
statistics.add_row(vec![
Cell::new("75th Percentile"),
Cell::new(format!("{:.2}", self.quantile_score(0.75))),
]);
writeln!(f, "{}", statistics)?;
let mut table = Table::new();
table.set_header(vec!["Expected Output", "Actual Output", "Score"]);
let bottom_third_of_metric =
self.range.start() + (self.range.end() - self.range.start()) / 3.0;
let bottom_half_of_metric =
self.range.start() + (self.range.end() - self.range.start()) / 2.0;
fn create_cell(score: f64, quantile: f64) -> Cell {
if quantile <= 0.1 {
Cell::new(format!("{:.2} (low outlier)", score))
} else if quantile <= 0.9 {
Cell::new(format!("{:.2}", score))
} else {
Cell::new(format!("{:.2} (high outlier)", score))
}
}
let buckets = [
(comfy_table::Color::Red, bottom_third_of_metric),
(comfy_table::Color::Yellow, bottom_half_of_metric),
(comfy_table::Color::Green, f64::INFINITY),
];
let mut test_iter = self.tests.iter().peekable();
for (color, max) in buckets {
let mut count = 0;
while let Some(test) = test_iter.next_if(|test| test.score <= max) {
let quantile =
histogram.percentile_below((test.score * Self::SCALE_FACTOR) as u64) / 100.0;
let score_cell = create_cell(test.score, quantile).fg(color);
let mut row = Row::new();
row.add_cell(Cell::new(&test.case.expected))
.add_cell(Cell::new(&test.case.actual))
.add_cell(score_cell);
table.add_row(row);
count += 1;
if count >= 5 {
let mut remaining_matching_tests = 0;
let mut total_score = 0.0;
for test in test_iter.by_ref() {
if test.score > max {
break;
}
total_score += test.score;
remaining_matching_tests += 1;
}
if remaining_matching_tests > 0 {
let mut row = Row::new();
row.add_cell(Cell::new(format!("... {} more", remaining_matching_tests)))
.add_cell(Cell::new(""))
.add_cell(
Cell::new(format!(
"{:.2} (average)",
total_score / remaining_matching_tests as f64
))
.fg(color),
);
table.add_row(row);
}
break;
}
}
}
writeln!(f, "{}", table)?;
let mut buckets = [0; 10];
for test in &self.tests {
let normalized_score = self.normalize_score(test.score);
let bucket = (normalized_score * 10.0) as usize;
buckets[bucket.min(9)] += 1;
}
let max_width = *buckets.iter().max().unwrap();
let scale_factor = if max_width > 50 {
50.0 / max_width as f64
} else {
1.0
};
let max_width = ((max_width as f64 * scale_factor) as usize).max(3);
writeln!(f, "| Score Histogram {} |", " ".repeat(max_width - 3))?;
for (i, bucket) in buckets.iter().enumerate() {
let min_bucket = self.denormalize_score(i as f64 / 10.0);
let max_bucket = self.denormalize_score((i + 1) as f64 / 10.0);
let bucket = (*bucket as f64 * scale_factor) as usize;
writeln!(
f,
"| {:.2} - {:.2}: {}{} |",
min_bucket,
max_bucket,
"*".repeat(bucket),
" ".repeat(max_width - bucket)
)?;
}
Ok(())
}
}
#[derive(Default, Clone)]
struct TestCase<I> {
expected: I,
actual: I,
}
#[derive(Clone)]
struct TestCaseScored<'a, I> {
case: &'a TestCase<I>,
score: f64,
}