use crate::core::{Feature, ColmapError};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct FeatureMatch {
pub query_idx: usize,
pub train_idx: usize,
pub distance: f64,
pub confidence: f64,
}
impl FeatureMatch {
pub fn new(query_idx: usize, train_idx: usize, distance: f64) -> Self {
Self {
query_idx,
train_idx,
distance,
confidence: 1.0 / (1.0 + distance),
}
}
pub fn is_valid(&self, max_distance: f64) -> bool {
self.distance <= max_distance
}
}
pub trait FeatureMatcher: Send + Sync {
fn match_features(
&self,
features1: &[Feature],
features2: &[Feature],
) -> Result<Vec<FeatureMatch>, ColmapError>;
fn name(&self) -> &str;
fn params(&self) -> HashMap<String, f64>;
fn set_params(&mut self, params: HashMap<String, f64>) -> Result<(), ColmapError>;
}
#[derive(Debug, Clone, PartialEq)]
pub enum MatcherType {
BruteForce,
Flann,
RatioTest,
CrossCheck,
}
#[derive(Debug, Clone, PartialEq)]
pub enum DistanceType {
L2,
L1,
Hamming,
Cosine,
}
#[derive(Debug, Clone)]
pub struct MatcherConfig {
pub matcher_type: MatcherType,
pub distance_type: DistanceType,
pub max_distance: f64,
pub ratio_threshold: f64,
pub cross_check: bool,
pub max_matches: usize,
pub min_confidence: f64,
}
impl Default for MatcherConfig {
fn default() -> Self {
Self {
matcher_type: MatcherType::BruteForce,
distance_type: DistanceType::L2,
max_distance: 0.7,
ratio_threshold: 0.8,
cross_check: true,
max_matches: 1000,
min_confidence: 0.1,
}
}
}
pub struct MatcherFactory;
impl MatcherFactory {
pub fn create(config: &MatcherConfig) -> Result<Box<dyn FeatureMatcher>, ColmapError> {
match config.matcher_type {
MatcherType::BruteForce => {
Ok(Box::new(BruteForceMatcher::new(config)?))
},
MatcherType::Flann => {
Ok(Box::new(FlannMatcher::new(config)?))
},
MatcherType::RatioTest => {
Ok(Box::new(RatioTestMatcher::new(config)?))
},
MatcherType::CrossCheck => {
Ok(Box::new(CrossCheckMatcher::new(config)?))
},
}
}
}
pub struct BruteForceMatcher {
config: MatcherConfig,
}
impl BruteForceMatcher {
pub fn new(config: &MatcherConfig) -> Result<Self, ColmapError> {
Ok(Self {
config: config.clone(),
})
}
fn compute_distance(&self, desc1: &[u8], desc2: &[u8]) -> f64 {
match self.config.distance_type {
DistanceType::L2 => {
desc1.iter().zip(desc2.iter())
.map(|(a, b)| (*a as f64 - *b as f64).powi(2))
.sum::<f64>().sqrt()
},
DistanceType::L1 => {
desc1.iter().zip(desc2.iter())
.map(|(a, b)| (*a as f64 - *b as f64).abs())
.sum::<f64>()
},
DistanceType::Hamming => {
desc1.iter().zip(desc2.iter())
.map(|(a, b)| (a ^ b).count_ones() as f64)
.sum::<f64>()
},
DistanceType::Cosine => {
let dot_product: f64 = desc1.iter().zip(desc2.iter())
.map(|(a, b)| *a as f64 * *b as f64)
.sum();
let norm1: f64 = desc1.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
let norm2: f64 = desc2.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
if norm1 == 0.0 || norm2 == 0.0 {
1.0
} else {
1.0 - (dot_product / (norm1 * norm2))
}
},
}
}
}
impl FeatureMatcher for BruteForceMatcher {
fn match_features(
&self,
features1: &[Feature],
features2: &[Feature],
) -> Result<Vec<FeatureMatch>, ColmapError> {
let mut matches = Vec::new();
for (i, feat1) in features1.iter().enumerate() {
if feat1.descriptor.is_empty() {
continue;
}
let mut best_distance = f64::INFINITY;
let mut best_idx = None;
for (j, feat2) in features2.iter().enumerate() {
if feat2.descriptor.is_empty() || feat1.descriptor.len() != feat2.descriptor.len() {
continue;
}
let distance = self.compute_distance(&feat1.descriptor, &feat2.descriptor);
if distance < best_distance && distance <= self.config.max_distance {
best_distance = distance;
best_idx = Some(j);
}
}
if let Some(j) = best_idx {
let match_result = FeatureMatch::new(i, j, best_distance);
if match_result.confidence >= self.config.min_confidence {
matches.push(match_result);
}
}
}
matches.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
matches.truncate(self.config.max_matches);
Ok(matches)
}
fn name(&self) -> &str {
"BruteForce"
}
fn params(&self) -> HashMap<String, f64> {
let mut params = HashMap::new();
params.insert("max_distance".to_string(), self.config.max_distance);
params.insert("max_matches".to_string(), self.config.max_matches as f64);
params.insert("min_confidence".to_string(), self.config.min_confidence);
params
}
fn set_params(&mut self, params: HashMap<String, f64>) -> Result<(), ColmapError> {
for (key, value) in params {
match key.as_str() {
"max_distance" => self.config.max_distance = value,
"max_matches" => self.config.max_matches = value as usize,
"min_confidence" => self.config.min_confidence = value,
_ => return Err(ColmapError::InvalidParameter(format!("Unknown parameter: {}", key))),
}
}
Ok(())
}
}
pub struct FlannMatcher {
config: MatcherConfig,
}
impl FlannMatcher {
pub fn new(config: &MatcherConfig) -> Result<Self, ColmapError> {
Ok(Self {
config: config.clone(),
})
}
}
impl FeatureMatcher for FlannMatcher {
fn match_features(
&self,
features1: &[Feature],
features2: &[Feature],
) -> Result<Vec<FeatureMatch>, ColmapError> {
let mut matches = Vec::new();
for (i, feat1) in features1.iter().enumerate() {
if feat1.descriptor.is_empty() {
continue;
}
let mut best_distance = f64::INFINITY;
let mut best_idx = None;
for (j, feat2) in features2.iter().enumerate() {
if feat2.descriptor.is_empty() || feat1.descriptor.len() != feat2.descriptor.len() {
continue;
}
let distance = self.compute_distance(&feat1.descriptor, &feat2.descriptor);
if distance < best_distance && distance <= self.config.max_distance {
best_distance = distance;
best_idx = Some(j);
}
}
if let Some(j) = best_idx {
let match_result = FeatureMatch::new(i, j, best_distance);
if match_result.confidence >= self.config.min_confidence {
matches.push(match_result);
}
}
}
matches.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
matches.truncate(self.config.max_matches);
Ok(matches)
}
fn name(&self) -> &str {
"FLANN"
}
fn params(&self) -> HashMap<String, f64> {
let mut params = HashMap::new();
params.insert("max_distance".to_string(), self.config.max_distance);
params.insert("max_matches".to_string(), self.config.max_matches as f64);
params.insert("min_confidence".to_string(), self.config.min_confidence);
params
}
fn set_params(&mut self, params: HashMap<String, f64>) -> Result<(), ColmapError> {
for (key, value) in params {
match key.as_str() {
"max_distance" => self.config.max_distance = value,
"max_matches" => self.config.max_matches = value as usize,
"min_confidence" => self.config.min_confidence = value,
_ => return Err(ColmapError::InvalidParameter(format!("Unknown parameter: {}", key))),
}
}
Ok(())
}
}
impl FlannMatcher {
fn compute_distance(&self, desc1: &[u8], desc2: &[u8]) -> f64 {
match self.config.distance_type {
DistanceType::L2 => {
desc1.iter().zip(desc2.iter())
.map(|(a, b)| (*a as f64 - *b as f64).powi(2))
.sum::<f64>().sqrt()
},
DistanceType::L1 => {
desc1.iter().zip(desc2.iter())
.map(|(a, b)| (*a as f64 - *b as f64).abs())
.sum::<f64>()
},
DistanceType::Hamming => {
desc1.iter().zip(desc2.iter())
.map(|(a, b)| (a ^ b).count_ones() as f64)
.sum::<f64>()
},
DistanceType::Cosine => {
let dot_product: f64 = desc1.iter().zip(desc2.iter())
.map(|(a, b)| *a as f64 * *b as f64)
.sum();
let norm1: f64 = desc1.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
let norm2: f64 = desc2.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
if norm1 == 0.0 || norm2 == 0.0 {
1.0
} else {
1.0 - (dot_product / (norm1 * norm2))
}
},
}
}
}
pub struct RatioTestMatcher {
config: MatcherConfig,
base_matcher: BruteForceMatcher,
}
impl RatioTestMatcher {
pub fn new(config: &MatcherConfig) -> Result<Self, ColmapError> {
let base_matcher = BruteForceMatcher::new(config)?;
Ok(Self {
config: config.clone(),
base_matcher,
})
}
}
impl FeatureMatcher for RatioTestMatcher {
fn match_features(
&self,
features1: &[Feature],
features2: &[Feature],
) -> Result<Vec<FeatureMatch>, ColmapError> {
let mut matches = Vec::new();
for (i, feat1) in features1.iter().enumerate() {
if feat1.descriptor.is_empty() {
continue;
}
let mut distances: Vec<(usize, f64)> = Vec::new();
for (j, feat2) in features2.iter().enumerate() {
if feat2.descriptor.is_empty() || feat1.descriptor.len() != feat2.descriptor.len() {
continue;
}
let distance = self.base_matcher.compute_distance(&feat1.descriptor, &feat2.descriptor);
distances.push((j, distance));
}
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
if distances.len() >= 2 {
let best_distance = distances[0].1;
let second_best_distance = distances[1].1;
if best_distance <= self.config.max_distance &&
best_distance / second_best_distance < self.config.ratio_threshold {
let match_result = FeatureMatch::new(i, distances[0].0, best_distance);
if match_result.confidence >= self.config.min_confidence {
matches.push(match_result);
}
}
}
}
matches.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
matches.truncate(self.config.max_matches);
Ok(matches)
}
fn name(&self) -> &str {
"RatioTest"
}
fn params(&self) -> HashMap<String, f64> {
let mut params = self.base_matcher.params();
params.insert("ratio_threshold".to_string(), self.config.ratio_threshold);
params
}
fn set_params(&mut self, params: HashMap<String, f64>) -> Result<(), ColmapError> {
if let Some(&ratio) = params.get("ratio_threshold") {
self.config.ratio_threshold = ratio;
}
self.base_matcher.set_params(params)
}
}
pub struct CrossCheckMatcher {
config: MatcherConfig,
base_matcher: BruteForceMatcher,
}
impl CrossCheckMatcher {
pub fn new(config: &MatcherConfig) -> Result<Self, ColmapError> {
let base_matcher = BruteForceMatcher::new(config)?;
Ok(Self {
config: config.clone(),
base_matcher,
})
}
}
impl FeatureMatcher for CrossCheckMatcher {
fn match_features(
&self,
features1: &[Feature],
features2: &[Feature],
) -> Result<Vec<FeatureMatch>, ColmapError> {
let forward_matches = self.base_matcher.match_features(features1, features2)?;
let backward_matches = self.base_matcher.match_features(features2, features1)?;
let mut cross_checked_matches = Vec::new();
for forward_match in forward_matches {
for backward_match in &backward_matches {
if backward_match.query_idx == forward_match.train_idx &&
backward_match.train_idx == forward_match.query_idx {
let final_distance = forward_match.distance.min(backward_match.distance);
let cross_match = FeatureMatch::new(
forward_match.query_idx,
forward_match.train_idx,
final_distance,
);
if cross_match.confidence >= self.config.min_confidence {
cross_checked_matches.push(cross_match);
}
break;
}
}
}
cross_checked_matches.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
cross_checked_matches.truncate(self.config.max_matches);
Ok(cross_checked_matches)
}
fn name(&self) -> &str {
"CrossCheck"
}
fn params(&self) -> HashMap<String, f64> {
self.base_matcher.params()
}
fn set_params(&mut self, params: HashMap<String, f64>) -> Result<(), ColmapError> {
self.base_matcher.set_params(params)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::Feature;
use nalgebra::Point2;
fn create_test_features() -> Vec<Feature> {
vec![
Feature {
point: Point2::new(10.0, 20.0),
descriptor: vec![1, 2, 3, 4],
response: 0.8,
angle: 0.0,
octave: 0,
scale: 1.0,
point3d_id: None,
},
Feature {
point: Point2::new(30.0, 40.0),
descriptor: vec![5, 6, 7, 8],
response: 0.9,
angle: 0.0,
octave: 0,
scale: 1.0,
point3d_id: None,
},
]
}
#[test]
fn test_feature_match_creation() {
let match_result = FeatureMatch::new(0, 1, 0.5);
assert_eq!(match_result.query_idx, 0);
assert_eq!(match_result.train_idx, 1);
assert_eq!(match_result.distance, 0.5);
assert!(match_result.is_valid(1.0));
assert!(!match_result.is_valid(0.3));
}
#[test]
fn test_matcher_config_default() {
let config = MatcherConfig::default();
assert_eq!(config.matcher_type, MatcherType::BruteForce);
assert_eq!(config.distance_type, DistanceType::L2);
assert_eq!(config.max_distance, 0.7);
assert_eq!(config.ratio_threshold, 0.8);
}
#[test]
fn test_matcher_factory() {
let config = MatcherConfig::default();
let matcher = MatcherFactory::create(&config);
assert!(matcher.is_ok());
assert_eq!(matcher.unwrap().name(), "BruteForce");
}
#[test]
fn test_brute_force_matcher_creation() {
let config = MatcherConfig::default();
let matcher = BruteForceMatcher::new(&config);
assert!(matcher.is_ok());
}
#[test]
fn test_brute_force_matching() {
let config = MatcherConfig {
max_distance: 10.0,
..Default::default()
};
let matcher = BruteForceMatcher::new(&config).unwrap();
let features1 = create_test_features();
let features2 = create_test_features();
let matches = matcher.match_features(&features1, &features2);
assert!(matches.is_ok());
let matches = matches.unwrap();
assert!(!matches.is_empty());
}
#[test]
fn test_distance_computation() {
let config = MatcherConfig {
distance_type: DistanceType::L2,
..Default::default()
};
let matcher = BruteForceMatcher::new(&config).unwrap();
let desc1 = vec![1, 2, 3, 4];
let desc2 = vec![1, 2, 3, 4];
let distance = matcher.compute_distance(&desc1, &desc2);
assert_eq!(distance, 0.0);
let desc3 = vec![2, 3, 4, 5];
let distance = matcher.compute_distance(&desc1, &desc3);
assert!(distance > 0.0);
}
}