#![allow(dead_code)]
use crate::{Result, VisionError};
use scirs2_core::ndarray::{arr2, Array1, Array2, ArrayView2};
use scirs2_spatial::distance::{cdist, cosine, euclidean, manhattan, pdist};
use scirs2_spatial::simd_distance::{simd_euclidean_distance, simd_manhattan_distance};
use torsh_tensor::Tensor;
pub struct PatchMatcher {
use_simd: bool,
threshold: f64,
}
impl PatchMatcher {
pub fn new(use_simd: bool, threshold: f64) -> Self {
Self {
use_simd,
threshold,
}
}
pub fn find_similar_patches(
&self,
query_patches: &Array2<f64>,
database_patches: &Array2<f64>,
k: usize,
) -> Result<Array2<usize>> {
if self.use_simd && query_patches.ncols() >= 8 {
self.simd_patch_search(query_patches, database_patches, k)
} else {
self.standard_patch_search(query_patches, database_patches, k)
}
}
fn simd_patch_search(
&self,
query_patches: &Array2<f64>,
_database_patches: &Array2<f64>,
k: usize,
) -> Result<Array2<usize>> {
let num_queries = query_patches.nrows();
let indices = Array2::zeros((num_queries, k));
Ok(indices)
}
fn standard_patch_search(
&self,
query_patches: &Array2<f64>,
database_patches: &Array2<f64>,
k: usize,
) -> Result<Array2<usize>> {
let distances = cdist(query_patches, database_patches, |a, b| {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
})
.map_err(|e| VisionError::Other(anyhow::anyhow!("Distance computation failed: {}", e)))?;
let mut indices = Array2::zeros((query_patches.nrows(), k));
for (i, row) in distances.outer_iter().enumerate() {
let mut indexed_distances: Vec<(usize, f64)> =
row.iter().enumerate().map(|(j, &d)| (j, d)).collect();
indexed_distances
.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("comparison should succeed"));
for (rank, &(idx, _)) in indexed_distances.iter().take(k).enumerate() {
indices[[i, rank]] = idx;
}
}
Ok(indices)
}
pub fn compute_distance_matrix(
&self,
features1: &Array2<f64>,
features2: &Array2<f64>,
metric: &str,
) -> Result<Array2<f64>> {
let distances = match metric {
"euclidean" => {
if self.use_simd {
self.simd_distance_matrix(features1, features2, "euclidean")
} else {
cdist(features1, features2, |a, b| {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
})
.map_err(|e| {
VisionError::Other(anyhow::anyhow!("Distance computation failed: {}", e))
})
}
}
"manhattan" => {
if self.use_simd {
self.simd_distance_matrix(features1, features2, "manhattan")
} else {
cdist(features1, features2, |a, b| {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).abs())
.sum::<f64>()
})
.map_err(|e| {
VisionError::Other(anyhow::anyhow!("Distance computation failed: {}", e))
})
}
}
_ => {
cdist(features1, features2, |a, b| {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
})
.map_err(|e| {
VisionError::Other(anyhow::anyhow!("Distance computation failed: {}", e))
})
}
}?;
Ok(distances)
}
fn simd_distance_matrix(
&self,
features1: &Array2<f64>,
features2: &Array2<f64>,
_metric: &str,
) -> Result<Array2<f64>> {
let distances = Array2::zeros((features1.nrows(), features2.nrows()));
Ok(distances)
}
}
pub fn compute_image_similarity(image1: &Tensor, image2: &Tensor, metric: &str) -> Result<f64> {
let arr1 = tensor_to_array2(image1)?;
let arr2 = tensor_to_array2(image2)?;
let similarity = match metric {
"euclidean" => {
let dist = euclidean(
arr1.as_slice().expect("slice conversion should succeed"),
arr2.as_slice().expect("slice conversion should succeed"),
);
1.0 / (1.0 + dist) }
"cosine" => {
let sim = cosine(
arr1.as_slice().expect("slice conversion should succeed"),
arr2.as_slice().expect("slice conversion should succeed"),
);
1.0 - sim }
"manhattan" => {
let dist = manhattan(
arr1.as_slice().expect("slice conversion should succeed"),
arr2.as_slice().expect("slice conversion should succeed"),
);
1.0 / (1.0 + dist)
}
_ => {
return Err(VisionError::InvalidArgument(format!(
"Unknown metric: {}",
metric
)))
}
};
Ok(similarity)
}
fn tensor_to_array2(tensor: &Tensor) -> Result<Array2<f64>> {
let shape = tensor.shape();
let rows = shape.dims()[0];
let cols = if shape.dims().len() > 1 {
shape.dims()[1]
} else {
1
};
Ok(Array2::zeros((rows, cols)))
}
pub struct BatchDistanceComputer {
patch_matcher: PatchMatcher,
}
impl BatchDistanceComputer {
pub fn new(use_simd: bool) -> Self {
Self {
patch_matcher: PatchMatcher::new(use_simd, 0.7),
}
}
pub fn compute_batch_distances(&self, images: &[Tensor]) -> Result<Array2<f64>> {
let n = images.len();
let mut distances = Array2::zeros((n, n));
for i in 0..n {
for j in i + 1..n {
let similarity = compute_image_similarity(&images[i], &images[j], "euclidean")?;
distances[[i, j]] = 1.0 - similarity; distances[[j, i]] = distances[[i, j]]; }
}
Ok(distances)
}
pub fn find_similar_images(
&self,
query: &Tensor,
database: &[Tensor],
k: usize,
) -> Result<Vec<(usize, f64)>> {
let mut similarities = Vec::new();
for (idx, image) in database.iter().enumerate() {
let similarity = compute_image_similarity(query, image, "euclidean")?;
similarities.push((idx, similarity));
}
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("comparison should succeed"));
similarities.truncate(k);
Ok(similarities)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_patch_matcher_creation() {
let matcher = PatchMatcher::new(true, 0.7);
assert!(matcher.use_simd);
assert_eq!(matcher.threshold, 0.7);
}
#[test]
fn test_batch_distance_computer() {
let computer = BatchDistanceComputer::new(false);
assert!(!computer.patch_matcher.use_simd);
}
#[test]
fn test_distance_matrix_computation() {
let matcher = PatchMatcher::new(false, 0.5);
let features1 = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
let features2 = arr2(&[[2.0, 3.0], [4.0, 5.0]]);
let result = matcher.compute_distance_matrix(&features1, &features2, "euclidean");
assert!(result.is_ok());
}
}