use crate::error::{SpatialError, SpatialResult};
use scirs2_core::ndarray::{Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::Complex64;
use std::f64::consts::PI;
use super::super::concepts::QuantumState;
#[derive(Debug, Clone)]
pub struct QuantumNearestNeighbor {
quantum_points: Vec<QuantumState>,
classical_points: Array2<f64>,
quantum_encoding: bool,
amplitude_amplification: bool,
grover_iterations: usize,
}
impl QuantumNearestNeighbor {
pub fn new(points: &ArrayView2<'_, f64>) -> SpatialResult<Self> {
let classical_points = points.to_owned();
let quantum_points = Vec::new();
Ok(Self {
quantum_points,
classical_points,
quantum_encoding: false,
amplitude_amplification: false,
grover_iterations: 3,
})
}
pub fn with_quantum_encoding(mut self, enabled: bool) -> Self {
self.quantum_encoding = enabled;
if enabled {
if let Ok(encoded) = self.encode_reference_points() {
self.quantum_points = encoded;
}
}
self
}
pub fn with_amplitude_amplification(mut self, enabled: bool) -> Self {
self.amplitude_amplification = enabled;
self
}
pub fn with_grover_iterations(mut self, iterations: usize) -> Self {
self.grover_iterations = iterations;
self
}
pub fn query_quantum(
&self,
query_point: &ArrayView1<f64>,
k: usize,
) -> SpatialResult<(Vec<usize>, Vec<f64>)> {
let n_points = self.classical_points.nrows();
if k > n_points {
return Err(SpatialError::InvalidInput(
"k cannot be larger than number of points".to_string(),
));
}
let mut distances = if self.quantum_encoding && !self.quantum_points.is_empty() {
self.quantum_distance_computation(query_point)?
} else {
self.classical_distance_computation(query_point)
};
if self.amplitude_amplification {
distances = self.apply_amplitude_amplification(distances)?;
}
let mut indexed_distances: Vec<(usize, f64)> = distances.into_iter().enumerate().collect();
indexed_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("Operation failed"));
let indices: Vec<usize> = indexed_distances
.iter()
.take(k)
.map(|(i_, _)| *i_)
.collect();
let dists: Vec<f64> = indexed_distances.iter().take(k).map(|(_, d)| *d).collect();
Ok((indices, dists))
}
fn encode_reference_points(&self) -> SpatialResult<Vec<QuantumState>> {
let (n_points, n_dims) = self.classical_points.dim();
let mut encoded_points = Vec::new();
for i in 0..n_points {
let point = self.classical_points.row(i);
let numqubits = (n_dims).next_power_of_two().trailing_zeros() as usize + 2;
let mut quantum_point = QuantumState::zero_state(numqubits);
for (dim, &coord) in point.iter().enumerate() {
if dim < numqubits - 1 {
let normalized_coord = (coord + 10.0) / 20.0; let angle = normalized_coord.clamp(0.0, 1.0) * PI;
quantum_point.phase_rotation(dim, angle)?;
}
}
for i in 0..numqubits - 1 {
quantum_point.controlled_rotation(i, i + 1, PI / 4.0)?;
}
encoded_points.push(quantum_point);
}
Ok(encoded_points)
}
fn quantum_distance_computation(
&self,
query_point: &ArrayView1<f64>,
) -> SpatialResult<Vec<f64>> {
let n_dims = query_point.len();
let mut distances = Vec::new();
let numqubits = n_dims.next_power_of_two().trailing_zeros() as usize + 2;
let mut query_state = QuantumState::zero_state(numqubits);
for (dim, &coord) in query_point.iter().enumerate() {
if dim < numqubits - 1 {
let normalized_coord = (coord + 10.0) / 20.0;
let angle = normalized_coord.clamp(0.0, 1.0) * PI;
query_state.phase_rotation(dim, angle)?;
}
}
for i in 0..numqubits - 1 {
query_state.controlled_rotation(i, i + 1, PI / 4.0)?;
}
for quantum_ref in &self.quantum_points {
let fidelity =
QuantumNearestNeighbor::calculate_quantum_fidelity(&query_state, quantum_ref);
let quantum_distance = 1.0 - fidelity;
distances.push(quantum_distance);
}
Ok(distances)
}
fn classical_distance_computation(&self, query_point: &ArrayView1<f64>) -> Vec<f64> {
let mut distances = Vec::new();
for i in 0..self.classical_points.nrows() {
let ref_point = self.classical_points.row(i);
let distance: f64 = query_point
.iter()
.zip(ref_point.iter())
.map(|(&a, &b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
distances.push(distance);
}
distances
}
fn calculate_quantum_fidelity(state1: &QuantumState, state2: &QuantumState) -> f64 {
if state1.amplitudes.len() != state2.amplitudes.len() {
return 0.0;
}
let inner_product: Complex64 = state1
.amplitudes
.iter()
.zip(state2.amplitudes.iter())
.map(|(a, b)| a.conj() * b)
.sum();
inner_product.norm_sqr()
}
fn apply_amplitude_amplification(&self, mut distances: Vec<f64>) -> SpatialResult<Vec<f64>> {
if distances.is_empty() {
return Ok(distances);
}
let avg_distance: f64 = distances.iter().sum::<f64>() / distances.len() as f64;
for _ in 0..self.grover_iterations {
#[allow(clippy::manual_slice_fill)]
for distance in &mut distances {
*distance = 2.0 * avg_distance - *distance;
}
for distance in &mut distances {
if *distance < avg_distance {
*distance *= 0.9; }
}
}
let min_distance = distances.iter().fold(f64::INFINITY, |a, &b| a.min(b));
if min_distance < 0.0 {
for distance in &mut distances {
*distance -= min_distance;
}
}
Ok(distances)
}
pub fn len(&self) -> usize {
self.classical_points.nrows()
}
pub fn is_empty(&self) -> bool {
self.classical_points.nrows() == 0
}
pub fn classical_points(&self) -> &Array2<f64> {
&self.classical_points
}
pub fn is_quantum_enabled(&self) -> bool {
self.quantum_encoding
}
pub fn is_amplification_enabled(&self) -> bool {
self.amplitude_amplification
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_quantum_nearest_neighbor_creation() {
let points = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0])
.expect("Operation failed");
let searcher = QuantumNearestNeighbor::new(&points.view()).expect("Operation failed");
assert_eq!(searcher.len(), 3);
assert!(!searcher.is_quantum_enabled());
assert!(!searcher.is_amplification_enabled());
}
#[test]
fn test_classical_search() {
let points = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0])
.expect("Operation failed");
let searcher = QuantumNearestNeighbor::new(&points.view()).expect("Operation failed");
let query = scirs2_core::ndarray::arr1(&[0.5, 0.5]);
let (indices, distances) = searcher
.query_quantum(&query.view(), 2)
.expect("Operation failed");
assert_eq!(indices.len(), 2);
assert_eq!(distances.len(), 2);
assert!(indices.contains(&0));
assert!(indices.contains(&1));
}
#[test]
fn test_quantum_configuration() {
let points =
Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).expect("Operation failed");
let searcher = QuantumNearestNeighbor::new(&points.view())
.expect("Operation failed")
.with_quantum_encoding(true)
.with_amplitude_amplification(true)
.with_grover_iterations(5);
assert!(searcher.is_quantum_enabled());
assert!(searcher.is_amplification_enabled());
assert_eq!(searcher.grover_iterations, 5);
}
#[test]
fn test_empty_points() {
let points = Array2::from_shape_vec((0, 2), vec![]).expect("Operation failed");
let searcher = QuantumNearestNeighbor::new(&points.view()).expect("Operation failed");
assert!(searcher.is_empty());
assert_eq!(searcher.len(), 0);
}
#[test]
fn test_invalid_k() {
let points =
Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).expect("Operation failed");
let searcher = QuantumNearestNeighbor::new(&points.view()).expect("Operation failed");
let query = scirs2_core::ndarray::arr1(&[0.5, 0.5]);
let result = searcher.query_quantum(&query.view(), 5);
assert!(result.is_err());
}
}