use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
#[derive(Clone, Debug)]
pub struct SphereTangentEmbedding {
basepoint: Array1<f64>,
tangent_basis: Array2<f64>,
}
impl SphereTangentEmbedding {
pub fn fit(prob_rows: ArrayView2<'_, f64>) -> Result<(Self, Array2<f64>), String> {
let (n, v) = prob_rows.dim();
if n == 0 || v < 2 {
return Err(format!(
"SphereTangentEmbedding::fit: need n ≥ 1 rows and V ≥ 2 tokens; got ({n}, {v})"
));
}
let mut q = Array2::<f64>::zeros((n, v));
let mut mean = Array1::<f64>::zeros(v);
for i in 0..n {
let row = prob_rows.row(i);
let mut sum = 0.0_f64;
for &value in row.iter() {
if !(value.is_finite() && value >= 0.0) {
return Err(format!(
"SphereTangentEmbedding::fit: row {i} has a non-finite or negative \
probability entry ({value})"
));
}
sum += value;
}
if !(sum > 0.0) {
return Err(format!(
"SphereTangentEmbedding::fit: row {i} sums to {sum}; a behavioral summary \
must have positive mass"
));
}
let inv_sqrt_sum = 1.0 / sum.sqrt();
let mut q_row = q.row_mut(i);
for j in 0..v {
let qij = prob_rows[[i, j]].sqrt() * inv_sqrt_sum;
q_row[j] = qij;
mean[j] += qij;
}
}
let mean_norm = mean.dot(&mean).sqrt();
if !(mean_norm > 0.0) {
return Err(
"SphereTangentEmbedding::fit: the extrinsic mean of the half-densities is the \
zero vector (antipodally balanced behavior); no basepoint is defined"
.to_string(),
);
}
let basepoint = &mean / mean_norm;
let tangent_basis = tangent_basis_orthogonal_to(basepoint.view())?;
let root_two = std::f64::consts::SQRT_2;
let mut target = q.dot(&tangent_basis);
target.mapv_inplace(|value| root_two * value);
Ok((
Self {
basepoint,
tangent_basis,
},
target,
))
}
pub fn vocab(&self) -> usize {
self.basepoint.len()
}
pub fn behavior_dim(&self) -> usize {
self.tangent_basis.ncols()
}
pub fn basepoint(&self) -> ArrayView1<'_, f64> {
self.basepoint.view()
}
pub fn embed(&self, prob_rows: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
let v = self.vocab();
let (m, v_in) = prob_rows.dim();
if v_in != v {
return Err(format!(
"SphereTangentEmbedding::embed: rows have {v_in} tokens; chart is over {v}"
));
}
let mut q = Array2::<f64>::zeros((m, v));
for i in 0..m {
let row = prob_rows.row(i);
let mut sum = 0.0_f64;
for &value in row.iter() {
if !(value.is_finite() && value >= 0.0) {
return Err(format!(
"SphereTangentEmbedding::embed: row {i} has a non-finite or negative entry \
({value})"
));
}
sum += value;
}
if !(sum > 0.0) {
return Err(format!(
"SphereTangentEmbedding::embed: row {i} sums to {sum}"
));
}
let inv_sqrt_sum = 1.0 / sum.sqrt();
let mut q_row = q.row_mut(i);
for j in 0..v {
q_row[j] = prob_rows[[i, j]].sqrt() * inv_sqrt_sum;
}
}
let root_two = std::f64::consts::SQRT_2;
let mut coords = q.dot(&self.tangent_basis);
coords.mapv_inplace(|value| root_two * value);
Ok(coords)
}
pub fn decode_sphere(&self, y: ArrayView1<'_, f64>) -> Result<Array1<f64>, String> {
let py = self.behavior_dim();
if y.len() != py {
return Err(format!(
"SphereTangentEmbedding::decode_sphere: coordinate has length {}; chart tangent \
dim is {py}",
y.len()
));
}
let inv_root_two = std::f64::consts::FRAC_1_SQRT_2;
let c = &y.to_owned() * inv_root_two;
let tangent = self.tangent_basis.dot(&c);
let radial_sq = 1.0 - c.dot(&c);
let radial = if radial_sq > 0.0 {
radial_sq.sqrt()
} else {
0.0
};
let mut q = &tangent + &(&self.basepoint * radial);
let norm = q.dot(&q).sqrt();
if norm > 0.0 {
q.mapv_inplace(|value| value / norm);
}
Ok(q)
}
pub fn decode(&self, y: ArrayView1<'_, f64>) -> Result<Array1<f64>, String> {
let q = self.decode_sphere(y)?;
Ok(q.mapv(|value| value * value))
}
pub fn predicted_nats(delta_y: ArrayView1<'_, f64>) -> f64 {
delta_y.dot(&delta_y)
}
pub fn exact_kl(p_a: ArrayView1<'_, f64>, p_b: ArrayView1<'_, f64>) -> Result<f64, String> {
if p_a.len() != p_b.len() {
return Err(format!(
"SphereTangentEmbedding::exact_kl: length mismatch {} vs {}",
p_a.len(),
p_b.len()
));
}
let mut kl = 0.0_f64;
for (&a, &b) in p_a.iter().zip(p_b.iter()) {
if a > 0.0 {
kl += a * (a / b).ln();
}
}
Ok(kl)
}
pub fn fisher_rao_distance(q_a: ArrayView1<'_, f64>, q_b: ArrayView1<'_, f64>) -> f64 {
let dot = q_a.dot(&q_b).clamp(-1.0, 1.0);
2.0 * dot.acos()
}
}
fn tangent_basis_orthogonal_to(axis: ArrayView1<'_, f64>) -> Result<Array2<f64>, String> {
let v = axis.len();
if v < 2 {
return Err(format!("tangent_basis_orthogonal_to: need V ≥ 2; got {v}"));
}
let mut pivot = 0usize;
let mut best = axis[0].abs();
for j in 1..v {
let a = axis[j].abs();
if a > best {
best = a;
pivot = j;
}
}
let mut w = axis.to_owned();
w.mapv_inplace(|value| -value);
w[pivot] += 1.0;
let w_norm = f64::sqrt(w.dot(&w));
if !(w_norm > 0.0) {
w.fill(0.0);
} else {
w.mapv_inplace(|value| value / w_norm);
}
let mut basis = Array2::<f64>::zeros((v, v - 1));
let mut col = 0usize;
for j in 0..v {
if j == pivot {
continue;
}
let two_wj = 2.0 * w[j];
for i in 0..v {
let e_ij = if i == j { 1.0 } else { 0.0 };
basis[[i, col]] = e_ij - two_wj * w[i];
}
col += 1;
}
Ok(basis)
}
#[derive(Clone, Debug)]
pub struct BehaviorBlock {
pub embedding: SphereTangentEmbedding,
pub target: Array2<f64>,
pub activation_dim: usize,
pub log_lambda_y: f64,
}
impl BehaviorBlock {
pub fn fit(
prob_rows: ArrayView2<'_, f64>,
activation_dim: usize,
log_lambda_y: f64,
) -> Result<Self, String> {
if activation_dim == 0 {
return Err("BehaviorBlock::fit: activation_dim must be positive".into());
}
if !log_lambda_y.is_finite() {
return Err(format!(
"BehaviorBlock::fit: log_lambda_y must be finite; got {log_lambda_y}"
));
}
let (embedding, target) = SphereTangentEmbedding::fit(prob_rows)?;
Ok(Self {
embedding,
target,
activation_dim,
log_lambda_y,
})
}
pub fn behavior_dim(&self) -> usize {
self.embedding.behavior_dim()
}
pub fn augmented_dim(&self) -> usize {
self.activation_dim + self.behavior_dim()
}
pub fn lambda_y(&self) -> f64 {
self.log_lambda_y.exp()
}
pub fn sqrt_lambda_y(&self) -> f64 {
(0.5 * self.log_lambda_y).exp()
}
pub fn augmented_target(&self, activation: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
let (n, px) = activation.dim();
if px != self.activation_dim {
return Err(format!(
"BehaviorBlock::augmented_target: activation has {px} columns; block activation_dim \
is {}",
self.activation_dim
));
}
if self.target.nrows() != n {
return Err(format!(
"BehaviorBlock::augmented_target: activation has {n} rows but behavior target has {}",
self.target.nrows()
));
}
let py = self.behavior_dim();
let sqrt_lambda = self.sqrt_lambda_y();
let mut augmented = Array2::<f64>::zeros((n, px + py));
for i in 0..n {
for j in 0..px {
augmented[[i, j]] = activation[[i, j]];
}
for j in 0..py {
augmented[[i, px + j]] = sqrt_lambda * self.target[[i, j]];
}
}
Ok(augmented)
}
pub fn split_decoder(
&self,
augmented_decoder: ArrayView2<'_, f64>,
) -> Result<(Array2<f64>, Array2<f64>), String> {
let px = self.activation_dim;
let py = self.behavior_dim();
let (m, p_tot) = augmented_decoder.dim();
if p_tot != px + py {
return Err(format!(
"BehaviorBlock::split_decoder: decoder has {p_tot} output columns; expected \
p_x + p_y = {px} + {py} = {}",
px + py
));
}
let inv_sqrt_lambda = 1.0 / self.sqrt_lambda_y();
let mut b = Array2::<f64>::zeros((m, px));
let mut c = Array2::<f64>::zeros((m, py));
for row in 0..m {
for j in 0..px {
b[[row, j]] = augmented_decoder[[row, j]];
}
for j in 0..py {
c[[row, j]] = inv_sqrt_lambda * augmented_decoder[[row, px + j]];
}
}
Ok((b, c))
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array1, Array2};
#[test]
fn tangent_basis_is_orthonormal_and_orthogonal_to_axis() {
let mut axis = Array1::<f64>::from(vec![0.3, -0.5, 0.2, 0.7, -0.34]);
let norm = axis.dot(&axis).sqrt();
axis.mapv_inplace(|v| v / norm);
let e = tangent_basis_orthogonal_to(axis.view()).unwrap();
assert_eq!(e.dim(), (5, 4));
for col in 0..e.ncols() {
let dot = e.column(col).dot(&axis);
assert!(dot.abs() < 1e-12, "column {col} not ⟂ axis: {dot}");
}
let gram = e.t().dot(&e);
for i in 0..4 {
for j in 0..4 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(gram[[i, j]] - expected).abs() < 1e-12,
"EᵀE[{i},{j}] = {} != {expected}",
gram[[i, j]]
);
}
}
}
#[test]
fn embed_decode_round_trips_distributions() {
let rows = vec![
vec![0.4, 0.2, 0.1, 0.1, 0.1, 0.1],
vec![0.1, 0.5, 0.1, 0.1, 0.1, 0.1],
vec![0.2, 0.2, 0.2, 0.2, 0.1, 0.1],
vec![0.05, 0.05, 0.6, 0.1, 0.1, 0.1],
];
let n = rows.len();
let v = rows[0].len();
let mut p = Array2::<f64>::zeros((n, v));
for (i, row) in rows.iter().enumerate() {
for (j, &value) in row.iter().enumerate() {
p[[i, j]] = value;
}
}
let (chart, y) = SphereTangentEmbedding::fit(p.view()).unwrap();
assert_eq!(chart.behavior_dim(), v - 1);
for i in 0..n {
let decoded = chart.decode(y.row(i)).unwrap();
for j in 0..v {
assert!(
(decoded[j] - p[[i, j]]).abs() < 1e-10,
"row {i} token {j}: decoded {} != original {}",
decoded[j],
p[[i, j]]
);
}
}
}
#[test]
fn predicted_nats_matches_exact_kl_to_second_order() {
let base = Array1::from(vec![0.25, 0.25, 0.2, 0.15, 0.15]);
let v = base.len();
let dir = Array1::from(vec![0.1, -0.05, -0.02, -0.02, -0.01]);
let make = |eps: f64| -> Array2<f64> {
let mut p = Array2::<f64>::zeros((2, v));
for j in 0..v {
p[[0, j]] = base[j];
p[[1, j]] = base[j] + eps * dir[j];
}
p
};
let mut prev_rel: Option<f64> = None;
for &eps in &[0.2_f64, 0.1, 0.05, 0.025] {
let p = make(eps);
let (chart, y) = SphereTangentEmbedding::fit(p.view()).unwrap();
let delta_y = &y.row(1).to_owned() - &y.row(0).to_owned();
let predicted = SphereTangentEmbedding::predicted_nats(delta_y.view());
let p0 = chart.decode(y.row(0)).unwrap();
let p1 = chart.decode(y.row(1)).unwrap();
let kl = SphereTangentEmbedding::exact_kl(p1.view(), p0.view()).unwrap();
let rel = (predicted - kl).abs() / kl.max(1e-12);
if let Some(prev) = prev_rel {
assert!(
rel < prev * 0.6,
"relative KL error did not fall second-order: {prev} → {rel} at ε={eps}"
);
}
prev_rel = Some(rel);
}
}
#[test]
fn constant_behavior_has_zero_tangent_target() {
let base = vec![0.3, 0.3, 0.2, 0.2];
let n = 5;
let v = base.len();
let mut p = Array2::<f64>::zeros((n, v));
for i in 0..n {
for j in 0..v {
p[[i, j]] = base[j];
}
}
let (_chart, y) = SphereTangentEmbedding::fit(p.view()).unwrap();
for value in y.iter() {
assert!(
value.abs() < 1e-12,
"constant behavior gave nonzero target {value}"
);
}
}
#[test]
fn fisher_rao_matches_half_kl_leading_order() {
let a = Array1::from(vec![0.25_f64, 0.25, 0.25, 0.25]);
let mut b = Array1::from(vec![0.26_f64, 0.25, 0.25, 0.24]);
let bsum: f64 = b.iter().sum();
b.mapv_inplace(|v| v / bsum);
let qa = a.mapv(|v| v.sqrt());
let qb = b.mapv(|v| v.sqrt());
let self_dist = SphereTangentEmbedding::fisher_rao_distance(qa.view(), qa.view());
assert!(
self_dist < 1e-9,
"self F-R distance should be 0, got {self_dist}"
);
let d = SphereTangentEmbedding::fisher_rao_distance(qa.view(), qb.view());
let kl = SphereTangentEmbedding::exact_kl(a.view(), b.view()).unwrap();
let half_dsq = 0.5 * d * d;
let rel = (half_dsq - kl).abs() / kl;
assert!(rel < 0.05, "½ d_FR² = {half_dsq} vs KL = {kl}, rel {rel}");
}
}