colmap 0.1.2

A comprehensive Rust library for COLMAP-style computer vision and 3D reconstruction
Documentation
//! SIFT (Scale-Invariant Feature Transform) 特征检测器实现

use crate::core::{Feature, Point2, Result};
use crate::feature::detector::FeatureDetector;
use image::{GrayImage, ImageBuffer, Luma};

/// SIFT 特征检测器
#[derive(Debug, Clone)]
pub struct SiftDetector {
    /// 高斯金字塔的层数
    pub num_octaves: usize,
    /// 每个八度的层数
    pub num_scales: usize,
    /// 对比度阈值
    pub contrast_threshold: f32,
    /// 边缘阈值
    pub edge_threshold: f32,
    /// 初始 sigma 值
    pub sigma: f32,
}

impl Default for SiftDetector {
    fn default() -> Self {
        Self {
            num_octaves: 4,
            num_scales: 3,
            contrast_threshold: 0.04,
            edge_threshold: 10.0,
            sigma: 1.6,
        }
    }
}

impl SiftDetector {
    /// 创建新的 SIFT 检测器
    pub fn new() -> Self {
        Self::default()
    }

    /// 设置参数
    pub fn with_params(
        num_octaves: usize,
        num_scales: usize,
        contrast_threshold: f32,
        edge_threshold: f32,
        sigma: f32,
    ) -> Self {
        Self {
            num_octaves,
            num_scales,
            contrast_threshold,
            edge_threshold,
            sigma,
        }
    }

    /// 构建高斯金字塔
    fn build_gaussian_pyramid(&self, image: &GrayImage) -> Vec<Vec<GrayImage>> {
        let mut pyramid = Vec::new();
        let mut current_image = image.clone();

        for octave in 0..self.num_octaves {
            let mut octave_images = Vec::new();
            
            for scale in 0..=self.num_scales + 2 {
                let sigma = self.sigma * (2.0_f32).powf(scale as f32 / self.num_scales as f32);
                let blurred = self.gaussian_blur(&current_image, sigma);
                octave_images.push(blurred);
            }
            
            pyramid.push(octave_images);
            
            // 下采样到下一个八度
            if octave < self.num_octaves - 1 {
                current_image = self.downsample(&current_image);
            }
        }

        pyramid
    }

    /// 高斯模糊
    fn gaussian_blur(&self, image: &GrayImage, sigma: f32) -> GrayImage {
        // 简化实现:使用 imageproc 的高斯模糊
        imageproc::filter::gaussian_blur_f32(image, sigma)
    }

    /// 下采样图像
    fn downsample(&self, image: &GrayImage) -> GrayImage {
        let (width, height) = image.dimensions();
        let new_width = width / 2;
        let new_height = height / 2;
        
        ImageBuffer::from_fn(new_width, new_height, |x, y| {
            *image.get_pixel(x * 2, y * 2)
        })
    }

    /// 构建 DoG 金字塔
    fn build_dog_pyramid(&self, gaussian_pyramid: &[Vec<GrayImage>]) -> Vec<Vec<GrayImage>> {
        let mut dog_pyramid = Vec::new();
        
        for octave_images in gaussian_pyramid {
            let mut dog_octave = Vec::new();
            
            for i in 1..octave_images.len() {
                let dog = self.subtract_images(&octave_images[i], &octave_images[i-1]);
                dog_octave.push(dog);
            }
            
            dog_pyramid.push(dog_octave);
        }
        
        dog_pyramid
    }

    /// 图像相减
    fn subtract_images(&self, img1: &GrayImage, img2: &GrayImage) -> GrayImage {
        let (width, height) = img1.dimensions();
        ImageBuffer::from_fn(width, height, |x, y| {
            let val1 = img1.get_pixel(x, y)[0] as i16;
            let val2 = img2.get_pixel(x, y)[0] as i16;
            let diff = (val1 - val2).unsigned_abs() as u8;
            Luma([diff])
        })
    }

