use crate::error::{Result, VisionError};
use crate::feature::KeyPoint;
use crate::gpu_ops::GpuVisionContext;
use scirs2_core::ndarray::ArrayStatCompat;
use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView2};
use statrs::statistics::Statistics;
pub struct NeuralFeatureNetwork {
#[allow(dead_code)]
detection_weights: ModelWeights,
#[allow(dead_code)]
descriptor_weights: ModelWeights,
gpu_context: Option<GpuVisionContext>,
config: NeuralFeatureConfig,
}
#[derive(Clone)]
pub struct ModelWeights {
#[allow(dead_code)]
conv_weights: Vec<Array3<f32>>,
#[allow(dead_code)]
conv_biases: Vec<Array1<f32>>,
#[allow(dead_code)]
bn_weights: Vec<Array1<f32>>,
#[allow(dead_code)]
bn_biases: Vec<Array1<f32>>,
#[allow(dead_code)]
fc_weights: Vec<Array2<f32>>,
#[allow(dead_code)]
fc_biases: Vec<Array1<f32>>,
}
#[derive(Clone)]
pub struct NeuralFeatureConfig {
pub input_size: (usize, usize),
pub max_keypoints: usize,
pub detection_threshold: f32,
pub nms_radius: usize,
pub descriptor_dim: usize,
pub border_remove: usize,
pub use_gpu: bool,
}
impl Default for NeuralFeatureConfig {
fn default() -> Self {
Self {
input_size: (480, 640),
max_keypoints: 1024,
detection_threshold: 0.015,
nms_radius: 4,
descriptor_dim: 256,
border_remove: 4,
use_gpu: true,
}
}
}
pub struct SuperPointNet {
network: NeuralFeatureNetwork,
}
impl SuperPointNet {
pub fn new(config: Option<NeuralFeatureConfig>) -> Result<Self> {
let config = config.unwrap_or_default();
let detection_weights = Self::create_detection_weights(&config)?;
let descriptor_weights = Self::create_descriptor_weights(&config)?;
let gpu_context = if config.use_gpu {
GpuVisionContext::new().ok()
} else {
None
};
let network = NeuralFeatureNetwork {
detection_weights,
descriptor_weights,
gpu_context,
config,
};
Ok(Self { network })
}
#[allow(dead_code)]
pub fn from_file(_modelpath: &str, config: Option<NeuralFeatureConfig>) -> Result<Self> {
let config = config.unwrap_or_default();
Self::new(Some(config))
}
pub fn detect_and_describe(
&self,
image: &ArrayView2<f32>,
) -> Result<(Vec<KeyPoint>, Array2<f32>)> {
let (height, width) = image.dim();
if height % 8 != 0 || width % 8 != 0 {
return Err(VisionError::InvalidInput(
"Input image dimensions must be multiples of 8 for neural feature detection"
.to_string(),
));
}
let processed_image = if (height, width) != self.network.config.input_size {
self.resize_image(image, self.network.config.input_size)?
} else {
image.to_owned()
};
if let Some(ref gpu_ctx) = self.network.gpu_context {
self.gpu_inference(gpu_ctx, &processed_image.view())
} else {
self.cpu_inference(&processed_image.view())
}
}
fn gpu_inference(
&self,
gpu_ctx: &GpuVisionContext,
image: &ArrayView2<f32>,
) -> Result<(Vec<KeyPoint>, Array2<f32>)> {
let featuremap = self.gpu_forward_detection(gpu_ctx, image)?;
let descriptor_map = self.gpu_forward_descriptors(gpu_ctx, image)?;
self.post_process_features(&featuremap, &descriptor_map)
}
fn cpu_inference(&self, image: &ArrayView2<f32>) -> Result<(Vec<KeyPoint>, Array2<f32>)> {
let featuremap = self.cpu_forward_detection(image)?;
let descriptor_map = self.cpu_forward_descriptors(image)?;
self.post_process_features(&featuremap, &descriptor_map)
}
fn gpu_forward_detection(
&self,
gpu_ctx: &GpuVisionContext,
image: &ArrayView2<f32>,
) -> Result<Array2<f32>> {
let conv1_kernel =
Array2::from_shape_vec((3, 3), vec![-1.0, 0.0, 1.0, -2.0, 0.0, 2.0, -1.0, 0.0, 1.0])?;
let conv1_result = crate::gpu_ops::gpu_convolve_2d(gpu_ctx, image, &conv1_kernel.view())?;
let activated = conv1_result.mapv(|x| x.max(0.0));
let pooled = crate::gpu_ops::gpu_gaussian_blur(gpu_ctx, &activated.view(), 2.0)?;
let (height, width) = pooled.dim();
let out_height = height / 8;
let out_width = width / 8;
let mut detection_map = Array2::zeros((out_height, out_width));
for y in 0..out_height {
for x in 0..out_width {
let src_y = (y * 8).min(height - 1);
let src_x = (x * 8).min(width - 1);
detection_map[[y, x]] = pooled[[src_y, src_x]].abs();
}
}
Ok(detection_map)
}
fn cpu_forward_detection(&self, image: &ArrayView2<f32>) -> Result<Array2<f32>> {
let (_, _, magnitude) = self.compute_simple_gradients(image)?;
let (height, width) = magnitude.dim();
let out_height = height / 8;
let out_width = width / 8;
let mut detection_map = Array2::zeros((out_height, out_width));
for y in 0..out_height {
for x in 0..out_width {
let mut max_val = 0.0f32;
for dy in 0..8 {
for dx in 0..8 {
let src_y = (y * 8 + dy).min(height - 1);
let src_x = (x * 8 + dx).min(width - 1);
max_val = max_val.max(magnitude[[src_y, src_x]]);
}
}
detection_map[[y, x]] = max_val;
}
}
Ok(detection_map)
}
fn gpu_forward_descriptors(
&self,
gpu_ctx: &GpuVisionContext,
image: &ArrayView2<f32>,
) -> Result<Array3<f32>> {
let blurred = crate::gpu_ops::gpu_gaussian_blur(gpu_ctx, image, 1.0)?;
let (height, width) = blurred.dim();
let desc_height = height / 8;
let desc_width = width / 8;
let desc_dim = self.network.config.descriptor_dim;
let mut descriptor_map = Array3::zeros((desc_height, desc_width, desc_dim));
for y in 0..desc_height {
for x in 0..desc_width {
let patch_y = y * 8;
let patch_x = x * 8;
let mut descriptor = Array1::zeros(desc_dim);
for i in 0..desc_dim {
let dy = i % 16;
let dx = i / 16;
let sample_y = (patch_y + dy).min(height - 1);
let sample_x = (patch_x + dx).min(width - 1);
descriptor[i] = blurred[[sample_y, sample_x]];
}
let norm = descriptor.dot(&descriptor).sqrt();
if norm > 1e-6 {
descriptor.mapv_inplace(|x| x / norm);
}
descriptor_map.slice_mut(s![y, x, ..]).assign(&descriptor);
}
}
Ok(descriptor_map)
}
fn compute_simple_gradients(
&self,
image: &ArrayView2<f32>,
) -> Result<(Array2<f32>, Array2<f32>, Array2<f32>)> {
let (height, width) = image.dim();
let mut gx = Array2::zeros((height, width));
let mut gy = Array2::zeros((height, width));
let mut magnitude = Array2::zeros((height, width));
for y in 1..height - 1 {
for x in 1..width - 1 {
let dx = image[[y, x + 1]] - image[[y, x - 1]];
let dy = image[[y + 1, x]] - image[[y - 1, x]];
gx[[y, x]] = dx;
gy[[y, x]] = dy;
magnitude[[y, x]] = (dx * dx + dy * dy).sqrt();
}
}
Ok((gx, gy, magnitude))
}
fn simple_gaussian_blur(&self, image: &ArrayView2<f32>, sigma: f32) -> Result<Array2<f32>> {
let (height, width) = image.dim();
let mut blurred = Array2::zeros((height, width));
for y in 1..height - 1 {
for x in 1..width - 1 {
let avg = (image[[y - 1, x - 1]]
+ image[[y - 1, x]]
+ image[[y - 1, x + 1]]
+ image[[y, x - 1]]
+ image[[y, x]]
+ image[[y, x + 1]]
+ image[[y + 1, x - 1]]
+ image[[y + 1, x]]
+ image[[y + 1, x + 1]])
/ 9.0;
blurred[[y, x]] = avg;
}
}
for y in 0..height {
blurred[[y, 0]] = image[[y, 0]];
if width > 1 {
blurred[[y, width - 1]] = image[[y, width - 1]];
}
}
for x in 0..width {
blurred[[0, x]] = image[[0, x]];
if height > 1 {
blurred[[height - 1, x]] = image[[height - 1, x]];
}
}
Ok(blurred)
}
fn cpu_forward_descriptors(&self, image: &ArrayView2<f32>) -> Result<Array3<f32>> {
let blurred = self.simple_gaussian_blur(image, 1.0)?;
let (height, width) = blurred.dim();
let desc_height = height / 8;
let desc_width = width / 8;
let desc_dim = self.network.config.descriptor_dim;
let mut descriptor_map = Array3::zeros((desc_height, desc_width, desc_dim));
for y in 0..desc_height {
for x in 0..desc_width {
let patch_y = y * 8;
let patch_x = x * 8;
let mut descriptor = Array1::zeros(desc_dim);
for i in 0..desc_dim.min(64) {
let angle = i as f32 * std::f32::consts::PI / 32.0;
let cos_a = angle.cos();
let sin_a = angle.sin();
let mut sum = 0.0f32;
for dy in 0..8 {
for dx in 0..8 {
let sample_y = (patch_y + dy).min(height - 1);
let sample_x = (patch_x + dx).min(width - 1);
let value = blurred[[sample_y, sample_x]];
let weight = (cos_a * dx as f32 + sin_a * dy as f32).cos();
sum += value * weight;
}
}
descriptor[i] = sum;
}
let norm = descriptor.dot(&descriptor).sqrt();
if norm > 1e-6 {
descriptor.mapv_inplace(|x| x / norm);
}
descriptor_map.slice_mut(s![y, x, ..]).assign(&descriptor);
}
}
Ok(descriptor_map)
}
fn post_process_features(
&self,
featuremap: &Array2<f32>,
descriptor_map: &Array3<f32>,
) -> Result<(Vec<KeyPoint>, Array2<f32>)> {
let nms_result = self.non_maximum_suppression(featuremap)?;
let mut candidates: Vec<(f32, usize, usize)> = Vec::new();
let (height, width) = nms_result.dim();
for y in self.network.config.border_remove..height - self.network.config.border_remove {
for x in self.network.config.border_remove..width - self.network.config.border_remove {
let score = nms_result[[y, x]];
if score > self.network.config.detection_threshold {
candidates.push((score, y, x));
}
}
}
candidates.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("Operation failed"));
candidates.truncate(self.network.config.max_keypoints);
let mut keypoints = Vec::new();
let mut descriptors = Array2::zeros((candidates.len(), self.network.config.descriptor_dim));
for (i, &(score, y, x)) in candidates.iter().enumerate() {
let orig_x = (x * 8) as f32;
let orig_y = (y * 8) as f32;
keypoints.push(KeyPoint {
x: orig_x,
y: orig_y,
response: score,
scale: 1.0,
orientation: 0.0, });
if y < descriptor_map.shape()[0] && x < descriptor_map.shape()[1] {
let desc = descriptor_map.slice(s![y, x, ..]);
descriptors.slice_mut(s![i, ..]).assign(&desc);
}
}
Ok((keypoints, descriptors))
}
fn non_maximum_suppression(&self, featuremap: &Array2<f32>) -> Result<Array2<f32>> {
let (height, width) = featuremap.dim();
let mut nms_result = Array2::zeros((height, width));
let radius = self.network.config.nms_radius;
for y in radius..height - radius {
for x in radius..width - radius {
let center_val = featuremap[[y, x]];
let mut is_maximum = true;
for dy in -(radius as isize)..=(radius as isize) {
for dx in -(radius as isize)..=(radius as isize) {
if dy == 0 && dx == 0 {
continue;
}
let ny = (y as isize + dy) as usize;
let nx = (x as isize + dx) as usize;
if featuremap[[ny, nx]] >= center_val {
is_maximum = false;
break;
}
}
if !is_maximum {
break;
}
}
if is_maximum {
nms_result[[y, x]] = center_val;
}
}
}
Ok(nms_result)
}
fn resize_image(
&self,
image: &ArrayView2<f32>,
target_size: (usize, usize),
) -> Result<Array2<f32>> {
let (src_height, src_width) = image.dim();
let (dst_height, dst_width) = target_size;
let mut resized = Array2::zeros((dst_height, dst_width));
let scale_y = src_height as f32 / dst_height as f32;
let scale_x = src_width as f32 / dst_width as f32;
for y in 0..dst_height {
for x in 0..dst_width {
let src_y = (y as f32 * scale_y) as usize;
let src_x = (x as f32 * scale_x) as usize;
let src_y = src_y.min(src_height - 1);
let src_x = src_x.min(src_width - 1);
resized[[y, x]] = image[[src_y, src_x]];
}
}
Ok(resized)
}
fn create_detection_weights(config: &NeuralFeatureConfig) -> Result<ModelWeights> {
let conv_weights = vec![
Array3::from_shape_fn((64, 1, 3), |___| scirs2_core::random::random::<f32>() * 0.1),
Array3::from_shape_fn((64, 64, 3), |___| {
scirs2_core::random::random::<f32>() * 0.1
}),
Array3::from_shape_fn((128, 64, 3), |___| {
scirs2_core::random::random::<f32>() * 0.1
}),
Array3::from_shape_fn((128, 128, 3), |___| {
scirs2_core::random::random::<f32>() * 0.1
}),
];
let conv_biases = vec![
Array1::zeros(64),
Array1::zeros(64),
Array1::zeros(128),
Array1::zeros(128),
];
let bn_weights = vec![
Array1::ones(64),
Array1::ones(64),
Array1::ones(128),
Array1::ones(128),
];
let bn_biases = vec![
Array1::zeros(64),
Array1::zeros(64),
Array1::zeros(128),
Array1::zeros(128),
];
let fc_weights = vec![Array2::from_shape_fn((65, 128), |_| {
scirs2_core::random::random::<f32>() * 0.1
})];
let fc_biases = vec![
Array1::zeros(65), ];
Ok(ModelWeights {
conv_weights,
conv_biases,
bn_weights,
bn_biases,
fc_weights,
fc_biases,
})
}
fn create_descriptor_weights(config: &NeuralFeatureConfig) -> Result<ModelWeights> {
let fc_weights = vec![Array2::from_shape_fn((config.descriptor_dim, 128), |_| {
scirs2_core::random::random::<f32>() * 0.1
})];
let fc_biases = vec![Array1::zeros(config.descriptor_dim)];
Ok(ModelWeights {
conv_weights: Vec::new(),
conv_biases: Vec::new(),
bn_weights: Vec::new(),
bn_biases: Vec::new(),
fc_weights,
fc_biases,
})
}
}
pub struct NeuralFeatureMatcher {
distance_threshold: f32,
ratio_threshold: f32,
#[allow(dead_code)]
use_gpu: bool,
}
impl Default for NeuralFeatureMatcher {
fn default() -> Self {
Self::new()
}
}
impl NeuralFeatureMatcher {
pub fn new() -> Self {
Self {
distance_threshold: 0.7,
ratio_threshold: 0.8,
use_gpu: true,
}
}
pub fn with_params(mut self, distance_threshold: f32, ratiothreshold: f32) -> Self {
self.distance_threshold = distance_threshold;
self.ratio_threshold = ratiothreshold;
self
}
pub fn match_descriptors(
&self,
desc1: &ArrayView2<f32>,
desc2: &ArrayView2<f32>,
) -> Result<Vec<(usize, usize)>> {
let n1 = desc1.shape()[0];
let n2 = desc2.shape()[0];
if n1 == 0 || n2 == 0 {
return Ok(Vec::new());
}
let distances = self.compute_pairwise_distances(desc1, desc2)?;
let mut matches = Vec::new();
for i in 0..n1 {
let mut best_dist = f32::INFINITY;
let mut second_best_dist = f32::INFINITY;
let mut best_idx = 0;
for j in 0..n2 {
let dist = distances[[i, j]];
if dist < best_dist {
second_best_dist = best_dist;
best_dist = dist;
best_idx = j;
} else if dist < second_best_dist {
second_best_dist = dist;
}
}
if best_dist < self.distance_threshold
&& best_dist / second_best_dist < self.ratio_threshold
{
matches.push((i, best_idx));
}
}
Ok(matches)
}
fn compute_pairwise_distances(
&self,
desc1: &ArrayView2<f32>,
desc2: &ArrayView2<f32>,
) -> Result<Array2<f32>> {
let n1 = desc1.shape()[0];
let n2 = desc2.shape()[0];
let mut distances = Array2::zeros((n1, n2));
for i in 0..n1 {
for j in 0..n2 {
let desc1_row = desc1.slice(s![i, ..]);
let desc2_row = desc2.slice(s![j, ..]);
let dot_product = desc1_row.dot(&desc2_row);
let norm1 = desc1_row.dot(&desc1_row).sqrt();
let norm2 = desc2_row.dot(&desc2_row).sqrt();
let cosine_sim = if norm1 > 1e-6 && norm2 > 1e-6 {
dot_product / (norm1 * norm2)
} else {
0.0
};
distances[[i, j]] = 1.0 - cosine_sim;
}
}
Ok(distances)
}
}
pub struct AttentionFeatureMatcher {
#[allow(dead_code)]
attention_dim: usize,
#[allow(dead_code)]
numheads: usize,
#[allow(dead_code)]
use_gpu: bool,
}
impl AttentionFeatureMatcher {
pub fn new(_attention_dim: usize, numheads: usize) -> Self {
Self {
attention_dim: _attention_dim,
numheads,
use_gpu: true,
}
}
pub fn match_with_attention(
&self,
keypoints1: &[KeyPoint],
descriptors1: &ArrayView2<f32>,
keypoints2: &[KeyPoint],
descriptors2: &ArrayView2<f32>,
) -> Result<Vec<(usize, usize)>> {
let n1 = descriptors1.shape()[0];
let n2 = descriptors2.shape()[0];
if n1 == 0 || n2 == 0 {
return Ok(Vec::new());
}
let pos_enc1 = self.compute_positional_encoding(keypoints1)?;
let pos_enc2 = self.compute_positional_encoding(keypoints2)?;
let enhanced_desc1 = self.enhance_descriptors(descriptors1, &pos_enc1)?;
let enhanced_desc2 = self.enhance_descriptors(descriptors2, &pos_enc2)?;
let attention_scores = self.compute_attention_scores(&enhanced_desc1, &enhanced_desc2)?;
self.extract_matches_from_attention(&attention_scores)
}
fn compute_positional_encoding(&self, keypoints: &[KeyPoint]) -> Result<Array2<f32>> {
let n = keypoints.len();
let mut pos_encoding = Array2::zeros((n, 4));
for (i, kp) in keypoints.iter().enumerate() {
pos_encoding[[i, 0]] = kp.x / 1000.0; pos_encoding[[i, 1]] = kp.y / 1000.0;
pos_encoding[[i, 2]] = (kp.x * 0.01).cos();
pos_encoding[[i, 3]] = (kp.y * 0.01).sin();
}
Ok(pos_encoding)
}
fn enhance_descriptors(
&self,
descriptors: &ArrayView2<f32>,
pos_encoding: &Array2<f32>,
) -> Result<Array2<f32>> {
let n = descriptors.shape()[0];
let desc_dim = descriptors.shape()[1];
let pos_dim = pos_encoding.shape()[1];
let mut enhanced = Array2::zeros((n, desc_dim + pos_dim));
for i in 0..n {
enhanced
.slice_mut(s![i, ..desc_dim])
.assign(&descriptors.slice(s![i, ..]));
enhanced
.slice_mut(s![i, desc_dim..])
.assign(&pos_encoding.slice(s![i, ..]));
}
Ok(enhanced)
}
fn compute_attention_scores(
&self,
desc1: &Array2<f32>,
desc2: &Array2<f32>,
) -> Result<Array2<f32>> {
let n1 = desc1.shape()[0];
let n2 = desc2.shape()[0];
let dim = desc1.shape()[1];
let mut attention_scores = Array2::zeros((n1, n2));
let scale = 1.0 / (dim as f32).sqrt();
for i in 0..n1 {
for j in 0..n2 {
let query = desc1.slice(s![i, ..]);
let key = desc2.slice(s![j, ..]);
let score = query.dot(&key) * scale;
attention_scores[[i, j]] = score;
}
}
for i in 0..n1 {
let mut row = attention_scores.slice_mut(s![i, ..]);
let max_val = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
row.mapv_inplace(|x| (x - max_val).exp());
let sum = row.sum();
if sum > 1e-8 {
row.mapv_inplace(|x| x / sum);
}
}
Ok(attention_scores)
}
fn extract_matches_from_attention(
&self,
attention_scores: &Array2<f32>,
) -> Result<Vec<(usize, usize)>> {
let n1 = attention_scores.shape()[0];
let n2 = attention_scores.shape()[1];
let mut matches = Vec::new();
let mut used_j = vec![false; n2];
for i in 0..n1 {
let mut best_score = 0.0;
let mut best_j = None;
for j in 0..n2 {
if !used_j[j] && attention_scores[[i, j]] > best_score {
best_score = attention_scores[[i, j]];
best_j = Some(j);
}
}
if let Some(j) = best_j {
if best_score > 0.1 {
matches.push((i, j));
used_j[j] = true;
}
}
}
Ok(matches)
}
}
pub struct LearnedSIFT {
siftconfig: SIFTConfig,
enhancement_network: Option<NeuralFeatureNetwork>,
}
#[derive(Clone)]
pub struct SIFTConfig {
pub num_octaves: usize,
pub num_scales: usize,
pub sigma: f32,
pub edge_threshold: f32,
pub peak_threshold: f32,
}
impl Default for SIFTConfig {
fn default() -> Self {
Self {
num_octaves: 4,
num_scales: 3,
sigma: 1.6,
edge_threshold: 10.0,
peak_threshold: 0.03,
}
}
}
impl LearnedSIFT {
pub fn new(config: Option<SIFTConfig>) -> Self {
Self {
siftconfig: config.unwrap_or_default(),
enhancement_network: None,
}
}
fn simple_gaussian_blur(&self, image: &ArrayView2<f32>, sigma: f32) -> Result<Array2<f32>> {
let (height, width) = image.dim();
let mut blurred = Array2::zeros((height, width));
for y in 1..height - 1 {
for x in 1..width - 1 {
let avg = (image[[y - 1, x - 1]]
+ image[[y - 1, x]]
+ image[[y - 1, x + 1]]
+ image[[y, x - 1]]
+ image[[y, x]]
+ image[[y, x + 1]]
+ image[[y + 1, x - 1]]
+ image[[y + 1, x]]
+ image[[y + 1, x + 1]])
/ 9.0;
blurred[[y, x]] = avg;
}
}
for y in 0..height {
blurred[[y, 0]] = image[[y, 0]];
if width > 1 {
blurred[[y, width - 1]] = image[[y, width - 1]];
}
}
for x in 0..width {
blurred[[0, x]] = image[[0, x]];
if height > 1 {
blurred[[height - 1, x]] = image[[height - 1, x]];
}
}
Ok(blurred)
}
pub fn detect_keypoints(&self, image: &ArrayView2<f32>) -> Result<Vec<KeyPoint>> {
let scalespace = self.build_scale_space(image)?;
let dogspace = self.compute_dog_space(&scalespace)?;
let extrema = self.detect_extrema(&dogspace)?;
let refined_keypoints = self.refine_keypoints(&extrema, &dogspace)?;
let filtered_keypoints = self.filter_keypoints(&refined_keypoints, &dogspace)?;
Ok(filtered_keypoints)
}
pub fn compute_descriptors(
&self,
image: &ArrayView2<f32>,
keypoints: &[KeyPoint],
) -> Result<Array2<f32>> {
let mut descriptors = Array2::zeros((keypoints.len(), 128));
for (i, kp) in keypoints.iter().enumerate() {
let descriptor = self.compute_sift_descriptor(image, kp)?;
descriptors.slice_mut(s![i, ..]).assign(&descriptor);
}
if let Some(ref network) = self.enhancement_network {
self.enhance_descriptors_neural(&mut descriptors, network)?;
}
Ok(descriptors)
}
fn build_scale_space(&self, image: &ArrayView2<f32>) -> Result<Vec<Vec<Array2<f32>>>> {
let mut scalespace = Vec::new();
let mut current_image = image.to_owned();
for octave in 0..self.siftconfig.num_octaves {
let mut octave_images = Vec::new();
for scale in 0..self.siftconfig.num_scales + 3 {
let sigma = self.siftconfig.sigma
* 2.0_f32.powf(scale as f32 / self.siftconfig.num_scales as f32);
let blurred = self.simple_gaussian_blur(¤t_image.view(), sigma)?;
octave_images.push(blurred);
}
scalespace.push(octave_images);
if octave < self.siftconfig.num_octaves - 1 {
current_image = self.downsample(¤t_image)?;
}
}
Ok(scalespace)
}
fn compute_dog_space(&self, scalespace: &[Vec<Array2<f32>>]) -> Result<Vec<Vec<Array2<f32>>>> {
let mut dogspace = Vec::new();
for octave_images in scalespace {
let mut dog_octave = Vec::new();
for i in 0..octave_images.len() - 1 {
let dog = &octave_images[i + 1] - &octave_images[i];
dog_octave.push(dog);
}
dogspace.push(dog_octave);
}
Ok(dogspace)
}
fn detect_extrema(&self, dogspace: &[Vec<Array2<f32>>]) -> Result<Vec<KeyPoint>> {
let mut extrema = Vec::new();
for (octave, dog_octave) in dogspace.iter().enumerate() {
for (scale, dog_image) in dog_octave
.iter()
.enumerate()
.skip(1)
.take(dog_octave.len() - 2)
{
let (height, width) = dog_image.dim();
for y in 1..height - 1 {
for x in 1..width - 1 {
let center_val = dog_image[[y, x]];
if center_val.abs() < self.siftconfig.peak_threshold {
continue;
}
if self.is_extremum(dog_octave, scale, y, x, center_val) {
extrema.push(KeyPoint {
x: x as f32 * 2.0_f32.powi(octave as i32),
y: y as f32 * 2.0_f32.powi(octave as i32),
response: center_val.abs(),
scale: 2.0_f32.powi(octave as i32),
orientation: 0.0,
});
}
}
}
}
}
Ok(extrema)
}
fn is_extremum(
&self,
dog_octave: &[Array2<f32>],
scale: usize,
y: usize,
x: usize,
center_val: f32,
) -> bool {
let is_max = center_val > 0.0;
for s_offset in -1_isize..=1_isize {
let s = (scale as isize + s_offset) as usize;
for dy in -1_isize..=1_isize {
for dx in -1_isize..=1_isize {
if s_offset == 0 && dy == 0 && dx == 0 {
continue;
}
let ny = (y as isize + dy) as usize;
let nx = (x as isize + dx) as usize;
let neighbor_val = dog_octave[s][[ny, nx]];
if is_max && neighbor_val >= center_val {
return false;
}
if !is_max && neighbor_val <= center_val {
return false;
}
}
}
}
true
}
fn refine_keypoints(
&self,
keypoints: &[KeyPoint],
_dog_space: &[Vec<Array2<f32>>],
) -> Result<Vec<KeyPoint>> {
Ok(keypoints.to_vec())
}
fn filter_keypoints(
&self,
keypoints: &[KeyPoint],
_dog_space: &[Vec<Array2<f32>>],
) -> Result<Vec<KeyPoint>> {
let mut filtered = Vec::new();
for kp in keypoints {
if kp.response > self.siftconfig.peak_threshold {
filtered.push(kp.clone());
}
}
Ok(filtered)
}
fn compute_sift_descriptor(
&self,
image: &ArrayView2<f32>,
keypoint: &KeyPoint,
) -> Result<Array1<f32>> {
let mut descriptor = Array1::zeros(128);
let (height, width) = image.dim();
let x = keypoint.x as usize;
let y = keypoint.y as usize;
for i in 0..128 {
let angle = i as f32 * 2.0 * std::f32::consts::PI / 128.0;
let radius = 8.0 + (i % 16) as f32;
let sample_x = x as f32 + radius * angle.cos();
let sample_y = y as f32 + radius * angle.sin();
if sample_x >= 0.0
&& sample_x < width as f32
&& sample_y >= 0.0
&& sample_y < height as f32
{
let sx = sample_x as usize;
let sy = sample_y as usize;
descriptor[i] = image[[sy.min(height - 1), sx.min(width - 1)]];
}
}
let norm = descriptor.dot(&descriptor).sqrt();
if norm > 1e-6 {
descriptor.mapv_inplace(|x| x / norm);
}
Ok(descriptor)
}
fn enhance_descriptors_neural(
&self,
descriptors: &mut Array2<f32>,
_network: &NeuralFeatureNetwork,
) -> Result<()> {
for mut row in descriptors.rows_mut() {
let mean = row.mean_or(0.0);
let std = ((row.mapv(|x| (x - mean).powi(2)).mean_or(0.0)).sqrt()).max(1e-6);
row.mapv_inplace(|x| (x - mean) / std);
}
Ok(())
}
fn downsample(&self, image: &Array2<f32>) -> Result<Array2<f32>> {
let (height, width) = image.dim();
let new_height = height / 2;
let new_width = width / 2;
let mut downsampled = Array2::zeros((new_height, new_width));
for y in 0..new_height {
for x in 0..new_width {
downsampled[[y, x]] = image[[y * 2, x * 2]];
}
}
Ok(downsampled)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::arr2;
#[test]
fn test_superpoint_creation() {
let config = NeuralFeatureConfig {
input_size: (480, 640),
max_keypoints: 512,
use_gpu: false, ..Default::default()
};
let result = SuperPointNet::new(Some(config));
assert!(result.is_ok());
}
#[test]
fn test_superpoint_detection() {
let config = NeuralFeatureConfig {
input_size: (480, 640),
max_keypoints: 100,
use_gpu: false,
..Default::default()
};
if let Ok(superpoint) = SuperPointNet::new(Some(config)) {
let image = Array2::from_shape_fn((480, 640), |(y, x)| {
((x as f32 / 10.0).sin() + (y as f32 / 10.0).cos()) * 0.5 + 0.5
});
let result = superpoint.detect_and_describe(&image.view());
assert!(result.is_ok());
let (keypoints, descriptors) = result.expect("Operation failed");
assert!(!keypoints.is_empty());
assert_eq!(descriptors.shape()[0], keypoints.len());
}
}
#[test]
fn test_neural_feature_matcher() {
let matcher = NeuralFeatureMatcher::new();
let desc1 = arr2(&[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
]);
let desc2 = arr2(&[
[0.9, 0.1, 0.0, 0.0],
[0.0, 0.0, 0.9, 0.1],
[0.1, 0.9, 0.0, 0.0],
]);
let matches = matcher
.match_descriptors(&desc1.view(), &desc2.view())
.expect("Operation failed");
assert!(!matches.is_empty());
}
#[test]
fn test_learned_sift() {
let sift = LearnedSIFT::new(None);
let image = Array2::from_shape_fn((100, 100), |(y, x)| {
if (x as i32 - 50).abs() < 5 && (y as i32 - 50).abs() < 5 {
1.0
} else {
0.0
}
});
let keypoints = sift
.detect_keypoints(&image.view())
.expect("Operation failed");
if !keypoints.is_empty() {
let descriptors = sift
.compute_descriptors(&image.view(), &keypoints)
.expect("Operation failed");
assert_eq!(descriptors.shape()[0], keypoints.len());
assert_eq!(descriptors.shape()[1], 128);
}
}
#[test]
fn test_attention_matcher() {
let matcher = AttentionFeatureMatcher::new(64, 4);
let keypoints1 = vec![
KeyPoint {
x: 10.0,
y: 10.0,
response: 1.0,
scale: 1.0,
orientation: 0.0,
},
KeyPoint {
x: 20.0,
y: 20.0,
response: 1.0,
scale: 1.0,
orientation: 0.0,
},
];
let keypoints2 = vec![
KeyPoint {
x: 12.0,
y: 11.0,
response: 1.0,
scale: 1.0,
orientation: 0.0,
},
KeyPoint {
x: 50.0,
y: 50.0,
response: 1.0,
scale: 1.0,
orientation: 0.0,
},
];
let desc1 = Array2::from_shape_fn((2, 64), |__| scirs2_core::random::random::<f32>());
let desc2 = Array2::from_shape_fn((2, 64), |__| scirs2_core::random::random::<f32>());
let result =
matcher.match_with_attention(&keypoints1, &desc1.view(), &keypoints2, &desc2.view());
assert!(result.is_ok());
}
}