#![allow(dead_code)]
use super::fft::{Cplx, Fft, add_fft, adj_fft, div_fft, mul_fft, sub_fft};
use super::fpr::Fpr;
use super::sampler::{SamplerRng, sampler_z};
use alloc::boxed::Box;
use alloc::vec::Vec;
pub(crate) type Gram = [[Vec<Cplx>; 2]; 2];
pub(crate) enum FftTree {
Leaf(Fpr),
Node {
l10: Vec<Cplx>,
left: Box<FftTree>,
right: Box<FftTree>,
},
}
pub(crate) fn gram(b: &[[Vec<Cplx>; 2]; 2]) -> Gram {
let mut g: Gram = [[Vec::new(), Vec::new()], [Vec::new(), Vec::new()]];
for (i, grow) in g.iter_mut().enumerate() {
for (j, gij) in grow.iter_mut().enumerate() {
let mut acc = vec_zero(b[0][0].len());
for k in 0..2 {
let term = mul_fft(&b[i][k], &adj_fft(&b[j][k]));
for (a, t) in acc.iter_mut().zip(term) {
*a = a.add(t);
}
}
*gij = acc;
}
}
g
}
fn vec_zero(m: usize) -> Vec<Cplx> {
let mut v = Vec::with_capacity(m);
v.resize(m, Cplx::zero());
v
}
pub(crate) fn ffldl(fft: &Fft, g: &Gram, sigma: Fpr) -> FftTree {
let m = g[0][0].len();
let d00 = g[0][0].clone();
let l10 = div_fft(&g[1][0], &g[0][0]);
let tmp = mul_fft(&mul_fft(&l10, &adj_fft(&l10)), &g[0][0]);
let d11 = sub_fft(&g[1][1], &tmp);
if m > 2 {
let (d00a, d00b) = fft.split_fft(&d00);
let (d11a, d11b) = fft.split_fft(&d11);
let g0: Gram = [[d00a.clone(), d00b.clone()], [adj_fft(&d00b), d00a]];
let g1: Gram = [[d11a.clone(), d11b.clone()], [adj_fft(&d11b), d11a]];
FftTree::Node {
l10,
left: Box::new(ffldl(fft, &g0, sigma)),
right: Box::new(ffldl(fft, &g1, sigma)),
}
} else {
let leaf0 = sigma.div(d00[0].re.sqrt());
let leaf1 = sigma.div(d11[0].re.sqrt());
FftTree::Node {
l10,
left: Box::new(FftTree::Leaf(leaf0)),
right: Box::new(FftTree::Leaf(leaf1)),
}
}
}
pub(crate) fn ff_sampling<R: SamplerRng>(
fft: &Fft,
t0: &[Cplx],
t1: &[Cplx],
tree: &FftTree,
sigmin: Fpr,
rng: &mut R,
) -> (Vec<Cplx>, Vec<Cplx>) {
match tree {
FftTree::Leaf(sigma) => {
let z0 = sampler_z(t0[0].re, *sigma, sigmin, rng);
let z1 = sampler_z(t1[0].re, *sigma, sigmin, rng);
(
alloc::vec![Cplx::new(Fpr::of_i64(z0), Fpr::from_f64(0.0))],
alloc::vec![Cplx::new(Fpr::of_i64(z1), Fpr::from_f64(0.0))],
)
}
FftTree::Node { l10, left, right } => {
let (t1a, t1b) = fft.split_fft(t1);
let (z1a, z1b) = ff_sampling(fft, &t1a, &t1b, right, sigmin, rng);
let z1 = fft.merge_fft_pub(&z1a, &z1b);
let diff = sub_fft(t1, &z1);
let t0b = add_fft(t0, &mul_fft(&diff, l10));
let (t0a, t0bb) = fft.split_fft(&t0b);
let (z0a, z0b) = ff_sampling(fft, &t0a, &t0bb, left, sigmin, rng);
let z0 = fft.merge_fft_pub(&z0a, &z0b);
(z0, z1)
}
}
}
#[cfg(test)]
#[path = "tree_tests.rs"]
mod tree_tests;