#![allow(dead_code)]
use crate::{Result, VisionError};
use scirs2_core::ndarray::{arr1, arr2, s, Array1, Array2, ArrayView2};
use scirs2_spatial::distance::{cosine, euclidean, EuclideanDistance}; use scirs2_spatial::kdtree::KDTree;
use std::collections::HashMap;
use torsh_tensor::Tensor;
#[derive(Debug, Clone)]
pub struct MatchingConfig {
pub method: MatchingMethod,
pub distance_metric: DistanceMetric,
pub ratio_threshold: f64,
pub max_distance: f64,
pub cross_check: bool,
pub ransac_threshold: f64,
}
impl Default for MatchingConfig {
fn default() -> Self {
Self {
method: MatchingMethod::BruteForce,
distance_metric: DistanceMetric::Euclidean,
ratio_threshold: 0.7,
max_distance: f64::INFINITY,
cross_check: true,
ransac_threshold: 3.0,
}
}
}
#[derive(Debug, Clone)]
pub enum MatchingMethod {
BruteForce,
KdTree,
LSH, FlannKmeans, }
#[derive(Debug, Clone)]
pub enum DistanceMetric {
Euclidean,
Cosine,
Hamming,
L1,
ChiSquared,
}
#[derive(Debug, Clone)]
pub struct Feature {
pub keypoint: Keypoint,
pub descriptor: Array1<f64>,
pub id: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct Keypoint {
pub x: f64,
pub y: f64,
pub scale: f64,
pub angle: f64,
pub response: f64,
}
impl Keypoint {
pub fn new(x: f64, y: f64, scale: f64, angle: f64, response: f64) -> Self {
Self {
x,
y,
scale,
angle,
response,
}
}
pub fn coordinates(&self) -> Array1<f64> {
Array1::from(vec![self.x, self.y])
}
}
#[derive(Debug, Clone)]
pub struct Match {
pub query_idx: usize,
pub train_idx: usize,
pub distance: f64,
pub confidence: f64,
}
impl Match {
pub fn new(query_idx: usize, train_idx: usize, distance: f64) -> Self {
Self {
query_idx,
train_idx,
distance,
confidence: 1.0 / (1.0 + distance),
}
}
}
pub struct FeatureMatcher {
config: MatchingConfig,
kdtree: Option<KDTree<f32, EuclideanDistance<f32>>>,
train_features: Vec<Feature>,
}
impl FeatureMatcher {
pub fn new(config: MatchingConfig) -> Self {
Self {
config,
kdtree: None,
train_features: Vec::new(),
}
}
pub fn train(&mut self, features: Vec<Feature>) -> Result<()> {
self.train_features = features;
match self.config.method {
MatchingMethod::KdTree => {
self.build_kdtree()?;
}
MatchingMethod::FlannKmeans => {
self.build_flann_index()?;
}
_ => {
}
}
Ok(())
}
fn build_kdtree(&mut self) -> Result<()> {
if self.train_features.is_empty() {
return Err(VisionError::InvalidInput(
"No training features provided".to_string(),
));
}
let descriptor_dim = self.train_features[0].descriptor.len();
let mut descriptors: Array2<f32> =
Array2::zeros((self.train_features.len(), descriptor_dim));
for (i, feature) in self.train_features.iter().enumerate() {
if feature.descriptor.len() != descriptor_dim {
return Err(VisionError::InvalidArgument(
"All descriptors must have the same dimension".to_string(),
));
}
for (j, &val) in feature.descriptor.iter().enumerate() {
descriptors[[i, j]] = val as f32;
}
}
let kdtree = KDTree::new(&descriptors)
.map_err(|e| VisionError::Other(anyhow::anyhow!("Failed to build KD-tree: {}", e)))?;
self.kdtree = Some(kdtree);
Ok(())
}
fn build_flann_index(&mut self) -> Result<()> {
self.build_kdtree()
}
pub fn match_features(&self, query_features: &[Feature]) -> Result<Vec<Match>> {
if self.train_features.is_empty() {
return Err(VisionError::InvalidInput("Matcher not trained".to_string()));
}
match self.config.method {
MatchingMethod::BruteForce => self.brute_force_matching(query_features),
MatchingMethod::KdTree => self.kdtree_matching(query_features),
MatchingMethod::FlannKmeans => self.flann_matching(query_features),
MatchingMethod::LSH => self.lsh_matching(query_features),
}
}
fn brute_force_matching(&self, query_features: &[Feature]) -> Result<Vec<Match>> {
let mut matches = Vec::new();
for (query_idx, query_feature) in query_features.iter().enumerate() {
let mut best_distance = f64::INFINITY;
let mut second_best_distance = f64::INFINITY;
let mut best_train_idx = 0;
for (train_idx, train_feature) in self.train_features.iter().enumerate() {
let distance =
self.compute_distance(&query_feature.descriptor, &train_feature.descriptor)?;
if distance < best_distance {
second_best_distance = best_distance;
best_distance = distance;
best_train_idx = train_idx;
} else if distance < second_best_distance {
second_best_distance = distance;
}
}
if best_distance < self.config.max_distance {
let ratio = if second_best_distance > 0.0 {
best_distance / second_best_distance
} else {
0.0
};
if ratio < self.config.ratio_threshold {
matches.push(Match::new(query_idx, best_train_idx, best_distance));
}
}
}
Ok(matches)
}
fn kdtree_matching(&self, query_features: &[Feature]) -> Result<Vec<Match>> {
let kdtree = self
.kdtree
.as_ref()
.ok_or_else(|| VisionError::InvalidInput("KD-tree not built".to_string()))?;
let mut matches = Vec::new();
for (query_idx, query_feature) in query_features.iter().enumerate() {
let query_desc_f32: Vec<f32> =
query_feature.descriptor.iter().map(|&x| x as f32).collect();
let (indices, distances) = kdtree
.query(&query_desc_f32, 2)
.map_err(|e| VisionError::Other(anyhow::anyhow!("KD-tree query failed: {}", e)))?;
if !distances.is_empty() && (distances[0] as f64) < self.config.max_distance {
let ratio = if distances.len() > 1 && distances[1] > 0.0 {
distances[0] / distances[1]
} else {
0.0
};
if (ratio as f64) < self.config.ratio_threshold {
matches.push(Match::new(query_idx, indices[0], distances[0] as f64));
}
}
}
Ok(matches)
}
fn flann_matching(&self, query_features: &[Feature]) -> Result<Vec<Match>> {
self.kdtree_matching(query_features)
}
fn lsh_matching(&self, query_features: &[Feature]) -> Result<Vec<Match>> {
self.brute_force_matching(query_features)
}
fn compute_distance(&self, desc1: &Array1<f64>, desc2: &Array1<f64>) -> Result<f64> {
if desc1.len() != desc2.len() {
return Err(VisionError::InvalidArgument(
"Descriptors must have same dimension".to_string(),
));
}
let distance = match self.config.distance_metric {
DistanceMetric::Euclidean => {
let diff = desc1 - desc2;
(diff.mapv(|x| x * x).sum()).sqrt()
}
DistanceMetric::L1 => {
let diff = desc1 - desc2;
diff.mapv(|x| x.abs()).sum()
}
DistanceMetric::Cosine => {
let dot_product = (desc1 * desc2).sum();
let norm1 = (desc1.mapv(|x| x * x).sum()).sqrt();
let norm2 = (desc2.mapv(|x| x * x).sum()).sqrt();
if norm1 > 0.0 && norm2 > 0.0 {
1.0 - dot_product / (norm1 * norm2)
} else {
1.0
}
}
DistanceMetric::ChiSquared => {
let mut chi_squared = 0.0;
for (_i, (&a, &b)) in desc1.iter().zip(desc2.iter()).enumerate() {
if a + b > 0.0 {
chi_squared += (a - b).powi(2) / (a + b);
}
}
chi_squared * 0.5
}
_ => {
let diff = desc1 - desc2;
(diff.mapv(|x| x * x).sum()).sqrt()
}
};
Ok(distance)
}
pub fn filter_matches_geometric(
&self,
matches: &[Match],
_query_keypoints: &[Keypoint],
_train_keypoints: &[Keypoint],
) -> Result<Vec<Match>> {
if self.config.cross_check {
self.cross_check_matches(matches)
} else {
Ok(matches.to_vec())
}
}
fn cross_check_matches(&self, matches: &[Match]) -> Result<Vec<Match>> {
let mut filtered_matches = Vec::new();
filtered_matches.extend_from_slice(matches);
Ok(filtered_matches)
}
}
pub struct TemplateMatcher {
templates: Vec<Template>,
config: MatchingConfig,
}
#[derive(Debug, Clone)]
pub struct Template {
pub pattern: Array2<f64>,
pub id: String,
pub threshold: f64,
}
#[derive(Debug, Clone)]
pub struct TemplateMatch {
pub template_id: String,
pub position: (f64, f64),
pub confidence: f64,
pub bounding_box: (f64, f64, f64, f64), }
impl TemplateMatcher {
pub fn new(config: MatchingConfig) -> Self {
Self {
templates: Vec::new(),
config,
}
}
pub fn add_template(&mut self, template: Template) {
self.templates.push(template);
}
pub fn match_templates(&self, image: &Array2<f64>) -> Result<Vec<TemplateMatch>> {
let mut all_matches = Vec::new();
for template in &self.templates {
let matches = self.match_single_template(image, template)?;
all_matches.extend(matches);
}
Ok(all_matches)
}
fn match_single_template(
&self,
image: &Array2<f64>,
template: &Template,
) -> Result<Vec<TemplateMatch>> {
let mut matches = Vec::new();
let (img_height, img_width) = image.dim();
let (tmpl_height, tmpl_width) = template.pattern.dim();
if tmpl_height > img_height || tmpl_width > img_width {
return Ok(matches);
}
for y in 0..=(img_height - tmpl_height) {
for x in 0..=(img_width - tmpl_width) {
let window = template.pattern.clone();
let confidence = self.compute_template_similarity(&window, &template.pattern)?;
if confidence > template.threshold {
matches.push(TemplateMatch {
template_id: template.id.clone(),
position: (x as f64, y as f64),
confidence,
bounding_box: (x as f64, y as f64, tmpl_width as f64, tmpl_height as f64),
});
}
}
}
Ok(matches)
}
fn compute_template_similarity(
&self,
window: &Array2<f64>,
template: &Array2<f64>,
) -> Result<f64> {
let window_mean = window.mean().unwrap_or(0.0);
let template_mean = template.mean().unwrap_or(0.0);
let mut numerator = 0.0;
let mut window_var = 0.0;
let mut template_var = 0.0;
for ((i, j), &window_val) in window.indexed_iter() {
let template_val = template[[i, j]];
let window_centered = window_val - window_mean;
let template_centered = template_val - template_mean;
numerator += window_centered * template_centered;
window_var += window_centered * window_centered;
template_var += template_centered * template_centered;
}
let denominator = (window_var * template_var).sqrt();
if denominator > 0.0 {
Ok(numerator / denominator)
} else {
Ok(0.0)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_keypoint_creation() {
let keypoint = Keypoint::new(10.0, 20.0, 1.0, 0.0, 0.8);
assert_eq!(keypoint.x, 10.0);
assert_eq!(keypoint.y, 20.0);
assert_eq!(keypoint.scale, 1.0);
}
#[test]
fn test_feature_matcher_creation() {
let config = MatchingConfig::default();
let matcher = FeatureMatcher::new(config);
assert!(matcher.train_features.is_empty());
}
#[test]
fn test_match_creation() {
let match_result = Match::new(0, 1, 0.5);
assert_eq!(match_result.query_idx, 0);
assert_eq!(match_result.train_idx, 1);
assert_eq!(match_result.distance, 0.5);
assert!(match_result.confidence > 0.0);
}
#[test]
fn test_template_matcher() {
let config = MatchingConfig::default();
let mut matcher = TemplateMatcher::new(config);
let template = Template {
pattern: arr2(&[[1.0, 0.0], [0.0, 1.0]]),
id: "test_template".to_string(),
threshold: 0.5,
};
matcher.add_template(template);
assert_eq!(matcher.templates.len(), 1);
}
}