colmap 0.1.2

A comprehensive Rust library for COLMAP-style computer vision and 3D reconstruction
Documentation
//! BRIEF (Binary Robust Independent Elementary Features) 描述符实现
//!
//! BRIEF 是一种二进制特征描述符,通过比较图像块中像素对的强度来生成二进制字符串。
//! 它具有计算快速、存储空间小的优点,适合实时应用。

use crate::core::{Feature, ColmapError};
use image::GrayImage;
use nalgebra::Point2;
use rand::{Rng, SeedableRng};
use rand::rngs::StdRng;
use std::collections::HashMap;

/// BRIEF 描述符提取器
#[derive(Debug, Clone)]
pub struct BriefExtractor {
    /// 描述符长度(位数)
    pub descriptor_length: usize,
    /// 采样模式半径
    pub patch_size: i32,
    /// 随机种子(用于生成一致的采样模式)
    pub seed: u64,
    /// 预计算的采样点对
    sampling_pairs: Vec<(Point2<i32>, Point2<i32>)>,
}

impl Default for BriefExtractor {
    fn default() -> Self {
        let mut extractor = Self {
            descriptor_length: 256, // 256 位描述符
            patch_size: 15,         // 31x31 的采样窗口
            seed: 42,
            sampling_pairs: Vec::new(),
        };
        extractor.generate_sampling_pairs();
        extractor
    }
}

impl BriefExtractor {
    /// 创建新的 BRIEF 描述符提取器
    pub fn new(descriptor_length: usize, patch_size: i32, seed: u64) -> Self {
        let mut extractor = Self {
            descriptor_length,
            patch_size,
            seed,
            sampling_pairs: Vec::new(),
        };
        extractor.generate_sampling_pairs();
        extractor
    }
    
    /// 生成随机采样点对
    fn generate_sampling_pairs(&mut self) {
        let mut rng = StdRng::seed_from_u64(self.seed);
        self.sampling_pairs.clear();
        
        let half_patch = self.patch_size / 2;
        
        for _ in 0..self.descriptor_length {
            // 生成两个随机采样点
            let p1 = Point2::new(
                rng.gen_range(-half_patch..=half_patch),
                rng.gen_range(-half_patch..=half_patch),
            );
            let p2 = Point2::new(
                rng.gen_range(-half_patch..=half_patch),
                rng.gen_range(-half_patch..=half_patch),
            );
            
            self.sampling_pairs.push((p1, p2));
        }
    }
    
    /// 为单个特征点计算 BRIEF 描述符
    pub fn compute_descriptor(
        &self,
        image: &GrayImage,
        keypoint: &Point2<f64>,
    ) -> Result<Vec<u8>, ColmapError> {
        let (width, height) = image.dimensions();
        let x = keypoint.x as i32;
        let y = keypoint.y as i32;
        
        // 检查边界
        let half_patch = self.patch_size / 2;
        if x < half_patch || y < half_patch || 
           x >= (width as i32 - half_patch) || y >= (height as i32 - half_patch) {
            return Err(ColmapError::InvalidParameter(
                "Keypoint too close to image boundary".to_string()
            ));
        }
        
        let mut descriptor = vec![0u8; (self.descriptor_length + 7) / 8]; // 字节数组
        
        for (bit_idx, (p1, p2)) in self.sampling_pairs.iter().enumerate() {
            // 计算采样点的实际坐标
            let x1 = (x + p1.x) as u32;
            let y1 = (y + p1.y) as u32;
            let x2 = (x + p2.x) as u32;
            let y2 = (y + p2.y) as u32;
            
            // 获取像素值
            let pixel1 = image.get_pixel(x1, y1)[0];
            let pixel2 = image.get_pixel(x2, y2)[0];
            
            // 比较像素值并设置对应的位
            if pixel1 < pixel2 {
                let byte_idx = bit_idx / 8;
                let bit_pos = bit_idx % 8;
                descriptor[byte_idx] |= 1 << bit_pos;
            }
        }
        
        Ok(descriptor)
    }
    
    /// 为多个特征点批量计算描述符
    pub fn compute_descriptors(
        &self,
        image: &GrayImage,
        keypoints: &[Point2<f64>],
    ) -> Result<Vec<Vec<u8>>, ColmapError> {
        let mut descriptors = Vec::with_capacity(keypoints.len());
        
        for keypoint in keypoints {
            match self.compute_descriptor(image, keypoint) {
                Ok(desc) => descriptors.push(desc),
                Err(_) => {
                    // 对于边界附近的点,使用空描述符
                    descriptors.push(vec![0u8; (self.descriptor_length + 7) / 8]);
                }
            }
        }
        
        Ok(descriptors)
    }
    
    /// 计算两个 BRIEF 描述符之间的汉明距离
    pub fn hamming_distance(desc1: &[u8], desc2: &[u8]) -> u32 {
        if desc1.len() != desc2.len() {
            return u32::MAX; // 长度不匹配
        }
        
        let mut distance = 0u32;
        for (b1, b2) in desc1.iter().zip(desc2.iter()) {
            distance += (b1 ^ b2).count_ones();
        }
        
        distance
    }
    
