Skip to main content

assay_core/
model.rs

1use crate::on_error::ErrorPolicy;
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5#[serde(deny_unknown_fields)]
6pub struct EvalConfig {
7    #[serde(default, rename = "configVersion", alias = "version")]
8    pub version: u32,
9    pub suite: String,
10    pub model: String,
11    #[serde(default, skip_serializing_if = "is_default_settings")]
12    pub settings: Settings,
13    #[serde(default, skip_serializing_if = "is_default_thresholds")]
14    pub thresholds: crate::thresholds::ThresholdConfig,
15    pub tests: Vec<TestCase>,
16}
17
18impl EvalConfig {
19    pub fn is_legacy(&self) -> bool {
20        self.version == 0
21    }
22
23    pub fn has_legacy_usage(&self) -> bool {
24        self.tests
25            .iter()
26            .any(|t| t.expected.get_policy_path().is_some())
27    }
28
29    pub fn validate(&self) -> anyhow::Result<()> {
30        if self.version >= 1 {
31            for test in &self.tests {
32                if matches!(test.expected, Expected::Reference { .. }) {
33                    anyhow::bail!("$ref in expected block is not allowed in configVersion >= 1. Run `assay migrate` to inline policies.");
34                }
35            }
36        }
37        Ok(())
38    }
39
40    /// Get the effective error policy for a test.
41    /// Test-level on_error overrides suite-level settings.
42    pub fn effective_error_policy(&self, test: &TestCase) -> ErrorPolicy {
43        test.on_error.unwrap_or(self.settings.on_error)
44    }
45}
46
47fn is_default_thresholds(t: &crate::thresholds::ThresholdConfig) -> bool {
48    t == &crate::thresholds::ThresholdConfig::default()
49}
50
51#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
52pub struct Settings {
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub parallel: Option<usize>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub timeout_seconds: Option<u64>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub cache: Option<bool>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub seed: Option<u64>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub judge: Option<JudgeConfig>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub thresholding: Option<ThresholdingSettings>,
65
66    /// Global error handling policy (default: block)
67    /// Can be overridden per-test
68    #[serde(default, skip_serializing_if = "is_default_error_policy")]
69    pub on_error: ErrorPolicy,
70
71    /// Bail on first failure (useful for CI)
72    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
73    pub bail_on_first_failure: bool,
74}
75
76fn is_default_error_policy(p: &ErrorPolicy) -> bool {
77    *p == ErrorPolicy::default()
78}
79
80fn is_default_settings(s: &Settings) -> bool {
81    s == &Settings::default()
82}
83
84#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
85pub struct ThresholdingSettings {
86    pub mode: Option<String>,
87    pub max_drop: Option<f64>,
88    pub min_floor: Option<f64>,
89}
90
91#[derive(Debug, Clone, Default, Serialize)]
92pub struct TestCase {
93    pub id: String,
94    pub input: TestInput,
95    pub expected: Expected,
96    #[serde(default, skip_serializing_if = "Option::is_none")]
97    pub assertions: Option<Vec<crate::agent_assertions::model::TraceAssertion>>,
98    /// Per-test error handling policy override
99    /// If None, uses settings.on_error
100    #[serde(default, skip_serializing_if = "Option::is_none")]
101    pub on_error: Option<ErrorPolicy>,
102    #[serde(default, skip_serializing_if = "Vec::is_empty")]
103    pub tags: Vec<String>,
104    #[serde(default, skip_serializing_if = "Option::is_none")]
105    pub metadata: Option<serde_json::Value>,
106}
107
108impl<'de> Deserialize<'de> for TestCase {
109    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
110    where
111        D: serde::Deserializer<'de>,
112    {
113        #[derive(Deserialize)]
114        #[serde(deny_unknown_fields)]
115        struct RawTestCase {
116            id: String,
117            input: TestInput,
118            #[serde(default)]
119            expected: Option<serde_json::Value>,
120            assertions: Option<Vec<crate::agent_assertions::model::TraceAssertion>>,
121            #[serde(default)]
122            on_error: Option<ErrorPolicy>,
123            #[serde(default)]
124            tags: Vec<String>,
125            metadata: Option<serde_json::Value>,
126        }
127
128        let raw = RawTestCase::deserialize(deserializer)?;
129        let mut expected_main = Expected::default();
130        let extra_assertions = raw.assertions.unwrap_or_default();
131
132        if let Some(val) = raw.expected {
133            if let Some(arr) = val.as_array() {
134                // Legacy list format
135                for (i, item) in arr.iter().enumerate() {
136                    // Try to parse as Expected
137                    // Try to parse as Expected (Strict V1)
138                    if let Ok(exp) = serde_json::from_value::<Expected>(item.clone()) {
139                        if i == 0 {
140                            expected_main = exp;
141                        }
142                    } else if let Some(obj) = item.as_object() {
143                        // Try Legacy Heuristics
144                        let mut parsed = None;
145                        let mut matched_keys = Vec::new();
146
147                        if let Some(r) = obj.get("$ref") {
148                            parsed = Some(Expected::Reference {
149                                path: r.as_str().unwrap_or("").to_string(),
150                            });
151                            matched_keys.push("$ref");
152                        }
153
154                        // Don't chain else-ifs, check all to detect ambiguity
155                        if let Some(mc) = obj.get("must_contain") {
156                            let val = if mc.is_string() {
157                                vec![mc.as_str().unwrap().to_string()]
158                            } else {
159                                serde_json::from_value(mc.clone()).unwrap_or_default()
160                            };
161                            // Last match wins for parsed, but we warn below
162                            if parsed.is_none() {
163                                parsed = Some(Expected::MustContain { must_contain: val });
164                            }
165                            matched_keys.push("must_contain");
166                        }
167
168                        if obj.get("sequence").is_some() {
169                            if parsed.is_none() {
170                                parsed = Some(Expected::SequenceValid {
171                                    policy: None,
172                                    sequence: serde_json::from_value(
173                                        obj.get("sequence").unwrap().clone(),
174                                    )
175                                    .ok(),
176                                    rules: None,
177                                });
178                            }
179                            matched_keys.push("sequence");
180                        }
181
182                        if obj.get("schema").is_some() {
183                            if parsed.is_none() {
184                                parsed = Some(Expected::ArgsValid {
185                                    policy: None,
186                                    schema: obj.get("schema").cloned(),
187                                });
188                            }
189                            matched_keys.push("schema");
190                        }
191
192                        if matched_keys.len() > 1 {
193                            eprintln!("WARN: Ambiguous legacy expected block. Found keys: {:?}. Using first match.", matched_keys);
194                        }
195
196                        if let Some(p) = parsed {
197                            if i == 0 {
198                                expected_main = p;
199                            }
200                            // else: drop or move to assertions (out of scope for quick fix, primary policy is priority)
201                        }
202                    }
203                }
204            } else {
205                // Try V1 single object
206                if let Ok(exp) = serde_json::from_value(val.clone()) {
207                    expected_main = exp;
208                }
209            }
210        }
211
212        Ok(TestCase {
213            id: raw.id,
214            input: raw.input,
215            expected: expected_main,
216            assertions: if extra_assertions.is_empty() {
217                None
218            } else {
219                Some(extra_assertions)
220            },
221            on_error: raw.on_error,
222            tags: raw.tags,
223            metadata: raw.metadata,
224        })
225    }
226}
227
228#[derive(Debug, Clone, Default, Serialize)]
229pub struct TestInput {
230    pub prompt: String,
231    #[serde(default, skip_serializing_if = "Option::is_none")]
232    pub context: Option<Vec<String>>,
233}
234
235impl<'de> Deserialize<'de> for TestInput {
236    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
237    where
238        D: serde::Deserializer<'de>,
239    {
240        struct TestInputVisitor;
241
242        impl<'de> serde::de::Visitor<'de> for TestInputVisitor {
243            type Value = TestInput;
244
245            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246                formatter.write_str("string or struct TestInput")
247            }
248
249            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
250            where
251                E: serde::de::Error,
252            {
253                Ok(TestInput {
254                    prompt: value.to_owned(),
255                    context: None,
256                })
257            }
258
259            fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
260            where
261                A: serde::de::MapAccess<'de>,
262            {
263                // Default derivation logic manually implemented or use intermediate struct
264                // Using intermediate struct is easier to avoid massive boilerplate
265                #[derive(Deserialize)]
266                struct Helper {
267                    prompt: String,
268                    #[serde(default)]
269                    context: Option<Vec<String>>,
270                }
271                let helper =
272                    Helper::deserialize(serde::de::value::MapAccessDeserializer::new(map))?;
273                Ok(TestInput {
274                    prompt: helper.prompt,
275                    context: helper.context,
276                })
277            }
278        }
279
280        deserializer.deserialize_any(TestInputVisitor)
281    }
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize)]
285#[serde(rename_all = "snake_case", tag = "type")]
286pub enum Expected {
287    MustContain {
288        #[serde(default)]
289        must_contain: Vec<String>,
290    },
291    MustNotContain {
292        #[serde(default)]
293        must_not_contain: Vec<String>,
294    },
295
296    RegexMatch {
297        pattern: String,
298        #[serde(default)]
299        flags: Vec<String>,
300    },
301    RegexNotMatch {
302        pattern: String,
303        #[serde(default)]
304        flags: Vec<String>,
305    },
306
307    JsonSchema {
308        json_schema: String,
309        #[serde(default)]
310        schema_file: Option<String>,
311    },
312    SemanticSimilarityTo {
313        // canonical field
314        #[serde(alias = "text")]
315        semantic_similarity_to: String,
316
317        // canonical field
318        #[serde(default = "default_min_score", alias = "threshold")]
319        min_score: f64,
320
321        #[serde(default)]
322        thresholding: Option<ThresholdingConfig>,
323    },
324    JudgeCriteria {
325        judge_criteria: serde_json::Value,
326    },
327    Faithfulness {
328        #[serde(default = "default_min_score")]
329        min_score: f64,
330        rubric_version: Option<String>,
331        #[serde(default)]
332        thresholding: Option<ThresholdingConfig>,
333    },
334    Relevance {
335        #[serde(default = "default_min_score")]
336        min_score: f64,
337        rubric_version: Option<String>,
338        #[serde(default)]
339        thresholding: Option<ThresholdingConfig>,
340    },
341
342    ArgsValid {
343        #[serde(skip_serializing_if = "Option::is_none")]
344        policy: Option<String>,
345        #[serde(default, skip_serializing_if = "Option::is_none")]
346        schema: Option<serde_json::Value>,
347    },
348    SequenceValid {
349        #[serde(skip_serializing_if = "Option::is_none")]
350        policy: Option<String>,
351        #[serde(default, skip_serializing_if = "Option::is_none")]
352        sequence: Option<Vec<String>>,
353        #[serde(default, skip_serializing_if = "Option::is_none")]
354        rules: Option<Vec<SequenceRule>>,
355    },
356    ToolBlocklist {
357        blocked: Vec<String>,
358    },
359    // For migration/legacy support
360    #[serde(rename = "$ref")]
361    Reference {
362        path: String,
363    },
364}
365
366impl Default for Expected {
367    fn default() -> Self {
368        Expected::MustContain {
369            must_contain: vec![],
370        }
371    }
372}
373
374#[derive(Debug, Clone, Serialize, Deserialize)]
375#[serde(deny_unknown_fields)]
376pub struct Policy {
377    pub version: String,
378    #[serde(default)]
379    pub name: String,
380    #[serde(default)]
381    pub metadata: Option<serde_json::Value>,
382    #[serde(default)]
383    pub tools: ToolsPolicy,
384    #[serde(default)]
385    pub sequences: Vec<SequenceRule>,
386    #[serde(default)]
387    pub aliases: std::collections::HashMap<String, Vec<String>>,
388    #[serde(default)]
389    pub on_error: ErrorPolicy,
390}
391
392#[derive(Debug, Clone, Default, Serialize, Deserialize)]
393#[serde(deny_unknown_fields)]
394pub struct ToolsPolicy {
395    #[serde(default)]
396    pub allow: Option<Vec<String>>,
397    #[serde(default)]
398    pub deny: Option<Vec<String>>,
399    #[serde(default)]
400    pub require_args: Option<std::collections::HashMap<String, Vec<String>>>,
401    #[serde(default)]
402    pub arg_constraints: Option<
403        std::collections::HashMap<String, std::collections::HashMap<String, serde_json::Value>>,
404    >,
405}
406
407#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
408#[serde(tag = "type", rename_all = "snake_case")]
409pub enum SequenceRule {
410    Require {
411        tool: String,
412    },
413    Eventually {
414        tool: String,
415        within: u32,
416    },
417    MaxCalls {
418        tool: String,
419        max: u32,
420    },
421    Before {
422        first: String,
423        then: String,
424    },
425    After {
426        trigger: String,
427        then: String,
428        #[serde(default = "default_one")]
429        within: u32,
430    },
431    NeverAfter {
432        trigger: String,
433        forbidden: String,
434    },
435    Sequence {
436        tools: Vec<String>,
437        #[serde(default)]
438        strict: bool,
439    },
440    Blocklist {
441        pattern: String,
442    },
443}
444
445fn default_one() -> u32 {
446    1
447}
448
449// Helper for alias resolution
450impl Policy {
451    pub fn load<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<Self> {
452        let content = std::fs::read_to_string(path)?;
453        let policy: Policy = serde_yaml::from_str(&content)?;
454        Ok(policy)
455    }
456
457    pub fn resolve_alias(&self, tool_name: &str) -> Vec<String> {
458        if let Some(members) = self.aliases.get(tool_name) {
459            members.clone()
460        } else {
461            // If not an alias, return strict singleton if no alias found?
462            // RFC says: "Matches SearchKnowledgeBase OR SearchWeb".
463            // "Alias can be used anywhere a tool name is expected".
464            // If we rely on resolve_alias to return all matches for a "rule target",
465            // AND we want to support literals:
466            // If 'Search' is in aliases, satisfy if match any alias member.
467            // If 'Search' is NOT in aliases, it's a literal.
468            vec![tool_name.to_string()]
469        }
470    }
471}
472
473impl Expected {
474    pub fn get_policy_path(&self) -> Option<&str> {
475        match self {
476            Expected::ArgsValid { policy, .. } => policy.as_deref(),
477            Expected::SequenceValid { policy, .. } => policy.as_deref(),
478            _ => None,
479        }
480    }
481}
482
483#[derive(Debug, Clone, Serialize, Deserialize)]
484pub struct ToolCallRecord {
485    pub id: String,
486    pub tool_name: String,
487    pub args: serde_json::Value,
488    pub result: Option<serde_json::Value>,
489    pub error: Option<serde_json::Value>,
490    pub index: usize,
491    pub ts_ms: u64,
492}
493
494#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
495pub struct ThresholdingConfig {
496    pub max_drop: Option<f64>,
497}
498
499#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
500pub struct JudgeConfig {
501    pub rubric_version: Option<String>,
502    pub samples: Option<u32>,
503}
504
505fn default_min_score() -> f64 {
506    0.80
507}
508
509#[derive(Debug, Clone, Default, Serialize, Deserialize)]
510pub struct LlmResponse {
511    pub text: String,
512    pub provider: String,
513    pub model: String,
514    pub cached: bool,
515    #[serde(default)]
516    pub meta: serde_json::Value,
517}
518
519#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
520#[serde(rename_all = "snake_case")]
521pub enum TestStatus {
522    Pass,
523    Fail,
524    Flaky,
525    Warn,
526    Error,
527    Skipped,
528    Unstable,
529    /// New: Action was allowed despite error (fail-open mode)
530    AllowedOnError,
531}
532
533impl TestStatus {
534    pub fn parse(s: &str) -> Self {
535        match s {
536            "pass" => TestStatus::Pass,
537            "fail" => TestStatus::Fail,
538            "flaky" => TestStatus::Flaky,
539            "warn" => TestStatus::Warn,
540            "error" => TestStatus::Error,
541            "skipped" => TestStatus::Skipped,
542            "unstable" => TestStatus::Unstable,
543            "allowed_on_error" => TestStatus::AllowedOnError,
544            _ => TestStatus::Error,
545        }
546    }
547
548    /// Returns true if this status should be treated as passing for CI purposes
549    pub fn is_passing(&self) -> bool {
550        matches!(
551            self,
552            TestStatus::Pass | TestStatus::AllowedOnError | TestStatus::Warn
553        )
554    }
555
556    /// Returns true if this status should block CI
557    pub fn is_blocking(&self) -> bool {
558        matches!(self, TestStatus::Fail | TestStatus::Error)
559    }
560}
561
562#[derive(Debug, Clone, Serialize, Deserialize)]
563pub struct TestResultRow {
564    pub test_id: String,
565    pub status: TestStatus,
566    pub score: Option<f64>,
567    pub cached: bool,
568    pub message: String,
569    #[serde(default)]
570    pub details: serde_json::Value,
571    pub duration_ms: Option<u64>,
572    #[serde(default)]
573    pub fingerprint: Option<String>,
574    #[serde(default)]
575    pub skip_reason: Option<String>,
576    #[serde(default)]
577    pub attempts: Option<Vec<AttemptRow>>,
578    /// Error policy that was applied (if error occurred)
579    #[serde(default, skip_serializing_if = "Option::is_none")]
580    pub error_policy_applied: Option<ErrorPolicy>,
581}
582
583#[derive(Debug, Clone, Serialize, Deserialize)]
584pub struct AttemptRow {
585    pub attempt_no: u32,
586    pub status: TestStatus,
587    pub message: String,
588    pub duration_ms: Option<u64>,
589    #[serde(default)]
590    pub details: serde_json::Value,
591}
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596
597    #[test]
598    fn test_string_input_deserialize() {
599        let yaml = r#"
600            id: test1
601            input: "simple string"
602            expected:
603              type: must_contain
604              must_contain: ["foo"]
605        "#;
606        let tc: TestCase = serde_yaml::from_str(yaml).expect("failed to parse");
607        assert_eq!(tc.input.prompt, "simple string");
608    }
609
610    #[test]
611    fn test_legacy_list_expected() {
612        let yaml = r#"
613            id: test1
614            input: "test"
615            expected:
616              - must_contain: "Paris"
617              - must_not_contain: "London"
618        "#;
619        let tc: TestCase = serde_yaml::from_str(yaml).expect("failed to parse");
620        if let Expected::MustContain { must_contain } = tc.expected {
621            assert_eq!(must_contain, vec!["Paris"]);
622        } else {
623            panic!("Expected MustContain, got {:?}", tc.expected);
624        }
625    }
626
627    #[test]
628    fn test_scalar_must_contain_promotion() {
629        let yaml = r#"
630            id: test1
631            input: "test"
632            expected:
633              - must_contain: "single value"
634        "#;
635        let tc: TestCase = serde_yaml::from_str(yaml).unwrap();
636        if let Expected::MustContain { must_contain } = tc.expected {
637            assert_eq!(must_contain, vec!["single value"]);
638        } else {
639            panic!("Expected MustContain");
640        }
641    }
642
643    #[test]
644    fn test_validate_ref_in_v1() {
645        let config = EvalConfig {
646            version: 1,
647            suite: "test".into(),
648            model: "test".into(),
649            settings: Settings::default(),
650            thresholds: Default::default(),
651            tests: vec![TestCase {
652                id: "t1".into(),
653                input: TestInput {
654                    prompt: "hi".into(),
655                    context: None,
656                },
657                expected: Expected::Reference {
658                    path: "foo.yaml".into(),
659                },
660                assertions: None,
661                tags: vec![],
662                metadata: None,
663                on_error: None,
664            }],
665        };
666        assert!(config.validate().is_err());
667    }
668}