use std::f64::consts::PI;
#[inline]
pub fn xlog2x(x: f64) -> f64 {
if x <= 0.0 { 0.0 } else { x * x.log2() }
}
#[inline]
pub fn xlnx(x: f64) -> f64 {
if x <= 0.0 { 0.0 } else { x * x.ln() }
}
pub fn marginal_x(joint: &[Vec<f64>]) -> Vec<f64> {
joint.iter().map(|row| row.iter().sum()).collect()
}
pub fn marginal_y(joint: &[Vec<f64>]) -> Vec<f64> {
if joint.is_empty() {
return vec![];
}
let ncols = joint[0].len();
let mut py = vec![0.0_f64; ncols];
for row in joint {
for (j, &v) in row.iter().enumerate() {
py[j] += v;
}
}
py
}
pub fn normalise_inplace(v: &mut [f64]) -> bool {
let s: f64 = v.iter().sum();
if s <= 0.0 {
return false;
}
for x in v.iter_mut() {
*x /= s;
}
true
}
pub fn softmax(logits: &[f64]) -> Vec<f64> {
if logits.is_empty() {
return vec![];
}
let max_val = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = logits.iter().map(|&x| (x - max_val).exp()).collect();
let s: f64 = exps.iter().sum();
exps.iter().map(|&e| e / s).collect()
}
pub fn entropy_bits(probs: &[f64]) -> f64 {
-probs.iter().map(|&p| xlog2x(p)).sum::<f64>()
}
pub fn entropy_nats(probs: &[f64]) -> f64 {
-probs.iter().map(|&p| xlnx(p)).sum::<f64>()
}
pub fn joint_entropy(joint: &[Vec<f64>]) -> f64 {
-joint
.iter()
.flat_map(|row| row.iter())
.map(|&p| xlog2x(p))
.sum::<f64>()
}
pub fn conditional_entropy(joint: &[Vec<f64>]) -> f64 {
let px = marginal_x(joint);
let hx = entropy_bits(&px);
let hxy = joint_entropy(joint);
hxy - hx
}
pub fn mutual_information(joint: &[Vec<f64>]) -> f64 {
let px = marginal_x(joint);
let py = marginal_y(joint);
entropy_bits(&px) + entropy_bits(&py) - joint_entropy(joint)
}
pub fn normalised_mutual_information(joint: &[Vec<f64>]) -> f64 {
let px = marginal_x(joint);
let py = marginal_y(joint);
let hx = entropy_bits(&px);
let hy = entropy_bits(&py);
if hx <= 0.0 || hy <= 0.0 {
return 0.0;
}
mutual_information(joint) / (hx * hy).sqrt()
}
pub fn variation_of_information(joint: &[Vec<f64>]) -> f64 {
let hxy = joint_entropy(joint);
let mi = mutual_information(joint);
hxy - mi
}
pub fn interaction_information_3(
joint_xy: &[Vec<f64>],
joint_xz: &[Vec<f64>],
joint_yz: &[Vec<f64>],
joint_xyz: &[Vec<f64>],
) -> f64 {
let px = marginal_x(joint_xy);
let py = marginal_y(joint_xy);
let pz = marginal_y(joint_xz);
-entropy_bits(&px) - entropy_bits(&py) - entropy_bits(&pz)
+ joint_entropy(joint_xy)
+ joint_entropy(joint_xz)
+ joint_entropy(joint_yz)
- joint_entropy(joint_xyz)
}
pub fn kl_divergence(p: &[f64], q: &[f64]) -> f64 {
assert_eq!(p.len(), q.len(), "kl_divergence: length mismatch");
p.iter()
.zip(q.iter())
.map(|(&pi, &qi)| {
if pi <= 0.0 {
0.0
} else if qi <= 0.0 {
f64::INFINITY
} else {
pi * (pi / qi).ln()
}
})
.sum()
}
pub fn kl_divergence_bits(p: &[f64], q: &[f64]) -> f64 {
assert_eq!(p.len(), q.len(), "kl_divergence_bits: length mismatch");
p.iter()
.zip(q.iter())
.map(|(&pi, &qi)| {
if pi <= 0.0 {
0.0
} else if qi <= 0.0 {
f64::INFINITY
} else {
pi * (pi / qi).log2()
}
})
.sum()
}
pub fn js_divergence(p: &[f64], q: &[f64]) -> f64 {
assert_eq!(p.len(), q.len(), "js_divergence: length mismatch");
let m: Vec<f64> = p
.iter()
.zip(q.iter())
.map(|(&pi, &qi)| (pi + qi) / 2.0)
.collect();
(kl_divergence(p, &m) + kl_divergence(q, &m)) / 2.0
}
pub fn cross_entropy(p: &[f64], q: &[f64]) -> f64 {
assert_eq!(p.len(), q.len(), "cross_entropy: length mismatch");
-p.iter()
.zip(q.iter())
.map(|(&pi, &qi)| {
if pi <= 0.0 {
0.0
} else if qi <= 0.0 {
f64::INFINITY
} else {
pi * qi.log2()
}
})
.sum::<f64>()
}
pub fn total_variation_distance(p: &[f64], q: &[f64]) -> f64 {
assert_eq!(p.len(), q.len(), "tv_distance: length mismatch");
0.5 * p
.iter()
.zip(q.iter())
.map(|(&a, &b)| (a - b).abs())
.sum::<f64>()
}
pub fn hellinger_distance_sq(p: &[f64], q: &[f64]) -> f64 {
assert_eq!(p.len(), q.len(), "hellinger: length mismatch");
let bc: f64 = p
.iter()
.zip(q.iter())
.map(|(&pi, &qi)| (pi * qi).sqrt())
.sum();
(1.0 - bc).max(0.0)
}
pub fn bhattacharyya_distance(p: &[f64], q: &[f64]) -> f64 {
assert_eq!(p.len(), q.len(), "bhattacharyya: length mismatch");
let bc: f64 = p
.iter()
.zip(q.iter())
.map(|(&pi, &qi)| (pi * qi).sqrt())
.sum();
if bc <= 0.0 { f64::INFINITY } else { -bc.ln() }
}
pub fn chi_squared_divergence(p: &[f64], q: &[f64]) -> f64 {
assert_eq!(p.len(), q.len(), "chi_squared: length mismatch");
p.iter()
.zip(q.iter())
.map(|(&pi, &qi)| {
if qi <= 0.0 {
if pi > 0.0 { f64::INFINITY } else { 0.0 }
} else {
(pi - qi).powi(2) / qi
}
})
.sum()
}
pub fn f_divergence(p: &[f64], q: &[f64], f_fn: impl Fn(f64) -> f64) -> f64 {
assert_eq!(p.len(), q.len(), "f_divergence: length mismatch");
p.iter()
.zip(q.iter())
.map(
|(&pi, &qi)| {
if qi <= 0.0 { 0.0 } else { qi * f_fn(pi / qi) }
},
)
.sum()
}
pub fn renyi_entropy(probs: &[f64], alpha: f64) -> f64 {
if (alpha - 1.0).abs() < 1e-12 {
return entropy_bits(probs);
}
let sum: f64 = probs.iter().map(|&p| p.powf(alpha)).sum();
sum.log2() / (1.0 - alpha)
}
pub fn min_entropy(probs: &[f64]) -> f64 {
let p_max = probs.iter().cloned().fold(0.0_f64, f64::max);
if p_max <= 0.0 { 0.0 } else { -p_max.log2() }
}
pub fn tsallis_entropy(probs: &[f64], q: f64) -> f64 {
if (q - 1.0).abs() < 1e-12 {
return entropy_nats(probs);
}
let sum: f64 = probs.iter().map(|&p| p.powf(q)).sum();
(1.0 - sum) / (q - 1.0)
}
pub fn collision_entropy(probs: &[f64]) -> f64 {
let sum_sq: f64 = probs.iter().map(|&p| p * p).sum();
if sum_sq <= 0.0 { 0.0 } else { -sum_sq.log2() }
}
pub fn hartley_entropy(probs: &[f64]) -> f64 {
let count = probs.iter().filter(|&&p| p > 0.0).count() as f64;
if count <= 0.0 { 0.0 } else { count.log2() }
}
pub fn gaussian_differential_entropy(sigma: f64) -> f64 {
0.5 * (2.0 * PI * std::f64::consts::E * sigma * sigma).ln()
}
pub fn uniform_differential_entropy(a: f64, b: f64) -> f64 {
let width = b - a;
if width <= 0.0 {
f64::NEG_INFINITY
} else {
width.ln()
}
}
pub fn exponential_differential_entropy(rate: f64) -> f64 {
1.0 - rate.ln()
}
pub fn laplace_differential_entropy(scale: f64) -> f64 {
1.0 + (2.0 * scale).ln()
}
pub fn multivariate_gaussian_entropy(dimension: usize, det_cov: f64) -> f64 {
let d = dimension as f64;
0.5 * d * (2.0 * PI * std::f64::consts::E).ln() + 0.5 * det_cov.ln()
}
pub fn gamma_differential_entropy(alpha: f64, beta: f64) -> f64 {
let ln_gamma = if alpha > 10.0 {
(alpha - 0.5) * alpha.ln() - alpha + 0.5 * (2.0 * PI).ln()
} else {
let mut val = 1.0_f64;
let mut a = alpha;
while a < 10.0 {
val *= a;
a += 1.0;
}
let lg_a = (a - 0.5) * a.ln() - a + 0.5 * (2.0 * PI).ln();
lg_a - val.ln()
};
let psi_alpha = if alpha > 5.0 {
alpha.ln() - 0.5 / alpha
} else {
let mut psi = 0.0_f64;
let mut a = alpha;
while a < 5.0 {
psi -= 1.0 / a;
a += 1.0;
}
psi + a.ln() - 0.5 / a
};
alpha - beta.ln() + ln_gamma + (1.0 - alpha) * psi_alpha
}
pub fn max_entropy_mean_constraint(n: usize, target_mean: f64) -> Vec<f64> {
if n == 0 {
return vec![];
}
if n == 1 {
return vec![1.0];
}
let mut lo = -20.0_f64;
let mut hi = 20.0_f64;
for _ in 0..200 {
let mid = (lo + hi) / 2.0;
let probs = softmax(&(0..n).map(|i| -mid * i as f64).collect::<Vec<_>>());
let mean: f64 = probs.iter().enumerate().map(|(i, &p)| i as f64 * p).sum();
if mean < target_mean {
hi = mid;
} else {
lo = mid;
}
}
let lambda = (lo + hi) / 2.0;
softmax(&(0..n).map(|i| -lambda * i as f64).collect::<Vec<_>>())
}
pub fn max_entropy_mean_variance_constraint(
n: usize,
target_mean: f64,
target_var: f64,
) -> Vec<f64> {
if n == 0 {
return vec![];
}
if n == 1 {
return vec![1.0];
}
let mut lambda1 = 0.0_f64;
let mut lambda2 = 0.0_f64;
let lr = 0.01_f64;
for _ in 0..500 {
let logits: Vec<f64> = (0..n)
.map(|i| -lambda1 * i as f64 - lambda2 * (i as f64).powi(2))
.collect();
let probs = softmax(&logits);
let mean: f64 = probs.iter().enumerate().map(|(i, &p)| i as f64 * p).sum();
let var: f64 = probs
.iter()
.enumerate()
.map(|(i, &p)| (i as f64 - mean).powi(2) * p)
.sum();
lambda1 += lr * (mean - target_mean);
lambda2 += lr * (var - target_var);
}
let logits: Vec<f64> = (0..n)
.map(|i| -lambda1 * i as f64 - lambda2 * (i as f64).powi(2))
.collect();
softmax(&logits)
}
pub fn max_entropy_uniform(n: usize) -> Vec<f64> {
if n == 0 {
return vec![];
}
vec![1.0 / n as f64; n]
}
pub fn max_entropy_gaussian_channel(power: f64, noise: f64) -> f64 {
0.5 * (1.0 + power / noise).log2()
}
pub fn transfer_entropy(source: &[f64], target: &[f64], lag: usize) -> f64 {
let n = source.len().min(target.len());
if n <= lag + 1 {
return 0.0;
}
let bins = 8_usize;
let all_vals: Vec<f64> = source[..n]
.iter()
.chain(target[..n].iter())
.copied()
.collect();
let lo = all_vals.iter().cloned().fold(f64::INFINITY, f64::min);
let hi = all_vals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if (hi - lo).abs() < 1e-12 {
return 0.0;
}
let bin_of = |v: f64| -> usize {
let idx = ((v - lo) / (hi - lo) * bins as f64) as usize;
idx.min(bins - 1)
};
let mut p3 = vec![0.0_f64; bins * bins * bins];
let mut p2_ty = vec![0.0_f64; bins * bins];
let mut p2_src_ty = vec![0.0_f64; bins * bins];
let mut p1_ty = vec![0.0_f64; bins];
let count = (n - lag) as f64;
for t in lag..n {
let yt = bin_of(target[t]);
let yt_lag = bin_of(target[t - lag]);
let xt_lag = bin_of(source[t - lag]);
p3[yt * bins * bins + yt_lag * bins + xt_lag] += 1.0 / count;
p2_ty[yt * bins + yt_lag] += 1.0 / count;
p2_src_ty[yt_lag * bins + xt_lag] += 1.0 / count;
p1_ty[yt_lag] += 1.0 / count;
}
let mut te = 0.0_f64;
for yt in 0..bins {
for yt_l in 0..bins {
for xt_l in 0..bins {
let joint = p3[yt * bins * bins + yt_l * bins + xt_l];
if joint <= 0.0 {
continue;
}
let a = p2_ty[yt * bins + yt_l];
let b = p2_src_ty[yt_l * bins + xt_l];
let c = p1_ty[yt_l];
if a > 0.0 && b > 0.0 && c > 0.0 {
te += joint * (joint * c / (a * b)).ln();
}
}
}
}
te.max(0.0)
}
pub fn conditional_transfer_entropy(
source: &[f64],
target: &[f64],
cond: &[f64],
lag: usize,
) -> f64 {
let n = source.len().min(target.len()).min(cond.len());
if n <= lag + 1 {
return 0.0;
}
let te_xy = transfer_entropy(source, target, lag);
let te_zy = transfer_entropy(cond, target, lag);
(te_xy - te_zy).max(0.0)
}
pub fn shannon_hartley_capacity(bandwidth_hz: f64, snr: f64) -> f64 {
bandwidth_hz * (1.0 + snr).log2()
}
pub fn channel_capacity_blahut(transition: &[Vec<f64>]) -> f64 {
let n_in = transition.len();
if n_in == 0 {
return 0.0;
}
let n_out = transition[0].len();
if n_out == 0 {
return 0.0;
}
let mut q = vec![1.0 / n_in as f64; n_in];
for _ in 0..200 {
let mut py = vec![0.0_f64; n_out];
for (i, qi) in q.iter().enumerate() {
for (j, &tij) in transition[i].iter().enumerate() {
py[j] += qi * tij;
}
}
let mut c = vec![0.0_f64; n_in];
for i in 0..n_in {
let mut s = 0.0_f64;
for j in 0..n_out {
let pij = transition[i][j];
if pij > 0.0 && py[j] > 0.0 {
s += pij * (pij / py[j]).ln();
}
}
c[i] = s;
}
let c_max = c.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let mut new_q = vec![0.0_f64; n_in];
let mut sum = 0.0_f64;
for i in 0..n_in {
new_q[i] = q[i] * (c[i] - c_max).exp();
sum += new_q[i];
}
for qi in &mut new_q {
*qi /= sum;
}
q = new_q;
}
let mut py = vec![0.0_f64; n_out];
for (i, qi) in q.iter().enumerate() {
for (j, &tij) in transition[i].iter().enumerate() {
py[j] += qi * tij;
}
}
let mut cap = 0.0_f64;
for i in 0..n_in {
if q[i] <= 0.0 {
continue;
}
for j in 0..n_out {
let pij = transition[i][j];
if pij > 0.0 && py[j] > 0.0 {
cap += q[i] * pij * (pij / py[j]).log2();
}
}
}
cap.max(0.0)
}
pub fn bec_capacity(epsilon: f64) -> f64 {
(1.0 - epsilon).max(0.0)
}
pub fn bsc_capacity(p: f64) -> f64 {
let hb = -xlog2x(p) - xlog2x(1.0 - p);
(1.0 - hb).max(0.0)
}
pub fn awgn_capacity(power: f64, noise: f64) -> f64 {
0.5 * (1.0 + power / noise).log2()
}
pub fn sphere_packing_bound(n: usize, d: usize) -> f64 {
let ratio = d as f64 / (2.0 * n as f64);
let hb = -xlog2x(ratio) - xlog2x(1.0 - ratio);
(1.0 - hb).max(0.0)
}
pub fn huffman_lengths(probs: &[f64]) -> Vec<usize> {
let n = probs.len();
if n == 0 {
return vec![];
}
if n == 1 {
return vec![1];
}
#[derive(Clone)]
pub(super) struct Node {
prob: f64,
leaves: Vec<usize>,
}
let mut heap: Vec<Node> = probs
.iter()
.enumerate()
.map(|(i, &p)| Node {
prob: p,
leaves: vec![i],
})
.collect();
let mut lengths = vec![0_usize; n];
while heap.len() > 1 {
let (i1, _) = heap
.iter()
.enumerate()
.min_by(|a, b| {
a.1.prob
.partial_cmp(&b.1.prob)
.unwrap_or(std::cmp::Ordering::Equal)
})
.expect("heap has at least 2 elements");
let n1 = heap.remove(i1);
let (i2, _) = heap
.iter()
.enumerate()
.min_by(|a, b| {
a.1.prob
.partial_cmp(&b.1.prob)
.unwrap_or(std::cmp::Ordering::Equal)
})
.expect("heap has at least 2 elements");
let n2 = heap.remove(i2);
for &l in &n1.leaves {
lengths[l] += 1;
}
for &l in &n2.leaves {
lengths[l] += 1;
}
let combined_leaves: Vec<usize> = n1.leaves.into_iter().chain(n2.leaves).collect();
heap.push(Node {
prob: n1.prob + n2.prob,
leaves: combined_leaves,
});
}
lengths
}
pub fn huffman_expected_length(probs: &[f64]) -> f64 {
let lengths = huffman_lengths(probs);
probs
.iter()
.zip(lengths.iter())
.map(|(&p, &l)| p * l as f64)
.sum()
}
pub fn binary_rate_distortion(d: f64) -> f64 {
let d = d.clamp(0.0, 0.5);
let hd = -xlog2x(d) - xlog2x(1.0 - d);
(1.0 - hd).max(0.0)
}
pub fn gaussian_rate_distortion(sigma_sq: f64, d: f64) -> f64 {
if d >= sigma_sq || d <= 0.0 {
return 0.0;
}
0.5 * (sigma_sq / d).log2()
}
pub fn uniform_rate_distortion(m: usize, d: f64) -> f64 {
if m <= 1 {
return 0.0;
}
let d = d.clamp(0.0, 1.0 - 1.0 / m as f64);
let hd = -xlog2x(d) - xlog2x(1.0 - d);
let r = (m as f64).log2() - hd - d * ((m - 1) as f64).log2();
r.max(0.0)
}
pub fn gaussian_distortion_at_rate(sigma_sq: f64, rate: f64) -> f64 {
sigma_sq * 2.0_f64.powf(-2.0 * rate)
}
pub fn kraft_inequality(lengths: &[usize]) -> bool {
let sum: f64 = lengths.iter().map(|&l| 2.0_f64.powi(-(l as i32))).sum();
sum <= 1.0 + 1e-9
}
pub fn code_redundancy(probs: &[f64], lengths: &[usize]) -> f64 {
assert_eq!(
probs.len(),
lengths.len(),
"code_redundancy: length mismatch"
);
let expected_len: f64 = probs
.iter()
.zip(lengths.iter())
.map(|(&p, &l)| p * l as f64)
.sum();
expected_len - entropy_bits(probs)
}
pub fn shannon_fano_lengths(probs: &[f64]) -> Vec<usize> {
probs
.iter()
.map(|&p| {
if p <= 0.0 {
0
} else {
(-p.log2()).ceil() as usize + 1
}
})
.collect()
}
pub fn aic(k: usize, log_likelihood: f64) -> f64 {
2.0 * k as f64 - 2.0 * log_likelihood
}
pub fn aicc(k: usize, n: usize, log_likelihood: f64) -> f64 {
let base = aic(k, log_likelihood);
let kf = k as f64;
let nf = n as f64;
let denom = nf - kf - 1.0;
if denom <= 0.0 {
base
} else {
base + 2.0 * kf * (kf + 1.0) / denom
}
}
pub fn bic(k: usize, n: usize, log_likelihood: f64) -> f64 {
k as f64 * (n as f64).ln() - 2.0 * log_likelihood
}
pub fn select_by_aic(models: &[(usize, f64)]) -> Option<usize> {
models
.iter()
.enumerate()
.min_by(|a, b| {
let aic_a = aic(a.1.0, a.1.1);
let aic_b = aic(b.1.0, b.1.1);
aic_a
.partial_cmp(&aic_b)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
}
pub fn mdl(k: usize, n: usize, log_likelihood: f64) -> f64 {
-log_likelihood + 0.5 * k as f64 * (n as f64).ln()
}
pub fn fisher_metric_categorical(probs: &[f64]) -> Vec<f64> {
let k = probs.len();
let mut mat = vec![0.0_f64; k * k];
for (i, &p) in probs.iter().enumerate() {
if p > 0.0 {
mat[i * k + i] = 1.0 / p;
}
}
mat
}
pub fn fisher_information_gaussian_mean(sigma: f64) -> f64 {
if sigma <= 0.0 {
f64::INFINITY
} else {
1.0 / (sigma * sigma)
}
}
pub fn fisher_information_bernoulli(p: f64) -> f64 {
let denom = p * (1.0 - p);
if denom <= 0.0 {
f64::INFINITY
} else {
1.0 / denom
}
}
pub fn fisher_information_poisson(lambda: f64) -> f64 {
if lambda <= 0.0 {
f64::INFINITY
} else {
1.0 / lambda
}
}
pub fn fisher_information_exponential(lambda: f64) -> f64 {
if lambda <= 0.0 {
f64::INFINITY
} else {
1.0 / (lambda * lambda)
}
}
pub fn fisher_matrix_gaussian(sigma: f64) -> [f64; 4] {
let s2 = sigma * sigma;
[1.0 / s2, 0.0, 0.0, 2.0 / s2]
}
pub fn fisher_information_numerical(
log_p: impl Fn(f64, f64) -> f64,
theta: f64,
data: &[f64],
_h: f64,
) -> f64 {
let h = if _h > 0.0 { _h } else { 1e-6 };
let n = data.len() as f64;
if n <= 0.0 {
return 0.0;
}
let mut sum_sq = 0.0_f64;
for &x in data {
let dlog = (log_p(theta + h, x) - log_p(theta - h, x)) / (2.0 * h);
sum_sq += dlog * dlog;
}
sum_sq / n
}
pub fn natural_gradient_step_categorical(
probs: &[f64],
grad: &[f64],
learning_rate: f64,
) -> Vec<f64> {
assert_eq!(
probs.len(),
grad.len(),
"natural_gradient_step: length mismatch"
);
probs
.iter()
.zip(grad.iter())
.map(|(&p, &g)| p - learning_rate * p * g)
.collect()
}
pub fn fisher_rao_distance(p: &[f64], q: &[f64]) -> f64 {
assert_eq!(p.len(), q.len(), "fisher_rao_distance: length mismatch");
let bc: f64 = p
.iter()
.zip(q.iter())
.map(|(&pi, &qi)| (pi * qi).sqrt())
.sum();
2.0 * bc.clamp(-1.0, 1.0).acos()
}
pub fn fisher_rao_gaussian(mu1: f64, sigma1: f64, mu2: f64, sigma2: f64) -> f64 {
let s1 = sigma1.abs();
let s2 = sigma2.abs();
if s1 <= 0.0 || s2 <= 0.0 {
return f64::INFINITY;
}
let num = s1 * s1 + s2 * s2 + (mu1 - mu2).powi(2);
let den = 2.0 * s1 * s2;
(2.0 * (num / den).ln()).sqrt()
}
pub fn christoffel_first_kind_simplex(probs: &[f64], i: usize, j: usize, k: usize) -> f64 {
if i == j && j == k && probs[i] > 0.0 {
-1.0 / (2.0 * probs[i] * probs[i])
} else {
0.0
}
}
pub fn exponential_map_simplex(p: &[f64], v: &[f64], t: f64) -> Vec<f64> {
assert_eq!(p.len(), v.len(), "exponential_map: length mismatch");
let n = p.len();
let xi: Vec<f64> = p.iter().map(|&pi| pi.sqrt()).collect();
let mut tang: Vec<f64> = (0..n)
.map(|i| {
if p[i] > 0.0 {
v[i] / (2.0 * p[i].sqrt())
} else {
0.0
}
})
.collect();
let dot: f64 = xi.iter().zip(tang.iter()).map(|(&x, &t_)| x * t_).sum();
for i in 0..n {
tang[i] -= dot * xi[i];
}
let tang_norm: f64 = tang.iter().map(|&t_| t_ * t_).sum::<f64>().sqrt();
if tang_norm < 1e-15 {
return p.to_vec();
}
let angle = tang_norm * t;
let cos_a = angle.cos();
let sin_a = angle.sin();
let mut result = vec![0.0_f64; n];
for i in 0..n {
let new_xi = cos_a * xi[i] + sin_a * tang[i] / tang_norm;
result[i] = new_xi * new_xi;
}
let _ = normalise_inplace(&mut result);
result
}
pub fn parallel_transport_simplex(p: &[f64], q: &[f64], v: &[f64], t: f64) -> Vec<f64> {
let n = p.len();
assert_eq!(n, q.len());
assert_eq!(n, v.len());
let d = fisher_rao_distance(p, q);
if d < 1e-15 {
return v.to_vec();
}
let mut result = vec![0.0_f64; n];
for i in 0..n {
let pi_t = (1.0 - t) * p[i] + t * q[i];
if p[i] > 0.0 && pi_t > 0.0 {
result[i] = v[i] * (pi_t / p[i]).sqrt();
}
}
result
}