use crate::core::{Feature, ColmapError};
use image::GrayImage;
use imageproc::gradients;
use std::collections::HashMap;
pub trait DescriptorExtractor: Send + Sync {
fn compute(&self, image: &GrayImage, features: &mut Vec<Feature>) -> Result<(), ColmapError>;
fn descriptor_size(&self) -> usize;
fn descriptor_type(&self) -> DescriptorType;
fn name(&self) -> &str;
fn params(&self) -> HashMap<String, f64>;
fn set_params(&mut self, params: HashMap<String, f64>) -> Result<(), ColmapError>;
}
#[derive(Debug, Clone, PartialEq)]
pub enum DescriptorType {
Binary,
Float,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ExtractorType {
Sift,
Orb,
Surf,
Brief,
Brisk,
}
#[derive(Debug, Clone)]
pub struct ExtractorConfig {
pub extractor_type: ExtractorType,
pub descriptor_size: usize,
pub num_octaves: i32,
pub num_octave_layers: i32,
pub sigma: f64,
pub edge_threshold: f64,
pub contrast_threshold: f64,
}
impl Default for ExtractorConfig {
fn default() -> Self {
Self {
extractor_type: ExtractorType::Sift,
descriptor_size: 128,
num_octaves: 4,
num_octave_layers: 3,
sigma: 1.6,
edge_threshold: 10.0,
contrast_threshold: 0.04,
}
}
}
pub struct ExtractorFactory;
impl ExtractorFactory {
pub fn create(config: &ExtractorConfig) -> Result<Box<dyn DescriptorExtractor>, ColmapError> {
match config.extractor_type {
ExtractorType::Sift => {
Ok(Box::new(SiftExtractor::new(config)?))
},
ExtractorType::Orb => {
Ok(Box::new(OrbExtractor::new(config)?))
},
ExtractorType::Surf => {
Err(ColmapError::InvalidParameter("SURF extractor not implemented yet".to_string()))
},
ExtractorType::Brief => {
Ok(Box::new(BriefExtractor::new(config)?))
},
ExtractorType::Brisk => {
Ok(Box::new(BriskExtractor::new(config)?))
},
}
}
}
pub struct SiftExtractor {
config: ExtractorConfig,
}
impl SiftExtractor {
pub fn new(config: &ExtractorConfig) -> Result<Self, ColmapError> {
Ok(Self {
config: config.clone(),
})
}
fn compute_sift_descriptor(&self, image: &GrayImage, x: u32, y: u32, scale: f64, angle: f64) -> Vec<f32> {
let mut descriptor = vec![0.0f32; 128];
let patch_size = (16.0 * scale) as u32;
let half_patch = patch_size / 2;
if x < half_patch || y < half_patch ||
x + half_patch >= image.width() || y + half_patch >= image.height() {
return descriptor;
}
let grad_magnitude = gradients::sobel_gradients(image);
for sub_y in 0..4 {
for sub_x in 0..4 {
let start_x = x - half_patch + sub_x * patch_size / 4;
let start_y = y - half_patch + sub_y * patch_size / 4;
let end_x = start_x + patch_size / 4;
let end_y = start_y + patch_size / 4;
let mut hist = [0.0f32; 8];
for py in start_y..end_y {
for px in start_x..end_x {
if px < image.width() && py < image.height() {
let magnitude = grad_magnitude.get_pixel(px, py)[0] as f64;
let orientation = ((px + py) as f64 * 0.1) % 8.0;
let bin = (orientation as usize).min(7);
hist[bin] += magnitude as f32;
}
}
}
let desc_idx = (sub_y as usize * 4 + sub_x as usize) * 8;
for i in 0..8 {
if desc_idx + i < descriptor.len() {
descriptor[desc_idx + i] = hist[i];
}
}
}
}
let norm: f32 = descriptor.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in &mut descriptor {
*val /= norm;
}
}
descriptor
}
}
impl DescriptorExtractor for SiftExtractor {
fn compute(&self, image: &GrayImage, features: &mut Vec<Feature>) -> Result<(), ColmapError> {
for feature in features.iter_mut() {
let descriptor = self.compute_sift_descriptor(
image,
feature.point.x as u32,
feature.point.y as u32,
feature.scale as f64,
feature.angle as f64,
);
feature.descriptor = descriptor.into_iter().map(|x| (x * 255.0) as u8).collect();
}
Ok(())
}
fn descriptor_size(&self) -> usize {
128
}
fn descriptor_type(&self) -> DescriptorType {
DescriptorType::Float
}
fn name(&self) -> &str {
"SIFT"
}
fn params(&self) -> HashMap<String, f64> {
let mut params = HashMap::new();
params.insert("num_octaves".to_string(), self.config.num_octaves as f64);
params.insert("num_octave_layers".to_string(), self.config.num_octave_layers as f64);
params.insert("sigma".to_string(), self.config.sigma);
params.insert("edge_threshold".to_string(), self.config.edge_threshold);
params.insert("contrast_threshold".to_string(), self.config.contrast_threshold);
params
}
fn set_params(&mut self, params: HashMap<String, f64>) -> Result<(), ColmapError> {
for (key, value) in params {
match key.as_str() {
"num_octaves" => self.config.num_octaves = value as i32,
"num_octave_layers" => self.config.num_octave_layers = value as i32,
"sigma" => self.config.sigma = value,
"edge_threshold" => self.config.edge_threshold = value,
"contrast_threshold" => self.config.contrast_threshold = value,
_ => return Err(ColmapError::InvalidParameter(format!("Unknown parameter: {}", key))),
}
}
Ok(())
}
}
pub struct OrbExtractor {
config: ExtractorConfig,
}
impl OrbExtractor {
pub fn new(config: &ExtractorConfig) -> Result<Self, ColmapError> {
Ok(Self {
config: config.clone(),
})
}
fn compute_orb_descriptor(&self, image: &GrayImage, x: u32, y: u32, _angle: f64) -> Vec<u8> {
let mut descriptor = vec![0u8; 32];
let patch_size = 31u32; let half_patch = patch_size / 2;
if x < half_patch || y < half_patch ||
x + half_patch >= image.width() || y + half_patch >= image.height() {
return descriptor;
}
let test_patterns = [
((-8, -3), (9, 5)), ((-13, 2), (12, -6)), ((-6, -13), (-4, -8)),
((20, -1), (4, 2)), ((-13, -13), (5, -13)), ((16, -9), (-4, 6)),
((-16, -7), (-4, -10)), ((12, -6), (-13, -4)), ((-16, -3), (-2, -11)),
((1, -3), (15, -5)), ((-1, -8), (14, -15)), ((4, -6), (7, 12)),
((2, -4), (12, 12)), ((-15, -10), (-5, -7)), ((-4, 9), (1, -4)),
((0, 14), (-3, 10)), ((-8, 7), (-8, 1)), ((4, 2), (12, 1)),
((-5, -13), (-7, 0)), ((-13, -5), (-3, -4)), ((-1, 1), (5, 1)),
((-7, -10), (12, 14)), ((-13, 3), (-11, -5)), ((4, -2), (13, 2)),
((7, -15), (12, -6)), ((-7, -3), (11, 0)), ((-10, -5), (5, 10)),
((-13, -8), (7, 7)), ((1, 9), (-1, -13)), ((-3, 7), (7, 12)),
((12, 6), (-1, -9)), ((-10, 0), (10, -5)), ((-13, 0), (1, -8)),
];
for (bit_idx, &((dx1, dy1), (dx2, dy2))) in test_patterns.iter().enumerate() {
if bit_idx >= 256 { break; }
let x1 = (x as i32 + dx1) as u32;
let y1 = (y as i32 + dy1) as u32;
let x2 = (x as i32 + dx2) as u32;
let y2 = (y as i32 + dy2) as u32;
if x1 < image.width() && y1 < image.height() &&
x2 < image.width() && y2 < image.height() {
let pixel1 = image.get_pixel(x1, y1)[0];
let pixel2 = image.get_pixel(x2, y2)[0];
if pixel1 < pixel2 {
let byte_idx = bit_idx / 8;
let bit_pos = bit_idx % 8;
descriptor[byte_idx] |= 1 << bit_pos;
}
}
}
descriptor
}
}
impl DescriptorExtractor for OrbExtractor {
fn compute(&self, image: &GrayImage, features: &mut Vec<Feature>) -> Result<(), ColmapError> {
for feature in features.iter_mut() {
let descriptor = self.compute_orb_descriptor(
image,
feature.point.x as u32,
feature.point.y as u32,
feature.angle as f64,
);
feature.descriptor = descriptor;
}
Ok(())
}
fn descriptor_size(&self) -> usize {
32 }
fn descriptor_type(&self) -> DescriptorType {
DescriptorType::Binary
}
fn name(&self) -> &str {
"ORB"
}
fn params(&self) -> HashMap<String, f64> {
HashMap::new()
}
fn set_params(&mut self, _params: HashMap<String, f64>) -> Result<(), ColmapError> {
Ok(())
}
}
pub struct BriefExtractor {
config: ExtractorConfig,
}
impl BriefExtractor {
pub fn new(config: &ExtractorConfig) -> Result<Self, ColmapError> {
Ok(Self {
config: config.clone(),
})
}
fn compute_brief_descriptor(&self, image: &GrayImage, x: u32, y: u32) -> Vec<u8> {
let descriptor_bits = self.config.descriptor_size;
let mut descriptor = vec![0u8; descriptor_bits / 8];
let patch_size = 48u32; let half_patch = patch_size / 2;
if x < half_patch || y < half_patch ||
x + half_patch >= image.width() || y + half_patch >= image.height() {
return descriptor;
}
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
for bit_idx in 0..descriptor_bits {
let mut hasher = DefaultHasher::new();
bit_idx.hash(&mut hasher);
let hash = hasher.finish();
let dx1 = ((hash & 0xFF) as i32 - 128) * patch_size as i32 / 256;
let dy1 = (((hash >> 8) & 0xFF) as i32 - 128) * patch_size as i32 / 256;
let dx2 = (((hash >> 16) & 0xFF) as i32 - 128) * patch_size as i32 / 256;
let dy2 = (((hash >> 24) & 0xFF) as i32 - 128) * patch_size as i32 / 256;
let x1 = (x as i32 + dx1) as u32;
let y1 = (y as i32 + dy1) as u32;
let x2 = (x as i32 + dx2) as u32;
let y2 = (y as i32 + dy2) as u32;
if x1 < image.width() && y1 < image.height() &&
x2 < image.width() && y2 < image.height() {
let pixel1 = image.get_pixel(x1, y1)[0];
let pixel2 = image.get_pixel(x2, y2)[0];
if pixel1 < pixel2 {
let byte_idx = bit_idx / 8;
let bit_pos = bit_idx % 8;
descriptor[byte_idx] |= 1 << bit_pos;
}
}
}
descriptor
}
}
impl DescriptorExtractor for BriefExtractor {
fn compute(&self, image: &GrayImage, features: &mut Vec<Feature>) -> Result<(), ColmapError> {
for feature in features.iter_mut() {
let descriptor = self.compute_brief_descriptor(
image,
feature.point.x as u32,
feature.point.y as u32,
);
feature.descriptor = descriptor;
}
Ok(())
}
fn descriptor_size(&self) -> usize {
self.config.descriptor_size / 8 }
fn descriptor_type(&self) -> DescriptorType {
DescriptorType::Binary
}
fn name(&self) -> &str {
"BRIEF"
}
fn params(&self) -> HashMap<String, f64> {
let mut params = HashMap::new();
params.insert("descriptor_size".to_string(), self.config.descriptor_size as f64);
params
}
fn set_params(&mut self, params: HashMap<String, f64>) -> Result<(), ColmapError> {
for (key, value) in params {
match key.as_str() {
"descriptor_size" => self.config.descriptor_size = value as usize,
_ => return Err(ColmapError::InvalidParameter(format!("Unknown parameter: {}", key))),
}
}
Ok(())
}
}
pub struct BriskExtractor {
config: ExtractorConfig,
}
impl BriskExtractor {
pub fn new(config: &ExtractorConfig) -> Result<Self, ColmapError> {
Ok(Self {
config: config.clone(),
})
}
fn compute_brisk_descriptor(&self, image: &GrayImage, x: u32, y: u32, angle: f64) -> Vec<u8> {
let mut descriptor = vec![0u8; 64];
let patch_size = 60u32; let half_patch = patch_size / 2;
if x < half_patch || y < half_patch ||
x + half_patch >= image.width() || y + half_patch >= image.height() {
return descriptor;
}
let sampling_points = [
(0.0, -2.5), (1.8, -1.8), (2.5, 0.0), (1.8, 1.8),
(0.0, 2.5), (-1.8, 1.8), (-2.5, 0.0), (-1.8, -1.8),
(0.0, -4.0), (2.8, -2.8), (4.0, 0.0), (2.8, 2.8),
(0.0, 4.0), (-2.8, 2.8), (-4.0, 0.0), (-2.8, -2.8),
];
let cos_angle = angle.cos();
let sin_angle = angle.sin();
let mut rotated_points = Vec::new();
for &(px, py) in &sampling_points {
let rx = px * cos_angle - py * sin_angle;
let ry = px * sin_angle + py * cos_angle;
rotated_points.push((rx, ry));
}
let mut bit_idx = 0;
for i in 0..rotated_points.len() {
for j in (i + 1)..rotated_points.len() {
if bit_idx >= 512 { break; }
let (rx1, ry1) = rotated_points[i];
let (rx2, ry2) = rotated_points[j];
let x1 = (x as f64 + rx1) as u32;
let y1 = (y as f64 + ry1) as u32;
let x2 = (x as f64 + rx2) as u32;
let y2 = (y as f64 + ry2) as u32;
if x1 < image.width() && y1 < image.height() &&
x2 < image.width() && y2 < image.height() {
let pixel1 = image.get_pixel(x1, y1)[0];
let pixel2 = image.get_pixel(x2, y2)[0];
if pixel1 < pixel2 {
let byte_idx = bit_idx / 8;
let bit_pos = bit_idx % 8;
descriptor[byte_idx] |= 1 << bit_pos;
}
}
bit_idx += 1;
}
}
descriptor
}
}
impl DescriptorExtractor for BriskExtractor {
fn compute(&self, image: &GrayImage, features: &mut Vec<Feature>) -> Result<(), ColmapError> {
for feature in features.iter_mut() {
let descriptor = self.compute_brisk_descriptor(
image,
feature.point.x as u32,
feature.point.y as u32,
feature.angle as f64,
);
feature.descriptor = descriptor;
}
Ok(())
}
fn descriptor_size(&self) -> usize {
64 }
fn descriptor_type(&self) -> DescriptorType {
DescriptorType::Binary
}
fn name(&self) -> &str {
"BRISK"
}
fn params(&self) -> HashMap<String, f64> {
HashMap::new()
}
fn set_params(&mut self, _params: HashMap<String, f64>) -> Result<(), ColmapError> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::Point2;
#[test]
fn test_extractor_config_default() {
let config = ExtractorConfig::default();
assert_eq!(config.extractor_type, ExtractorType::Sift);
assert_eq!(config.descriptor_size, 128);
}
#[test]
fn test_extractor_factory() {
let config = ExtractorConfig::default();
let extractor = ExtractorFactory::create(&config);
assert!(extractor.is_ok());
assert_eq!(extractor.unwrap().name(), "SIFT");
}
#[test]
fn test_sift_extractor_creation() {
let config = ExtractorConfig::default();
let extractor = SiftExtractor::new(&config);
assert!(extractor.is_ok());
let extractor = extractor.unwrap();
assert_eq!(extractor.descriptor_size(), 128);
assert_eq!(extractor.descriptor_type(), DescriptorType::Float);
}
#[test]
fn test_sift_extractor() {
let config = ExtractorConfig::default();
let extractor = SiftExtractor::new(&config).unwrap();
let image = GrayImage::new(100, 100);
let mut features = vec![
Feature {
point: Point2::new(50.0, 50.0),
scale: 1.0,
angle: 0.0,
response: 1.0,
octave: 0,
descriptor: Vec::new(),
point3d_id: None,
}
];
assert!(extractor.compute(&image, &mut features).is_ok());
assert!(!features[0].descriptor.is_empty());
assert_eq!(extractor.descriptor_size(), 128);
assert_eq!(extractor.descriptor_type(), DescriptorType::Float);
}
#[test]
fn test_orb_extractor_creation() {
let config = ExtractorConfig {
extractor_type: ExtractorType::Orb,
..Default::default()
};
let extractor = OrbExtractor::new(&config);
assert!(extractor.is_ok());
let extractor = extractor.unwrap();
assert_eq!(extractor.descriptor_size(), 32);
assert_eq!(extractor.descriptor_type(), DescriptorType::Binary);
let image = GrayImage::new(100, 100);
let mut features = vec![
Feature {
point: Point2::new(50.0, 50.0),
scale: 1.0,
angle: 0.0,
response: 1.0,
octave: 0,
descriptor: Vec::new(),
point3d_id: None,
}
];
assert!(extractor.compute(&image, &mut features).is_ok());
assert_eq!(features[0].descriptor.len(), 32);
}
#[test]
fn test_brief_extractor_creation() {
let config = ExtractorConfig {
extractor_type: ExtractorType::Brief,
descriptor_size: 256,
..Default::default()
};
let extractor = BriefExtractor::new(&config);
assert!(extractor.is_ok());
let extractor = extractor.unwrap();
assert_eq!(extractor.descriptor_size(), 32); assert_eq!(extractor.descriptor_type(), DescriptorType::Binary);
let image = GrayImage::new(100, 100);
let mut features = vec![
Feature {
point: Point2::new(50.0, 50.0),
scale: 1.0,
angle: 0.0,
response: 1.0,
octave: 0,
descriptor: Vec::new(),
point3d_id: None,
}
];
assert!(extractor.compute(&image, &mut features).is_ok());
assert_eq!(features[0].descriptor.len(), 32);
}
#[test]
fn test_brisk_extractor_creation() {
let config = ExtractorConfig {
extractor_type: ExtractorType::Brisk,
..Default::default()
};
let extractor = BriskExtractor::new(&config);
assert!(extractor.is_ok());
let extractor = extractor.unwrap();
assert_eq!(extractor.descriptor_size(), 64);
assert_eq!(extractor.descriptor_type(), DescriptorType::Binary);
}
}