1use std::{
10 fs::{self, File},
11 io::{BufWriter, Write as IoWrite},
12 path::{Path, PathBuf},
13};
14
15use zer_core::{error::ZerError, scoring::MatchBand};
16
17#[derive(Debug, Clone)]
19pub struct AccuracyMetrics {
20 pub true_pos: usize,
21 pub false_pos: usize,
22 pub false_neg: usize,
23 pub precision: f32,
24 pub recall: f32,
25 pub f1: f32,
26}
27
28impl AccuracyMetrics {
29 pub fn from_counts(true_pos: usize, false_pos: usize, false_neg: usize) -> Self {
31 let precision = if true_pos + false_pos > 0 {
32 true_pos as f32 / (true_pos + false_pos) as f32
33 } else {
34 0.0
35 };
36 let recall = if true_pos + false_neg > 0 {
37 true_pos as f32 / (true_pos + false_neg) as f32
38 } else {
39 0.0
40 };
41 let f1 = if precision + recall > 0.0 {
42 2.0 * precision * recall / (precision + recall)
43 } else {
44 0.0
45 };
46 Self {
47 true_pos,
48 false_pos,
49 false_neg,
50 precision,
51 recall,
52 f1,
53 }
54 }
55}
56
57#[derive(Debug, Clone, serde::Serialize)]
59pub struct PairRecord {
60 pub run_id: String,
61 pub record_key_a: String,
62 pub source_a: Option<String>,
63 pub record_key_b: String,
64 pub source_b: Option<String>,
65 pub match_probability: f32,
66 pub predicted_match: bool,
67}
68
69#[derive(Debug, Clone, serde::Serialize)]
71struct SummaryRow {
72 library: String,
73 mode: String,
74 dataset: String,
75 run_id: String,
76 timestamp: String,
77 total_records: usize,
78 candidate_pairs: usize,
79 auto_matched: usize,
80 borderline: usize,
81 auto_rejected: usize,
82 elapsed_ms: u64,
83 true_pos: Option<usize>,
84 false_pos: Option<usize>,
85 false_neg: Option<usize>,
86 precision: Option<f32>,
87 recall: Option<f32>,
88 f1: Option<f32>,
89}
90
91pub struct BenchBatchSummary {
95 pub total_records: usize,
96 pub candidate_pairs: usize,
97 pub auto_matched: usize,
98 pub borderline: usize,
99 pub auto_rejected: usize,
100 pub elapsed_ms: u64,
101 pub link_mode: String,
102 pub dataset: String,
103}
104
105pub struct BenchResultWriter {
106 run_id: String,
107 out_dir: PathBuf,
108}
109
110impl BenchResultWriter {
111 pub fn new(out_dir: &Path, run_id: &str) -> Result<Self, ZerError> {
113 fs::create_dir_all(out_dir)
114 .map_err(|e| ZerError::Store(format!("cannot create output dir: {e}")))?;
115 Ok(Self {
116 run_id: run_id.to_owned(),
117 out_dir: out_dir.to_path_buf(),
118 })
119 }
120
121 pub fn write_pairs(&self, pairs: &[PairRecord]) -> Result<(), ZerError> {
123 let path = self.out_dir.join(format!("{}_pairs.ndjson", self.run_id));
124 let file = File::create(&path)
125 .map_err(|e| ZerError::Store(format!("cannot create pairs file: {e}")))?;
126 let mut w = BufWriter::new(file);
127 for pair in pairs {
128 let line = serde_json::to_string(pair)
129 .map_err(|e| ZerError::Store(format!("JSON serialise error: {e}")))?;
130 writeln!(w, "{line}").map_err(|e| ZerError::Store(format!("write error: {e}")))?;
131 }
132 Ok(())
133 }
134
135 pub fn write_summary(
142 &self,
143 summary: &BenchBatchSummary,
144 accuracy: Option<&AccuracyMetrics>,
145 ) -> Result<(), ZerError> {
146 self.write_summary_with_library(summary, accuracy, "zer")
147 }
148
149 pub fn write_summary_with_library(
155 &self,
156 summary: &BenchBatchSummary,
157 accuracy: Option<&AccuracyMetrics>,
158 library: &str,
159 ) -> Result<(), ZerError> {
160 let path = self.out_dir.join(format!("{}_summary.csv", self.run_id));
161 let file = File::create(&path)
162 .map_err(|e| ZerError::Store(format!("cannot create summary file: {e}")))?;
163
164 let timestamp = crate::time::utc_timestamp_iso();
165 let row = SummaryRow {
166 library: library.to_owned(),
167 mode: summary.link_mode.to_lowercase(),
168 dataset: summary.dataset.clone(),
169 run_id: self.run_id.clone(),
170 timestamp,
171 total_records: summary.total_records,
172 candidate_pairs: summary.candidate_pairs,
173 auto_matched: summary.auto_matched,
174 borderline: summary.borderline,
175 auto_rejected: summary.auto_rejected,
176 elapsed_ms: summary.elapsed_ms,
177 true_pos: accuracy.map(|a| a.true_pos),
178 false_pos: accuracy.map(|a| a.false_pos),
179 false_neg: accuracy.map(|a| a.false_neg),
180 precision: accuracy.map(|a| a.precision),
181 recall: accuracy.map(|a| a.recall),
182 f1: accuracy.map(|a| a.f1),
183 };
184
185 let mut wtr = csv::Writer::from_writer(file);
186 wtr.serialize(&row)
187 .map_err(|e| ZerError::Store(format!("CSV write error: {e}")))?;
188 wtr.flush()
189 .map_err(|e| ZerError::Store(format!("CSV flush error: {e}")))?;
190 Ok(())
191 }
192
193 pub fn write_scored_pairs_csv(&self, pairs: &[(f32, bool)]) -> Result<(), ZerError> {
199 let path = self
200 .out_dir
201 .join(format!("{}_scored_pairs.csv", self.run_id));
202 let file = File::create(&path)
203 .map_err(|e| ZerError::Store(format!("cannot create scored pairs file: {e}")))?;
204 let mut w = csv::Writer::from_writer(file);
205 w.write_record(["score", "is_match"])
206 .map_err(|e| ZerError::Store(format!("CSV write error: {e}")))?;
207 let mut sorted: Vec<(f32, bool)> = pairs.to_vec();
208 sorted.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
209 for (score, is_match) in &sorted {
210 w.write_record(&[score.to_string(), (*is_match as u8).to_string()])
211 .map_err(|e| ZerError::Store(format!("CSV write error: {e}")))?;
212 }
213 w.flush()
214 .map_err(|e| ZerError::Store(format!("CSV flush error: {e}")))?;
215 Ok(())
216 }
217
218 pub fn run_id(&self) -> &str {
219 &self.run_id
220 }
221
222 pub fn out_dir(&self) -> &Path {
223 &self.out_dir
224 }
225}
226
227pub fn band_to_match(band: MatchBand) -> bool {
229 matches!(band, MatchBand::AutoMatch)
230}
231
232#[cfg(test)]
235mod tests {
236 use super::*;
237 use tempfile::TempDir;
238
239 fn sample_summary(_dir: &TempDir) -> BenchBatchSummary {
240 BenchBatchSummary {
241 total_records: 100,
242 candidate_pairs: 500,
243 auto_matched: 400,
244 borderline: 50,
245 auto_rejected: 50,
246 elapsed_ms: 1200,
247 link_mode: "deduplicate".into(),
248 dataset: "test_dataset".into(),
249 }
250 }
251
252 #[test]
253 fn write_pairs_ndjson_line_count() {
254 let dir = TempDir::new().unwrap();
255 let writer = BenchResultWriter::new(dir.path(), "test_run").unwrap();
256
257 let pairs: Vec<PairRecord> = (0..5)
258 .map(|i| PairRecord {
259 run_id: "test_run".into(),
260 record_key_a: i.to_string(),
261 source_a: Some("brp".into()),
262 record_key_b: (i + 100).to_string(),
263 source_b: Some("kvk".into()),
264 match_probability: 0.9,
265 predicted_match: true,
266 })
267 .collect();
268
269 writer.write_pairs(&pairs).unwrap();
270
271 let path = dir.path().join("test_run_pairs.ndjson");
272 let content = std::fs::read_to_string(&path).unwrap();
273 let lines: Vec<&str> = content.lines().collect();
274 assert_eq!(lines.len(), 5, "NDJSON file must have exactly N lines");
275
276 for line in &lines {
278 let v: serde_json::Value = serde_json::from_str(line).unwrap();
279 assert!(v.get("run_id").is_some());
280 assert!(v.get("match_probability").is_some());
281 }
282 }
283
284 #[test]
285 fn write_summary_csv_no_accuracy() {
286 let dir = TempDir::new().unwrap();
287 let writer = BenchResultWriter::new(dir.path(), "run_no_acc").unwrap();
288 let summary = sample_summary(&dir);
289
290 writer.write_summary(&summary, None).unwrap();
291
292 let path = dir.path().join("run_no_acc_summary.csv");
293 let content = std::fs::read_to_string(&path).unwrap();
294 assert!(content.contains("zer"), "library field must be 'zer'");
295 assert!(content.contains("test_dataset"));
296 assert!(content.contains("100")); }
298
299 #[test]
300 fn write_summary_csv_with_accuracy() {
301 let dir = TempDir::new().unwrap();
302 let writer = BenchResultWriter::new(dir.path(), "run_acc").unwrap();
303 let summary = sample_summary(&dir);
304 let acc = AccuracyMetrics::from_counts(96, 4, 2);
305
306 writer.write_summary(&summary, Some(&acc)).unwrap();
307
308 let path = dir.path().join("run_acc_summary.csv");
309 let content = std::fs::read_to_string(&path).unwrap();
310 assert!(content.contains("96")); }
312
313 #[test]
314 fn accuracy_metrics_from_counts() {
315 let acc = AccuracyMetrics::from_counts(90, 10, 5);
316 assert!((acc.precision - 0.9).abs() < 0.001);
317 assert!((acc.recall - (90.0 / 95.0)).abs() < 0.001);
318 assert!(acc.f1 > 0.0 && acc.f1 < 1.0);
319 }
320
321 #[test]
322 fn accuracy_metrics_zero_denominator() {
323 let acc = AccuracyMetrics::from_counts(0, 0, 0);
324 assert_eq!(acc.precision, 0.0);
325 assert_eq!(acc.recall, 0.0);
326 assert_eq!(acc.f1, 0.0);
327 }
328}