Skip to main content

alkahest_cas/lattice/
lll.rs

1//! Lenstra–Lenstra–Lovász lattice basis reduction over ℤ (row basis vectors).
2//!
3//! Algorithm structure follows Henri Cohen (*A Course in Computational Algebraic Number
4//! Theory*, §2.6): Gram–Schmidt orthogonalisation with exact [`rug::Rational`] arithmetic,
5//! iterative size reductions and pairwise swaps enforcing the Lovász condition.
6//!
7//! This is intended for modest dimensions (`n,m ≲ 300`) where squared norms stay
8//! representable comfortably in exact rationals — the primary consumers (van-Hoeij knapsacks)
9//! rarely exceed that.
10
11use crate::errors::AlkahestError;
12use rug::{Assign, Float, Integer, Rational};
13use std::fmt;
14
15/// Lattice-basis reductions errors.
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum LatticeError {
18    /// No basis vectors supplied.
19    EmptyBasis,
20    /// Row `row` differs in length from the first row (ambient mismatch).
21    RaggedBasis {
22        row: usize,
23        expected_cols: usize,
24        got_cols: usize,
25    },
26    /// `δ ∉ (¼, 1)` as required by the LLL hypotheses.
27    InvalidDelta { provided: Rational },
28    /// Swap loop exceeded the iteration budget — basis may be degenerate or the implementation buggy.
29    IterationLimit { iterations: usize },
30}
31
32impl fmt::Display for LatticeError {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        match self {
35            LatticeError::EmptyBasis => write!(f, "LLL expects at least one basis row"),
36            LatticeError::RaggedBasis {
37                row,
38                expected_cols,
39                got_cols,
40            } => write!(
41                f,
42                "row {row} has length {got_cols}; expected ambient dimension {expected_cols}"
43            ),
44            LatticeError::InvalidDelta { .. } => {
45                write!(f, "LLL Lovász factor δ must lie strictly between ¼ and 1")
46            }
47            LatticeError::IterationLimit { iterations } => write!(
48                f,
49                "LLL reduction aborted after {iterations} swaps (degenerate span or oversized basis)"
50            ),
51        }
52    }
53}
54
55impl std::error::Error for LatticeError {}
56
57impl AlkahestError for LatticeError {
58    fn code(&self) -> &'static str {
59        match self {
60            LatticeError::EmptyBasis => "E-LAT-001",
61            LatticeError::RaggedBasis { .. } => "E-LAT-002",
62            LatticeError::InvalidDelta { .. } => "E-LAT-003",
63            LatticeError::IterationLimit { .. } => "E-LAT-004",
64        }
65    }
66
67    fn remediation(&self) -> Option<&'static str> {
68        match self {
69            LatticeError::EmptyBasis => {
70                Some("pass a non-empty list of equally long integer coefficient rows")
71            }
72            LatticeError::RaggedBasis { .. } => {
73                Some("pad or trim rows so every basis vector lies in ℤ^m for fixed m")
74            }
75            LatticeError::InvalidDelta { .. } => {
76                Some("use the default δ = ¾, or choose another rational strictly between ¼ and 1")
77            }
78            LatticeError::IterationLimit { .. } => Some(
79                "check for rank-deficient rows, reduce dimension, or report a bug with a minimal basis",
80            ),
81        }
82    }
83}
84
85#[inline]
86fn dot_int_rat(row: &[Integer], v: &[Rational]) -> Rational {
87    let mut acc = Rational::from(0u32);
88    for (zi, vv) in row.iter().zip(v.iter()) {
89        let mut term = Rational::from(0u32);
90        let prod = Rational::from(zi) * vv;
91        term.assign(&prod);
92        acc += term;
93    }
94    acc
95}
96
97fn dot_rat(a: &[Rational], b: &[Rational]) -> Rational {
98    let mut acc = Rational::from(0u32);
99    for (x, y) in a.iter().zip(b.iter()) {
100        let mut term = Rational::from(0u32);
101        let prod = x.clone() * y.clone();
102        term.assign(&prod);
103        acc += term;
104    }
105    acc
106}
107
108fn int_row_as_rat(row: &[Integer]) -> Vec<Rational> {
109    row.iter().map(Rational::from).collect()
110}
111
112/// Gram–Schmidt data for rows `basis[0 … n − 1]`.
113///
114/// * `star[i]` holds the orthogonal residual `b*_i`.
115/// * `mu[i][j]` for `j < i` is ⟨b_i,b*_j⟩ / ⟨b*_j,b*_j⟩.
116/// * `b_norm_sq[i]` is ⟨b*_i,b*_i⟩ ∈ ℚ.
117fn gram_schmidt_rows(
118    basis: &[Vec<Integer>],
119) -> (Vec<Vec<Rational>>, Vec<Vec<Rational>>, Vec<Rational>) {
120    let n = basis.len();
121    let ambient = basis[0].len();
122    let mut star = vec![vec![Rational::from(0); ambient]; n];
123    let mut mu = vec![vec![Rational::from(0); n]; n];
124    let mut b_norm_sq = vec![Rational::from(0); n];
125
126    for i in 0..n {
127        let mut vip = int_row_as_rat(&basis[i]);
128        for j in 0..i {
129            mu[i][j].assign(&dot_int_rat(&basis[i], &star[j]) / &b_norm_sq[j]);
130            for t in 0..ambient {
131                let m = mu[i][j].clone() * star[j][t].clone();
132                let vpt = vip[t].clone();
133                let sub = vpt - &m;
134                vip[t].assign(sub);
135            }
136        }
137        star[i] = vip;
138        let ni = dot_rat(&star[i], &star[i]);
139        b_norm_sq[i].assign(ni);
140    }
141    (mu, star, b_norm_sq)
142}
143
144fn nearest_integer_rational(x: &Rational) -> Integer {
145    Float::with_val(4096u32, x)
146        .round()
147        .to_integer()
148        .unwrap_or_else(|| Integer::from(0))
149}
150
151fn validate_rows(basis: &[Vec<Integer>]) -> Result<usize, LatticeError> {
152    if basis.is_empty() {
153        return Err(LatticeError::EmptyBasis);
154    }
155    let cols = basis[0].len();
156    for (i, row) in basis.iter().enumerate() {
157        if row.len() != cols {
158            return Err(LatticeError::RaggedBasis {
159                row: i,
160                expected_cols: cols,
161                got_cols: row.len(),
162            });
163        }
164    }
165    Ok(cols)
166}
167
168fn validate_delta(delta: &Rational) -> Result<(), LatticeError> {
169    let low = Rational::from((1i32, 4i32));
170    let hi = Rational::from(1u32);
171    if *delta <= low || *delta >= hi {
172        return Err(LatticeError::InvalidDelta {
173            provided: delta.clone(),
174        });
175    }
176    Ok(())
177}
178
179fn size_reduce_single(
180    basis: &mut [Vec<Integer>],
181    mu: &[Vec<Rational>],
182    b_norm_sq: &[Rational],
183    k: usize,
184) -> bool {
185    let mut altered = false;
186    for j in (0..k).rev() {
187        if b_norm_sq[j].is_zero() {
188            continue;
189        }
190        let mij = &mu[k][j];
191        let q = nearest_integer_rational(mij);
192        if q == 0 {
193            continue;
194        }
195        altered = true;
196        for col in 0..basis[k].len() {
197            let bjk = basis[j][col].clone();
198            basis[k][col] -= &(q.clone() * bjk);
199        }
200        return altered;
201    }
202    altered
203}
204
205/// Lovász predicate at index `k` (1-based outer loop index = `k` here 0-indexed):
206/// \(B*_k ≥ (δ − μ²_{k,k−1}) B*_{k−1}\).
207fn lovasz_ok(b_norm_sq: &[Rational], mu: &[Vec<Rational>], delta: &Rational, k: usize) -> bool {
208    if k == 0 {
209        return true;
210    }
211    let bk = &b_norm_sq[k];
212    let bkm1 = &b_norm_sq[k - 1];
213    if bkm1.is_zero() {
214        return false;
215    }
216    let mux = mu[k][k - 1].clone();
217    let mux_sq = Rational::from(&mux * &mux);
218    let mut slack = delta.clone();
219    slack -= &mux_sq;
220    let rhs: Rational = slack * bkm1;
221    bk.clone() >= rhs
222}
223
224fn lll_reduce_once(
225    basis_rows: &[Vec<Integer>],
226    delta: &Rational,
227) -> Result<Vec<Vec<Integer>>, LatticeError> {
228    validate_rows(basis_rows)?;
229    validate_delta(delta)?;
230    let ambient = basis_rows[0].len();
231    let n = basis_rows.len();
232    let mut basis: Vec<Vec<Integer>> = basis_rows.to_vec();
233
234    let mut k: usize = 1;
235    let mut guard: usize = 0;
236    const MAX_LLL_SWAPS: usize = 2_000_000;
237    loop {
238        if k >= n {
239            break;
240        }
241        guard += 1;
242        if guard > MAX_LLL_SWAPS {
243            return Err(LatticeError::IterationLimit { iterations: guard });
244        }
245        // Size-reduce row k until stable (against each successive Gram–Schmidt refresh).
246        loop {
247            let (mu_ref, _, b_norm_sq) = gram_schmidt_rows(&basis);
248            if !size_reduce_single(&mut basis, &mu_ref, &b_norm_sq, k) {
249                break;
250            }
251            // Projection numbers changed materially — reorthogonalise implicitly next loop.
252        }
253        let (mu, _, b_norm_sq) = gram_schmidt_rows(&basis);
254        if lovasz_ok(&b_norm_sq, &mu, delta, k) {
255            k += 1;
256        } else {
257            basis.swap(k, k - 1);
258            k = k.saturating_sub(1);
259            if k < 1 {
260                k = 1;
261            }
262        }
263        // Guard against malformed rank-deficient setups that spin forever:
264        let _ = ambient;
265        if k >= n && n > 8000 {
266            break;
267        }
268    }
269
270    Ok(basis)
271}
272
273/// Run LLL on integer row vectors using the conventional Lovász parameter `δ = ¾`.
274pub fn lattice_reduce_rows(basis_rows: &[Vec<Integer>]) -> Result<Vec<Vec<Integer>>, LatticeError> {
275    let delta = Rational::from((3u32, 4u32));
276    lll_reduce_once(basis_rows, &delta)
277}
278
279/// Same as [`lattice_reduce_rows`], with an explicit `δ ∈ (¼, 1)`.
280pub fn lattice_reduce_rows_with_delta(
281    basis_rows: &[Vec<Integer>],
282    delta: Rational,
283) -> Result<Vec<Vec<Integer>>, LatticeError> {
284    lll_reduce_once(basis_rows, &delta)
285}
286
287/// Check that `(δ, Lovász residuals, coefficient bounds)` satisfy the textbook LLL
288/// inequalities (useful as a regression oracle).
289///
290/// Uses the **current row order**.
291pub fn validate_lll_rows(
292    basis_rows: &[Vec<Integer>],
293    delta: &Rational,
294) -> Result<(), &'static str> {
295    validate_rows(basis_rows).map_err(|_| "shape")?;
296    validate_delta(delta).map_err(|_| "delta")?;
297    let n = basis_rows.len();
298    let (mu, _, b_sq) = gram_schmidt_rows(basis_rows);
299    if n == 1 {
300        return Ok(());
301    }
302    let half = Rational::from((1u32, 2u32));
303    for i in 1..n {
304        for mij in mu[i].iter().take(i) {
305            let mut absmu = mij.clone();
306            absmu.abs_mut();
307            if absmu > half {
308                return Err("size");
309            }
310        }
311        if !lovasz_ok(&b_sq, &mu, delta, i) {
312            return Err("lovasz");
313        }
314    }
315    Ok(())
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use rug::Rational;
322
323    #[test]
324    fn planar_two_vectors_lll() {
325        let rows: Vec<Vec<Integer>> = vec![
326            vec![Integer::from(2), Integer::from(15)],
327            vec![Integer::from(1), Integer::from(21)],
328        ];
329        let reduced = lattice_reduce_rows(&rows).unwrap();
330        let delta = Rational::from((3u32, 4u32));
331        validate_lll_rows(&reduced, &delta).unwrap();
332    }
333
334    #[test]
335    fn knapsack_row_weighted_near_origin() {
336        let rows: Vec<Vec<Integer>> = vec![
337            vec![Integer::from(1), Integer::from(0), Integer::from(5)],
338            vec![Integer::from(0), Integer::from(1), Integer::from(6)],
339            vec![Integer::from(0), Integer::from(0), Integer::from(33)],
340        ];
341        let reduced = lattice_reduce_rows(&rows).unwrap();
342        validate_lll_rows(&reduced, &Rational::from((3u32, 4u32))).unwrap();
343        fn max_row_norm_squared(basis: &[Vec<Integer>]) -> Integer {
344            basis
345                .iter()
346                .map(|row| {
347                    row.iter().fold(Integer::from(0), |a, zi| {
348                        a.clone() + zi.clone() * zi.clone()
349                    })
350                })
351                .max_by(|x, y| x.cmp(y))
352                .unwrap()
353        }
354        assert!(
355            max_row_norm_squared(&reduced) <= max_row_norm_squared(&rows),
356            "maximum squared row norm should shrink on this scaffold"
357        );
358    }
359}