Skip to main content

assay_core/model/
serde.rs

1use crate::on_error::ErrorPolicy;
2use serde::Deserialize;
3
4use super::types::{Expected, TestCase, TestInput};
5
6impl<'de> Deserialize<'de> for TestCase {
7    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
8    where
9        D: serde::Deserializer<'de>,
10    {
11        #[derive(Deserialize)]
12        #[serde(deny_unknown_fields)]
13        struct RawTestCase {
14            id: String,
15            input: TestInput,
16            #[serde(default)]
17            expected: Option<serde_json::Value>,
18            assertions: Option<Vec<crate::agent_assertions::model::TraceAssertion>>,
19            #[serde(default)]
20            on_error: Option<ErrorPolicy>,
21            #[serde(default)]
22            tags: Vec<String>,
23            metadata: Option<serde_json::Value>,
24        }
25
26        let raw = RawTestCase::deserialize(deserializer)?;
27        let mut expected_main = Expected::default();
28        let extra_assertions = raw.assertions.unwrap_or_default();
29
30        if let Some(val) = raw.expected {
31            if let Some(arr) = val.as_array() {
32                // Legacy list format
33                for (i, item) in arr.iter().enumerate() {
34                    // Try to parse as Expected
35                    // Try to parse as Expected (Strict V1)
36                    if let Ok(exp) = serde_json::from_value::<Expected>(item.clone()) {
37                        if i == 0 {
38                            expected_main = exp;
39                        }
40                    } else if let Some(obj) = item.as_object() {
41                        // Try Legacy Heuristics
42                        let mut parsed = None;
43                        let mut matched_keys = Vec::new();
44
45                        if let Some(r) = obj.get("$ref") {
46                            parsed = Some(Expected::Reference {
47                                path: r.as_str().unwrap_or("").to_string(),
48                            });
49                            matched_keys.push("$ref");
50                        }
51
52                        // Don't chain else-ifs, check all to detect ambiguity
53                        if let Some(mc) = obj.get("must_contain") {
54                            let val = if mc.is_string() {
55                                vec![mc.as_str().unwrap().to_string()]
56                            } else {
57                                serde_json::from_value(mc.clone()).unwrap_or_default()
58                            };
59                            // Last match wins for parsed, but we warn below
60                            if parsed.is_none() {
61                                parsed = Some(Expected::MustContain { must_contain: val });
62                            }
63                            matched_keys.push("must_contain");
64                        }
65
66                        if obj.get("sequence").is_some() {
67                            if parsed.is_none() {
68                                parsed = Some(Expected::SequenceValid {
69                                    policy: None,
70                                    sequence: serde_json::from_value(
71                                        obj.get("sequence").unwrap().clone(),
72                                    )
73                                    .ok(),
74                                    rules: None,
75                                });
76                            }
77                            matched_keys.push("sequence");
78                        }
79
80                        if obj.get("schema").is_some() {
81                            if parsed.is_none() {
82                                parsed = Some(Expected::ArgsValid {
83                                    policy: None,
84                                    schema: obj.get("schema").cloned(),
85                                });
86                            }
87                            matched_keys.push("schema");
88                        }
89
90                        if matched_keys.len() > 1 {
91                            eprintln!(
92                                "WARN: Ambiguous legacy expected block. Found keys: {:?}. Using first match.",
93                                matched_keys
94                            );
95                        }
96
97                        if let Some(p) = parsed {
98                            if i == 0 {
99                                expected_main = p;
100                            }
101                            // else: drop or move to assertions (out of scope for quick fix, primary policy is priority)
102                        }
103                    }
104                }
105            } else {
106                // Try V1 single object
107                if let Ok(exp) = serde_json::from_value(val.clone()) {
108                    expected_main = exp;
109                }
110            }
111        }
112
113        Ok(TestCase {
114            id: raw.id,
115            input: raw.input,
116            expected: expected_main,
117            assertions: if extra_assertions.is_empty() {
118                None
119            } else {
120                Some(extra_assertions)
121            },
122            on_error: raw.on_error,
123            tags: raw.tags,
124            metadata: raw.metadata,
125        })
126    }
127}
128
129impl<'de> Deserialize<'de> for TestInput {
130    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
131    where
132        D: serde::Deserializer<'de>,
133    {
134        struct TestInputVisitor;
135
136        impl<'de> serde::de::Visitor<'de> for TestInputVisitor {
137            type Value = TestInput;
138
139            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140                formatter.write_str("string or struct TestInput")
141            }
142
143            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
144            where
145                E: serde::de::Error,
146            {
147                Ok(TestInput {
148                    prompt: value.to_owned(),
149                    context: None,
150                })
151            }
152
153            fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
154            where
155                A: serde::de::MapAccess<'de>,
156            {
157                // Default derivation logic manually implemented or use intermediate struct
158                // Using intermediate struct is easier to avoid massive boilerplate
159                #[derive(Deserialize)]
160                struct Helper {
161                    prompt: String,
162                    #[serde(default)]
163                    context: Option<Vec<String>>,
164                }
165                let helper =
166                    Helper::deserialize(serde::de::value::MapAccessDeserializer::new(map))?;
167                Ok(TestInput {
168                    prompt: helper.prompt,
169                    context: helper.context,
170                })
171            }
172        }
173
174        deserializer.deserialize_any(TestInputVisitor)
175    }
176}