assay_core/model/
serde.rs1use crate::on_error::ErrorPolicy;
2use serde::Deserialize;
3
4use super::types::{Expected, TestCase, TestInput};
5
6impl<'de> Deserialize<'de> for TestCase {
7 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
8 where
9 D: serde::Deserializer<'de>,
10 {
11 #[derive(Deserialize)]
12 #[serde(deny_unknown_fields)]
13 struct RawTestCase {
14 id: String,
15 input: TestInput,
16 #[serde(default)]
17 expected: Option<serde_json::Value>,
18 assertions: Option<Vec<crate::agent_assertions::model::TraceAssertion>>,
19 #[serde(default)]
20 on_error: Option<ErrorPolicy>,
21 #[serde(default)]
22 tags: Vec<String>,
23 metadata: Option<serde_json::Value>,
24 }
25
26 let raw = RawTestCase::deserialize(deserializer)?;
27 let mut expected_main = Expected::default();
28 let extra_assertions = raw.assertions.unwrap_or_default();
29
30 if let Some(val) = raw.expected {
31 if let Some(arr) = val.as_array() {
32 for (i, item) in arr.iter().enumerate() {
34 if let Ok(exp) = serde_json::from_value::<Expected>(item.clone()) {
37 if i == 0 {
38 expected_main = exp;
39 }
40 } else if let Some(obj) = item.as_object() {
41 let mut parsed = None;
43 let mut matched_keys = Vec::new();
44
45 if let Some(r) = obj.get("$ref") {
46 parsed = Some(Expected::Reference {
47 path: r.as_str().unwrap_or("").to_string(),
48 });
49 matched_keys.push("$ref");
50 }
51
52 if let Some(mc) = obj.get("must_contain") {
54 let val = if mc.is_string() {
55 vec![mc.as_str().unwrap().to_string()]
56 } else {
57 serde_json::from_value(mc.clone()).unwrap_or_default()
58 };
59 if parsed.is_none() {
61 parsed = Some(Expected::MustContain { must_contain: val });
62 }
63 matched_keys.push("must_contain");
64 }
65
66 if obj.get("sequence").is_some() {
67 if parsed.is_none() {
68 parsed = Some(Expected::SequenceValid {
69 policy: None,
70 sequence: serde_json::from_value(
71 obj.get("sequence").unwrap().clone(),
72 )
73 .ok(),
74 rules: None,
75 });
76 }
77 matched_keys.push("sequence");
78 }
79
80 if obj.get("schema").is_some() {
81 if parsed.is_none() {
82 parsed = Some(Expected::ArgsValid {
83 policy: None,
84 schema: obj.get("schema").cloned(),
85 });
86 }
87 matched_keys.push("schema");
88 }
89
90 if matched_keys.len() > 1 {
91 eprintln!(
92 "WARN: Ambiguous legacy expected block. Found keys: {:?}. Using first match.",
93 matched_keys
94 );
95 }
96
97 if let Some(p) = parsed {
98 if i == 0 {
99 expected_main = p;
100 }
101 }
103 }
104 }
105 } else {
106 if let Ok(exp) = serde_json::from_value(val.clone()) {
108 expected_main = exp;
109 }
110 }
111 }
112
113 Ok(TestCase {
114 id: raw.id,
115 input: raw.input,
116 expected: expected_main,
117 assertions: if extra_assertions.is_empty() {
118 None
119 } else {
120 Some(extra_assertions)
121 },
122 on_error: raw.on_error,
123 tags: raw.tags,
124 metadata: raw.metadata,
125 })
126 }
127}
128
129impl<'de> Deserialize<'de> for TestInput {
130 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
131 where
132 D: serde::Deserializer<'de>,
133 {
134 struct TestInputVisitor;
135
136 impl<'de> serde::de::Visitor<'de> for TestInputVisitor {
137 type Value = TestInput;
138
139 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 formatter.write_str("string or struct TestInput")
141 }
142
143 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
144 where
145 E: serde::de::Error,
146 {
147 Ok(TestInput {
148 prompt: value.to_owned(),
149 context: None,
150 })
151 }
152
153 fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
154 where
155 A: serde::de::MapAccess<'de>,
156 {
157 #[derive(Deserialize)]
160 struct Helper {
161 prompt: String,
162 #[serde(default)]
163 context: Option<Vec<String>>,
164 }
165 let helper =
166 Helper::deserialize(serde::de::value::MapAccessDeserializer::new(map))?;
167 Ok(TestInput {
168 prompt: helper.prompt,
169 context: helper.context,
170 })
171 }
172 }
173
174 deserializer.deserialize_any(TestInputVisitor)
175 }
176}