1use std::{
10 fs::{self, File},
11 io::{BufWriter, Write as IoWrite},
12 path::{Path, PathBuf},
13};
14
15use zer_core::{error::ZerError, record::RecordId, scoring::MatchBand};
16
17#[derive(Debug, Clone)]
19pub struct AccuracyMetrics {
20 pub true_pos: usize,
21 pub false_pos: usize,
22 pub false_neg: usize,
23 pub precision: f32,
24 pub recall: f32,
25 pub f1: f32,
26}
27
28impl AccuracyMetrics {
29 pub fn from_counts(true_pos: usize, false_pos: usize, false_neg: usize) -> Self {
31 let precision = if true_pos + false_pos > 0 {
32 true_pos as f32 / (true_pos + false_pos) as f32
33 } else {
34 0.0
35 };
36 let recall = if true_pos + false_neg > 0 {
37 true_pos as f32 / (true_pos + false_neg) as f32
38 } else {
39 0.0
40 };
41 let f1 = if precision + recall > 0.0 {
42 2.0 * precision * recall / (precision + recall)
43 } else {
44 0.0
45 };
46 Self { true_pos, false_pos, false_neg, precision, recall, f1 }
47 }
48}
49
50#[derive(Debug, Clone, serde::Serialize)]
52pub struct PairRecord {
53 pub run_id: String,
54 pub record_id_a: RecordId,
55 pub source_a: Option<String>,
56 pub record_id_b: RecordId,
57 pub source_b: Option<String>,
58 pub match_probability: f32,
59 pub predicted_match: bool,
60}
61
62#[derive(Debug, Clone, serde::Serialize)]
64struct SummaryRow {
65 library: String,
66 mode: String,
67 dataset: String,
68 run_id: String,
69 timestamp: String,
70 total_records: usize,
71 candidate_pairs: usize,
72 auto_matched: usize,
73 borderline: usize,
74 auto_rejected: usize,
75 elapsed_ms: u64,
76 true_pos: Option<usize>,
77 false_pos: Option<usize>,
78 false_neg: Option<usize>,
79 precision: Option<f32>,
80 recall: Option<f32>,
81 f1: Option<f32>,
82}
83
84pub struct BenchBatchSummary {
88 pub total_records: usize,
89 pub candidate_pairs: usize,
90 pub auto_matched: usize,
91 pub borderline: usize,
92 pub auto_rejected: usize,
93 pub elapsed_ms: u64,
94 pub link_mode: String,
95 pub dataset: String,
96}
97
98pub struct BenchResultWriter {
99 run_id: String,
100 out_dir: PathBuf,
101}
102
103impl BenchResultWriter {
104 pub fn new(out_dir: &Path, run_id: &str) -> Result<Self, ZerError> {
106 fs::create_dir_all(out_dir)
107 .map_err(|e| ZerError::Store(format!("cannot create output dir: {e}")))?;
108 Ok(Self {
109 run_id: run_id.to_owned(),
110 out_dir: out_dir.to_path_buf(),
111 })
112 }
113
114 pub fn write_pairs(&self, pairs: &[PairRecord]) -> Result<(), ZerError> {
116 let path = self.out_dir.join(format!("{}_pairs.ndjson", self.run_id));
117 let file = File::create(&path)
118 .map_err(|e| ZerError::Store(format!("cannot create pairs file: {e}")))?;
119 let mut w = BufWriter::new(file);
120 for pair in pairs {
121 let line = serde_json::to_string(pair)
122 .map_err(|e| ZerError::Store(format!("JSON serialise error: {e}")))?;
123 writeln!(w, "{line}")
124 .map_err(|e| ZerError::Store(format!("write error: {e}")))?;
125 }
126 Ok(())
127 }
128
129 pub fn write_summary(
136 &self,
137 summary: &BenchBatchSummary,
138 accuracy: Option<&AccuracyMetrics>,
139 ) -> Result<(), ZerError> {
140 self.write_summary_with_library(summary, accuracy, "zer")
141 }
142
143 pub fn write_summary_with_library(
149 &self,
150 summary: &BenchBatchSummary,
151 accuracy: Option<&AccuracyMetrics>,
152 library: &str,
153 ) -> Result<(), ZerError> {
154 let path = self.out_dir.join(format!("{}_summary.csv", self.run_id));
155 let file = File::create(&path)
156 .map_err(|e| ZerError::Store(format!("cannot create summary file: {e}")))?;
157
158 let timestamp = crate::time::utc_timestamp_iso();
159 let row = SummaryRow {
160 library: library.to_owned(),
161 mode: summary.link_mode.to_lowercase(),
162 dataset: summary.dataset.clone(),
163 run_id: self.run_id.clone(),
164 timestamp,
165 total_records: summary.total_records,
166 candidate_pairs: summary.candidate_pairs,
167 auto_matched: summary.auto_matched,
168 borderline: summary.borderline,
169 auto_rejected: summary.auto_rejected,
170 elapsed_ms: summary.elapsed_ms,
171 true_pos: accuracy.map(|a| a.true_pos),
172 false_pos: accuracy.map(|a| a.false_pos),
173 false_neg: accuracy.map(|a| a.false_neg),
174 precision: accuracy.map(|a| a.precision),
175 recall: accuracy.map(|a| a.recall),
176 f1: accuracy.map(|a| a.f1),
177 };
178
179 let mut wtr = csv::Writer::from_writer(file);
180 wtr.serialize(&row)
181 .map_err(|e| ZerError::Store(format!("CSV write error: {e}")))?;
182 wtr.flush()
183 .map_err(|e| ZerError::Store(format!("CSV flush error: {e}")))?;
184 Ok(())
185 }
186
187 pub fn write_scored_pairs_csv(&self, pairs: &[(f32, bool)]) -> Result<(), ZerError> {
193 let path = self.out_dir.join(format!("{}_scored_pairs.csv", self.run_id));
194 let file = File::create(&path)
195 .map_err(|e| ZerError::Store(format!("cannot create scored pairs file: {e}")))?;
196 let mut w = csv::Writer::from_writer(file);
197 w.write_record(["score", "is_match"])
198 .map_err(|e| ZerError::Store(format!("CSV write error: {e}")))?;
199 let mut sorted: Vec<(f32, bool)> = pairs.to_vec();
200 sorted.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
201 for (score, is_match) in &sorted {
202 w.write_record(&[score.to_string(), (*is_match as u8).to_string()])
203 .map_err(|e| ZerError::Store(format!("CSV write error: {e}")))?;
204 }
205 w.flush().map_err(|e| ZerError::Store(format!("CSV flush error: {e}")))?;
206 Ok(())
207 }
208
209 pub fn run_id(&self) -> &str {
210 &self.run_id
211 }
212
213 pub fn out_dir(&self) -> &Path {
214 &self.out_dir
215 }
216}
217
218pub fn band_to_match(band: MatchBand) -> bool {
220 matches!(band, MatchBand::AutoMatch)
221}
222
223
224#[cfg(test)]
227mod tests {
228 use super::*;
229 use tempfile::TempDir;
230
231 fn sample_summary(_dir: &TempDir) -> BenchBatchSummary {
232 BenchBatchSummary {
233 total_records: 100,
234 candidate_pairs: 500,
235 auto_matched: 400,
236 borderline: 50,
237 auto_rejected: 50,
238 elapsed_ms: 1200,
239 link_mode: "deduplicate".into(),
240 dataset: "test_dataset".into(),
241 }
242 }
243
244 #[test]
245 fn write_pairs_ndjson_line_count() {
246 let dir = TempDir::new().unwrap();
247 let writer = BenchResultWriter::new(dir.path(), "test_run").unwrap();
248
249 let pairs: Vec<PairRecord> = (0..5).map(|i| PairRecord {
250 run_id: "test_run".into(),
251 record_id_a: i,
252 source_a: Some("brp".into()),
253 record_id_b: i + 100,
254 source_b: Some("kvk".into()),
255 match_probability: 0.9,
256 predicted_match: true,
257 }).collect();
258
259 writer.write_pairs(&pairs).unwrap();
260
261 let path = dir.path().join("test_run_pairs.ndjson");
262 let content = std::fs::read_to_string(&path).unwrap();
263 let lines: Vec<&str> = content.lines().collect();
264 assert_eq!(lines.len(), 5, "NDJSON file must have exactly N lines");
265
266 for line in &lines {
268 let v: serde_json::Value = serde_json::from_str(line).unwrap();
269 assert!(v.get("run_id").is_some());
270 assert!(v.get("match_probability").is_some());
271 }
272 }
273
274 #[test]
275 fn write_summary_csv_no_accuracy() {
276 let dir = TempDir::new().unwrap();
277 let writer = BenchResultWriter::new(dir.path(), "run_no_acc").unwrap();
278 let summary = sample_summary(&dir);
279
280 writer.write_summary(&summary, None).unwrap();
281
282 let path = dir.path().join("run_no_acc_summary.csv");
283 let content = std::fs::read_to_string(&path).unwrap();
284 assert!(content.contains("zer"), "library field must be 'zer'");
285 assert!(content.contains("test_dataset"));
286 assert!(content.contains("100")); }
288
289 #[test]
290 fn write_summary_csv_with_accuracy() {
291 let dir = TempDir::new().unwrap();
292 let writer = BenchResultWriter::new(dir.path(), "run_acc").unwrap();
293 let summary = sample_summary(&dir);
294 let acc = AccuracyMetrics::from_counts(96, 4, 2);
295
296 writer.write_summary(&summary, Some(&acc)).unwrap();
297
298 let path = dir.path().join("run_acc_summary.csv");
299 let content = std::fs::read_to_string(&path).unwrap();
300 assert!(content.contains("96")); }
302
303 #[test]
304 fn accuracy_metrics_from_counts() {
305 let acc = AccuracyMetrics::from_counts(90, 10, 5);
306 assert!((acc.precision - 0.9).abs() < 0.001);
307 assert!((acc.recall - (90.0 / 95.0)).abs() < 0.001);
308 assert!(acc.f1 > 0.0 && acc.f1 < 1.0);
309 }
310
311 #[test]
312 fn accuracy_metrics_zero_denominator() {
313 let acc = AccuracyMetrics::from_counts(0, 0, 0);
314 assert_eq!(acc.precision, 0.0);
315 assert_eq!(acc.recall, 0.0);
316 assert_eq!(acc.f1, 0.0);
317 }
318}