use super::halton::halton_direction;
pub(crate) fn mesh_and_poll_size(ell: i32) -> (f64, f64) {
let poll = 2f64.powi(-ell); let mesh = 4f64.powi(-ell).min(1.0); (mesh, poll)
}
pub(crate) fn adjusted_halton_direction(t: usize, ell: i32, n: usize) -> Vec<i64> {
let u = halton_direction(t, n);
let w: Vec<f64> = u.iter().map(|&ui| 2.0 * ui - 1.0).collect();
let w_norm = w.iter().map(|wi| wi * wi).sum::<f64>().sqrt();
debug_assert!(w_norm > 0.0, "2u_t − e = 0 only for n = t = 1");
let c: Vec<f64> = w.iter().map(|wi| wi / w_norm).collect();
let target = 2f64.powf(f64::from(ell.unsigned_abs()) / 2.0); let target_sq = target * target;
let q_at =
|alpha: f64| -> Vec<i64> { c.iter().map(|ci| (alpha * ci).round() as i64).collect() };
let norm_sq = |q: &[i64]| -> i64 { q.iter().map(|qi| qi * qi).sum() };
let c_max = c.iter().fold(0.0_f64, |m, ci| m.max(ci.abs()));
let alpha_ub = (target + 0.5) / c_max;
let mut breakpoints: Vec<f64> = Vec::new();
for ci in &c {
let aci = ci.abs();
if aci == 0.0 {
continue;
}
let mut j = 0u64;
loop {
let bp = (2 * j + 1) as f64 / (2.0 * aci);
if bp > alpha_ub {
break;
}
breakpoints.push(bp);
j += 1;
}
}
breakpoints.sort_by(|a, b| a.partial_cmp(b).expect("finite breakpoints"));
let mut best = vec![0i64; n]; for &bp in &breakpoints {
let q = q_at(bp + bp * 1e-12);
if (norm_sq(&q) as f64) <= target_sq {
best = q;
} else {
break;
}
}
best
}
pub(crate) fn householder_basis(q: &[i64]) -> Vec<Vec<i64>> {
let n = q.len();
let q_norm_sq: i64 = q.iter().map(|qi| qi * qi).sum();
(0..n)
.map(|j| {
(0..n)
.map(|i| {
let diag = if i == j { q_norm_sq } else { 0 };
diag - 2 * q[j] * q[i]
})
.collect()
})
.collect()
}
pub(crate) fn poll_directions(t: usize, ell: i32, n: usize) -> Vec<Vec<i64>> {
let q = adjusted_halton_direction(t, ell, n);
let h = householder_basis(&q);
let mut dirs: Vec<Vec<i64>> = Vec::with_capacity(2 * n);
dirs.extend(h.iter().cloned());
dirs.extend(h.iter().map(|col| col.iter().map(|&x| -x).collect()));
dirs
}
#[cfg(test)]
mod tests {
use super::*;
fn norm_sq(q: &[i64]) -> i64 {
q.iter().map(|qi| qi * qi).sum()
}
fn dot(a: &[i64], b: &[i64]) -> i64 {
a.iter().zip(b).map(|(x, y)| x * y).sum()
}
#[test]
fn mesh_sizes_eq_1() {
assert_eq!(mesh_and_poll_size(0), (1.0, 1.0));
assert_eq!(mesh_and_poll_size(3), (1.0 / 64.0, 1.0 / 8.0)); assert_eq!(mesh_and_poll_size(-2), (1.0, 4.0));
}
#[test]
fn adjusted_direction_table_2() {
let rows: [(usize, i32, [i64; 4], i64); 8] = [
(7, 0, [0, 0, 0, -1], 1),
(8, 1, [-1, 1, 0, 0], 2),
(9, 2, [0, -1, 1, -1], 3),
(10, 3, [-1, -1, -2, 0], 6),
(11, 4, [2, 2, -2, 1], 13),
(12, 5, [-3, -4, 0, 2], 29),
(13, 6, [3, 0, 3, 6], 54),
(14, 7, [-1, 5, 6, -8], 126),
];
for (t, ell, q_expected, nsq_expected) in rows {
let q = adjusted_halton_direction(t, ell, 4);
assert_eq!(q, q_expected.to_vec(), "q_{{{t},{ell}}}");
assert_eq!(norm_sq(&q), nsq_expected, "‖q_{{{t},{ell}}}‖²");
}
}
#[test]
fn figure_1_full_chain() {
let q = adjusted_halton_direction(6, 3, 2);
assert_eq!(q, vec![-1, -2]);
let h = householder_basis(&q);
assert_eq!(h[0], vec![3, -4]);
assert_eq!(h[1], vec![-4, -3]);
assert_eq!(dot(&h[0], &h[1]), 0);
assert_eq!(norm_sq(&h[0]), 25);
assert_eq!(norm_sq(&h[1]), 25);
assert_eq!((norm_sq(&h[0]) as f64).sqrt(), 5.0);
let dirs = poll_directions(6, 3, 2);
assert_eq!(dirs.len(), 4);
assert_eq!(dirs[2], vec![-3, 4]);
assert_eq!(dirs[3], vec![4, 3]);
let (mesh, poll) = mesh_and_poll_size(3);
for d in &dirs {
let dist = mesh * (norm_sq(d) as f64).sqrt();
assert!((dist - 5.0 / 64.0).abs() < 1e-12, "Δᵐ‖d‖ = {dist}");
assert!(dist < poll);
}
}
#[test]
fn orthogonal_basis_in_higher_dim() {
let dirs = poll_directions(11, 4, 4);
assert_eq!(dirs.len(), 8);
let nsq = norm_sq(&dirs[0]);
for i in 0..4 {
assert_eq!(norm_sq(&dirs[i]), nsq, "equal column norms");
for j in 0..4 {
if i != j {
assert_eq!(dot(&dirs[i], &dirs[j]), 0, "columns {i},{j} orthogonal");
}
}
}
}
}