use crate::core::{Feature, Point2, Result};
use crate::feature::detector::FeatureDetector;
use image::GrayImage;
#[derive(Debug, Clone)]
pub struct OrbDetector {
pub max_features: usize,
pub num_levels: usize,
pub scale_factor: f32,
pub fast_threshold: u8,
}
impl Default for OrbDetector {
fn default() -> Self {
Self {
max_features: 500,
num_levels: 8,
scale_factor: 1.2,
fast_threshold: 20,
}
}
}
impl OrbDetector {
pub fn new() -> Self {
Self::default()
}
pub fn with_params(
max_features: usize,
num_levels: usize,
scale_factor: f32,
fast_threshold: u8,
) -> Self {
Self {
max_features,
num_levels,
scale_factor,
fast_threshold,
}
}
fn build_pyramid(&self, image: &GrayImage) -> Vec<GrayImage> {
let mut pyramid = vec![image.clone()];
let mut current_image = image.clone();
for _ in 1..self.num_levels {
let (width, height) = current_image.dimensions();
let new_width = (width as f32 / self.scale_factor) as u32;
let new_height = (height as f32 / self.scale_factor) as u32;
if new_width < 10 || new_height < 10 {
break;
}
current_image = image::imageops::resize(
¤t_image,
new_width,
new_height,
image::imageops::FilterType::Lanczos3,
);
pyramid.push(current_image.clone());
}
pyramid
}
fn detect_fast_corners(&self, image: &GrayImage) -> Vec<(u32, u32, f32)> {
let mut corners = Vec::new();
let (width, height) = image.dimensions();
let circle_offsets = [
(0, -3), (1, -3), (2, -2), (3, -1),
(3, 0), (3, 1), (2, 2), (1, 3),
(0, 3), (-1, 3), (-2, 2), (-3, 1),
(-3, 0), (-3, -1), (-2, -2), (-1, -3),
];
for y in 3..height-3 {
for x in 3..width-3 {
let center_intensity = image.get_pixel(x, y)[0];
let threshold = self.fast_threshold;
let mut brighter_count = 0;
let mut darker_count = 0;
for &(dx, dy) in &circle_offsets {
let px = (x as i32 + dx) as u32;
let py = (y as i32 + dy) as u32;
let intensity = image.get_pixel(px, py)[0];
if intensity > center_intensity.saturating_add(threshold) {
brighter_count += 1;
} else if intensity < center_intensity.saturating_sub(threshold) {
darker_count += 1;
}
}
if brighter_count >= 9 || darker_count >= 9 {
let response = self.calculate_fast_response(image, x, y, &circle_offsets);
corners.push((x, y, response));
}
}
}
corners
}
fn calculate_fast_response(
&self,
image: &GrayImage,
x: u32,
y: u32,
circle_offsets: &[(i32, i32)],
) -> f32 {
let center_intensity = image.get_pixel(x, y)[0] as f32;
let mut sum_diff = 0.0;
for &(dx, dy) in circle_offsets {
let px = (x as i32 + dx) as u32;
let py = (y as i32 + dy) as u32;
let intensity = image.get_pixel(px, py)[0] as f32;
sum_diff += (intensity - center_intensity).abs();
}
sum_diff / circle_offsets.len() as f32
}
fn compute_orientation(&self, image: &GrayImage, x: u32, y: u32) -> f32 {
let (width, height) = image.dimensions();
let radius = 15;
let mut m01 = 0.0;
let mut m10 = 0.0;
for dy in -radius..=radius {
for dx in -radius..=radius {
let px = x as i32 + dx;
let py = y as i32 + dy;
if px >= 0 && px < width as i32 && py >= 0 && py < height as i32 {
let intensity = image.get_pixel(px as u32, py as u32)[0] as f32;
m01 += dy as f32 * intensity;
m10 += dx as f32 * intensity;
}
}
}
m01.atan2(m10)
}
fn non_maximum_suppression(&self, corners: &[(u32, u32, f32)]) -> Vec<(u32, u32, f32)> {
let mut suppressed = Vec::new();
let radius = 3;
for (i, &(x1, y1, response1)) in corners.iter().enumerate() {
let mut is_maximum = true;
for (j, &(x2, y2, response2)) in corners.iter().enumerate() {
if i == j { continue; }
let dx = (x1 as i32 - x2 as i32).abs();
let dy = (y1 as i32 - y2 as i32).abs();
if dx <= radius && dy <= radius && response2 > response1 {
is_maximum = false;
break;
}
}
if is_maximum {
suppressed.push((x1, y1, response1));
}
}
suppressed
}
}
impl FeatureDetector for OrbDetector {
fn name(&self) -> &str {
"ORB"
}
fn params(&self) -> std::collections::HashMap<String, f64> {
let mut params = std::collections::HashMap::new();
params.insert("max_features".to_string(), self.max_features as f64);
params.insert("num_levels".to_string(), self.num_levels as f64);
params.insert("scale_factor".to_string(), self.scale_factor as f64);
params.insert("fast_threshold".to_string(), self.fast_threshold as f64);
params
}
fn set_params(&mut self, params: std::collections::HashMap<String, f64>) -> Result<()> {
if let Some(&max_features) = params.get("max_features") {
self.max_features = max_features as usize;
}
if let Some(&num_levels) = params.get("num_levels") {
self.num_levels = num_levels as usize;
}
if let Some(&scale_factor) = params.get("scale_factor") {
self.scale_factor = scale_factor as f32;
}
if let Some(&fast_threshold) = params.get("fast_threshold") {
self.fast_threshold = fast_threshold as u8;
}
Ok(())
}
fn detect(&self, image: &GrayImage) -> Result<Vec<Feature>> {
let mut all_features = Vec::new();
let pyramid = self.build_pyramid(image);
for (level, level_image) in pyramid.iter().enumerate() {
let corners = self.detect_fast_corners(level_image);
let suppressed_corners = self.non_maximum_suppression(&corners);
for (x, y, response) in suppressed_corners {
let scale_factor = self.scale_factor.powi(level as i32);
let orientation = self.compute_orientation(level_image, x, y);
let feature = Feature {
point: Point2::new(
(x as f64) * scale_factor as f64,
(y as f64) * scale_factor as f64,
),
descriptor: Vec::new(), scale: scale_factor,
angle: orientation,
response,
octave: level as i32,
point3d_id: None,
};
all_features.push(feature);
}
}
all_features.sort_by(|a, b| b.response.partial_cmp(&a.response).unwrap());
all_features.truncate(self.max_features);
Ok(all_features)
}
}
#[cfg(test)]
mod tests {
use super::*;
use image::{ImageBuffer, Luma};
#[test]
fn test_orb_detector_creation() {
let detector = OrbDetector::new();
assert_eq!(detector.max_features, 500);
assert_eq!(detector.num_levels, 8);
}
#[test]
fn test_orb_detection() {
let detector = OrbDetector::new();
let image = ImageBuffer::from_fn(100, 100, |x, y| {
Luma([((x + y) % 256) as u8])
});
let result = detector.detect(&image);
assert!(result.is_ok());
}
#[test]
fn test_pyramid_building() {
let detector = OrbDetector::new();
let image = ImageBuffer::from_fn(100, 100, |_, _| Luma([128]));
let pyramid = detector.build_pyramid(&image);
assert!(!pyramid.is_empty());
assert_eq!(pyramid[0].dimensions(), (100, 100));
}
}