Skip to main content

serdes_ai_evals/
case.rs

1//! Evaluation case definitions.
2
3use serde::{Deserialize, Serialize};
4use std::fmt;
5
6/// A single evaluation test case.
7#[derive(Debug, Clone)]
8pub struct Case<Inputs, Output = (), Metadata = ()> {
9    /// Test case name.
10    pub name: Option<String>,
11    /// Inputs to the task.
12    pub inputs: Inputs,
13    /// Expected output (if known).
14    pub expected_output: Option<Output>,
15    /// Metadata for evaluators.
16    pub metadata: Option<Metadata>,
17    /// Tags for filtering.
18    pub tags: Vec<String>,
19}
20
21impl<Inputs, Output, Metadata> Case<Inputs, Output, Metadata> {
22    /// Create a new case with inputs.
23    pub fn new(inputs: Inputs) -> Self {
24        Self {
25            name: None,
26            inputs,
27            expected_output: None,
28            metadata: None,
29            tags: Vec::new(),
30        }
31    }
32
33    /// Set the case name.
34    pub fn with_name(mut self, name: impl Into<String>) -> Self {
35        self.name = Some(name.into());
36        self
37    }
38
39    /// Set the expected output.
40    pub fn with_expected_output(mut self, output: Output) -> Self {
41        self.expected_output = Some(output);
42        self
43    }
44
45    /// Set the metadata.
46    pub fn with_metadata(mut self, metadata: Metadata) -> Self {
47        self.metadata = Some(metadata);
48        self
49    }
50
51    /// Add a tag.
52    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
53        self.tags.push(tag.into());
54        self
55    }
56
57    /// Add multiple tags.
58    pub fn with_tags(mut self, tags: impl IntoIterator<Item = impl Into<String>>) -> Self {
59        self.tags.extend(tags.into_iter().map(Into::into));
60        self
61    }
62
63    /// Get the display name.
64    pub fn display_name(&self, index: usize) -> String {
65        self.name
66            .clone()
67            .unwrap_or_else(|| format!("case_{}", index))
68    }
69
70    /// Check if case has a tag.
71    pub fn has_tag(&self, tag: &str) -> bool {
72        self.tags.iter().any(|t| t == tag)
73    }
74}
75
76impl<Inputs: Default, Output, Metadata> Default for Case<Inputs, Output, Metadata> {
77    fn default() -> Self {
78        Self::new(Inputs::default())
79    }
80}
81
82/// Legacy eval case for backward compatibility.
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct EvalCase {
85    /// Test case name.
86    #[serde(default)]
87    pub name: Option<String>,
88    /// Input prompt.
89    pub input: String,
90    /// Expected output patterns.
91    #[serde(default)]
92    pub expected: Vec<Expected>,
93    /// Tags for filtering.
94    #[serde(default)]
95    pub tags: Vec<String>,
96}
97
98/// Expected output criteria.
99#[derive(Debug, Clone, Serialize, Deserialize)]
100#[serde(tag = "type", rename_all = "snake_case")]
101pub enum Expected {
102    /// Exact match.
103    Exact {
104        /// Expected value.
105        value: String,
106    },
107    /// Contains substring.
108    Contains {
109        /// Pattern to find.
110        pattern: String,
111    },
112    /// Matches regex pattern.
113    Regex {
114        /// Regex pattern.
115        pattern: String,
116    },
117    /// Semantic similarity above threshold.
118    Semantic {
119        /// Expected text.
120        text: String,
121        /// Minimum similarity score.
122        threshold: f32,
123    },
124    /// Custom check function name.
125    Custom {
126        /// Function name.
127        name: String,
128    },
129}
130
131impl Expected {
132    /// Create an exact match expectation.
133    pub fn exact(s: impl Into<String>) -> Self {
134        Self::Exact { value: s.into() }
135    }
136
137    /// Create a contains expectation.
138    pub fn contains(s: impl Into<String>) -> Self {
139        Self::Contains { pattern: s.into() }
140    }
141
142    /// Create a regex expectation.
143    pub fn regex(pattern: impl Into<String>) -> Self {
144        Self::Regex {
145            pattern: pattern.into(),
146        }
147    }
148
149    /// Create a semantic similarity expectation.
150    pub fn semantic(text: impl Into<String>, threshold: f32) -> Self {
151        Self::Semantic {
152            text: text.into(),
153            threshold,
154        }
155    }
156}
157
158impl fmt::Display for Expected {
159    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
160        match self {
161            Self::Exact { value } => write!(f, "exact({})", value),
162            Self::Contains { pattern } => write!(f, "contains({})", pattern),
163            Self::Regex { pattern } => write!(f, "regex({})", pattern),
164            Self::Semantic { text, threshold } => {
165                write!(f, "semantic({}, threshold={})", text, threshold)
166            }
167            Self::Custom { name } => write!(f, "custom({})", name),
168        }
169    }
170}
171
172impl EvalCase {
173    /// Create a new eval case.
174    pub fn new() -> Self {
175        Self {
176            name: None,
177            input: String::new(),
178            expected: Vec::new(),
179            tags: Vec::new(),
180        }
181    }
182
183    /// Set the input.
184    pub fn input(mut self, input: impl Into<String>) -> Self {
185        self.input = input.into();
186        self
187    }
188
189    /// Add an exact match expectation.
190    pub fn expected_exact(mut self, s: impl Into<String>) -> Self {
191        self.expected.push(Expected::exact(s));
192        self
193    }
194
195    /// Add a contains expectation.
196    pub fn expected_contains(mut self, s: impl Into<String>) -> Self {
197        self.expected.push(Expected::contains(s));
198        self
199    }
200
201    /// Add a regex expectation.
202    pub fn expected_regex(mut self, pattern: impl Into<String>) -> Self {
203        self.expected.push(Expected::regex(pattern));
204        self
205    }
206
207    /// Add a semantic similarity expectation.
208    pub fn expected_semantic(mut self, text: impl Into<String>, threshold: f32) -> Self {
209        self.expected.push(Expected::semantic(text, threshold));
210        self
211    }
212
213    /// Set the name.
214    pub fn name(mut self, name: impl Into<String>) -> Self {
215        self.name = Some(name.into());
216        self
217    }
218
219    /// Add a tag.
220    pub fn tag(mut self, tag: impl Into<String>) -> Self {
221        self.tags.push(tag.into());
222        self
223    }
224
225    /// Check if all expectations are satisfied.
226    pub fn check(&self, output: &str) -> Vec<(&Expected, bool)> {
227        self.expected
228            .iter()
229            .map(|exp| {
230                let passed = match exp {
231                    Expected::Exact { value } => output == value,
232                    Expected::Contains { pattern } => output.contains(pattern),
233                    Expected::Regex { pattern } => regex::Regex::new(pattern)
234                        .map(|re| re.is_match(output))
235                        .unwrap_or(false),
236                    Expected::Semantic { .. } => false, // Requires embedding model
237                    Expected::Custom { .. } => false,   // Requires external handler
238                };
239                (exp, passed)
240            })
241            .collect()
242    }
243
244    /// Check if all expectations pass.
245    pub fn all_pass(&self, output: &str) -> bool {
246        self.check(output).iter().all(|(_, passed)| *passed)
247    }
248}
249
250impl Default for EvalCase {
251    fn default() -> Self {
252        Self::new()
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn test_case_new() {
262        let case: Case<String, String, ()> = Case::new("test input".to_string())
263            .with_name("test case")
264            .with_expected_output("expected".to_string())
265            .with_tag("unit");
266
267        assert_eq!(case.name, Some("test case".to_string()));
268        assert_eq!(case.inputs, "test input");
269        assert!(case.has_tag("unit"));
270    }
271
272    #[test]
273    fn test_case_display_name() {
274        let case: Case<String> = Case::new("input".to_string());
275        assert_eq!(case.display_name(0), "case_0");
276
277        let named = case.with_name("my_case");
278        assert_eq!(named.display_name(0), "my_case");
279    }
280
281    #[test]
282    fn test_eval_case_check() {
283        let case = EvalCase::new()
284            .input("What is 2+2?")
285            .expected_contains("4")
286            .expected_contains("four");
287
288        let results = case.check("The answer is 4");
289        assert_eq!(results.len(), 2);
290        assert!(results[0].1); // contains "4"
291        assert!(!results[1].1); // doesn't contain "four"
292    }
293
294    #[test]
295    fn test_eval_case_exact() {
296        let case = EvalCase::new().input("test").expected_exact("hello");
297
298        assert!(case.all_pass("hello"));
299        assert!(!case.all_pass("hello world"));
300    }
301
302    #[test]
303    fn test_expected_display() {
304        assert_eq!(Expected::exact("foo").to_string(), "exact(foo)");
305        assert_eq!(Expected::contains("bar").to_string(), "contains(bar)");
306    }
307
308    #[test]
309    fn test_eval_case_serialize() {
310        let case = EvalCase::new().input("hello").expected_contains("world");
311
312        let json = serde_json::to_string(&case).unwrap();
313        assert!(json.contains("hello"));
314        assert!(json.contains("contains"));
315    }
316}