use std::sync::Arc;
use crate::model::{SharedEmbeddingModel, SharedModel};
use crate::model::types::{ChatRequest, Message};
pub enum Scorer {
ExactMatch(String),
Contains(String),
Regex(String),
Custom(Arc<dyn Fn(&str) -> bool + Send + Sync>),
SemanticSimilarity {
expected: String,
embedding_model: SharedEmbeddingModel,
threshold: f64,
},
LlmJudge {
rubric: String,
model: SharedModel,
},
}
impl Scorer {
pub fn semantic(
expected: impl Into<String>,
embedding_model: SharedEmbeddingModel,
threshold: f64,
) -> Self {
Scorer::SemanticSimilarity {
expected: expected.into(),
embedding_model,
threshold,
}
}
pub fn llm_judge(rubric: impl Into<String>, model: SharedModel) -> Self {
Scorer::LlmJudge {
rubric: rubric.into(),
model,
}
}
pub async fn evaluate(&self, output: &str) -> bool {
match self {
Scorer::ExactMatch(expected) => output == expected,
Scorer::Contains(substring) => output.contains(substring.as_str()),
Scorer::Regex(pattern) => regex_lite::Regex::new(pattern)
.map(|re| re.is_match(output))
.unwrap_or(false),
Scorer::Custom(f) => f(output),
Scorer::SemanticSimilarity {
expected,
embedding_model,
threshold,
} => {
let texts = [expected.as_str(), output];
match embedding_model.embed_erased(&texts).await {
Ok(embeddings) if embeddings.len() == 2 => {
let sim = cosine_similarity(&embeddings[0], &embeddings[1]);
sim as f64 >= *threshold
}
_ => false,
}
}
Scorer::LlmJudge { rubric, model } => {
let prompt = format!(
"You are an evaluation judge. Grade the following output against the rubric.\n\n\
Rubric: {rubric}\n\n\
Output to evaluate:\n{output}\n\n\
Respond with EXACTLY one word: PASS or FAIL."
);
let request = ChatRequest::new(vec![Message::user(&prompt)]);
match model.generate_erased(&request).await {
Ok(response) => response.text().trim().to_uppercase().starts_with("PASS"),
Err(_) => false,
}
}
}
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = norm_a.sqrt() * norm_b.sqrt();
if denom == 0.0 { 0.0 } else { dot / denom }
}
impl std::fmt::Debug for Scorer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Scorer::ExactMatch(s) => write!(f, "ExactMatch({s:?})"),
Scorer::Contains(s) => write!(f, "Contains({s:?})"),
Scorer::Regex(s) => write!(f, "Regex({s:?})"),
Scorer::Custom(_) => write!(f, "Custom(...)"),
Scorer::SemanticSimilarity { expected, threshold, .. } => {
write!(f, "SemanticSimilarity({expected:?}, threshold={threshold})")
}
Scorer::LlmJudge { rubric, .. } => write!(f, "LlmJudge({rubric:?})"),
}
}
}
impl Clone for Scorer {
fn clone(&self) -> Self {
match self {
Scorer::ExactMatch(s) => Scorer::ExactMatch(s.clone()),
Scorer::Contains(s) => Scorer::Contains(s.clone()),
Scorer::Regex(s) => Scorer::Regex(s.clone()),
Scorer::Custom(f) => Scorer::Custom(Arc::clone(f)),
Scorer::SemanticSimilarity { expected, embedding_model, threshold } => {
Scorer::SemanticSimilarity {
expected: expected.clone(),
embedding_model: Arc::clone(embedding_model),
threshold: *threshold,
}
}
Scorer::LlmJudge { rubric, model } => Scorer::LlmJudge {
rubric: rubric.clone(),
model: Arc::clone(model),
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_exact_match() {
assert!(Scorer::ExactMatch("hello".into()).evaluate("hello").await);
assert!(!Scorer::ExactMatch("hello".into()).evaluate("Hello").await);
}
#[tokio::test]
async fn test_contains() {
assert!(Scorer::Contains("world".into()).evaluate("hello world").await);
assert!(!Scorer::Contains("xyz".into()).evaluate("hello world").await);
}
#[tokio::test]
async fn test_regex() {
assert!(Scorer::Regex(r"\d+".into()).evaluate("answer is 42").await);
assert!(!Scorer::Regex(r"^\d+$".into()).evaluate("answer is 42").await);
}
#[tokio::test]
async fn test_custom() {
let scorer = Scorer::Custom(Arc::new(|s| s.len() > 5));
assert!(scorer.evaluate("long enough").await);
assert!(!scorer.evaluate("short").await);
}
}