use serde::{Deserialize, Serialize};
use crate::context::{ContextKey, Fact};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum EvalOutcome {
Pass,
Fail,
Indeterminate,
}
impl EvalOutcome {
#[must_use]
pub fn is_pass(&self) -> bool {
matches!(self, Self::Pass)
}
#[must_use]
pub fn is_fail(&self) -> bool {
matches!(self, Self::Fail)
}
#[must_use]
pub fn is_indeterminate(&self) -> bool {
matches!(self, Self::Indeterminate)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct EvalResult {
pub eval_name: String,
pub outcome: EvalOutcome,
pub score: f64,
pub rationale: String,
pub fact_ids: Vec<String>,
pub metadata: Option<String>,
}
impl EvalResult {
#[must_use]
pub fn new(
eval_name: impl Into<String>,
outcome: EvalOutcome,
score: f64,
rationale: impl Into<String>,
) -> Self {
Self {
eval_name: eval_name.into(),
outcome,
score: score.clamp(0.0, 1.0),
rationale: rationale.into(),
fact_ids: Vec::new(),
metadata: None,
}
}
#[must_use]
pub fn with_facts(
eval_name: impl Into<String>,
outcome: EvalOutcome,
score: f64,
rationale: impl Into<String>,
fact_ids: Vec<String>,
) -> Self {
Self {
eval_name: eval_name.into(),
outcome,
score: score.clamp(0.0, 1.0),
rationale: rationale.into(),
fact_ids,
metadata: None,
}
}
#[must_use]
pub fn to_fact(&self, eval_id: Option<&str>) -> Fact {
let id = if let Some(eid) = eval_id {
format!("eval:{}:{}", self.eval_name, eid)
} else {
format!("eval:{}", self.eval_name)
};
let content = format!(
"Outcome: {:?} | Score: {:.2} | {}",
self.outcome, self.score, self.rationale
);
Fact {
key: ContextKey::Evaluations,
id,
content,
}
}
}
pub trait Eval: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn evaluate(&self, ctx: &dyn crate::ContextView) -> EvalResult;
fn dependencies(&self) -> &[ContextKey] {
&[]
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct EvalId(pub(crate) u32);
impl std::fmt::Display for EvalId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Eval({})", self.0)
}
}
#[derive(Default)]
pub struct EvalRegistry {
evals: Vec<Box<dyn Eval>>,
by_dependency: std::collections::HashMap<ContextKey, Vec<EvalId>>,
next_id: u32,
}
impl EvalRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, eval: impl Eval + 'static) -> EvalId {
let id = EvalId(self.next_id);
self.next_id += 1;
let deps = eval.dependencies();
for &key in deps {
self.by_dependency.entry(key).or_default().push(id);
}
self.evals.push(Box::new(eval));
id
}
#[must_use]
pub fn count(&self) -> usize {
self.evals.len()
}
#[must_use]
pub fn evaluate_all(&self, ctx: &dyn crate::ContextView) -> Vec<EvalResult> {
self.evals.iter().map(|eval| eval.evaluate(ctx)).collect()
}
#[must_use]
pub fn evaluate_dependent(
&self,
ctx: &dyn crate::ContextView,
dirty_keys: &[ContextKey],
) -> Vec<EvalResult> {
let mut eval_ids: std::collections::HashSet<EvalId> = std::collections::HashSet::new();
for key in dirty_keys {
if let Some(ids) = self.by_dependency.get(key) {
eval_ids.extend(ids);
}
}
self.evals
.iter()
.enumerate()
.filter_map(|(idx, eval)| {
let id = EvalId(idx as u32);
if eval_ids.contains(&id) || eval.dependencies().is_empty() {
Some(eval.evaluate(ctx))
} else {
None
}
})
.collect()
}
#[must_use]
pub fn evaluate_by_id(&self, id: EvalId, ctx: &dyn crate::ContextView) -> EvalResult {
self.evals[id.0 as usize].evaluate(ctx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::Context;
struct RequireSeedsEval;
impl Eval for RequireSeedsEval {
fn name(&self) -> &'static str {
"require_seeds"
}
fn description(&self) -> &'static str {
"Checks if at least one seed exists in context"
}
fn evaluate(&self, ctx: &dyn crate::ContextView) -> EvalResult {
let seeds = ctx.get(ContextKey::Seeds);
let count = seeds.len();
if count > 0 {
EvalResult::new(
self.name(),
EvalOutcome::Pass,
1.0,
format!("Found {} seeds", count),
)
} else {
EvalResult::new(self.name(), EvalOutcome::Fail, 0.0, "No seeds found")
}
}
fn dependencies(&self) -> &[ContextKey] {
&[ContextKey::Seeds]
}
}
#[test]
fn registry_registers_evals() {
let mut registry = EvalRegistry::new();
let id1 = registry.register(RequireSeedsEval);
let id2 = registry.register(RequireSeedsEval);
assert_eq!(registry.count(), 2);
assert_ne!(id1, id2);
}
#[test]
fn eval_passes_when_seeds_exist() {
let mut registry = EvalRegistry::new();
let id = registry.register(RequireSeedsEval);
let mut ctx = Context::new();
let _ = ctx.add_fact(Fact {
key: ContextKey::Seeds,
id: "s1".into(),
content: "value".into(),
});
let result = registry.evaluate_by_id(id, &ctx);
assert_eq!(result.outcome, EvalOutcome::Pass);
assert!((result.score - 1.0_f64).abs() < f64::EPSILON);
}
#[test]
fn eval_fails_when_no_seeds() {
let mut registry = EvalRegistry::new();
let id = registry.register(RequireSeedsEval);
let ctx = Context::new();
let result = registry.evaluate_by_id(id, &ctx);
assert_eq!(result.outcome, EvalOutcome::Fail);
assert!((result.score - 0.0_f64).abs() < f64::EPSILON);
}
#[test]
fn eval_result_converts_to_fact() {
let result = EvalResult::new("test_eval", EvalOutcome::Pass, 0.85, "Test passed");
let fact = result.to_fact(None);
assert_eq!(fact.key, ContextKey::Evaluations);
assert!(fact.id.starts_with("eval:test_eval"));
assert!(fact.content.contains("Pass"));
assert!(fact.content.contains("0.85"));
}
#[test]
fn eval_result_score_is_clamped() {
let result = EvalResult::new(
"test",
EvalOutcome::Pass,
1.5, "test",
);
assert!((result.score - 1.0_f64).abs() < f64::EPSILON);
let result = EvalResult::new(
"test",
EvalOutcome::Pass,
-0.5, "test",
);
assert!((result.score - 0.0_f64).abs() < f64::EPSILON);
}
}