1use std::collections::HashSet;
28use std::io::{BufRead, BufReader, BufWriter, Write};
29use std::path::Path;
30
31use serde::{Deserialize, Serialize};
32
33use crate::error::{EvalError, Result};
34use crate::report::EvaluationResult;
35use crate::schema::EvalCase;
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct AnnotationRecord {
40 pub case_id: String,
42 pub input: String,
44 pub expected_response: Option<String>,
46 pub actual_response: Option<String>,
48 pub verdict: Option<HumanVerdict>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct HumanVerdict {
55 pub score: f64,
57 pub reasoning: String,
59 pub annotator_id: String,
61}
62
63pub struct AnnotationStore;
65
66impl AnnotationStore {
67 pub fn export(
80 cases: &[EvalCase],
81 results: &[EvaluationResult],
82 output_path: impl AsRef<Path>,
83 ) -> Result<()> {
84 let file = std::fs::File::create(output_path.as_ref()).map_err(|e| {
85 EvalError::AnnotationError(format!(
86 "failed to create annotation file '{}': {e}",
87 output_path.as_ref().display()
88 ))
89 })?;
90 let mut writer = BufWriter::new(file);
91
92 for case in cases {
93 let input = case
95 .conversation
96 .iter()
97 .map(|turn| turn.user_content.get_text())
98 .collect::<Vec<_>>()
99 .join("\n");
100
101 let expected_response = case
103 .conversation
104 .last()
105 .and_then(|turn| turn.final_response.as_ref())
106 .map(|content| content.get_text());
107
108 let actual_response = results
110 .iter()
111 .find(|r| r.eval_id == case.eval_id)
112 .and_then(|r| r.turn_results.last())
113 .and_then(|tr| tr.actual_response.clone());
114
115 let record = AnnotationRecord {
116 case_id: case.eval_id.clone(),
117 input,
118 expected_response,
119 actual_response,
120 verdict: None,
121 };
122
123 let line = serde_json::to_string(&record).map_err(|e| {
124 EvalError::AnnotationError(format!(
125 "failed to serialize annotation record for case '{}': {e}",
126 case.eval_id
127 ))
128 })?;
129
130 writeln!(writer, "{line}").map_err(|e| {
131 EvalError::AnnotationError(format!("failed to write annotation line: {e}"))
132 })?;
133 }
134
135 writer.flush().map_err(|e| {
136 EvalError::AnnotationError(format!("failed to flush annotation file: {e}"))
137 })?;
138
139 Ok(())
140 }
141
142 pub fn import(
158 path: impl AsRef<Path>,
159 valid_case_ids: &HashSet<String>,
160 ) -> Result<(Vec<AnnotationRecord>, Vec<String>)> {
161 let file = std::fs::File::open(path.as_ref()).map_err(|e| {
162 EvalError::AnnotationError(format!(
163 "failed to open annotation file '{}': {e}",
164 path.as_ref().display()
165 ))
166 })?;
167 let reader = BufReader::new(file);
168
169 let mut records = Vec::new();
170 let mut warnings = Vec::new();
171
172 for (line_num, line_result) in reader.lines().enumerate() {
173 let line = line_result.map_err(|e| {
174 EvalError::AnnotationError(format!("failed to read line {}: {e}", line_num + 1))
175 })?;
176
177 let trimmed = line.trim();
179 if trimmed.is_empty() {
180 continue;
181 }
182
183 let record: AnnotationRecord = serde_json::from_str(trimmed).map_err(|e| {
184 EvalError::AnnotationError(format!(
185 "failed to parse annotation at line {}: {e}",
186 line_num + 1
187 ))
188 })?;
189
190 if valid_case_ids.contains(&record.case_id) {
191 records.push(record);
192 } else {
193 warnings.push(format!(
194 "unmatched case_id '{}' at line {}",
195 record.case_id,
196 line_num + 1
197 ));
198 }
199 }
200
201 Ok((records, warnings))
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::report::EvaluationResult;
209 use crate::schema::{ContentData, EvalCase, Turn};
210 use std::collections::HashMap;
211 use std::time::Duration;
212 use tempfile::NamedTempFile;
213
214 fn make_case(id: &str, input: &str, expected: Option<&str>) -> EvalCase {
215 let mut conversation = vec![Turn {
216 invocation_id: format!("inv_{id}"),
217 user_content: ContentData::text(input),
218 final_response: expected.map(ContentData::model_response),
219 intermediate_data: None,
220 }];
221
222 if expected.is_none() {
224 conversation[0].final_response = None;
225 }
226
227 EvalCase {
228 eval_id: id.to_string(),
229 description: String::new(),
230 conversation,
231 session_input: Default::default(),
232 tags: vec![],
233 metadata: None,
234 }
235 }
236
237 fn make_result(id: &str) -> EvaluationResult {
238 EvaluationResult::passed(id, HashMap::new(), Duration::from_millis(50))
239 }
240
241 #[test]
242 fn test_export_creates_jsonl_file() {
243 let cases = vec![
244 make_case("case_1", "Hello", Some("Hi there")),
245 make_case("case_2", "How are you?", None),
246 ];
247 let results = vec![make_result("case_1"), make_result("case_2")];
248
249 let tmp = NamedTempFile::new().unwrap();
250 let path = tmp.path().to_path_buf();
251
252 AnnotationStore::export(&cases, &results, &path).unwrap();
253
254 let content = std::fs::read_to_string(&path).unwrap();
255 let lines: Vec<&str> = content.lines().collect();
256 assert_eq!(lines.len(), 2);
257
258 let record: AnnotationRecord = serde_json::from_str(lines[0]).unwrap();
260 assert_eq!(record.case_id, "case_1");
261 assert_eq!(record.input, "Hello");
262 assert_eq!(record.expected_response, Some("Hi there".to_string()));
263 assert!(record.verdict.is_none());
264
265 let record: AnnotationRecord = serde_json::from_str(lines[1]).unwrap();
267 assert_eq!(record.case_id, "case_2");
268 assert_eq!(record.input, "How are you?");
269 assert_eq!(record.expected_response, None);
270 assert!(record.verdict.is_none());
271 }
272
273 #[test]
274 fn test_import_valid_records() {
275 let tmp = NamedTempFile::new().unwrap();
276 let path = tmp.path().to_path_buf();
277
278 let records = [
279 AnnotationRecord {
280 case_id: "case_1".to_string(),
281 input: "Hello".to_string(),
282 expected_response: Some("Hi".to_string()),
283 actual_response: Some("Hey".to_string()),
284 verdict: Some(HumanVerdict {
285 score: 0.9,
286 reasoning: "Good response".to_string(),
287 annotator_id: "reviewer_1".to_string(),
288 }),
289 },
290 AnnotationRecord {
291 case_id: "case_2".to_string(),
292 input: "Bye".to_string(),
293 expected_response: None,
294 actual_response: None,
295 verdict: None,
296 },
297 ];
298
299 let content: String = records
301 .iter()
302 .map(|r| serde_json::to_string(r).unwrap())
303 .collect::<Vec<_>>()
304 .join("\n");
305 std::fs::write(&path, content).unwrap();
306
307 let valid_ids: HashSet<String> =
308 ["case_1", "case_2"].iter().map(|s| s.to_string()).collect();
309 let (imported, warnings) = AnnotationStore::import(&path, &valid_ids).unwrap();
310
311 assert_eq!(imported.len(), 2);
312 assert!(warnings.is_empty());
313 assert_eq!(imported[0].case_id, "case_1");
314 assert!(imported[0].verdict.is_some());
315 assert_eq!(imported[0].verdict.as_ref().unwrap().score, 0.9);
316 }
317
318 #[test]
319 fn test_import_unmatched_case_ids_produce_warnings() {
320 let tmp = NamedTempFile::new().unwrap();
321 let path = tmp.path().to_path_buf();
322
323 let record = AnnotationRecord {
324 case_id: "unknown_case".to_string(),
325 input: "test".to_string(),
326 expected_response: None,
327 actual_response: None,
328 verdict: None,
329 };
330
331 let content = serde_json::to_string(&record).unwrap();
332 std::fs::write(&path, content).unwrap();
333
334 let valid_ids: HashSet<String> = ["case_1"].iter().map(|s| s.to_string()).collect();
335 let (imported, warnings) = AnnotationStore::import(&path, &valid_ids).unwrap();
336
337 assert!(imported.is_empty());
338 assert_eq!(warnings.len(), 1);
339 assert!(warnings[0].contains("unknown_case"));
340 }
341
342 #[test]
343 fn test_import_malformed_json_returns_error() {
344 let tmp = NamedTempFile::new().unwrap();
345 let path = tmp.path().to_path_buf();
346
347 std::fs::write(&path, "not valid json\n").unwrap();
348
349 let valid_ids: HashSet<String> = HashSet::new();
350 let result = AnnotationStore::import(&path, &valid_ids);
351
352 assert!(result.is_err());
353 let err = result.unwrap_err();
354 assert!(err.to_string().contains("annotation"));
355 }
356
357 #[test]
358 fn test_import_skips_empty_lines() {
359 let tmp = NamedTempFile::new().unwrap();
360 let path = tmp.path().to_path_buf();
361
362 let record = AnnotationRecord {
363 case_id: "case_1".to_string(),
364 input: "hello".to_string(),
365 expected_response: None,
366 actual_response: None,
367 verdict: None,
368 };
369
370 let line = serde_json::to_string(&record).unwrap();
371 let content = format!("\n{line}\n\n");
372 std::fs::write(&path, content).unwrap();
373
374 let valid_ids: HashSet<String> = ["case_1"].iter().map(|s| s.to_string()).collect();
375 let (imported, warnings) = AnnotationStore::import(&path, &valid_ids).unwrap();
376
377 assert_eq!(imported.len(), 1);
378 assert!(warnings.is_empty());
379 }
380
381 #[test]
382 fn test_export_import_round_trip() {
383 let cases = vec![
384 make_case("rt_1", "What is Rust?", Some("A systems programming language")),
385 make_case("rt_2", "Tell me a joke", Some("Why did the crab cross the road?")),
386 ];
387 let results = vec![make_result("rt_1"), make_result("rt_2")];
388
389 let tmp = NamedTempFile::new().unwrap();
390 let path = tmp.path().to_path_buf();
391
392 AnnotationStore::export(&cases, &results, &path).unwrap();
394
395 let valid_ids: HashSet<String> = cases.iter().map(|c| c.eval_id.clone()).collect();
397 let (imported, warnings) = AnnotationStore::import(&path, &valid_ids).unwrap();
398
399 assert!(warnings.is_empty());
400 assert_eq!(imported.len(), 2);
401
402 assert_eq!(imported[0].case_id, "rt_1");
404 assert_eq!(imported[0].input, "What is Rust?");
405 assert_eq!(
406 imported[0].expected_response,
407 Some("A systems programming language".to_string())
408 );
409
410 assert_eq!(imported[1].case_id, "rt_2");
411 assert_eq!(imported[1].input, "Tell me a joke");
412 assert_eq!(
413 imported[1].expected_response,
414 Some("Why did the crab cross the road?".to_string())
415 );
416 }
417
418 #[test]
419 fn test_export_nonexistent_directory_returns_error() {
420 let cases = vec![make_case("c1", "hi", None)];
421 let results = vec![];
422
423 let result =
424 AnnotationStore::export(&cases, &results, "/nonexistent/dir/annotations.jsonl");
425 assert!(result.is_err());
426 let err = result.unwrap_err();
427 assert!(err.to_string().contains("annotation"));
428 }
429
430 #[test]
431 fn test_import_nonexistent_file_returns_error() {
432 let valid_ids: HashSet<String> = HashSet::new();
433 let result = AnnotationStore::import("/nonexistent/file.jsonl", &valid_ids);
434 assert!(result.is_err());
435 let err = result.unwrap_err();
436 assert!(err.to_string().contains("annotation"));
437 }
438}