Skip to main content

assay_core/engine/
runner.rs

1use crate::cache::key::cache_key;
2use crate::cache::vcr::VcrCache;
3use crate::metrics_api::Metric;
4use crate::model::{EvalConfig, LlmResponse, TestCase, TestResultRow, TestStatus};
5use crate::providers::llm::LlmClient;
6use crate::quarantine::QuarantineMode;
7use crate::report::progress::ProgressSink;
8use crate::report::RunArtifacts;
9use crate::storage::store::Store;
10use std::sync::Arc;
11use tracing::{info_span, Instrument};
12
13#[path = "runner_next/mod.rs"]
14mod runner_next;
15
16#[derive(Debug, Clone)]
17pub struct RunPolicy {
18    pub rerun_failures: u32,
19    pub quarantine_mode: QuarantineMode,
20    pub replay_strict: bool,
21}
22
23impl Default for RunPolicy {
24    fn default() -> Self {
25        Self {
26            rerun_failures: 1,
27            quarantine_mode: QuarantineMode::Warn,
28            replay_strict: false,
29        }
30    }
31}
32
33pub struct Runner {
34    pub store: Store,
35    pub cache: VcrCache,
36    pub client: Arc<dyn LlmClient>,
37    pub metrics: Vec<Arc<dyn Metric>>,
38    pub policy: RunPolicy,
39    pub _network_guard: Option<crate::providers::network::NetworkPolicyGuard>,
40    pub embedder: Option<Arc<dyn crate::providers::embedder::Embedder>>,
41    pub refresh_embeddings: bool,
42    pub incremental: bool,
43    pub refresh_cache: bool,
44    pub judge: Option<crate::judge::JudgeService>,
45    pub baseline: Option<crate::baseline::Baseline>,
46}
47
48impl Runner {
49    /// Run the suite; results are collected in completion order internally but returned
50    /// sorted by test_id for deterministic output. If `progress` is set, it is called
51    /// after each test completes (E4.3 realtime progress).
52    pub async fn run_suite(
53        &self,
54        cfg: &EvalConfig,
55        progress: Option<ProgressSink>,
56    ) -> anyhow::Result<RunArtifacts> {
57        runner_next::execute::run_suite_impl(self, cfg, progress).await
58    }
59
60    fn apply_agent_assertions(
61        &self,
62        run_id: i64,
63        tc: &TestCase,
64        final_row: &mut TestResultRow,
65    ) -> anyhow::Result<()> {
66        if let Some(assertions) = &tc.assertions {
67            if !assertions.is_empty() {
68                match crate::agent_assertions::verify_assertions(
69                    &self.store,
70                    run_id,
71                    &tc.id,
72                    assertions,
73                ) {
74                    Ok(diags) => {
75                        if !diags.is_empty() {
76                            // Assertion Failures
77                            final_row.status = TestStatus::Fail;
78
79                            let diag_json: Vec<serde_json::Value> = diags
80                                .iter()
81                                .map(|d| serde_json::to_value(d).unwrap_or_default())
82                                .collect();
83
84                            final_row.details["assertions"] = serde_json::Value::Array(diag_json);
85
86                            let fail_msg = format!("assertions failed ({})", diags.len());
87                            if final_row.message == "ok" {
88                                final_row.message = fail_msg;
89                            } else {
90                                final_row.message = format!("{}; {}", final_row.message, fail_msg);
91                            }
92                        } else {
93                            // passed
94                            final_row.details["assertions"] = serde_json::json!({ "passed": true });
95                        }
96                    }
97                    Err(e) => {
98                        // Missing or Ambiguous Episode -> Fail
99                        final_row.status = TestStatus::Fail;
100                        final_row.message = format!("assertions error: {}", e);
101                        final_row.details["assertions"] =
102                            serde_json::json!({ "error": e.to_string() });
103                    }
104                }
105            }
106        }
107        Ok(())
108    }
109
110    async fn run_test_once(
111        &self,
112        cfg: &EvalConfig,
113        tc: &TestCase,
114    ) -> anyhow::Result<(TestResultRow, LlmResponse)> {
115        let expected_json = serde_json::to_string(&tc.expected).unwrap_or_default();
116        let metric_versions = [("assay", env!("CARGO_PKG_VERSION"))];
117
118        let policy_hash = if let Some(path) = tc.expected.get_policy_path() {
119            // Read policy content to ensure cache invalidation on content change
120            match std::fs::read_to_string(path) {
121                Ok(content) => Some(crate::fingerprint::sha256_hex(&content)),
122                Err(_) => None, // If file missing, finding it later will error.
123                                // We don't fail here to allow error reporting in metrics phase or main loop.
124            }
125        } else {
126            None
127        };
128
129        let fp = crate::fingerprint::compute(crate::fingerprint::Context {
130            suite: &cfg.suite,
131            model: &cfg.model,
132            test_id: &tc.id,
133            prompt: &tc.input.prompt,
134            context: tc.input.context.as_deref(),
135            expected_canonical: &expected_json,
136            policy_hash: policy_hash.as_deref(),
137            metric_versions: &metric_versions,
138        });
139
140        // Incremental cache check.
141        if self.incremental && !self.refresh_cache {
142            if let Some(prev) = self.store.get_last_passing_by_fingerprint(&fp.hex)? {
143                // Return Skipped Result
144                let row = TestResultRow {
145                    test_id: tc.id.clone(),
146                    status: TestStatus::Skipped,
147                    score: prev.score,
148                    cached: true,
149                    message: "skipped: fingerprint match".into(),
150                    details: serde_json::json!({
151                        "skip": {
152                             "reason": "fingerprint_match",
153                             "fingerprint": fp.hex,
154                             "previous_run_id": prev.details.get("skip").and_then(|s: &serde_json::Value| s.get("previous_run_id")).and_then(|v: &serde_json::Value| v.as_i64()),
155                             "previous_at": prev.details.get("skip").and_then(|s: &serde_json::Value| s.get("previous_at")).and_then(|v: &serde_json::Value| v.as_str()),
156                             "origin_run_id": prev.details.get("skip").and_then(|s: &serde_json::Value| s.get("origin_run_id")).and_then(|v: &serde_json::Value| v.as_i64()),
157                             "previous_score": prev.score
158                        }
159                    }),
160                    duration_ms: Some(0), // Instant
161                    fingerprint: Some(fp.hex.clone()),
162                    skip_reason: Some("fingerprint_match".into()),
163                    attempts: None,
164                    error_policy_applied: None,
165                };
166
167                // Construct placeholder response for pipeline consistency
168                let resp = LlmResponse {
169                    text: "".into(),
170                    provider: "skipped".into(),
171                    model: cfg.model.clone(),
172                    cached: true,
173                    meta: serde_json::json!({}),
174                };
175                return Ok((row, resp));
176            }
177        }
178
179        // Original Execution Logic
180        // We use the computed fingerprint for caching key to distinguish config variations
181        let key = cache_key(
182            &cfg.model,
183            &tc.input.prompt,
184            &fp.hex,
185            self.client.fingerprint().as_deref(),
186        );
187
188        let start = std::time::Instant::now();
189        let mut cached = false;
190
191        let mut resp: LlmResponse = if cfg.settings.cache.unwrap_or(true) && !self.refresh_cache {
192            if let Some(r) = self.cache.get(&key)? {
193                cached = true;
194                eprintln!(
195                    "  [CACHE HIT] key={} prompt_len={}",
196                    key,
197                    tc.input.prompt.len()
198                );
199                r
200            } else {
201                let r = self.call_llm(cfg, tc).await?;
202                self.cache.put(&key, &r)?;
203                r
204            }
205        } else {
206            self.call_llm(cfg, tc).await?
207        };
208        resp.cached = resp.cached || cached;
209
210        // Semantic Enrichment
211        self.enrich_semantic(cfg, tc, &mut resp).await?;
212        self.enrich_judge(cfg, tc, &mut resp).await?;
213
214        let mut final_status = TestStatus::Pass;
215        let mut final_score: Option<f64> = None;
216        let mut msg = String::new();
217        let mut details = serde_json::json!({ "metrics": {} });
218
219        for m in &self.metrics {
220            let metric_name = m.name();
221            let metric_span = info_span!(
222                "assay.eval.metric",
223                "assay.eval.test_id" = tc.id.as_str(),
224                "assay.eval.metric.name" = metric_name,
225                "assay.eval.response.cached" = resp.cached,
226                "assay.eval.metric.score" = tracing::field::Empty,
227                "assay.eval.metric.passed" = tracing::field::Empty,
228                "assay.eval.metric.unstable" = tracing::field::Empty,
229                "assay.eval.metric.duration_ms" = tracing::field::Empty,
230                "error" = tracing::field::Empty,
231                "error.message" = tracing::field::Empty
232            );
233            let metric_start = std::time::Instant::now();
234            let metric_result = async { m.evaluate(tc, &tc.expected, &resp).await }
235                .instrument(metric_span.clone())
236                .await;
237            let metric_duration_ms = metric_start.elapsed().as_millis() as u64;
238            metric_span.record("assay.eval.metric.duration_ms", metric_duration_ms);
239
240            let r = match metric_result {
241                Ok(result) => {
242                    metric_span.record("assay.eval.metric.score", result.score);
243                    metric_span.record("assay.eval.metric.passed", result.passed);
244                    metric_span.record("assay.eval.metric.unstable", result.unstable);
245                    result
246                }
247                Err(err) => {
248                    let error_message = err.to_string();
249                    metric_span.record("error", true);
250                    metric_span.record("error.message", error_message.as_str());
251                    return Err(err);
252                }
253            };
254
255            details["metrics"][metric_name] = serde_json::json!({
256                "score": r.score, "passed": r.passed, "unstable": r.unstable, "details": r.details
257            });
258            final_score = Some(r.score);
259
260            if r.unstable {
261                // gate stability first: treat unstable as warn in MVP
262                final_status = TestStatus::Warn;
263                msg = format!("unstable metric: {}", metric_name);
264                break;
265            }
266            if !r.passed {
267                final_status = TestStatus::Fail;
268                msg = format!("failed: {}", metric_name);
269                break;
270            }
271        }
272
273        // Post-metric baseline check
274        if let Some(baseline) = &self.baseline {
275            if let Some((new_status, new_msg)) =
276                self.check_baseline_regressions(tc, cfg, &details, &self.metrics, baseline)
277            {
278                if matches!(new_status, TestStatus::Fail | TestStatus::Warn) {
279                    final_status = new_status;
280                    msg = new_msg;
281                }
282            }
283        }
284
285        let duration_ms = start.elapsed().as_millis() as u64;
286        let mut row = TestResultRow {
287            test_id: tc.id.clone(),
288            status: final_status,
289            score: final_score,
290            cached: resp.cached,
291            message: if msg.is_empty() { "ok".into() } else { msg },
292            details,
293            duration_ms: Some(duration_ms),
294            fingerprint: Some(fp.hex),
295            skip_reason: None,
296            attempts: None,
297            error_policy_applied: None,
298        };
299
300        if self.client.provider_name() == "trace" {
301            row.details["assay.replay"] = serde_json::json!(true);
302        }
303
304        row.details["prompt"] = serde_json::Value::String(tc.input.prompt.clone());
305
306        Ok((row, resp))
307    }
308
309    async fn call_llm(&self, cfg: &EvalConfig, tc: &TestCase) -> anyhow::Result<LlmResponse> {
310        runner_next::execute::call_llm_impl(self, cfg, tc).await
311    }
312
313    fn check_baseline_regressions(
314        &self,
315        tc: &TestCase,
316        cfg: &EvalConfig,
317        details: &serde_json::Value,
318        metrics: &[Arc<dyn Metric>],
319        baseline: &crate::baseline::Baseline,
320    ) -> Option<(TestStatus, String)> {
321        runner_next::baseline::check_baseline_regressions_impl(
322            self, tc, cfg, details, metrics, baseline,
323        )
324    }
325
326    // Embeddings logic
327    async fn enrich_semantic(
328        &self,
329        _cfg: &EvalConfig,
330        tc: &TestCase,
331        resp: &mut LlmResponse,
332    ) -> anyhow::Result<()> {
333        runner_next::scoring::enrich_semantic_impl(self, _cfg, tc, resp).await
334    }
335
336    pub async fn embed_text(
337        &self,
338        model_id: &str,
339        embedder: &dyn crate::providers::embedder::Embedder,
340        text: &str,
341    ) -> anyhow::Result<(Vec<f32>, &'static str)> {
342        runner_next::cache::embed_text_impl(self, model_id, embedder, text).await
343    }
344
345    async fn enrich_judge(
346        &self,
347        cfg: &EvalConfig,
348        tc: &TestCase,
349        resp: &mut LlmResponse,
350    ) -> anyhow::Result<()> {
351        runner_next::scoring::enrich_judge_impl(self, cfg, tc, resp).await
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use crate::metrics_api::{Metric, MetricResult};
359    use crate::model::{Expected, Settings, TestInput};
360    use crate::on_error::ErrorPolicy;
361    use crate::providers::llm::fake::FakeClient;
362    use crate::providers::llm::LlmClient;
363    use async_trait::async_trait;
364    use std::sync::atomic::{AtomicUsize, Ordering};
365
366    #[derive(Clone, Copy)]
367    enum MetricMode {
368        FailThenPass,
369        AlwaysFail,
370        AlwaysPass,
371    }
372
373    struct ScriptedMetric {
374        mode: MetricMode,
375        calls: AtomicUsize,
376    }
377
378    impl ScriptedMetric {
379        fn fail_then_pass() -> Self {
380            Self {
381                mode: MetricMode::FailThenPass,
382                calls: AtomicUsize::new(0),
383            }
384        }
385
386        fn always_fail() -> Self {
387            Self {
388                mode: MetricMode::AlwaysFail,
389                calls: AtomicUsize::new(0),
390            }
391        }
392
393        fn always_pass() -> Self {
394            Self {
395                mode: MetricMode::AlwaysPass,
396                calls: AtomicUsize::new(0),
397            }
398        }
399    }
400
401    #[async_trait]
402    impl Metric for ScriptedMetric {
403        fn name(&self) -> &'static str {
404            "scripted"
405        }
406
407        async fn evaluate(
408            &self,
409            _tc: &TestCase,
410            _expected: &Expected,
411            _resp: &LlmResponse,
412        ) -> anyhow::Result<MetricResult> {
413            let n = self.calls.fetch_add(1, Ordering::SeqCst);
414            match self.mode {
415                MetricMode::FailThenPass => {
416                    if n == 0 {
417                        Ok(MetricResult::fail(0.0, "scripted_fail_once"))
418                    } else {
419                        Ok(MetricResult::pass(1.0))
420                    }
421                }
422                MetricMode::AlwaysFail => Ok(MetricResult::fail(0.0, "scripted_fail")),
423                MetricMode::AlwaysPass => Ok(MetricResult::pass(1.0)),
424            }
425        }
426    }
427
428    struct ErrorClient;
429
430    #[async_trait]
431    impl LlmClient for ErrorClient {
432        async fn complete(
433            &self,
434            _prompt: &str,
435            _context: Option<&[String]>,
436        ) -> anyhow::Result<LlmResponse> {
437            Err(anyhow::anyhow!("scripted provider error"))
438        }
439
440        fn provider_name(&self) -> &'static str {
441            "error_client"
442        }
443    }
444
445    fn runner_for_contract_tests(
446        client: Arc<dyn LlmClient>,
447        metrics: Vec<Arc<dyn Metric>>,
448        rerun_failures: u32,
449    ) -> Runner {
450        let store = Store::memory().expect("in-memory store");
451        store.init_schema().expect("schema init");
452        Runner {
453            store: store.clone(),
454            cache: VcrCache::new(store),
455            client,
456            metrics,
457            policy: RunPolicy {
458                rerun_failures,
459                quarantine_mode: QuarantineMode::Off,
460                replay_strict: false,
461            },
462            _network_guard: None,
463            embedder: None,
464            refresh_embeddings: false,
465            incremental: false,
466            refresh_cache: false,
467            judge: None,
468            baseline: None,
469        }
470    }
471
472    fn single_test_config(on_error: ErrorPolicy) -> EvalConfig {
473        EvalConfig {
474            version: 1,
475            suite: "runner-contract".to_string(),
476            model: "fake-model".to_string(),
477            settings: Settings {
478                parallel: Some(1),
479                cache: Some(false),
480                seed: Some(1234),
481                on_error,
482                ..Default::default()
483            },
484            thresholds: Default::default(),
485            otel: Default::default(),
486            tests: vec![TestCase {
487                id: "t1".to_string(),
488                input: TestInput {
489                    prompt: "contract prompt".to_string(),
490                    context: None,
491                },
492                // Expected payload is not used by scripted metrics, but keeps test case valid.
493                expected: Expected::MustContain {
494                    must_contain: vec!["ok".to_string()],
495                },
496                assertions: None,
497                on_error: None,
498                tags: vec![],
499                metadata: None,
500            }],
501        }
502    }
503
504    fn config_with_test_ids(ids: &[&str], on_error: ErrorPolicy) -> EvalConfig {
505        EvalConfig {
506            version: 1,
507            suite: "runner-contract".to_string(),
508            model: "fake-model".to_string(),
509            settings: Settings {
510                parallel: Some(1),
511                cache: Some(false),
512                seed: Some(1234),
513                on_error,
514                ..Default::default()
515            },
516            thresholds: Default::default(),
517            otel: Default::default(),
518            tests: ids
519                .iter()
520                .map(|id| TestCase {
521                    id: (*id).to_string(),
522                    input: TestInput {
523                        prompt: format!("prompt-{id}"),
524                        context: None,
525                    },
526                    expected: Expected::MustContain {
527                        must_contain: vec!["ok".to_string()],
528                    },
529                    assertions: None,
530                    on_error: None,
531                    tags: vec![],
532                    metadata: None,
533                })
534                .collect(),
535        }
536    }
537
538    #[tokio::test]
539    async fn runner_contract_flake_fail_then_pass_classified_flaky() -> anyhow::Result<()> {
540        let cfg = single_test_config(ErrorPolicy::Block);
541        let client = Arc::new(FakeClient::new("fake-model".to_string()).with_response("ok".into()));
542        let metric = Arc::new(ScriptedMetric::fail_then_pass());
543        let runner = runner_for_contract_tests(client, vec![metric], 1);
544
545        let artifacts = runner.run_suite(&cfg, None).await?;
546        let row = artifacts
547            .results
548            .iter()
549            .find(|r| r.test_id == "t1")
550            .expect("result for t1");
551
552        assert_eq!(row.status, TestStatus::Flaky);
553        assert_eq!(row.message, "flake detected (rerun passed)");
554        let attempts = row.attempts.as_ref().expect("attempts");
555        assert_eq!(attempts.len(), 2);
556        assert_eq!(attempts[0].status, TestStatus::Fail);
557        assert_eq!(attempts[1].status, TestStatus::Pass);
558        Ok(())
559    }
560
561    #[tokio::test]
562    async fn runner_contract_fail_after_retries_stays_fail() -> anyhow::Result<()> {
563        let cfg = single_test_config(ErrorPolicy::Block);
564        let client = Arc::new(FakeClient::new("fake-model".to_string()).with_response("ok".into()));
565        let metric = Arc::new(ScriptedMetric::always_fail());
566        let runner = runner_for_contract_tests(client, vec![metric], 1);
567
568        let artifacts = runner.run_suite(&cfg, None).await?;
569        let row = artifacts
570            .results
571            .iter()
572            .find(|r| r.test_id == "t1")
573            .expect("result for t1");
574
575        assert_eq!(row.status, TestStatus::Fail);
576        assert!(
577            row.message.contains("failed: scripted"),
578            "expected stable failure reason, got: {}",
579            row.message
580        );
581        let attempts = row.attempts.as_ref().expect("attempts");
582        assert_eq!(attempts.len(), 2);
583        assert_eq!(attempts[0].status, TestStatus::Fail);
584        assert_eq!(attempts[1].status, TestStatus::Fail);
585        Ok(())
586    }
587
588    #[tokio::test]
589    async fn runner_contract_on_error_allow_marks_allowed_and_policy_applied() -> anyhow::Result<()>
590    {
591        let cfg = single_test_config(ErrorPolicy::Allow);
592        let client = Arc::new(ErrorClient);
593        let runner = runner_for_contract_tests(client, vec![], 2);
594
595        let artifacts = runner.run_suite(&cfg, None).await?;
596        let row = artifacts
597            .results
598            .iter()
599            .find(|r| r.test_id == "t1")
600            .expect("result for t1");
601
602        assert_eq!(row.status, TestStatus::AllowedOnError);
603        assert_eq!(row.error_policy_applied, Some(ErrorPolicy::Allow));
604        assert_eq!(row.details["policy_applied"], serde_json::json!("allow"));
605        let attempts = row.attempts.as_ref().expect("attempts");
606        assert_eq!(attempts.len(), 1);
607        assert_eq!(attempts[0].status, TestStatus::AllowedOnError);
608        Ok(())
609    }
610
611    #[tokio::test]
612    async fn runner_contract_results_sorted_by_test_id() -> anyhow::Result<()> {
613        let mut cfg = config_with_test_ids(&["t3", "t1", "t2"], ErrorPolicy::Block);
614        cfg.settings.parallel = Some(3);
615        let client = Arc::new(FakeClient::new("fake-model".to_string()).with_response("ok".into()));
616        let metric = Arc::new(ScriptedMetric::always_pass());
617        let runner = runner_for_contract_tests(client, vec![metric], 0);
618
619        let artifacts = runner.run_suite(&cfg, None).await?;
620        let ids: Vec<_> = artifacts
621            .results
622            .iter()
623            .map(|r| r.test_id.as_str())
624            .collect();
625        assert_eq!(ids, vec!["t1", "t2", "t3"]);
626        Ok(())
627    }
628
629    #[tokio::test]
630    async fn runner_contract_progress_sink_reports_done_total() -> anyhow::Result<()> {
631        let cfg = config_with_test_ids(&["p1", "p2", "p3"], ErrorPolicy::Block);
632        let client = Arc::new(FakeClient::new("fake-model".to_string()).with_response("ok".into()));
633        let metric = Arc::new(ScriptedMetric::always_pass());
634        let runner = runner_for_contract_tests(client, vec![metric], 0);
635
636        let events = Arc::new(std::sync::Mutex::new(Vec::<(usize, usize)>::new()));
637        let sink = {
638            let events = Arc::clone(&events);
639            Arc::new(move |ev: crate::report::progress::ProgressEvent| {
640                events
641                    .lock()
642                    .expect("progress lock")
643                    .push((ev.done, ev.total));
644            }) as crate::report::progress::ProgressSink
645        };
646
647        let artifacts = runner.run_suite(&cfg, Some(sink)).await?;
648        assert_eq!(artifacts.results.len(), 3);
649
650        let observed = events.lock().expect("progress lock");
651        assert_eq!(observed.len(), 3);
652        assert_eq!(observed.last(), Some(&(3, 3)));
653        assert!(observed.windows(2).all(|w| w[0].0 < w[1].0));
654        Ok(())
655    }
656
657    #[tokio::test]
658    async fn runner_contract_relative_baseline_missing_warns_in_helper() -> anyhow::Result<()> {
659        let mut cfg = single_test_config(ErrorPolicy::Block);
660        cfg.settings.thresholding = Some(crate::model::ThresholdingSettings {
661            mode: Some("relative".to_string()),
662            max_drop: Some(0.05),
663            min_floor: None,
664        });
665
666        let client = Arc::new(FakeClient::new("fake-model".to_string()).with_response("ok".into()));
667        let metric = Arc::new(ScriptedMetric::always_pass());
668        let runner = runner_for_contract_tests(client, vec![], 0);
669        let baseline = crate::baseline::Baseline {
670            schema_version: 1,
671            suite: "runner-contract".to_string(),
672            assay_version: env!("CARGO_PKG_VERSION").to_string(),
673            created_at: "2026-01-01T00:00:00Z".to_string(),
674            config_fingerprint: "md5:test".to_string(),
675            git_info: None,
676            entries: vec![],
677        };
678        let tc = cfg.tests.first().cloned().expect("single test case");
679        let details = serde_json::json!({
680            "metrics": {
681                "scripted": {
682                    "score": 1.0,
683                    "passed": true,
684                    "unstable": false,
685                    "details": {}
686                }
687            }
688        });
689
690        let verdict = runner.check_baseline_regressions(&tc, &cfg, &details, &[metric], &baseline);
691        let (status, message) = verdict.expect("relative baseline decision");
692        assert_eq!(status, TestStatus::Warn);
693        assert_eq!(message, "missing baseline for t1/scripted");
694        Ok(())
695    }
696}