use crate::core::{Feature, Point2, Result};
use crate::feature::detector::FeatureDetector;
use image::{GrayImage, ImageBuffer, Luma};
#[derive(Debug, Clone)]
pub struct SiftDetector {
pub num_octaves: usize,
pub num_scales: usize,
pub contrast_threshold: f32,
pub edge_threshold: f32,
pub sigma: f32,
}
impl Default for SiftDetector {
fn default() -> Self {
Self {
num_octaves: 4,
num_scales: 3,
contrast_threshold: 0.04,
edge_threshold: 10.0,
sigma: 1.6,
}
}
}
impl SiftDetector {
pub fn new() -> Self {
Self::default()
}
pub fn with_params(
num_octaves: usize,
num_scales: usize,
contrast_threshold: f32,
edge_threshold: f32,
sigma: f32,
) -> Self {
Self {
num_octaves,
num_scales,
contrast_threshold,
edge_threshold,
sigma,
}
}
fn build_gaussian_pyramid(&self, image: &GrayImage) -> Vec<Vec<GrayImage>> {
let mut pyramid = Vec::new();
let mut current_image = image.clone();
for octave in 0..self.num_octaves {
let mut octave_images = Vec::new();
for scale in 0..=self.num_scales + 2 {
let sigma = self.sigma * (2.0_f32).powf(scale as f32 / self.num_scales as f32);
let blurred = self.gaussian_blur(¤t_image, sigma);
octave_images.push(blurred);
}
pyramid.push(octave_images);
if octave < self.num_octaves - 1 {
current_image = self.downsample(¤t_image);
}
}
pyramid
}
fn gaussian_blur(&self, image: &GrayImage, sigma: f32) -> GrayImage {
imageproc::filter::gaussian_blur_f32(image, sigma)
}
fn downsample(&self, image: &GrayImage) -> GrayImage {
let (width, height) = image.dimensions();
let new_width = width / 2;
let new_height = height / 2;
ImageBuffer::from_fn(new_width, new_height, |x, y| {
*image.get_pixel(x * 2, y * 2)
})
}
fn build_dog_pyramid(&self, gaussian_pyramid: &[Vec<GrayImage>]) -> Vec<Vec<GrayImage>> {
let mut dog_pyramid = Vec::new();
for octave_images in gaussian_pyramid {
let mut dog_octave = Vec::new();
for i in 1..octave_images.len() {
let dog = self.subtract_images(&octave_images[i], &octave_images[i-1]);
dog_octave.push(dog);
}
dog_pyramid.push(dog_octave);
}
dog_pyramid
}
fn subtract_images(&self, img1: &GrayImage, img2: &GrayImage) -> GrayImage {
let (width, height) = img1.dimensions();
ImageBuffer::from_fn(width, height, |x, y| {
let val1 = img1.get_pixel(x, y)[0] as i16;
let val2 = img2.get_pixel(x, y)[0] as i16;
let diff = (val1 - val2).unsigned_abs() as u8;
Luma([diff])
})
}
fn detect_extrema(&self, dog_pyramid: &[Vec<GrayImage>]) -> Vec<Feature> {
let mut keypoints = Vec::new();
for (octave_idx, octave_images) in dog_pyramid.iter().enumerate() {
for (scale_idx, image) in octave_images.iter().enumerate() {
if scale_idx == 0 || scale_idx == octave_images.len() - 1 {
continue; }
let (width, height) = image.dimensions();
for y in 1..height-1 {
for x in 1..width-1 {
if self.is_extremum(octave_images, scale_idx, x, y) {
let feature = Feature {
point: Point2::new(
(x as f64) * (2.0_f64).powf(octave_idx as f64),
(y as f64) * (2.0_f64).powf(octave_idx as f64)
),
descriptor: Vec::new(),
scale: self.sigma * (2.0_f32).powf(
octave_idx as f32 + scale_idx as f32 / self.num_scales as f32
),
angle: 0.0, response: image.get_pixel(x, y)[0] as f32,
octave: octave_idx as i32,
point3d_id: None,
};
keypoints.push(feature);
}
}
}
}
}
keypoints
}
fn is_extremum(&self, octave_images: &[GrayImage], scale_idx: usize, x: u32, y: u32) -> bool {
let current_val = octave_images[scale_idx].get_pixel(x, y)[0];
for dz in -1..=1 {
let z = (scale_idx as i32 + dz) as usize;
if z >= octave_images.len() { continue; }
for dy in -1..=1 {
for dx in -1..=1 {
if dx == 0 && dy == 0 && dz == 0 { continue; }
let nx = (x as i32 + dx) as u32;
let ny = (y as i32 + dy) as u32;
let neighbor_val = octave_images[z].get_pixel(nx, ny)[0];
if current_val <= neighbor_val {
return false; }
}
}
}
true
}
}
impl FeatureDetector for SiftDetector {
fn name(&self) -> &str {
"SIFT"
}
fn params(&self) -> std::collections::HashMap<String, f64> {
let mut params = std::collections::HashMap::new();
params.insert("num_octaves".to_string(), self.num_octaves as f64);
params.insert("num_scales".to_string(), self.num_scales as f64);
params.insert("contrast_threshold".to_string(), self.contrast_threshold as f64);
params.insert("edge_threshold".to_string(), self.edge_threshold as f64);
params.insert("sigma".to_string(), self.sigma as f64);
params
}
fn set_params(&mut self, params: std::collections::HashMap<String, f64>) -> Result<()> {
if let Some(&num_octaves) = params.get("num_octaves") {
self.num_octaves = num_octaves as usize;
}
if let Some(&num_scales) = params.get("num_scales") {
self.num_scales = num_scales as usize;
}
if let Some(&contrast_threshold) = params.get("contrast_threshold") {
self.contrast_threshold = contrast_threshold as f32;
}
if let Some(&edge_threshold) = params.get("edge_threshold") {
self.edge_threshold = edge_threshold as f32;
}
if let Some(&sigma) = params.get("sigma") {
self.sigma = sigma as f32;
}
Ok(())
}
fn detect(&self, image: &GrayImage) -> Result<Vec<Feature>> {
let gaussian_pyramid = self.build_gaussian_pyramid(image);
let dog_pyramid = self.build_dog_pyramid(&gaussian_pyramid);
let keypoints = self.detect_extrema(&dog_pyramid);
Ok(keypoints)
}
}
#[cfg(test)]
mod tests {
use super::*;
use image::{ImageBuffer, Luma};
#[test]
fn test_sift_detector_creation() {
let detector = SiftDetector::new();
assert_eq!(detector.num_octaves, 4);
assert_eq!(detector.num_scales, 3);
}
#[test]
fn test_sift_detection() {
let detector = SiftDetector::new();
let image = ImageBuffer::from_fn(100, 100, |x, y| {
Luma([(x + y) as u8])
});
let result = detector.detect(&image);
assert!(result.is_ok());
}
}