1use crate::error::{NdimageError, NdimageResult};
14use scirs2_core::ndarray::{s, Array2, Array3};
15use std::f64::consts::PI;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum MatchMethod {
24 SumSquaredDiff,
26 NormalizedSumSquaredDiff,
28 NormalizedCrossCorrelation,
30 CoeffCorrelation,
32}
33
34pub fn template_match(
52 image: &Array2<f64>,
53 template: &Array2<f64>,
54 method: MatchMethod,
55) -> NdimageResult<Array2<f64>> {
56 let (ih, iw) = image.dim();
57 let (th, tw) = template.dim();
58
59 if th == 0 || tw == 0 {
60 return Err(NdimageError::InvalidInput(
61 "Template must not be empty".into(),
62 ));
63 }
64 if th > ih || tw > iw {
65 return Err(NdimageError::InvalidInput(
66 "Template must not be larger than the image".into(),
67 ));
68 }
69
70 match method {
71 MatchMethod::SumSquaredDiff => ssd_map(image, template, false),
72 MatchMethod::NormalizedSumSquaredDiff => ssd_map(image, template, true),
73 MatchMethod::NormalizedCrossCorrelation => normalized_cross_correlation(image, template),
74 MatchMethod::CoeffCorrelation => coeff_correlation(image, template),
75 }
76}
77
78fn ssd_map(
83 image: &Array2<f64>,
84 template: &Array2<f64>,
85 normalize: bool,
86) -> NdimageResult<Array2<f64>> {
87 let (ih, iw) = image.dim();
88 let (th, tw) = template.dim();
89 let out_h = ih - th + 1;
90 let out_w = iw - tw + 1;
91
92 let template_ss: f64 = template.iter().map(|&v| v * v).sum();
94
95 let mut result = Array2::zeros((out_h, out_w));
96
97 for r in 0..out_h {
98 for c in 0..out_w {
99 let patch = image.slice(s![r..r + th, c..c + tw]);
100 let mut ssd = 0.0;
101 for (iv, tv) in patch.iter().zip(template.iter()) {
102 let d = iv - tv;
103 ssd += d * d;
104 }
105
106 if normalize {
107 let patch_ss: f64 = patch.iter().map(|&v| v * v).sum();
109 let denom = (patch_ss * template_ss).sqrt();
110 result[[r, c]] = if denom > 1e-12 { ssd / denom } else { 0.0 };
111 } else {
112 result[[r, c]] = ssd;
113 }
114 }
115 }
116
117 Ok(result)
118}
119
120pub fn normalized_cross_correlation(
136 image: &Array2<f64>,
137 template: &Array2<f64>,
138) -> NdimageResult<Array2<f64>> {
139 let (ih, iw) = image.dim();
140 let (th, tw) = template.dim();
141
142 if th == 0 || tw == 0 {
143 return Err(NdimageError::InvalidInput(
144 "Template must not be empty".into(),
145 ));
146 }
147 if th > ih || tw > iw {
148 return Err(NdimageError::InvalidInput(
149 "Template must not be larger than the image".into(),
150 ));
151 }
152
153 let out_h = ih - th + 1;
154 let out_w = iw - tw + 1;
155
156 let template_norm: f64 = template.iter().map(|&v| v * v).sum::<f64>().sqrt();
157
158 let mut result = Array2::zeros((out_h, out_w));
159
160 for r in 0..out_h {
161 for c in 0..out_w {
162 let patch = image.slice(s![r..r + th, c..c + tw]);
163 let cross: f64 = patch.iter().zip(template.iter()).map(|(a, b)| a * b).sum();
164 let patch_norm: f64 = patch.iter().map(|&v| v * v).sum::<f64>().sqrt();
165 let denom = patch_norm * template_norm;
166 result[[r, c]] = if denom > 1e-12 { cross / denom } else { 0.0 };
167 }
168 }
169
170 Ok(result)
171}
172
173fn coeff_correlation(image: &Array2<f64>, template: &Array2<f64>) -> NdimageResult<Array2<f64>> {
178 let (ih, iw) = image.dim();
179 let (th, tw) = template.dim();
180 let out_h = ih - th + 1;
181 let out_w = iw - tw + 1;
182 let n = (th * tw) as f64;
183
184 let t_mean: f64 = template.iter().sum::<f64>() / n;
186 let t_centered: Vec<f64> = template.iter().map(|&v| v - t_mean).collect();
187 let t_std: f64 = t_centered.iter().map(|&v| v * v).sum::<f64>().sqrt();
188
189 let mut result = Array2::zeros((out_h, out_w));
190
191 for r in 0..out_h {
192 for c in 0..out_w {
193 let patch = image.slice(s![r..r + th, c..c + tw]);
194 let p_mean: f64 = patch.iter().sum::<f64>() / n;
195 let cross: f64 = patch
196 .iter()
197 .zip(t_centered.iter())
198 .map(|(a, b)| (a - p_mean) * b)
199 .sum();
200 let p_std: f64 = patch
201 .iter()
202 .map(|&v| (v - p_mean).powi(2))
203 .sum::<f64>()
204 .sqrt();
205 let denom = p_std * t_std;
206 result[[r, c]] = if denom > 1e-12 { cross / denom } else { 0.0 };
207 }
208 }
209
210 Ok(result)
211}
212
213pub fn find_matches(
234 correlation_map: &Array2<f64>,
235 threshold: f64,
236 min_distance: usize,
237) -> NdimageResult<Vec<(usize, usize, f64)>> {
238 let (rows, cols) = correlation_map.dim();
239 if rows == 0 || cols == 0 {
240 return Ok(Vec::new());
241 }
242
243 let mut candidates: Vec<(usize, usize, f64)> = correlation_map
245 .indexed_iter()
246 .filter_map(|((r, c), &score)| {
247 if score >= threshold {
248 Some((r, c, score))
249 } else {
250 None
251 }
252 })
253 .collect();
254
255 candidates.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
257
258 let mut accepted: Vec<(usize, usize, f64)> = Vec::new();
261 let min_dist_sq = (min_distance as f64) * (min_distance as f64);
262
263 'outer: for (r, c, score) in candidates {
264 for &(ar, ac, _) in &accepted {
265 let dr = r as f64 - ar as f64;
266 let dc = c as f64 - ac as f64;
267 if dr * dr + dc * dc < min_dist_sq {
268 continue 'outer;
269 }
270 }
271 accepted.push((r, c, score));
272 }
273
274 Ok(accepted)
275}
276
277fn downsample_2x(image: &Array2<f64>) -> Array2<f64> {
283 let (h, w) = image.dim();
284 let oh = h / 2;
285 let ow = w / 2;
286 if oh == 0 || ow == 0 {
287 return image.clone();
288 }
289 let mut out = Array2::zeros((oh, ow));
290 for r in 0..oh {
291 for c in 0..ow {
292 out[[r, c]] = 0.25
293 * (image[[2 * r, 2 * c]]
294 + image[[2 * r, 2 * c + 1]]
295 + image[[2 * r + 1, 2 * c]]
296 + image[[2 * r + 1, 2 * c + 1]]);
297 }
298 }
299 out
300}
301
302pub fn pyramid_template_match(
322 image: &Array2<f64>,
323 template: &Array2<f64>,
324 n_scales: usize,
325) -> NdimageResult<Vec<(usize, usize, f64, f64)>> {
326 if n_scales == 0 {
327 return Err(NdimageError::InvalidInput(
328 "n_scales must be at least 1".into(),
329 ));
330 }
331 if template.dim().0 == 0 || template.dim().1 == 0 {
332 return Err(NdimageError::InvalidInput(
333 "Template must not be empty".into(),
334 ));
335 }
336
337 let (th, tw) = template.dim();
338 let mut results: Vec<(usize, usize, f64, f64)> = Vec::new();
339
340 let mut current_image = image.clone();
341 let mut current_template = template.clone();
342 let mut scale = 1.0_f64;
343
344 for _lvl in 0..n_scales {
345 let (ih, iw) = current_image.dim();
346 let (cth, ctw) = current_template.dim();
347
348 if cth == 0 || ctw == 0 || cth > ih || ctw > iw {
350 break;
351 }
352
353 let ncc = normalized_cross_correlation(¤t_image, ¤t_template)?;
354
355 let threshold = 0.5;
357 let min_dist = (th.max(tw) / 2).max(1);
358 let local_matches = find_matches(&ncc, threshold, min_dist)?;
359
360 for (r, c, score) in local_matches {
361 let orig_r = (r as f64 / scale).round() as usize;
363 let orig_c = (c as f64 / scale).round() as usize;
364 results.push((orig_r, orig_c, score, scale));
365 }
366
367 current_image = downsample_2x(¤t_image);
369 current_template = downsample_2x(¤t_template);
370 scale *= 0.5;
371 }
372
373 results.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
375
376 let nms_dist: usize = (th.max(tw) / 2).max(1);
378 let min_dist_sq = (nms_dist as f64).powi(2);
379 let mut accepted: Vec<(usize, usize, f64, f64)> = Vec::new();
380
381 'outer: for (r, c, score, s) in results {
382 for &(ar, ac, _, _) in &accepted {
383 let dr = r as f64 - ar as f64;
384 let dc = c as f64 - ac as f64;
385 if dr * dr + dc * dc < min_dist_sq {
386 continue 'outer;
387 }
388 }
389 accepted.push((r, c, score, s));
390 }
391
392 Ok(accepted)
393}
394
395#[cfg(test)]
400mod tests {
401 use super::*;
402 use scirs2_core::ndarray::Array2;
403
404 fn checkerboard_image(rows: usize, cols: usize) -> Array2<f64> {
405 Array2::from_shape_fn(
406 (rows, cols),
407 |(r, c)| {
408 if (r + c) % 2 == 0 {
409 1.0
410 } else {
411 0.0
412 }
413 },
414 )
415 }
416
417 #[test]
418 fn test_ssd_perfect_match() {
419 let image: Array2<f64> = Array2::from_shape_vec(
420 (4, 4),
421 vec![
422 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0,
423 ],
424 )
425 .expect("shape ok");
426
427 let template: Array2<f64> =
428 Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).expect("shape ok");
429
430 let map = template_match(&image, &template, MatchMethod::SumSquaredDiff).expect("ssd ok");
431 assert!(
433 map[[0, 0]] < 1e-12,
434 "Expected zero SSD at perfect-match location"
435 );
436 }
437
438 #[test]
439 fn test_ncc_perfect_match() {
440 let img = checkerboard_image(6, 6);
441 let tpl = img.slice(s![1..3, 1..3]).to_owned();
442 let ncc = normalized_cross_correlation(&img, &tpl).expect("ncc ok");
443 let score = ncc[[1, 1]];
445 assert!(
446 score > 0.99,
447 "NCC at matching position should be ~1, got {score}"
448 );
449 }
450
451 #[test]
452 fn test_find_matches_basic() {
453 let mut map: Array2<f64> = Array2::zeros((10, 10));
454 map[[2, 3]] = 0.9;
455 map[[7, 8]] = 0.8;
456 map[[2, 4]] = 0.85; let matches = find_matches(&map, 0.7, 3).expect("matches ok");
459 assert!(!matches.is_empty());
460 assert_eq!(matches[0], (2, 3, 0.9));
462 }
463
464 #[test]
465 fn test_pyramid_match_runs() {
466 let image = checkerboard_image(32, 32);
467 let template: Array2<f64> = image.slice(s![4..8, 4..8]).to_owned();
468 let results = pyramid_template_match(&image, &template, 3).expect("pyramid ok");
469 assert!(!results.is_empty());
471 }
472
473 #[test]
474 fn test_template_larger_than_image_errors() {
475 let small: Array2<f64> = Array2::zeros((3, 3));
476 let large: Array2<f64> = Array2::zeros((5, 5));
477 let err = template_match(&small, &large, MatchMethod::SumSquaredDiff);
478 assert!(err.is_err());
479 }
480}