Skip to main content

jamjet_worker/executors/
eval.rs

1//! Executor for `Eval` workflow nodes (3.12–3.17).
2//!
3//! Runs configurable scorers against the preceding node's output:
4//! - `LlmJudge`  — calls an LLM with a rubric, extracts 1-5 score
5//! - `Assertion` — evaluates Python-like boolean expressions
6//! - `Latency`   — checks elapsed ms against a threshold
7//! - `Cost`      — checks estimated cost against a USD threshold
8//! - `Custom`    — delegates to a Python scorer via module reference
9//!
10//! On failure, the configured `on_fail` action determines the next step:
11//! `retry_with_feedback`, `escalate`, `halt`, or `log_and_continue`.
12
13use crate::executor::{ExecutionResult, NodeExecutor};
14use async_trait::async_trait;
15use jamjet_core::node::{EvalOnFail, EvalScorer};
16use jamjet_models::{ChatMessage, ModelConfig, ModelRegistry, ModelRequest};
17use jamjet_state::backend::WorkItem;
18use serde_json::{json, Value};
19use std::sync::Arc;
20use tracing::{debug, instrument, warn};
21
22/// Per-scorer result recorded in telemetry and execution output.
23#[derive(Debug, serde::Serialize)]
24pub struct ScorerResult {
25    pub scorer_type: String,
26    pub passed: bool,
27    pub score: Option<f64>,
28    pub message: String,
29}
30
31pub struct EvalExecutor {
32    model_registry: Arc<ModelRegistry>,
33}
34
35impl EvalExecutor {
36    pub fn new(model_registry: Arc<ModelRegistry>) -> Self {
37        Self { model_registry }
38    }
39
40    async fn run_llm_judge(
41        &self,
42        model: &str,
43        rubric: &str,
44        min_score: u8,
45        subject: &Value,
46    ) -> ScorerResult {
47        let prompt = format!(
48            "You are an impartial evaluator.\n\n\
49             Rubric: {rubric}\n\n\
50             Output to evaluate:\n{subject}\n\n\
51             Respond with ONLY a JSON object: {{\"score\": <integer 1-5>, \"reason\": \"<brief reason>\"}}"
52        );
53
54        let request = ModelRequest::new(vec![ChatMessage::user(prompt)]).with_config(ModelConfig {
55            model: Some(model.to_string()),
56            max_tokens: Some(256),
57            temperature: Some(0.0),
58            system_prompt: None,
59            stop_sequences: None,
60        });
61
62        match self.model_registry.chat(request).await {
63            Ok(resp) => {
64                // Extract JSON from the response content.
65                let content = resp.content.trim();
66                // Find the JSON object in the response.
67                let parsed: Option<Value> = content
68                    .find('{')
69                    .and_then(|start| content.rfind('}').map(|end| &content[start..=end]))
70                    .and_then(|json_str| serde_json::from_str(json_str).ok());
71
72                if let Some(obj) = parsed {
73                    let score = obj.get("score").and_then(|s| s.as_u64()).unwrap_or(0) as u8;
74                    let reason = obj
75                        .get("reason")
76                        .and_then(|r| r.as_str())
77                        .unwrap_or("no reason")
78                        .to_string();
79                    let passed = score >= min_score;
80                    ScorerResult {
81                        scorer_type: "llm_judge".into(),
82                        passed,
83                        score: Some(score as f64),
84                        message: format!("score={score}/5 (min={min_score}): {reason}"),
85                    }
86                } else {
87                    ScorerResult {
88                        scorer_type: "llm_judge".into(),
89                        passed: false,
90                        score: None,
91                        message: format!("failed to parse judge response: {content}"),
92                    }
93                }
94            }
95            Err(e) => ScorerResult {
96                scorer_type: "llm_judge".into(),
97                passed: false,
98                score: None,
99                message: format!("model call failed: {e}"),
100            },
101        }
102    }
103
104    fn run_assertions(&self, checks: &[String], subject: &Value) -> ScorerResult {
105        let mut failures = Vec::new();
106
107        for check in checks {
108            // Evaluate simple Python-like assertions against the JSON output.
109            // Full Python eval is delegated to the Python worker process.
110            // Here we support a minimal set of structural checks:
111            // - "'key' in output"  → key present in top-level object
112            // - "len(output.key) >= N" → array length check
113            // - "output.key == 'value'" → string equality check
114            let passed = eval_assertion(check, subject);
115            if !passed {
116                failures.push(check.clone());
117            }
118        }
119
120        if failures.is_empty() {
121            ScorerResult {
122                scorer_type: "assertion".into(),
123                passed: true,
124                score: Some(1.0),
125                message: format!("all {} assertions passed", checks.len()),
126            }
127        } else {
128            ScorerResult {
129                scorer_type: "assertion".into(),
130                passed: false,
131                score: Some(0.0),
132                message: format!("failed assertions: {}", failures.join("; ")),
133            }
134        }
135    }
136
137    fn run_latency(&self, threshold_ms: u64, actual_ms: u64) -> ScorerResult {
138        let passed = actual_ms <= threshold_ms;
139        ScorerResult {
140            scorer_type: "latency".into(),
141            passed,
142            score: Some(actual_ms as f64),
143            message: format!("{actual_ms}ms (threshold: {threshold_ms}ms)"),
144        }
145    }
146
147    fn run_cost(&self, threshold_usd: f64, actual_usd: f64) -> ScorerResult {
148        let passed = actual_usd <= threshold_usd;
149        ScorerResult {
150            scorer_type: "cost".into(),
151            passed,
152            score: Some(actual_usd),
153            message: format!("${actual_usd:.6} (threshold: ${threshold_usd:.4})"),
154        }
155    }
156
157    fn run_custom(&self, module: &str, _kwargs: &Value, _subject: &Value) -> ScorerResult {
158        // Custom scorers are invoked via the Python worker process in production.
159        // In the Rust executor, we emit a marker that the Python layer can intercept.
160        warn!(
161            module = %module,
162            "Custom scorer: delegating to Python worker process (not yet implemented in Rust executor)"
163        );
164        ScorerResult {
165            scorer_type: "custom".into(),
166            passed: true, // optimistic pass — Python worker enforces
167            score: None,
168            message: format!("custom scorer '{module}' delegated to Python worker"),
169        }
170    }
171}
172
173#[async_trait]
174impl NodeExecutor for EvalExecutor {
175    #[instrument(skip(self, item), fields(node_id = %item.node_id))]
176    async fn execute(&self, item: &WorkItem) -> Result<ExecutionResult, String> {
177        let start = std::time::Instant::now();
178
179        // Deserialize scorer configs from payload.
180        let scorers: Vec<EvalScorer> = item
181            .payload
182            .get("scorers")
183            .and_then(|v| serde_json::from_value(v.clone()).ok())
184            .unwrap_or_default();
185
186        let on_fail: EvalOnFail = item
187            .payload
188            .get("on_fail")
189            .and_then(|v| serde_json::from_value(v.clone()).ok())
190            .unwrap_or_default();
191
192        let max_retries: u32 = item
193            .payload
194            .get("max_retries")
195            .and_then(|v| v.as_u64())
196            .unwrap_or(0) as u32;
197
198        // The subject is the output of the preceding node (stored in state).
199        let subject: Value = item
200            .payload
201            .get("input")
202            .or_else(|| item.payload.get("last_output"))
203            .cloned()
204            .unwrap_or(Value::Null);
205
206        // Preceding node latency and cost from payload (set by scheduler).
207        let preceding_ms = item
208            .payload
209            .get("preceding_duration_ms")
210            .and_then(|v| v.as_u64())
211            .unwrap_or(0);
212        let preceding_cost_usd = item
213            .payload
214            .get("preceding_cost_usd")
215            .and_then(|v| v.as_f64())
216            .unwrap_or(0.0);
217
218        debug!(scorers = scorers.len(), "Running eval node");
219
220        // Run all scorers.
221        let mut results: Vec<ScorerResult> = Vec::new();
222        for scorer in &scorers {
223            let result = match scorer {
224                EvalScorer::LlmJudge {
225                    model,
226                    rubric,
227                    min_score,
228                } => {
229                    self.run_llm_judge(model, rubric, *min_score, &subject)
230                        .await
231                }
232                EvalScorer::Assertion { checks } => self.run_assertions(checks, &subject),
233                EvalScorer::Latency { threshold_ms } => {
234                    self.run_latency(*threshold_ms, preceding_ms)
235                }
236                EvalScorer::Cost { threshold_usd } => {
237                    self.run_cost(*threshold_usd, preceding_cost_usd)
238                }
239                EvalScorer::Custom { module, kwargs } => self.run_custom(module, kwargs, &subject),
240            };
241            results.push(result);
242        }
243
244        let overall_passed = results.iter().all(|r| r.passed);
245        let duration_ms = start.elapsed().as_millis() as u64;
246
247        // Build structured output with per-scorer breakdowns.
248        let results_json: Vec<Value> = results
249            .iter()
250            .map(|r| {
251                json!({
252                    "scorer": r.scorer_type,
253                    "passed": r.passed,
254                    "score": r.score,
255                    "message": r.message,
256                })
257            })
258            .collect();
259
260        let output = json!({
261            "passed": overall_passed,
262            "scorers": results_json,
263            "on_fail": serde_json::to_value(&on_fail).unwrap_or(Value::Null),
264            "max_retries": max_retries,
265        });
266
267        // Determine on_fail action: encode it in state_patch so the scheduler can act.
268        let action = if overall_passed {
269            "continue"
270        } else {
271            match &on_fail {
272                EvalOnFail::RetryWithFeedback => "retry_with_feedback",
273                EvalOnFail::Escalate => "escalate",
274                EvalOnFail::Halt => "halt",
275                EvalOnFail::LogAndContinue => "log_and_continue",
276            }
277        };
278
279        let state_patch = json!({
280            "eval_passed": overall_passed,
281            "eval_action": action,
282            "eval_results": results_json,
283        });
284
285        if !overall_passed && matches!(on_fail, EvalOnFail::Halt) {
286            return Err(format!(
287                "eval node failed — scorers: {}",
288                results
289                    .iter()
290                    .filter(|r| !r.passed)
291                    .map(|r| r.message.as_str())
292                    .collect::<Vec<_>>()
293                    .join("; ")
294            ));
295        }
296
297        Ok(ExecutionResult {
298            output,
299            state_patch,
300            duration_ms,
301            gen_ai_system: None,
302            gen_ai_model: None,
303            input_tokens: None,
304            output_tokens: None,
305            finish_reason: Some(if overall_passed { "pass" } else { action }.to_string()),
306        })
307    }
308}
309
310// ── Assertion evaluator ───────────────────────────────────────────────────────
311
312/// Evaluate a simple structural assertion against a JSON value.
313///
314/// Supported patterns (Python-compatible syntax):
315/// - `"'key' in output"` — key present in top-level object
316/// - `"len(output.key) >= N"` — array length comparison
317/// - `"output.key == 'value'"` — string equality
318/// - `"output.key != null"` — null check
319///
320/// For arbitrary Python expressions, the Python worker process evaluates them.
321fn eval_assertion(check: &str, subject: &Value) -> bool {
322    let check = check.trim();
323
324    // Pattern: "'key' in output"
325    if let Some(rest) = check.strip_suffix(" in output") {
326        let key = rest.trim().trim_matches('\'').trim_matches('"');
327        if let Some(obj) = subject.as_object() {
328            return obj.contains_key(key);
329        }
330        return false;
331    }
332
333    // Pattern: "len(output.key) >= N" or "len(output.key) > N" etc.
334    if check.starts_with("len(output.") {
335        if let Some(inner) = check.strip_prefix("len(output.") {
336            if let Some(paren_end) = inner.find(')') {
337                let field = &inner[..paren_end];
338                let rest = inner[paren_end + 1..].trim();
339                let arr_len = subject
340                    .get(field)
341                    .and_then(|v| v.as_array())
342                    .map(|a| a.len())
343                    .unwrap_or(0);
344                return eval_numeric_comparison(arr_len as i64, rest);
345            }
346        }
347    }
348
349    // Pattern: "output.key == 'value'" or "output.key != null"
350    if let Some(rest) = check.strip_prefix("output.") {
351        if let Some(eq_pos) = rest.find(" == ") {
352            let field = rest[..eq_pos].trim();
353            let expected = rest[eq_pos + 4..]
354                .trim()
355                .trim_matches('\'')
356                .trim_matches('"');
357            let actual = subject.get(field).and_then(|v| v.as_str()).unwrap_or("");
358            return actual == expected;
359        }
360        if let Some(ne_pos) = rest.find(" != ") {
361            let field = rest[..ne_pos].trim();
362            let expected_raw = rest[ne_pos + 4..].trim();
363            if expected_raw == "null" || expected_raw == "None" {
364                return !subject.get(field).map(|v| v.is_null()).unwrap_or(true);
365            }
366        }
367    }
368
369    // Unknown pattern — log and optimistically pass (Python worker will re-check).
370    warn!(check = %check, "Unknown assertion pattern; delegating to Python worker");
371    true
372}
373
374fn eval_numeric_comparison(actual: i64, op_and_rhs: &str) -> bool {
375    let op_and_rhs = op_and_rhs.trim();
376    if let Some(rhs_str) = op_and_rhs.strip_prefix(">= ") {
377        if let Ok(n) = rhs_str.trim().parse::<i64>() {
378            return actual >= n;
379        }
380    }
381    if let Some(rhs_str) = op_and_rhs.strip_prefix("> ") {
382        if let Ok(n) = rhs_str.trim().parse::<i64>() {
383            return actual > n;
384        }
385    }
386    if let Some(rhs_str) = op_and_rhs.strip_prefix("<= ") {
387        if let Ok(n) = rhs_str.trim().parse::<i64>() {
388            return actual <= n;
389        }
390    }
391    if let Some(rhs_str) = op_and_rhs.strip_prefix("< ") {
392        if let Ok(n) = rhs_str.trim().parse::<i64>() {
393            return actual < n;
394        }
395    }
396    if let Some(rhs_str) = op_and_rhs.strip_prefix("== ") {
397        if let Ok(n) = rhs_str.trim().parse::<i64>() {
398            return actual == n;
399        }
400    }
401    true // unknown comparison — delegate to Python
402}