1use crate::executor::{ExecutionResult, NodeExecutor};
14use async_trait::async_trait;
15use jamjet_core::node::{EvalOnFail, EvalScorer};
16use jamjet_models::{ChatMessage, ModelConfig, ModelRegistry, ModelRequest};
17use jamjet_state::backend::WorkItem;
18use serde_json::{json, Value};
19use std::sync::Arc;
20use tracing::{debug, instrument, warn};
21
22#[derive(Debug, serde::Serialize)]
24pub struct ScorerResult {
25 pub scorer_type: String,
26 pub passed: bool,
27 pub score: Option<f64>,
28 pub message: String,
29}
30
31pub struct EvalExecutor {
32 model_registry: Arc<ModelRegistry>,
33}
34
35impl EvalExecutor {
36 pub fn new(model_registry: Arc<ModelRegistry>) -> Self {
37 Self { model_registry }
38 }
39
40 async fn run_llm_judge(
41 &self,
42 model: &str,
43 rubric: &str,
44 min_score: u8,
45 subject: &Value,
46 ) -> ScorerResult {
47 let prompt = format!(
48 "You are an impartial evaluator.\n\n\
49 Rubric: {rubric}\n\n\
50 Output to evaluate:\n{subject}\n\n\
51 Respond with ONLY a JSON object: {{\"score\": <integer 1-5>, \"reason\": \"<brief reason>\"}}"
52 );
53
54 let request = ModelRequest::new(vec![ChatMessage::user(prompt)]).with_config(ModelConfig {
55 model: Some(model.to_string()),
56 max_tokens: Some(256),
57 temperature: Some(0.0),
58 system_prompt: None,
59 stop_sequences: None,
60 });
61
62 match self.model_registry.chat(request).await {
63 Ok(resp) => {
64 let content = resp.content.trim();
66 let parsed: Option<Value> = content
68 .find('{')
69 .and_then(|start| content.rfind('}').map(|end| &content[start..=end]))
70 .and_then(|json_str| serde_json::from_str(json_str).ok());
71
72 if let Some(obj) = parsed {
73 let score = obj.get("score").and_then(|s| s.as_u64()).unwrap_or(0) as u8;
74 let reason = obj
75 .get("reason")
76 .and_then(|r| r.as_str())
77 .unwrap_or("no reason")
78 .to_string();
79 let passed = score >= min_score;
80 ScorerResult {
81 scorer_type: "llm_judge".into(),
82 passed,
83 score: Some(score as f64),
84 message: format!("score={score}/5 (min={min_score}): {reason}"),
85 }
86 } else {
87 ScorerResult {
88 scorer_type: "llm_judge".into(),
89 passed: false,
90 score: None,
91 message: format!("failed to parse judge response: {content}"),
92 }
93 }
94 }
95 Err(e) => ScorerResult {
96 scorer_type: "llm_judge".into(),
97 passed: false,
98 score: None,
99 message: format!("model call failed: {e}"),
100 },
101 }
102 }
103
104 fn run_assertions(&self, checks: &[String], subject: &Value) -> ScorerResult {
105 let mut failures = Vec::new();
106
107 for check in checks {
108 let passed = eval_assertion(check, subject);
115 if !passed {
116 failures.push(check.clone());
117 }
118 }
119
120 if failures.is_empty() {
121 ScorerResult {
122 scorer_type: "assertion".into(),
123 passed: true,
124 score: Some(1.0),
125 message: format!("all {} assertions passed", checks.len()),
126 }
127 } else {
128 ScorerResult {
129 scorer_type: "assertion".into(),
130 passed: false,
131 score: Some(0.0),
132 message: format!("failed assertions: {}", failures.join("; ")),
133 }
134 }
135 }
136
137 fn run_latency(&self, threshold_ms: u64, actual_ms: u64) -> ScorerResult {
138 let passed = actual_ms <= threshold_ms;
139 ScorerResult {
140 scorer_type: "latency".into(),
141 passed,
142 score: Some(actual_ms as f64),
143 message: format!("{actual_ms}ms (threshold: {threshold_ms}ms)"),
144 }
145 }
146
147 fn run_cost(&self, threshold_usd: f64, actual_usd: f64) -> ScorerResult {
148 let passed = actual_usd <= threshold_usd;
149 ScorerResult {
150 scorer_type: "cost".into(),
151 passed,
152 score: Some(actual_usd),
153 message: format!("${actual_usd:.6} (threshold: ${threshold_usd:.4})"),
154 }
155 }
156
157 fn run_custom(&self, module: &str, _kwargs: &Value, _subject: &Value) -> ScorerResult {
158 warn!(
161 module = %module,
162 "Custom scorer: delegating to Python worker process (not yet implemented in Rust executor)"
163 );
164 ScorerResult {
165 scorer_type: "custom".into(),
166 passed: true, score: None,
168 message: format!("custom scorer '{module}' delegated to Python worker"),
169 }
170 }
171}
172
173#[async_trait]
174impl NodeExecutor for EvalExecutor {
175 #[instrument(skip(self, item), fields(node_id = %item.node_id))]
176 async fn execute(&self, item: &WorkItem) -> Result<ExecutionResult, String> {
177 let start = std::time::Instant::now();
178
179 let scorers: Vec<EvalScorer> = item
181 .payload
182 .get("scorers")
183 .and_then(|v| serde_json::from_value(v.clone()).ok())
184 .unwrap_or_default();
185
186 let on_fail: EvalOnFail = item
187 .payload
188 .get("on_fail")
189 .and_then(|v| serde_json::from_value(v.clone()).ok())
190 .unwrap_or_default();
191
192 let max_retries: u32 = item
193 .payload
194 .get("max_retries")
195 .and_then(|v| v.as_u64())
196 .unwrap_or(0) as u32;
197
198 let subject: Value = item
200 .payload
201 .get("input")
202 .or_else(|| item.payload.get("last_output"))
203 .cloned()
204 .unwrap_or(Value::Null);
205
206 let preceding_ms = item
208 .payload
209 .get("preceding_duration_ms")
210 .and_then(|v| v.as_u64())
211 .unwrap_or(0);
212 let preceding_cost_usd = item
213 .payload
214 .get("preceding_cost_usd")
215 .and_then(|v| v.as_f64())
216 .unwrap_or(0.0);
217
218 debug!(scorers = scorers.len(), "Running eval node");
219
220 let mut results: Vec<ScorerResult> = Vec::new();
222 for scorer in &scorers {
223 let result = match scorer {
224 EvalScorer::LlmJudge {
225 model,
226 rubric,
227 min_score,
228 } => {
229 self.run_llm_judge(model, rubric, *min_score, &subject)
230 .await
231 }
232 EvalScorer::Assertion { checks } => self.run_assertions(checks, &subject),
233 EvalScorer::Latency { threshold_ms } => {
234 self.run_latency(*threshold_ms, preceding_ms)
235 }
236 EvalScorer::Cost { threshold_usd } => {
237 self.run_cost(*threshold_usd, preceding_cost_usd)
238 }
239 EvalScorer::Custom { module, kwargs } => self.run_custom(module, kwargs, &subject),
240 };
241 results.push(result);
242 }
243
244 let overall_passed = results.iter().all(|r| r.passed);
245 let duration_ms = start.elapsed().as_millis() as u64;
246
247 let results_json: Vec<Value> = results
249 .iter()
250 .map(|r| {
251 json!({
252 "scorer": r.scorer_type,
253 "passed": r.passed,
254 "score": r.score,
255 "message": r.message,
256 })
257 })
258 .collect();
259
260 let output = json!({
261 "passed": overall_passed,
262 "scorers": results_json,
263 "on_fail": serde_json::to_value(&on_fail).unwrap_or(Value::Null),
264 "max_retries": max_retries,
265 });
266
267 let action = if overall_passed {
269 "continue"
270 } else {
271 match &on_fail {
272 EvalOnFail::RetryWithFeedback => "retry_with_feedback",
273 EvalOnFail::Escalate => "escalate",
274 EvalOnFail::Halt => "halt",
275 EvalOnFail::LogAndContinue => "log_and_continue",
276 }
277 };
278
279 let state_patch = json!({
280 "eval_passed": overall_passed,
281 "eval_action": action,
282 "eval_results": results_json,
283 });
284
285 if !overall_passed && matches!(on_fail, EvalOnFail::Halt) {
286 return Err(format!(
287 "eval node failed — scorers: {}",
288 results
289 .iter()
290 .filter(|r| !r.passed)
291 .map(|r| r.message.as_str())
292 .collect::<Vec<_>>()
293 .join("; ")
294 ));
295 }
296
297 Ok(ExecutionResult {
298 output,
299 state_patch,
300 duration_ms,
301 gen_ai_system: None,
302 gen_ai_model: None,
303 input_tokens: None,
304 output_tokens: None,
305 finish_reason: Some(if overall_passed { "pass" } else { action }.to_string()),
306 })
307 }
308}
309
310fn eval_assertion(check: &str, subject: &Value) -> bool {
322 let check = check.trim();
323
324 if let Some(rest) = check.strip_suffix(" in output") {
326 let key = rest.trim().trim_matches('\'').trim_matches('"');
327 if let Some(obj) = subject.as_object() {
328 return obj.contains_key(key);
329 }
330 return false;
331 }
332
333 if check.starts_with("len(output.") {
335 if let Some(inner) = check.strip_prefix("len(output.") {
336 if let Some(paren_end) = inner.find(')') {
337 let field = &inner[..paren_end];
338 let rest = inner[paren_end + 1..].trim();
339 let arr_len = subject
340 .get(field)
341 .and_then(|v| v.as_array())
342 .map(|a| a.len())
343 .unwrap_or(0);
344 return eval_numeric_comparison(arr_len as i64, rest);
345 }
346 }
347 }
348
349 if let Some(rest) = check.strip_prefix("output.") {
351 if let Some(eq_pos) = rest.find(" == ") {
352 let field = rest[..eq_pos].trim();
353 let expected = rest[eq_pos + 4..]
354 .trim()
355 .trim_matches('\'')
356 .trim_matches('"');
357 let actual = subject.get(field).and_then(|v| v.as_str()).unwrap_or("");
358 return actual == expected;
359 }
360 if let Some(ne_pos) = rest.find(" != ") {
361 let field = rest[..ne_pos].trim();
362 let expected_raw = rest[ne_pos + 4..].trim();
363 if expected_raw == "null" || expected_raw == "None" {
364 return !subject.get(field).map(|v| v.is_null()).unwrap_or(true);
365 }
366 }
367 }
368
369 warn!(check = %check, "Unknown assertion pattern; delegating to Python worker");
371 true
372}
373
374fn eval_numeric_comparison(actual: i64, op_and_rhs: &str) -> bool {
375 let op_and_rhs = op_and_rhs.trim();
376 if let Some(rhs_str) = op_and_rhs.strip_prefix(">= ") {
377 if let Ok(n) = rhs_str.trim().parse::<i64>() {
378 return actual >= n;
379 }
380 }
381 if let Some(rhs_str) = op_and_rhs.strip_prefix("> ") {
382 if let Ok(n) = rhs_str.trim().parse::<i64>() {
383 return actual > n;
384 }
385 }
386 if let Some(rhs_str) = op_and_rhs.strip_prefix("<= ") {
387 if let Ok(n) = rhs_str.trim().parse::<i64>() {
388 return actual <= n;
389 }
390 }
391 if let Some(rhs_str) = op_and_rhs.strip_prefix("< ") {
392 if let Ok(n) = rhs_str.trim().parse::<i64>() {
393 return actual < n;
394 }
395 }
396 if let Some(rhs_str) = op_and_rhs.strip_prefix("== ") {
397 if let Ok(n) = rhs_str.trim().parse::<i64>() {
398 return actual == n;
399 }
400 }
401 true }