1use serde::{Deserialize, Serialize};
4use std::fmt;
5
6#[derive(Debug, Clone)]
8pub struct Case<Inputs, Output = (), Metadata = ()> {
9 pub name: Option<String>,
11 pub inputs: Inputs,
13 pub expected_output: Option<Output>,
15 pub metadata: Option<Metadata>,
17 pub tags: Vec<String>,
19}
20
21impl<Inputs, Output, Metadata> Case<Inputs, Output, Metadata> {
22 pub fn new(inputs: Inputs) -> Self {
24 Self {
25 name: None,
26 inputs,
27 expected_output: None,
28 metadata: None,
29 tags: Vec::new(),
30 }
31 }
32
33 pub fn with_name(mut self, name: impl Into<String>) -> Self {
35 self.name = Some(name.into());
36 self
37 }
38
39 pub fn with_expected_output(mut self, output: Output) -> Self {
41 self.expected_output = Some(output);
42 self
43 }
44
45 pub fn with_metadata(mut self, metadata: Metadata) -> Self {
47 self.metadata = Some(metadata);
48 self
49 }
50
51 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
53 self.tags.push(tag.into());
54 self
55 }
56
57 pub fn with_tags(mut self, tags: impl IntoIterator<Item = impl Into<String>>) -> Self {
59 self.tags.extend(tags.into_iter().map(Into::into));
60 self
61 }
62
63 pub fn display_name(&self, index: usize) -> String {
65 self.name
66 .clone()
67 .unwrap_or_else(|| format!("case_{}", index))
68 }
69
70 pub fn has_tag(&self, tag: &str) -> bool {
72 self.tags.iter().any(|t| t == tag)
73 }
74}
75
76impl<Inputs: Default, Output, Metadata> Default for Case<Inputs, Output, Metadata> {
77 fn default() -> Self {
78 Self::new(Inputs::default())
79 }
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct EvalCase {
85 #[serde(default)]
87 pub name: Option<String>,
88 pub input: String,
90 #[serde(default)]
92 pub expected: Vec<Expected>,
93 #[serde(default)]
95 pub tags: Vec<String>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100#[serde(tag = "type", rename_all = "snake_case")]
101pub enum Expected {
102 Exact {
104 value: String,
106 },
107 Contains {
109 pattern: String,
111 },
112 Regex {
114 pattern: String,
116 },
117 Semantic {
119 text: String,
121 threshold: f32,
123 },
124 Custom {
126 name: String,
128 },
129}
130
131impl Expected {
132 pub fn exact(s: impl Into<String>) -> Self {
134 Self::Exact { value: s.into() }
135 }
136
137 pub fn contains(s: impl Into<String>) -> Self {
139 Self::Contains { pattern: s.into() }
140 }
141
142 pub fn regex(pattern: impl Into<String>) -> Self {
144 Self::Regex {
145 pattern: pattern.into(),
146 }
147 }
148
149 pub fn semantic(text: impl Into<String>, threshold: f32) -> Self {
151 Self::Semantic {
152 text: text.into(),
153 threshold,
154 }
155 }
156}
157
158impl fmt::Display for Expected {
159 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
160 match self {
161 Self::Exact { value } => write!(f, "exact({})", value),
162 Self::Contains { pattern } => write!(f, "contains({})", pattern),
163 Self::Regex { pattern } => write!(f, "regex({})", pattern),
164 Self::Semantic { text, threshold } => {
165 write!(f, "semantic({}, threshold={})", text, threshold)
166 }
167 Self::Custom { name } => write!(f, "custom({})", name),
168 }
169 }
170}
171
172impl EvalCase {
173 pub fn new() -> Self {
175 Self {
176 name: None,
177 input: String::new(),
178 expected: Vec::new(),
179 tags: Vec::new(),
180 }
181 }
182
183 pub fn input(mut self, input: impl Into<String>) -> Self {
185 self.input = input.into();
186 self
187 }
188
189 pub fn expected_exact(mut self, s: impl Into<String>) -> Self {
191 self.expected.push(Expected::exact(s));
192 self
193 }
194
195 pub fn expected_contains(mut self, s: impl Into<String>) -> Self {
197 self.expected.push(Expected::contains(s));
198 self
199 }
200
201 pub fn expected_regex(mut self, pattern: impl Into<String>) -> Self {
203 self.expected.push(Expected::regex(pattern));
204 self
205 }
206
207 pub fn expected_semantic(mut self, text: impl Into<String>, threshold: f32) -> Self {
209 self.expected.push(Expected::semantic(text, threshold));
210 self
211 }
212
213 pub fn name(mut self, name: impl Into<String>) -> Self {
215 self.name = Some(name.into());
216 self
217 }
218
219 pub fn tag(mut self, tag: impl Into<String>) -> Self {
221 self.tags.push(tag.into());
222 self
223 }
224
225 pub fn check(&self, output: &str) -> Vec<(&Expected, bool)> {
227 self.expected
228 .iter()
229 .map(|exp| {
230 let passed = match exp {
231 Expected::Exact { value } => output == value,
232 Expected::Contains { pattern } => output.contains(pattern),
233 Expected::Regex { pattern } => regex::Regex::new(pattern)
234 .map(|re| re.is_match(output))
235 .unwrap_or(false),
236 Expected::Semantic { .. } => false, Expected::Custom { .. } => false, };
239 (exp, passed)
240 })
241 .collect()
242 }
243
244 pub fn all_pass(&self, output: &str) -> bool {
246 self.check(output).iter().all(|(_, passed)| *passed)
247 }
248}
249
250impl Default for EvalCase {
251 fn default() -> Self {
252 Self::new()
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn test_case_new() {
262 let case: Case<String, String, ()> = Case::new("test input".to_string())
263 .with_name("test case")
264 .with_expected_output("expected".to_string())
265 .with_tag("unit");
266
267 assert_eq!(case.name, Some("test case".to_string()));
268 assert_eq!(case.inputs, "test input");
269 assert!(case.has_tag("unit"));
270 }
271
272 #[test]
273 fn test_case_display_name() {
274 let case: Case<String> = Case::new("input".to_string());
275 assert_eq!(case.display_name(0), "case_0");
276
277 let named = case.with_name("my_case");
278 assert_eq!(named.display_name(0), "my_case");
279 }
280
281 #[test]
282 fn test_eval_case_check() {
283 let case = EvalCase::new()
284 .input("What is 2+2?")
285 .expected_contains("4")
286 .expected_contains("four");
287
288 let results = case.check("The answer is 4");
289 assert_eq!(results.len(), 2);
290 assert!(results[0].1); assert!(!results[1].1); }
293
294 #[test]
295 fn test_eval_case_exact() {
296 let case = EvalCase::new().input("test").expected_exact("hello");
297
298 assert!(case.all_pass("hello"));
299 assert!(!case.all_pass("hello world"));
300 }
301
302 #[test]
303 fn test_expected_display() {
304 assert_eq!(Expected::exact("foo").to_string(), "exact(foo)");
305 assert_eq!(Expected::contains("bar").to_string(), "contains(bar)");
306 }
307
308 #[test]
309 fn test_eval_case_serialize() {
310 let case = EvalCase::new().input("hello").expected_contains("world");
311
312 let json = serde_json::to_string(&case).unwrap();
313 assert!(json.contains("hello"));
314 assert!(json.contains("contains"));
315 }
316}