1use crate::errors::AlkahestError;
12use rug::{Assign, Float, Integer, Rational};
13use std::fmt;
14
15#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum LatticeError {
18 EmptyBasis,
20 RaggedBasis {
22 row: usize,
23 expected_cols: usize,
24 got_cols: usize,
25 },
26 InvalidDelta { provided: Rational },
28 IterationLimit { iterations: usize },
30}
31
32impl fmt::Display for LatticeError {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 LatticeError::EmptyBasis => write!(f, "LLL expects at least one basis row"),
36 LatticeError::RaggedBasis {
37 row,
38 expected_cols,
39 got_cols,
40 } => write!(
41 f,
42 "row {row} has length {got_cols}; expected ambient dimension {expected_cols}"
43 ),
44 LatticeError::InvalidDelta { .. } => {
45 write!(f, "LLL Lovász factor δ must lie strictly between ¼ and 1")
46 }
47 LatticeError::IterationLimit { iterations } => write!(
48 f,
49 "LLL reduction aborted after {iterations} swaps (degenerate span or oversized basis)"
50 ),
51 }
52 }
53}
54
55impl std::error::Error for LatticeError {}
56
57impl AlkahestError for LatticeError {
58 fn code(&self) -> &'static str {
59 match self {
60 LatticeError::EmptyBasis => "E-LAT-001",
61 LatticeError::RaggedBasis { .. } => "E-LAT-002",
62 LatticeError::InvalidDelta { .. } => "E-LAT-003",
63 LatticeError::IterationLimit { .. } => "E-LAT-004",
64 }
65 }
66
67 fn remediation(&self) -> Option<&'static str> {
68 match self {
69 LatticeError::EmptyBasis => {
70 Some("pass a non-empty list of equally long integer coefficient rows")
71 }
72 LatticeError::RaggedBasis { .. } => {
73 Some("pad or trim rows so every basis vector lies in ℤ^m for fixed m")
74 }
75 LatticeError::InvalidDelta { .. } => {
76 Some("use the default δ = ¾, or choose another rational strictly between ¼ and 1")
77 }
78 LatticeError::IterationLimit { .. } => Some(
79 "check for rank-deficient rows, reduce dimension, or report a bug with a minimal basis",
80 ),
81 }
82 }
83}
84
85#[inline]
86fn dot_int_rat(row: &[Integer], v: &[Rational]) -> Rational {
87 let mut acc = Rational::from(0u32);
88 for (zi, vv) in row.iter().zip(v.iter()) {
89 let mut term = Rational::from(0u32);
90 let prod = Rational::from(zi) * vv;
91 term.assign(&prod);
92 acc += term;
93 }
94 acc
95}
96
97fn dot_rat(a: &[Rational], b: &[Rational]) -> Rational {
98 let mut acc = Rational::from(0u32);
99 for (x, y) in a.iter().zip(b.iter()) {
100 let mut term = Rational::from(0u32);
101 let prod = x.clone() * y.clone();
102 term.assign(&prod);
103 acc += term;
104 }
105 acc
106}
107
108fn int_row_as_rat(row: &[Integer]) -> Vec<Rational> {
109 row.iter().map(Rational::from).collect()
110}
111
112fn gram_schmidt_rows(
118 basis: &[Vec<Integer>],
119) -> (Vec<Vec<Rational>>, Vec<Vec<Rational>>, Vec<Rational>) {
120 let n = basis.len();
121 let ambient = basis[0].len();
122 let mut star = vec![vec![Rational::from(0); ambient]; n];
123 let mut mu = vec![vec![Rational::from(0); n]; n];
124 let mut b_norm_sq = vec![Rational::from(0); n];
125
126 for i in 0..n {
127 let mut vip = int_row_as_rat(&basis[i]);
128 for j in 0..i {
129 mu[i][j].assign(&dot_int_rat(&basis[i], &star[j]) / &b_norm_sq[j]);
130 for t in 0..ambient {
131 let m = mu[i][j].clone() * star[j][t].clone();
132 let vpt = vip[t].clone();
133 let sub = vpt - &m;
134 vip[t].assign(sub);
135 }
136 }
137 star[i] = vip;
138 let ni = dot_rat(&star[i], &star[i]);
139 b_norm_sq[i].assign(ni);
140 }
141 (mu, star, b_norm_sq)
142}
143
144fn nearest_integer_rational(x: &Rational) -> Integer {
145 Float::with_val(4096u32, x)
146 .round()
147 .to_integer()
148 .unwrap_or_else(|| Integer::from(0))
149}
150
151fn validate_rows(basis: &[Vec<Integer>]) -> Result<usize, LatticeError> {
152 if basis.is_empty() {
153 return Err(LatticeError::EmptyBasis);
154 }
155 let cols = basis[0].len();
156 for (i, row) in basis.iter().enumerate() {
157 if row.len() != cols {
158 return Err(LatticeError::RaggedBasis {
159 row: i,
160 expected_cols: cols,
161 got_cols: row.len(),
162 });
163 }
164 }
165 Ok(cols)
166}
167
168fn validate_delta(delta: &Rational) -> Result<(), LatticeError> {
169 let low = Rational::from((1i32, 4i32));
170 let hi = Rational::from(1u32);
171 if *delta <= low || *delta >= hi {
172 return Err(LatticeError::InvalidDelta {
173 provided: delta.clone(),
174 });
175 }
176 Ok(())
177}
178
179fn size_reduce_single(
180 basis: &mut [Vec<Integer>],
181 mu: &[Vec<Rational>],
182 b_norm_sq: &[Rational],
183 k: usize,
184) -> bool {
185 let mut altered = false;
186 for j in (0..k).rev() {
187 if b_norm_sq[j].is_zero() {
188 continue;
189 }
190 let mij = &mu[k][j];
191 let q = nearest_integer_rational(mij);
192 if q == 0 {
193 continue;
194 }
195 altered = true;
196 for col in 0..basis[k].len() {
197 let bjk = basis[j][col].clone();
198 basis[k][col] -= &(q.clone() * bjk);
199 }
200 return altered;
201 }
202 altered
203}
204
205fn lovasz_ok(b_norm_sq: &[Rational], mu: &[Vec<Rational>], delta: &Rational, k: usize) -> bool {
208 if k == 0 {
209 return true;
210 }
211 let bk = &b_norm_sq[k];
212 let bkm1 = &b_norm_sq[k - 1];
213 if bkm1.is_zero() {
214 return false;
215 }
216 let mux = mu[k][k - 1].clone();
217 let mux_sq = Rational::from(&mux * &mux);
218 let mut slack = delta.clone();
219 slack -= &mux_sq;
220 let rhs: Rational = slack * bkm1;
221 bk.clone() >= rhs
222}
223
224fn lll_reduce_once(
225 basis_rows: &[Vec<Integer>],
226 delta: &Rational,
227) -> Result<Vec<Vec<Integer>>, LatticeError> {
228 validate_rows(basis_rows)?;
229 validate_delta(delta)?;
230 let ambient = basis_rows[0].len();
231 let n = basis_rows.len();
232 let mut basis: Vec<Vec<Integer>> = basis_rows.to_vec();
233
234 let mut k: usize = 1;
235 let mut guard: usize = 0;
236 const MAX_LLL_SWAPS: usize = 2_000_000;
237 loop {
238 if k >= n {
239 break;
240 }
241 guard += 1;
242 if guard > MAX_LLL_SWAPS {
243 return Err(LatticeError::IterationLimit { iterations: guard });
244 }
245 loop {
247 let (mu_ref, _, b_norm_sq) = gram_schmidt_rows(&basis);
248 if !size_reduce_single(&mut basis, &mu_ref, &b_norm_sq, k) {
249 break;
250 }
251 }
253 let (mu, _, b_norm_sq) = gram_schmidt_rows(&basis);
254 if lovasz_ok(&b_norm_sq, &mu, delta, k) {
255 k += 1;
256 } else {
257 basis.swap(k, k - 1);
258 k = k.saturating_sub(1);
259 if k < 1 {
260 k = 1;
261 }
262 }
263 let _ = ambient;
265 if k >= n && n > 8000 {
266 break;
267 }
268 }
269
270 Ok(basis)
271}
272
273pub fn lattice_reduce_rows(basis_rows: &[Vec<Integer>]) -> Result<Vec<Vec<Integer>>, LatticeError> {
275 let delta = Rational::from((3u32, 4u32));
276 lll_reduce_once(basis_rows, &delta)
277}
278
279pub fn lattice_reduce_rows_with_delta(
281 basis_rows: &[Vec<Integer>],
282 delta: Rational,
283) -> Result<Vec<Vec<Integer>>, LatticeError> {
284 lll_reduce_once(basis_rows, &delta)
285}
286
287pub fn validate_lll_rows(
292 basis_rows: &[Vec<Integer>],
293 delta: &Rational,
294) -> Result<(), &'static str> {
295 validate_rows(basis_rows).map_err(|_| "shape")?;
296 validate_delta(delta).map_err(|_| "delta")?;
297 let n = basis_rows.len();
298 let (mu, _, b_sq) = gram_schmidt_rows(basis_rows);
299 if n == 1 {
300 return Ok(());
301 }
302 let half = Rational::from((1u32, 2u32));
303 for i in 1..n {
304 for mij in mu[i].iter().take(i) {
305 let mut absmu = mij.clone();
306 absmu.abs_mut();
307 if absmu > half {
308 return Err("size");
309 }
310 }
311 if !lovasz_ok(&b_sq, &mu, delta, i) {
312 return Err("lovasz");
313 }
314 }
315 Ok(())
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321 use rug::Rational;
322
323 #[test]
324 fn planar_two_vectors_lll() {
325 let rows: Vec<Vec<Integer>> = vec![
326 vec![Integer::from(2), Integer::from(15)],
327 vec![Integer::from(1), Integer::from(21)],
328 ];
329 let reduced = lattice_reduce_rows(&rows).unwrap();
330 let delta = Rational::from((3u32, 4u32));
331 validate_lll_rows(&reduced, &delta).unwrap();
332 }
333
334 #[test]
335 fn knapsack_row_weighted_near_origin() {
336 let rows: Vec<Vec<Integer>> = vec![
337 vec![Integer::from(1), Integer::from(0), Integer::from(5)],
338 vec![Integer::from(0), Integer::from(1), Integer::from(6)],
339 vec![Integer::from(0), Integer::from(0), Integer::from(33)],
340 ];
341 let reduced = lattice_reduce_rows(&rows).unwrap();
342 validate_lll_rows(&reduced, &Rational::from((3u32, 4u32))).unwrap();
343 fn max_row_norm_squared(basis: &[Vec<Integer>]) -> Integer {
344 basis
345 .iter()
346 .map(|row| {
347 row.iter().fold(Integer::from(0), |a, zi| {
348 a.clone() + zi.clone() * zi.clone()
349 })
350 })
351 .max_by(|x, y| x.cmp(y))
352 .unwrap()
353 }
354 assert!(
355 max_row_norm_squared(&reduced) <= max_row_norm_squared(&rows),
356 "maximum squared row norm should shrink on this scaffold"
357 );
358 }
359}