1use crate::on_error::ErrorPolicy;
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5#[serde(deny_unknown_fields)]
6pub struct EvalConfig {
7 #[serde(default, rename = "configVersion", alias = "version")]
8 pub version: u32,
9 pub suite: String,
10 pub model: String,
11 #[serde(default, skip_serializing_if = "is_default_settings")]
12 pub settings: Settings,
13 #[serde(default, skip_serializing_if = "is_default_thresholds")]
14 pub thresholds: crate::thresholds::ThresholdConfig,
15 #[serde(default, skip_serializing_if = "is_default_otel")]
16 pub otel: crate::config::otel::OtelConfig,
17 pub tests: Vec<TestCase>,
18}
19
20fn is_default_otel(o: &crate::config::otel::OtelConfig) -> bool {
21 o == &crate::config::otel::OtelConfig::default()
22}
23
24impl EvalConfig {
25 pub fn is_legacy(&self) -> bool {
26 self.version == 0
27 }
28
29 pub fn has_legacy_usage(&self) -> bool {
30 self.tests
31 .iter()
32 .any(|t| t.expected.get_policy_path().is_some())
33 }
34
35 pub fn validate(&self) -> anyhow::Result<()> {
36 if self.version >= 1 {
37 for test in &self.tests {
38 if matches!(test.expected, Expected::Reference { .. }) {
39 anyhow::bail!("$ref in expected block is not allowed in configVersion >= 1. Run `assay migrate` to inline policies.");
40 }
41 }
42 }
43 Ok(())
44 }
45
46 pub fn effective_error_policy(&self, test: &TestCase) -> ErrorPolicy {
49 test.on_error.unwrap_or(self.settings.on_error)
50 }
51}
52
53fn is_default_thresholds(t: &crate::thresholds::ThresholdConfig) -> bool {
54 t == &crate::thresholds::ThresholdConfig::default()
55}
56
57#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
58pub struct Settings {
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub parallel: Option<usize>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub timeout_seconds: Option<u64>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub cache: Option<bool>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub seed: Option<u64>,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub judge: Option<JudgeConfig>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub thresholding: Option<ThresholdingSettings>,
71
72 #[serde(default, skip_serializing_if = "is_default_error_policy")]
75 pub on_error: ErrorPolicy,
76
77 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
79 pub bail_on_first_failure: bool,
80}
81
82fn is_default_error_policy(p: &ErrorPolicy) -> bool {
83 *p == ErrorPolicy::default()
84}
85
86fn is_default_settings(s: &Settings) -> bool {
87 s == &Settings::default()
88}
89
90#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
91pub struct ThresholdingSettings {
92 pub mode: Option<String>,
93 pub max_drop: Option<f64>,
94 pub min_floor: Option<f64>,
95}
96
97#[derive(Debug, Clone, Default, Serialize)]
98pub struct TestCase {
99 pub id: String,
100 pub input: TestInput,
101 pub expected: Expected,
102 #[serde(default, skip_serializing_if = "Option::is_none")]
103 pub assertions: Option<Vec<crate::agent_assertions::model::TraceAssertion>>,
104 #[serde(default, skip_serializing_if = "Option::is_none")]
107 pub on_error: Option<ErrorPolicy>,
108 #[serde(default, skip_serializing_if = "Vec::is_empty")]
109 pub tags: Vec<String>,
110 #[serde(default, skip_serializing_if = "Option::is_none")]
111 pub metadata: Option<serde_json::Value>,
112}
113
114impl<'de> Deserialize<'de> for TestCase {
115 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
116 where
117 D: serde::Deserializer<'de>,
118 {
119 #[derive(Deserialize)]
120 #[serde(deny_unknown_fields)]
121 struct RawTestCase {
122 id: String,
123 input: TestInput,
124 #[serde(default)]
125 expected: Option<serde_json::Value>,
126 assertions: Option<Vec<crate::agent_assertions::model::TraceAssertion>>,
127 #[serde(default)]
128 on_error: Option<ErrorPolicy>,
129 #[serde(default)]
130 tags: Vec<String>,
131 metadata: Option<serde_json::Value>,
132 }
133
134 let raw = RawTestCase::deserialize(deserializer)?;
135 let mut expected_main = Expected::default();
136 let extra_assertions = raw.assertions.unwrap_or_default();
137
138 if let Some(val) = raw.expected {
139 if let Some(arr) = val.as_array() {
140 for (i, item) in arr.iter().enumerate() {
142 if let Ok(exp) = serde_json::from_value::<Expected>(item.clone()) {
145 if i == 0 {
146 expected_main = exp;
147 }
148 } else if let Some(obj) = item.as_object() {
149 let mut parsed = None;
151 let mut matched_keys = Vec::new();
152
153 if let Some(r) = obj.get("$ref") {
154 parsed = Some(Expected::Reference {
155 path: r.as_str().unwrap_or("").to_string(),
156 });
157 matched_keys.push("$ref");
158 }
159
160 if let Some(mc) = obj.get("must_contain") {
162 let val = if mc.is_string() {
163 vec![mc.as_str().unwrap().to_string()]
164 } else {
165 serde_json::from_value(mc.clone()).unwrap_or_default()
166 };
167 if parsed.is_none() {
169 parsed = Some(Expected::MustContain { must_contain: val });
170 }
171 matched_keys.push("must_contain");
172 }
173
174 if obj.get("sequence").is_some() {
175 if parsed.is_none() {
176 parsed = Some(Expected::SequenceValid {
177 policy: None,
178 sequence: serde_json::from_value(
179 obj.get("sequence").unwrap().clone(),
180 )
181 .ok(),
182 rules: None,
183 });
184 }
185 matched_keys.push("sequence");
186 }
187
188 if obj.get("schema").is_some() {
189 if parsed.is_none() {
190 parsed = Some(Expected::ArgsValid {
191 policy: None,
192 schema: obj.get("schema").cloned(),
193 });
194 }
195 matched_keys.push("schema");
196 }
197
198 if matched_keys.len() > 1 {
199 eprintln!("WARN: Ambiguous legacy expected block. Found keys: {:?}. Using first match.", matched_keys);
200 }
201
202 if let Some(p) = parsed {
203 if i == 0 {
204 expected_main = p;
205 }
206 }
208 }
209 }
210 } else {
211 if let Ok(exp) = serde_json::from_value(val.clone()) {
213 expected_main = exp;
214 }
215 }
216 }
217
218 Ok(TestCase {
219 id: raw.id,
220 input: raw.input,
221 expected: expected_main,
222 assertions: if extra_assertions.is_empty() {
223 None
224 } else {
225 Some(extra_assertions)
226 },
227 on_error: raw.on_error,
228 tags: raw.tags,
229 metadata: raw.metadata,
230 })
231 }
232}
233
234#[derive(Debug, Clone, Default, Serialize)]
235pub struct TestInput {
236 pub prompt: String,
237 #[serde(default, skip_serializing_if = "Option::is_none")]
238 pub context: Option<Vec<String>>,
239}
240
241impl<'de> Deserialize<'de> for TestInput {
242 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
243 where
244 D: serde::Deserializer<'de>,
245 {
246 struct TestInputVisitor;
247
248 impl<'de> serde::de::Visitor<'de> for TestInputVisitor {
249 type Value = TestInput;
250
251 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252 formatter.write_str("string or struct TestInput")
253 }
254
255 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
256 where
257 E: serde::de::Error,
258 {
259 Ok(TestInput {
260 prompt: value.to_owned(),
261 context: None,
262 })
263 }
264
265 fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
266 where
267 A: serde::de::MapAccess<'de>,
268 {
269 #[derive(Deserialize)]
272 struct Helper {
273 prompt: String,
274 #[serde(default)]
275 context: Option<Vec<String>>,
276 }
277 let helper =
278 Helper::deserialize(serde::de::value::MapAccessDeserializer::new(map))?;
279 Ok(TestInput {
280 prompt: helper.prompt,
281 context: helper.context,
282 })
283 }
284 }
285
286 deserializer.deserialize_any(TestInputVisitor)
287 }
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize)]
291#[serde(rename_all = "snake_case", tag = "type")]
292pub enum Expected {
293 MustContain {
294 #[serde(default)]
295 must_contain: Vec<String>,
296 },
297 MustNotContain {
298 #[serde(default)]
299 must_not_contain: Vec<String>,
300 },
301
302 RegexMatch {
303 pattern: String,
304 #[serde(default)]
305 flags: Vec<String>,
306 },
307 RegexNotMatch {
308 pattern: String,
309 #[serde(default)]
310 flags: Vec<String>,
311 },
312
313 JsonSchema {
314 json_schema: String,
315 #[serde(default)]
316 schema_file: Option<String>,
317 },
318 SemanticSimilarityTo {
319 #[serde(alias = "text")]
321 semantic_similarity_to: String,
322
323 #[serde(default = "default_min_score", alias = "threshold")]
325 min_score: f64,
326
327 #[serde(default)]
328 thresholding: Option<ThresholdingConfig>,
329 },
330 JudgeCriteria {
331 judge_criteria: serde_json::Value,
332 },
333 Faithfulness {
334 #[serde(default = "default_min_score")]
335 min_score: f64,
336 rubric_version: Option<String>,
337 #[serde(default)]
338 thresholding: Option<ThresholdingConfig>,
339 },
340 Relevance {
341 #[serde(default = "default_min_score")]
342 min_score: f64,
343 rubric_version: Option<String>,
344 #[serde(default)]
345 thresholding: Option<ThresholdingConfig>,
346 },
347
348 ArgsValid {
349 #[serde(skip_serializing_if = "Option::is_none")]
350 policy: Option<String>,
351 #[serde(default, skip_serializing_if = "Option::is_none")]
352 schema: Option<serde_json::Value>,
353 },
354 SequenceValid {
355 #[serde(skip_serializing_if = "Option::is_none")]
356 policy: Option<String>,
357 #[serde(default, skip_serializing_if = "Option::is_none")]
358 sequence: Option<Vec<String>>,
359 #[serde(default, skip_serializing_if = "Option::is_none")]
360 rules: Option<Vec<SequenceRule>>,
361 },
362 ToolBlocklist {
363 blocked: Vec<String>,
364 },
365 #[serde(rename = "$ref")]
367 Reference {
368 path: String,
369 },
370}
371
372impl Default for Expected {
373 fn default() -> Self {
374 Expected::MustContain {
375 must_contain: vec![],
376 }
377 }
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize)]
381#[serde(deny_unknown_fields)]
382pub struct Policy {
383 pub version: String,
384 #[serde(default)]
385 pub name: String,
386 #[serde(default)]
387 pub metadata: Option<serde_json::Value>,
388 #[serde(default)]
389 pub tools: ToolsPolicy,
390 #[serde(default)]
391 pub sequences: Vec<SequenceRule>,
392 #[serde(default)]
393 pub aliases: std::collections::HashMap<String, Vec<String>>,
394 #[serde(default)]
395 pub on_error: ErrorPolicy,
396}
397
398#[derive(Debug, Clone, Default, Serialize, Deserialize)]
399#[serde(deny_unknown_fields)]
400pub struct ToolsPolicy {
401 #[serde(default)]
402 pub allow: Option<Vec<String>>,
403 #[serde(default)]
404 pub deny: Option<Vec<String>>,
405 #[serde(default)]
406 pub require_args: Option<std::collections::HashMap<String, Vec<String>>>,
407 #[serde(default)]
408 pub arg_constraints: Option<
409 std::collections::HashMap<String, std::collections::HashMap<String, serde_json::Value>>,
410 >,
411}
412
413#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
414#[serde(tag = "type", rename_all = "snake_case")]
415pub enum SequenceRule {
416 Require {
417 tool: String,
418 },
419 Eventually {
420 tool: String,
421 within: u32,
422 },
423 MaxCalls {
424 tool: String,
425 max: u32,
426 },
427 Before {
428 first: String,
429 then: String,
430 },
431 After {
432 trigger: String,
433 then: String,
434 #[serde(default = "default_one")]
435 within: u32,
436 },
437 NeverAfter {
438 trigger: String,
439 forbidden: String,
440 },
441 Sequence {
442 tools: Vec<String>,
443 #[serde(default)]
444 strict: bool,
445 },
446 Blocklist {
447 pattern: String,
448 },
449}
450
451fn default_one() -> u32 {
452 1
453}
454
455impl Policy {
457 pub fn load<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<Self> {
458 let content = std::fs::read_to_string(path)?;
459 let policy: Policy = serde_yaml::from_str(&content)?;
460 Ok(policy)
461 }
462
463 pub fn resolve_alias(&self, tool_name: &str) -> Vec<String> {
464 if let Some(members) = self.aliases.get(tool_name) {
465 members.clone()
466 } else {
467 vec![tool_name.to_string()]
475 }
476 }
477}
478
479impl Expected {
480 pub fn get_policy_path(&self) -> Option<&str> {
481 match self {
482 Expected::ArgsValid { policy, .. } => policy.as_deref(),
483 Expected::SequenceValid { policy, .. } => policy.as_deref(),
484 _ => None,
485 }
486 }
487
488 pub fn thresholding_for_metric(&self, metric_name: &str) -> Option<&ThresholdingConfig> {
490 match (metric_name, self) {
491 ("semantic_similarity_to", Expected::SemanticSimilarityTo { thresholding, .. }) => {
492 thresholding.as_ref()
493 }
494 ("faithfulness", Expected::Faithfulness { thresholding, .. }) => thresholding.as_ref(),
495 ("relevance", Expected::Relevance { thresholding, .. }) => thresholding.as_ref(),
496 _ => None,
497 }
498 }
499}
500
501#[derive(Debug, Clone, Serialize, Deserialize)]
502pub struct ToolCallRecord {
503 pub id: String,
504 pub tool_name: String,
505 pub args: serde_json::Value,
506 pub result: Option<serde_json::Value>,
507 pub error: Option<serde_json::Value>,
508 pub index: usize,
509 pub ts_ms: u64,
510}
511
512#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
513pub struct ThresholdingConfig {
514 pub max_drop: Option<f64>,
515}
516
517#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
518pub struct JudgeConfig {
519 pub rubric_version: Option<String>,
520 pub samples: Option<u32>,
521 #[serde(default)]
522 pub reliability: crate::judge::reliability::ReliabilityConfig,
523}
524
525fn default_min_score() -> f64 {
526 0.80
527}
528
529#[derive(Debug, Clone, Default, Serialize, Deserialize)]
530pub struct LlmResponse {
531 pub text: String,
532 pub provider: String,
533 pub model: String,
534 pub cached: bool,
535 #[serde(default)]
536 pub meta: serde_json::Value,
537}
538
539#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
540#[serde(rename_all = "snake_case")]
541pub enum TestStatus {
542 Pass,
543 Fail,
544 Flaky,
545 Warn,
546 Error,
547 Skipped,
548 Unstable,
549 AllowedOnError,
551}
552
553impl TestStatus {
554 pub fn parse(s: &str) -> Self {
555 match s {
556 "pass" => TestStatus::Pass,
557 "fail" => TestStatus::Fail,
558 "flaky" => TestStatus::Flaky,
559 "warn" => TestStatus::Warn,
560 "error" => TestStatus::Error,
561 "skipped" => TestStatus::Skipped,
562 "unstable" => TestStatus::Unstable,
563 "allowed_on_error" => TestStatus::AllowedOnError,
564 _ => TestStatus::Error,
565 }
566 }
567
568 pub fn is_passing(&self) -> bool {
570 matches!(
571 self,
572 TestStatus::Pass | TestStatus::AllowedOnError | TestStatus::Warn
573 )
574 }
575
576 pub fn is_blocking(&self) -> bool {
578 matches!(self, TestStatus::Fail | TestStatus::Error)
579 }
580}
581
582#[derive(Debug, Clone, Serialize, Deserialize)]
583pub struct TestResultRow {
584 pub test_id: String,
585 pub status: TestStatus,
586 pub score: Option<f64>,
587 pub cached: bool,
588 pub message: String,
589 #[serde(default)]
590 pub details: serde_json::Value,
591 pub duration_ms: Option<u64>,
592 #[serde(default)]
593 pub fingerprint: Option<String>,
594 #[serde(default)]
595 pub skip_reason: Option<String>,
596 #[serde(default)]
597 pub attempts: Option<Vec<AttemptRow>>,
598 #[serde(default, skip_serializing_if = "Option::is_none")]
600 pub error_policy_applied: Option<ErrorPolicy>,
601}
602
603#[derive(Debug, Clone, Serialize, Deserialize)]
604pub struct AttemptRow {
605 pub attempt_no: u32,
606 pub status: TestStatus,
607 pub message: String,
608 pub duration_ms: Option<u64>,
609 #[serde(default)]
610 pub details: serde_json::Value,
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
618 fn test_string_input_deserialize() {
619 let yaml = r#"
620 id: test1
621 input: "simple string"
622 expected:
623 type: must_contain
624 must_contain: ["foo"]
625 "#;
626 let tc: TestCase = serde_yaml::from_str(yaml).expect("failed to parse");
627 assert_eq!(tc.input.prompt, "simple string");
628 }
629
630 #[test]
631 fn test_legacy_list_expected() {
632 let yaml = r#"
633 id: test1
634 input: "test"
635 expected:
636 - must_contain: "Paris"
637 - must_not_contain: "London"
638 "#;
639 let tc: TestCase = serde_yaml::from_str(yaml).expect("failed to parse");
640 if let Expected::MustContain { must_contain } = tc.expected {
641 assert_eq!(must_contain, vec!["Paris"]);
642 } else {
643 panic!("Expected MustContain, got {:?}", tc.expected);
644 }
645 }
646
647 #[test]
648 fn test_scalar_must_contain_promotion() {
649 let yaml = r#"
650 id: test1
651 input: "test"
652 expected:
653 - must_contain: "single value"
654 "#;
655 let tc: TestCase = serde_yaml::from_str(yaml).unwrap();
656 if let Expected::MustContain { must_contain } = tc.expected {
657 assert_eq!(must_contain, vec!["single value"]);
658 } else {
659 panic!("Expected MustContain");
660 }
661 }
662
663 #[test]
664 fn test_validate_ref_in_v1() {
665 let config = EvalConfig {
666 version: 1,
667 suite: "test".into(),
668 model: "test".into(),
669 settings: Settings::default(),
670 thresholds: Default::default(),
671 tests: vec![TestCase {
672 id: "t1".into(),
673 input: TestInput {
674 prompt: "hi".into(),
675 context: None,
676 },
677 expected: Expected::Reference {
678 path: "foo.yaml".into(),
679 },
680 assertions: None,
681 tags: vec![],
682 metadata: None,
683 on_error: None,
684 }],
685 otel: Default::default(),
686 };
687 assert!(config.validate().is_err());
688 }
689
690 #[test]
691 fn test_thresholding_for_metric() {
692 let exp = Expected::SemanticSimilarityTo {
694 semantic_similarity_to: "ref".into(),
695 min_score: 0.8,
696 thresholding: None,
697 };
698 assert!(exp
699 .thresholding_for_metric("semantic_similarity_to")
700 .is_none());
701 let exp = Expected::SemanticSimilarityTo {
703 semantic_similarity_to: "ref".into(),
704 min_score: 0.8,
705 thresholding: Some(ThresholdingConfig {
706 max_drop: Some(0.05),
707 }),
708 };
709 let t = exp
710 .thresholding_for_metric("semantic_similarity_to")
711 .unwrap();
712 assert_eq!(t.max_drop, Some(0.05));
713 assert!(exp.thresholding_for_metric("faithfulness").is_none());
715 let exp = Expected::Faithfulness {
717 min_score: 0.7,
718 rubric_version: None,
719 thresholding: Some(ThresholdingConfig {
720 max_drop: Some(0.1),
721 }),
722 };
723 let t = exp.thresholding_for_metric("faithfulness").unwrap();
724 assert_eq!(t.max_drop, Some(0.1));
725 }
726}