1use zer_core::{
2 comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector},
3 error::ZerError,
4 scoring::ModelParams,
5};
6
7const N_LEVELS: usize = 4; pub fn e_step(vector: &ComparisonVector, params: &ModelParams) -> f32 {
14 let log_odds: f32 = params.log_prior_odds
15 + vector.levels.iter().enumerate()
16 .map(|(i, &level)| {
17 if level == ComparisonLevel::Null { return 0.0_f32; }
18 let l = level as usize;
19 let m = params.m[i][l].max(1e-9_f32);
20 let u = params.u[i][l].max(1e-9_f32);
21 (m / u).ln()
22 })
23 .sum::<f32>();
24 1.0 / (1.0 + (-log_odds).exp())
25}
26
27#[inline]
28fn e_step_p(batch: &ComparisonBatch, p: usize, params: &ModelParams) -> f32 {
29 let n_pairs = batch.n_pairs;
30 let log_odds: f32 = params.log_prior_odds
31 + (0..batch.n_fields)
32 .map(|f| {
33 let l_u8 = batch.levels[f * n_pairs + p];
34 if l_u8 == 255 { return 0.0_f32; } let l = l_u8 as usize;
36 let m = params.m[f][l].max(1e-9_f32);
37 let u = params.u[f][l].max(1e-9_f32);
38 (m / u).ln()
39 })
40 .sum::<f32>();
41 1.0 / (1.0 + (-log_odds).exp())
42}
43
44fn m_step(
47 batch: &ComparisonBatch,
48 posteriors: &[f32],
49 prev: &ModelParams,
50) -> ModelParams {
51 let n_fields = batch.n_fields;
52 let n_pairs = batch.n_pairs;
53
54 let mut m_num = vec![vec![0.0f32; N_LEVELS]; n_fields];
55 let mut u_num = vec![vec![0.0f32; N_LEVELS]; n_fields];
56
57 let mut total_match = 0.0f32;
58 let mut total_nonmatch = 0.0f32;
59
60 for p in 0..n_pairs {
61 total_match += posteriors[p];
62 total_nonmatch += 1.0 - posteriors[p];
63 }
64
65 for f in 0..n_fields {
69 let field_slice = &batch.levels[f * n_pairs..(f + 1) * n_pairs];
70 for p in 0..n_pairs {
71 let l_u8 = field_slice[p];
72 if l_u8 == 255 { continue; } let l = l_u8 as usize;
74 m_num[f][l] += posteriors[p];
75 u_num[f][l] += 1.0 - posteriors[p];
76 }
77 }
78
79 let total_match = total_match.max(1e-9);
80 let total_nonmatch = total_nonmatch.max(1e-9);
81
82 let mut m = vec![vec![1e-9f32; N_LEVELS]; n_fields];
83 let mut u = vec![vec![1e-9f32; N_LEVELS]; n_fields];
84
85 for f in 0..n_fields {
86 for l in 0..N_LEVELS {
87 m[f][l] = (m_num[f][l] / total_match).max(1e-9);
88 u[f][l] = (u_num[f][l] / total_nonmatch).max(1e-9);
89 }
90 let m_sum: f32 = m[f].iter().sum();
91 let u_sum: f32 = u[f].iter().sum();
92 for l in 0..N_LEVELS {
93 m[f][l] /= m_sum;
94 u[f][l] /= u_sum;
95 }
96 }
97
98 let lambda = (total_match / n_pairs as f32).max(0.001).min(0.999);
99 let log_prior = (lambda / (1.0 - lambda)).ln();
100
101 ModelParams {
102 m,
103 u,
104 log_prior_odds: log_prior,
105 upper_threshold: prev.upper_threshold,
106 lower_threshold: prev.lower_threshold,
107 }
108}
109
110fn params_delta(a: &ModelParams, b: &ModelParams) -> f32 {
113 let mut max_delta = 0.0f32;
114 for (am, bm) in a.m.iter().zip(b.m.iter()) {
115 for (&av, &bv) in am.iter().zip(bm.iter()) {
116 max_delta = max_delta.max((av - bv).abs());
117 }
118 }
119 for (au, bu) in a.u.iter().zip(b.u.iter()) {
120 for (&av, &bv) in au.iter().zip(bu.iter()) {
121 max_delta = max_delta.max((av - bv).abs());
122 }
123 }
124 max_delta
125}
126
127fn init_from_priors(n_fields: usize) -> ModelParams {
130 let m = vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields];
131 let u = vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields];
132 ModelParams {
133 m,
134 u,
135 log_prior_odds: 0.0,
136 upper_threshold: 0.9,
137 lower_threshold: 0.1,
138 }
139}
140
141pub fn estimate_lambda(batch: &ComparisonBatch) -> f32 {
145 if batch.n_pairs == 0 { return 0.01; }
146 let exact = ComparisonLevel::Exact as u8;
147 let n_pairs = batch.n_pairs;
148 let high_sim_count = (0..n_pairs)
149 .filter(|&p| {
150 (0..batch.n_fields).any(|f| batch.levels[f * n_pairs + p] == exact)
151 })
152 .count();
153 let raw = high_sim_count as f32 / n_pairs as f32;
154 raw.max(0.001).min(0.5)
155}
156
157pub fn auto_calibrate_thresholds(scores: &[f32]) -> (f32, f32) {
159 if scores.is_empty() { return (0.9, 0.1); }
160
161 let high: Vec<f32> = scores.iter().copied().filter(|&s| s >= 0.7).collect();
162 let low: Vec<f32> = scores.iter().copied().filter(|&s| s <= 0.3).collect();
163
164 let upper = if high.len() >= 10 {
165 let mut sorted = high.clone();
166 sorted.sort_by(f32::total_cmp);
167 sorted[(sorted.len() as f32 * 0.05) as usize].max(0.85)
168 } else {
169 0.9
170 };
171
172 let lower = if low.len() >= 10 {
173 let mut sorted = low.clone();
174 sorted.sort_by(f32::total_cmp);
175 sorted[(sorted.len() as f32 * 0.95) as usize].min(0.15)
176 } else {
177 0.1
178 };
179
180 (upper, lower)
181}
182
183pub fn run_em(
185 batch: &ComparisonBatch,
186 init: Option<ModelParams>,
187 max_iter: usize,
188) -> Result<ModelParams, ZerError> {
189 if batch.n_pairs == 0 {
190 return Err(ZerError::SchemaMismatch { expected: 1, got: 0 });
191 }
192
193 let n_fields = batch.n_fields;
194 if n_fields == 0 {
195 return Err(ZerError::EmptySchema);
196 }
197
198 let mut params = init.unwrap_or_else(|| {
199 let mut p = init_from_priors(n_fields);
200 let lambda = estimate_lambda(batch);
201 p.log_prior_odds = (lambda / (1.0 - lambda)).ln();
202 tracing::debug!(lambda, "auto-estimated prior match rate");
203 p
204 });
205
206 for iter in 0..max_iter {
207 let posteriors: Vec<f32> = (0..batch.n_pairs)
208 .map(|p| e_step_p(batch, p, ¶ms))
209 .collect();
210
211 let new_params = m_step(batch, &posteriors, ¶ms);
212 let delta = params_delta(¶ms, &new_params);
213
214 params = new_params;
215 tracing::debug!(iter, delta, "EM iteration");
216
217 if delta < 1e-6 {
218 tracing::info!(iter, "EM converged");
219 break;
220 }
221 }
222
223 Ok(params)
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use zer_core::comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector};
230
231 fn uniform_vector(id_a: u64, id_b: u64, n_fields: usize, level: ComparisonLevel) -> ComparisonVector {
232 ComparisonVector::new(id_a, id_b, vec![level; n_fields])
233 }
234
235 fn synthetic_batch(n_match: usize, n_nonmatch: usize, n_fields: usize) -> ComparisonBatch {
236 let mut vecs = Vec::with_capacity(n_match + n_nonmatch);
237 for i in 0..n_match {
238 vecs.push(uniform_vector(i as u64, (i + 1_000_000) as u64, n_fields, ComparisonLevel::Exact));
239 }
240 for i in 0..n_nonmatch {
241 vecs.push(uniform_vector((i + 2_000_000) as u64, (i + 3_000_000) as u64, n_fields, ComparisonLevel::None));
242 }
243 ComparisonBatch::from_vectors(&vecs)
244 }
245
246 #[test]
247 fn em_converges_on_synthetic_data() {
248 let batch = synthetic_batch(200, 800, 4);
249 let params = run_em(&batch, None, 100).expect("EM should succeed");
250 for f in 0..4 {
251 let exact_idx = ComparisonLevel::Exact as usize;
252 assert!(
253 params.m[f][exact_idx] > params.u[f][exact_idx],
254 "m[Exact] should exceed u[Exact] for field {f}: m={}, u={}",
255 params.m[f][exact_idx], params.u[f][exact_idx]
256 );
257 }
258 }
259
260 #[test]
261 fn em_warm_start_converges_faster() {
262 let batch = synthetic_batch(200, 800, 3);
263
264 let warm = ModelParams {
265 m: vec![vec![0.02, 0.06, 0.12, 0.78]; 3],
266 u: vec![vec![0.75, 0.12, 0.08, 0.05]; 3],
267 log_prior_odds: (0.2_f32 / 0.8_f32).ln(),
268 upper_threshold: 0.9,
269 lower_threshold: 0.1,
270 };
271
272 let params = run_em(&batch, Some(warm), 5).expect("warm start EM should succeed");
273 for f in 0..3 {
274 let exact_idx = ComparisonLevel::Exact as usize;
275 assert!(params.m[f][exact_idx] > params.u[f][exact_idx],
276 "warm-start: m[Exact] should exceed u[Exact] for field {f}");
277 }
278 }
279
280 #[test]
281 fn em_empty_batch_returns_error() {
282 let batch = ComparisonBatch::new(0, 0, vec![]);
283 let result = run_em(&batch, None, 50);
284 assert!(result.is_err(), "empty batch should return an error");
285 }
286
287 #[test]
288 fn estimate_lambda_all_exact() {
289 let batch = synthetic_batch(100, 0, 2);
290 let lambda = estimate_lambda(&batch);
291 assert_eq!(lambda, 0.5);
292 }
293
294 #[test]
295 fn estimate_lambda_all_none() {
296 let batch = synthetic_batch(0, 100, 2);
297 let lambda = estimate_lambda(&batch);
298 assert_eq!(lambda, 0.001);
299 }
300
301 #[test]
302 fn auto_calibrate_bimodal_distribution() {
303 let mut scores = vec![];
304 for _ in 0..50 { scores.push(0.95_f32); }
305 for _ in 0..200 { scores.push(0.05_f32); }
306 let (upper, lower) = auto_calibrate_thresholds(&scores);
307 assert!(upper >= 0.85, "upper threshold should be ≥ 0.85, got {upper}");
308 assert!(lower <= 0.15, "lower threshold should be ≤ 0.15, got {lower}");
309 }
310
311 #[test]
312 fn auto_calibrate_empty_returns_defaults() {
313 let (upper, lower) = auto_calibrate_thresholds(&[]);
314 assert_eq!(upper, 0.9);
315 assert_eq!(lower, 0.1);
316 }
317}