pub trait JetScalar<const K: usize>: Copy {
fn constant(c: f64) -> Self;
fn variable(x: f64, axis: usize) -> Self;
fn value(&self) -> f64;
fn add(&self, o: &Self) -> Self;
fn sub(&self, o: &Self) -> Self;
fn mul(&self, o: &Self) -> Self;
fn neg(&self) -> Self;
fn scale(&self, s: f64) -> Self;
fn compose_unary(&self, d: [f64; 5]) -> Self;
fn exp(&self) -> Self {
let e = self.value().exp();
self.compose_unary([e, e, e, e, e])
}
fn sqrt(&self) -> Self {
let u = self.value();
let s = u.sqrt();
self.compose_unary([
s,
0.5 / s,
-0.25 / (u * s),
0.375 / (u * u * s),
-0.9375 / (u * u * u * s),
])
}
}
pub fn filtered_implicit_solve_scalar<const K: usize, S: JetScalar<K>>(
a0: f64,
inv_fa: f64,
iters: usize,
f: impl Fn(&S) -> S,
) -> S {
let mut a = S::constant(a0);
for _ in 0..iters {
let residual = f(&a);
a = a.sub(&residual.scale(inv_fa));
}
a
}
#[derive(Clone, Copy, Debug)]
pub struct Order2<const K: usize>(pub super::jet_tower::Tower2<K>);
impl<const K: usize> Order2<K> {
#[inline]
pub fn g(&self) -> [f64; K] {
self.0.g
}
#[inline]
pub fn h(&self) -> [[f64; K]; K] {
self.0.h
}
}
impl<const K: usize> JetScalar<K> for Order2<K> {
fn constant(c: f64) -> Self {
Order2(super::jet_tower::Tower2::constant(c))
}
fn variable(x: f64, axis: usize) -> Self {
Order2(super::jet_tower::Tower2::variable(x, axis))
}
fn value(&self) -> f64 {
self.0.v
}
fn add(&self, o: &Self) -> Self {
Order2(self.0 + o.0)
}
fn sub(&self, o: &Self) -> Self {
Order2(self.0 + o.0.scale(-1.0))
}
fn mul(&self, o: &Self) -> Self {
Order2(super::jet_tower::Tower2::mul(&self.0, &o.0))
}
fn neg(&self) -> Self {
Order2(self.0.scale(-1.0))
}
fn scale(&self, s: f64) -> Self {
Order2(self.0.scale(s))
}
fn compose_unary(&self, d: [f64; 5]) -> Self {
Order2(self.0.compose_unary([d[0], d[1], d[2]]))
}
}
#[derive(Clone, Copy, Debug)]
pub struct OneSeed<const K: usize> {
pub base: Order2<K>,
pub eps: Order2<K>,
}
impl<const K: usize> OneSeed<K> {
pub fn seed_direction(x: f64, axis: usize, u_axis: f64) -> Self {
OneSeed {
base: Order2::variable(x, axis),
eps: Order2::constant(u_axis),
}
}
pub fn contracted_third(&self) -> [[f64; K]; K] {
self.eps.h()
}
}
impl<const K: usize> JetScalar<K> for OneSeed<K> {
fn constant(c: f64) -> Self {
OneSeed {
base: Order2::constant(c),
eps: Order2::constant(0.0),
}
}
fn variable(x: f64, axis: usize) -> Self {
OneSeed {
base: Order2::variable(x, axis),
eps: Order2::constant(0.0),
}
}
fn value(&self) -> f64 {
self.base.value()
}
fn add(&self, o: &Self) -> Self {
OneSeed {
base: self.base.add(&o.base),
eps: self.eps.add(&o.eps),
}
}
fn sub(&self, o: &Self) -> Self {
OneSeed {
base: self.base.sub(&o.base),
eps: self.eps.sub(&o.eps),
}
}
fn mul(&self, o: &Self) -> Self {
OneSeed {
base: self.base.mul(&o.base),
eps: self.base.mul(&o.eps).add(&self.eps.mul(&o.base)),
}
}
fn neg(&self) -> Self {
OneSeed {
base: self.base.neg(),
eps: self.eps.neg(),
}
}
fn scale(&self, s: f64) -> Self {
OneSeed {
base: self.base.scale(s),
eps: self.eps.scale(s),
}
}
fn compose_unary(&self, d: [f64; 5]) -> Self {
let base = self.base.compose_unary([d[0], d[1], d[2], d[3], d[4]]);
let fprime = self.base.compose_unary([d[1], d[2], d[3], d[4], d[4]]);
let eps = fprime.mul(&self.eps);
OneSeed { base, eps }
}
}
#[derive(Clone, Copy, Debug)]
pub struct TwoSeed<const K: usize> {
pub base: Order2<K>,
pub eps: Order2<K>,
pub del: Order2<K>,
pub eps_del: Order2<K>,
}
impl<const K: usize> TwoSeed<K> {
pub fn seed(x: f64, axis: usize, u_axis: f64, v_axis: f64) -> Self {
TwoSeed {
base: Order2::variable(x, axis),
eps: Order2::constant(u_axis),
del: Order2::constant(v_axis),
eps_del: Order2::constant(0.0),
}
}
pub fn contracted_fourth(&self) -> [[f64; K]; K] {
self.eps_del.h()
}
}
impl<const K: usize> JetScalar<K> for TwoSeed<K> {
fn constant(c: f64) -> Self {
TwoSeed {
base: Order2::constant(c),
eps: Order2::constant(0.0),
del: Order2::constant(0.0),
eps_del: Order2::constant(0.0),
}
}
fn variable(x: f64, axis: usize) -> Self {
TwoSeed {
base: Order2::variable(x, axis),
eps: Order2::constant(0.0),
del: Order2::constant(0.0),
eps_del: Order2::constant(0.0),
}
}
fn value(&self) -> f64 {
self.base.value()
}
fn add(&self, o: &Self) -> Self {
TwoSeed {
base: self.base.add(&o.base),
eps: self.eps.add(&o.eps),
del: self.del.add(&o.del),
eps_del: self.eps_del.add(&o.eps_del),
}
}
fn sub(&self, o: &Self) -> Self {
TwoSeed {
base: self.base.sub(&o.base),
eps: self.eps.sub(&o.eps),
del: self.del.sub(&o.del),
eps_del: self.eps_del.sub(&o.eps_del),
}
}
fn mul(&self, o: &Self) -> Self {
let a = self;
let b = o;
let base = a.base.mul(&b.base);
let eps = a.base.mul(&b.eps).add(&a.eps.mul(&b.base));
let del = a.base.mul(&b.del).add(&a.del.mul(&b.base));
let eps_del = a
.base
.mul(&b.eps_del)
.add(&a.eps.mul(&b.del))
.add(&a.del.mul(&b.eps))
.add(&a.eps_del.mul(&b.base));
TwoSeed {
base,
eps,
del,
eps_del,
}
}
fn neg(&self) -> Self {
TwoSeed {
base: self.base.neg(),
eps: self.eps.neg(),
del: self.del.neg(),
eps_del: self.eps_del.neg(),
}
}
fn scale(&self, s: f64) -> Self {
TwoSeed {
base: self.base.scale(s),
eps: self.eps.scale(s),
del: self.del.scale(s),
eps_del: self.eps_del.scale(s),
}
}
fn compose_unary(&self, d: [f64; 5]) -> Self {
let base = self.base.compose_unary([d[0], d[1], d[2], d[3], d[4]]);
let fprime = self.base.compose_unary([d[1], d[2], d[3], d[4], d[4]]); let fsecond = self.base.compose_unary([d[2], d[3], d[4], d[4], d[4]]); let eps = fprime.mul(&self.eps);
let del = fprime.mul(&self.del);
let eps_del = fsecond
.mul(&self.eps)
.mul(&self.del)
.add(&fprime.mul(&self.eps_del));
TwoSeed {
base,
eps,
del,
eps_del,
}
}
}
impl<const K: usize> JetScalar<K> for super::jet_tower::Tower4<K> {
fn constant(c: f64) -> Self {
super::jet_tower::Tower4::constant(c)
}
fn variable(x: f64, axis: usize) -> Self {
super::jet_tower::Tower4::variable(x, axis)
}
fn value(&self) -> f64 {
self.v
}
fn add(&self, o: &Self) -> Self {
*self + *o
}
fn sub(&self, o: &Self) -> Self {
*self - *o
}
fn mul(&self, o: &Self) -> Self {
super::jet_tower::Tower4::mul(self, o)
}
fn neg(&self) -> Self {
self.scale(-1.0)
}
fn scale(&self, s: f64) -> Self {
super::jet_tower::Tower4::scale(self, s)
}
fn compose_unary(&self, d: [f64; 5]) -> Self {
super::jet_tower::Tower4::compose_unary(self, d)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::families::jet_tower::{evaluate_program, RowNllProgram, Tower4};
fn row_expr<S: JetScalar<2>>(p: &[S; 2]) -> S {
let g = p[0].mul(&p[1]).exp();
let inner = g.add(&S::constant(2.0));
let radic = p[0].mul(&p[0]).add(&S::constant(1.0)).sqrt();
inner.mul(&radic).sub(&p[1].mul(&p[1]).scale(0.5))
}
struct ExprProgram {
p: [f64; 2],
}
impl RowNllProgram<2> for ExprProgram {
fn n_rows(&self) -> usize {
1
}
fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
if row >= self.n_rows() {
return Err(format!("ExprProgram: row {row} out of range"));
}
Ok(self.p)
}
fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
if row >= self.n_rows() {
return Err(format!("ExprProgram: row {row} out of range"));
}
Ok(row_expr(p))
}
}
const SEED: [f64; 2] = [0.37, -0.81];
const U: [f64; 2] = [0.6, -0.2];
const V: [f64; 2] = [-0.4, 1.1];
const TOL: f64 = 1e-10;
fn close(a: f64, b: f64, label: &str) {
let band = TOL + TOL * a.abs().max(b.abs());
assert!(
(a - b).abs() <= band,
"{label}: {a:+.15e} vs {b:+.15e} (band {band:.3e})"
);
}
fn tower() -> Tower4<2> {
evaluate_program(&ExprProgram { p: SEED }, 0).expect("tower")
}
#[test]
fn order2_matches_tower_value_grad_hessian() {
let t = tower();
let vars: [Order2<2>; 2] = std::array::from_fn(|a| Order2::variable(SEED[a], a));
let s = row_expr(&vars);
close(s.value(), t.v, "value");
for a in 0..2 {
close(s.0.g[a], t.g[a], &format!("grad[{a}]"));
for b in 0..2 {
close(s.h()[a][b], t.h[a][b], &format!("hess[{a}][{b}]"));
}
}
}
#[test]
fn one_seed_matches_tower_third_contracted() {
let t = tower();
let truth = t.third_contracted(&U);
let vars: [OneSeed<2>; 2] =
std::array::from_fn(|a| OneSeed::seed_direction(SEED[a], a, U[a]));
let s = row_expr(&vars);
close(s.value(), t.v, "value");
for a in 0..2 {
for b in 0..2 {
close(s.base.h()[a][b], t.h[a][b], &format!("base hess[{a}][{b}]"));
}
}
let third = s.contracted_third();
for a in 0..2 {
for b in 0..2 {
close(third[a][b], truth[a][b], &format!("third[{a}][{b}]"));
}
}
}
#[test]
fn two_seed_matches_tower_fourth_contracted() {
let t = tower();
let truth4 = t.fourth_contracted(&U, &V);
let truth3_u = t.third_contracted(&U);
let truth3_v = t.third_contracted(&V);
let vars: [TwoSeed<2>; 2] = std::array::from_fn(|a| TwoSeed::seed(SEED[a], a, U[a], V[a]));
let s = row_expr(&vars);
close(s.value(), t.v, "value");
for a in 0..2 {
close(s.base.0.g[a], t.g[a], &format!("grad[{a}]"));
for b in 0..2 {
close(s.base.h()[a][b], t.h[a][b], &format!("base hess[{a}][{b}]"));
close(
s.eps.h()[a][b],
truth3_u[a][b],
&format!("eps third_u[{a}][{b}]"),
);
close(
s.del.h()[a][b],
truth3_v[a][b],
&format!("del third_v[{a}][{b}]"),
);
}
}
let fourth = s.contracted_fourth();
for a in 0..2 {
for b in 0..2 {
close(fourth[a][b], truth4[a][b], &format!("fourth[{a}][{b}]"));
}
}
}
#[test]
fn generic_program_seam_matches_tower_for_every_channel() {
let t = tower();
let o2: [Order2<2>; 2] = std::array::from_fn(|a| Order2::variable(SEED[a], a));
let so2 = row_expr(&o2);
close(so2.value(), t.v, "seam order2 value");
let os: [OneSeed<2>; 2] =
std::array::from_fn(|a| OneSeed::seed_direction(SEED[a], a, U[a]));
let third = row_expr(&os).contracted_third();
let truth3 = t.third_contracted(&U);
for a in 0..2 {
for b in 0..2 {
close(third[a][b], truth3[a][b], &format!("seam third[{a}][{b}]"));
}
}
let ts: [TwoSeed<2>; 2] = std::array::from_fn(|a| TwoSeed::seed(SEED[a], a, U[a], V[a]));
let fourth = row_expr(&ts).contracted_fourth();
let truth4 = t.fourth_contracted(&U, &V);
for a in 0..2 {
for b in 0..2 {
close(fourth[a][b], truth4[a][b], &format!("seam fourth[{a}][{b}]"));
}
}
}
#[test]
fn tower4_as_jetscalar_matches_program_tower_all_channels() {
let t = tower();
let vars: [Tower4<2>; 2] = std::array::from_fn(|a| Tower4::variable(SEED[a], a));
let s = row_expr(&vars);
close(s.v, t.v, "tower-jetscalar value");
for a in 0..2 {
close(s.g[a], t.g[a], &format!("tower-jetscalar grad[{a}]"));
for b in 0..2 {
close(s.h[a][b], t.h[a][b], &format!("tower-jetscalar hess[{a}][{b}]"));
for c in 0..2 {
close(
s.t3[a][b][c],
t.t3[a][b][c],
&format!("tower-jetscalar t3[{a}][{b}][{c}]"),
);
for d in 0..2 {
close(
s.t4[a][b][c][d],
t.t4[a][b][c][d],
&format!("tower-jetscalar t4[{a}][{b}][{c}][{d}]"),
);
}
}
}
}
}
}