1use crate::cache::key::cache_key;
2use crate::cache::vcr::VcrCache;
3use crate::metrics_api::Metric;
4use crate::model::{EvalConfig, LlmResponse, TestCase, TestResultRow, TestStatus};
5use crate::providers::llm::LlmClient;
6use crate::quarantine::QuarantineMode;
7use crate::report::progress::ProgressSink;
8use crate::report::RunArtifacts;
9use crate::storage::store::Store;
10use std::sync::Arc;
11use tracing::{info_span, Instrument};
12
13#[path = "runner_next/mod.rs"]
14mod runner_next;
15
16#[derive(Debug, Clone)]
17pub struct RunPolicy {
18 pub rerun_failures: u32,
19 pub quarantine_mode: QuarantineMode,
20 pub replay_strict: bool,
21}
22
23impl Default for RunPolicy {
24 fn default() -> Self {
25 Self {
26 rerun_failures: 1,
27 quarantine_mode: QuarantineMode::Warn,
28 replay_strict: false,
29 }
30 }
31}
32
33pub struct Runner {
34 pub store: Store,
35 pub cache: VcrCache,
36 pub client: Arc<dyn LlmClient>,
37 pub metrics: Vec<Arc<dyn Metric>>,
38 pub policy: RunPolicy,
39 pub _network_guard: Option<crate::providers::network::NetworkPolicyGuard>,
40 pub embedder: Option<Arc<dyn crate::providers::embedder::Embedder>>,
41 pub refresh_embeddings: bool,
42 pub incremental: bool,
43 pub refresh_cache: bool,
44 pub judge: Option<crate::judge::JudgeService>,
45 pub baseline: Option<crate::baseline::Baseline>,
46}
47
48impl Runner {
49 pub async fn run_suite(
53 &self,
54 cfg: &EvalConfig,
55 progress: Option<ProgressSink>,
56 ) -> anyhow::Result<RunArtifacts> {
57 runner_next::execute::run_suite_impl(self, cfg, progress).await
58 }
59
60 fn apply_agent_assertions(
61 &self,
62 run_id: i64,
63 tc: &TestCase,
64 final_row: &mut TestResultRow,
65 ) -> anyhow::Result<()> {
66 if let Some(assertions) = &tc.assertions {
67 if !assertions.is_empty() {
68 match crate::agent_assertions::verify_assertions(
69 &self.store,
70 run_id,
71 &tc.id,
72 assertions,
73 ) {
74 Ok(diags) => {
75 if !diags.is_empty() {
76 final_row.status = TestStatus::Fail;
78
79 let diag_json: Vec<serde_json::Value> = diags
80 .iter()
81 .map(|d| serde_json::to_value(d).unwrap_or_default())
82 .collect();
83
84 final_row.details["assertions"] = serde_json::Value::Array(diag_json);
85
86 let fail_msg = format!("assertions failed ({})", diags.len());
87 if final_row.message == "ok" {
88 final_row.message = fail_msg;
89 } else {
90 final_row.message = format!("{}; {}", final_row.message, fail_msg);
91 }
92 } else {
93 final_row.details["assertions"] = serde_json::json!({ "passed": true });
95 }
96 }
97 Err(e) => {
98 final_row.status = TestStatus::Fail;
100 final_row.message = format!("assertions error: {}", e);
101 final_row.details["assertions"] =
102 serde_json::json!({ "error": e.to_string() });
103 }
104 }
105 }
106 }
107 Ok(())
108 }
109
110 async fn run_test_once(
111 &self,
112 cfg: &EvalConfig,
113 tc: &TestCase,
114 ) -> anyhow::Result<(TestResultRow, LlmResponse)> {
115 let expected_json = serde_json::to_string(&tc.expected).unwrap_or_default();
116 let metric_versions = [("assay", env!("CARGO_PKG_VERSION"))];
117
118 let policy_hash = if let Some(path) = tc.expected.get_policy_path() {
119 match std::fs::read_to_string(path) {
121 Ok(content) => Some(crate::fingerprint::sha256_hex(&content)),
122 Err(_) => None, }
125 } else {
126 None
127 };
128
129 let fp = crate::fingerprint::compute(crate::fingerprint::Context {
130 suite: &cfg.suite,
131 model: &cfg.model,
132 test_id: &tc.id,
133 prompt: &tc.input.prompt,
134 context: tc.input.context.as_deref(),
135 expected_canonical: &expected_json,
136 policy_hash: policy_hash.as_deref(),
137 metric_versions: &metric_versions,
138 });
139
140 if self.incremental && !self.refresh_cache {
142 if let Some(prev) = self.store.get_last_passing_by_fingerprint(&fp.hex)? {
143 let row = TestResultRow {
145 test_id: tc.id.clone(),
146 status: TestStatus::Skipped,
147 score: prev.score,
148 cached: true,
149 message: "skipped: fingerprint match".into(),
150 details: serde_json::json!({
151 "skip": {
152 "reason": "fingerprint_match",
153 "fingerprint": fp.hex,
154 "previous_run_id": prev.details.get("skip").and_then(|s: &serde_json::Value| s.get("previous_run_id")).and_then(|v: &serde_json::Value| v.as_i64()),
155 "previous_at": prev.details.get("skip").and_then(|s: &serde_json::Value| s.get("previous_at")).and_then(|v: &serde_json::Value| v.as_str()),
156 "origin_run_id": prev.details.get("skip").and_then(|s: &serde_json::Value| s.get("origin_run_id")).and_then(|v: &serde_json::Value| v.as_i64()),
157 "previous_score": prev.score
158 }
159 }),
160 duration_ms: Some(0), fingerprint: Some(fp.hex.clone()),
162 skip_reason: Some("fingerprint_match".into()),
163 attempts: None,
164 error_policy_applied: None,
165 };
166
167 let resp = LlmResponse {
169 text: "".into(),
170 provider: "skipped".into(),
171 model: cfg.model.clone(),
172 cached: true,
173 meta: serde_json::json!({}),
174 };
175 return Ok((row, resp));
176 }
177 }
178
179 let key = cache_key(
182 &cfg.model,
183 &tc.input.prompt,
184 &fp.hex,
185 self.client.fingerprint().as_deref(),
186 );
187
188 let start = std::time::Instant::now();
189 let mut cached = false;
190
191 let mut resp: LlmResponse = if cfg.settings.cache.unwrap_or(true) && !self.refresh_cache {
192 if let Some(r) = self.cache.get(&key)? {
193 cached = true;
194 eprintln!(
195 " [CACHE HIT] key={} prompt_len={}",
196 key,
197 tc.input.prompt.len()
198 );
199 r
200 } else {
201 let r = self.call_llm(cfg, tc).await?;
202 self.cache.put(&key, &r)?;
203 r
204 }
205 } else {
206 self.call_llm(cfg, tc).await?
207 };
208 resp.cached = resp.cached || cached;
209
210 self.enrich_semantic(cfg, tc, &mut resp).await?;
212 self.enrich_judge(cfg, tc, &mut resp).await?;
213
214 let mut final_status = TestStatus::Pass;
215 let mut final_score: Option<f64> = None;
216 let mut msg = String::new();
217 let mut details = serde_json::json!({ "metrics": {} });
218
219 for m in &self.metrics {
220 let metric_name = m.name();
221 let metric_span = info_span!(
222 "assay.eval.metric",
223 "assay.eval.test_id" = tc.id.as_str(),
224 "assay.eval.metric.name" = metric_name,
225 "assay.eval.response.cached" = resp.cached,
226 "assay.eval.metric.score" = tracing::field::Empty,
227 "assay.eval.metric.passed" = tracing::field::Empty,
228 "assay.eval.metric.unstable" = tracing::field::Empty,
229 "assay.eval.metric.duration_ms" = tracing::field::Empty,
230 "error" = tracing::field::Empty,
231 "error.message" = tracing::field::Empty
232 );
233 let metric_start = std::time::Instant::now();
234 let metric_result = async { m.evaluate(tc, &tc.expected, &resp).await }
235 .instrument(metric_span.clone())
236 .await;
237 let metric_duration_ms = metric_start.elapsed().as_millis() as u64;
238 metric_span.record("assay.eval.metric.duration_ms", metric_duration_ms);
239
240 let r = match metric_result {
241 Ok(result) => {
242 metric_span.record("assay.eval.metric.score", result.score);
243 metric_span.record("assay.eval.metric.passed", result.passed);
244 metric_span.record("assay.eval.metric.unstable", result.unstable);
245 result
246 }
247 Err(err) => {
248 let error_message = err.to_string();
249 metric_span.record("error", true);
250 metric_span.record("error.message", error_message.as_str());
251 return Err(err);
252 }
253 };
254
255 details["metrics"][metric_name] = serde_json::json!({
256 "score": r.score, "passed": r.passed, "unstable": r.unstable, "details": r.details
257 });
258 final_score = Some(r.score);
259
260 if r.unstable {
261 final_status = TestStatus::Warn;
263 msg = format!("unstable metric: {}", metric_name);
264 break;
265 }
266 if !r.passed {
267 final_status = TestStatus::Fail;
268 msg = format!("failed: {}", metric_name);
269 break;
270 }
271 }
272
273 if let Some(baseline) = &self.baseline {
275 if let Some((new_status, new_msg)) =
276 self.check_baseline_regressions(tc, cfg, &details, &self.metrics, baseline)
277 {
278 if matches!(new_status, TestStatus::Fail | TestStatus::Warn) {
279 final_status = new_status;
280 msg = new_msg;
281 }
282 }
283 }
284
285 let duration_ms = start.elapsed().as_millis() as u64;
286 let mut row = TestResultRow {
287 test_id: tc.id.clone(),
288 status: final_status,
289 score: final_score,
290 cached: resp.cached,
291 message: if msg.is_empty() { "ok".into() } else { msg },
292 details,
293 duration_ms: Some(duration_ms),
294 fingerprint: Some(fp.hex),
295 skip_reason: None,
296 attempts: None,
297 error_policy_applied: None,
298 };
299
300 if self.client.provider_name() == "trace" {
301 row.details["assay.replay"] = serde_json::json!(true);
302 }
303
304 row.details["prompt"] = serde_json::Value::String(tc.input.prompt.clone());
305
306 Ok((row, resp))
307 }
308
309 async fn call_llm(&self, cfg: &EvalConfig, tc: &TestCase) -> anyhow::Result<LlmResponse> {
310 runner_next::execute::call_llm_impl(self, cfg, tc).await
311 }
312
313 fn check_baseline_regressions(
314 &self,
315 tc: &TestCase,
316 cfg: &EvalConfig,
317 details: &serde_json::Value,
318 metrics: &[Arc<dyn Metric>],
319 baseline: &crate::baseline::Baseline,
320 ) -> Option<(TestStatus, String)> {
321 runner_next::baseline::check_baseline_regressions_impl(
322 self, tc, cfg, details, metrics, baseline,
323 )
324 }
325
326 async fn enrich_semantic(
328 &self,
329 _cfg: &EvalConfig,
330 tc: &TestCase,
331 resp: &mut LlmResponse,
332 ) -> anyhow::Result<()> {
333 runner_next::scoring::enrich_semantic_impl(self, _cfg, tc, resp).await
334 }
335
336 pub async fn embed_text(
337 &self,
338 model_id: &str,
339 embedder: &dyn crate::providers::embedder::Embedder,
340 text: &str,
341 ) -> anyhow::Result<(Vec<f32>, &'static str)> {
342 runner_next::cache::embed_text_impl(self, model_id, embedder, text).await
343 }
344
345 async fn enrich_judge(
346 &self,
347 cfg: &EvalConfig,
348 tc: &TestCase,
349 resp: &mut LlmResponse,
350 ) -> anyhow::Result<()> {
351 runner_next::scoring::enrich_judge_impl(self, cfg, tc, resp).await
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use crate::metrics_api::{Metric, MetricResult};
359 use crate::model::{Expected, Settings, TestInput};
360 use crate::on_error::ErrorPolicy;
361 use crate::providers::llm::fake::FakeClient;
362 use crate::providers::llm::LlmClient;
363 use async_trait::async_trait;
364 use std::sync::atomic::{AtomicUsize, Ordering};
365
366 #[derive(Clone, Copy)]
367 enum MetricMode {
368 FailThenPass,
369 AlwaysFail,
370 AlwaysPass,
371 }
372
373 struct ScriptedMetric {
374 mode: MetricMode,
375 calls: AtomicUsize,
376 }
377
378 impl ScriptedMetric {
379 fn fail_then_pass() -> Self {
380 Self {
381 mode: MetricMode::FailThenPass,
382 calls: AtomicUsize::new(0),
383 }
384 }
385
386 fn always_fail() -> Self {
387 Self {
388 mode: MetricMode::AlwaysFail,
389 calls: AtomicUsize::new(0),
390 }
391 }
392
393 fn always_pass() -> Self {
394 Self {
395 mode: MetricMode::AlwaysPass,
396 calls: AtomicUsize::new(0),
397 }
398 }
399 }
400
401 #[async_trait]
402 impl Metric for ScriptedMetric {
403 fn name(&self) -> &'static str {
404 "scripted"
405 }
406
407 async fn evaluate(
408 &self,
409 _tc: &TestCase,
410 _expected: &Expected,
411 _resp: &LlmResponse,
412 ) -> anyhow::Result<MetricResult> {
413 let n = self.calls.fetch_add(1, Ordering::SeqCst);
414 match self.mode {
415 MetricMode::FailThenPass => {
416 if n == 0 {
417 Ok(MetricResult::fail(0.0, "scripted_fail_once"))
418 } else {
419 Ok(MetricResult::pass(1.0))
420 }
421 }
422 MetricMode::AlwaysFail => Ok(MetricResult::fail(0.0, "scripted_fail")),
423 MetricMode::AlwaysPass => Ok(MetricResult::pass(1.0)),
424 }
425 }
426 }
427
428 struct ErrorClient;
429
430 #[async_trait]
431 impl LlmClient for ErrorClient {
432 async fn complete(
433 &self,
434 _prompt: &str,
435 _context: Option<&[String]>,
436 ) -> anyhow::Result<LlmResponse> {
437 Err(anyhow::anyhow!("scripted provider error"))
438 }
439
440 fn provider_name(&self) -> &'static str {
441 "error_client"
442 }
443 }
444
445 fn runner_for_contract_tests(
446 client: Arc<dyn LlmClient>,
447 metrics: Vec<Arc<dyn Metric>>,
448 rerun_failures: u32,
449 ) -> Runner {
450 let store = Store::memory().expect("in-memory store");
451 store.init_schema().expect("schema init");
452 Runner {
453 store: store.clone(),
454 cache: VcrCache::new(store),
455 client,
456 metrics,
457 policy: RunPolicy {
458 rerun_failures,
459 quarantine_mode: QuarantineMode::Off,
460 replay_strict: false,
461 },
462 _network_guard: None,
463 embedder: None,
464 refresh_embeddings: false,
465 incremental: false,
466 refresh_cache: false,
467 judge: None,
468 baseline: None,
469 }
470 }
471
472 fn single_test_config(on_error: ErrorPolicy) -> EvalConfig {
473 EvalConfig {
474 version: 1,
475 suite: "runner-contract".to_string(),
476 model: "fake-model".to_string(),
477 settings: Settings {
478 parallel: Some(1),
479 cache: Some(false),
480 seed: Some(1234),
481 on_error,
482 ..Default::default()
483 },
484 thresholds: Default::default(),
485 otel: Default::default(),
486 tests: vec![TestCase {
487 id: "t1".to_string(),
488 input: TestInput {
489 prompt: "contract prompt".to_string(),
490 context: None,
491 },
492 expected: Expected::MustContain {
494 must_contain: vec!["ok".to_string()],
495 },
496 assertions: None,
497 on_error: None,
498 tags: vec![],
499 metadata: None,
500 }],
501 }
502 }
503
504 fn config_with_test_ids(ids: &[&str], on_error: ErrorPolicy) -> EvalConfig {
505 EvalConfig {
506 version: 1,
507 suite: "runner-contract".to_string(),
508 model: "fake-model".to_string(),
509 settings: Settings {
510 parallel: Some(1),
511 cache: Some(false),
512 seed: Some(1234),
513 on_error,
514 ..Default::default()
515 },
516 thresholds: Default::default(),
517 otel: Default::default(),
518 tests: ids
519 .iter()
520 .map(|id| TestCase {
521 id: (*id).to_string(),
522 input: TestInput {
523 prompt: format!("prompt-{id}"),
524 context: None,
525 },
526 expected: Expected::MustContain {
527 must_contain: vec!["ok".to_string()],
528 },
529 assertions: None,
530 on_error: None,
531 tags: vec![],
532 metadata: None,
533 })
534 .collect(),
535 }
536 }
537
538 #[tokio::test]
539 async fn runner_contract_flake_fail_then_pass_classified_flaky() -> anyhow::Result<()> {
540 let cfg = single_test_config(ErrorPolicy::Block);
541 let client = Arc::new(FakeClient::new("fake-model".to_string()).with_response("ok".into()));
542 let metric = Arc::new(ScriptedMetric::fail_then_pass());
543 let runner = runner_for_contract_tests(client, vec![metric], 1);
544
545 let artifacts = runner.run_suite(&cfg, None).await?;
546 let row = artifacts
547 .results
548 .iter()
549 .find(|r| r.test_id == "t1")
550 .expect("result for t1");
551
552 assert_eq!(row.status, TestStatus::Flaky);
553 assert_eq!(row.message, "flake detected (rerun passed)");
554 let attempts = row.attempts.as_ref().expect("attempts");
555 assert_eq!(attempts.len(), 2);
556 assert_eq!(attempts[0].status, TestStatus::Fail);
557 assert_eq!(attempts[1].status, TestStatus::Pass);
558 Ok(())
559 }
560
561 #[tokio::test]
562 async fn runner_contract_fail_after_retries_stays_fail() -> anyhow::Result<()> {
563 let cfg = single_test_config(ErrorPolicy::Block);
564 let client = Arc::new(FakeClient::new("fake-model".to_string()).with_response("ok".into()));
565 let metric = Arc::new(ScriptedMetric::always_fail());
566 let runner = runner_for_contract_tests(client, vec![metric], 1);
567
568 let artifacts = runner.run_suite(&cfg, None).await?;
569 let row = artifacts
570 .results
571 .iter()
572 .find(|r| r.test_id == "t1")
573 .expect("result for t1");
574
575 assert_eq!(row.status, TestStatus::Fail);
576 assert!(
577 row.message.contains("failed: scripted"),
578 "expected stable failure reason, got: {}",
579 row.message
580 );
581 let attempts = row.attempts.as_ref().expect("attempts");
582 assert_eq!(attempts.len(), 2);
583 assert_eq!(attempts[0].status, TestStatus::Fail);
584 assert_eq!(attempts[1].status, TestStatus::Fail);
585 Ok(())
586 }
587
588 #[tokio::test]
589 async fn runner_contract_on_error_allow_marks_allowed_and_policy_applied() -> anyhow::Result<()>
590 {
591 let cfg = single_test_config(ErrorPolicy::Allow);
592 let client = Arc::new(ErrorClient);
593 let runner = runner_for_contract_tests(client, vec![], 2);
594
595 let artifacts = runner.run_suite(&cfg, None).await?;
596 let row = artifacts
597 .results
598 .iter()
599 .find(|r| r.test_id == "t1")
600 .expect("result for t1");
601
602 assert_eq!(row.status, TestStatus::AllowedOnError);
603 assert_eq!(row.error_policy_applied, Some(ErrorPolicy::Allow));
604 assert_eq!(row.details["policy_applied"], serde_json::json!("allow"));
605 let attempts = row.attempts.as_ref().expect("attempts");
606 assert_eq!(attempts.len(), 1);
607 assert_eq!(attempts[0].status, TestStatus::AllowedOnError);
608 Ok(())
609 }
610
611 #[tokio::test]
612 async fn runner_contract_results_sorted_by_test_id() -> anyhow::Result<()> {
613 let mut cfg = config_with_test_ids(&["t3", "t1", "t2"], ErrorPolicy::Block);
614 cfg.settings.parallel = Some(3);
615 let client = Arc::new(FakeClient::new("fake-model".to_string()).with_response("ok".into()));
616 let metric = Arc::new(ScriptedMetric::always_pass());
617 let runner = runner_for_contract_tests(client, vec![metric], 0);
618
619 let artifacts = runner.run_suite(&cfg, None).await?;
620 let ids: Vec<_> = artifacts
621 .results
622 .iter()
623 .map(|r| r.test_id.as_str())
624 .collect();
625 assert_eq!(ids, vec!["t1", "t2", "t3"]);
626 Ok(())
627 }
628
629 #[tokio::test]
630 async fn runner_contract_progress_sink_reports_done_total() -> anyhow::Result<()> {
631 let cfg = config_with_test_ids(&["p1", "p2", "p3"], ErrorPolicy::Block);
632 let client = Arc::new(FakeClient::new("fake-model".to_string()).with_response("ok".into()));
633 let metric = Arc::new(ScriptedMetric::always_pass());
634 let runner = runner_for_contract_tests(client, vec![metric], 0);
635
636 let events = Arc::new(std::sync::Mutex::new(Vec::<(usize, usize)>::new()));
637 let sink = {
638 let events = Arc::clone(&events);
639 Arc::new(move |ev: crate::report::progress::ProgressEvent| {
640 events
641 .lock()
642 .expect("progress lock")
643 .push((ev.done, ev.total));
644 }) as crate::report::progress::ProgressSink
645 };
646
647 let artifacts = runner.run_suite(&cfg, Some(sink)).await?;
648 assert_eq!(artifacts.results.len(), 3);
649
650 let observed = events.lock().expect("progress lock");
651 assert_eq!(observed.len(), 3);
652 assert_eq!(observed.last(), Some(&(3, 3)));
653 assert!(observed.windows(2).all(|w| w[0].0 < w[1].0));
654 Ok(())
655 }
656
657 #[tokio::test]
658 async fn runner_contract_relative_baseline_missing_warns_in_helper() -> anyhow::Result<()> {
659 let mut cfg = single_test_config(ErrorPolicy::Block);
660 cfg.settings.thresholding = Some(crate::model::ThresholdingSettings {
661 mode: Some("relative".to_string()),
662 max_drop: Some(0.05),
663 min_floor: None,
664 });
665
666 let client = Arc::new(FakeClient::new("fake-model".to_string()).with_response("ok".into()));
667 let metric = Arc::new(ScriptedMetric::always_pass());
668 let runner = runner_for_contract_tests(client, vec![], 0);
669 let baseline = crate::baseline::Baseline {
670 schema_version: 1,
671 suite: "runner-contract".to_string(),
672 assay_version: env!("CARGO_PKG_VERSION").to_string(),
673 created_at: "2026-01-01T00:00:00Z".to_string(),
674 config_fingerprint: "md5:test".to_string(),
675 git_info: None,
676 entries: vec![],
677 };
678 let tc = cfg.tests.first().cloned().expect("single test case");
679 let details = serde_json::json!({
680 "metrics": {
681 "scripted": {
682 "score": 1.0,
683 "passed": true,
684 "unstable": false,
685 "details": {}
686 }
687 }
688 });
689
690 let verdict = runner.check_baseline_regressions(&tc, &cfg, &details, &[metric], &baseline);
691 let (status, message) = verdict.expect("relative baseline decision");
692 assert_eq!(status, TestStatus::Warn);
693 assert_eq!(message, "missing baseline for t1/scripted");
694 Ok(())
695 }
696}