use crate::error::{OptimizeError, OptimizeResult};
use scirs2_core::ndarray::{Array1, Array2};
use super::gram_schmidt::gram_schmidt;
pub fn projected_norm_sq(
coeffs: &[f64],
mu: &Array2<f64>,
bnorm_sq: &Array1<f64>,
k: usize,
) -> f64 {
let n = coeffs.len();
let mut total = 0.0;
for i in k..n {
let mut sigma = coeffs[i];
for j in (i + 1)..n {
sigma += coeffs[j] * mu[[j, i]];
}
total += sigma * sigma * bnorm_sq[i];
}
total
}
pub fn solve_svp(basis: &Array2<f64>, max_nodes: usize) -> OptimizeResult<Array1<f64>> {
let n = basis.nrows();
let d = basis.ncols();
if n == 0 || d == 0 {
return Err(OptimizeError::ValueError(
"SVP: basis must be non-empty".to_string(),
));
}
let (mu, bnorm_sq) = gram_schmidt(basis);
let first_norm_sq: f64 = (0..d).map(|j| basis[[0, j]].powi(2)).sum();
let mut best_norm_sq = first_norm_sq;
let mut best_coeffs: Vec<f64> = vec![0.0; n];
best_coeffs[0] = 1.0;
for i in 0..n {
let norm_sq: f64 = (0..d).map(|j| basis[[i, j]].powi(2)).sum();
if norm_sq < best_norm_sq && norm_sq > 1e-14 {
best_norm_sq = norm_sq;
best_coeffs = vec![0.0; n];
best_coeffs[i] = 1.0;
}
}
let mut nodes_visited = 0usize;
let mut coeffs = vec![0.0f64; n];
enum StackEntry {
Push {
level: usize,
c_start: i64,
c_end: i64,
c_current: i64,
},
Pop,
}
let result = enumerate_svp(
&mu,
&bnorm_sq,
n,
&mut coeffs,
0,
best_norm_sq,
&mut nodes_visited,
max_nodes,
);
let (found_norm_sq, found_coeffs) = result?;
if found_norm_sq < best_norm_sq {
best_norm_sq = found_norm_sq;
best_coeffs = found_coeffs;
}
let mut vec = Array1::<f64>::zeros(d);
for i in 0..n {
for j in 0..d {
vec[j] += best_coeffs[i] * basis[[i, j]];
}
}
let norm: f64 = vec.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm < 1e-10 {
return Err(OptimizeError::ComputationError(
"SVP enumeration found a zero vector; basis may be degenerate".to_string(),
));
}
let _ = best_norm_sq;
Ok(vec)
}
fn enumerate_svp(
mu: &Array2<f64>,
bnorm_sq: &Array1<f64>,
n: usize,
coeffs: &mut Vec<f64>,
k: usize,
current_bound: f64,
nodes_visited: &mut usize,
max_nodes: usize,
) -> OptimizeResult<(f64, Vec<f64>)> {
*nodes_visited += 1;
if *nodes_visited > max_nodes {
return Ok((current_bound, coeffs.clone()));
}
if k == n {
if coeffs.iter().all(|&c| c == 0.0) {
return Ok((current_bound, coeffs.clone()));
}
let norm_sq = projected_norm_sq(coeffs, mu, bnorm_sq, 0);
if norm_sq > 1e-14 && norm_sq < current_bound {
return Ok((norm_sq, coeffs.clone()));
}
return Ok((current_bound, coeffs.clone()));
}
let mut sigma_parent = 0.0f64;
for j in (k + 1)..n {
sigma_parent += coeffs[j] * mu[[j, k]];
}
let center = -sigma_parent;
let upper_contrib = projected_norm_sq_partial(coeffs, mu, bnorm_sq, k + 1, n);
let remaining = current_bound - upper_contrib;
if remaining <= 0.0 {
return Ok((current_bound, coeffs.clone()));
}
let bk = bnorm_sq[k];
if bk < 1e-14 {
coeffs[k] = 0.0;
let (nb, nc) = enumerate_svp(mu, bnorm_sq, n, coeffs, k + 1, current_bound, nodes_visited, max_nodes)?;
return Ok((nb, nc));
}
let radius = (remaining / bk).sqrt();
let c_lo = (center - radius).ceil() as i64;
let c_hi = (center + radius).floor() as i64;
let mut best_norm = current_bound;
let mut best_c = coeffs.clone();
let c_center = center.round() as i64;
let mut candidates: Vec<i64> = Vec::new();
candidates.push(c_center);
let mut offset = 1i64;
loop {
let added = candidates.len();
if c_center + offset <= c_hi {
candidates.push(c_center + offset);
}
if c_center - offset >= c_lo {
candidates.push(c_center - offset);
}
if candidates.len() == added {
break; }
offset += 1;
if offset > (c_hi - c_lo + 1).max(0) + 1 {
break;
}
}
for &c in &candidates {
if c < c_lo || c > c_hi {
continue;
}
coeffs[k] = c as f64;
let sigma_k = sigma_parent + c as f64;
let contrib_k = sigma_k * sigma_k * bk;
if upper_contrib + contrib_k >= best_norm {
continue;
}
let (sub_norm, sub_coeffs) = enumerate_svp(
mu,
bnorm_sq,
n,
coeffs,
k + 1,
best_norm,
nodes_visited,
max_nodes,
)?;
if sub_norm < best_norm {
best_norm = sub_norm;
best_c = sub_coeffs;
}
if *nodes_visited > max_nodes {
break;
}
}
coeffs[k] = 0.0;
Ok((best_norm, best_c))
}
fn projected_norm_sq_partial(
coeffs: &[f64],
mu: &Array2<f64>,
bnorm_sq: &Array1<f64>,
from: usize,
to: usize,
) -> f64 {
let n = coeffs.len();
let mut total = 0.0;
for i in from..to {
let mut sigma = coeffs[i];
for j in (i + 1)..n {
sigma += coeffs[j] * mu[[j, i]];
}
total += sigma * sigma * bnorm_sq[i];
}
total
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_svp_2d_known_shortest() {
let basis = array![[3.0, 0.0], [0.0, 2.0]];
let v = solve_svp(&basis, 100_000).expect("SVP should succeed");
let norm_sq: f64 = v.iter().map(|x| x * x).sum();
assert!(
(norm_sq - 4.0).abs() < 1e-6 || (norm_sq - 9.0).abs() < 1e-6,
"Expected norm^2 = 4 (or 9), got {}",
norm_sq
);
assert!(norm_sq > 1e-10);
}
#[test]
fn test_svp_integer_lattice_shortest_vector() {
let basis = array![[1.0, 0.0], [0.0, 1.0]];
let v = solve_svp(&basis, 100_000).expect("SVP should succeed");
let norm_sq: f64 = v.iter().map(|x| x * x).sum();
assert!((norm_sq - 1.0).abs() < 1e-6, "Expected norm^2 = 1, got {}", norm_sq);
}
#[test]
fn test_projected_norm_sq_zero_for_zero_vector() {
let basis = array![[1.0, 0.0], [0.0, 1.0]];
let (mu, bnorm_sq) = gram_schmidt(&basis);
let coeffs = vec![0.0, 0.0];
let pn = projected_norm_sq(&coeffs, &mu, &bnorm_sq, 0);
assert!(pn.abs() < 1e-12, "Zero vector should have zero projected norm, got {}", pn);
}
}