use crate::errors::AlkahestError;
use rug::{Assign, Float, Integer, Rational};
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LatticeError {
EmptyBasis,
RaggedBasis {
row: usize,
expected_cols: usize,
got_cols: usize,
},
InvalidDelta { provided: Rational },
IterationLimit { iterations: usize },
}
impl fmt::Display for LatticeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LatticeError::EmptyBasis => write!(f, "LLL expects at least one basis row"),
LatticeError::RaggedBasis {
row,
expected_cols,
got_cols,
} => write!(
f,
"row {row} has length {got_cols}; expected ambient dimension {expected_cols}"
),
LatticeError::InvalidDelta { .. } => {
write!(f, "LLL Lovász factor δ must lie strictly between ¼ and 1")
}
LatticeError::IterationLimit { iterations } => write!(
f,
"LLL reduction aborted after {iterations} swaps (degenerate span or oversized basis)"
),
}
}
}
impl std::error::Error for LatticeError {}
impl AlkahestError for LatticeError {
fn code(&self) -> &'static str {
match self {
LatticeError::EmptyBasis => "E-LAT-001",
LatticeError::RaggedBasis { .. } => "E-LAT-002",
LatticeError::InvalidDelta { .. } => "E-LAT-003",
LatticeError::IterationLimit { .. } => "E-LAT-004",
}
}
fn remediation(&self) -> Option<&'static str> {
match self {
LatticeError::EmptyBasis => {
Some("pass a non-empty list of equally long integer coefficient rows")
}
LatticeError::RaggedBasis { .. } => {
Some("pad or trim rows so every basis vector lies in ℤ^m for fixed m")
}
LatticeError::InvalidDelta { .. } => {
Some("use the default δ = ¾, or choose another rational strictly between ¼ and 1")
}
LatticeError::IterationLimit { .. } => Some(
"check for rank-deficient rows, reduce dimension, or report a bug with a minimal basis",
),
}
}
}
#[inline]
fn dot_int_rat(row: &[Integer], v: &[Rational]) -> Rational {
let mut acc = Rational::from(0u32);
for (zi, vv) in row.iter().zip(v.iter()) {
let mut term = Rational::from(0u32);
let prod = Rational::from(zi) * vv;
term.assign(&prod);
acc += term;
}
acc
}
fn dot_rat(a: &[Rational], b: &[Rational]) -> Rational {
let mut acc = Rational::from(0u32);
for (x, y) in a.iter().zip(b.iter()) {
let mut term = Rational::from(0u32);
let prod = x.clone() * y.clone();
term.assign(&prod);
acc += term;
}
acc
}
fn int_row_as_rat(row: &[Integer]) -> Vec<Rational> {
row.iter().map(Rational::from).collect()
}
fn gram_schmidt_rows(
basis: &[Vec<Integer>],
) -> (Vec<Vec<Rational>>, Vec<Vec<Rational>>, Vec<Rational>) {
let n = basis.len();
let ambient = basis[0].len();
let mut star = vec![vec![Rational::from(0); ambient]; n];
let mut mu = vec![vec![Rational::from(0); n]; n];
let mut b_norm_sq = vec![Rational::from(0); n];
for i in 0..n {
let mut vip = int_row_as_rat(&basis[i]);
for j in 0..i {
mu[i][j].assign(&dot_int_rat(&basis[i], &star[j]) / &b_norm_sq[j]);
for t in 0..ambient {
let m = mu[i][j].clone() * star[j][t].clone();
let vpt = vip[t].clone();
let sub = vpt - &m;
vip[t].assign(sub);
}
}
star[i] = vip;
let ni = dot_rat(&star[i], &star[i]);
b_norm_sq[i].assign(ni);
}
(mu, star, b_norm_sq)
}
fn nearest_integer_rational(x: &Rational) -> Integer {
Float::with_val(4096u32, x)
.round()
.to_integer()
.unwrap_or_else(|| Integer::from(0))
}
fn validate_rows(basis: &[Vec<Integer>]) -> Result<usize, LatticeError> {
if basis.is_empty() {
return Err(LatticeError::EmptyBasis);
}
let cols = basis[0].len();
for (i, row) in basis.iter().enumerate() {
if row.len() != cols {
return Err(LatticeError::RaggedBasis {
row: i,
expected_cols: cols,
got_cols: row.len(),
});
}
}
Ok(cols)
}
fn validate_delta(delta: &Rational) -> Result<(), LatticeError> {
let low = Rational::from((1i32, 4i32));
let hi = Rational::from(1u32);
if *delta <= low || *delta >= hi {
return Err(LatticeError::InvalidDelta {
provided: delta.clone(),
});
}
Ok(())
}
fn size_reduce_single(
basis: &mut [Vec<Integer>],
mu: &[Vec<Rational>],
b_norm_sq: &[Rational],
k: usize,
) -> bool {
let mut altered = false;
for j in (0..k).rev() {
if b_norm_sq[j].is_zero() {
continue;
}
let mij = &mu[k][j];
let q = nearest_integer_rational(mij);
if q == 0 {
continue;
}
altered = true;
for col in 0..basis[k].len() {
let bjk = basis[j][col].clone();
basis[k][col] -= &(q.clone() * bjk);
}
return altered;
}
altered
}
fn lovasz_ok(b_norm_sq: &[Rational], mu: &[Vec<Rational>], delta: &Rational, k: usize) -> bool {
if k == 0 {
return true;
}
let bk = &b_norm_sq[k];
let bkm1 = &b_norm_sq[k - 1];
if bkm1.is_zero() {
return false;
}
let mux = mu[k][k - 1].clone();
let mux_sq = Rational::from(&mux * &mux);
let mut slack = delta.clone();
slack -= &mux_sq;
let rhs: Rational = slack * bkm1;
bk.clone() >= rhs
}
fn lll_reduce_once(
basis_rows: &[Vec<Integer>],
delta: &Rational,
) -> Result<Vec<Vec<Integer>>, LatticeError> {
validate_rows(basis_rows)?;
validate_delta(delta)?;
let ambient = basis_rows[0].len();
let n = basis_rows.len();
let mut basis: Vec<Vec<Integer>> = basis_rows.to_vec();
let mut k: usize = 1;
let mut guard: usize = 0;
const MAX_LLL_SWAPS: usize = 2_000_000;
loop {
if k >= n {
break;
}
guard += 1;
if guard > MAX_LLL_SWAPS {
return Err(LatticeError::IterationLimit { iterations: guard });
}
loop {
let (mu_ref, _, b_norm_sq) = gram_schmidt_rows(&basis);
if !size_reduce_single(&mut basis, &mu_ref, &b_norm_sq, k) {
break;
}
}
let (mu, _, b_norm_sq) = gram_schmidt_rows(&basis);
if lovasz_ok(&b_norm_sq, &mu, delta, k) {
k += 1;
} else {
basis.swap(k, k - 1);
k = k.saturating_sub(1);
if k < 1 {
k = 1;
}
}
let _ = ambient;
if k >= n && n > 8000 {
break;
}
}
Ok(basis)
}
pub fn lattice_reduce_rows(basis_rows: &[Vec<Integer>]) -> Result<Vec<Vec<Integer>>, LatticeError> {
let delta = Rational::from((3u32, 4u32));
lll_reduce_once(basis_rows, &delta)
}
pub fn lattice_reduce_rows_with_delta(
basis_rows: &[Vec<Integer>],
delta: Rational,
) -> Result<Vec<Vec<Integer>>, LatticeError> {
lll_reduce_once(basis_rows, &delta)
}
pub fn validate_lll_rows(
basis_rows: &[Vec<Integer>],
delta: &Rational,
) -> Result<(), &'static str> {
validate_rows(basis_rows).map_err(|_| "shape")?;
validate_delta(delta).map_err(|_| "delta")?;
let n = basis_rows.len();
let (mu, _, b_sq) = gram_schmidt_rows(basis_rows);
if n == 1 {
return Ok(());
}
let half = Rational::from((1u32, 2u32));
for i in 1..n {
for mij in mu[i].iter().take(i) {
let mut absmu = mij.clone();
absmu.abs_mut();
if absmu > half {
return Err("size");
}
}
if !lovasz_ok(&b_sq, &mu, delta, i) {
return Err("lovasz");
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use rug::Rational;
#[test]
fn planar_two_vectors_lll() {
let rows: Vec<Vec<Integer>> = vec![
vec![Integer::from(2), Integer::from(15)],
vec![Integer::from(1), Integer::from(21)],
];
let reduced = lattice_reduce_rows(&rows).unwrap();
let delta = Rational::from((3u32, 4u32));
validate_lll_rows(&reduced, &delta).unwrap();
}
#[test]
fn knapsack_row_weighted_near_origin() {
let rows: Vec<Vec<Integer>> = vec![
vec![Integer::from(1), Integer::from(0), Integer::from(5)],
vec![Integer::from(0), Integer::from(1), Integer::from(6)],
vec![Integer::from(0), Integer::from(0), Integer::from(33)],
];
let reduced = lattice_reduce_rows(&rows).unwrap();
validate_lll_rows(&reduced, &Rational::from((3u32, 4u32))).unwrap();
fn max_row_norm_squared(basis: &[Vec<Integer>]) -> Integer {
basis
.iter()
.map(|row| {
row.iter().fold(Integer::from(0), |a, zi| {
a.clone() + zi.clone() * zi.clone()
})
})
.max_by(|x, y| x.cmp(y))
.unwrap()
}
assert!(
max_row_norm_squared(&reduced) <= max_row_norm_squared(&rows),
"maximum squared row norm should shrink on this scaffold"
);
}
}