use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::math::bhattacharyya_coefficient";
#[must_use]
pub fn bhattacharyya_per_element(p: &str, q: &str, out_per_elem: &str, n: u32) -> Program {
if n == 0 {
return crate::invalid_output_program(
OP_ID,
out_per_elem,
DataType::U32,
format!("Fix: bhattacharyya_per_element requires n > 0, got {n}."),
);
}
let t = Expr::InvocationId { axis: 0 };
let isqrt_inline = |val_var: &'static str| {
let mut steps = vec![
Node::let_bind(val_var, Expr::load(p, t.clone())),
Node::let_bind(
"x",
Expr::select(
Expr::eq(Expr::var(val_var), Expr::u32(0)),
Expr::u32(1),
Expr::var(val_var),
),
),
];
for _ in 0..4 {
steps.push(Node::assign(
"x",
Expr::shr(
Expr::add(
Expr::var("x"),
Expr::div(Expr::var(val_var), Expr::var("x")),
),
Expr::u32(1),
),
));
}
steps
};
let mut body_inner = isqrt_inline("pv");
body_inner.push(Node::let_bind("xp", Expr::var("x")));
body_inner.push(Node::let_bind("qv", Expr::load(q, t.clone())));
body_inner.push(Node::let_bind(
"y",
Expr::select(
Expr::eq(Expr::var("qv"), Expr::u32(0)),
Expr::u32(1),
Expr::var("qv"),
),
));
for _ in 0..4 {
body_inner.push(Node::assign(
"y",
Expr::shr(
Expr::add(Expr::var("y"), Expr::div(Expr::var("qv"), Expr::var("y"))),
Expr::u32(1),
),
));
}
body_inner.push(Node::store(
out_per_elem,
t.clone(),
Expr::shr(Expr::mul(Expr::var("xp"), Expr::var("y")), Expr::u32(16)),
));
let body = vec![Node::if_then(Expr::lt(t.clone(), Expr::u32(n)), body_inner)];
Program::wrapped(
vec![
BufferDecl::storage(p, 0, BufferAccess::ReadOnly, DataType::U32).with_count(n),
BufferDecl::storage(q, 1, BufferAccess::ReadOnly, DataType::U32).with_count(n),
BufferDecl::storage(out_per_elem, 2, BufferAccess::ReadWrite, DataType::U32)
.with_count(n),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn bhattacharyya_coefficient_cpu(p: &[f64], q: &[f64]) -> f64 {
p.iter()
.zip(q.iter())
.map(|(&pi, &qi)| (pi.max(0.0) * qi.max(0.0)).sqrt())
.sum()
}
#[must_use]
pub fn fisher_rao_distance_cpu(p: &[f64], q: &[f64]) -> f64 {
let c = bhattacharyya_coefficient_cpu(p, q).clamp(0.0, 1.0);
2.0 * c.acos()
}
#[must_use]
pub fn amari_alpha_step_cpu(p: &[f64], q: &[f64], alpha: f64, t: f64) -> Vec<f64> {
let t = t.clamp(0.0, 1.0);
let s = 1.0 - t;
p.iter()
.zip(q.iter())
.map(|(&pi, &qi)| {
if (alpha - 1.0).abs() < 1e-12 {
pi.powf(t) * qi.powf(s)
} else if (alpha + 1.0).abs() < 1e-12 {
t * pi + s * qi
} else if alpha.abs() < 1e-12 {
let sp = pi.max(0.0).sqrt();
let sq = qi.max(0.0).sqrt();
let blended = t * sp + s * sq;
blended * blended
} else {
let beta = (1.0 - alpha) / 2.0;
let blended = t * pi.max(0.0).powf(beta) + s * qi.max(0.0).powf(beta);
blended.powf(1.0 / beta)
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < 1e-6 * (1.0 + a.abs() + b.abs())
}
#[test]
fn cpu_self_distance_is_zero() {
let p = vec![0.5, 0.3, 0.2];
assert!(approx_eq(fisher_rao_distance_cpu(&p, &p), 0.0));
}
#[test]
fn cpu_orthogonal_distance_is_pi() {
let p = vec![1.0, 0.0];
let q = vec![0.0, 1.0];
assert!(approx_eq(
fisher_rao_distance_cpu(&p, &q),
std::f64::consts::PI
));
}
#[test]
fn cpu_bhattacharyya_symmetric() {
let p = vec![0.4, 0.6];
let q = vec![0.7, 0.3];
let bc1 = bhattacharyya_coefficient_cpu(&p, &q);
let bc2 = bhattacharyya_coefficient_cpu(&q, &p);
assert!(approx_eq(bc1, bc2));
}
#[test]
fn cpu_mismatched_inputs_truncate_and_t_clamps() {
assert_eq!(bhattacharyya_coefficient_cpu(&[1.0], &[]), 0.0);
let out = amari_alpha_step_cpu(&[1.0, 2.0], &[3.0], -1.0, 2.0);
assert_eq!(out, vec![1.0]);
}
#[test]
fn cpu_amari_alpha_neg_one_recovers_linear_mix() {
let p = vec![1.0, 0.0];
let q = vec![0.0, 1.0];
let r = amari_alpha_step_cpu(&p, &q, -1.0, 0.25);
assert!(approx_eq(r[0], 0.25));
assert!(approx_eq(r[1], 0.75));
}
#[test]
fn cpu_amari_alpha_one_recovers_geometric_mix() {
let p = vec![0.5, 0.5];
let q = vec![0.5, 0.5];
let r = amari_alpha_step_cpu(&p, &q, 1.0, 0.5);
assert!(approx_eq(r[0], 0.5));
assert!(approx_eq(r[1], 0.5));
}
#[test]
fn cpu_amari_alpha_zero_recovers_spherical_slerp() {
let p = vec![1.0, 0.0];
let q = vec![0.0, 1.0];
let r0 = amari_alpha_step_cpu(&p, &q, 0.0, 0.0);
let r1 = amari_alpha_step_cpu(&p, &q, 0.0, 1.0);
assert!(approx_eq(r0[0], 0.0) && approx_eq(r0[1], 1.0));
assert!(approx_eq(r1[0], 1.0) && approx_eq(r1[1], 0.0));
}
#[test]
fn ir_program_buffer_layout() {
let prog = bhattacharyya_per_element("p", "q", "out", 16);
assert_eq!(prog.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = prog.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["p", "q", "out"]);
}
#[test]
fn zero_n_traps() {
let p = bhattacharyya_per_element("p", "q", "out", 0);
assert!(p.stats().trap());
}
}