colmap 0.1.2

A comprehensive Rust library for COLMAP-style computer vision and 3D reconstruction
Documentation
//! ORB (Oriented FAST and Rotated BRIEF) 特征检测器实现

use crate::core::{Feature, Point2, Result};
use crate::feature::detector::FeatureDetector;
use image::GrayImage;

/// ORB 特征检测器
#[derive(Debug, Clone)]
pub struct OrbDetector {
    /// 最大特征点数量
    pub max_features: usize,
    /// 金字塔层数
    pub num_levels: usize,
    /// 缩放因子
    pub scale_factor: f32,
    /// FAST 阈值
    pub fast_threshold: u8,
}

impl Default for OrbDetector {
    fn default() -> Self {
        Self {
            max_features: 500,
            num_levels: 8,
            scale_factor: 1.2,
            fast_threshold: 20,
        }
    }
}

impl OrbDetector {
    /// 创建新的 ORB 检测器
    pub fn new() -> Self {
        Self::default()
    }

    /// 设置参数
    pub fn with_params(
        max_features: usize,
        num_levels: usize,
        scale_factor: f32,
        fast_threshold: u8,
    ) -> Self {
        Self {
            max_features,
            num_levels,
            scale_factor,
            fast_threshold,
        }
    }

    /// 构建图像金字塔
    fn build_pyramid(&self, image: &GrayImage) -> Vec<GrayImage> {
        let mut pyramid = vec![image.clone()];
        let mut current_image = image.clone();
        
        for _ in 1..self.num_levels {
            let (width, height) = current_image.dimensions();
            let new_width = (width as f32 / self.scale_factor) as u32;
            let new_height = (height as f32 / self.scale_factor) as u32;
            
            if new_width < 10 || new_height < 10 {
                break;
            }
            
            current_image = image::imageops::resize(
                &current_image,
                new_width,
                new_height,
                image::imageops::FilterType::Lanczos3,
            );
            pyramid.push(current_image.clone());
        }
        
        pyramid
    }

    /// 使用 FAST 检测角点
    fn detect_fast_corners(&self, image: &GrayImage) -> Vec<(u32, u32, f32)> {
        let mut corners = Vec::new();
        let (width, height) = image.dimensions();
        
        // FAST-9 检测器的圆形模式
        let circle_offsets = [
            (0, -3), (1, -3), (2, -2), (3, -1),
            (3, 0), (3, 1), (2, 2), (1, 3),
            (0, 3), (-1, 3), (-2, 2), (-3, 1),
            (-3, 0), (-3, -1), (-2, -2), (-1, -3),
        ];
        
        for y in 3..height-3 {
            for x in 3..width-3 {
                let center_intensity = image.get_pixel(x, y)[0];
                let threshold = self.fast_threshold;
                
                let mut brighter_count = 0;
                let mut darker_count = 0;
                
                for &(dx, dy) in &circle_offsets {
                    let px = (x as i32 + dx) as u32;
                    let py = (y as i32 + dy) as u32;
                    let intensity = image.get_pixel(px, py)[0];
                    
                    if intensity > center_intensity.saturating_add(threshold) {
                        brighter_count += 1;
                    } else if intensity < center_intensity.saturating_sub(threshold) {
                        darker_count += 1;
                    }
                }
                
                // 需要连续的 9 个点
                if brighter_count >= 9 || darker_count >= 9 {
                    let response = self.calculate_fast_response(image, x, y, &circle_offsets);
                    corners.push((x, y, response));
                }
            }
        }
        
        corners
    }

    /// 计算 FAST 响应值
    fn calculate_fast_response(
        &self,
        image: &GrayImage,
        x: u32,
        y: u32,
        circle_offsets: &[(i32, i32)],
    ) -> f32 {
        let center_intensity = image.get_pixel(x, y)[0] as f32;
        let mut sum_diff = 0.0;
        
        for &(dx, dy) in circle_offsets {
            let px = (x as i32 + dx) as u32;
            let py = (y as i32 + dy) as u32;
            let intensity = image.get_pixel(px, py)[0] as f32;
            sum_diff += (intensity - center_intensity).abs();
        }
        
        sum_diff / circle_offsets.len() as f32
    }

