use crate::error::{OptimizeError, OptimizeResult};
use scirs2_core::ndarray::{Array1, Array2};
use super::gram_schmidt::gram_schmidt;
use super::lll::{LLLConfig, LLLReducer};
use super::svp::solve_svp;
#[derive(Debug, Clone)]
pub struct BKZConfig {
pub block_size: usize,
pub max_tours: usize,
pub lll_delta: f64,
pub svp_max_nodes: usize,
}
impl Default for BKZConfig {
fn default() -> Self {
BKZConfig {
block_size: 10,
max_tours: 8,
lll_delta: 0.99,
svp_max_nodes: 100_000,
}
}
}
#[derive(Debug, Clone)]
pub struct BKZResult {
pub reduced_basis: Array2<f64>,
pub n_tours: usize,
pub first_vector_norm: f64,
}
pub struct BKZReducer {
config: BKZConfig,
}
impl BKZReducer {
pub fn new(config: BKZConfig) -> Self {
BKZReducer { config }
}
pub fn reduce(&self, basis: &Array2<f64>) -> OptimizeResult<BKZResult> {
let n = basis.nrows();
let d = basis.ncols();
if n == 0 || d == 0 {
return Err(OptimizeError::ValueError(
"BKZ: basis matrix must be non-empty".to_string(),
));
}
let beta = self.config.block_size.min(n);
let lll_config = LLLConfig {
delta: self.config.lll_delta,
..Default::default()
};
let lll_reducer = LLLReducer::new(lll_config.clone());
let lll_result = lll_reducer.reduce(basis)?;
let mut b = lll_result.reduced_basis;
let mut n_tours = 0usize;
for _tour in 0..self.config.max_tours {
n_tours += 1;
let mut improved = false;
for i in 0..=(n.saturating_sub(beta)) {
let block_end = (i + beta).min(n);
let block_size = block_end - i;
let block_basis = extract_projected_block(&b, i, block_end)?;
let svp_vec = match solve_svp(&block_basis, self.config.svp_max_nodes) {
Ok(v) => v,
Err(_) => continue, };
let svp_norm_sq: f64 = svp_vec.iter().map(|x| x * x).sum();
let (_, bnorm_sq) = gram_schmidt(&block_basis);
let current_b0_proj_norm_sq = if !bnorm_sq.is_empty() { bnorm_sq[0] } else { 0.0 };
if svp_norm_sq < current_b0_proj_norm_sq - 1e-8 {
if let Ok(true) = insert_svp_vector(&mut b, &block_basis, &svp_vec, i, block_end, d) {
let lll_r = lll_reducer.reduce(&b)?;
b = lll_r.reduced_basis;
improved = true;
}
}
if block_end == n && block_size < beta {
break;
}
}
if !improved {
break;
}
}
let first_vector_norm: f64 = b.row(0).iter().map(|x| x * x).sum::<f64>().sqrt();
Ok(BKZResult {
reduced_basis: b,
n_tours,
first_vector_norm,
})
}
}
fn extract_projected_block(
basis: &Array2<f64>,
start: usize,
end: usize,
) -> OptimizeResult<Array2<f64>> {
let d = basis.ncols();
let block_size = end - start;
if block_size == 0 {
return Err(OptimizeError::ValueError(
"BKZ block must be non-empty".to_string(),
));
}
let (mu, bnorm_sq) = gram_schmidt(basis);
let n = basis.nrows();
let mut b_star: Vec<Vec<f64>> = Vec::with_capacity(n);
for i in 0..n {
let mut bsi: Vec<f64> = (0..d).map(|k| basis[[i, k]]).collect();
for j in 0..i {
let c = mu[[i, j]];
for k in 0..d {
bsi[k] -= c * b_star[j][k];
}
}
b_star.push(bsi);
}
let mut block = Array2::<f64>::zeros((block_size, d));
for (bi, i) in (start..end).enumerate() {
let mut proj: Vec<f64> = (0..d).map(|k| basis[[i, k]]).collect();
for j in 0..start {
let c = mu[[i, j]];
for k in 0..d {
proj[k] -= c * b_star[j][k];
}
}
for k in 0..d {
block[[bi, k]] = proj[k];
}
}
let _ = bnorm_sq;
Ok(block)
}
fn insert_svp_vector(
basis: &mut Array2<f64>,
block_basis: &Array2<f64>,
svp_vec: &Array1<f64>,
insert_pos: usize,
block_end: usize,
_d: usize,
) -> OptimizeResult<bool> {
let n_rows = basis.nrows();
let d = basis.ncols();
let block_n = block_basis.nrows();
let mut coeffs = vec![0.0f64; block_n];
let (block_mu, block_bnorm_sq) = gram_schmidt(block_basis);
let mut b_star: Vec<Vec<f64>> = Vec::with_capacity(block_n);
for i in 0..block_n {
let mut bsi: Vec<f64> = (0..d).map(|k| block_basis[[i, k]]).collect();
for j in 0..i {
let c = block_mu[[i, j]];
for k in 0..d {
bsi[k] -= c * b_star[j][k];
}
}
b_star.push(bsi);
}
let mut residual: Vec<f64> = (0..d).map(|k| svp_vec[k]).collect();
for i in (0..block_n).rev() {
if block_bnorm_sq[i] < 1e-14 {
continue;
}
let dot: f64 = residual.iter().zip(b_star[i].iter()).map(|(a, b)| a * b).sum();
let c = dot / block_bnorm_sq[i];
coeffs[i] = c;
for k in 0..d {
residual[k] -= c * b_star[i][k];
}
}
let int_coeffs: Vec<i64> = coeffs.iter().map(|&c| c.round() as i64).collect();
let mut lattice_vec = vec![0.0f64; d];
for (bi, i) in (insert_pos..block_end).enumerate() {
let c = int_coeffs[bi] as f64;
if c != 0.0 {
for k in 0..d {
lattice_vec[k] += c * basis[[i, k]];
}
}
}
let lv_norm_sq: f64 = lattice_vec.iter().map(|x| x * x).sum();
let b0_norm_sq: f64 = (0..d).map(|k| basis[[insert_pos, k]].powi(2)).sum();
if lv_norm_sq < b0_norm_sq - 1e-8 && lv_norm_sq > 1e-14 {
let new_vec = lattice_vec.clone();
for row in (insert_pos + 1..block_end).rev() {
for k in 0..d {
let val = basis[[row - 1, k]];
basis[[row, k]] = val;
}
}
for k in 0..d {
basis[[insert_pos, k]] = new_vec[k];
}
let _ = n_rows; return Ok(true);
}
Ok(false)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_bkz_block2_equivalent_to_lll() {
let basis = array![
[1.0, 1.0, 1.0],
[-1.0, 0.0, 2.0],
[3.0, 5.0, 6.0]
];
let bkz_config = BKZConfig {
block_size: 2,
max_tours: 10,
lll_delta: 0.75,
svp_max_nodes: 10_000,
};
let bkz = BKZReducer::new(bkz_config);
let bkz_result = bkz.reduce(&basis).expect("BKZ should succeed");
let lll = LLLReducer::new(LLLConfig::default());
let lll_result = lll.reduce(&basis).expect("LLL should succeed");
assert_eq!(bkz_result.reduced_basis.nrows(), 3);
assert_eq!(lll_result.reduced_basis.nrows(), 3);
}
#[test]
fn test_bkz_first_vector_not_longer_than_lll() {
let basis = array![
[10.0, 3.0, -1.0, 2.0],
[-4.0, 7.0, 2.0, 1.0],
[1.0, -2.0, 5.0, 3.0],
[2.0, 1.0, -3.0, 6.0]
];
let lll = LLLReducer::new(LLLConfig::default());
let lll_result = lll.reduce(&basis).expect("LLL should succeed");
let lll_norm: f64 = lll_result.reduced_basis.row(0).iter().map(|x| x * x).sum::<f64>().sqrt();
let bkz = BKZReducer::new(BKZConfig {
block_size: 3,
max_tours: 5,
lll_delta: 0.99,
svp_max_nodes: 50_000,
});
let bkz_result = bkz.reduce(&basis).expect("BKZ should succeed");
assert!(
bkz_result.first_vector_norm <= lll_norm + 1e-6,
"BKZ norm {} should be <= LLL norm {}",
bkz_result.first_vector_norm, lll_norm
);
}
#[test]
fn test_bkz_respects_max_tours() {
let basis = array![[1.0, 0.0], [0.0, 1.0]];
let config = BKZConfig {
block_size: 2,
max_tours: 3,
..Default::default()
};
let bkz = BKZReducer::new(config);
let result = bkz.reduce(&basis).expect("BKZ should succeed");
assert!(result.n_tours <= 3, "Should not exceed max_tours");
}
}