use crate::error::{OptimizeError, OptimizeResult};
use scirs2_core::ndarray::{Array2};
use super::bkz::{BKZConfig, BKZReducer};
use super::lll::{LLLConfig, LLLReducer};
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReductionMethod {
LLL,
BKZ,
}
impl Default for ReductionMethod {
fn default() -> Self {
ReductionMethod::LLL
}
}
#[derive(Debug, Clone)]
pub struct LatticePreprocessorConfig {
pub method: ReductionMethod,
pub lll_config: LLLConfig,
pub bkz_config: BKZConfig,
pub apply_to_columns: bool,
}
impl Default for LatticePreprocessorConfig {
fn default() -> Self {
LatticePreprocessorConfig {
method: ReductionMethod::LLL,
lll_config: LLLConfig::default(),
bkz_config: BKZConfig::default(),
apply_to_columns: true,
}
}
}
#[derive(Debug, Clone)]
pub struct LatticePreprocessorResult {
pub transformed_a: Array2<f64>,
pub transform: Array2<f64>,
pub transform_inv: Array2<f64>,
pub reduction_method: ReductionMethod,
}
pub struct LatticePreprocessor {
config: LatticePreprocessorConfig,
}
impl LatticePreprocessor {
pub fn new(config: LatticePreprocessorConfig) -> Self {
LatticePreprocessor { config }
}
pub fn preprocess(&self, a: &Array2<f64>) -> OptimizeResult<LatticePreprocessorResult> {
let m = a.nrows();
let n = a.ncols();
if m == 0 || n == 0 {
return Err(OptimizeError::ValueError(
"LatticePreprocessor: constraint matrix must be non-empty".to_string(),
));
}
let basis: Array2<f64> = if self.config.apply_to_columns {
a.t().to_owned()
} else {
a.to_owned()
};
let (u_transform, reduced_basis) = self.apply_reduction(&basis)?;
let n_basis = basis.nrows();
let u = u_transform.t().to_owned();
let u_inv = Self::integer_matrix_inverse(&u)?;
let transformed_a = if self.config.apply_to_columns {
matmul_f64(a, &u_inv)
} else {
matmul_f64(&u_inv, a)
};
let _ = (reduced_basis, n_basis);
Ok(LatticePreprocessorResult {
transformed_a,
transform: u,
transform_inv: u_inv,
reduction_method: self.config.method,
})
}
pub fn verify(&self, result: &LatticePreprocessorResult, a: &Array2<f64>) -> bool {
let tol = 1e-6;
let reconstructed = if self.config.apply_to_columns {
matmul_f64(&result.transformed_a, &result.transform)
} else {
matmul_f64(&result.transform, &result.transformed_a)
};
if reconstructed.nrows() != a.nrows() || reconstructed.ncols() != a.ncols() {
return false;
}
for i in 0..a.nrows() {
for j in 0..a.ncols() {
if (reconstructed[[i, j]] - a[[i, j]]).abs() > tol {
return false;
}
}
}
true
}
fn apply_reduction(
&self,
basis: &Array2<f64>,
) -> OptimizeResult<(Array2<f64>, Array2<f64>)> {
match self.config.method {
ReductionMethod::LLL => {
let reducer = LLLReducer::new(self.config.lll_config.clone());
let result = reducer.reduce(basis)?;
Ok((result.transformation, result.reduced_basis))
}
ReductionMethod::BKZ => {
let reducer = BKZReducer::new(self.config.bkz_config.clone());
let result = reducer.reduce(basis)?;
let lll_config = LLLConfig {
delta: self.config.bkz_config.lll_delta,
..Default::default()
};
let lll_reducer = LLLReducer::new(lll_config);
let lll_result = lll_reducer.reduce(basis)?;
let final_basis = result.reduced_basis;
Ok((lll_result.transformation, final_basis))
}
_ => Err(OptimizeError::NotImplementedError(
"Unknown reduction method".to_string(),
)),
}
}
pub(crate) fn integer_matrix_inverse(u: &Array2<f64>) -> OptimizeResult<Array2<f64>> {
let n = u.nrows();
if n != u.ncols() {
return Err(OptimizeError::ValueError(
"Matrix must be square for inversion".to_string(),
));
}
if n == 0 {
return Ok(Array2::<f64>::zeros((0, 0)));
}
if n == 1 {
let val = u[[0, 0]];
if val.abs() < 0.5 {
return Err(OptimizeError::ComputationError(
"Matrix is singular (det ≈ 0)".to_string(),
));
}
let mut inv = Array2::<f64>::zeros((1, 1));
inv[[0, 0]] = 1.0 / val;
return Ok(inv);
}
let det = compute_det(u)?;
if det.abs() < 0.5 {
return Err(OptimizeError::ComputationError(format!(
"Matrix is singular or not unimodular: det = {}",
det
)));
}
let sign = if det > 0.0 { 1.0 } else { -1.0 };
let _ = sign;
let adj = compute_adjugate(u)?;
let inv_det = 1.0 / det;
let mut inv = adj;
for i in 0..n {
for j in 0..n {
inv[[i, j]] *= inv_det;
}
}
Ok(inv)
}
}
fn compute_det(m: &Array2<f64>) -> OptimizeResult<f64> {
let n = m.nrows();
if n == 1 {
return Ok(m[[0, 0]]);
}
if n == 2 {
return Ok(m[[0, 0]] * m[[1, 1]] - m[[0, 1]] * m[[1, 0]]);
}
if n == 3 {
let d = m[[0, 0]] * (m[[1, 1]] * m[[2, 2]] - m[[1, 2]] * m[[2, 1]])
- m[[0, 1]] * (m[[1, 0]] * m[[2, 2]] - m[[1, 2]] * m[[2, 0]])
+ m[[0, 2]] * (m[[1, 0]] * m[[2, 1]] - m[[1, 1]] * m[[2, 0]]);
return Ok(d);
}
compute_det_lu(m)
}
fn compute_det_lu(m: &Array2<f64>) -> OptimizeResult<f64> {
let n = m.nrows();
let mut lu = m.to_owned();
let mut sign = 1.0f64;
for k in 0..n {
let mut max_val = lu[[k, k]].abs();
let mut max_row = k;
for i in (k + 1)..n {
if lu[[i, k]].abs() > max_val {
max_val = lu[[i, k]].abs();
max_row = i;
}
}
if max_val < 1e-14 {
return Ok(0.0);
}
if max_row != k {
for j in 0..n {
let tmp = lu[[k, j]];
lu[[k, j]] = lu[[max_row, j]];
lu[[max_row, j]] = tmp;
}
sign = -sign;
}
let pivot = lu[[k, k]];
for i in (k + 1)..n {
let factor = lu[[i, k]] / pivot;
for j in k..n {
let val = lu[[k, j]];
lu[[i, j]] -= factor * val;
}
}
}
let diag_prod: f64 = (0..n).map(|i| lu[[i, i]]).product();
Ok(sign * diag_prod)
}
fn compute_adjugate(m: &Array2<f64>) -> OptimizeResult<Array2<f64>> {
let n = m.nrows();
let mut adj = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
let minor = extract_minor(m, i, j);
let cofactor_val = compute_det(&minor)?;
let sign = if (i + j) % 2 == 0 { 1.0 } else { -1.0 };
adj[[j, i]] = sign * cofactor_val;
}
}
Ok(adj)
}
fn extract_minor(m: &Array2<f64>, del_row: usize, del_col: usize) -> Array2<f64> {
let n = m.nrows();
let mut minor = Array2::<f64>::zeros((n - 1, n - 1));
let mut ri = 0;
for i in 0..n {
if i == del_row {
continue;
}
let mut ci = 0;
for j in 0..n {
if j == del_col {
continue;
}
minor[[ri, ci]] = m[[i, j]];
ci += 1;
}
ri += 1;
}
minor
}
fn matmul_f64(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
let m = a.nrows();
let k = a.ncols();
let n = b.ncols();
let mut c = Array2::<f64>::zeros((m, n));
for i in 0..m {
for j in 0..n {
let mut s = 0.0f64;
for l in 0..k {
s += a[[i, l]] * b[[l, j]];
}
c[[i, j]] = s;
}
}
c
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_integer_inverse_2x2() {
let u = array![[1.0, 2.0], [0.0, 1.0]];
let u_inv = LatticePreprocessor::integer_matrix_inverse(&u).expect("Should invert");
assert!((u_inv[[0, 0]] - 1.0).abs() < 1e-6);
assert!((u_inv[[0, 1]] - (-2.0)).abs() < 1e-6);
assert!((u_inv[[1, 0]] - 0.0).abs() < 1e-6);
assert!((u_inv[[1, 1]] - 1.0).abs() < 1e-6);
}
#[test]
fn test_integer_inverse_3x3() {
let u = array![[1.0, 0.0, 1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
let u_inv = LatticePreprocessor::integer_matrix_inverse(&u).expect("Should invert");
assert!((u_inv[[0, 2]] - (-1.0)).abs() < 1e-6, "u_inv[0][2] = {}", u_inv[[0, 2]]);
let prod = matmul_f64(&u, &u_inv);
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!((prod[[i, j]] - expected).abs() < 1e-6, "prod[{}][{}] = {}", i, j, prod[[i,j]]);
}
}
}
#[test]
fn test_lattice_preprocessor_lll_method() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let config = LatticePreprocessorConfig {
method: ReductionMethod::LLL,
..Default::default()
};
let pp = LatticePreprocessor::new(config);
let result = pp.preprocess(&a).expect("Preprocessing should succeed");
assert_eq!(result.reduction_method, ReductionMethod::LLL);
assert!(pp.verify(&result, &a), "Verification should pass for LLL method");
}
#[test]
fn test_lattice_preprocessor_bkz_method() {
let a = array![[1.0, 0.0], [0.0, 1.0]];
let config = LatticePreprocessorConfig {
method: ReductionMethod::BKZ,
bkz_config: BKZConfig {
block_size: 2,
max_tours: 3,
..Default::default()
},
..Default::default()
};
let pp = LatticePreprocessor::new(config);
let result = pp.preprocess(&a).expect("Preprocessing should succeed");
assert_eq!(result.reduction_method, ReductionMethod::BKZ);
}
#[test]
fn test_lattice_preprocessor_verify() {
let a = array![[2.0, 1.0, 3.0], [1.0, 4.0, 2.0]];
let pp = LatticePreprocessor::new(LatticePreprocessorConfig::default());
let result = pp.preprocess(&a).expect("Preprocessing should succeed");
assert!(
pp.verify(&result, &a),
"A should equal transformed_a * U"
);
}
#[test]
fn test_lattice_preprocessor_deterministic() {
let a = array![[1.0, 1.0], [-1.0, 2.0]];
let pp = LatticePreprocessor::new(LatticePreprocessorConfig::default());
let r1 = pp.preprocess(&a).expect("First call should succeed");
let r2 = pp.preprocess(&a).expect("Second call should succeed");
for i in 0..2 {
for j in 0..2 {
assert!(
(r1.transformed_a[[i, j]] - r2.transformed_a[[i, j]]).abs() < 1e-8,
"Results should be deterministic"
);
}
}
}
#[test]
fn test_lattice_preprocessor_identity_matrix() {
let a = Array2::<f64>::eye(3);
let pp = LatticePreprocessor::new(LatticePreprocessorConfig::default());
let result = pp.preprocess(&a).expect("Should succeed");
assert!(pp.verify(&result, &a), "Verify should pass for identity");
}
#[test]
fn test_reduction_method_non_exhaustive() {
let m = ReductionMethod::LLL;
let _ = match m {
ReductionMethod::LLL => "lll",
ReductionMethod::BKZ => "bkz",
_ => "other",
};
}
}