Skip to main content

assay_core/
model.rs

1use crate::on_error::ErrorPolicy;
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5#[serde(deny_unknown_fields)]
6pub struct EvalConfig {
7    #[serde(default, rename = "configVersion", alias = "version")]
8    pub version: u32,
9    pub suite: String,
10    pub model: String,
11    #[serde(default, skip_serializing_if = "is_default_settings")]
12    pub settings: Settings,
13    #[serde(default, skip_serializing_if = "is_default_thresholds")]
14    pub thresholds: crate::thresholds::ThresholdConfig,
15    #[serde(default, skip_serializing_if = "is_default_otel")]
16    pub otel: crate::config::otel::OtelConfig,
17    pub tests: Vec<TestCase>,
18}
19
20fn is_default_otel(o: &crate::config::otel::OtelConfig) -> bool {
21    o == &crate::config::otel::OtelConfig::default()
22}
23
24impl EvalConfig {
25    pub fn is_legacy(&self) -> bool {
26        self.version == 0
27    }
28
29    pub fn has_legacy_usage(&self) -> bool {
30        self.tests
31            .iter()
32            .any(|t| t.expected.get_policy_path().is_some())
33    }
34
35    pub fn validate(&self) -> anyhow::Result<()> {
36        if self.version >= 1 {
37            for test in &self.tests {
38                if matches!(test.expected, Expected::Reference { .. }) {
39                    anyhow::bail!("$ref in expected block is not allowed in configVersion >= 1. Run `assay migrate` to inline policies.");
40                }
41            }
42        }
43        Ok(())
44    }
45
46    /// Get the effective error policy for a test.
47    /// Test-level on_error overrides suite-level settings.
48    pub fn effective_error_policy(&self, test: &TestCase) -> ErrorPolicy {
49        test.on_error.unwrap_or(self.settings.on_error)
50    }
51}
52
53fn is_default_thresholds(t: &crate::thresholds::ThresholdConfig) -> bool {
54    t == &crate::thresholds::ThresholdConfig::default()
55}
56
57#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
58pub struct Settings {
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub parallel: Option<usize>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub timeout_seconds: Option<u64>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub cache: Option<bool>,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub seed: Option<u64>,
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub judge: Option<JudgeConfig>,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub thresholding: Option<ThresholdingSettings>,
71
72    /// Global error handling policy (default: block)
73    /// Can be overridden per-test
74    #[serde(default, skip_serializing_if = "is_default_error_policy")]
75    pub on_error: ErrorPolicy,
76
77    /// Bail on first failure (useful for CI)
78    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
79    pub bail_on_first_failure: bool,
80}
81
82fn is_default_error_policy(p: &ErrorPolicy) -> bool {
83    *p == ErrorPolicy::default()
84}
85
86fn is_default_settings(s: &Settings) -> bool {
87    s == &Settings::default()
88}
89
90#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
91pub struct ThresholdingSettings {
92    pub mode: Option<String>,
93    pub max_drop: Option<f64>,
94    pub min_floor: Option<f64>,
95}
96
97#[derive(Debug, Clone, Default, Serialize)]
98pub struct TestCase {
99    pub id: String,
100    pub input: TestInput,
101    pub expected: Expected,
102    #[serde(default, skip_serializing_if = "Option::is_none")]
103    pub assertions: Option<Vec<crate::agent_assertions::model::TraceAssertion>>,
104    /// Per-test error handling policy override
105    /// If None, uses settings.on_error
106    #[serde(default, skip_serializing_if = "Option::is_none")]
107    pub on_error: Option<ErrorPolicy>,
108    #[serde(default, skip_serializing_if = "Vec::is_empty")]
109    pub tags: Vec<String>,
110    #[serde(default, skip_serializing_if = "Option::is_none")]
111    pub metadata: Option<serde_json::Value>,
112}
113
114impl<'de> Deserialize<'de> for TestCase {
115    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
116    where
117        D: serde::Deserializer<'de>,
118    {
119        #[derive(Deserialize)]
120        #[serde(deny_unknown_fields)]
121        struct RawTestCase {
122            id: String,
123            input: TestInput,
124            #[serde(default)]
125            expected: Option<serde_json::Value>,
126            assertions: Option<Vec<crate::agent_assertions::model::TraceAssertion>>,
127            #[serde(default)]
128            on_error: Option<ErrorPolicy>,
129            #[serde(default)]
130            tags: Vec<String>,
131            metadata: Option<serde_json::Value>,
132        }
133
134        let raw = RawTestCase::deserialize(deserializer)?;
135        let mut expected_main = Expected::default();
136        let extra_assertions = raw.assertions.unwrap_or_default();
137
138        if let Some(val) = raw.expected {
139            if let Some(arr) = val.as_array() {
140                // Legacy list format
141                for (i, item) in arr.iter().enumerate() {
142                    // Try to parse as Expected
143                    // Try to parse as Expected (Strict V1)
144                    if let Ok(exp) = serde_json::from_value::<Expected>(item.clone()) {
145                        if i == 0 {
146                            expected_main = exp;
147                        }
148                    } else if let Some(obj) = item.as_object() {
149                        // Try Legacy Heuristics
150                        let mut parsed = None;
151                        let mut matched_keys = Vec::new();
152
153                        if let Some(r) = obj.get("$ref") {
154                            parsed = Some(Expected::Reference {
155                                path: r.as_str().unwrap_or("").to_string(),
156                            });
157                            matched_keys.push("$ref");
158                        }
159
160                        // Don't chain else-ifs, check all to detect ambiguity
161                        if let Some(mc) = obj.get("must_contain") {
162                            let val = if mc.is_string() {
163                                vec![mc.as_str().unwrap().to_string()]
164                            } else {
165                                serde_json::from_value(mc.clone()).unwrap_or_default()
166                            };
167                            // Last match wins for parsed, but we warn below
168                            if parsed.is_none() {
169                                parsed = Some(Expected::MustContain { must_contain: val });
170                            }
171                            matched_keys.push("must_contain");
172                        }
173
174                        if obj.get("sequence").is_some() {
175                            if parsed.is_none() {
176                                parsed = Some(Expected::SequenceValid {
177                                    policy: None,
178                                    sequence: serde_json::from_value(
179                                        obj.get("sequence").unwrap().clone(),
180                                    )
181                                    .ok(),
182                                    rules: None,
183                                });
184                            }
185                            matched_keys.push("sequence");
186                        }
187
188                        if obj.get("schema").is_some() {
189                            if parsed.is_none() {
190                                parsed = Some(Expected::ArgsValid {
191                                    policy: None,
192                                    schema: obj.get("schema").cloned(),
193                                });
194                            }
195                            matched_keys.push("schema");
196                        }
197
198                        if matched_keys.len() > 1 {
199                            eprintln!("WARN: Ambiguous legacy expected block. Found keys: {:?}. Using first match.", matched_keys);
200                        }
201
202                        if let Some(p) = parsed {
203                            if i == 0 {
204                                expected_main = p;
205                            }
206                            // else: drop or move to assertions (out of scope for quick fix, primary policy is priority)
207                        }
208                    }
209                }
210            } else {
211                // Try V1 single object
212                if let Ok(exp) = serde_json::from_value(val.clone()) {
213                    expected_main = exp;
214                }
215            }
216        }
217
218        Ok(TestCase {
219            id: raw.id,
220            input: raw.input,
221            expected: expected_main,
222            assertions: if extra_assertions.is_empty() {
223                None
224            } else {
225                Some(extra_assertions)
226            },
227            on_error: raw.on_error,
228            tags: raw.tags,
229            metadata: raw.metadata,
230        })
231    }
232}
233
234#[derive(Debug, Clone, Default, Serialize)]
235pub struct TestInput {
236    pub prompt: String,
237    #[serde(default, skip_serializing_if = "Option::is_none")]
238    pub context: Option<Vec<String>>,
239}
240
241impl<'de> Deserialize<'de> for TestInput {
242    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
243    where
244        D: serde::Deserializer<'de>,
245    {
246        struct TestInputVisitor;
247
248        impl<'de> serde::de::Visitor<'de> for TestInputVisitor {
249            type Value = TestInput;
250
251            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252                formatter.write_str("string or struct TestInput")
253            }
254
255            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
256            where
257                E: serde::de::Error,
258            {
259                Ok(TestInput {
260                    prompt: value.to_owned(),
261                    context: None,
262                })
263            }
264
265            fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
266            where
267                A: serde::de::MapAccess<'de>,
268            {
269                // Default derivation logic manually implemented or use intermediate struct
270                // Using intermediate struct is easier to avoid massive boilerplate
271                #[derive(Deserialize)]
272                struct Helper {
273                    prompt: String,
274                    #[serde(default)]
275                    context: Option<Vec<String>>,
276                }
277                let helper =
278                    Helper::deserialize(serde::de::value::MapAccessDeserializer::new(map))?;
279                Ok(TestInput {
280                    prompt: helper.prompt,
281                    context: helper.context,
282                })
283            }
284        }
285
286        deserializer.deserialize_any(TestInputVisitor)
287    }
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize)]
291#[serde(rename_all = "snake_case", tag = "type")]
292pub enum Expected {
293    MustContain {
294        #[serde(default)]
295        must_contain: Vec<String>,
296    },
297    MustNotContain {
298        #[serde(default)]
299        must_not_contain: Vec<String>,
300    },
301
302    RegexMatch {
303        pattern: String,
304        #[serde(default)]
305        flags: Vec<String>,
306    },
307    RegexNotMatch {
308        pattern: String,
309        #[serde(default)]
310        flags: Vec<String>,
311    },
312
313    JsonSchema {
314        json_schema: String,
315        #[serde(default)]
316        schema_file: Option<String>,
317    },
318    SemanticSimilarityTo {
319        // canonical field
320        #[serde(alias = "text")]
321        semantic_similarity_to: String,
322
323        // canonical field
324        #[serde(default = "default_min_score", alias = "threshold")]
325        min_score: f64,
326
327        #[serde(default)]
328        thresholding: Option<ThresholdingConfig>,
329    },
330    JudgeCriteria {
331        judge_criteria: serde_json::Value,
332    },
333    Faithfulness {
334        #[serde(default = "default_min_score")]
335        min_score: f64,
336        rubric_version: Option<String>,
337        #[serde(default)]
338        thresholding: Option<ThresholdingConfig>,
339    },
340    Relevance {
341        #[serde(default = "default_min_score")]
342        min_score: f64,
343        rubric_version: Option<String>,
344        #[serde(default)]
345        thresholding: Option<ThresholdingConfig>,
346    },
347
348    ArgsValid {
349        #[serde(skip_serializing_if = "Option::is_none")]
350        policy: Option<String>,
351        #[serde(default, skip_serializing_if = "Option::is_none")]
352        schema: Option<serde_json::Value>,
353    },
354    SequenceValid {
355        #[serde(skip_serializing_if = "Option::is_none")]
356        policy: Option<String>,
357        #[serde(default, skip_serializing_if = "Option::is_none")]
358        sequence: Option<Vec<String>>,
359        #[serde(default, skip_serializing_if = "Option::is_none")]
360        rules: Option<Vec<SequenceRule>>,
361    },
362    ToolBlocklist {
363        blocked: Vec<String>,
364    },
365    // For migration/legacy support
366    #[serde(rename = "$ref")]
367    Reference {
368        path: String,
369    },
370}
371
372impl Default for Expected {
373    fn default() -> Self {
374        Expected::MustContain {
375            must_contain: vec![],
376        }
377    }
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize)]
381#[serde(deny_unknown_fields)]
382pub struct Policy {
383    pub version: String,
384    #[serde(default)]
385    pub name: String,
386    #[serde(default)]
387    pub metadata: Option<serde_json::Value>,
388    #[serde(default)]
389    pub tools: ToolsPolicy,
390    #[serde(default)]
391    pub sequences: Vec<SequenceRule>,
392    #[serde(default)]
393    pub aliases: std::collections::HashMap<String, Vec<String>>,
394    #[serde(default)]
395    pub on_error: ErrorPolicy,
396}
397
398#[derive(Debug, Clone, Default, Serialize, Deserialize)]
399#[serde(deny_unknown_fields)]
400pub struct ToolsPolicy {
401    #[serde(default)]
402    pub allow: Option<Vec<String>>,
403    #[serde(default)]
404    pub deny: Option<Vec<String>>,
405    #[serde(default)]
406    pub require_args: Option<std::collections::HashMap<String, Vec<String>>>,
407    #[serde(default)]
408    pub arg_constraints: Option<
409        std::collections::HashMap<String, std::collections::HashMap<String, serde_json::Value>>,
410    >,
411}
412
413#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
414#[serde(tag = "type", rename_all = "snake_case")]
415pub enum SequenceRule {
416    Require {
417        tool: String,
418    },
419    Eventually {
420        tool: String,
421        within: u32,
422    },
423    MaxCalls {
424        tool: String,
425        max: u32,
426    },
427    Before {
428        first: String,
429        then: String,
430    },
431    After {
432        trigger: String,
433        then: String,
434        #[serde(default = "default_one")]
435        within: u32,
436    },
437    NeverAfter {
438        trigger: String,
439        forbidden: String,
440    },
441    Sequence {
442        tools: Vec<String>,
443        #[serde(default)]
444        strict: bool,
445    },
446    Blocklist {
447        pattern: String,
448    },
449}
450
451fn default_one() -> u32 {
452    1
453}
454
455// Helper for alias resolution
456impl Policy {
457    pub fn load<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<Self> {
458        let content = std::fs::read_to_string(path)?;
459        let policy: Policy = serde_yaml::from_str(&content)?;
460        Ok(policy)
461    }
462
463    pub fn resolve_alias(&self, tool_name: &str) -> Vec<String> {
464        if let Some(members) = self.aliases.get(tool_name) {
465            members.clone()
466        } else {
467            // If not an alias, return strict singleton if no alias found?
468            // RFC says: "Matches SearchKnowledgeBase OR SearchWeb".
469            // "Alias can be used anywhere a tool name is expected".
470            // If we rely on resolve_alias to return all matches for a "rule target",
471            // AND we want to support literals:
472            // If 'Search' is in aliases, satisfy if match any alias member.
473            // If 'Search' is NOT in aliases, it's a literal.
474            vec![tool_name.to_string()]
475        }
476    }
477}
478
479impl Expected {
480    pub fn get_policy_path(&self) -> Option<&str> {
481        match self {
482            Expected::ArgsValid { policy, .. } => policy.as_deref(),
483            Expected::SequenceValid { policy, .. } => policy.as_deref(),
484            _ => None,
485        }
486    }
487
488    /// Per-test thresholding for baseline regression (mode/max_drop) when this Expected variant matches the metric.
489    pub fn thresholding_for_metric(&self, metric_name: &str) -> Option<&ThresholdingConfig> {
490        match (metric_name, self) {
491            ("semantic_similarity_to", Expected::SemanticSimilarityTo { thresholding, .. }) => {
492                thresholding.as_ref()
493            }
494            ("faithfulness", Expected::Faithfulness { thresholding, .. }) => thresholding.as_ref(),
495            ("relevance", Expected::Relevance { thresholding, .. }) => thresholding.as_ref(),
496            _ => None,
497        }
498    }
499}
500
501#[derive(Debug, Clone, Serialize, Deserialize)]
502pub struct ToolCallRecord {
503    pub id: String,
504    pub tool_name: String,
505    pub args: serde_json::Value,
506    pub result: Option<serde_json::Value>,
507    pub error: Option<serde_json::Value>,
508    pub index: usize,
509    pub ts_ms: u64,
510}
511
512#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
513pub struct ThresholdingConfig {
514    pub max_drop: Option<f64>,
515}
516
517#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
518pub struct JudgeConfig {
519    pub rubric_version: Option<String>,
520    pub samples: Option<u32>,
521    #[serde(default)]
522    pub reliability: crate::judge::reliability::ReliabilityConfig,
523}
524
525fn default_min_score() -> f64 {
526    0.80
527}
528
529#[derive(Debug, Clone, Default, Serialize, Deserialize)]
530pub struct LlmResponse {
531    pub text: String,
532    pub provider: String,
533    pub model: String,
534    pub cached: bool,
535    #[serde(default)]
536    pub meta: serde_json::Value,
537}
538
539#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
540#[serde(rename_all = "snake_case")]
541pub enum TestStatus {
542    Pass,
543    Fail,
544    Flaky,
545    Warn,
546    Error,
547    Skipped,
548    Unstable,
549    /// Action was allowed despite an upstream error (fail-open mode).
550    AllowedOnError,
551}
552
553impl TestStatus {
554    pub fn parse(s: &str) -> Self {
555        match s {
556            "pass" => TestStatus::Pass,
557            "fail" => TestStatus::Fail,
558            "flaky" => TestStatus::Flaky,
559            "warn" => TestStatus::Warn,
560            "error" => TestStatus::Error,
561            "skipped" => TestStatus::Skipped,
562            "unstable" => TestStatus::Unstable,
563            "allowed_on_error" => TestStatus::AllowedOnError,
564            _ => TestStatus::Error,
565        }
566    }
567
568    /// Returns true if this status should be treated as passing for CI purposes
569    pub fn is_passing(&self) -> bool {
570        matches!(
571            self,
572            TestStatus::Pass | TestStatus::AllowedOnError | TestStatus::Warn
573        )
574    }
575
576    /// Returns true if this status should block CI
577    pub fn is_blocking(&self) -> bool {
578        matches!(self, TestStatus::Fail | TestStatus::Error)
579    }
580}
581
582#[derive(Debug, Clone, Serialize, Deserialize)]
583pub struct TestResultRow {
584    pub test_id: String,
585    pub status: TestStatus,
586    pub score: Option<f64>,
587    pub cached: bool,
588    pub message: String,
589    #[serde(default)]
590    pub details: serde_json::Value,
591    pub duration_ms: Option<u64>,
592    #[serde(default)]
593    pub fingerprint: Option<String>,
594    #[serde(default)]
595    pub skip_reason: Option<String>,
596    #[serde(default)]
597    pub attempts: Option<Vec<AttemptRow>>,
598    /// Error policy that was applied (if error occurred)
599    #[serde(default, skip_serializing_if = "Option::is_none")]
600    pub error_policy_applied: Option<ErrorPolicy>,
601}
602
603#[derive(Debug, Clone, Serialize, Deserialize)]
604pub struct AttemptRow {
605    pub attempt_no: u32,
606    pub status: TestStatus,
607    pub message: String,
608    pub duration_ms: Option<u64>,
609    #[serde(default)]
610    pub details: serde_json::Value,
611}
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616
617    #[test]
618    fn test_string_input_deserialize() {
619        let yaml = r#"
620            id: test1
621            input: "simple string"
622            expected:
623              type: must_contain
624              must_contain: ["foo"]
625        "#;
626        let tc: TestCase = serde_yaml::from_str(yaml).expect("failed to parse");
627        assert_eq!(tc.input.prompt, "simple string");
628    }
629
630    #[test]
631    fn test_legacy_list_expected() {
632        let yaml = r#"
633            id: test1
634            input: "test"
635            expected:
636              - must_contain: "Paris"
637              - must_not_contain: "London"
638        "#;
639        let tc: TestCase = serde_yaml::from_str(yaml).expect("failed to parse");
640        if let Expected::MustContain { must_contain } = tc.expected {
641            assert_eq!(must_contain, vec!["Paris"]);
642        } else {
643            panic!("Expected MustContain, got {:?}", tc.expected);
644        }
645    }
646
647    #[test]
648    fn test_scalar_must_contain_promotion() {
649        let yaml = r#"
650            id: test1
651            input: "test"
652            expected:
653              - must_contain: "single value"
654        "#;
655        let tc: TestCase = serde_yaml::from_str(yaml).unwrap();
656        if let Expected::MustContain { must_contain } = tc.expected {
657            assert_eq!(must_contain, vec!["single value"]);
658        } else {
659            panic!("Expected MustContain");
660        }
661    }
662
663    #[test]
664    fn test_validate_ref_in_v1() {
665        let config = EvalConfig {
666            version: 1,
667            suite: "test".into(),
668            model: "test".into(),
669            settings: Settings::default(),
670            thresholds: Default::default(),
671            tests: vec![TestCase {
672                id: "t1".into(),
673                input: TestInput {
674                    prompt: "hi".into(),
675                    context: None,
676                },
677                expected: Expected::Reference {
678                    path: "foo.yaml".into(),
679                },
680                assertions: None,
681                tags: vec![],
682                metadata: None,
683                on_error: None,
684            }],
685            otel: Default::default(),
686        };
687        assert!(config.validate().is_err());
688    }
689
690    #[test]
691    fn test_thresholding_for_metric() {
692        // No thresholding
693        let exp = Expected::SemanticSimilarityTo {
694            semantic_similarity_to: "ref".into(),
695            min_score: 0.8,
696            thresholding: None,
697        };
698        assert!(exp
699            .thresholding_for_metric("semantic_similarity_to")
700            .is_none());
701        // With thresholding
702        let exp = Expected::SemanticSimilarityTo {
703            semantic_similarity_to: "ref".into(),
704            min_score: 0.8,
705            thresholding: Some(ThresholdingConfig {
706                max_drop: Some(0.05),
707            }),
708        };
709        let t = exp
710            .thresholding_for_metric("semantic_similarity_to")
711            .unwrap();
712        assert_eq!(t.max_drop, Some(0.05));
713        // Wrong metric name
714        assert!(exp.thresholding_for_metric("faithfulness").is_none());
715        // Faithfulness variant
716        let exp = Expected::Faithfulness {
717            min_score: 0.7,
718            rubric_version: None,
719            thresholding: Some(ThresholdingConfig {
720                max_drop: Some(0.1),
721            }),
722        };
723        let t = exp.thresholding_for_metric("faithfulness").unwrap();
724        assert_eq!(t.max_drop, Some(0.1));
725    }
726}