use crate::error::Result;
use crate::feature::image_to_array;
use image::DynamicImage;
use scirs2_core::ndarray::Array2;
use std::f32::consts::PI;
#[derive(Debug, Clone, Default)]
pub struct KeyPoint {
pub x: f32,
pub y: f32,
pub scale: f32,
pub orientation: f32,
pub response: f32,
}
#[derive(Debug, Clone)]
pub struct Descriptor {
pub keypoint: KeyPoint,
pub vector: Vec<f32>,
}
#[allow(dead_code)]
pub fn detect_and_compute(
img: &DynamicImage,
max_features: usize,
threshold: f32,
) -> Result<Vec<Descriptor>> {
let gray = img.to_luma8();
let _width_height = gray.dimensions();
let array = image_to_array(img)?;
let (height, width) = array.dim();
let mut magnitude = Array2::zeros((height, width));
let mut orientation = Array2::zeros((height, width));
for y in 1..(height - 1) {
for x in 1..(width - 1) {
let dx = array[[y, x + 1]] - array[[y, x - 1]];
let dy = array[[y + 1, x]] - array[[y - 1, x]];
magnitude[[y, x]] = (dx * dx + dy * dy).sqrt();
orientation[[y, x]] = dy.atan2(dx);
}
}
let mut keypoints = Vec::new();
for y in 2..(height - 2) {
for x in 2..(width - 2) {
let current_mag = magnitude[[y, x]];
if current_mag < threshold {
continue;
}
let mut is_max = true;
'neighborhood: for ny in (y - 1)..=(y + 1) {
for nx in (x - 1)..=(x + 1) {
if (ny != y || nx != x) && magnitude[[ny, nx]] >= current_mag {
is_max = false;
break 'neighborhood;
}
}
}
if is_max {
keypoints.push(KeyPoint {
x: x as f32,
y: y as f32,
scale: 1.0,
orientation: orientation[[y, x]],
response: current_mag,
});
}
}
}
keypoints.sort_by(|a, b| {
b.response
.partial_cmp(&a.response)
.unwrap_or(std::cmp::Ordering::Equal)
});
if keypoints.len() > max_features {
keypoints.truncate(max_features);
}
let mut descriptors = Vec::new();
for kp in keypoints {
if kp.x < 8.0 || kp.x >= (width as f32 - 8.0) || kp.y < 8.0 || kp.y >= (height as f32 - 8.0)
{
continue;
}
let descriptor = compute_descriptor(&array, &magnitude, &orientation, &kp)?;
descriptors.push(Descriptor {
keypoint: kp,
vector: descriptor,
});
}
Ok(descriptors)
}
#[allow(dead_code)]
fn compute_descriptor(
image: &Array2<f32>,
magnitude: &Array2<f32>,
orientation: &Array2<f32>,
keypoint: &KeyPoint,
) -> Result<Vec<f32>> {
let (height, width) = image.dim();
let mut descriptor = vec![0.0; 128];
let cos_angle = keypoint.orientation.cos();
let sin_angle = keypoint.orientation.sin();
let sigma = 4.0;
let num_spatial_bins = 4;
let num_orientation_bins = 8;
let orientation_bin_width = 2.0 * PI / num_orientation_bins as f32;
for i in -8..8 {
for j in -8..8 {
let rotated_i = (cos_angle * i as f32 - sin_angle * j as f32).round() as isize;
let rotated_j = (sin_angle * i as f32 + cos_angle * j as f32).round() as isize;
let img_y = keypoint.y as isize + rotated_i;
let img_x = keypoint.x as isize + rotated_j;
if img_y < 0 || img_y >= height as isize || img_x < 0 || img_x >= width as isize {
continue;
}
let mag = magnitude[[img_y as usize, img_x as usize]];
let ori = orientation[[img_y as usize, img_x as usize]];
let bin_i = ((i as f32 + 8.0) * num_spatial_bins as f32 / 16.0).floor() as usize;
let bin_j = ((j as f32 + 8.0) * num_spatial_bins as f32 / 16.0).floor() as usize;
let bin_i = bin_i.min(num_spatial_bins - 1);
let bin_j = bin_j.min(num_spatial_bins - 1);
let rel_ori = (ori - keypoint.orientation + 2.0 * PI) % (2.0 * PI);
let ori_bin = (rel_ori / orientation_bin_width).floor() as usize % num_orientation_bins;
let weight =
(-(i as f32 * i as f32 + j as f32 * j as f32) / (2.0 * sigma * sigma)).exp();
let idx = (bin_i * num_spatial_bins + bin_j) * num_orientation_bins + ori_bin;
descriptor[idx] += mag * weight;
}
}
let mut norm = 0.0;
for val in &descriptor {
norm += val * val;
}
norm = norm.sqrt();
if norm > 1e-6 {
for val in &mut descriptor {
*val /= norm;
}
}
for val in &mut descriptor {
*val = (*val).min(0.2);
}
let mut norm = 0.0;
for val in &descriptor {
norm += val * val;
}
norm = norm.sqrt();
if norm > 1e-6 {
for val in &mut descriptor {
*val /= norm;
}
}
Ok(descriptor)
}
#[allow(dead_code)]
pub fn match_descriptors(
descriptors1: &[Descriptor],
descriptors2: &[Descriptor],
threshold: f32,
) -> Vec<(usize, usize, f32)> {
let mut matches = Vec::new();
for (i, desc1) in descriptors1.iter().enumerate() {
let mut best_distance = f32::MAX;
let mut best_index = 0;
let mut second_best_distance = f32::MAX;
for (j, desc2) in descriptors2.iter().enumerate() {
let distance = euclidean_distance(&desc1.vector, &desc2.vector);
if distance < best_distance {
second_best_distance = best_distance;
best_distance = distance;
best_index = j;
} else if distance < second_best_distance {
second_best_distance = distance;
}
}
if best_distance < threshold && best_distance < 0.7 * second_best_distance {
matches.push((i, best_index, best_distance));
}
}
matches.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
matches
}
#[allow(dead_code)]
fn euclidean_distance(vec1: &[f32], vec2: &[f32]) -> f32 {
let mut sum_sq = 0.0;
for i in 0..vec1.len().min(vec2.len()) {
let diff = vec1[i] - vec2[i];
sum_sq += diff * diff;
}
sum_sq.sqrt()
}