use crate::executor::{ExecutionResult, NodeExecutor};
use async_trait::async_trait;
use jamjet_core::node::{EvalOnFail, EvalScorer};
use jamjet_models::{ChatMessage, ModelConfig, ModelRegistry, ModelRequest};
use jamjet_state::backend::WorkItem;
use serde_json::{json, Value};
use std::sync::Arc;
use tracing::{debug, instrument, warn};
#[derive(Debug, serde::Serialize)]
pub struct ScorerResult {
pub scorer_type: String,
pub passed: bool,
pub score: Option<f64>,
pub message: String,
}
pub struct EvalExecutor {
model_registry: Arc<ModelRegistry>,
}
impl EvalExecutor {
pub fn new(model_registry: Arc<ModelRegistry>) -> Self {
Self { model_registry }
}
async fn run_llm_judge(
&self,
model: &str,
rubric: &str,
min_score: u8,
subject: &Value,
) -> ScorerResult {
let prompt = format!(
"You are an impartial evaluator.\n\n\
Rubric: {rubric}\n\n\
Output to evaluate:\n{subject}\n\n\
Respond with ONLY a JSON object: {{\"score\": <integer 1-5>, \"reason\": \"<brief reason>\"}}"
);
let request = ModelRequest::new(vec![ChatMessage::user(prompt)]).with_config(ModelConfig {
model: Some(model.to_string()),
max_tokens: Some(256),
temperature: Some(0.0),
system_prompt: None,
stop_sequences: None,
});
match self.model_registry.chat(request).await {
Ok(resp) => {
let content = resp.content.trim();
let parsed: Option<Value> = content
.find('{')
.and_then(|start| content.rfind('}').map(|end| &content[start..=end]))
.and_then(|json_str| serde_json::from_str(json_str).ok());
if let Some(obj) = parsed {
let score = obj.get("score").and_then(|s| s.as_u64()).unwrap_or(0) as u8;
let reason = obj
.get("reason")
.and_then(|r| r.as_str())
.unwrap_or("no reason")
.to_string();
let passed = score >= min_score;
ScorerResult {
scorer_type: "llm_judge".into(),
passed,
score: Some(score as f64),
message: format!("score={score}/5 (min={min_score}): {reason}"),
}
} else {
ScorerResult {
scorer_type: "llm_judge".into(),
passed: false,
score: None,
message: format!("failed to parse judge response: {content}"),
}
}
}
Err(e) => ScorerResult {
scorer_type: "llm_judge".into(),
passed: false,
score: None,
message: format!("model call failed: {e}"),
},
}
}
fn run_assertions(&self, checks: &[String], subject: &Value) -> ScorerResult {
let mut failures = Vec::new();
for check in checks {
let passed = eval_assertion(check, subject);
if !passed {
failures.push(check.clone());
}
}
if failures.is_empty() {
ScorerResult {
scorer_type: "assertion".into(),
passed: true,
score: Some(1.0),
message: format!("all {} assertions passed", checks.len()),
}
} else {
ScorerResult {
scorer_type: "assertion".into(),
passed: false,
score: Some(0.0),
message: format!("failed assertions: {}", failures.join("; ")),
}
}
}
fn run_latency(&self, threshold_ms: u64, actual_ms: u64) -> ScorerResult {
let passed = actual_ms <= threshold_ms;
ScorerResult {
scorer_type: "latency".into(),
passed,
score: Some(actual_ms as f64),
message: format!("{actual_ms}ms (threshold: {threshold_ms}ms)"),
}
}
fn run_cost(&self, threshold_usd: f64, actual_usd: f64) -> ScorerResult {
let passed = actual_usd <= threshold_usd;
ScorerResult {
scorer_type: "cost".into(),
passed,
score: Some(actual_usd),
message: format!("${actual_usd:.6} (threshold: ${threshold_usd:.4})"),
}
}
fn run_custom(&self, module: &str, _kwargs: &Value, _subject: &Value) -> ScorerResult {
warn!(
module = %module,
"Custom scorer: delegating to Python worker process (not yet implemented in Rust executor)"
);
ScorerResult {
scorer_type: "custom".into(),
passed: true, score: None,
message: format!("custom scorer '{module}' delegated to Python worker"),
}
}
}
#[async_trait]
impl NodeExecutor for EvalExecutor {
#[instrument(skip(self, item), fields(node_id = %item.node_id))]
async fn execute(&self, item: &WorkItem) -> Result<ExecutionResult, String> {
let start = std::time::Instant::now();
let scorers: Vec<EvalScorer> = item
.payload
.get("scorers")
.and_then(|v| serde_json::from_value(v.clone()).ok())
.unwrap_or_default();
let on_fail: EvalOnFail = item
.payload
.get("on_fail")
.and_then(|v| serde_json::from_value(v.clone()).ok())
.unwrap_or_default();
let max_retries: u32 = item
.payload
.get("max_retries")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32;
let subject: Value = item
.payload
.get("input")
.or_else(|| item.payload.get("last_output"))
.cloned()
.unwrap_or(Value::Null);
let preceding_ms = item
.payload
.get("preceding_duration_ms")
.and_then(|v| v.as_u64())
.unwrap_or(0);
let preceding_cost_usd = item
.payload
.get("preceding_cost_usd")
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
debug!(scorers = scorers.len(), "Running eval node");
let mut results: Vec<ScorerResult> = Vec::new();
for scorer in &scorers {
let result = match scorer {
EvalScorer::LlmJudge {
model,
rubric,
min_score,
} => {
self.run_llm_judge(model, rubric, *min_score, &subject)
.await
}
EvalScorer::Assertion { checks } => self.run_assertions(checks, &subject),
EvalScorer::Latency { threshold_ms } => {
self.run_latency(*threshold_ms, preceding_ms)
}
EvalScorer::Cost { threshold_usd } => {
self.run_cost(*threshold_usd, preceding_cost_usd)
}
EvalScorer::Custom { module, kwargs } => self.run_custom(module, kwargs, &subject),
};
results.push(result);
}
let overall_passed = results.iter().all(|r| r.passed);
let duration_ms = start.elapsed().as_millis() as u64;
let results_json: Vec<Value> = results
.iter()
.map(|r| {
json!({
"scorer": r.scorer_type,
"passed": r.passed,
"score": r.score,
"message": r.message,
})
})
.collect();
let output = json!({
"passed": overall_passed,
"scorers": results_json,
"on_fail": serde_json::to_value(&on_fail).unwrap_or(Value::Null),
"max_retries": max_retries,
});
let action = if overall_passed {
"continue"
} else {
match &on_fail {
EvalOnFail::RetryWithFeedback => "retry_with_feedback",
EvalOnFail::Escalate => "escalate",
EvalOnFail::Halt => "halt",
EvalOnFail::LogAndContinue => "log_and_continue",
}
};
let state_patch = json!({
"eval_passed": overall_passed,
"eval_action": action,
"eval_results": results_json,
});
if !overall_passed && matches!(on_fail, EvalOnFail::Halt) {
return Err(format!(
"eval node failed — scorers: {}",
results
.iter()
.filter(|r| !r.passed)
.map(|r| r.message.as_str())
.collect::<Vec<_>>()
.join("; ")
));
}
Ok(ExecutionResult {
output,
state_patch,
duration_ms,
gen_ai_system: None,
gen_ai_model: None,
input_tokens: None,
output_tokens: None,
finish_reason: Some(if overall_passed { "pass" } else { action }.to_string()),
})
}
}
fn eval_assertion(check: &str, subject: &Value) -> bool {
let check = check.trim();
if let Some(rest) = check.strip_suffix(" in output") {
let key = rest.trim().trim_matches('\'').trim_matches('"');
if let Some(obj) = subject.as_object() {
return obj.contains_key(key);
}
return false;
}
if check.starts_with("len(output.") {
if let Some(inner) = check.strip_prefix("len(output.") {
if let Some(paren_end) = inner.find(')') {
let field = &inner[..paren_end];
let rest = inner[paren_end + 1..].trim();
let arr_len = subject
.get(field)
.and_then(|v| v.as_array())
.map(|a| a.len())
.unwrap_or(0);
return eval_numeric_comparison(arr_len as i64, rest);
}
}
}
if let Some(rest) = check.strip_prefix("output.") {
if let Some(eq_pos) = rest.find(" == ") {
let field = rest[..eq_pos].trim();
let expected = rest[eq_pos + 4..]
.trim()
.trim_matches('\'')
.trim_matches('"');
let actual = subject.get(field).and_then(|v| v.as_str()).unwrap_or("");
return actual == expected;
}
if let Some(ne_pos) = rest.find(" != ") {
let field = rest[..ne_pos].trim();
let expected_raw = rest[ne_pos + 4..].trim();
if expected_raw == "null" || expected_raw == "None" {
return !subject.get(field).map(|v| v.is_null()).unwrap_or(true);
}
}
}
warn!(check = %check, "Unknown assertion pattern; delegating to Python worker");
true
}
fn eval_numeric_comparison(actual: i64, op_and_rhs: &str) -> bool {
let op_and_rhs = op_and_rhs.trim();
if let Some(rhs_str) = op_and_rhs.strip_prefix(">= ") {
if let Ok(n) = rhs_str.trim().parse::<i64>() {
return actual >= n;
}
}
if let Some(rhs_str) = op_and_rhs.strip_prefix("> ") {
if let Ok(n) = rhs_str.trim().parse::<i64>() {
return actual > n;
}
}
if let Some(rhs_str) = op_and_rhs.strip_prefix("<= ") {
if let Ok(n) = rhs_str.trim().parse::<i64>() {
return actual <= n;
}
}
if let Some(rhs_str) = op_and_rhs.strip_prefix("< ") {
if let Ok(n) = rhs_str.trim().parse::<i64>() {
return actual < n;
}
}
if let Some(rhs_str) = op_and_rhs.strip_prefix("== ") {
if let Ok(n) = rhs_str.trim().parse::<i64>() {
return actual == n;
}
}
true }