use crate::error::{SpatialError, SpatialResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use scirs2_core::random::{Rng, RngExt};
use std::f64::consts::PI;
#[derive(Debug, Clone)]
pub struct QuantumSVMModel {
n_qubits: usize,
regularization: f64,
support_vectors: Vec<Array1<f64>>,
alphas: Vec<f64>,
bias: f64,
random_weights: Option<Array2<f64>>,
random_offsets: Option<Array1<f64>>,
bandwidth: f64,
}
impl QuantumSVMModel {
pub fn new(n_qubits: usize, regularization: f64) -> Self {
Self {
n_qubits,
regularization,
support_vectors: Vec::new(),
alphas: Vec::new(),
bias: 0.0,
random_weights: None,
random_offsets: None,
bandwidth: 1.0,
}
}
pub fn n_qubits(&self) -> usize {
self.n_qubits
}
pub fn regularization(&self) -> f64 {
self.regularization
}
pub fn num_support_vectors(&self) -> usize {
self.support_vectors.len()
}
fn feature_dim(&self) -> usize {
let raw = 1usize << self.n_qubits;
raw.clamp(4, 256)
}
fn init_random_features(&mut self, d: usize) {
let big_d = self.feature_dim();
let mut rng = scirs2_core::random::rng();
let scale = 1.0 / self.bandwidth;
let mut weights = Array2::<f64>::zeros((big_d, d));
let mut offsets = Array1::<f64>::zeros(big_d);
for i in 0..big_d {
for j in 0..d {
let u1: f64 = rng.random_range(1e-10_f64..1.0_f64);
let u2: f64 = rng.random_range(0.0_f64..1.0_f64);
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
weights[[i, j]] = z * scale;
}
offsets[i] = rng.random_range(0.0_f64..(2.0 * PI));
}
self.random_weights = Some(weights);
self.random_offsets = Some(offsets);
}
fn quantum_feature_map(&self, x: &ArrayView1<'_, f64>) -> SpatialResult<Array1<f64>> {
let weights = self.random_weights.as_ref().ok_or_else(|| {
SpatialError::InvalidInput("Model not fitted: call fit() first".to_string())
})?;
let offsets = self.random_offsets.as_ref().ok_or_else(|| {
SpatialError::InvalidInput("Model not fitted: call fit() first".to_string())
})?;
let big_d = weights.nrows();
let scale = (2.0 / big_d as f64).sqrt();
let mut phi = Array1::<f64>::zeros(big_d);
for i in 0..big_d {
let row = weights.row(i);
let dot: f64 = row.iter().zip(x.iter()).map(|(w, xi)| w * xi).sum();
phi[i] = scale * (dot + offsets[i]).cos();
}
Ok(phi)
}
fn quantum_kernel(
&self,
a: &ArrayView1<'_, f64>,
b: &ArrayView1<'_, f64>,
) -> SpatialResult<f64> {
let phi_a = self.quantum_feature_map(a)?;
let phi_b = self.quantum_feature_map(b)?;
Ok(phi_a.iter().zip(phi_b.iter()).map(|(ai, bi)| ai * bi).sum())
}
pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> SpatialResult<()> {
let (n, d) = x.dim();
if n == 0 {
return Err(SpatialError::InvalidInput(
"Training set must be non-empty".to_string(),
));
}
if y.len() != n {
return Err(SpatialError::InvalidInput(format!(
"x has {} rows but y has {} elements",
n,
y.len()
)));
}
if self.regularization <= 0.0 {
return Err(SpatialError::InvalidInput(
"regularization (C) must be positive".to_string(),
));
}
for (i, &yi) in y.iter().enumerate() {
if (yi - 1.0).abs() > 1e-9 && (yi + 1.0).abs() > 1e-9 {
return Err(SpatialError::InvalidInput(format!(
"Label y[{}] = {} is not in {{-1, +1}}",
i, yi
)));
}
}
let mut sq_dists: Vec<f64> = Vec::with_capacity(n * (n - 1) / 2);
for i in 0..n {
for j in (i + 1)..n {
let sq: f64 = (0..d).map(|k| (x[[i, k]] - x[[j, k]]).powi(2)).sum();
sq_dists.push(sq);
}
}
if !sq_dists.is_empty() {
sq_dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median = sq_dists[sq_dists.len() / 2];
self.bandwidth = median.sqrt().max(1e-6);
}
self.init_random_features(d);
let mut kernel_matrix = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in i..n {
let kij = self.quantum_kernel(&x.row(i), &x.row(j))?;
kernel_matrix[[i, j]] = kij;
kernel_matrix[[j, i]] = kij;
}
}
let mut alpha = vec![0.0f64; n];
let max_smo_iter = 200;
let tol = 1e-4;
for _ in 0..max_smo_iter {
let mut changed = false;
for i in 0..n {
let fi: f64 = alpha
.iter()
.enumerate()
.map(|(j, &aj)| aj * y[j] * kernel_matrix[[i, j]])
.sum::<f64>()
+ self.bias;
let ri = fi * y[i] - 1.0;
let kkt_violated = (ri < -tol && alpha[i] < self.regularization - tol)
|| (ri > tol && alpha[i] > tol);
if !kkt_violated {
continue;
}
let j = (0..n)
.filter(|&k| k != i)
.max_by(|&a, &b| {
let fa: f64 = alpha
.iter()
.enumerate()
.map(|(l, &al)| al * y[l] * kernel_matrix[[a, l]])
.sum::<f64>()
+ self.bias;
let fb: f64 = alpha
.iter()
.enumerate()
.map(|(l, &al)| al * y[l] * kernel_matrix[[b, l]])
.sum::<f64>()
+ self.bias;
(fa * y[a] - 1.0)
.abs()
.partial_cmp(&(fb * y[b] - 1.0).abs())
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or((i + 1) % n);
let eta =
kernel_matrix[[i, i]] + kernel_matrix[[j, j]] - 2.0 * kernel_matrix[[i, j]];
if eta <= 1e-12 {
continue;
}
let fj: f64 = alpha
.iter()
.enumerate()
.map(|(l, &al)| al * y[l] * kernel_matrix[[j, l]])
.sum::<f64>()
+ self.bias;
let e_i = fi - y[i];
let e_j = fj - y[j];
let alpha_j_new =
(alpha[j] + y[j] * (e_i - e_j) / eta).clamp(0.0, self.regularization);
let alpha_i_new = alpha[i] + y[i] * y[j] * (alpha[j] - alpha_j_new);
let alpha_i_new = alpha_i_new.clamp(0.0, self.regularization);
if (alpha_i_new - alpha[i]).abs() > 1e-8 {
let delta_i = alpha_i_new - alpha[i];
let delta_j = alpha_j_new - alpha[j];
alpha[i] = alpha_i_new;
alpha[j] = alpha_j_new;
let b_i = -fi
- y[i] * delta_i * kernel_matrix[[i, i]]
- y[j] * delta_j * kernel_matrix[[i, j]];
let b_j = -fj
- y[i] * delta_i * kernel_matrix[[i, j]]
- y[j] * delta_j * kernel_matrix[[j, j]];
if alpha[i] > tol && alpha[i] < self.regularization - tol {
self.bias += b_i;
} else if alpha[j] > tol && alpha[j] < self.regularization - tol {
self.bias += b_j;
} else {
self.bias += (b_i + b_j) * 0.5;
}
changed = true;
}
}
if !changed {
break;
}
}
let sv_threshold = 1e-6;
self.support_vectors.clear();
self.alphas.clear();
for i in 0..n {
if alpha[i] > sv_threshold {
self.support_vectors.push(x.row(i).to_owned());
self.alphas.push(alpha[i] * y[i]);
}
}
Ok(())
}
pub fn predict(&self, x: &Array2<f64>) -> SpatialResult<Array1<f64>> {
if self.support_vectors.is_empty() {
return Err(SpatialError::InvalidInput(
"Model not fitted: call fit() first".to_string(),
));
}
let n_test = x.nrows();
let mut preds = Array1::<f64>::zeros(n_test);
for (idx, row) in x.outer_iter().enumerate() {
let mut decision: f64 = self.bias;
for (sv, &alpha) in self.support_vectors.iter().zip(self.alphas.iter()) {
let kval = self.quantum_kernel(&row, &sv.view())?;
decision += alpha * kval;
}
preds[idx] = if decision >= 0.0 { 1.0 } else { -1.0 };
}
Ok(preds)
}
}
#[derive(Debug, Clone)]
pub struct QuantumClassifier {
svm: QuantumSVMModel,
standardise: bool,
feature_means: Option<Array1<f64>>,
feature_stds: Option<Array1<f64>>,
}
impl QuantumClassifier {
pub fn new(n_qubits: usize, regularization: f64, standardise: bool) -> Self {
Self {
svm: QuantumSVMModel::new(n_qubits, regularization),
standardise,
feature_means: None,
feature_stds: None,
}
}
pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> SpatialResult<()> {
let x_proc = if self.standardise {
self.compute_and_store_stats(x)?
} else {
x.clone()
};
self.svm.fit(&x_proc, y)
}
pub fn predict(&self, x: &Array2<f64>) -> SpatialResult<Array1<f64>> {
let x_proc = if self.standardise {
self.apply_stored_stats(x)?
} else {
x.clone()
};
self.svm.predict(&x_proc)
}
fn compute_and_store_stats(&mut self, x: &Array2<f64>) -> SpatialResult<Array2<f64>> {
let (n, d) = x.dim();
if n == 0 {
return Err(SpatialError::InvalidInput(
"Cannot standardise an empty matrix".to_string(),
));
}
let mut means = Array1::<f64>::zeros(d);
let mut stds = Array1::<f64>::zeros(d);
for j in 0..d {
let col = x.column(j);
let mean = col.iter().sum::<f64>() / n as f64;
let variance = col.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / n as f64;
means[j] = mean;
stds[j] = variance.sqrt().max(1e-8);
}
self.feature_means = Some(means.clone());
self.feature_stds = Some(stds.clone());
let mut x_std = x.clone();
for j in 0..d {
for i in 0..n {
x_std[[i, j]] = (x_std[[i, j]] - means[j]) / stds[j];
}
}
Ok(x_std)
}
fn apply_stored_stats(&self, x: &Array2<f64>) -> SpatialResult<Array2<f64>> {
let means = self.feature_means.as_ref().ok_or_else(|| {
SpatialError::InvalidInput("Model not fitted: call fit() first".to_string())
})?;
let stds = self.feature_stds.as_ref().ok_or_else(|| {
SpatialError::InvalidInput("Model not fitted: call fit() first".to_string())
})?;
let (n, d) = x.dim();
if d != means.len() {
return Err(SpatialError::InvalidInput(format!(
"Expected {} features but got {}",
means.len(),
d
)));
}
let mut x_std = x.clone();
for j in 0..d {
for i in 0..n {
x_std[[i, j]] = (x_std[[i, j]] - means[j]) / stds[j];
}
}
Ok(x_std)
}
pub fn num_support_vectors(&self) -> usize {
self.svm.num_support_vectors()
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{Array1, Array2};
fn two_class_data() -> (Array2<f64>, Array1<f64>) {
let x = Array2::from_shape_vec(
(6, 2),
vec![0.1, 0.1, -0.1, 0.2, 0.2, -0.1, 5.0, 5.0, 5.5, 5.0, 5.0, 5.5],
)
.expect("shape is valid");
let y = Array1::from_vec(vec![1.0, 1.0, 1.0, -1.0, -1.0, -1.0]);
(x, y)
}
#[test]
fn test_quantum_svm_fit_and_predict() {
let (x, y) = two_class_data();
let mut model = QuantumSVMModel::new(3, 1.0);
model.fit(&x, &y).expect("fit should succeed");
assert!(
model.num_support_vectors() > 0,
"model must produce support vectors"
);
let preds = model.predict(&x).expect("predict should succeed");
assert_eq!(preds.len(), x.nrows());
for &p in preds.iter() {
assert!(
(p - 1.0).abs() < 1e-9 || (p + 1.0).abs() < 1e-9,
"prediction {p} is not ±1"
);
}
}
#[test]
fn test_quantum_classifier_with_standardisation() {
let (x, y) = two_class_data();
let mut clf = QuantumClassifier::new(3, 1.0, true);
clf.fit(&x, &y).expect("fit should succeed");
let preds = clf.predict(&x).expect("predict should succeed");
assert_eq!(preds.len(), x.nrows());
for &p in preds.iter() {
assert!(
(p - 1.0).abs() < 1e-9 || (p + 1.0).abs() < 1e-9,
"prediction {p} is not ±1"
);
}
}
#[test]
fn test_svm_rejects_bad_labels() {
let x = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).expect("shape is valid");
let y_bad = Array1::from_vec(vec![0.0, 1.0]); let mut model = QuantumSVMModel::new(2, 1.0);
assert!(model.fit(&x, &y_bad).is_err());
}
#[test]
fn test_predict_before_fit_errors() {
let model = QuantumSVMModel::new(2, 1.0);
let x_test = Array2::from_shape_vec((1, 2), vec![0.0, 0.0]).expect("shape is valid");
assert!(model.predict(&x_test).is_err());
}
}