    /// 检测极值点
    fn detect_extrema(&self, dog_pyramid: &[Vec<GrayImage>]) -> Vec<Feature> {
        let mut keypoints = Vec::new();
        
        for (octave_idx, octave_images) in dog_pyramid.iter().enumerate() {
            for (scale_idx, image) in octave_images.iter().enumerate() {
                if scale_idx == 0 || scale_idx == octave_images.len() - 1 {
                    continue; // 跳过边界层
                }
                
                let (width, height) = image.dimensions();
                
                for y in 1..height-1 {
                    for x in 1..width-1 {
                        if self.is_extremum(octave_images, scale_idx, x, y) {
                            let feature = Feature {
                                point: Point2::new(
                                    (x as f64) * (2.0_f64).powf(octave_idx as f64),
                                    (y as f64) * (2.0_f64).powf(octave_idx as f64)
                                ),
                                descriptor: Vec::new(),
                                scale: self.sigma * (2.0_f32).powf(
                                    octave_idx as f32 + scale_idx as f32 / self.num_scales as f32
                                ),
                                angle: 0.0, // 将在后续步骤中计算
                                response: image.get_pixel(x, y)[0] as f32,
                                octave: octave_idx as i32,
                                point3d_id: None,
                            };
                            keypoints.push(feature);
                        }
                    }
                }
            }
        }
        
        keypoints
    }

    /// 检查是否为极值点
    fn is_extremum(&self, octave_images: &[GrayImage], scale_idx: usize, x: u32, y: u32) -> bool {
        let current_val = octave_images[scale_idx].get_pixel(x, y)[0];
        
        // 检查 3x3x3 邻域
        for dz in -1..=1 {
            let z = (scale_idx as i32 + dz) as usize;
            if z >= octave_images.len() { continue; }
            
            for dy in -1..=1 {
                for dx in -1..=1 {
                    if dx == 0 && dy == 0 && dz == 0 { continue; }
                    
                    let nx = (x as i32 + dx) as u32;
                    let ny = (y as i32 + dy) as u32;
                    
                    let neighbor_val = octave_images[z].get_pixel(nx, ny)[0];
                    
                    if current_val <= neighbor_val {
                        return false; // 不是最大值
                    }
                }
            }
        }
        
        true
    }
}

impl FeatureDetector for SiftDetector {
    fn name(&self) -> &str {
        "SIFT"
    }
    
    fn params(&self) -> std::collections::HashMap<String, f64> {
        let mut params = std::collections::HashMap::new();
        params.insert("num_octaves".to_string(), self.num_octaves as f64);
        params.insert("num_scales".to_string(), self.num_scales as f64);
        params.insert("contrast_threshold".to_string(), self.contrast_threshold as f64);
        params.insert("edge_threshold".to_string(), self.edge_threshold as f64);
        params.insert("sigma".to_string(), self.sigma as f64);
        params
    }
    
    fn set_params(&mut self, params: std::collections::HashMap<String, f64>) -> Result<()> {
        if let Some(&num_octaves) = params.get("num_octaves") {
            self.num_octaves = num_octaves as usize;
        }
        if let Some(&num_scales) = params.get("num_scales") {
            self.num_scales = num_scales as usize;
        }
        if let Some(&contrast_threshold) = params.get("contrast_threshold") {
            self.contrast_threshold = contrast_threshold as f32;
        }
        if let Some(&edge_threshold) = params.get("edge_threshold") {
            self.edge_threshold = edge_threshold as f32;
        }
        if let Some(&sigma) = params.get("sigma") {
            self.sigma = sigma as f32;
        }
        Ok(())
    }
    
    fn detect(&self, image: &GrayImage) -> Result<Vec<Feature>> {
        // 构建高斯金字塔
        let gaussian_pyramid = self.build_gaussian_pyramid(image);
        
        // 构建 DoG 金字塔
        let dog_pyramid = self.build_dog_pyramid(&gaussian_pyramid);
        
        // 检测极值点
        let keypoints = self.detect_extrema(&dog_pyramid);
        
        Ok(keypoints)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use image::{ImageBuffer, Luma};

    #[test]
    fn test_sift_detector_creation() {
        let detector = SiftDetector::new();
        assert_eq!(detector.num_octaves, 4);
        assert_eq!(detector.num_scales, 3);
    }

    #[test]
    fn test_sift_detection() {
        let detector = SiftDetector::new();
        let image = ImageBuffer::from_fn(100, 100, |x, y| {
            Luma([(x + y) as u8])
        });
        
        let result = detector.detect(&image);
        assert!(result.is_ok());
    }
}