1use crate::on_error::ErrorPolicy;
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5#[serde(deny_unknown_fields)]
6pub struct EvalConfig {
7 #[serde(default, rename = "configVersion", alias = "version")]
8 pub version: u32,
9 pub suite: String,
10 pub model: String,
11 #[serde(default, skip_serializing_if = "is_default_settings")]
12 pub settings: Settings,
13 #[serde(default, skip_serializing_if = "is_default_thresholds")]
14 pub thresholds: crate::thresholds::ThresholdConfig,
15 pub tests: Vec<TestCase>,
16}
17
18impl EvalConfig {
19 pub fn is_legacy(&self) -> bool {
20 self.version == 0
21 }
22
23 pub fn has_legacy_usage(&self) -> bool {
24 self.tests
25 .iter()
26 .any(|t| t.expected.get_policy_path().is_some())
27 }
28
29 pub fn validate(&self) -> anyhow::Result<()> {
30 if self.version >= 1 {
31 for test in &self.tests {
32 if matches!(test.expected, Expected::Reference { .. }) {
33 anyhow::bail!("$ref in expected block is not allowed in configVersion >= 1. Run `assay migrate` to inline policies.");
34 }
35 }
36 }
37 Ok(())
38 }
39
40 pub fn effective_error_policy(&self, test: &TestCase) -> ErrorPolicy {
43 test.on_error.unwrap_or(self.settings.on_error)
44 }
45}
46
47fn is_default_thresholds(t: &crate::thresholds::ThresholdConfig) -> bool {
48 t == &crate::thresholds::ThresholdConfig::default()
49}
50
51#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
52pub struct Settings {
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub parallel: Option<usize>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub timeout_seconds: Option<u64>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub cache: Option<bool>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub seed: Option<u64>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub judge: Option<JudgeConfig>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub thresholding: Option<ThresholdingSettings>,
65
66 #[serde(default, skip_serializing_if = "is_default_error_policy")]
69 pub on_error: ErrorPolicy,
70
71 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
73 pub bail_on_first_failure: bool,
74}
75
76fn is_default_error_policy(p: &ErrorPolicy) -> bool {
77 *p == ErrorPolicy::default()
78}
79
80fn is_default_settings(s: &Settings) -> bool {
81 s == &Settings::default()
82}
83
84#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
85pub struct ThresholdingSettings {
86 pub mode: Option<String>,
87 pub max_drop: Option<f64>,
88 pub min_floor: Option<f64>,
89}
90
91#[derive(Debug, Clone, Default, Serialize)]
92pub struct TestCase {
93 pub id: String,
94 pub input: TestInput,
95 pub expected: Expected,
96 #[serde(default, skip_serializing_if = "Option::is_none")]
97 pub assertions: Option<Vec<crate::agent_assertions::model::TraceAssertion>>,
98 #[serde(default, skip_serializing_if = "Option::is_none")]
101 pub on_error: Option<ErrorPolicy>,
102 #[serde(default, skip_serializing_if = "Vec::is_empty")]
103 pub tags: Vec<String>,
104 #[serde(default, skip_serializing_if = "Option::is_none")]
105 pub metadata: Option<serde_json::Value>,
106}
107
108impl<'de> Deserialize<'de> for TestCase {
109 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
110 where
111 D: serde::Deserializer<'de>,
112 {
113 #[derive(Deserialize)]
114 #[serde(deny_unknown_fields)]
115 struct RawTestCase {
116 id: String,
117 input: TestInput,
118 #[serde(default)]
119 expected: Option<serde_json::Value>,
120 assertions: Option<Vec<crate::agent_assertions::model::TraceAssertion>>,
121 #[serde(default)]
122 on_error: Option<ErrorPolicy>,
123 #[serde(default)]
124 tags: Vec<String>,
125 metadata: Option<serde_json::Value>,
126 }
127
128 let raw = RawTestCase::deserialize(deserializer)?;
129 let mut expected_main = Expected::default();
130 let extra_assertions = raw.assertions.unwrap_or_default();
131
132 if let Some(val) = raw.expected {
133 if let Some(arr) = val.as_array() {
134 for (i, item) in arr.iter().enumerate() {
136 if let Ok(exp) = serde_json::from_value::<Expected>(item.clone()) {
139 if i == 0 {
140 expected_main = exp;
141 }
142 } else if let Some(obj) = item.as_object() {
143 let mut parsed = None;
145 let mut matched_keys = Vec::new();
146
147 if let Some(r) = obj.get("$ref") {
148 parsed = Some(Expected::Reference {
149 path: r.as_str().unwrap_or("").to_string(),
150 });
151 matched_keys.push("$ref");
152 }
153
154 if let Some(mc) = obj.get("must_contain") {
156 let val = if mc.is_string() {
157 vec![mc.as_str().unwrap().to_string()]
158 } else {
159 serde_json::from_value(mc.clone()).unwrap_or_default()
160 };
161 if parsed.is_none() {
163 parsed = Some(Expected::MustContain { must_contain: val });
164 }
165 matched_keys.push("must_contain");
166 }
167
168 if obj.get("sequence").is_some() {
169 if parsed.is_none() {
170 parsed = Some(Expected::SequenceValid {
171 policy: None,
172 sequence: serde_json::from_value(
173 obj.get("sequence").unwrap().clone(),
174 )
175 .ok(),
176 rules: None,
177 });
178 }
179 matched_keys.push("sequence");
180 }
181
182 if obj.get("schema").is_some() {
183 if parsed.is_none() {
184 parsed = Some(Expected::ArgsValid {
185 policy: None,
186 schema: obj.get("schema").cloned(),
187 });
188 }
189 matched_keys.push("schema");
190 }
191
192 if matched_keys.len() > 1 {
193 eprintln!("WARN: Ambiguous legacy expected block. Found keys: {:?}. Using first match.", matched_keys);
194 }
195
196 if let Some(p) = parsed {
197 if i == 0 {
198 expected_main = p;
199 }
200 }
202 }
203 }
204 } else {
205 if let Ok(exp) = serde_json::from_value(val.clone()) {
207 expected_main = exp;
208 }
209 }
210 }
211
212 Ok(TestCase {
213 id: raw.id,
214 input: raw.input,
215 expected: expected_main,
216 assertions: if extra_assertions.is_empty() {
217 None
218 } else {
219 Some(extra_assertions)
220 },
221 on_error: raw.on_error,
222 tags: raw.tags,
223 metadata: raw.metadata,
224 })
225 }
226}
227
228#[derive(Debug, Clone, Default, Serialize)]
229pub struct TestInput {
230 pub prompt: String,
231 #[serde(default, skip_serializing_if = "Option::is_none")]
232 pub context: Option<Vec<String>>,
233}
234
235impl<'de> Deserialize<'de> for TestInput {
236 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
237 where
238 D: serde::Deserializer<'de>,
239 {
240 struct TestInputVisitor;
241
242 impl<'de> serde::de::Visitor<'de> for TestInputVisitor {
243 type Value = TestInput;
244
245 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246 formatter.write_str("string or struct TestInput")
247 }
248
249 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
250 where
251 E: serde::de::Error,
252 {
253 Ok(TestInput {
254 prompt: value.to_owned(),
255 context: None,
256 })
257 }
258
259 fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
260 where
261 A: serde::de::MapAccess<'de>,
262 {
263 #[derive(Deserialize)]
266 struct Helper {
267 prompt: String,
268 #[serde(default)]
269 context: Option<Vec<String>>,
270 }
271 let helper =
272 Helper::deserialize(serde::de::value::MapAccessDeserializer::new(map))?;
273 Ok(TestInput {
274 prompt: helper.prompt,
275 context: helper.context,
276 })
277 }
278 }
279
280 deserializer.deserialize_any(TestInputVisitor)
281 }
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize)]
285#[serde(rename_all = "snake_case", tag = "type")]
286pub enum Expected {
287 MustContain {
288 #[serde(default)]
289 must_contain: Vec<String>,
290 },
291 MustNotContain {
292 #[serde(default)]
293 must_not_contain: Vec<String>,
294 },
295
296 RegexMatch {
297 pattern: String,
298 #[serde(default)]
299 flags: Vec<String>,
300 },
301 RegexNotMatch {
302 pattern: String,
303 #[serde(default)]
304 flags: Vec<String>,
305 },
306
307 JsonSchema {
308 json_schema: String,
309 #[serde(default)]
310 schema_file: Option<String>,
311 },
312 SemanticSimilarityTo {
313 #[serde(alias = "text")]
315 semantic_similarity_to: String,
316
317 #[serde(default = "default_min_score", alias = "threshold")]
319 min_score: f64,
320
321 #[serde(default)]
322 thresholding: Option<ThresholdingConfig>,
323 },
324 JudgeCriteria {
325 judge_criteria: serde_json::Value,
326 },
327 Faithfulness {
328 #[serde(default = "default_min_score")]
329 min_score: f64,
330 rubric_version: Option<String>,
331 #[serde(default)]
332 thresholding: Option<ThresholdingConfig>,
333 },
334 Relevance {
335 #[serde(default = "default_min_score")]
336 min_score: f64,
337 rubric_version: Option<String>,
338 #[serde(default)]
339 thresholding: Option<ThresholdingConfig>,
340 },
341
342 ArgsValid {
343 #[serde(skip_serializing_if = "Option::is_none")]
344 policy: Option<String>,
345 #[serde(default, skip_serializing_if = "Option::is_none")]
346 schema: Option<serde_json::Value>,
347 },
348 SequenceValid {
349 #[serde(skip_serializing_if = "Option::is_none")]
350 policy: Option<String>,
351 #[serde(default, skip_serializing_if = "Option::is_none")]
352 sequence: Option<Vec<String>>,
353 #[serde(default, skip_serializing_if = "Option::is_none")]
354 rules: Option<Vec<SequenceRule>>,
355 },
356 ToolBlocklist {
357 blocked: Vec<String>,
358 },
359 #[serde(rename = "$ref")]
361 Reference {
362 path: String,
363 },
364}
365
366impl Default for Expected {
367 fn default() -> Self {
368 Expected::MustContain {
369 must_contain: vec![],
370 }
371 }
372}
373
374#[derive(Debug, Clone, Serialize, Deserialize)]
375#[serde(deny_unknown_fields)]
376pub struct Policy {
377 pub version: String,
378 #[serde(default)]
379 pub name: String,
380 #[serde(default)]
381 pub metadata: Option<serde_json::Value>,
382 #[serde(default)]
383 pub tools: ToolsPolicy,
384 #[serde(default)]
385 pub sequences: Vec<SequenceRule>,
386 #[serde(default)]
387 pub aliases: std::collections::HashMap<String, Vec<String>>,
388 #[serde(default)]
389 pub on_error: ErrorPolicy,
390}
391
392#[derive(Debug, Clone, Default, Serialize, Deserialize)]
393#[serde(deny_unknown_fields)]
394pub struct ToolsPolicy {
395 #[serde(default)]
396 pub allow: Option<Vec<String>>,
397 #[serde(default)]
398 pub deny: Option<Vec<String>>,
399 #[serde(default)]
400 pub require_args: Option<std::collections::HashMap<String, Vec<String>>>,
401 #[serde(default)]
402 pub arg_constraints: Option<
403 std::collections::HashMap<String, std::collections::HashMap<String, serde_json::Value>>,
404 >,
405}
406
407#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
408#[serde(tag = "type", rename_all = "snake_case")]
409pub enum SequenceRule {
410 Require {
411 tool: String,
412 },
413 Eventually {
414 tool: String,
415 within: u32,
416 },
417 MaxCalls {
418 tool: String,
419 max: u32,
420 },
421 Before {
422 first: String,
423 then: String,
424 },
425 After {
426 trigger: String,
427 then: String,
428 #[serde(default = "default_one")]
429 within: u32,
430 },
431 NeverAfter {
432 trigger: String,
433 forbidden: String,
434 },
435 Sequence {
436 tools: Vec<String>,
437 #[serde(default)]
438 strict: bool,
439 },
440 Blocklist {
441 pattern: String,
442 },
443}
444
445fn default_one() -> u32 {
446 1
447}
448
449impl Policy {
451 pub fn load<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<Self> {
452 let content = std::fs::read_to_string(path)?;
453 let policy: Policy = serde_yaml::from_str(&content)?;
454 Ok(policy)
455 }
456
457 pub fn resolve_alias(&self, tool_name: &str) -> Vec<String> {
458 if let Some(members) = self.aliases.get(tool_name) {
459 members.clone()
460 } else {
461 vec![tool_name.to_string()]
469 }
470 }
471}
472
473impl Expected {
474 pub fn get_policy_path(&self) -> Option<&str> {
475 match self {
476 Expected::ArgsValid { policy, .. } => policy.as_deref(),
477 Expected::SequenceValid { policy, .. } => policy.as_deref(),
478 _ => None,
479 }
480 }
481}
482
483#[derive(Debug, Clone, Serialize, Deserialize)]
484pub struct ToolCallRecord {
485 pub id: String,
486 pub tool_name: String,
487 pub args: serde_json::Value,
488 pub result: Option<serde_json::Value>,
489 pub error: Option<serde_json::Value>,
490 pub index: usize,
491 pub ts_ms: u64,
492}
493
494#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
495pub struct ThresholdingConfig {
496 pub max_drop: Option<f64>,
497}
498
499#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
500pub struct JudgeConfig {
501 pub rubric_version: Option<String>,
502 pub samples: Option<u32>,
503}
504
505fn default_min_score() -> f64 {
506 0.80
507}
508
509#[derive(Debug, Clone, Default, Serialize, Deserialize)]
510pub struct LlmResponse {
511 pub text: String,
512 pub provider: String,
513 pub model: String,
514 pub cached: bool,
515 #[serde(default)]
516 pub meta: serde_json::Value,
517}
518
519#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
520#[serde(rename_all = "snake_case")]
521pub enum TestStatus {
522 Pass,
523 Fail,
524 Flaky,
525 Warn,
526 Error,
527 Skipped,
528 Unstable,
529 AllowedOnError,
531}
532
533impl TestStatus {
534 pub fn parse(s: &str) -> Self {
535 match s {
536 "pass" => TestStatus::Pass,
537 "fail" => TestStatus::Fail,
538 "flaky" => TestStatus::Flaky,
539 "warn" => TestStatus::Warn,
540 "error" => TestStatus::Error,
541 "skipped" => TestStatus::Skipped,
542 "unstable" => TestStatus::Unstable,
543 "allowed_on_error" => TestStatus::AllowedOnError,
544 _ => TestStatus::Error,
545 }
546 }
547
548 pub fn is_passing(&self) -> bool {
550 matches!(
551 self,
552 TestStatus::Pass | TestStatus::AllowedOnError | TestStatus::Warn
553 )
554 }
555
556 pub fn is_blocking(&self) -> bool {
558 matches!(self, TestStatus::Fail | TestStatus::Error)
559 }
560}
561
562#[derive(Debug, Clone, Serialize, Deserialize)]
563pub struct TestResultRow {
564 pub test_id: String,
565 pub status: TestStatus,
566 pub score: Option<f64>,
567 pub cached: bool,
568 pub message: String,
569 #[serde(default)]
570 pub details: serde_json::Value,
571 pub duration_ms: Option<u64>,
572 #[serde(default)]
573 pub fingerprint: Option<String>,
574 #[serde(default)]
575 pub skip_reason: Option<String>,
576 #[serde(default)]
577 pub attempts: Option<Vec<AttemptRow>>,
578 #[serde(default, skip_serializing_if = "Option::is_none")]
580 pub error_policy_applied: Option<ErrorPolicy>,
581}
582
583#[derive(Debug, Clone, Serialize, Deserialize)]
584pub struct AttemptRow {
585 pub attempt_no: u32,
586 pub status: TestStatus,
587 pub message: String,
588 pub duration_ms: Option<u64>,
589 #[serde(default)]
590 pub details: serde_json::Value,
591}
592
593#[cfg(test)]
594mod tests {
595 use super::*;
596
597 #[test]
598 fn test_string_input_deserialize() {
599 let yaml = r#"
600 id: test1
601 input: "simple string"
602 expected:
603 type: must_contain
604 must_contain: ["foo"]
605 "#;
606 let tc: TestCase = serde_yaml::from_str(yaml).expect("failed to parse");
607 assert_eq!(tc.input.prompt, "simple string");
608 }
609
610 #[test]
611 fn test_legacy_list_expected() {
612 let yaml = r#"
613 id: test1
614 input: "test"
615 expected:
616 - must_contain: "Paris"
617 - must_not_contain: "London"
618 "#;
619 let tc: TestCase = serde_yaml::from_str(yaml).expect("failed to parse");
620 if let Expected::MustContain { must_contain } = tc.expected {
621 assert_eq!(must_contain, vec!["Paris"]);
622 } else {
623 panic!("Expected MustContain, got {:?}", tc.expected);
624 }
625 }
626
627 #[test]
628 fn test_scalar_must_contain_promotion() {
629 let yaml = r#"
630 id: test1
631 input: "test"
632 expected:
633 - must_contain: "single value"
634 "#;
635 let tc: TestCase = serde_yaml::from_str(yaml).unwrap();
636 if let Expected::MustContain { must_contain } = tc.expected {
637 assert_eq!(must_contain, vec!["single value"]);
638 } else {
639 panic!("Expected MustContain");
640 }
641 }
642
643 #[test]
644 fn test_validate_ref_in_v1() {
645 let config = EvalConfig {
646 version: 1,
647 suite: "test".into(),
648 model: "test".into(),
649 settings: Settings::default(),
650 thresholds: Default::default(),
651 tests: vec![TestCase {
652 id: "t1".into(),
653 input: TestInput {
654 prompt: "hi".into(),
655 context: None,
656 },
657 expected: Expected::Reference {
658 path: "foo.yaml".into(),
659 },
660 assertions: None,
661 tags: vec![],
662 metadata: None,
663 on_error: None,
664 }],
665 };
666 assert!(config.validate().is_err());
667 }
668}