assay_core/engine/
runner.rs

1use crate::attempts::{classify_attempts, FailureClass};
2use crate::cache::key::cache_key;
3use crate::cache::vcr::VcrCache;
4use crate::errors::try_map_error;
5use crate::metrics_api::Metric;
6use crate::model::{AttemptRow, EvalConfig, LlmResponse, TestCase, TestResultRow, TestStatus};
7use crate::on_error::{ErrorPolicy, ErrorPolicyResult};
8use crate::providers::llm::LlmClient;
9use crate::quarantine::{QuarantineMode, QuarantineService};
10use crate::report::RunArtifacts;
11use crate::storage::store::Store;
12use std::sync::Arc;
13use tokio::sync::Semaphore;
14use tokio::time::{timeout, Duration};
15
16#[derive(Debug, Clone)]
17pub struct RunPolicy {
18    pub rerun_failures: u32,
19    pub quarantine_mode: QuarantineMode,
20    pub replay_strict: bool,
21}
22
23impl Default for RunPolicy {
24    fn default() -> Self {
25        Self {
26            rerun_failures: 1,
27            quarantine_mode: QuarantineMode::Warn,
28            replay_strict: false,
29        }
30    }
31}
32
33pub struct Runner {
34    pub store: Store,
35    pub cache: VcrCache,
36    pub client: Arc<dyn LlmClient>,
37    pub metrics: Vec<Arc<dyn Metric>>,
38    pub policy: RunPolicy,
39    pub embedder: Option<Arc<dyn crate::providers::embedder::Embedder>>,
40    pub refresh_embeddings: bool,
41    pub incremental: bool,
42    pub refresh_cache: bool,
43    pub judge: Option<crate::judge::JudgeService>,
44    pub baseline: Option<crate::baseline::Baseline>,
45}
46
47impl Runner {
48    pub async fn run_suite(&self, cfg: &EvalConfig) -> anyhow::Result<RunArtifacts> {
49        let run_id = self.store.create_run(cfg)?;
50
51        let parallel = cfg.settings.parallel.unwrap_or(4).max(1);
52        let sem = Arc::new(Semaphore::new(parallel));
53        let mut handles = Vec::new();
54
55        for tc in cfg.tests.iter() {
56            let permit = sem.clone().acquire_owned().await?;
57            let this = self.clone_for_task();
58            let cfg = cfg.clone();
59            let tc = tc.clone();
60            let h = tokio::spawn(async move {
61                let _permit = permit;
62                this.run_test_with_policy(&cfg, &tc, run_id).await
63            });
64            handles.push(h);
65        }
66
67        let mut rows = Vec::new();
68        let mut any_fail = false;
69        for h in handles {
70            let row = match h.await {
71                Ok(Ok(row)) => row,
72                Ok(Err(e)) => TestResultRow {
73                    test_id: "unknown".into(),
74                    status: TestStatus::Error,
75                    score: None,
76                    cached: false,
77                    message: format!("task error: {}", e),
78                    details: serde_json::json!({}),
79                    duration_ms: None,
80                    fingerprint: None,
81                    skip_reason: None,
82                    attempts: None,
83                    error_policy_applied: None,
84                },
85                Err(e) => TestResultRow {
86                    test_id: "unknown".into(),
87                    status: TestStatus::Error,
88                    score: None,
89                    cached: false,
90                    message: format!("join error: {}", e),
91                    details: serde_json::json!({}),
92                    duration_ms: None,
93                    fingerprint: None,
94                    skip_reason: None,
95                    attempts: None,
96                    error_policy_applied: None,
97                },
98            };
99            any_fail = any_fail || matches!(row.status, TestStatus::Fail | TestStatus::Error);
100            rows.push(row);
101        }
102
103        self.store
104            .finalize_run(run_id, if any_fail { "failed" } else { "passed" })?;
105        Ok(RunArtifacts {
106            run_id,
107            suite: cfg.suite.clone(),
108            results: rows,
109        })
110    }
111
112    async fn run_test_with_policy(
113        &self,
114        cfg: &EvalConfig,
115        tc: &TestCase,
116        run_id: i64,
117    ) -> anyhow::Result<TestResultRow> {
118        let quarantine = QuarantineService::new(self.store.clone());
119        let q_reason = quarantine.is_quarantined(&cfg.suite, &tc.id)?;
120        let error_policy = cfg.effective_error_policy(tc);
121
122        let max_attempts = 1 + self.policy.rerun_failures;
123        let mut attempts: Vec<AttemptRow> = Vec::new();
124        let mut last_row: Option<TestResultRow> = None;
125        let mut last_output: Option<LlmResponse> = None;
126
127        for i in 0..max_attempts {
128            // Catch execution errors and convert to ResultRow to leverage retry/reporting logic
129            let (row, output) = match self.run_test_once(cfg, tc).await {
130                Ok(res) => res,
131                Err(e) => {
132                    let msg = if let Some(diag) = try_map_error(&e) {
133                        diag.to_string()
134                    } else {
135                        e.to_string()
136                    };
137
138                    let policy_result = error_policy.apply_to_error(&e);
139                    let (status, final_msg, applied_policy) = match policy_result {
140                        ErrorPolicyResult::Blocked { reason } => {
141                            (TestStatus::Error, reason, ErrorPolicy::Block)
142                        }
143                        ErrorPolicyResult::Allowed { warning } => {
144                            crate::on_error::log_fail_safe(&warning, None);
145                            (TestStatus::AllowedOnError, warning, ErrorPolicy::Allow)
146                        }
147                    };
148
149                    (
150                        TestResultRow {
151                            test_id: tc.id.clone(),
152                            status,
153                            score: None,
154                            cached: false,
155                            message: final_msg,
156                            details: serde_json::json!({
157                                "error": msg,
158                                "policy_applied": applied_policy
159                            }),
160                            duration_ms: None,
161                            fingerprint: None,
162                            skip_reason: None,
163                            attempts: None,
164                            error_policy_applied: Some(applied_policy),
165                        },
166                        LlmResponse {
167                            text: "".into(),
168                            provider: "error".into(),
169                            model: cfg.model.clone(),
170                            cached: false,
171                            meta: serde_json::json!({}),
172                        },
173                    )
174                }
175            };
176            attempts.push(AttemptRow {
177                attempt_no: i + 1,
178                status: row.status.clone(),
179                message: row.message.clone(),
180                duration_ms: row.duration_ms,
181                details: row.details.clone(),
182            });
183            last_row = Some(row.clone());
184            last_output = Some(output.clone());
185
186            match row.status {
187                TestStatus::Pass | TestStatus::Warn | TestStatus::AllowedOnError => break,
188                TestStatus::Skipped => break, // Should not happen in loop
189                TestStatus::Fail | TestStatus::Error | TestStatus::Flaky | TestStatus::Unstable => {
190                    continue
191                }
192            }
193        }
194
195        let class = classify_attempts(&attempts);
196        let mut final_row = last_row.unwrap_or(TestResultRow {
197            test_id: tc.id.clone(),
198            status: TestStatus::Error,
199            score: None,
200            cached: false,
201            message: "no attempts".into(),
202            details: serde_json::json!({}),
203            duration_ms: None,
204            fingerprint: None,
205            skip_reason: None,
206            attempts: None,
207            error_policy_applied: None,
208        });
209
210        // quarantine overlay
211        if let Some(reason) = q_reason {
212            match self.policy.quarantine_mode {
213                QuarantineMode::Off => {}
214                QuarantineMode::Warn => {
215                    final_row.status = TestStatus::Warn;
216                    final_row.message = format!("quarantined: {}", reason);
217                }
218                QuarantineMode::Strict => {
219                    final_row.status = TestStatus::Fail;
220                    final_row.message = format!("quarantined (strict): {}", reason);
221                }
222            }
223        }
224
225        match class {
226            FailureClass::Skipped => {
227                final_row.status = TestStatus::Skipped;
228                // message usually set by run_test_once
229            }
230            FailureClass::Flaky => {
231                final_row.status = TestStatus::Flaky;
232                final_row.message = "flake detected (rerun passed)".into();
233                final_row.details["flake"] = serde_json::json!({ "attempts": attempts.len() });
234            }
235            FailureClass::Unstable => {
236                final_row.status = TestStatus::Unstable;
237                final_row.message = "unstable outcomes detected".into();
238                final_row.details["unstable"] = serde_json::json!({ "attempts": attempts.len() });
239            }
240            FailureClass::Error => final_row.status = TestStatus::Error,
241            FailureClass::DeterministicFail => {
242                // Ensures if last attempt was fail, we keep fail status
243                final_row.status = TestStatus::Fail;
244            }
245            FailureClass::DeterministicPass => {
246                final_row.status = TestStatus::Pass;
247            }
248        }
249
250        let output = last_output.unwrap_or(LlmResponse {
251            text: "".into(),
252            provider: self.client.provider_name().to_string(),
253            model: cfg.model.clone(),
254            cached: false,
255            meta: serde_json::json!({}),
256        });
257
258        final_row.attempts = Some(attempts.clone());
259
260        // PR-4.0.3 Agent Assertions
261        if let Some(assertions) = &tc.assertions {
262            if !assertions.is_empty() {
263                // Verify assertions against DB
264                match crate::agent_assertions::verify_assertions(
265                    &self.store,
266                    run_id,
267                    &tc.id,
268                    assertions,
269                ) {
270                    Ok(diags) => {
271                        if !diags.is_empty() {
272                            // Assertion Failures
273                            final_row.status = TestStatus::Fail;
274
275                            // serialize diagnostics
276                            let diag_json: Vec<serde_json::Value> = diags
277                                .iter()
278                                .map(|d| serde_json::to_value(d).unwrap_or_default())
279                                .collect();
280
281                            final_row.details["assertions"] = serde_json::Value::Array(diag_json);
282
283                            let fail_msg = format!("assertions failed ({})", diags.len());
284                            if final_row.message == "ok" {
285                                final_row.message = fail_msg;
286                            } else {
287                                final_row.message = format!("{}; {}", final_row.message, fail_msg);
288                            }
289                        } else {
290                            // passed
291                            final_row.details["assertions"] = serde_json::json!({ "passed": true });
292                        }
293                    }
294                    Err(e) => {
295                        // Missing or Ambiguous Episode -> Fail
296                        final_row.status = TestStatus::Fail;
297                        final_row.message = format!("assertions error: {}", e);
298                        final_row.details["assertions"] =
299                            serde_json::json!({ "error": e.to_string() });
300                    }
301                }
302            }
303        }
304
305        self.store
306            .insert_result_embedded(run_id, &final_row, &attempts, &output)?;
307
308        Ok(final_row)
309    }
310
311    async fn run_test_once(
312        &self,
313        cfg: &EvalConfig,
314        tc: &TestCase,
315    ) -> anyhow::Result<(TestResultRow, LlmResponse)> {
316        let expected_json = serde_json::to_string(&tc.expected).unwrap_or_default();
317        let metric_versions = [("assay", env!("CARGO_PKG_VERSION"))];
318
319        let policy_hash = if let Some(path) = tc.expected.get_policy_path() {
320            // Read policy content to ensure cache invalidation on content change
321            match std::fs::read_to_string(path) {
322                Ok(content) => Some(crate::fingerprint::sha256_hex(&content)),
323                Err(_) => None, // If file missing, finding it later will error.
324                                // We don't fail here to allow error reporting in metrics phase or main loop.
325            }
326        } else {
327            None
328        };
329
330        let fp = crate::fingerprint::compute(crate::fingerprint::Context {
331            suite: &cfg.suite,
332            model: &cfg.model,
333            test_id: &tc.id,
334            prompt: &tc.input.prompt,
335            context: tc.input.context.as_deref(),
336            expected_canonical: &expected_json,
337            policy_hash: policy_hash.as_deref(),
338            metric_versions: &metric_versions,
339        });
340
341        // Incremental Check
342        // Note: Global --incremental flag should be checked here.
343        // Assuming self.incremental is available.
344        if self.incremental && !self.refresh_cache {
345            if let Some(prev) = self.store.get_last_passing_by_fingerprint(&fp.hex)? {
346                // Return Skipped Result
347                let row = TestResultRow {
348                    test_id: tc.id.clone(),
349                    status: TestStatus::Skipped,
350                    score: prev.score,
351                    cached: true,
352                    message: "skipped: fingerprint match".into(),
353                    details: serde_json::json!({
354                        "skip": {
355                             "reason": "fingerprint_match",
356                             "fingerprint": fp.hex,
357                             "previous_run_id": prev.details.get("skip").and_then(|s: &serde_json::Value| s.get("previous_run_id")).and_then(|v: &serde_json::Value| v.as_i64()),
358                             "previous_at": prev.details.get("skip").and_then(|s: &serde_json::Value| s.get("previous_at")).and_then(|v: &serde_json::Value| v.as_str()),
359                             "origin_run_id": prev.details.get("skip").and_then(|s: &serde_json::Value| s.get("origin_run_id")).and_then(|v: &serde_json::Value| v.as_i64()),
360                             "previous_score": prev.score
361                        }
362                    }),
363                    duration_ms: Some(0), // Instant
364                    fingerprint: Some(fp.hex.clone()),
365                    skip_reason: Some("fingerprint_match".into()),
366                    attempts: None,
367                    error_policy_applied: None,
368                };
369
370                // Construct placeholder response for pipeline consistency
371                let resp = LlmResponse {
372                    text: "".into(),
373                    provider: "skipped".into(),
374                    model: cfg.model.clone(),
375                    cached: true,
376                    meta: serde_json::json!({}),
377                };
378                return Ok((row, resp));
379            }
380        }
381
382        // Original Execution Logic
383        // We use the computed fingerprint for caching key to distinguish config variations
384        let key = cache_key(
385            &cfg.model,
386            &tc.input.prompt,
387            &fp.hex,
388            self.client.fingerprint().as_deref(),
389        );
390
391        let start = std::time::Instant::now();
392        let mut cached = false;
393
394        let mut resp: LlmResponse = if cfg.settings.cache.unwrap_or(true) && !self.refresh_cache {
395            if let Some(r) = self.cache.get(&key)? {
396                cached = true;
397                eprintln!(
398                    "  [CACHE HIT] key={} prompt_len={}",
399                    key,
400                    tc.input.prompt.len()
401                );
402                r
403            } else {
404                let r = self.call_llm(cfg, tc).await?;
405                self.cache.put(&key, &r)?;
406                r
407            }
408        } else {
409            self.call_llm(cfg, tc).await?
410        };
411        resp.cached = resp.cached || cached;
412
413        // Semantic Enrichment
414        self.enrich_semantic(tc, &mut resp).await?;
415        self.enrich_judge(tc, &mut resp).await?;
416
417        let mut final_status = TestStatus::Pass;
418        let mut final_score: Option<f64> = None;
419        let mut msg = String::new();
420        let mut details = serde_json::json!({ "metrics": {} });
421
422        for m in &self.metrics {
423            let r = m.evaluate(tc, &tc.expected, &resp).await?;
424            details["metrics"][m.name()] = serde_json::json!({
425                "score": r.score, "passed": r.passed, "unstable": r.unstable, "details": r.details
426            });
427            final_score = Some(r.score);
428
429            if r.unstable {
430                // gate stability first: treat unstable as warn in MVP
431                final_status = TestStatus::Warn;
432                msg = format!("unstable metric: {}", m.name());
433                break;
434            }
435            if !r.passed {
436                final_status = TestStatus::Fail;
437                msg = format!("failed: {}", m.name());
438                break;
439            }
440        }
441
442        // Post-metric baseline check
443        if let Some(baseline) = &self.baseline {
444            if let Some((new_status, new_msg)) =
445                self.check_baseline_regressions(tc, cfg, &details, &self.metrics, baseline)
446            {
447                if matches!(new_status, TestStatus::Fail | TestStatus::Warn) {
448                    final_status = new_status;
449                    msg = new_msg;
450                }
451            }
452        }
453
454        let duration_ms = start.elapsed().as_millis() as u64;
455        let mut row = TestResultRow {
456            test_id: tc.id.clone(),
457            status: final_status,
458            score: final_score,
459            cached: resp.cached,
460            message: if msg.is_empty() { "ok".into() } else { msg },
461            details,
462            duration_ms: Some(duration_ms),
463            fingerprint: Some(fp.hex),
464            skip_reason: None,
465            attempts: None,
466            error_policy_applied: None,
467        };
468
469        if self.client.provider_name() == "trace" {
470            row.details["assay.replay"] = serde_json::json!(true);
471        }
472
473        row.details["prompt"] = serde_json::Value::String(tc.input.prompt.clone());
474
475        Ok((row, resp))
476    }
477
478    async fn call_llm(&self, cfg: &EvalConfig, tc: &TestCase) -> anyhow::Result<LlmResponse> {
479        let t = cfg.settings.timeout_seconds.unwrap_or(30);
480        let fut = self
481            .client
482            .complete(&tc.input.prompt, tc.input.context.as_deref());
483        let resp = timeout(Duration::from_secs(t), fut).await??;
484        Ok(resp)
485    }
486
487    fn clone_for_task(&self) -> RunnerRef {
488        RunnerRef {
489            store: self.store.clone(),
490            cache: self.cache.clone(),
491            client: self.client.clone(),
492            metrics: self.metrics.clone(),
493            policy: self.policy.clone(),
494            embedder: self.embedder.clone(),
495            refresh_embeddings: self.refresh_embeddings,
496            incremental: self.incremental,
497            refresh_cache: self.refresh_cache,
498            judge: self.judge.clone(),
499            baseline: self.baseline.clone(),
500        }
501    }
502
503    fn check_baseline_regressions(
504        &self,
505        tc: &TestCase,
506        cfg: &EvalConfig,
507        details: &serde_json::Value,
508        metrics: &[Arc<dyn Metric>],
509        baseline: &crate::baseline::Baseline,
510    ) -> Option<(TestStatus, String)> {
511        // Check suite-level defaults
512        let suite_defaults = cfg.settings.thresholding.as_ref();
513
514        for m in metrics {
515            let metric_name = m.name();
516            // Only numeric metrics supported right now
517            let score = details["metrics"][metric_name]["score"].as_f64()?;
518
519            // Determine thresholding config
520            // 1. Metric override (from expected enum - tricky as Metric trait hides this)
521            // Use suite defaults unless specific metric logic overrides
522
523            let (mode, max_drop) = self.resolve_threshold_config(tc, metric_name, suite_defaults);
524
525            if mode == "relative" {
526                if let Some(base_score) = baseline.get_score(&tc.id, metric_name) {
527                    let delta = score - base_score;
528                    if let Some(drop_limit) = max_drop {
529                        if delta < -drop_limit {
530                            return Some((
531                                TestStatus::Fail,
532                                format!(
533                                    "regression: {} dropped {:.3} (limit: {:.3})",
534                                    metric_name, -delta, drop_limit
535                                ),
536                            ));
537                        }
538                    }
539                } else {
540                    // Missing baseline
541                    return Some((
542                        TestStatus::Warn,
543                        format!("missing baseline for {}/{}", tc.id, metric_name),
544                    ));
545                }
546            }
547        }
548        None
549    }
550
551    fn resolve_threshold_config(
552        &self,
553        _tc: &TestCase,
554        _metric_name: &str,
555        suite_defaults: Option<&crate::model::ThresholdingSettings>,
556    ) -> (String, Option<f64>) {
557        // Defaults
558        let mut mode = "absolute".to_string();
559        let mut max_drop = None;
560
561        if let Some(s) = suite_defaults {
562            if let Some(m) = &s.mode {
563                mode = m.clone();
564            }
565            max_drop = s.max_drop;
566        }
567
568        // TODO: Map metric_name to strict Expected variant fields for per-test overrides.
569        // Currently relies on global suite defaults.
570        (mode, max_drop)
571    }
572
573    // Embeddings logic
574    async fn enrich_semantic(&self, tc: &TestCase, resp: &mut LlmResponse) -> anyhow::Result<()> {
575        use crate::model::Expected;
576
577        let Expected::SemanticSimilarityTo {
578            semantic_similarity_to,
579            ..
580        } = &tc.expected
581        else {
582            return Ok(());
583        };
584
585        if resp.meta.pointer("/assay/embeddings/response").is_some()
586            && resp.meta.pointer("/assay/embeddings/reference").is_some()
587        {
588            return Ok(());
589        }
590
591        if self.policy.replay_strict {
592            anyhow::bail!("config error: --replay-strict is on, but embeddings are missing in trace. Run 'assay trace precompute-embeddings' or disable strict mode.");
593        }
594
595        let embedder = self.embedder.as_ref().ok_or_else(|| {
596            anyhow::anyhow!(
597                "config error: semantic_similarity_to requires an embedder (--embedder) or trace meta embeddings"
598            )
599        })?;
600
601        let model_id = embedder.model_id();
602
603        let (resp_vec, src_resp) = self
604            .embed_text(&model_id, embedder.as_ref(), &resp.text)
605            .await?;
606        let (ref_vec, src_ref) = self
607            .embed_text(&model_id, embedder.as_ref(), semantic_similarity_to)
608            .await?;
609
610        // write into meta.assay.embeddings
611        if !resp.meta.get("assay").is_some_and(|v| v.is_object()) {
612            resp.meta["assay"] = serde_json::json!({});
613        }
614        resp.meta["assay"]["embeddings"] = serde_json::json!({
615            "model": model_id,
616            "response": resp_vec,
617            "reference": ref_vec,
618            "source_response": src_resp,
619            "source_reference": src_ref
620        });
621
622        Ok(())
623    }
624
625    pub async fn embed_text(
626        &self,
627        model_id: &str,
628        embedder: &dyn crate::providers::embedder::Embedder,
629        text: &str,
630    ) -> anyhow::Result<(Vec<f32>, &'static str)> {
631        use crate::embeddings::util::embed_cache_key;
632
633        let key = embed_cache_key(model_id, text);
634
635        if !self.refresh_embeddings {
636            if let Some((_m, vec)) = self.store.get_embedding(&key)? {
637                return Ok((vec, "cache"));
638            }
639        }
640
641        let vec = embedder.embed(text).await?;
642        self.store.put_embedding(&key, model_id, &vec)?;
643        Ok((vec, "live"))
644    }
645
646    async fn enrich_judge(&self, tc: &TestCase, resp: &mut LlmResponse) -> anyhow::Result<()> {
647        use crate::model::Expected;
648
649        let (rubric_id, rubric_version) = match &tc.expected {
650            Expected::Faithfulness { rubric_version, .. } => {
651                ("faithfulness", rubric_version.as_deref())
652            }
653            Expected::Relevance { rubric_version, .. } => ("relevance", rubric_version.as_deref()),
654            _ => return Ok(()),
655        };
656
657        // Check if judge result exists in meta is handled by JudgeService::evaluate
658        // BUT for a better error message in strict mode we can check here too or rely on the StrictLlmClient failure.
659        // User requested: "judge guard ... missing judge result in trace meta ... run precompute-judge"
660
661        let has_trace = resp
662            .meta
663            .pointer(&format!("/assay/judge/{}", rubric_id))
664            .is_some();
665        if self.policy.replay_strict && !has_trace {
666            anyhow::bail!("config error: --replay-strict is on, but judge results are missing in trace for '{}'. Run 'assay trace precompute-judge' or disable strict mode.", rubric_id);
667        }
668
669        let judge = self.judge.as_ref().ok_or_else(|| {
670            anyhow::anyhow!("config error: judge required but service not initialized")
671        })?;
672
673        judge
674            .evaluate(
675                &tc.id,
676                rubric_id,
677                &tc.input,
678                &resp.text,
679                rubric_version,
680                &mut resp.meta,
681            )
682            .await?;
683
684        Ok(())
685    }
686}
687
688#[derive(Clone)]
689struct RunnerRef {
690    store: Store,
691    cache: VcrCache,
692    client: Arc<dyn LlmClient>,
693    metrics: Vec<Arc<dyn Metric>>,
694    policy: RunPolicy,
695    embedder: Option<Arc<dyn crate::providers::embedder::Embedder>>,
696    refresh_embeddings: bool,
697    incremental: bool,
698    refresh_cache: bool,
699    judge: Option<crate::judge::JudgeService>,
700    baseline: Option<crate::baseline::Baseline>,
701}
702
703impl RunnerRef {
704    async fn run_test_with_policy(
705        &self,
706        cfg: &EvalConfig,
707        tc: &TestCase,
708        run_id: i64,
709    ) -> anyhow::Result<TestResultRow> {
710        let runner = Runner {
711            store: self.store.clone(),
712            cache: self.cache.clone(),
713            client: self.client.clone(),
714            metrics: self.metrics.clone(),
715            policy: self.policy.clone(),
716            embedder: self.embedder.clone(),
717            refresh_embeddings: self.refresh_embeddings,
718            incremental: self.incremental,
719            refresh_cache: self.refresh_cache,
720            judge: self.judge.clone(),
721            baseline: self.baseline.clone(),
722        };
723        runner.run_test_with_policy(cfg, tc, run_id).await
724    }
725}