use ndarray::ArrayView2;
pub fn hutchinson_trace(probes: ArrayView2<'_, f64>, a_probes: ArrayView2<'_, f64>) -> f64 {
let k = probes.ncols();
if k == 0 {
return 0.0;
}
let mut acc = 0.0;
for j in 0..k {
acc += probes.column(j).dot(&a_probes.column(j));
}
acc / k as f64
}
pub fn controlled_trace(
q: ArrayView2<'_, f64>,
a_q: ArrayView2<'_, f64>,
probes: ArrayView2<'_, f64>,
a_probes: ArrayView2<'_, f64>,
) -> f64 {
let r = q.ncols();
let mut exact = 0.0;
for i in 0..r {
exact += q.column(i).dot(&a_q.column(i));
}
let k = probes.ncols();
if k == 0 {
return exact;
}
let mut residual = 0.0;
for j in 0..k {
let z = probes.column(j);
let qtz = q.t().dot(&z);
let z_def = &z.to_owned() - &q.dot(&qtz);
residual += z_def.dot(&a_probes.column(j));
}
exact + residual / k as f64
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn hadamard4() -> Array2<f64> {
Array2::from_shape_vec(
(4, 4),
vec![
1.0, 1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0, 1.0,
],
)
.unwrap()
}
fn spd4() -> Array2<f64> {
let m = Array2::from_shape_vec(
(4, 4),
vec![
1.0, 0.2, -0.3, 0.1, 0.0, 1.1, 0.4, -0.2, 0.5, -0.1, 0.9, 0.3, -0.2, 0.3, 0.0, 1.2,
],
)
.unwrap();
let mut a = m.dot(&m.t());
for i in 0..4 {
a[[i, i]] += 0.5;
}
a
}
fn trace(a: ArrayView2<'_, f64>) -> f64 {
(0..a.nrows()).map(|i| a[[i, i]]).sum()
}
#[test]
fn hutchinson_with_full_hadamard_probes_is_exact() {
let a = spd4();
let probes = hadamard4();
let a_probes = a.dot(&probes);
let est = hutchinson_trace(probes.view(), a_probes.view());
assert!(
(est - trace(a.view())).abs() < 1e-10,
"{est} vs {}",
trace(a.view())
);
}
#[test]
fn controlled_trace_with_partial_basis_and_hadamard_probes_is_exact() {
let a = spd4();
let h = hadamard4();
let mut q = Array2::<f64>::zeros((4, 2));
for i in 0..4 {
q[[i, 0]] = h[[i, 0]] / 2.0;
q[[i, 1]] = h[[i, 1]] / 2.0;
}
let a_q = a.dot(&q);
let probes = hadamard4();
let a_probes = a.dot(&probes);
let est = controlled_trace(q.view(), a_q.view(), probes.view(), a_probes.view());
assert!(
(est - trace(a.view())).abs() < 1e-10,
"{est} vs {}",
trace(a.view())
);
}
#[test]
fn full_rank_basis_makes_residual_vanish() {
let a = spd4();
let h = hadamard4();
let q = h.mapv(|v| v / 2.0);
let a_q = a.dot(&q);
let empty = Array2::<f64>::zeros((4, 0));
let est = controlled_trace(q.view(), a_q.view(), empty.view(), empty.view());
assert!((est - trace(a.view())).abs() < 1e-10);
}
#[test]
fn controlled_reduces_to_hutchinson_with_empty_basis() {
let a = spd4();
let probes = hadamard4();
let a_probes = a.dot(&probes);
let empty_q = Array2::<f64>::zeros((4, 0));
let empty_aq = Array2::<f64>::zeros((4, 0));
let controlled = controlled_trace(
empty_q.view(),
empty_aq.view(),
probes.view(),
a_probes.view(),
);
let hutch = hutchinson_trace(probes.view(), a_probes.view());
assert!((controlled - hutch).abs() < 1e-12);
}
}