    /// 计算特征点的主方向
    fn compute_orientation(&self, image: &GrayImage, x: u32, y: u32) -> f32 {
        let (width, height) = image.dimensions();
        let radius = 15;
        
        let mut m01 = 0.0;
        let mut m10 = 0.0;
        
        for dy in -radius..=radius {
            for dx in -radius..=radius {
                let px = x as i32 + dx;
                let py = y as i32 + dy;
                
                if px >= 0 && px < width as i32 && py >= 0 && py < height as i32 {
                    let intensity = image.get_pixel(px as u32, py as u32)[0] as f32;
                    m01 += dy as f32 * intensity;
                    m10 += dx as f32 * intensity;
                }
            }
        }
        
        m01.atan2(m10)
    }

    /// 非极大值抑制
    fn non_maximum_suppression(&self, corners: &[(u32, u32, f32)]) -> Vec<(u32, u32, f32)> {
        let mut suppressed = Vec::new();
        let radius = 3;
        
        for (i, &(x1, y1, response1)) in corners.iter().enumerate() {
            let mut is_maximum = true;
            
            for (j, &(x2, y2, response2)) in corners.iter().enumerate() {
                if i == j { continue; }
                
                let dx = (x1 as i32 - x2 as i32).abs();
                let dy = (y1 as i32 - y2 as i32).abs();
                
                if dx <= radius && dy <= radius && response2 > response1 {
                    is_maximum = false;
                    break;
                }
            }
            
            if is_maximum {
                suppressed.push((x1, y1, response1));
            }
        }
        
        suppressed
    }
}

impl FeatureDetector for OrbDetector {
    fn name(&self) -> &str {
        "ORB"
    }
    
    fn params(&self) -> std::collections::HashMap<String, f64> {
        let mut params = std::collections::HashMap::new();
        params.insert("max_features".to_string(), self.max_features as f64);
        params.insert("num_levels".to_string(), self.num_levels as f64);
        params.insert("scale_factor".to_string(), self.scale_factor as f64);
        params.insert("fast_threshold".to_string(), self.fast_threshold as f64);
        params
    }
    
    fn set_params(&mut self, params: std::collections::HashMap<String, f64>) -> Result<()> {
        if let Some(&max_features) = params.get("max_features") {
            self.max_features = max_features as usize;
        }
        if let Some(&num_levels) = params.get("num_levels") {
            self.num_levels = num_levels as usize;
        }
        if let Some(&scale_factor) = params.get("scale_factor") {
            self.scale_factor = scale_factor as f32;
        }
        if let Some(&fast_threshold) = params.get("fast_threshold") {
            self.fast_threshold = fast_threshold as u8;
        }
        Ok(())
    }
    
    fn detect(&self, image: &GrayImage) -> Result<Vec<Feature>> {
        let mut all_features = Vec::new();
        
        // 构建图像金字塔
        let pyramid = self.build_pyramid(image);
        
        for (level, level_image) in pyramid.iter().enumerate() {
            // 检测 FAST 角点
            let corners = self.detect_fast_corners(level_image);
            
            // 非极大值抑制
            let suppressed_corners = self.non_maximum_suppression(&corners);
            
            // 转换为 Feature 结构
            for (x, y, response) in suppressed_corners {
                let scale_factor = self.scale_factor.powi(level as i32);
                let orientation = self.compute_orientation(level_image, x, y);
                
                let feature = Feature {
                    point: Point2::new(
                        (x as f64) * scale_factor as f64,
                        (y as f64) * scale_factor as f64,
                    ),
                    descriptor: Vec::new(), // 将在描述符提取阶段填充
                    scale: scale_factor,
                    angle: orientation,
                    response,
                    octave: level as i32,
                    point3d_id: None,
                };
                
                all_features.push(feature);
            }
        }
        
        // 按响应值排序并限制数量
        all_features.sort_by(|a, b| b.response.partial_cmp(&a.response).unwrap());
        all_features.truncate(self.max_features);
        
        Ok(all_features)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use image::{ImageBuffer, Luma};

    #[test]
    fn test_orb_detector_creation() {
        let detector = OrbDetector::new();
        assert_eq!(detector.max_features, 500);
        assert_eq!(detector.num_levels, 8);
    }

    #[test]
    fn test_orb_detection() {
        let detector = OrbDetector::new();
        let image = ImageBuffer::from_fn(100, 100, |x, y| {
            Luma([((x + y) % 256) as u8])
        });
        
        let result = detector.detect(&image);
        assert!(result.is_ok());
    }

    #[test]
    fn test_pyramid_building() {
        let detector = OrbDetector::new();
        let image = ImageBuffer::from_fn(100, 100, |_, _| Luma([128]));
        let pyramid = detector.build_pyramid(&image);
        
        assert!(!pyramid.is_empty());
        assert_eq!(pyramid[0].dimensions(), (100, 100));
    }
}