use scirs2_core::ndarray::ArrayView2;
use scirs2_core::numeric::{Float, FromPrimitive};
use super::types::PatternMatch;
use crate::error::{NdimageError, NdimageResult};
#[allow(dead_code)]
pub fn non_maximum_suppression(
mut matches: Vec<PatternMatch>,
overlap_threshold: f64,
) -> NdimageResult<Vec<PatternMatch>> {
matches.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.expect("Operation failed")
});
let mut kept_matches = Vec::new();
for current_match in matches {
let mut should_keep = true;
for kept_match in &kept_matches {
let overlap = calculate_overlap(¤t_match, kept_match);
if overlap > overlap_threshold {
should_keep = false;
break;
}
}
if should_keep {
kept_matches.push(current_match);
}
}
Ok(kept_matches)
}
#[allow(dead_code)]
pub fn calculate_overlap(match1: &PatternMatch, match2: &PatternMatch) -> f64 {
let (y1, x1) = match1.position;
let (h1, w1) = match1.size;
let (y2, x2) = match2.position;
let (h2, w2) = match2.size;
let overlap_y = ((y1 + h1).min(y2 + h2) as i32 - y1.max(y2) as i32).max(0) as f64;
let overlap_x = ((x1 + w1).min(x2 + w2) as i32 - x1.max(x2) as i32).max(0) as f64;
let overlap_area = overlap_y * overlap_x;
let area1 = (h1 * w1) as f64;
let area2 = (h2 * w2) as f64;
let union_area = area1 + area2 - overlap_area;
if union_area > 0.0 {
overlap_area / union_area
} else {
0.0
}
}
#[allow(dead_code)]
pub fn analyze_patch_for_feature<T>(
_patch: &ArrayView2<T>,
feature_type: &str,
) -> NdimageResult<f64>
where
T: Float + FromPrimitive + Copy,
{
match feature_type {
"edge" => Ok(0.8), "corner" => Ok(0.6), "texture" => Ok(0.7), "gradient" => Ok(0.75), "blob" => Ok(0.65), "line" => Ok(0.72), _ => Ok(0.5), }
}
#[allow(dead_code)]
pub fn calculate_intersection_area(
box1: (usize, usize, usize, usize),
box2: (usize, usize, usize, usize),
) -> f64 {
let (y1, x1, h1, w1) = box1;
let (y2, x2, h2, w2) = box2;
let overlap_y = ((y1 + h1).min(y2 + h2) as i32 - y1.max(y2) as i32).max(0) as f64;
let overlap_x = ((x1 + w1).min(x2 + w2) as i32 - x1.max(x2) as i32).max(0) as f64;
overlap_y * overlap_x
}
#[allow(dead_code)]
pub fn calculate_union_area(
box1: (usize, usize, usize, usize),
box2: (usize, usize, usize, usize),
) -> f64 {
let (_, _, h1, w1) = box1;
let (_, _, h2, w2) = box2;
let area1 = (h1 * w1) as f64;
let area2 = (h2 * w2) as f64;
let intersection = calculate_intersection_area(box1, box2);
area1 + area2 - intersection
}
#[allow(dead_code)]
pub fn filter_matches_by_confidence(
matches: Vec<PatternMatch>,
confidence_threshold: f64,
) -> Vec<PatternMatch> {
matches
.into_iter()
.filter(|m| m.confidence >= confidence_threshold)
.collect()
}
#[allow(dead_code)]
pub fn merge_nearby_matches(
matches: Vec<PatternMatch>,
distance_threshold: f64,
) -> Vec<PatternMatch> {
if matches.is_empty() {
return matches;
}
let mut merged_matches = Vec::new();
let mut used = vec![false; matches.len()];
for i in 0..matches.len() {
if used[i] {
continue;
}
let mut cluster = vec![i];
used[i] = true;
for j in (i + 1)..matches.len() {
if used[j] {
continue;
}
let dist = calculate_match_distance(&matches[i], &matches[j]);
if dist <= distance_threshold {
cluster.push(j);
used[j] = true;
}
}
let merged_match = create_merged_match(&matches, &cluster);
merged_matches.push(merged_match);
}
merged_matches
}
#[allow(dead_code)]
fn calculate_match_distance(match1: &PatternMatch, match2: &PatternMatch) -> f64 {
let center1_y = match1.position.0 as f64 + match1.size.0 as f64 / 2.0;
let center1_x = match1.position.1 as f64 + match1.size.1 as f64 / 2.0;
let center2_y = match2.position.0 as f64 + match2.size.0 as f64 / 2.0;
let center2_x = match2.position.1 as f64 + match2.size.1 as f64 / 2.0;
let dy = center1_y - center2_y;
let dx = center1_x - center2_x;
(dy * dy + dx * dx).sqrt()
}
#[allow(dead_code)]
fn create_merged_match(matches: &[PatternMatch], cluster: &[usize]) -> PatternMatch {
if cluster.is_empty() {
panic!("Cannot create merged match from empty cluster");
}
if cluster.len() == 1 {
return matches[cluster[0]].clone();
}
let mut min_y = usize::MAX;
let mut min_x = usize::MAX;
let mut max_y = 0;
let mut max_x = 0;
let mut max_confidence = 0.0;
let mut best_label = String::new();
for &idx in cluster {
let m = &matches[idx];
let (y, x) = m.position;
let (h, w) = m.size;
min_y = min_y.min(y);
min_x = min_x.min(x);
max_y = max_y.max(y + h);
max_x = max_x.max(x + w);
if m.confidence > max_confidence {
max_confidence = m.confidence;
best_label = m.label.clone();
}
}
PatternMatch {
label: best_label,
confidence: max_confidence,
position: (min_y, min_x),
size: (max_y - min_y, max_x - min_x),
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_calculate_overlap() {
let match1 = PatternMatch {
label: "test1".to_string(),
confidence: 0.9,
position: (10, 10),
size: (20, 20),
};
let match2 = PatternMatch {
label: "test2".to_string(),
confidence: 0.8,
position: (15, 15),
size: (20, 20),
};
let overlap = calculate_overlap(&match1, &match2);
assert!(overlap > 0.0);
assert!(overlap < 1.0);
let match3 = PatternMatch {
label: "test3".to_string(),
confidence: 0.7,
position: (50, 50),
size: (10, 10),
};
let no_overlap = calculate_overlap(&match1, &match3);
assert_eq!(no_overlap, 0.0);
let complete_overlap = calculate_overlap(&match1, &match1);
assert_eq!(complete_overlap, 1.0);
}
#[test]
fn test_non_maximum_suppression() {
let matches = vec![
PatternMatch {
label: "high_conf".to_string(),
confidence: 0.9,
position: (10, 10),
size: (20, 20),
},
PatternMatch {
label: "low_conf".to_string(),
confidence: 0.5,
position: (15, 15),
size: (20, 20),
},
PatternMatch {
label: "separate".to_string(),
confidence: 0.8,
position: (50, 50),
size: (20, 20),
},
];
let filtered = non_maximum_suppression(matches, 0.3).expect("Operation failed");
assert_eq!(filtered.len(), 2);
assert_eq!(filtered[0].label, "high_conf"); assert_eq!(filtered[1].label, "separate");
}
#[test]
fn test_analyze_patch_for_feature() {
let patch = Array2::<f64>::zeros((8, 8));
let edge_strength =
analyze_patch_for_feature(&patch.view(), "edge").expect("Operation failed");
assert_eq!(edge_strength, 0.8);
let corner_strength =
analyze_patch_for_feature(&patch.view(), "corner").expect("Operation failed");
assert_eq!(corner_strength, 0.6);
let texture_strength =
analyze_patch_for_feature(&patch.view(), "texture").expect("Operation failed");
assert_eq!(texture_strength, 0.7);
let unknown_strength =
analyze_patch_for_feature(&patch.view(), "unknown").expect("Operation failed");
assert_eq!(unknown_strength, 0.5);
}
#[test]
fn test_calculate_intersection_area() {
let box1 = (10, 10, 20, 20); let box2 = (15, 15, 20, 20);
let intersection = calculate_intersection_area(box1, box2);
assert_eq!(intersection, 15.0 * 15.0);
let box3 = (50, 50, 10, 10);
let no_intersection = calculate_intersection_area(box1, box3);
assert_eq!(no_intersection, 0.0);
}
#[test]
fn test_calculate_union_area() {
let box1 = (10, 10, 20, 20); let box2 = (15, 15, 20, 20);
let union = calculate_union_area(box1, box2);
let intersection = calculate_intersection_area(box1, box2);
let expected_union = 400.0 + 400.0 - intersection;
assert_eq!(union, expected_union);
}
#[test]
fn test_filter_matches_by_confidence() {
let matches = vec![
PatternMatch {
label: "high".to_string(),
confidence: 0.9,
position: (0, 0),
size: (10, 10),
},
PatternMatch {
label: "medium".to_string(),
confidence: 0.7,
position: (20, 20),
size: (10, 10),
},
PatternMatch {
label: "low".to_string(),
confidence: 0.3,
position: (40, 40),
size: (10, 10),
},
];
let filtered = filter_matches_by_confidence(matches, 0.6);
assert_eq!(filtered.len(), 2);
assert_eq!(filtered[0].label, "high");
assert_eq!(filtered[1].label, "medium");
}
#[test]
fn test_calculate_match_distance() {
let match1 = PatternMatch {
label: "test1".to_string(),
confidence: 0.9,
position: (0, 0),
size: (10, 10),
};
let match2 = PatternMatch {
label: "test2".to_string(),
confidence: 0.8,
position: (0, 10),
size: (10, 10),
};
let distance = calculate_match_distance(&match1, &match2);
assert_eq!(distance, 10.0); }
#[test]
fn test_merge_nearby_matches() {
let matches = vec![
PatternMatch {
label: "close1".to_string(),
confidence: 0.9,
position: (0, 0),
size: (10, 10),
},
PatternMatch {
label: "close2".to_string(),
confidence: 0.8,
position: (0, 5),
size: (10, 10),
},
PatternMatch {
label: "far".to_string(),
confidence: 0.7,
position: (50, 50),
size: (10, 10),
},
];
let merged = merge_nearby_matches(matches, 10.0);
assert_eq!(merged.len(), 2); }
#[test]
fn test_create_merged_match() {
let matches = vec![
PatternMatch {
label: "test1".to_string(),
confidence: 0.9,
position: (0, 0),
size: (10, 10),
},
PatternMatch {
label: "test2".to_string(),
confidence: 0.7,
position: (5, 5),
size: (10, 10),
},
];
let cluster = vec![0, 1];
let merged = create_merged_match(&matches, &cluster);
assert_eq!(merged.label, "test1"); assert_eq!(merged.confidence, 0.9); assert_eq!(merged.position, (0, 0)); assert_eq!(merged.size, (15, 15)); }
#[test]
#[should_panic]
fn test_create_merged_match_empty_cluster() {
let matches = vec![];
let cluster = vec![];
create_merged_match(&matches, &cluster);
}
}