use alloc::boxed::Box;
use num::Zero;
use num_complex::{Complex, Complex64};
use rand::Rng;
use super::{fft::FastFft, polynomial::Polynomial, samplerz::sampler_z};
use crate::utils::zeroize::{Zeroize, ZeroizeOnDrop};
const SIGMIN: f64 = 1.2778336969128337;
pub fn gram(b: [Polynomial<Complex64>; 4]) -> [Polynomial<Complex64>; 4] {
const N: usize = 2;
let mut g: [Polynomial<Complex<f64>>; 4] =
[Polynomial::zero(), Polynomial::zero(), Polynomial::zero(), Polynomial::zero()];
for i in 0..N {
for j in 0..N {
for k in 0..N {
g[N * i + j] = g[N * i + j].clone()
+ b[N * i + k].hadamard_mul(&b[N * j + k].map(Complex::conj));
}
}
}
g
}
pub fn ldl(
g: [Polynomial<Complex64>; 4],
) -> (Polynomial<Complex64>, Polynomial<Complex64>, Polynomial<Complex64>) {
let l10 = g[2].hadamard_div(&g[0]);
let l10_squared_norm = l10.map(|c| c * c.conj());
let d11 = g[3].clone() - g[0].hadamard_mul(&l10_squared_norm);
(l10, g[0].clone(), d11)
}
#[derive(Debug, Clone)]
pub enum LdlTree {
Branch(Polynomial<Complex64>, Box<LdlTree>, Box<LdlTree>),
Leaf([Complex64; 2]),
}
impl Zeroize for LdlTree {
fn zeroize(&mut self) {
match self {
LdlTree::Branch(poly, left, right) => {
for coeff in poly.coefficients.iter_mut() {
unsafe {
core::ptr::write_volatile(coeff, Complex64::new(0.0, 0.0));
}
}
left.zeroize();
right.zeroize();
core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
},
LdlTree::Leaf(arr) => {
for val in arr.iter_mut() {
unsafe {
core::ptr::write_volatile(val, Complex64::new(0.0, 0.0));
}
}
core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
},
}
}
}
impl Drop for LdlTree {
fn drop(&mut self) {
self.zeroize();
}
}
impl ZeroizeOnDrop for LdlTree {}
pub fn ffldl(gram_matrix: [Polynomial<Complex64>; 4]) -> LdlTree {
let n = gram_matrix[0].coefficients.len();
let (l10, d00, d11) = ldl(gram_matrix);
if n > 2 {
let (d00_left, d00_right) = d00.split_fft();
let (d11_left, d11_right) = d11.split_fft();
let g0 = [d00_left.clone(), d00_right.clone(), d00_right.map(Complex::conj), d00_left];
let g1 = [d11_left.clone(), d11_right.clone(), d11_right.map(Complex::conj), d11_left];
LdlTree::Branch(l10, Box::new(ffldl(g0)), Box::new(ffldl(g1)))
} else {
LdlTree::Branch(
l10,
Box::new(LdlTree::Leaf(d00.coefficients.try_into().unwrap())),
Box::new(LdlTree::Leaf(d11.coefficients.try_into().unwrap())),
)
}
}
pub fn normalize_tree(tree: &mut LdlTree, sigma: f64) {
match tree {
LdlTree::Branch(_ell, left, right) => {
normalize_tree(left, sigma);
normalize_tree(right, sigma);
},
LdlTree::Leaf(vector) => {
vector[0] = Complex::new(sigma / vector[0].re.sqrt(), 0.0);
vector[1] = Complex64::zero();
},
}
}
pub fn ffsampling<R: Rng>(
t: &(Polynomial<Complex64>, Polynomial<Complex64>),
tree: &LdlTree,
mut rng: &mut R,
) -> (Polynomial<Complex64>, Polynomial<Complex64>) {
match tree {
LdlTree::Branch(ell, left, right) => {
let bold_t1 = t.1.split_fft();
let bold_z1 = ffsampling(&bold_t1, right, rng);
let z1 = Polynomial::<Complex64>::merge_fft(&bold_z1.0, &bold_z1.1);
let t0_prime = t.0.clone() + (t.1.clone() - z1.clone()).hadamard_mul(ell);
let bold_t0 = t0_prime.split_fft();
let bold_z0 = ffsampling(&bold_t0, left, rng);
let z0 = Polynomial::<Complex64>::merge_fft(&bold_z0.0, &bold_z0.1);
(z0, z1)
},
LdlTree::Leaf(value) => {
let z0 = sampler_z(t.0.coefficients[0].re, value[0].re, SIGMIN, &mut rng);
let z1 = sampler_z(t.1.coefficients[0].re, value[0].re, SIGMIN, &mut rng);
(
Polynomial::new(vec![Complex64::new(z0 as f64, 0.0)]),
Polynomial::new(vec![Complex64::new(z1 as f64, 0.0)]),
)
},
}
}
#[cfg(test)]
mod tests {
use num_complex::Complex64;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
use super::*;
fn reconstruct_g(
l10: &Polynomial<Complex64>,
d00: &Polynomial<Complex64>,
d11: &Polynomial<Complex64>,
) -> [Polynomial<Complex64>; 4] {
let l10_conj = l10.map(Complex::conj);
let g00 = d00.clone();
let g01 = d00.hadamard_mul(&l10_conj);
let g10 = l10.hadamard_mul(d00);
let g11 = l10.hadamard_mul(d00).hadamard_mul(&l10_conj) + d11.clone();
[g00, g01, g10, g11]
}
fn random_hermitian_matrix(n: usize, rng: &mut impl Rng) -> [Polynomial<Complex64>; 4] {
let mut g00 = vec![Complex64::new(0.0, 0.0); n];
let mut g01 = vec![Complex64::new(0.0, 0.0); n];
let mut g11 = vec![Complex64::new(0.0, 0.0); n];
for i in 0..n {
g00[i] = Complex64::new(rng.random_range(-10.0..10.0), 0.0);
g11[i] = Complex64::new(rng.random_range(-10.0..10.0), 0.0);
g01[i] = Complex64::new(rng.random_range(-10.0..10.0), rng.random_range(-10.0..10.0));
}
let g10 = g01.iter().map(Complex::conj).collect();
[
Polynomial::new(g00),
Polynomial::new(g01),
Polynomial::new(g10),
Polynomial::new(g11),
]
}
fn polynomials_approx_eq(
a: &Polynomial<Complex64>,
b: &Polynomial<Complex64>,
eps: f64,
) -> bool {
if a.coefficients.len() != b.coefficients.len() {
return false;
}
a.coefficients
.iter()
.zip(b.coefficients.iter())
.all(|(x, y)| (x.re - y.re).abs() < eps && (x.im - y.im).abs() < eps)
}
#[test]
fn test_ldl_decomposition_random() {
let mut rng = ChaCha20Rng::from_seed([42u8; 32]);
for degree in [1, 2, 16, 512] {
let g = random_hermitian_matrix(degree, &mut rng);
let (l10, d00, d11) = ldl(g.clone());
let g_reconstructed = reconstruct_g(&l10, &d00, &d11);
assert!(
polynomials_approx_eq(&g_reconstructed[0], &g[0], 1e-10),
"degree {degree}: G[0,0] mismatch"
);
assert!(
polynomials_approx_eq(&g_reconstructed[1], &g[1], 1e-10),
"degree {degree}: G[0,1] mismatch"
);
assert!(
polynomials_approx_eq(&g_reconstructed[2], &g[2], 1e-10),
"degree {degree}: G[1,0] mismatch"
);
assert!(
polynomials_approx_eq(&g_reconstructed[3], &g[3], 1e-10),
"degree {degree}: G[1,1] mismatch"
);
}
}
}