#[inline(always)]
pub(crate) fn faa_top2(m: [f64; 2], q: &[f64; 4]) -> f64 {
m[1] * q[1] * q[2] + m[0] * q[3]
}
#[inline(always)]
pub(crate) fn faa_top3(m: [f64; 3], q: &[f64; 8]) -> f64 {
let (a, b, u) = (1usize, 2, 4);
let p3 = q[a] * q[b] * q[u];
let p2 = q[a | b] * q[u] + q[a | u] * q[b] + q[b | u] * q[a];
let p1 = q[a | b | u];
m[2] * p3 + m[1] * p2 + m[0] * p1
}
#[inline(always)]
pub(crate) fn faa_top4(m: [f64; 4], q: &[f64; 16]) -> f64 {
let (a, b, u, v) = (1usize, 2, 4, 8);
let p4 = q[a] * q[b] * q[u] * q[v];
let p3 = q[a | b] * q[u] * q[v]
+ q[a | u] * q[b] * q[v]
+ q[a | v] * q[b] * q[u]
+ q[b | u] * q[a] * q[v]
+ q[b | v] * q[a] * q[u]
+ q[u | v] * q[a] * q[b];
let p2 = q[a | b] * q[u | v]
+ q[a | u] * q[b | v]
+ q[a | v] * q[b | u]
+ q[a | b | u] * q[v]
+ q[a | b | v] * q[u]
+ q[a | u | v] * q[b]
+ q[b | u | v] * q[a];
let p1 = q[a | b | u | v];
m[3] * p4 + m[2] * p3 + m[1] * p2 + m[0] * p1
}
#[cfg(test)]
mod oracle_tests {
use super::*;
use gam_math::jet_algebra::faa_di_bruno;
fn stream(seed: u64) -> impl FnMut() -> f64 {
let mut s = seed;
move || {
s = s
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((s >> 11) as f64 / (1u64 << 53) as f64) * 2.0 - 1.0
}
}
fn walker_top(n: usize, derivs: &[f64], q: &[f64]) -> f64 {
let positions: Vec<usize> = (0..n).collect();
faa_di_bruno(&positions, derivs, |block| {
let mask: usize = block.iter().fold(0usize, |acc, &p| acc | (1 << p));
q[mask]
})
}
#[test]
fn faa_top2_matches_runtime_walker() {
let mut next = stream(0x2);
for _ in 0..500 {
let m = [next(), next()];
let mut q = [0.0; 4];
for (mask, qm) in q.iter_mut().enumerate() {
if mask != 0 {
*qm = next();
}
}
let derivs = [0.0, m[0], m[1]];
let got = faa_top2(m, &q);
let want = walker_top(2, &derivs, &q);
assert!(
(got - want).abs() <= 1e-12 * want.abs().max(1.0),
"faa_top2 {got:+.17e} vs walker {want:+.17e}"
);
}
}
#[test]
fn faa_top3_matches_runtime_walker() {
let mut next = stream(0x3);
for _ in 0..500 {
let m = [next(), next(), next()];
let mut q = [0.0; 8];
for (mask, qm) in q.iter_mut().enumerate() {
if mask != 0 {
*qm = next();
}
}
let derivs = [0.0, m[0], m[1], m[2]];
let got = faa_top3(m, &q);
let want = walker_top(3, &derivs, &q);
assert!(
(got - want).abs() <= 1e-12 * want.abs().max(1.0),
"faa_top3 {got:+.17e} vs walker {want:+.17e}"
);
}
}
#[test]
fn faa_top4_matches_runtime_walker() {
let mut next = stream(0x4);
for _ in 0..500 {
let m = [next(), next(), next(), next()];
let mut q = [0.0; 16];
for (mask, qm) in q.iter_mut().enumerate() {
if mask != 0 {
*qm = next();
}
}
let derivs = [0.0, m[0], m[1], m[2], m[3]];
let got = faa_top4(m, &q);
let want = walker_top(4, &derivs, &q);
assert!(
(got - want).abs() <= 1e-12 * want.abs().max(1.0),
"faa_top4 {got:+.17e} vs walker {want:+.17e}"
);
}
}
}