use crate::core::{Feature, ColmapError};
use crate::feature::{
FeatureDetector, DescriptorExtractor, FeatureMatcher, FeatureMatch,
DetectorConfig, ExtractorConfig, MatcherConfig,
DetectorFactory, ExtractorFactory, MatcherFactory,
};
use image::{GrayImage, open};
use rayon::prelude::*;
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct PipelineConfig {
pub detector_config: DetectorConfig,
pub extractor_config: ExtractorConfig,
pub matcher_config: MatcherConfig,
pub parallel: bool,
pub num_threads: usize,
}
impl Default for PipelineConfig {
fn default() -> Self {
Self {
detector_config: DetectorConfig::default(),
extractor_config: ExtractorConfig::default(),
matcher_config: MatcherConfig::default(),
parallel: true,
num_threads: num_cpus::get(),
}
}
}
#[derive(Debug, Clone)]
pub struct ExtractionResult {
pub features: Vec<Feature>,
pub extraction_time_ms: u64,
pub num_detected: usize,
pub num_described: usize,
}
#[derive(Debug, Clone)]
pub struct MatchingResult {
pub matches: Vec<FeatureMatch>,
pub matching_time_ms: u64,
pub num_initial_matches: usize,
pub num_filtered_matches: usize,
pub quality_score: f64,
}
pub struct FeaturePipeline {
config: PipelineConfig,
detector: Box<dyn FeatureDetector>,
extractor: Box<dyn DescriptorExtractor>,
matcher: Box<dyn FeatureMatcher>,
}
impl FeaturePipeline {
pub fn new(config: PipelineConfig) -> Result<Self, ColmapError> {
let detector = DetectorFactory::create(&config.detector_config)?;
let extractor = ExtractorFactory::create(&config.extractor_config)?;
let matcher = MatcherFactory::create(&config.matcher_config)?;
Ok(Self {
config,
detector,
extractor,
matcher,
})
}
pub fn extract_from_file<P: AsRef<Path>>(&self, image_path: P) -> Result<ExtractionResult, ColmapError> {
let start_time = std::time::Instant::now();
let image = open(image_path)
.map_err(|e| ColmapError::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to load image: {}", e))))?
.to_luma8();
if image.width() == 0 || image.height() == 0 {
return Err(ColmapError::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, "Empty image loaded")));
}
self.extract_from_image(&image, start_time)
}
pub fn extract_from_image(&self, image: &GrayImage, start_time: std::time::Instant) -> Result<ExtractionResult, ColmapError> {
let mut features = self.detector.detect(image)?;
let num_detected = features.len();
if features.is_empty() {
return Ok(ExtractionResult {
features,
extraction_time_ms: start_time.elapsed().as_millis() as u64,
num_detected,
num_described: 0,
});
}
self.extractor.compute(image, &mut features)?;
features.retain(|f| !f.descriptor.is_empty());
let num_described = features.len();
let extraction_time_ms = start_time.elapsed().as_millis() as u64;
Ok(ExtractionResult {
features,
extraction_time_ms,
num_detected,
num_described,
})
}
pub fn match_features(
&self,
features1: &[Feature],
features2: &[Feature],
) -> Result<MatchingResult, ColmapError> {
let start_time = std::time::Instant::now();
if features1.is_empty() || features2.is_empty() {
return Ok(MatchingResult {
matches: Vec::new(),
matching_time_ms: 0,
num_initial_matches: 0,
num_filtered_matches: 0,
quality_score: 0.0,
});
}
let mut matches = self.matcher.match_features(features1, features2)?;
let num_initial_matches = matches.len();
matches = self.filter_matches(matches)?;
let num_filtered_matches = matches.len();
let quality_score = self.compute_quality_score(&matches, features1.len(), features2.len());
let matching_time_ms = start_time.elapsed().as_millis() as u64;
Ok(MatchingResult {
matches,
matching_time_ms,
num_initial_matches,
num_filtered_matches,
quality_score,
})
}
pub fn process_image_pair<P1: AsRef<Path>, P2: AsRef<Path>>(
&self,
image1_path: P1,
image2_path: P2,
) -> Result<(ExtractionResult, ExtractionResult, MatchingResult), ColmapError> {
let result1 = self.extract_from_file(image1_path)?;
let result2 = self.extract_from_file(image2_path)?;
let match_result = self.match_features(&result1.features, &result2.features)?;
Ok((result1, result2, match_result))
}
pub fn process_images<P: AsRef<Path> + Sync>(
&self,
image_paths: &[P],
) -> Result<Vec<ExtractionResult>, ColmapError> {
let mut results = Vec::with_capacity(image_paths.len());
if self.config.parallel && image_paths.len() > 1 {
let parallel_results: Result<Vec<_>, _> = image_paths
.par_iter()
.map(|path| self.extract_from_file(path))
.collect();
results = parallel_results?;
} else {
for path in image_paths {
results.push(self.extract_from_file(path)?);
}
}
Ok(results)
}
fn filter_matches(&self, mut matches: Vec<FeatureMatch>) -> Result<Vec<FeatureMatch>, ColmapError> {
matches.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
let mut used_query = std::collections::HashSet::new();
let mut used_train = std::collections::HashSet::new();
let mut filtered_matches = Vec::new();
for match_result in matches {
if !used_query.contains(&match_result.query_idx) &&
!used_train.contains(&match_result.train_idx) {
used_query.insert(match_result.query_idx);
used_train.insert(match_result.train_idx);
filtered_matches.push(match_result);
}
}
Ok(filtered_matches)
}
fn compute_quality_score(
&self,
matches: &[FeatureMatch],
num_features1: usize,
num_features2: usize,
) -> f64 {
if matches.is_empty() || num_features1 == 0 || num_features2 == 0 {
return 0.0;
}
let match_ratio = matches.len() as f64 / num_features1.min(num_features2) as f64;
let avg_confidence = matches.iter().map(|m| m.confidence).sum::<f64>() / matches.len() as f64;
let distances: Vec<f64> = matches.iter().map(|m| m.distance).collect();
let mean_distance = distances.iter().sum::<f64>() / distances.len() as f64;
let variance = distances.iter()
.map(|d| (d - mean_distance).powi(2))
.sum::<f64>() / distances.len() as f64;
let std_dev = variance.sqrt();
let consistency = 1.0 / (1.0 + std_dev);
(match_ratio * 0.4 + avg_confidence * 0.4 + consistency * 0.2).min(1.0)
}
pub fn get_stats(&self) -> HashMap<String, String> {
let mut stats = HashMap::new();
stats.insert("detector".to_string(), self.detector.name().to_string());
stats.insert("extractor".to_string(), self.extractor.name().to_string());
stats.insert("matcher".to_string(), self.matcher.name().to_string());
stats.insert("parallel".to_string(), self.config.parallel.to_string());
stats.insert("num_threads".to_string(), self.config.num_threads.to_string());
stats
}
pub fn update_config(&mut self, config: PipelineConfig) -> Result<(), ColmapError> {
if self.config.detector_config.detector_type != config.detector_config.detector_type {
self.detector = DetectorFactory::create(&config.detector_config)?;
}
if self.config.extractor_config.extractor_type != config.extractor_config.extractor_type {
self.extractor = ExtractorFactory::create(&config.extractor_config)?;
}
if self.config.matcher_config.matcher_type != config.matcher_config.matcher_type {
self.matcher = MatcherFactory::create(&config.matcher_config)?;
}
self.config = config;
Ok(())
}
}
pub struct PipelineBuilder {
config: PipelineConfig,
}
impl PipelineBuilder {
pub fn new() -> Self {
Self {
config: PipelineConfig::default(),
}
}
pub fn detector_config(mut self, config: DetectorConfig) -> Self {
self.config.detector_config = config;
self
}
pub fn extractor_config(mut self, config: ExtractorConfig) -> Self {
self.config.extractor_config = config;
self
}
pub fn matcher_config(mut self, config: MatcherConfig) -> Self {
self.config.matcher_config = config;
self
}
pub fn parallel(mut self, parallel: bool) -> Self {
self.config.parallel = parallel;
self
}
pub fn num_threads(mut self, num_threads: usize) -> Self {
self.config.num_threads = num_threads;
self
}
pub fn build(self) -> Result<FeaturePipeline, ColmapError> {
FeaturePipeline::new(self.config)
}
}
impl Default for PipelineBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::feature::{DetectorType, ExtractorType, MatcherType};
#[test]
fn test_pipeline_config_default() {
let config = PipelineConfig::default();
assert_eq!(config.detector_config.detector_type, DetectorType::Sift);
assert_eq!(config.extractor_config.extractor_type, ExtractorType::Sift);
assert_eq!(config.matcher_config.matcher_type, MatcherType::BruteForce);
assert!(config.parallel);
}
#[test]
fn test_pipeline_builder() {
let builder = PipelineBuilder::new()
.parallel(false)
.num_threads(4);
let pipeline = builder.build();
assert!(pipeline.is_ok());
}
#[test]
fn test_pipeline_creation() {
let config = PipelineConfig::default();
let pipeline = FeaturePipeline::new(config);
assert!(pipeline.is_ok());
}
#[test]
fn test_pipeline_stats() {
let config = PipelineConfig::default();
let pipeline = FeaturePipeline::new(config).unwrap();
let stats = pipeline.get_stats();
assert!(stats.contains_key("detector"));
assert!(stats.contains_key("extractor"));
assert!(stats.contains_key("matcher"));
assert_eq!(stats["detector"], "SIFT");
assert_eq!(stats["extractor"], "SIFT");
assert_eq!(stats["matcher"], "BruteForce");
}
#[test]
fn test_quality_score_computation() {
let config = PipelineConfig::default();
let pipeline = FeaturePipeline::new(config).unwrap();
let matches = vec![
FeatureMatch::new(0, 0, 0.1),
FeatureMatch::new(1, 1, 0.2),
FeatureMatch::new(2, 2, 0.15),
];
let score = pipeline.compute_quality_score(&matches, 10, 10);
assert!(score > 0.0 && score <= 1.0);
}
#[test]
fn test_match_filtering() {
let config = PipelineConfig::default();
let pipeline = FeaturePipeline::new(config).unwrap();
let matches = vec![
FeatureMatch::new(0, 0, 0.1),
FeatureMatch::new(0, 1, 0.2), FeatureMatch::new(1, 0, 0.15), FeatureMatch::new(2, 2, 0.3),
];
let filtered = pipeline.filter_matches(matches).unwrap();
assert_eq!(filtered.len(), 2); }
}