use crate::{Result, VisionError};
use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView2};
use std::collections::HashMap;
use torsh_nn::Module;
use torsh_tensor::Tensor;
#[derive(Debug, Clone)]
pub struct Feature {
pub x: f32,
pub y: f32,
pub response: f32,
pub descriptor: Array1<f32>,
pub scale: f32,
pub orientation: f32,
}
#[derive(Debug, Clone)]
pub struct SuperPointConfig {
pub detection_threshold: f32,
pub nms_radius: usize,
pub max_features: usize,
pub use_gpu: bool,
}
impl Default for SuperPointConfig {
fn default() -> Self {
Self {
detection_threshold: 0.015,
nms_radius: 4,
max_features: 1000,
use_gpu: false,
}
}
}
pub struct SuperPointDetector {
config: SuperPointConfig,
}
impl SuperPointDetector {
pub fn new(config: SuperPointConfig) -> Result<Self> {
Ok(Self { config })
}
pub fn detect(&self, _image: &Tensor) -> Result<Vec<Feature>> {
Err(VisionError::InvalidParameter(
"SuperPoint detection not yet implemented - requires neural network integration"
.to_string(),
))
}
pub fn detect_and_compute(&self, image: &Tensor) -> Result<Vec<Feature>> {
self.detect(image)
}
}
#[derive(Debug, Clone)]
pub struct LearnedSiftConfig {
pub num_octaves: usize,
pub num_scales: usize,
pub detection_threshold: f32,
pub edge_threshold: f32,
pub contrast_threshold: f32,
}
impl Default for LearnedSiftConfig {
fn default() -> Self {
Self {
num_octaves: 4,
num_scales: 3,
detection_threshold: 0.04,
edge_threshold: 10.0,
contrast_threshold: 0.03,
}
}
}
pub struct LearnedSiftDetector {
config: LearnedSiftConfig,
}
impl LearnedSiftDetector {
pub fn new(config: LearnedSiftConfig) -> Result<Self> {
Ok(Self { config })
}
pub fn detect(&self, _image: &Tensor) -> Result<Vec<Feature>> {
Err(VisionError::InvalidParameter(
"Learned SIFT not yet implemented - requires neural network integration".to_string(),
))
}
}
#[derive(Debug, Clone)]
pub struct AttentionMatcherConfig {
pub num_heads: usize,
pub hidden_dim: usize,
pub match_threshold: f32,
pub mutual_match: bool,
}
impl Default for AttentionMatcherConfig {
fn default() -> Self {
Self {
num_heads: 8,
hidden_dim: 256,
match_threshold: 0.5,
mutual_match: true,
}
}
}
#[derive(Debug, Clone)]
pub struct FeatureMatch {
pub idx1: usize,
pub idx2: usize,
pub confidence: f32,
pub distance: f32,
}
pub struct AttentionMatcher {
config: AttentionMatcherConfig,
}
impl AttentionMatcher {
pub fn new(config: AttentionMatcherConfig) -> Result<Self> {
Ok(Self { config })
}
pub fn match_features(
&self,
features1: &[Feature],
features2: &[Feature],
) -> Result<Vec<FeatureMatch>> {
if features1.is_empty() || features2.is_empty() {
return Ok(Vec::new());
}
let desc1 = self.stack_descriptors(features1)?;
let desc2 = self.stack_descriptors(features2)?;
let similarity = self.compute_attention_similarity(&desc1, &desc2)?;
let matches = self.extract_matches(&similarity)?;
Ok(matches)
}
fn stack_descriptors(&self, features: &[Feature]) -> Result<Array2<f32>> {
if features.is_empty() {
return Err(VisionError::InvalidParameter(
"Cannot stack empty feature set".to_string(),
));
}
let desc_dim = features[0].descriptor.len();
let mut descriptors = Array2::zeros((features.len(), desc_dim));
for (i, feature) in features.iter().enumerate() {
if feature.descriptor.len() != desc_dim {
return Err(VisionError::InvalidParameter(
"Inconsistent descriptor dimensions".to_string(),
));
}
for (j, &val) in feature.descriptor.iter().enumerate() {
descriptors[[i, j]] = val;
}
}
Ok(descriptors)
}
fn compute_attention_similarity(
&self,
desc1: &Array2<f32>,
desc2: &Array2<f32>,
) -> Result<Array2<f32>> {
let n1 = desc1.nrows();
let n2 = desc2.nrows();
let mut similarity = Array2::zeros((n1, n2));
for i in 0..n1 {
for j in 0..n2 {
let mut dot = 0.0;
let mut norm1 = 0.0;
let mut norm2 = 0.0;
for k in 0..desc1.ncols() {
let v1 = desc1[[i, k]];
let v2 = desc2[[j, k]];
dot += v1 * v2;
norm1 += v1 * v1;
norm2 += v2 * v2;
}
similarity[[i, j]] = if norm1 > 0.0 && norm2 > 0.0 {
dot / (norm1.sqrt() * norm2.sqrt())
} else {
0.0
};
}
}
Ok(similarity)
}
fn extract_matches(&self, similarity: &Array2<f32>) -> Result<Vec<FeatureMatch>> {
let mut matches = Vec::new();
let n1 = similarity.nrows();
let n2 = similarity.ncols();
for i in 0..n1 {
let mut best_j = 0;
let mut best_score = similarity[[i, 0]];
for j in 1..n2 {
if similarity[[i, j]] > best_score {
best_score = similarity[[i, j]];
best_j = j;
}
}
if best_score < self.config.match_threshold {
continue;
}
if self.config.mutual_match {
let mut reverse_best_i = 0;
let mut reverse_best_score = similarity[[0, best_j]];
for ii in 1..n1 {
if similarity[[ii, best_j]] > reverse_best_score {
reverse_best_score = similarity[[ii, best_j]];
reverse_best_i = ii;
}
}
if reverse_best_i != i {
continue; }
}
matches.push(FeatureMatch {
idx1: i,
idx2: best_j,
confidence: best_score,
distance: 1.0 - best_score, });
}
Ok(matches)
}
}
#[derive(Debug, Clone)]
pub struct MultiScaleConfig {
pub num_scales: usize,
pub scale_factor: f32,
pub independent_scales: bool,
}
impl Default for MultiScaleConfig {
fn default() -> Self {
Self {
num_scales: 5,
scale_factor: 1.2,
independent_scales: false,
}
}
}
pub struct MultiScaleDetector<D> {
detector: D,
config: MultiScaleConfig,
}
impl<D> MultiScaleDetector<D> {
pub fn new(detector: D, config: MultiScaleConfig) -> Self {
Self { detector, config }
}
}
pub struct BruteForceMatcher {
metric: DistanceMetric,
cross_check: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DistanceMetric {
L2,
L1,
Hamming,
Cosine,
}
impl BruteForceMatcher {
pub fn new(metric: DistanceMetric, cross_check: bool) -> Self {
Self {
metric,
cross_check,
}
}
pub fn match_features(
&self,
features1: &[Feature],
features2: &[Feature],
) -> Result<Vec<FeatureMatch>> {
if features1.is_empty() || features2.is_empty() {
return Ok(Vec::new());
}
let mut matches = Vec::new();
for (i, f1) in features1.iter().enumerate() {
let mut best_dist = f32::MAX;
let mut best_j = 0;
for (j, f2) in features2.iter().enumerate() {
let dist = self.compute_distance(&f1.descriptor, &f2.descriptor)?;
if dist < best_dist {
best_dist = dist;
best_j = j;
}
}
if self.cross_check {
let mut reverse_best_dist = f32::MAX;
let mut reverse_best_i = 0;
for (i2, f1_2) in features1.iter().enumerate() {
let dist =
self.compute_distance(&features2[best_j].descriptor, &f1_2.descriptor)?;
if dist < reverse_best_dist {
reverse_best_dist = dist;
reverse_best_i = i2;
}
}
if reverse_best_i != i {
continue; }
}
matches.push(FeatureMatch {
idx1: i,
idx2: best_j,
confidence: 1.0 / (1.0 + best_dist), distance: best_dist,
});
}
Ok(matches)
}
fn compute_distance(&self, desc1: &Array1<f32>, desc2: &Array1<f32>) -> Result<f32> {
if desc1.len() != desc2.len() {
return Err(VisionError::InvalidParameter(
"Descriptor dimensions mismatch".to_string(),
));
}
let dist = match self.metric {
DistanceMetric::L2 => {
let mut sum = 0.0;
for i in 0..desc1.len() {
let diff = desc1[i] - desc2[i];
sum += diff * diff;
}
sum.sqrt()
}
DistanceMetric::L1 => {
let mut sum = 0.0;
for i in 0..desc1.len() {
sum += (desc1[i] - desc2[i]).abs();
}
sum
}
DistanceMetric::Hamming => {
let mut count = 0;
for i in 0..desc1.len() {
if (desc1[i] > 0.5) != (desc2[i] > 0.5) {
count += 1;
}
}
count as f32
}
DistanceMetric::Cosine => {
let mut dot = 0.0;
let mut norm1 = 0.0;
let mut norm2 = 0.0;
for i in 0..desc1.len() {
dot += desc1[i] * desc2[i];
norm1 += desc1[i] * desc1[i];
norm2 += desc2[i] * desc2[i];
}
if norm1 > 0.0 && norm2 > 0.0 {
1.0 - dot / (norm1.sqrt() * norm2.sqrt())
} else {
1.0
}
}
};
Ok(dist)
}
}
pub fn apply_ratio_test(matches: &[FeatureMatch], ratio_threshold: f32) -> Vec<FeatureMatch> {
let mut matches_by_src: HashMap<usize, Vec<&FeatureMatch>> = HashMap::new();
for m in matches {
matches_by_src
.entry(m.idx1)
.or_insert_with(Vec::new)
.push(m);
}
let mut filtered = Vec::new();
for (_, mut group) in matches_by_src {
if group.len() < 2 {
if let Some(&m) = group.first() {
filtered.push(m.clone());
}
continue;
}
group.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
let best = group[0];
let second_best = group[1];
if best.distance / second_best.distance < ratio_threshold {
filtered.push(best.clone());
}
}
filtered
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_superpoint_config_default() {
let config = SuperPointConfig::default();
assert_eq!(config.detection_threshold, 0.015);
assert_eq!(config.nms_radius, 4);
assert_eq!(config.max_features, 1000);
}
#[test]
fn test_attention_matcher_config_default() {
let config = AttentionMatcherConfig::default();
assert_eq!(config.num_heads, 8);
assert_eq!(config.hidden_dim, 256);
assert_eq!(config.match_threshold, 0.5);
assert!(config.mutual_match);
}
#[test]
fn test_brute_force_matcher_l2() {
let matcher = BruteForceMatcher::new(DistanceMetric::L2, false);
let f1 = Feature {
x: 0.0,
y: 0.0,
response: 1.0,
descriptor: Array1::from_vec(vec![1.0, 0.0, 0.0]),
scale: 1.0,
orientation: 0.0,
};
let f2 = Feature {
x: 1.0,
y: 1.0,
response: 1.0,
descriptor: Array1::from_vec(vec![0.9, 0.1, 0.0]),
scale: 1.0,
orientation: 0.0,
};
let matches = matcher
.match_features(&[f1.clone()], &[f2.clone()])
.expect("Matching failed");
assert_eq!(matches.len(), 1);
}
#[test]
fn test_attention_matcher_empty_features() {
let matcher = AttentionMatcher::new(AttentionMatcherConfig::default())
.expect("Failed to create matcher");
let matches = matcher.match_features(&[], &[]).expect("Matching failed");
assert_eq!(matches.len(), 0);
}
#[test]
fn test_ratio_test() {
let matches = vec![
FeatureMatch {
idx1: 0,
idx2: 0,
confidence: 0.9,
distance: 0.1,
},
FeatureMatch {
idx1: 0,
idx2: 1,
confidence: 0.5,
distance: 0.5,
},
];
let filtered = apply_ratio_test(&matches, 0.8);
assert_eq!(filtered.len(), 1); }
#[test]
fn test_multi_scale_config_default() {
let config = MultiScaleConfig::default();
assert_eq!(config.num_scales, 5);
assert_eq!(config.scale_factor, 1.2);
}
}