use std::collections::HashSet;
use jsonschema::Validator;
use crate::evaluator::Evaluator;
use crate::score::Score;
use crate::types::{EvalCase, EvalMetricResult, Invocation};
type RubricScorer =
dyn Fn(&str, &serde_json::Value, Option<&serde_json::Value>) -> f64 + Send + Sync;
#[derive(Clone)]
pub enum KeyStrategy {
Average,
All,
None,
Rubric {
scorer: std::sync::Arc<RubricScorer>,
},
}
impl std::fmt::Debug for KeyStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Average => f.debug_tuple("Average").finish(),
Self::All => f.debug_tuple("All").finish(),
Self::None => f.debug_tuple("None").finish(),
Self::Rubric { .. } => f.debug_struct("Rubric").field("scorer", &"<fn>").finish(),
}
}
}
pub struct JsonMatchEvaluator {
name: &'static str,
expected: serde_json::Value,
strategy: KeyStrategy,
exclude_keys: HashSet<String>,
}
impl JsonMatchEvaluator {
#[must_use]
pub fn new(expected: serde_json::Value) -> Self {
Self {
name: "json_match",
expected,
strategy: KeyStrategy::Average,
exclude_keys: HashSet::new(),
}
}
#[must_use]
pub const fn with_name(mut self, name: &'static str) -> Self {
self.name = name;
self
}
#[must_use]
pub fn with_strategy(mut self, strategy: KeyStrategy) -> Self {
self.strategy = strategy;
self
}
#[must_use]
pub fn with_exclude_keys<I, S>(mut self, keys: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.exclude_keys = keys.into_iter().map(Into::into).collect();
self
}
fn compare(&self, actual: &serde_json::Value) -> (f64, String) {
let expected_obj = if let Some(obj) = self.expected.as_object() {
obj
} else {
let eq = self.expected == *actual;
return (
if eq { 1.0 } else { 0.0 },
if eq {
"match".into()
} else {
"mismatch".into()
},
);
};
let actual_obj = actual.as_object();
let mut per_key: Vec<(String, f64)> = Vec::new();
for (key, expected_value) in expected_obj {
if self.exclude_keys.contains(key) {
continue;
}
let actual_value = actual_obj.and_then(|obj| obj.get(key));
let score = match &self.strategy {
KeyStrategy::Average | KeyStrategy::All | KeyStrategy::None => {
if actual_value == Some(expected_value) {
1.0
} else {
0.0
}
}
KeyStrategy::Rubric { scorer } => {
scorer(key, expected_value, actual_value).clamp(0.0_f64, 1.0_f64)
}
};
per_key.push((key.clone(), score));
}
if per_key.is_empty() {
return (1.0, "no comparable keys".into());
}
let score = match &self.strategy {
KeyStrategy::Average | KeyStrategy::Rubric { .. } => {
let sum: f64 = per_key.iter().map(|(_, s)| *s).sum();
#[allow(clippy::cast_precision_loss)]
{
sum / per_key.len() as f64
}
}
KeyStrategy::All => {
if per_key.iter().all(|(_, s)| *s >= 1.0) {
1.0
} else {
0.0
}
}
KeyStrategy::None => {
if per_key.iter().all(|(_, s)| *s <= 0.0) {
1.0
} else {
0.0
}
}
};
let details = per_key
.iter()
.map(|(k, s)| format!("{k}={s:.2}"))
.collect::<Vec<_>>()
.join(", ");
(score, details)
}
}
impl Evaluator for JsonMatchEvaluator {
fn name(&self) -> &'static str {
self.name
}
fn evaluate(&self, _case: &EvalCase, invocation: &Invocation) -> Option<EvalMetricResult> {
let raw = invocation.final_response.as_ref()?;
let parsed: serde_json::Value = match serde_json::from_str(raw) {
Ok(value) => value,
Err(err) => {
return Some(EvalMetricResult {
evaluator_name: self.name.to_string(),
score: Score::fail(),
details: Some(format!("malformed JSON response: {err}")),
});
}
};
let (value, details) = self.compare(&parsed);
Some(EvalMetricResult {
evaluator_name: self.name.to_string(),
score: Score::new(value, 0.5),
details: Some(details),
})
}
}
pub struct JsonSchemaEvaluator {
name: &'static str,
validator: Validator,
}
impl JsonSchemaEvaluator {
pub fn new(schema: &serde_json::Value) -> Result<Self, String> {
let validator = jsonschema::validator_for(schema).map_err(|err| err.to_string())?;
Ok(Self {
name: "json_schema",
validator,
})
}
#[must_use]
pub const fn with_name(mut self, name: &'static str) -> Self {
self.name = name;
self
}
}
impl Evaluator for JsonSchemaEvaluator {
fn name(&self) -> &'static str {
self.name
}
fn evaluate(&self, _case: &EvalCase, invocation: &Invocation) -> Option<EvalMetricResult> {
let raw = invocation.final_response.as_ref()?;
let parsed: serde_json::Value = match serde_json::from_str(raw) {
Ok(value) => value,
Err(err) => {
return Some(EvalMetricResult {
evaluator_name: self.name.to_string(),
score: Score::fail(),
details: Some(format!("malformed JSON response: {err}")),
});
}
};
let errors: Vec<String> = self
.validator
.iter_errors(&parsed)
.map(|err| err.to_string())
.collect();
if errors.is_empty() {
Some(EvalMetricResult {
evaluator_name: self.name.to_string(),
score: Score::pass(),
details: Some("schema valid".into()),
})
} else {
Some(EvalMetricResult {
evaluator_name: self.name.to_string(),
score: Score::fail(),
details: Some(errors.join("; ")),
})
}
}
}