Skip to main content

zer_cluster/
threshold.rs

1use zer_core::scoring::{MatchBand, ModelParams, ScoredPair};
2
3/// Pairs partitioned by their match band.
4pub struct BandedPairs {
5    pub auto_match:  Vec<ScoredPair>,
6    pub borderline:  Vec<ScoredPair>,
7    pub auto_reject: Vec<ScoredPair>,
8}
9
10/// Classify each pair by `match_probability` vs the upper/lower thresholds in
11/// `params`. A pair is `AutoMatch` if `prob >= upper_threshold`, `AutoReject`
12/// if `prob < lower_threshold`, and `Borderline` otherwise.
13///
14/// The band already stored in `ScoredPair::band` is used directly, it must
15/// have been assigned by the same `ModelParams` that are passed here. If the
16/// stored band disagrees with the thresholds (e.g., params were updated after
17/// scoring), the stored band takes precedence so that provenance is preserved.
18pub fn partition_by_band(pairs: Vec<ScoredPair>, _params: &ModelParams) -> BandedPairs {
19    let mut auto_match  = Vec::new();
20    let mut borderline  = Vec::new();
21    let mut auto_reject = Vec::new();
22
23    for pair in pairs {
24        match pair.band {
25            MatchBand::AutoMatch  => auto_match.push(pair),
26            MatchBand::Borderline => borderline.push(pair),
27            MatchBand::AutoReject => auto_reject.push(pair),
28        }
29    }
30
31    BandedPairs { auto_match, borderline, auto_reject }
32}
33
34// ── Unit tests ────────────────────────────────────────────────────────────────
35
36#[cfg(test)]
37mod tests {
38    use super::*;
39    use zer_core::{comparison::ComparisonVector, scoring::MatchBand};
40
41    fn params() -> ModelParams {
42        ModelParams {
43            m: vec![],
44            u: vec![],
45            log_prior_odds: 0.0,
46            upper_threshold: 0.8,
47            lower_threshold: 0.2,
48        }
49    }
50
51    fn pair(a: u64, b: u64, prob: f32, band: MatchBand) -> ScoredPair {
52        ScoredPair {
53            record_a:          a,
54            record_b:          b,
55            match_weight:      0.0,
56            match_probability: prob,
57            vector:            ComparisonVector { record_a: a, record_b: b, levels: vec![] },
58            band,
59        }
60    }
61
62    #[test]
63    fn empty_input_returns_empty_partitions() {
64        let result = partition_by_band(vec![], &params());
65        assert!(result.auto_match.is_empty());
66        assert!(result.borderline.is_empty());
67        assert!(result.auto_reject.is_empty());
68    }
69
70    #[test]
71    fn all_auto_match() {
72        let pairs = vec![
73            pair(1, 2, 0.95, MatchBand::AutoMatch),
74            pair(3, 4, 0.90, MatchBand::AutoMatch),
75        ];
76        let result = partition_by_band(pairs, &params());
77        assert_eq!(result.auto_match.len(), 2);
78        assert!(result.borderline.is_empty());
79        assert!(result.auto_reject.is_empty());
80    }
81
82    #[test]
83    fn all_auto_reject() {
84        let pairs = vec![
85            pair(1, 2, 0.05, MatchBand::AutoReject),
86            pair(3, 4, 0.10, MatchBand::AutoReject),
87        ];
88        let result = partition_by_band(pairs, &params());
89        assert!(result.auto_match.is_empty());
90        assert!(result.borderline.is_empty());
91        assert_eq!(result.auto_reject.len(), 2);
92    }
93
94    #[test]
95    fn mixed_bands() {
96        let pairs = vec![
97            pair(1, 2, 0.95, MatchBand::AutoMatch),
98            pair(2, 3, 0.50, MatchBand::Borderline),
99            pair(4, 5, 0.05, MatchBand::AutoReject),
100        ];
101        let result = partition_by_band(pairs, &params());
102        assert_eq!(result.auto_match.len(), 1);
103        assert_eq!(result.borderline.len(), 1);
104        assert_eq!(result.auto_reject.len(), 1);
105    }
106}