use crate::core::{Feature, ColmapError};
use image::GrayImage;
use nalgebra::Point2;
use rand::{Rng, SeedableRng};
use rand::rngs::StdRng;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct BriefExtractor {
pub descriptor_length: usize,
pub patch_size: i32,
pub seed: u64,
sampling_pairs: Vec<(Point2<i32>, Point2<i32>)>,
}
impl Default for BriefExtractor {
fn default() -> Self {
let mut extractor = Self {
descriptor_length: 256, patch_size: 15, seed: 42,
sampling_pairs: Vec::new(),
};
extractor.generate_sampling_pairs();
extractor
}
}
impl BriefExtractor {
pub fn new(descriptor_length: usize, patch_size: i32, seed: u64) -> Self {
let mut extractor = Self {
descriptor_length,
patch_size,
seed,
sampling_pairs: Vec::new(),
};
extractor.generate_sampling_pairs();
extractor
}
fn generate_sampling_pairs(&mut self) {
let mut rng = StdRng::seed_from_u64(self.seed);
self.sampling_pairs.clear();
let half_patch = self.patch_size / 2;
for _ in 0..self.descriptor_length {
let p1 = Point2::new(
rng.gen_range(-half_patch..=half_patch),
rng.gen_range(-half_patch..=half_patch),
);
let p2 = Point2::new(
rng.gen_range(-half_patch..=half_patch),
rng.gen_range(-half_patch..=half_patch),
);
self.sampling_pairs.push((p1, p2));
}
}
pub fn compute_descriptor(
&self,
image: &GrayImage,
keypoint: &Point2<f64>,
) -> Result<Vec<u8>, ColmapError> {
let (width, height) = image.dimensions();
let x = keypoint.x as i32;
let y = keypoint.y as i32;
let half_patch = self.patch_size / 2;
if x < half_patch || y < half_patch ||
x >= (width as i32 - half_patch) || y >= (height as i32 - half_patch) {
return Err(ColmapError::InvalidParameter(
"Keypoint too close to image boundary".to_string()
));
}
let mut descriptor = vec![0u8; (self.descriptor_length + 7) / 8];
for (bit_idx, (p1, p2)) in self.sampling_pairs.iter().enumerate() {
let x1 = (x + p1.x) as u32;
let y1 = (y + p1.y) as u32;
let x2 = (x + p2.x) as u32;
let y2 = (y + p2.y) as u32;
let pixel1 = image.get_pixel(x1, y1)[0];
let pixel2 = image.get_pixel(x2, y2)[0];
if pixel1 < pixel2 {
let byte_idx = bit_idx / 8;
let bit_pos = bit_idx % 8;
descriptor[byte_idx] |= 1 << bit_pos;
}
}
Ok(descriptor)
}
pub fn compute_descriptors(
&self,
image: &GrayImage,
keypoints: &[Point2<f64>],
) -> Result<Vec<Vec<u8>>, ColmapError> {
let mut descriptors = Vec::with_capacity(keypoints.len());
for keypoint in keypoints {
match self.compute_descriptor(image, keypoint) {
Ok(desc) => descriptors.push(desc),
Err(_) => {
descriptors.push(vec![0u8; (self.descriptor_length + 7) / 8]);
}
}
}
Ok(descriptors)
}
pub fn hamming_distance(desc1: &[u8], desc2: &[u8]) -> u32 {
if desc1.len() != desc2.len() {
return u32::MAX; }
let mut distance = 0u32;
for (b1, b2) in desc1.iter().zip(desc2.iter()) {
distance += (b1 ^ b2).count_ones();
}
distance
}
pub fn params(&self) -> HashMap<String, f64> {
let mut params = HashMap::new();
params.insert("descriptor_length".to_string(), self.descriptor_length as f64);
params.insert("patch_size".to_string(), self.patch_size as f64);
params.insert("seed".to_string(), self.seed as f64);
params
}
pub fn set_params(&mut self, params: HashMap<String, f64>) -> Result<(), ColmapError> {
let mut need_regenerate = false;
for (key, value) in params {
match key.as_str() {
"descriptor_length" => {
self.descriptor_length = value as usize;
need_regenerate = true;
},
"patch_size" => {
self.patch_size = value as i32;
need_regenerate = true;
},
"seed" => {
self.seed = value as u64;
need_regenerate = true;
},
_ => return Err(ColmapError::InvalidParameter(
format!("Unknown parameter: {}", key)
)),
}
}
if need_regenerate {
self.generate_sampling_pairs();
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use image::{ImageBuffer, Luma};
fn create_test_image() -> GrayImage {
ImageBuffer::from_fn(100, 100, |x, y| {
Luma([((x + y) % 256) as u8])
})
}
#[test]
fn test_brief_extractor_creation() {
let extractor = BriefExtractor::default();
assert_eq!(extractor.descriptor_length, 256);
assert_eq!(extractor.patch_size, 15);
assert_eq!(extractor.sampling_pairs.len(), 256);
}
#[test]
fn test_brief_extractor_custom() {
let extractor = BriefExtractor::new(128, 21, 123);
assert_eq!(extractor.descriptor_length, 128);
assert_eq!(extractor.patch_size, 21);
assert_eq!(extractor.seed, 123);
assert_eq!(extractor.sampling_pairs.len(), 128);
}
#[test]
fn test_compute_descriptor() {
let extractor = BriefExtractor::default();
let image = create_test_image();
let keypoint = Point2::new(50.0, 50.0);
let descriptor = extractor.compute_descriptor(&image, &keypoint);
assert!(descriptor.is_ok());
let desc = descriptor.unwrap();
assert_eq!(desc.len(), (256 + 7) / 8); }
#[test]
fn test_compute_descriptors_batch() {
let extractor = BriefExtractor::default();
let image = create_test_image();
let keypoints = vec![
Point2::new(30.0, 30.0),
Point2::new(50.0, 50.0),
Point2::new(70.0, 70.0),
];
let descriptors = extractor.compute_descriptors(&image, &keypoints);
assert!(descriptors.is_ok());
let descs = descriptors.unwrap();
assert_eq!(descs.len(), 3);
for desc in descs {
assert_eq!(desc.len(), (256 + 7) / 8);
}
}
#[test]
fn test_hamming_distance() {
let desc1 = vec![0b10101010, 0b11110000];
let desc2 = vec![0b01010101, 0b00001111];
let distance = BriefExtractor::hamming_distance(&desc1, &desc2);
assert_eq!(distance, 16); }
#[test]
fn test_hamming_distance_same() {
let desc1 = vec![0b10101010, 0b11110000];
let desc2 = vec![0b10101010, 0b11110000];
let distance = BriefExtractor::hamming_distance(&desc1, &desc2);
assert_eq!(distance, 0); }
#[test]
fn test_boundary_keypoint() {
let extractor = BriefExtractor::default();
let image = create_test_image();
let keypoint = Point2::new(5.0, 5.0);
let descriptor = extractor.compute_descriptor(&image, &keypoint);
assert!(descriptor.is_err());
}
#[test]
fn test_params() {
let extractor = BriefExtractor::default();
let params = extractor.params();
assert_eq!(params.get("descriptor_length"), Some(&256.0));
assert_eq!(params.get("patch_size"), Some(&15.0));
assert_eq!(params.get("seed"), Some(&42.0));
}
#[test]
fn test_set_params() {
let mut extractor = BriefExtractor::default();
let mut new_params = HashMap::new();
new_params.insert("descriptor_length".to_string(), 128.0);
new_params.insert("patch_size".to_string(), 21.0);
let result = extractor.set_params(new_params);
assert!(result.is_ok());
assert_eq!(extractor.descriptor_length, 128);
assert_eq!(extractor.patch_size, 21);
assert_eq!(extractor.sampling_pairs.len(), 128);
}
}