use crate::error::{OptimizeError, OptimizeResult};
use scirs2_core::ndarray::{Array1, Array2};
use super::gram_schmidt::{gram_schmidt, size_reduce_step, update_gram_schmidt_after_swap};
#[derive(Debug, Clone)]
pub struct LLLConfig {
pub delta: f64,
pub eta: f64,
pub max_iterations: usize,
}
impl Default for LLLConfig {
fn default() -> Self {
LLLConfig {
delta: 0.75,
eta: 0.501,
max_iterations: 10_000,
}
}
}
#[derive(Debug, Clone)]
pub struct LLLResult {
pub reduced_basis: Array2<f64>,
pub transformation: Array2<f64>,
pub n_swaps: usize,
pub n_size_reductions: usize,
}
pub struct LLLReducer {
config: LLLConfig,
}
impl LLLReducer {
pub fn new(config: LLLConfig) -> Self {
LLLReducer { config }
}
pub fn reduce(&self, basis: &Array2<f64>) -> OptimizeResult<LLLResult> {
let n = basis.nrows();
let d = basis.ncols();
if n == 0 || d == 0 {
return Err(OptimizeError::ValueError(
"Basis matrix must be non-empty".to_string(),
));
}
let delta = self.config.delta;
let eta = self.config.eta;
let mut b = basis.to_owned();
let mut u = Array2::<f64>::eye(n);
let (mut mu, mut bnorm_sq) = gram_schmidt(&b);
let mut n_swaps = 0usize;
let mut n_size_reductions = 0usize;
let mut k = 1usize;
let mut iteration = 0usize;
while k < n {
if iteration >= self.config.max_iterations {
return Err(OptimizeError::ConvergenceError(format!(
"LLL did not converge within {} iterations (n_swaps={})",
self.config.max_iterations, n_swaps
)));
}
iteration += 1;
for j in (0..k).rev() {
if mu[[k, j]].abs() > eta {
size_reduce_step(&mut b, &mut u, &mut mu, k, j);
n_size_reductions += 1;
}
}
let mu_k_km1 = mu[[k, k - 1]];
let lovász_rhs = (delta - mu_k_km1 * mu_k_km1) * bnorm_sq[k - 1];
if bnorm_sq[k] >= lovász_rhs {
k += 1;
} else {
swap_rows(&mut b, k - 1, k);
swap_rows(&mut u, k - 1, k);
n_swaps += 1;
update_gram_schmidt_after_swap(&b, &mut mu, &mut bnorm_sq, k);
if k > 1 {
k -= 1;
}
}
}
Ok(LLLResult {
reduced_basis: b,
transformation: u,
n_swaps,
n_size_reductions,
})
}
pub fn verify(&self, result: &LLLResult) -> bool {
let n = result.reduced_basis.nrows();
if n == 0 {
return true;
}
let (mu, bnorm_sq) = gram_schmidt(&result.reduced_basis);
let tol = 1e-6;
let delta = self.config.delta;
let eta = self.config.eta;
for i in 0..n {
for j in 0..i {
if mu[[i, j]].abs() > eta + tol {
return false;
}
}
}
for k in 1..n {
let mu_k_km1 = mu[[k, k - 1]];
let lhs = bnorm_sq[k];
let rhs = (delta - mu_k_km1 * mu_k_km1) * bnorm_sq[k - 1];
if lhs < rhs - tol * rhs.abs().max(1.0) {
return false;
}
}
true
}
}
fn swap_rows(mat: &mut Array2<f64>, i: usize, j: usize) {
let ncols = mat.ncols();
for col in 0..ncols {
let tmp = mat[[i, col]];
mat[[i, col]] = mat[[j, col]];
mat[[j, col]] = tmp;
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
fn det_naive(m: &Array2<f64>) -> f64 {
let n = m.nrows();
assert_eq!(n, m.ncols());
if n == 1 {
return m[[0, 0]];
}
if n == 2 {
return m[[0, 0]] * m[[1, 1]] - m[[0, 1]] * m[[1, 0]];
}
if n == 3 {
return m[[0, 0]] * (m[[1, 1]] * m[[2, 2]] - m[[1, 2]] * m[[2, 1]])
- m[[0, 1]] * (m[[1, 0]] * m[[2, 2]] - m[[1, 2]] * m[[2, 0]])
+ m[[0, 2]] * (m[[1, 0]] * m[[2, 1]] - m[[1, 1]] * m[[2, 0]]);
}
let mut d = 0.0;
for col in 0..n {
let mut minor = Array2::<f64>::zeros((n - 1, n - 1));
for r in 1..n {
let mut c2 = 0;
for c in 0..n {
if c == col {
continue;
}
minor[[r - 1, c2]] = m[[r, c]];
c2 += 1;
}
}
let sign = if col % 2 == 0 { 1.0 } else { -1.0 };
d += sign * m[[0, col]] * det_naive(&minor);
}
d
}
fn matmul(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
let n = a.nrows();
let m = b.ncols();
let k = a.ncols();
let mut c = Array2::<f64>::zeros((n, m));
for i in 0..n {
for j in 0..m {
for l in 0..k {
c[[i, j]] += a[[i, l]] * b[[l, j]];
}
}
}
c
}
#[test]
fn test_lll_identity_basis_unchanged() {
let basis = Array2::<f64>::eye(3);
let reducer = LLLReducer::new(LLLConfig::default());
let result = reducer.reduce(&basis).expect("LLL should succeed");
for i in 0..3 {
let norm_sq: f64 = (0..3).map(|j| result.reduced_basis[[i, j]].powi(2)).sum();
assert!((norm_sq - 1.0).abs() < 1e-8, "Norm of row {} should be 1", i);
}
}
#[test]
fn test_lll_result_satisfies_lovász_condition() {
let basis = array![
[1.0, 1.0, 1.0],
[-1.0, 0.0, 2.0],
[3.0, 5.0, 6.0]
];
let reducer = LLLReducer::new(LLLConfig::default());
let result = reducer.reduce(&basis).expect("LLL should succeed");
assert!(
reducer.verify(&result),
"LLL result should satisfy Lovász condition"
);
}
#[test]
fn test_lll_transformation_unimodular() {
let basis = array![[1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0]];
let reducer = LLLReducer::new(LLLConfig::default());
let result = reducer.reduce(&basis).expect("LLL should succeed");
let det_u = det_naive(&result.transformation);
assert!(
(det_u.abs() - 1.0).abs() < 1e-6,
"det(U) should be ±1, got {}",
det_u
);
}
#[test]
fn test_lll_lattice_equivalence() {
let basis = array![[1.0, 2.0], [3.0, 4.0]];
let reducer = LLLReducer::new(LLLConfig::default());
let result = reducer.reduce(&basis).expect("LLL should succeed");
let reconstructed = matmul(&result.transformation, &basis);
for i in 0..2 {
for j in 0..2 {
assert!(
(reconstructed[[i, j]] - result.reduced_basis[[i, j]]).abs() < 1e-6,
"U * B_orig[{},{}] = {} != B_red[{},{}] = {}",
i, j, reconstructed[[i, j]], i, j, result.reduced_basis[[i, j]]
);
}
}
}
#[test]
fn test_lll_2x2_classic() {
let basis = array![[1.0, 1.0], [0.0, 1.0]];
let reducer = LLLReducer::new(LLLConfig::default());
let result = reducer.reduce(&basis).expect("LLL should succeed");
assert!(reducer.verify(&result));
let norm0_sq: f64 = result.reduced_basis.row(0).iter().map(|x| x * x).sum();
assert!(norm0_sq <= 2.0 + 1e-8, "First vector should be short, got norm^2={}", norm0_sq);
}
#[test]
fn test_lll_terminates_within_max_iterations() {
let basis = array![
[10.0, 3.0, -1.0, 2.0],
[-4.0, 7.0, 2.0, 1.0],
[1.0, -2.0, 5.0, 3.0],
[2.0, 1.0, -3.0, 6.0]
];
let config = LLLConfig {
max_iterations: 10_000,
..Default::default()
};
let reducer = LLLReducer::new(config);
let result = reducer.reduce(&basis).expect("LLL should terminate");
assert!(result.n_swaps < 10_000);
}
#[test]
fn test_lll_known_reduction() {
let basis = array![[1.0, 1.0, 1.0], [-1.0, 0.0, 2.0], [3.0, 5.0, 6.0]];
let reducer = LLLReducer::new(LLLConfig::default());
let result = reducer.reduce(&basis).expect("LLL should succeed");
let orig_norms: Vec<f64> = (0..3)
.map(|i| (0..3).map(|j| basis[[i, j]].powi(2)).sum::<f64>())
.collect();
let min_orig_norm: f64 = orig_norms.into_iter().fold(f64::MAX, f64::min);
let reduced_norm0: f64 = (0..3).map(|j| result.reduced_basis[[0, j]].powi(2)).sum();
assert!(
reduced_norm0 <= min_orig_norm + 1e-6,
"LLL first vector norm^2={} should be <= min original norm^2={}",
reduced_norm0, min_orig_norm
);
}
#[test]
fn test_size_reduce_step_reduces_mu() {
use super::super::gram_schmidt::gram_schmidt;
let basis = array![[1.0, 0.0], [5.0, 1.0]];
let (mu_before, _) = gram_schmidt(&basis);
assert!(mu_before[[1, 0]].abs() > 0.5, "Setup: mu[1][0] should be large");
let mut b = basis.clone();
let mut u = Array2::<f64>::eye(2);
let (mut mu, mut _bnorm_sq) = gram_schmidt(&b);
size_reduce_step(&mut b, &mut u, &mut mu, 1, 0);
assert!(
mu[[1, 0]].abs() <= 0.501 + 1e-10,
"After size-reduce, mu[1][0]={} should be <= 0.501",
mu[[1, 0]]
);
}
}