    /// 获取描述符参数
    pub fn params(&self) -> HashMap<String, f64> {
        let mut params = HashMap::new();
        params.insert("descriptor_length".to_string(), self.descriptor_length as f64);
        params.insert("patch_size".to_string(), self.patch_size as f64);
        params.insert("seed".to_string(), self.seed as f64);
        params
    }
    
    /// 设置描述符参数
    pub fn set_params(&mut self, params: HashMap<String, f64>) -> Result<(), ColmapError> {
        let mut need_regenerate = false;
        
        for (key, value) in params {
            match key.as_str() {
                "descriptor_length" => {
                    self.descriptor_length = value as usize;
                    need_regenerate = true;
                },
                "patch_size" => {
                    self.patch_size = value as i32;
                    need_regenerate = true;
                },
                "seed" => {
                    self.seed = value as u64;
                    need_regenerate = true;
                },
                _ => return Err(ColmapError::InvalidParameter(
                    format!("Unknown parameter: {}", key)
                )),
            }
        }
        
        if need_regenerate {
            self.generate_sampling_pairs();
        }
        
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use image::{ImageBuffer, Luma};
    
    fn create_test_image() -> GrayImage {
        ImageBuffer::from_fn(100, 100, |x, y| {
            Luma([((x + y) % 256) as u8])
        })
    }
    
    #[test]
    fn test_brief_extractor_creation() {
        let extractor = BriefExtractor::default();
        assert_eq!(extractor.descriptor_length, 256);
        assert_eq!(extractor.patch_size, 15);
        assert_eq!(extractor.sampling_pairs.len(), 256);
    }
    
    #[test]
    fn test_brief_extractor_custom() {
        let extractor = BriefExtractor::new(128, 21, 123);
        assert_eq!(extractor.descriptor_length, 128);
        assert_eq!(extractor.patch_size, 21);
        assert_eq!(extractor.seed, 123);
        assert_eq!(extractor.sampling_pairs.len(), 128);
    }
    
    #[test]
    fn test_compute_descriptor() {
        let extractor = BriefExtractor::default();
        let image = create_test_image();
        let keypoint = Point2::new(50.0, 50.0);
        
        let descriptor = extractor.compute_descriptor(&image, &keypoint);
        assert!(descriptor.is_ok());
        
        let desc = descriptor.unwrap();
        assert_eq!(desc.len(), (256 + 7) / 8); // 32 字节
    }
    
    #[test]
    fn test_compute_descriptors_batch() {
        let extractor = BriefExtractor::default();
        let image = create_test_image();
        let keypoints = vec![
            Point2::new(30.0, 30.0),
            Point2::new(50.0, 50.0),
            Point2::new(70.0, 70.0),
        ];
        
        let descriptors = extractor.compute_descriptors(&image, &keypoints);
        assert!(descriptors.is_ok());
        
        let descs = descriptors.unwrap();
        assert_eq!(descs.len(), 3);
        for desc in descs {
            assert_eq!(desc.len(), (256 + 7) / 8);
        }
    }
    
    #[test]
    fn test_hamming_distance() {
        let desc1 = vec![0b10101010, 0b11110000];
        let desc2 = vec![0b01010101, 0b00001111];
        
        let distance = BriefExtractor::hamming_distance(&desc1, &desc2);
        assert_eq!(distance, 16); // 所有位都不同
    }
    
    #[test]
    fn test_hamming_distance_same() {
        let desc1 = vec![0b10101010, 0b11110000];
        let desc2 = vec![0b10101010, 0b11110000];
        
        let distance = BriefExtractor::hamming_distance(&desc1, &desc2);
        assert_eq!(distance, 0); // 完全相同
    }
    
    #[test]
    fn test_boundary_keypoint() {
        let extractor = BriefExtractor::default();
        let image = create_test_image();
        let keypoint = Point2::new(5.0, 5.0); // 太靠近边界
        
        let descriptor = extractor.compute_descriptor(&image, &keypoint);
        assert!(descriptor.is_err());
    }
    
    #[test]
    fn test_params() {
        let extractor = BriefExtractor::default();
        let params = extractor.params();
        
        assert_eq!(params.get("descriptor_length"), Some(&256.0));
        assert_eq!(params.get("patch_size"), Some(&15.0));
        assert_eq!(params.get("seed"), Some(&42.0));
    }
    
    #[test]
    fn test_set_params() {
        let mut extractor = BriefExtractor::default();
        let mut new_params = HashMap::new();
        new_params.insert("descriptor_length".to_string(), 128.0);
        new_params.insert("patch_size".to_string(), 21.0);
        
        let result = extractor.set_params(new_params);
        assert!(result.is_ok());
        
        assert_eq!(extractor.descriptor_length, 128);
        assert_eq!(extractor.patch_size, 21);
        assert_eq!(extractor.sampling_pairs.len(), 128);
    }
}