Skip to main content

assay_core/engine/
runner.rs

1use crate::cache::key::cache_key;
2use crate::cache::vcr::VcrCache;
3use crate::metrics_api::Metric;
4use crate::model::{EvalConfig, LlmResponse, TestCase, TestResultRow, TestStatus};
5use crate::providers::llm::LlmClient;
6use crate::quarantine::QuarantineMode;
7use crate::report::progress::ProgressSink;
8use crate::report::RunArtifacts;
9use crate::storage::store::Store;
10use std::sync::Arc;
11
12#[path = "runner_next/mod.rs"]
13mod runner_next;
14
15#[derive(Debug, Clone)]
16pub struct RunPolicy {
17    pub rerun_failures: u32,
18    pub quarantine_mode: QuarantineMode,
19    pub replay_strict: bool,
20}
21
22impl Default for RunPolicy {
23    fn default() -> Self {
24        Self {
25            rerun_failures: 1,
26            quarantine_mode: QuarantineMode::Warn,
27            replay_strict: false,
28        }
29    }
30}
31
32pub struct Runner {
33    pub store: Store,
34    pub cache: VcrCache,
35    pub client: Arc<dyn LlmClient>,
36    pub metrics: Vec<Arc<dyn Metric>>,
37    pub policy: RunPolicy,
38    pub _network_guard: Option<crate::providers::network::NetworkPolicyGuard>,
39    pub embedder: Option<Arc<dyn crate::providers::embedder::Embedder>>,
40    pub refresh_embeddings: bool,
41    pub incremental: bool,
42    pub refresh_cache: bool,
43    pub judge: Option<crate::judge::JudgeService>,
44    pub baseline: Option<crate::baseline::Baseline>,
45}
46
47impl Runner {
48    /// Run the suite; results are collected in completion order internally but returned
49    /// sorted by test_id for deterministic output. If `progress` is set, it is called
50    /// after each test completes (E4.3 realtime progress).
51    pub async fn run_suite(
52        &self,
53        cfg: &EvalConfig,
54        progress: Option<ProgressSink>,
55    ) -> anyhow::Result<RunArtifacts> {
56        runner_next::execute::run_suite_impl(self, cfg, progress).await
57    }
58
59    fn apply_agent_assertions(
60        &self,
61        run_id: i64,
62        tc: &TestCase,
63        final_row: &mut TestResultRow,
64    ) -> anyhow::Result<()> {
65        if let Some(assertions) = &tc.assertions {
66            if !assertions.is_empty() {
67                match crate::agent_assertions::verify_assertions(
68                    &self.store,
69                    run_id,
70                    &tc.id,
71                    assertions,
72                ) {
73                    Ok(diags) => {
74                        if !diags.is_empty() {
75                            // Assertion Failures
76                            final_row.status = TestStatus::Fail;
77
78                            let diag_json: Vec<serde_json::Value> = diags
79                                .iter()
80                                .map(|d| serde_json::to_value(d).unwrap_or_default())
81                                .collect();
82
83                            final_row.details["assertions"] = serde_json::Value::Array(diag_json);
84
85                            let fail_msg = format!("assertions failed ({})", diags.len());
86                            if final_row.message == "ok" {
87                                final_row.message = fail_msg;
88                            } else {
89                                final_row.message = format!("{}; {}", final_row.message, fail_msg);
90                            }
91                        } else {
92                            // passed
93                            final_row.details["assertions"] = serde_json::json!({ "passed": true });
94                        }
95                    }
96                    Err(e) => {
97                        // Missing or Ambiguous Episode -> Fail
98                        final_row.status = TestStatus::Fail;
99                        final_row.message = format!("assertions error: {}", e);
100                        final_row.details["assertions"] =
101                            serde_json::json!({ "error": e.to_string() });
102                    }
103                }
104            }
105        }
106        Ok(())
107    }
108
109    async fn run_test_once(
110        &self,
111        cfg: &EvalConfig,
112        tc: &TestCase,
113    ) -> anyhow::Result<(TestResultRow, LlmResponse)> {
114        let expected_json = serde_json::to_string(&tc.expected).unwrap_or_default();
115        let metric_versions = [("assay", env!("CARGO_PKG_VERSION"))];
116
117        let policy_hash = if let Some(path) = tc.expected.get_policy_path() {
118            // Read policy content to ensure cache invalidation on content change
119            match std::fs::read_to_string(path) {
120                Ok(content) => Some(crate::fingerprint::sha256_hex(&content)),
121                Err(_) => None, // If file missing, finding it later will error.
122                                // We don't fail here to allow error reporting in metrics phase or main loop.
123            }
124        } else {
125            None
126        };
127
128        let fp = crate::fingerprint::compute(crate::fingerprint::Context {
129            suite: &cfg.suite,
130            model: &cfg.model,
131            test_id: &tc.id,
132            prompt: &tc.input.prompt,
133            context: tc.input.context.as_deref(),
134            expected_canonical: &expected_json,
135            policy_hash: policy_hash.as_deref(),
136            metric_versions: &metric_versions,
137        });
138
139        // Incremental cache check.
140        if self.incremental && !self.refresh_cache {
141            if let Some(prev) = self.store.get_last_passing_by_fingerprint(&fp.hex)? {
142                // Return Skipped Result
143                let row = TestResultRow {
144                    test_id: tc.id.clone(),
145                    status: TestStatus::Skipped,
146                    score: prev.score,
147                    cached: true,
148                    message: "skipped: fingerprint match".into(),
149                    details: serde_json::json!({
150                        "skip": {
151                             "reason": "fingerprint_match",
152                             "fingerprint": fp.hex,
153                             "previous_run_id": prev.details.get("skip").and_then(|s: &serde_json::Value| s.get("previous_run_id")).and_then(|v: &serde_json::Value| v.as_i64()),
154                             "previous_at": prev.details.get("skip").and_then(|s: &serde_json::Value| s.get("previous_at")).and_then(|v: &serde_json::Value| v.as_str()),
155                             "origin_run_id": prev.details.get("skip").and_then(|s: &serde_json::Value| s.get("origin_run_id")).and_then(|v: &serde_json::Value| v.as_i64()),
156                             "previous_score": prev.score
157                        }
158                    }),
159                    duration_ms: Some(0), // Instant
160                    fingerprint: Some(fp.hex.clone()),
161                    skip_reason: Some("fingerprint_match".into()),
162                    attempts: None,
163                    error_policy_applied: None,
164                };
165
166                // Construct placeholder response for pipeline consistency
167                let resp = LlmResponse {
168                    text: "".into(),
169                    provider: "skipped".into(),
170                    model: cfg.model.clone(),
171                    cached: true,
172                    meta: serde_json::json!({}),
173                };
174                return Ok((row, resp));
175            }
176        }
177
178        // Original Execution Logic
179        // We use the computed fingerprint for caching key to distinguish config variations
180        let key = cache_key(
181            &cfg.model,
182            &tc.input.prompt,
183            &fp.hex,
184            self.client.fingerprint().as_deref(),
185        );
186
187        let start = std::time::Instant::now();
188        let mut cached = false;
189
190        let mut resp: LlmResponse = if cfg.settings.cache.unwrap_or(true) && !self.refresh_cache {
191            if let Some(r) = self.cache.get(&key)? {
192                cached = true;
193                eprintln!(
194                    "  [CACHE HIT] key={} prompt_len={}",
195                    key,
196                    tc.input.prompt.len()
197                );
198                r
199            } else {
200                let r = self.call_llm(cfg, tc).await?;
201                self.cache.put(&key, &r)?;
202                r
203            }
204        } else {
205            self.call_llm(cfg, tc).await?
206        };
207        resp.cached = resp.cached || cached;
208
209        // Semantic Enrichment
210        self.enrich_semantic(cfg, tc, &mut resp).await?;
211        self.enrich_judge(cfg, tc, &mut resp).await?;
212
213        let mut final_status = TestStatus::Pass;
214        let mut final_score: Option<f64> = None;
215        let mut msg = String::new();
216        let mut details = serde_json::json!({ "metrics": {} });
217
218        for m in &self.metrics {
219            let r = m.evaluate(tc, &tc.expected, &resp).await?;
220            details["metrics"][m.name()] = serde_json::json!({
221                "score": r.score, "passed": r.passed, "unstable": r.unstable, "details": r.details
222            });
223            final_score = Some(r.score);
224
225            if r.unstable {
226                // gate stability first: treat unstable as warn in MVP
227                final_status = TestStatus::Warn;
228                msg = format!("unstable metric: {}", m.name());
229                break;
230            }
231            if !r.passed {
232                final_status = TestStatus::Fail;
233                msg = format!("failed: {}", m.name());
234                break;
235            }
236        }
237
238        // Post-metric baseline check
239        if let Some(baseline) = &self.baseline {
240            if let Some((new_status, new_msg)) =
241                self.check_baseline_regressions(tc, cfg, &details, &self.metrics, baseline)
242            {
243                if matches!(new_status, TestStatus::Fail | TestStatus::Warn) {
244                    final_status = new_status;
245                    msg = new_msg;
246                }
247            }
248        }
249
250        let duration_ms = start.elapsed().as_millis() as u64;
251        let mut row = TestResultRow {
252            test_id: tc.id.clone(),
253            status: final_status,
254            score: final_score,
255            cached: resp.cached,
256            message: if msg.is_empty() { "ok".into() } else { msg },
257            details,
258            duration_ms: Some(duration_ms),
259            fingerprint: Some(fp.hex),
260            skip_reason: None,
261            attempts: None,
262            error_policy_applied: None,
263        };
264
265        if self.client.provider_name() == "trace" {
266            row.details["assay.replay"] = serde_json::json!(true);
267        }
268
269        row.details["prompt"] = serde_json::Value::String(tc.input.prompt.clone());
270
271        Ok((row, resp))
272    }
273
274    async fn call_llm(&self, cfg: &EvalConfig, tc: &TestCase) -> anyhow::Result<LlmResponse> {
275        runner_next::execute::call_llm_impl(self, cfg, tc).await
276    }
277
278    fn check_baseline_regressions(
279        &self,
280        tc: &TestCase,
281        cfg: &EvalConfig,
282        details: &serde_json::Value,
283        metrics: &[Arc<dyn Metric>],
284        baseline: &crate::baseline::Baseline,
285    ) -> Option<(TestStatus, String)> {
286        runner_next::baseline::check_baseline_regressions_impl(
287            self, tc, cfg, details, metrics, baseline,
288        )
289    }
290
291    // Embeddings logic
292    async fn enrich_semantic(
293        &self,
294        _cfg: &EvalConfig,
295        tc: &TestCase,
296        resp: &mut LlmResponse,
297    ) -> anyhow::Result<()> {
298        runner_next::scoring::enrich_semantic_impl(self, _cfg, tc, resp).await
299    }
300
301    pub async fn embed_text(
302        &self,
303        model_id: &str,
304        embedder: &dyn crate::providers::embedder::Embedder,
305        text: &str,
306    ) -> anyhow::Result<(Vec<f32>, &'static str)> {
307        runner_next::cache::embed_text_impl(self, model_id, embedder, text).await
308    }
309
310    async fn enrich_judge(
311        &self,
312        cfg: &EvalConfig,
313        tc: &TestCase,
314        resp: &mut LlmResponse,
315    ) -> anyhow::Result<()> {
316        runner_next::scoring::enrich_judge_impl(self, cfg, tc, resp).await
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::metrics_api::{Metric, MetricResult};
324    use crate::model::{Expected, Settings, TestInput};
325    use crate::on_error::ErrorPolicy;
326    use crate::providers::llm::fake::FakeClient;
327    use crate::providers::llm::LlmClient;
328    use async_trait::async_trait;
329    use std::sync::atomic::{AtomicUsize, Ordering};
330
331    #[derive(Clone, Copy)]
332    enum MetricMode {
333        FailThenPass,
334        AlwaysFail,
335        AlwaysPass,
336    }
337
338    struct ScriptedMetric {
339        mode: MetricMode,
340        calls: AtomicUsize,
341    }
342
343    impl ScriptedMetric {
344        fn fail_then_pass() -> Self {
345            Self {
346                mode: MetricMode::FailThenPass,
347                calls: AtomicUsize::new(0),
348            }
349        }
350
351        fn always_fail() -> Self {
352            Self {
353                mode: MetricMode::AlwaysFail,
354                calls: AtomicUsize::new(0),
355            }
356        }
357
358        fn always_pass() -> Self {
359            Self {
360                mode: MetricMode::AlwaysPass,
361                calls: AtomicUsize::new(0),
362            }
363        }
364    }
365
366    #[async_trait]
367    impl Metric for ScriptedMetric {
368        fn name(&self) -> &'static str {
369            "scripted"
370        }
371
372        async fn evaluate(
373            &self,
374            _tc: &TestCase,
375            _expected: &Expected,
376            _resp: &LlmResponse,
377        ) -> anyhow::Result<MetricResult> {
378            let n = self.calls.fetch_add(1, Ordering::SeqCst);
379            match self.mode {
380                MetricMode::FailThenPass => {
381                    if n == 0 {
382                        Ok(MetricResult::fail(0.0, "scripted_fail_once"))
383                    } else {
384                        Ok(MetricResult::pass(1.0))
385                    }
386                }
387                MetricMode::AlwaysFail => Ok(MetricResult::fail(0.0, "scripted_fail")),
388                MetricMode::AlwaysPass => Ok(MetricResult::pass(1.0)),
389            }
390        }
391    }
392
393    struct ErrorClient;
394
395    #[async_trait]
396    impl LlmClient for ErrorClient {
397        async fn complete(
398            &self,
399            _prompt: &str,
400            _context: Option<&[String]>,
401        ) -> anyhow::Result<LlmResponse> {
402            Err(anyhow::anyhow!("scripted provider error"))
403        }
404
405        fn provider_name(&self) -> &'static str {
406            "error_client"
407        }
408    }
409
410    fn runner_for_contract_tests(
411        client: Arc<dyn LlmClient>,
412        metrics: Vec<Arc<dyn Metric>>,
413        rerun_failures: u32,
414    ) -> Runner {
415        let store = Store::memory().expect("in-memory store");
416        store.init_schema().expect("schema init");
417        Runner {
418            store: store.clone(),
419            cache: VcrCache::new(store),
420            client,
421            metrics,
422            policy: RunPolicy {
423                rerun_failures,
424                quarantine_mode: QuarantineMode::Off,
425                replay_strict: false,
426            },
427            _network_guard: None,
428            embedder: None,
429            refresh_embeddings: false,
430            incremental: false,
431            refresh_cache: false,
432            judge: None,
433            baseline: None,
434        }
435    }
436
437    fn single_test_config(on_error: ErrorPolicy) -> EvalConfig {
438        EvalConfig {
439            version: 1,
440            suite: "runner-contract".to_string(),
441            model: "fake-model".to_string(),
442            settings: Settings {
443                parallel: Some(1),
444                cache: Some(false),
445                seed: Some(1234),
446                on_error,
447                ..Default::default()
448            },
449            thresholds: Default::default(),
450            otel: Default::default(),
451            tests: vec![TestCase {
452                id: "t1".to_string(),
453                input: TestInput {
454                    prompt: "contract prompt".to_string(),
455                    context: None,
456                },
457                // Expected payload is not used by scripted metrics, but keeps test case valid.
458                expected: Expected::MustContain {
459                    must_contain: vec!["ok".to_string()],
460                },
461                assertions: None,
462                on_error: None,
463                tags: vec![],
464                metadata: None,
465            }],
466        }
467    }
468
469    fn config_with_test_ids(ids: &[&str], on_error: ErrorPolicy) -> EvalConfig {
470        EvalConfig {
471            version: 1,
472            suite: "runner-contract".to_string(),
473            model: "fake-model".to_string(),
474            settings: Settings {
475                parallel: Some(1),
476                cache: Some(false),
477                seed: Some(1234),
478                on_error,
479                ..Default::default()
480            },
481            thresholds: Default::default(),
482            otel: Default::default(),
483            tests: ids
484                .iter()
485                .map(|id| TestCase {
486                    id: (*id).to_string(),
487                    input: TestInput {
488                        prompt: format!("prompt-{id}"),
489                        context: None,
490                    },
491                    expected: Expected::MustContain {
492                        must_contain: vec!["ok".to_string()],
493                    },
494                    assertions: None,
495                    on_error: None,
496                    tags: vec![],
497                    metadata: None,
498                })
499                .collect(),
500        }
501    }
502
503    #[tokio::test]
504    async fn runner_contract_flake_fail_then_pass_classified_flaky() -> anyhow::Result<()> {
505        let cfg = single_test_config(ErrorPolicy::Block);
506        let client = Arc::new(FakeClient::new("fake-model".to_string()).with_response("ok".into()));
507        let metric = Arc::new(ScriptedMetric::fail_then_pass());
508        let runner = runner_for_contract_tests(client, vec![metric], 1);
509
510        let artifacts = runner.run_suite(&cfg, None).await?;
511        let row = artifacts
512            .results
513            .iter()
514            .find(|r| r.test_id == "t1")
515            .expect("result for t1");
516
517        assert_eq!(row.status, TestStatus::Flaky);
518        assert_eq!(row.message, "flake detected (rerun passed)");
519        let attempts = row.attempts.as_ref().expect("attempts");
520        assert_eq!(attempts.len(), 2);
521        assert_eq!(attempts[0].status, TestStatus::Fail);
522        assert_eq!(attempts[1].status, TestStatus::Pass);
523        Ok(())
524    }
525
526    #[tokio::test]
527    async fn runner_contract_fail_after_retries_stays_fail() -> anyhow::Result<()> {
528        let cfg = single_test_config(ErrorPolicy::Block);
529        let client = Arc::new(FakeClient::new("fake-model".to_string()).with_response("ok".into()));
530        let metric = Arc::new(ScriptedMetric::always_fail());
531        let runner = runner_for_contract_tests(client, vec![metric], 1);
532
533        let artifacts = runner.run_suite(&cfg, None).await?;
534        let row = artifacts
535            .results
536            .iter()
537            .find(|r| r.test_id == "t1")
538            .expect("result for t1");
539
540        assert_eq!(row.status, TestStatus::Fail);
541        assert!(
542            row.message.contains("failed: scripted"),
543            "expected stable failure reason, got: {}",
544            row.message
545        );
546        let attempts = row.attempts.as_ref().expect("attempts");
547        assert_eq!(attempts.len(), 2);
548        assert_eq!(attempts[0].status, TestStatus::Fail);
549        assert_eq!(attempts[1].status, TestStatus::Fail);
550        Ok(())
551    }
552
553    #[tokio::test]
554    async fn runner_contract_on_error_allow_marks_allowed_and_policy_applied() -> anyhow::Result<()>
555    {
556        let cfg = single_test_config(ErrorPolicy::Allow);
557        let client = Arc::new(ErrorClient);
558        let runner = runner_for_contract_tests(client, vec![], 2);
559
560        let artifacts = runner.run_suite(&cfg, None).await?;
561        let row = artifacts
562            .results
563            .iter()
564            .find(|r| r.test_id == "t1")
565            .expect("result for t1");
566
567        assert_eq!(row.status, TestStatus::AllowedOnError);
568        assert_eq!(row.error_policy_applied, Some(ErrorPolicy::Allow));
569        assert_eq!(row.details["policy_applied"], serde_json::json!("allow"));
570        let attempts = row.attempts.as_ref().expect("attempts");
571        assert_eq!(attempts.len(), 1);
572        assert_eq!(attempts[0].status, TestStatus::AllowedOnError);
573        Ok(())
574    }
575
576    #[tokio::test]
577    async fn runner_contract_results_sorted_by_test_id() -> anyhow::Result<()> {
578        let mut cfg = config_with_test_ids(&["t3", "t1", "t2"], ErrorPolicy::Block);
579        cfg.settings.parallel = Some(3);
580        let client = Arc::new(FakeClient::new("fake-model".to_string()).with_response("ok".into()));
581        let metric = Arc::new(ScriptedMetric::always_pass());
582        let runner = runner_for_contract_tests(client, vec![metric], 0);
583
584        let artifacts = runner.run_suite(&cfg, None).await?;
585        let ids: Vec<_> = artifacts
586            .results
587            .iter()
588            .map(|r| r.test_id.as_str())
589            .collect();
590        assert_eq!(ids, vec!["t1", "t2", "t3"]);
591        Ok(())
592    }
593
594    #[tokio::test]
595    async fn runner_contract_progress_sink_reports_done_total() -> anyhow::Result<()> {
596        let cfg = config_with_test_ids(&["p1", "p2", "p3"], ErrorPolicy::Block);
597        let client = Arc::new(FakeClient::new("fake-model".to_string()).with_response("ok".into()));
598        let metric = Arc::new(ScriptedMetric::always_pass());
599        let runner = runner_for_contract_tests(client, vec![metric], 0);
600
601        let events = Arc::new(std::sync::Mutex::new(Vec::<(usize, usize)>::new()));
602        let sink = {
603            let events = Arc::clone(&events);
604            Arc::new(move |ev: crate::report::progress::ProgressEvent| {
605                events
606                    .lock()
607                    .expect("progress lock")
608                    .push((ev.done, ev.total));
609            }) as crate::report::progress::ProgressSink
610        };
611
612        let artifacts = runner.run_suite(&cfg, Some(sink)).await?;
613        assert_eq!(artifacts.results.len(), 3);
614
615        let observed = events.lock().expect("progress lock");
616        assert_eq!(observed.len(), 3);
617        assert_eq!(observed.last(), Some(&(3, 3)));
618        assert!(observed.windows(2).all(|w| w[0].0 < w[1].0));
619        Ok(())
620    }
621
622    #[tokio::test]
623    async fn runner_contract_relative_baseline_missing_warns_in_helper() -> anyhow::Result<()> {
624        let mut cfg = single_test_config(ErrorPolicy::Block);
625        cfg.settings.thresholding = Some(crate::model::ThresholdingSettings {
626            mode: Some("relative".to_string()),
627            max_drop: Some(0.05),
628            min_floor: None,
629        });
630
631        let client = Arc::new(FakeClient::new("fake-model".to_string()).with_response("ok".into()));
632        let metric = Arc::new(ScriptedMetric::always_pass());
633        let runner = runner_for_contract_tests(client, vec![], 0);
634        let baseline = crate::baseline::Baseline {
635            schema_version: 1,
636            suite: "runner-contract".to_string(),
637            assay_version: env!("CARGO_PKG_VERSION").to_string(),
638            created_at: "2026-01-01T00:00:00Z".to_string(),
639            config_fingerprint: "md5:test".to_string(),
640            git_info: None,
641            entries: vec![],
642        };
643        let tc = cfg.tests.first().cloned().expect("single test case");
644        let details = serde_json::json!({
645            "metrics": {
646                "scripted": {
647                    "score": 1.0,
648                    "passed": true,
649                    "unstable": false,
650                    "details": {}
651                }
652            }
653        });
654
655        let verdict = runner.check_baseline_regressions(&tc, &cfg, &details, &[metric], &baseline);
656        let (status, message) = verdict.expect("relative baseline decision");
657        assert_eq!(status, TestStatus::Warn);
658        assert_eq!(message, "missing baseline for t1/scripted");
659        Ok(())
660    }
661}