1use crate::attempts::{classify_attempts, FailureClass};
2use crate::cache::key::cache_key;
3use crate::cache::vcr::VcrCache;
4use crate::errors::try_map_error;
5use crate::metrics_api::Metric;
6use crate::model::{AttemptRow, EvalConfig, LlmResponse, TestCase, TestResultRow, TestStatus};
7use crate::on_error::{ErrorPolicy, ErrorPolicyResult};
8use crate::providers::llm::LlmClient;
9use crate::quarantine::{QuarantineMode, QuarantineService};
10use crate::report::RunArtifacts;
11use crate::storage::store::Store;
12use std::sync::Arc;
13use tokio::sync::Semaphore;
14use tokio::time::{timeout, Duration};
15
16#[derive(Debug, Clone)]
17pub struct RunPolicy {
18 pub rerun_failures: u32,
19 pub quarantine_mode: QuarantineMode,
20 pub replay_strict: bool,
21}
22
23impl Default for RunPolicy {
24 fn default() -> Self {
25 Self {
26 rerun_failures: 1,
27 quarantine_mode: QuarantineMode::Warn,
28 replay_strict: false,
29 }
30 }
31}
32
33pub struct Runner {
34 pub store: Store,
35 pub cache: VcrCache,
36 pub client: Arc<dyn LlmClient>,
37 pub metrics: Vec<Arc<dyn Metric>>,
38 pub policy: RunPolicy,
39 pub embedder: Option<Arc<dyn crate::providers::embedder::Embedder>>,
40 pub refresh_embeddings: bool,
41 pub incremental: bool,
42 pub refresh_cache: bool,
43 pub judge: Option<crate::judge::JudgeService>,
44 pub baseline: Option<crate::baseline::Baseline>,
45}
46
47impl Runner {
48 pub async fn run_suite(&self, cfg: &EvalConfig) -> anyhow::Result<RunArtifacts> {
49 let run_id = self.store.create_run(cfg)?;
50
51 let parallel = cfg.settings.parallel.unwrap_or(4).max(1);
52 let sem = Arc::new(Semaphore::new(parallel));
53 let mut handles = Vec::new();
54
55 for tc in cfg.tests.iter() {
56 let permit = sem.clone().acquire_owned().await?;
57 let this = self.clone_for_task();
58 let cfg = cfg.clone();
59 let tc = tc.clone();
60 let h = tokio::spawn(async move {
61 let _permit = permit;
62 this.run_test_with_policy(&cfg, &tc, run_id).await
63 });
64 handles.push(h);
65 }
66
67 let mut rows = Vec::new();
68 let mut any_fail = false;
69 for h in handles {
70 let row = match h.await {
71 Ok(Ok(row)) => row,
72 Ok(Err(e)) => TestResultRow {
73 test_id: "unknown".into(),
74 status: TestStatus::Error,
75 score: None,
76 cached: false,
77 message: format!("task error: {}", e),
78 details: serde_json::json!({}),
79 duration_ms: None,
80 fingerprint: None,
81 skip_reason: None,
82 attempts: None,
83 error_policy_applied: None,
84 },
85 Err(e) => TestResultRow {
86 test_id: "unknown".into(),
87 status: TestStatus::Error,
88 score: None,
89 cached: false,
90 message: format!("join error: {}", e),
91 details: serde_json::json!({}),
92 duration_ms: None,
93 fingerprint: None,
94 skip_reason: None,
95 attempts: None,
96 error_policy_applied: None,
97 },
98 };
99 any_fail = any_fail || matches!(row.status, TestStatus::Fail | TestStatus::Error);
100 rows.push(row);
101 }
102
103 self.store
104 .finalize_run(run_id, if any_fail { "failed" } else { "passed" })?;
105 Ok(RunArtifacts {
106 run_id,
107 suite: cfg.suite.clone(),
108 results: rows,
109 })
110 }
111
112 async fn run_test_with_policy(
113 &self,
114 cfg: &EvalConfig,
115 tc: &TestCase,
116 run_id: i64,
117 ) -> anyhow::Result<TestResultRow> {
118 let quarantine = QuarantineService::new(self.store.clone());
119 let q_reason = quarantine.is_quarantined(&cfg.suite, &tc.id)?;
120 let error_policy = cfg.effective_error_policy(tc);
121
122 let max_attempts = 1 + self.policy.rerun_failures;
123 let mut attempts: Vec<AttemptRow> = Vec::new();
124 let mut last_row: Option<TestResultRow> = None;
125 let mut last_output: Option<LlmResponse> = None;
126
127 for i in 0..max_attempts {
128 let (row, output) = match self.run_test_once(cfg, tc).await {
130 Ok(res) => res,
131 Err(e) => {
132 let msg = if let Some(diag) = try_map_error(&e) {
133 diag.to_string()
134 } else {
135 e.to_string()
136 };
137
138 let policy_result = error_policy.apply_to_error(&e);
139 let (status, final_msg, applied_policy) = match policy_result {
140 ErrorPolicyResult::Blocked { reason } => {
141 (TestStatus::Error, reason, ErrorPolicy::Block)
142 }
143 ErrorPolicyResult::Allowed { warning } => {
144 crate::on_error::log_fail_safe(&warning, None);
145 (TestStatus::AllowedOnError, warning, ErrorPolicy::Allow)
146 }
147 };
148
149 (
150 TestResultRow {
151 test_id: tc.id.clone(),
152 status,
153 score: None,
154 cached: false,
155 message: final_msg,
156 details: serde_json::json!({
157 "error": msg,
158 "policy_applied": applied_policy
159 }),
160 duration_ms: None,
161 fingerprint: None,
162 skip_reason: None,
163 attempts: None,
164 error_policy_applied: Some(applied_policy),
165 },
166 LlmResponse {
167 text: "".into(),
168 provider: "error".into(),
169 model: cfg.model.clone(),
170 cached: false,
171 meta: serde_json::json!({}),
172 },
173 )
174 }
175 };
176 attempts.push(AttemptRow {
177 attempt_no: i + 1,
178 status: row.status.clone(),
179 message: row.message.clone(),
180 duration_ms: row.duration_ms,
181 details: row.details.clone(),
182 });
183 last_row = Some(row.clone());
184 last_output = Some(output.clone());
185
186 match row.status {
187 TestStatus::Pass | TestStatus::Warn | TestStatus::AllowedOnError => break,
188 TestStatus::Skipped => break, TestStatus::Fail | TestStatus::Error | TestStatus::Flaky | TestStatus::Unstable => {
190 continue
191 }
192 }
193 }
194
195 let class = classify_attempts(&attempts);
196 let mut final_row = last_row.unwrap_or(TestResultRow {
197 test_id: tc.id.clone(),
198 status: TestStatus::Error,
199 score: None,
200 cached: false,
201 message: "no attempts".into(),
202 details: serde_json::json!({}),
203 duration_ms: None,
204 fingerprint: None,
205 skip_reason: None,
206 attempts: None,
207 error_policy_applied: None,
208 });
209
210 if let Some(reason) = q_reason {
212 match self.policy.quarantine_mode {
213 QuarantineMode::Off => {}
214 QuarantineMode::Warn => {
215 final_row.status = TestStatus::Warn;
216 final_row.message = format!("quarantined: {}", reason);
217 }
218 QuarantineMode::Strict => {
219 final_row.status = TestStatus::Fail;
220 final_row.message = format!("quarantined (strict): {}", reason);
221 }
222 }
223 }
224
225 match class {
226 FailureClass::Skipped => {
227 final_row.status = TestStatus::Skipped;
228 }
230 FailureClass::Flaky => {
231 final_row.status = TestStatus::Flaky;
232 final_row.message = "flake detected (rerun passed)".into();
233 final_row.details["flake"] = serde_json::json!({ "attempts": attempts.len() });
234 }
235 FailureClass::Unstable => {
236 final_row.status = TestStatus::Unstable;
237 final_row.message = "unstable outcomes detected".into();
238 final_row.details["unstable"] = serde_json::json!({ "attempts": attempts.len() });
239 }
240 FailureClass::Error => final_row.status = TestStatus::Error,
241 FailureClass::DeterministicFail => {
242 final_row.status = TestStatus::Fail;
244 }
245 FailureClass::DeterministicPass => {
246 final_row.status = TestStatus::Pass;
247 }
248 }
249
250 let output = last_output.unwrap_or(LlmResponse {
251 text: "".into(),
252 provider: self.client.provider_name().to_string(),
253 model: cfg.model.clone(),
254 cached: false,
255 meta: serde_json::json!({}),
256 });
257
258 final_row.attempts = Some(attempts.clone());
259
260 if let Some(assertions) = &tc.assertions {
262 if !assertions.is_empty() {
263 match crate::agent_assertions::verify_assertions(
265 &self.store,
266 run_id,
267 &tc.id,
268 assertions,
269 ) {
270 Ok(diags) => {
271 if !diags.is_empty() {
272 final_row.status = TestStatus::Fail;
274
275 let diag_json: Vec<serde_json::Value> = diags
277 .iter()
278 .map(|d| serde_json::to_value(d).unwrap_or_default())
279 .collect();
280
281 final_row.details["assertions"] = serde_json::Value::Array(diag_json);
282
283 let fail_msg = format!("assertions failed ({})", diags.len());
284 if final_row.message == "ok" {
285 final_row.message = fail_msg;
286 } else {
287 final_row.message = format!("{}; {}", final_row.message, fail_msg);
288 }
289 } else {
290 final_row.details["assertions"] = serde_json::json!({ "passed": true });
292 }
293 }
294 Err(e) => {
295 final_row.status = TestStatus::Fail;
297 final_row.message = format!("assertions error: {}", e);
298 final_row.details["assertions"] =
299 serde_json::json!({ "error": e.to_string() });
300 }
301 }
302 }
303 }
304
305 self.store
306 .insert_result_embedded(run_id, &final_row, &attempts, &output)?;
307
308 Ok(final_row)
309 }
310
311 async fn run_test_once(
312 &self,
313 cfg: &EvalConfig,
314 tc: &TestCase,
315 ) -> anyhow::Result<(TestResultRow, LlmResponse)> {
316 let expected_json = serde_json::to_string(&tc.expected).unwrap_or_default();
317 let metric_versions = [("assay", env!("CARGO_PKG_VERSION"))];
318
319 let policy_hash = if let Some(path) = tc.expected.get_policy_path() {
320 match std::fs::read_to_string(path) {
322 Ok(content) => Some(crate::fingerprint::sha256_hex(&content)),
323 Err(_) => None, }
326 } else {
327 None
328 };
329
330 let fp = crate::fingerprint::compute(crate::fingerprint::Context {
331 suite: &cfg.suite,
332 model: &cfg.model,
333 test_id: &tc.id,
334 prompt: &tc.input.prompt,
335 context: tc.input.context.as_deref(),
336 expected_canonical: &expected_json,
337 policy_hash: policy_hash.as_deref(),
338 metric_versions: &metric_versions,
339 });
340
341 if self.incremental && !self.refresh_cache {
345 if let Some(prev) = self.store.get_last_passing_by_fingerprint(&fp.hex)? {
346 let row = TestResultRow {
348 test_id: tc.id.clone(),
349 status: TestStatus::Skipped,
350 score: prev.score,
351 cached: true,
352 message: "skipped: fingerprint match".into(),
353 details: serde_json::json!({
354 "skip": {
355 "reason": "fingerprint_match",
356 "fingerprint": fp.hex,
357 "previous_run_id": prev.details.get("skip").and_then(|s: &serde_json::Value| s.get("previous_run_id")).and_then(|v: &serde_json::Value| v.as_i64()),
358 "previous_at": prev.details.get("skip").and_then(|s: &serde_json::Value| s.get("previous_at")).and_then(|v: &serde_json::Value| v.as_str()),
359 "origin_run_id": prev.details.get("skip").and_then(|s: &serde_json::Value| s.get("origin_run_id")).and_then(|v: &serde_json::Value| v.as_i64()),
360 "previous_score": prev.score
361 }
362 }),
363 duration_ms: Some(0), fingerprint: Some(fp.hex.clone()),
365 skip_reason: Some("fingerprint_match".into()),
366 attempts: None,
367 error_policy_applied: None,
368 };
369
370 let resp = LlmResponse {
372 text: "".into(),
373 provider: "skipped".into(),
374 model: cfg.model.clone(),
375 cached: true,
376 meta: serde_json::json!({}),
377 };
378 return Ok((row, resp));
379 }
380 }
381
382 let key = cache_key(
385 &cfg.model,
386 &tc.input.prompt,
387 &fp.hex,
388 self.client.fingerprint().as_deref(),
389 );
390
391 let start = std::time::Instant::now();
392 let mut cached = false;
393
394 let mut resp: LlmResponse = if cfg.settings.cache.unwrap_or(true) && !self.refresh_cache {
395 if let Some(r) = self.cache.get(&key)? {
396 cached = true;
397 eprintln!(
398 " [CACHE HIT] key={} prompt_len={}",
399 key,
400 tc.input.prompt.len()
401 );
402 r
403 } else {
404 let r = self.call_llm(cfg, tc).await?;
405 self.cache.put(&key, &r)?;
406 r
407 }
408 } else {
409 self.call_llm(cfg, tc).await?
410 };
411 resp.cached = resp.cached || cached;
412
413 self.enrich_semantic(tc, &mut resp).await?;
415 self.enrich_judge(tc, &mut resp).await?;
416
417 let mut final_status = TestStatus::Pass;
418 let mut final_score: Option<f64> = None;
419 let mut msg = String::new();
420 let mut details = serde_json::json!({ "metrics": {} });
421
422 for m in &self.metrics {
423 let r = m.evaluate(tc, &tc.expected, &resp).await?;
424 details["metrics"][m.name()] = serde_json::json!({
425 "score": r.score, "passed": r.passed, "unstable": r.unstable, "details": r.details
426 });
427 final_score = Some(r.score);
428
429 if r.unstable {
430 final_status = TestStatus::Warn;
432 msg = format!("unstable metric: {}", m.name());
433 break;
434 }
435 if !r.passed {
436 final_status = TestStatus::Fail;
437 msg = format!("failed: {}", m.name());
438 break;
439 }
440 }
441
442 if let Some(baseline) = &self.baseline {
444 if let Some((new_status, new_msg)) =
445 self.check_baseline_regressions(tc, cfg, &details, &self.metrics, baseline)
446 {
447 if matches!(new_status, TestStatus::Fail | TestStatus::Warn) {
448 final_status = new_status;
449 msg = new_msg;
450 }
451 }
452 }
453
454 let duration_ms = start.elapsed().as_millis() as u64;
455 let mut row = TestResultRow {
456 test_id: tc.id.clone(),
457 status: final_status,
458 score: final_score,
459 cached: resp.cached,
460 message: if msg.is_empty() { "ok".into() } else { msg },
461 details,
462 duration_ms: Some(duration_ms),
463 fingerprint: Some(fp.hex),
464 skip_reason: None,
465 attempts: None,
466 error_policy_applied: None,
467 };
468
469 if self.client.provider_name() == "trace" {
470 row.details["assay.replay"] = serde_json::json!(true);
471 }
472
473 row.details["prompt"] = serde_json::Value::String(tc.input.prompt.clone());
474
475 Ok((row, resp))
476 }
477
478 async fn call_llm(&self, cfg: &EvalConfig, tc: &TestCase) -> anyhow::Result<LlmResponse> {
479 let t = cfg.settings.timeout_seconds.unwrap_or(30);
480 let fut = self
481 .client
482 .complete(&tc.input.prompt, tc.input.context.as_deref());
483 let resp = timeout(Duration::from_secs(t), fut).await??;
484 Ok(resp)
485 }
486
487 fn clone_for_task(&self) -> RunnerRef {
488 RunnerRef {
489 store: self.store.clone(),
490 cache: self.cache.clone(),
491 client: self.client.clone(),
492 metrics: self.metrics.clone(),
493 policy: self.policy.clone(),
494 embedder: self.embedder.clone(),
495 refresh_embeddings: self.refresh_embeddings,
496 incremental: self.incremental,
497 refresh_cache: self.refresh_cache,
498 judge: self.judge.clone(),
499 baseline: self.baseline.clone(),
500 }
501 }
502
503 fn check_baseline_regressions(
504 &self,
505 tc: &TestCase,
506 cfg: &EvalConfig,
507 details: &serde_json::Value,
508 metrics: &[Arc<dyn Metric>],
509 baseline: &crate::baseline::Baseline,
510 ) -> Option<(TestStatus, String)> {
511 let suite_defaults = cfg.settings.thresholding.as_ref();
513
514 for m in metrics {
515 let metric_name = m.name();
516 let score = details["metrics"][metric_name]["score"].as_f64()?;
518
519 let (mode, max_drop) = self.resolve_threshold_config(tc, metric_name, suite_defaults);
524
525 if mode == "relative" {
526 if let Some(base_score) = baseline.get_score(&tc.id, metric_name) {
527 let delta = score - base_score;
528 if let Some(drop_limit) = max_drop {
529 if delta < -drop_limit {
530 return Some((
531 TestStatus::Fail,
532 format!(
533 "regression: {} dropped {:.3} (limit: {:.3})",
534 metric_name, -delta, drop_limit
535 ),
536 ));
537 }
538 }
539 } else {
540 return Some((
542 TestStatus::Warn,
543 format!("missing baseline for {}/{}", tc.id, metric_name),
544 ));
545 }
546 }
547 }
548 None
549 }
550
551 fn resolve_threshold_config(
552 &self,
553 _tc: &TestCase,
554 _metric_name: &str,
555 suite_defaults: Option<&crate::model::ThresholdingSettings>,
556 ) -> (String, Option<f64>) {
557 let mut mode = "absolute".to_string();
559 let mut max_drop = None;
560
561 if let Some(s) = suite_defaults {
562 if let Some(m) = &s.mode {
563 mode = m.clone();
564 }
565 max_drop = s.max_drop;
566 }
567
568 (mode, max_drop)
571 }
572
573 async fn enrich_semantic(&self, tc: &TestCase, resp: &mut LlmResponse) -> anyhow::Result<()> {
575 use crate::model::Expected;
576
577 let Expected::SemanticSimilarityTo {
578 semantic_similarity_to,
579 ..
580 } = &tc.expected
581 else {
582 return Ok(());
583 };
584
585 if resp.meta.pointer("/assay/embeddings/response").is_some()
586 && resp.meta.pointer("/assay/embeddings/reference").is_some()
587 {
588 return Ok(());
589 }
590
591 if self.policy.replay_strict {
592 anyhow::bail!("config error: --replay-strict is on, but embeddings are missing in trace. Run 'assay trace precompute-embeddings' or disable strict mode.");
593 }
594
595 let embedder = self.embedder.as_ref().ok_or_else(|| {
596 anyhow::anyhow!(
597 "config error: semantic_similarity_to requires an embedder (--embedder) or trace meta embeddings"
598 )
599 })?;
600
601 let model_id = embedder.model_id();
602
603 let (resp_vec, src_resp) = self
604 .embed_text(&model_id, embedder.as_ref(), &resp.text)
605 .await?;
606 let (ref_vec, src_ref) = self
607 .embed_text(&model_id, embedder.as_ref(), semantic_similarity_to)
608 .await?;
609
610 if !resp.meta.get("assay").is_some_and(|v| v.is_object()) {
612 resp.meta["assay"] = serde_json::json!({});
613 }
614 resp.meta["assay"]["embeddings"] = serde_json::json!({
615 "model": model_id,
616 "response": resp_vec,
617 "reference": ref_vec,
618 "source_response": src_resp,
619 "source_reference": src_ref
620 });
621
622 Ok(())
623 }
624
625 pub async fn embed_text(
626 &self,
627 model_id: &str,
628 embedder: &dyn crate::providers::embedder::Embedder,
629 text: &str,
630 ) -> anyhow::Result<(Vec<f32>, &'static str)> {
631 use crate::embeddings::util::embed_cache_key;
632
633 let key = embed_cache_key(model_id, text);
634
635 if !self.refresh_embeddings {
636 if let Some((_m, vec)) = self.store.get_embedding(&key)? {
637 return Ok((vec, "cache"));
638 }
639 }
640
641 let vec = embedder.embed(text).await?;
642 self.store.put_embedding(&key, model_id, &vec)?;
643 Ok((vec, "live"))
644 }
645
646 async fn enrich_judge(&self, tc: &TestCase, resp: &mut LlmResponse) -> anyhow::Result<()> {
647 use crate::model::Expected;
648
649 let (rubric_id, rubric_version) = match &tc.expected {
650 Expected::Faithfulness { rubric_version, .. } => {
651 ("faithfulness", rubric_version.as_deref())
652 }
653 Expected::Relevance { rubric_version, .. } => ("relevance", rubric_version.as_deref()),
654 _ => return Ok(()),
655 };
656
657 let has_trace = resp
662 .meta
663 .pointer(&format!("/assay/judge/{}", rubric_id))
664 .is_some();
665 if self.policy.replay_strict && !has_trace {
666 anyhow::bail!("config error: --replay-strict is on, but judge results are missing in trace for '{}'. Run 'assay trace precompute-judge' or disable strict mode.", rubric_id);
667 }
668
669 let judge = self.judge.as_ref().ok_or_else(|| {
670 anyhow::anyhow!("config error: judge required but service not initialized")
671 })?;
672
673 judge
674 .evaluate(
675 &tc.id,
676 rubric_id,
677 &tc.input,
678 &resp.text,
679 rubric_version,
680 &mut resp.meta,
681 )
682 .await?;
683
684 Ok(())
685 }
686}
687
688#[derive(Clone)]
689struct RunnerRef {
690 store: Store,
691 cache: VcrCache,
692 client: Arc<dyn LlmClient>,
693 metrics: Vec<Arc<dyn Metric>>,
694 policy: RunPolicy,
695 embedder: Option<Arc<dyn crate::providers::embedder::Embedder>>,
696 refresh_embeddings: bool,
697 incremental: bool,
698 refresh_cache: bool,
699 judge: Option<crate::judge::JudgeService>,
700 baseline: Option<crate::baseline::Baseline>,
701}
702
703impl RunnerRef {
704 async fn run_test_with_policy(
705 &self,
706 cfg: &EvalConfig,
707 tc: &TestCase,
708 run_id: i64,
709 ) -> anyhow::Result<TestResultRow> {
710 let runner = Runner {
711 store: self.store.clone(),
712 cache: self.cache.clone(),
713 client: self.client.clone(),
714 metrics: self.metrics.clone(),
715 policy: self.policy.clone(),
716 embedder: self.embedder.clone(),
717 refresh_embeddings: self.refresh_embeddings,
718 incremental: self.incremental,
719 refresh_cache: self.refresh_cache,
720 judge: self.judge.clone(),
721 baseline: self.baseline.clone(),
722 };
723 runner.run_test_with_policy(cfg, tc, run_id).await
724 }
725}