use super::estimator::Laplacian;
use super::jet::{directional_derivatives, taylor_jet_2nd_with_buf};
use super::pipeline::estimate;
use super::types::{EstimatorResult, WelfordAccumulator};
use crate::bytecode_tape::BytecodeTape;
use crate::taylor::Taylor;
use crate::Float;
pub fn laplacian<F: Float>(tape: &BytecodeTape<F>, x: &[F], directions: &[&[F]]) -> (F, F) {
assert!(!directions.is_empty(), "directions must not be empty");
let (value, _, second_order) = directional_derivatives(tape, x, directions);
let two = F::from(2.0).unwrap();
let s = F::from(directions.len()).unwrap();
let sum: F = second_order
.iter()
.fold(F::zero(), |acc, &c2| acc + two * c2);
let laplacian = sum / s;
(value, laplacian)
}
pub fn laplacian_with_stats<F: Float>(
tape: &BytecodeTape<F>,
x: &[F],
directions: &[&[F]],
) -> EstimatorResult<F> {
estimate(&Laplacian, tape, x, directions)
}
pub fn laplacian_with_control<F: Float>(
tape: &BytecodeTape<F>,
x: &[F],
directions: &[&[F]],
control_diagonal: &[F],
) -> EstimatorResult<F> {
assert!(!directions.is_empty(), "directions must not be empty");
let n = tape.num_inputs();
assert_eq!(
control_diagonal.len(),
n,
"control_diagonal.len() must match tape.num_inputs()"
);
let two = F::from(2.0).unwrap();
let trace_control: F = control_diagonal
.iter()
.copied()
.fold(F::zero(), |a, b| a + b);
let mut buf = Vec::new();
let mut value = F::zero();
let mut acc = WelfordAccumulator::new();
for v in directions.iter() {
let (c0, _, c2) = taylor_jet_2nd_with_buf(tape, x, v, &mut buf);
value = c0;
let raw = two * c2;
let cv: F = control_diagonal
.iter()
.zip(v.iter())
.fold(F::zero(), |acc, (&d, &vi)| acc + d * vi * vi);
acc.update(raw - cv + trace_control);
}
let (estimate, sample_variance, standard_error) = acc.finalize();
EstimatorResult {
value,
estimate,
sample_variance,
standard_error,
num_samples: directions.len(),
}
}
pub fn hessian_diagonal<F: Float>(tape: &BytecodeTape<F>, x: &[F]) -> (F, Vec<F>) {
let mut buf = Vec::new();
hessian_diagonal_with_buf(tape, x, &mut buf)
}
pub fn hessian_diagonal_with_buf<F: Float>(
tape: &BytecodeTape<F>,
x: &[F],
buf: &mut Vec<Taylor<F, 3>>,
) -> (F, Vec<F>) {
let n = tape.num_inputs();
assert_eq!(x.len(), n, "x.len() must match tape.num_inputs()");
let two = F::from(2.0).unwrap();
let mut diag = Vec::with_capacity(n);
let mut value = F::zero();
if n == 0 {
let mut values_buf = Vec::new();
tape.forward_into(&[], &mut values_buf);
if let Some(&v) = values_buf.get(tape.output_index()) {
value = v;
}
return (value, diag);
}
let mut e = vec![F::zero(); n];
for j in 0..n {
e[j] = F::one();
let (c0, _, c2) = taylor_jet_2nd_with_buf(tape, x, &e, buf);
value = c0;
diag.push(two * c2);
e[j] = F::zero();
}
(value, diag)
}
fn modified_gram_schmidt<F: Float>(columns: &mut Vec<Vec<F>>, epsilon: F) -> usize {
let mut rank = 0;
let mut i = 0;
while i < columns.len() {
for j in 0..rank {
let (left, right) = columns.split_at_mut(i);
let qj = &left[j];
let ci = &mut right[0];
let dot: F = qj
.iter()
.zip(ci.iter())
.fold(F::zero(), |acc, (&a, &b)| acc + a * b);
for (c, &q) in ci.iter_mut().zip(qj.iter()) {
*c = *c - dot * q;
}
}
let norm_sq: F = columns[i].iter().fold(F::zero(), |acc, &v| acc + v * v);
let norm = norm_sq.sqrt();
if norm < epsilon {
columns.swap_remove(i);
} else {
let inv_norm = F::one() / norm;
for v in columns[i].iter_mut() {
*v = *v * inv_norm;
}
if i != rank {
columns.swap(i, rank);
}
rank += 1;
i += 1;
}
}
columns.truncate(rank);
rank
}
pub fn laplacian_hutchpp<F: Float>(
tape: &BytecodeTape<F>,
x: &[F],
sketch_directions: &[&[F]],
stochastic_directions: &[&[F]],
) -> EstimatorResult<F> {
assert!(
!sketch_directions.is_empty(),
"sketch_directions must not be empty"
);
assert!(
!stochastic_directions.is_empty(),
"stochastic_directions must not be empty"
);
let n = tape.num_inputs();
let two = F::from(2.0).unwrap();
let eps = F::epsilon().sqrt();
let mut dual_vals_buf = Vec::new();
let mut adjoint_buf = Vec::new();
let mut hs_columns: Vec<Vec<F>> = Vec::with_capacity(sketch_directions.len());
for s in sketch_directions {
assert_eq!(
s.len(),
n,
"sketch direction length must match tape.num_inputs()"
);
let (_grad, hvp) = tape.hvp_with_buf(x, s, &mut dual_vals_buf, &mut adjoint_buf);
hs_columns.push(hvp);
}
let rank = modified_gram_schmidt(&mut hs_columns, eps);
let q = &hs_columns;
let mut taylor_buf = Vec::new();
let mut value = F::zero();
let mut exact_trace = F::zero();
for qi in q.iter().take(rank) {
let (c0, _, c2) = taylor_jet_2nd_with_buf(tape, x, qi, &mut taylor_buf);
value = c0;
exact_trace = exact_trace + two * c2; }
let mut acc = WelfordAccumulator::new();
let mut projected = vec![F::zero(); n];
for g in stochastic_directions.iter() {
assert_eq!(
g.len(),
n,
"stochastic direction length must match tape.num_inputs()"
);
projected.copy_from_slice(g);
for qi in q.iter().take(rank) {
let dot: F = qi
.iter()
.zip(g.iter())
.fold(F::zero(), |acc, (&a, &b)| acc + a * b);
for (p, &qv) in projected.iter_mut().zip(qi.iter()) {
*p = *p - dot * qv;
}
}
let (c0, _, c2) = taylor_jet_2nd_with_buf(tape, x, &projected, &mut taylor_buf);
value = c0;
acc.update(two * c2);
}
let (residual_mean, sample_variance, standard_error) = acc.finalize();
EstimatorResult {
value,
estimate: exact_trace + residual_mean,
sample_variance,
standard_error,
num_samples: stochastic_directions.len(),
}
}