logp/lib.rs
1//! # logp
2//!
3//! Information theory primitives: entropies and divergences.
4//!
5//! ## Scope
6//!
7//! This crate is **L1 (Logic)** in the mathematical foundation: it should stay small and reusable.
8//! It provides *scalar* information measures that appear across clustering, ranking,
9//! evaluation, and geometry:
10//!
11//! - Shannon entropy and cross-entropy
12//! - KL / Jensen–Shannon divergences
13//! - Csiszár \(f\)-divergences (a.k.a. *information monotone* divergences)
14//! - Bhattacharyya coefficient, Rényi/Tsallis families
15//! - Bregman divergences (convex-analytic, not generally monotone)
16//!
17//! ## Distances vs divergences (terminology that prevents bugs)
18//!
19//! A **divergence** \(D(p:q)\) is usually required to satisfy:
20//!
21//! - \(D(p:q) \ge 0\)
22//! - \(D(p:p) = 0\)
23//!
24//! but it is typically **not** symmetric and **not** a metric (no triangle inequality).
25//! Many failures in downstream code are caused by treating a divergence as a distance metric.
26//!
27//! ## Key invariants (what tests should enforce)
28//!
29//! - **Jensen–Shannon** is bounded on the simplex:
30//! \(0 \le JS(p,q) \le \ln 2\) (nats).
31//! - **Csiszár \(f\)-divergences** are monotone under coarse-graining (Markov kernels):
32//! merging bins cannot increase the divergence.
33//!
34//! ## Further reading
35//!
36//! - Frank Nielsen, “Divergences” portal (taxonomy diagrams + references):
37//! <https://franknielsen.github.io/Divergence/index.html>
38//! - `nocotan/awesome-information-geometry` (curated reading list):
39//! <https://github.com/nocotan/awesome-information-geometry>
40//! - Csiszár (1967): \(f\)-divergences and information monotonicity.
41//! - Amari & Nagaoka (2000): *Methods of Information Geometry*.
42//!
43//! ## Taxonomy of Divergences (Nielsen)
44//!
45//! | Family | Generator | Key Property |
46//! |---|---|---|
47//! | **f-divergences** | Convex \(f(t)\) with \(f(1)=0\) | Monotone under Markov morphisms (coarse-graining) |
48//! | **Bregman** | Convex \(F(x)\) | Dually flat geometry; generalized Pythagorean theorem |
49//! | **Jensen-Shannon** | \(f\)-div + metric | Symmetric, bounded \(\[0, \ln 2\]\), \(\sqrt{JS}\) is a metric |
50//! | **Alpha** | \(\rho_\alpha = \int p^\alpha q^{1-\alpha}\) | Encodes Rényi, Tsallis, Bhattacharyya, Hellinger |
51//!
52//! ## Connections
53//!
54//! - [`rkhs`](../rkhs): MMD and KL both measure distribution "distance"
55//! - [`wass`](../wass): Wasserstein vs entropy-based divergences
56//! - [`stratify`](../stratify): NMI for cluster evaluation uses this crate
57//! - [`fynch`](../fynch): Temperature scaling affects entropy calibration
58//!
59//! ## References
60//!
61//! - Shannon (1948). "A Mathematical Theory of Communication"
62//! - Cover & Thomas (2006). "Elements of Information Theory"
63
64#![forbid(unsafe_code)]
65
66use thiserror::Error;
67
68mod ksg;
69pub use ksg::{mutual_information_ksg, KsgVariant};
70
71/// Natural log of 2. Useful when converting nats ↔ bits or bounding Jensen–Shannon.
72pub const LN_2: f64 = core::f64::consts::LN_2;
73
74/// KL Divergence between two diagonal Multivariate Gaussians.
75///
76/// Used for Variational Information Bottleneck (VIB) to regularize latent spaces.
77///
78/// Returns 0.5 * Σ [ (std1/std2)^2 + (mu2-mu1)^2 / std2^2 - 1 + 2*ln(std2/std1) ]
79///
80/// # Examples
81///
82/// ```
83/// # use logp::kl_divergence_gaussians;
84/// // KL(N || N) = 0 (identical Gaussians).
85/// let mu = [0.0, 1.0];
86/// let std = [1.0, 2.0];
87/// let kl = kl_divergence_gaussians(&mu, &std, &mu, &std).unwrap();
88/// assert!(kl.abs() < 1e-12);
89///
90/// // KL is non-negative for distinct Gaussians.
91/// let kl = kl_divergence_gaussians(&[0.0], &[1.0], &[1.0], &[1.0]).unwrap();
92/// assert!(kl >= 0.0);
93/// ```
94pub fn kl_divergence_gaussians(
95 mu1: &[f64],
96 std1: &[f64],
97 mu2: &[f64],
98 std2: &[f64],
99) -> Result<f64> {
100 ensure_same_len(mu1, std1)?;
101 ensure_same_len(mu1, mu2)?;
102 ensure_same_len(mu1, std2)?;
103
104 let mut kl = 0.0;
105 for (((&m1, &s1), &m2), &s2) in mu1.iter().zip(std1).zip(mu2).zip(std2) {
106 if s1 <= 0.0 || s2 <= 0.0 {
107 return Err(Error::Domain("standard deviation must be positive"));
108 }
109 let v1 = s1 * s1;
110 let v2 = s2 * s2;
111 kl += (v1 / v2) + (m2 - m1).powi(2) / v2 - 1.0 + 2.0 * (s2.ln() - s1.ln());
112 }
113 Ok(0.5 * kl)
114}
115
116/// Errors for information-measure computations.
117#[derive(Debug, Error)]
118pub enum Error {
119 #[error("length mismatch: {0} vs {1}")]
120 LengthMismatch(usize, usize),
121
122 #[error("empty input")]
123 Empty,
124
125 #[error("non-finite entry at index {idx}: {value}")]
126 NonFinite { idx: usize, value: f64 },
127
128 #[error("negative entry at index {idx}: {value}")]
129 Negative { idx: usize, value: f64 },
130
131 #[error("not normalized (expected sum≈1): sum={sum}")]
132 NotNormalized { sum: f64 },
133
134 #[error("invalid alpha: {alpha} (must be finite and not equal to {forbidden})")]
135 InvalidAlpha { alpha: f64, forbidden: f64 },
136
137 #[error("domain error: {0}")]
138 Domain(&'static str),
139}
140
141pub type Result<T> = core::result::Result<T, Error>;
142
143fn ensure_nonempty(x: &[f64]) -> Result<()> {
144 if x.is_empty() {
145 return Err(Error::Empty);
146 }
147 Ok(())
148}
149
150fn ensure_same_len(a: &[f64], b: &[f64]) -> Result<()> {
151 if a.len() != b.len() {
152 return Err(Error::LengthMismatch(a.len(), b.len()));
153 }
154 Ok(())
155}
156
157fn ensure_nonnegative(x: &[f64]) -> Result<()> {
158 for (i, &v) in x.iter().enumerate() {
159 if !v.is_finite() {
160 return Err(Error::NonFinite { idx: i, value: v });
161 }
162 if v < 0.0 {
163 return Err(Error::Negative { idx: i, value: v });
164 }
165 }
166 Ok(())
167}
168
169fn sum(x: &[f64]) -> f64 {
170 x.iter().sum()
171}
172
173/// Validate that `p` is a probability distribution on the simplex (within `tol`).
174///
175/// # Examples
176///
177/// ```
178/// # use logp::validate_simplex;
179/// // Valid simplex.
180/// assert!(validate_simplex(&[0.3, 0.7], 1e-9).is_ok());
181/// assert!(validate_simplex(&[1.0], 1e-9).is_ok());
182///
183/// // Rejects bad sum.
184/// assert!(validate_simplex(&[0.3, 0.6], 1e-9).is_err());
185///
186/// // Rejects negative entries.
187/// assert!(validate_simplex(&[1.5, -0.5], 1e-9).is_err());
188///
189/// // Rejects empty input.
190/// assert!(validate_simplex(&[], 1e-9).is_err());
191/// ```
192pub fn validate_simplex(p: &[f64], tol: f64) -> Result<()> {
193 ensure_nonempty(p)?;
194 ensure_nonnegative(p)?;
195 let s = sum(p);
196 if (s - 1.0).abs() > tol {
197 return Err(Error::NotNormalized { sum: s });
198 }
199 Ok(())
200}
201
202/// Normalize a nonnegative vector in-place to sum to 1.
203///
204/// Returns the original sum.
205///
206/// # Examples
207///
208/// ```
209/// # use logp::normalize_in_place;
210/// let mut v = vec![2.0, 3.0, 5.0];
211/// let original_sum = normalize_in_place(&mut v).unwrap();
212/// assert!((original_sum - 10.0).abs() < 1e-12);
213/// assert!((v[0] - 0.2).abs() < 1e-12);
214/// assert!((v[1] - 0.3).abs() < 1e-12);
215/// assert!((v[2] - 0.5).abs() < 1e-12);
216///
217/// // Rejects all-zero input.
218/// assert!(normalize_in_place(&mut vec![0.0, 0.0]).is_err());
219/// ```
220pub fn normalize_in_place(p: &mut [f64]) -> Result<f64> {
221 ensure_nonempty(p)?;
222 ensure_nonnegative(p)?;
223 let s = sum(p);
224 if s <= 0.0 {
225 return Err(Error::Domain("cannot normalize: sum <= 0"));
226 }
227 for v in p.iter_mut() {
228 *v /= s;
229 }
230 Ok(s)
231}
232
233/// Shannon entropy in nats: the expected surprise under distribution \(p\).
234///
235/// \[H(p) = -\sum_i p_i \ln p_i\]
236///
237/// # Key properties
238///
239/// - **Non-negative**: \(H(p) \ge 0\), with equality iff \(p\) is a delta (point mass).
240/// - **Maximized by uniform**: among distributions on \(n\) outcomes,
241/// \(H(p) \le \ln n\), with equality iff \(p_i = 1/n\) for all \(i\).
242/// - **Concavity**: \(H\) is a concave function of \(p\) on the simplex.
243/// Mixing distributions never decreases entropy.
244/// - **Units**: result is in nats (base \(e\)); divide by \(\ln 2\) for bits.
245///
246/// # Domain
247///
248/// Requires `p` to be a valid simplex distribution (within `tol`).
249///
250/// # Examples
251///
252/// ```
253/// # use logp::entropy_nats;
254/// // Uniform distribution over 4 outcomes: H = ln(4).
255/// let p = [0.25, 0.25, 0.25, 0.25];
256/// let h = entropy_nats(&p, 1e-9).unwrap();
257/// assert!((h - 4.0_f64.ln()).abs() < 1e-12);
258///
259/// // Delta (point mass): H = 0.
260/// let delta = [1.0, 0.0, 0.0];
261/// assert!(entropy_nats(&delta, 1e-9).unwrap().abs() < 1e-15);
262/// ```
263pub fn entropy_nats(p: &[f64], tol: f64) -> Result<f64> {
264 validate_simplex(p, tol)?;
265 let mut h = 0.0;
266 for &pi in p {
267 if pi > 0.0 {
268 h -= pi * pi.ln();
269 }
270 }
271 Ok(h)
272}
273
274/// Shannon entropy in bits.
275///
276/// # Examples
277///
278/// ```
279/// # use logp::{entropy_bits, entropy_nats, LN_2};
280/// // Fair coin: H = 1 bit.
281/// let p = [0.5, 0.5];
282/// let bits = entropy_bits(&p, 1e-9).unwrap();
283/// assert!((bits - 1.0).abs() < 1e-12);
284///
285/// // Consistent with nats / ln(2).
286/// let nats = entropy_nats(&p, 1e-9).unwrap();
287/// assert!((bits - nats / LN_2).abs() < 1e-12);
288/// ```
289pub fn entropy_bits(p: &[f64], tol: f64) -> Result<f64> {
290 Ok(entropy_nats(p, tol)? / LN_2)
291}
292
293/// Fast Shannon entropy calculation without simplex validation.
294///
295/// Used in performance-critical loops like Sinkhorn iteration for Optimal Transport.
296///
297/// # Invariant
298/// Assumes `p` is non-negative and normalized.
299///
300/// # Examples
301///
302/// ```
303/// # use logp::{entropy_unchecked, LN_2};
304/// // Fair coin: H = ln(2).
305/// let h = entropy_unchecked(&[0.5, 0.5]);
306/// assert!((h - LN_2).abs() < 1e-12);
307///
308/// // Agrees with the checked version on valid input.
309/// let p = [0.3, 0.7];
310/// let h_checked = logp::entropy_nats(&p, 1e-9).unwrap();
311/// assert!((entropy_unchecked(&p) - h_checked).abs() < 1e-15);
312/// ```
313#[inline]
314pub fn entropy_unchecked(p: &[f64]) -> f64 {
315 let mut h = 0.0;
316 for &pi in p {
317 if pi > 0.0 {
318 h -= pi * pi.ln();
319 }
320 }
321 h
322}
323
324/// Cross-entropy in nats: the expected code length when using model \(q\) to encode
325/// data drawn from true distribution \(p\).
326///
327/// \[H(p, q) = -\sum_i p_i \ln q_i\]
328///
329/// # Key properties
330///
331/// - **Decomposition identity**: cross-entropy splits into entropy plus KL divergence:
332/// \(H(p, q) = H(p) + D_{KL}(p \| q)\).
333/// This means \(H(p, q) \ge H(p)\) with equality iff \(p = q\).
334/// - **Loss function**: minimizing \(H(p, q)\) over \(q\) is equivalent to minimizing
335/// \(D_{KL}(p \| q)\), which is why cross-entropy is the standard classification loss.
336/// - **Not symmetric**: \(H(p, q) \ne H(q, p)\) in general.
337///
338/// # Domain
339///
340/// `p` must be on the simplex; `q` must be nonnegative and normalized; and
341/// whenever `p_i > 0`, we require `q_i > 0` (otherwise cross-entropy is infinite).
342///
343/// # Examples
344///
345/// ```
346/// # use logp::{cross_entropy_nats, entropy_nats, kl_divergence};
347/// let p = [0.3, 0.7];
348/// let q = [0.5, 0.5];
349/// let h_pq = cross_entropy_nats(&p, &q, 1e-9).unwrap();
350///
351/// // Decomposition: H(p,q) = H(p) + KL(p||q).
352/// let h_p = entropy_nats(&p, 1e-9).unwrap();
353/// let kl = kl_divergence(&p, &q, 1e-9).unwrap();
354/// assert!((h_pq - (h_p + kl)).abs() < 1e-12);
355///
356/// // Self-cross-entropy equals entropy: H(p,p) = H(p).
357/// let h_pp = cross_entropy_nats(&p, &p, 1e-9).unwrap();
358/// assert!((h_pp - h_p).abs() < 1e-12);
359/// ```
360pub fn cross_entropy_nats(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
361 validate_simplex(p, tol)?;
362 validate_simplex(q, tol)?;
363 let mut h = 0.0;
364 for (&pi, &qi) in p.iter().zip(q.iter()) {
365 if pi == 0.0 {
366 continue;
367 }
368 if qi <= 0.0 {
369 return Err(Error::Domain("cross-entropy undefined: q_i=0 while p_i>0"));
370 }
371 h -= pi * qi.ln();
372 }
373 Ok(h)
374}
375
376/// Kullback--Leibler divergence in nats: the information lost when \(q\) is used to
377/// approximate \(p\).
378///
379/// \[D_{KL}(p \| q) = \sum_i p_i \ln \frac{p_i}{q_i}\]
380///
381/// # Key properties
382///
383/// - **Gibbs' inequality**: \(D_{KL}(p \| q) \ge 0\), with equality iff \(p = q\).
384/// This follows directly from Jensen's inequality applied to \(-\ln\).
385/// - **Not symmetric**: \(D_{KL}(p \| q) \ne D_{KL}(q \| p)\) in general;
386/// this is why KL is a divergence, not a distance.
387/// - **Not bounded above**: KL can be arbitrarily large when supports differ.
388/// - **Connection to MLE**: minimizing \(D_{KL}(p_{data} \| q_\theta)\) over \(\theta\)
389/// is equivalent to maximum likelihood estimation.
390/// - **Additive for independent distributions**: if \(p = p_1 \otimes p_2\) and
391/// \(q = q_1 \otimes q_2\), then
392/// \(D_{KL}(p \| q) = D_{KL}(p_1 \| q_1) + D_{KL}(p_2 \| q_2)\).
393///
394/// # Domain
395///
396/// `p` and `q` must be valid simplex distributions; and whenever `p_i > 0`,
397/// we require `q_i > 0`.
398///
399/// # Examples
400///
401/// ```
402/// # use logp::kl_divergence;
403/// // KL(p || p) = 0 (Gibbs' inequality, tight case).
404/// let p = [0.2, 0.3, 0.5];
405/// assert!(kl_divergence(&p, &p, 1e-9).unwrap().abs() < 1e-15);
406///
407/// // KL is non-negative.
408/// let q = [0.5, 0.25, 0.25];
409/// assert!(kl_divergence(&p, &q, 1e-9).unwrap() >= 0.0);
410///
411/// // Not symmetric in general.
412/// let kl_pq = kl_divergence(&p, &q, 1e-9).unwrap();
413/// let kl_qp = kl_divergence(&q, &p, 1e-9).unwrap();
414/// assert!((kl_pq - kl_qp).abs() > 1e-6);
415/// ```
416pub fn kl_divergence(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
417 ensure_same_len(p, q)?;
418 validate_simplex(p, tol)?;
419 validate_simplex(q, tol)?;
420 let mut d = 0.0;
421 for (&pi, &qi) in p.iter().zip(q.iter()) {
422 if pi == 0.0 {
423 continue;
424 }
425 if qi <= 0.0 {
426 return Err(Error::Domain("KL undefined: q_i=0 while p_i>0"));
427 }
428 d += pi * (pi / qi).ln();
429 }
430 Ok(d)
431}
432
433/// Jensen--Shannon divergence in nats: a symmetric, bounded smoothing of KL divergence.
434///
435/// \[JS(p, q) = \tfrac{1}{2} D_{KL}(p \| m) + \tfrac{1}{2} D_{KL}(q \| m), \quad m = \tfrac{1}{2}(p + q)\]
436///
437/// # Key properties
438///
439/// - **Symmetric**: \(JS(p, q) = JS(q, p)\), unlike KL.
440/// - **Bounded**: \(0 \le JS(p, q) \le \ln 2\). The upper bound is attained when \(p\)
441/// and \(q\) have disjoint supports.
442/// - **Square root is a metric**: \(\sqrt{JS(p, q)}\) satisfies the triangle inequality
443/// (Endres & Schindelin, 2003), so it can be used as a proper distance function.
444/// - **Connection to mutual information**: \(JS(p, q) = I(X; Z)\) where \(Z\) is a
445/// fair coin selecting between \(p\) and \(q\), and \(X\) is drawn from the selected
446/// distribution.
447/// - **Always finite**: because \(m_i > 0\) whenever \(p_i > 0\) or \(q_i > 0\), the
448/// KL terms are always well-defined (no division by zero).
449///
450/// # Domain
451///
452/// `p`, `q` must be simplex distributions.
453///
454/// # Examples
455///
456/// ```
457/// # use logp::{jensen_shannon_divergence, LN_2};
458/// // JS(p, p) = 0.
459/// let p = [0.3, 0.7];
460/// assert!(jensen_shannon_divergence(&p, &p, 1e-9).unwrap().abs() < 1e-15);
461///
462/// // Disjoint supports: JS = ln(2).
463/// let a = [1.0, 0.0];
464/// let b = [0.0, 1.0];
465/// let js = jensen_shannon_divergence(&a, &b, 1e-9).unwrap();
466/// assert!((js - LN_2).abs() < 1e-12);
467///
468/// // Symmetric: JS(p, q) = JS(q, p).
469/// let q = [0.5, 0.5];
470/// let js_pq = jensen_shannon_divergence(&p, &q, 1e-9).unwrap();
471/// let js_qp = jensen_shannon_divergence(&q, &p, 1e-9).unwrap();
472/// assert!((js_pq - js_qp).abs() < 1e-15);
473/// ```
474pub fn jensen_shannon_divergence(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
475 ensure_same_len(p, q)?;
476 validate_simplex(p, tol)?;
477 validate_simplex(q, tol)?;
478
479 let mut m = vec![0.0; p.len()];
480 for i in 0..p.len() {
481 m[i] = 0.5 * (p[i] + q[i]);
482 }
483
484 Ok(0.5 * kl_divergence(p, &m, tol)? + 0.5 * kl_divergence(q, &m, tol)?)
485}
486
487/// Mutual information in nats: how much knowing \(Y\) reduces uncertainty about \(X\).
488///
489/// \[I(X; Y) = \sum_{x,y} p(x,y) \ln \frac{p(x,y)}{p(x)\,p(y)}\]
490///
491/// # Key properties
492///
493/// - **KL form**: \(I(X; Y) = D_{KL}\bigl(p(x,y) \;\|\; p(x)\,p(y)\bigr)\), measuring
494/// how far the joint distribution is from the product of its marginals.
495/// - **Non-negative**: \(I(X; Y) \ge 0\), with equality iff \(X\) and \(Y\) are
496/// independent.
497/// - **Symmetric**: \(I(X; Y) = I(Y; X)\).
498/// - **Bounded by entropy**: \(I(X; Y) \le \min\{H(X),\, H(Y)\}\).
499/// - **Data processing inequality**: for any Markov chain \(X \to Y \to Z\),
500/// \(I(X; Z) \le I(X; Y)\). Processing cannot create information.
501/// - **Entropy decomposition**: \(I(X; Y) = H(X) + H(Y) - H(X, Y) = H(X) - H(X|Y)\).
502///
503/// # Layout
504///
505/// For discrete distributions, given a **row-major** joint distribution table `p_xy`
506/// with shape `(n_x, n_y)`.
507///
508/// Public invariant (this is the important one): this API is **backend-agnostic**.
509/// It does not force `ndarray` into the public surface of an L1 crate.
510///
511/// # Examples
512///
513/// ```
514/// # use logp::{mutual_information, entropy_nats, LN_2};
515/// // Independent joint: p(x,y) = p(x)*p(y), so I(X;Y) = 0.
516/// let p_xy = [0.15, 0.35, 0.15, 0.35]; // 2x2, marginals [0.5,0.5] x [0.3,0.7]
517/// let mi = mutual_information(&p_xy, 2, 2, 1e-9).unwrap();
518/// assert!(mi.abs() < 1e-12);
519///
520/// // Perfect correlation (Y = X, uniform bit): I(X;Y) = H(X) = ln(2).
521/// let diag = [0.5, 0.0, 0.0, 0.5];
522/// let mi = mutual_information(&diag, 2, 2, 1e-9).unwrap();
523/// assert!((mi - LN_2).abs() < 1e-12);
524/// ```
525pub fn mutual_information(p_xy: &[f64], n_x: usize, n_y: usize, tol: f64) -> Result<f64> {
526 if n_x == 0 || n_y == 0 {
527 return Err(Error::Domain(
528 "mutual_information: n_x and n_y must be >= 1",
529 ));
530 }
531 if p_xy.len() != n_x * n_y {
532 return Err(Error::LengthMismatch(p_xy.len(), n_x * n_y));
533 }
534 validate_simplex(p_xy, tol)?;
535
536 let mut p_x = vec![0.0; n_x];
537 let mut p_y = vec![0.0; n_y];
538 for i in 0..n_x {
539 for j in 0..n_y {
540 let p = p_xy[i * n_y + j];
541 p_x[i] += p;
542 p_y[j] += p;
543 }
544 }
545
546 let mut mi = 0.0;
547 for i in 0..n_x {
548 for j in 0..n_y {
549 let pxy = p_xy[i * n_y + j];
550 if pxy > 0.0 {
551 let px = p_x[i];
552 let py = p_y[j];
553 if px <= 0.0 || py <= 0.0 {
554 return Err(Error::Domain(
555 "mutual_information: p(x)=0 or p(y)=0 while p(x,y)>0",
556 ));
557 }
558 mi += pxy * (pxy / (px * py)).ln();
559 }
560 }
561 }
562 Ok(mi)
563}
564
565/// `ndarray` adapter for discrete mutual information.
566///
567/// Requires `logp` feature `ndarray`.
568#[cfg(feature = "ndarray")]
569pub fn mutual_information_ndarray(p_xy: &ndarray::Array2<f64>, tol: f64) -> Result<f64> {
570 let (n_x, n_y) = p_xy.dim();
571 let flat: Vec<f64> = p_xy.iter().copied().collect();
572 mutual_information(&flat, n_x, n_y, tol)
573}
574
575/// Pointwise mutual information: the log-ratio measuring how much more (or less)
576/// likely two specific outcomes co-occur than if they were independent.
577///
578/// \[PMI(x; y) = \ln \frac{p(x, y)}{p(x)\,p(y)}\]
579///
580/// # Key properties
581///
582/// - **Sign**: positive when \(x\) and \(y\) co-occur more than chance; negative when
583/// less; zero when independent.
584/// - **Unbounded**: \(PMI \in (-\infty, -\ln p(x,y)]\). In practice, rare events yield
585/// very large PMI, which is why PPMI (positive PMI, clamped at 0) is common.
586/// - **Connection to mutual information**: \(I(X; Y) = \mathbb{E}_{p(x,y)}[PMI(x; y)]\).
587/// MI is the expected value of PMI over the joint distribution.
588/// - **Connection to word2vec**: Levy & Goldberg (2014) showed that Skip-gram with
589/// negative sampling implicitly factorizes a PMI matrix (shifted by \(\ln k\)).
590///
591/// # Examples
592///
593/// ```
594/// # use logp::pmi;
595/// // Independent events: p(x,y) = p(x)*p(y), so PMI = 0.
596/// let val = pmi(0.06, 0.3, 0.2);
597/// assert!(val.abs() < 1e-10);
598///
599/// // Positive correlation: p(x,y) > p(x)*p(y).
600/// assert!(pmi(0.4, 0.5, 0.5) > 0.0);
601///
602/// // Negative correlation: p(x,y) < p(x)*p(y).
603/// assert!(pmi(0.1, 0.5, 0.5) < 0.0);
604///
605/// // Zero joint probability returns 0 by convention.
606/// assert_eq!(pmi(0.0, 0.5, 0.5), 0.0);
607/// ```
608pub fn pmi(pxy: f64, px: f64, py: f64) -> f64 {
609 if pxy <= 0.0 || px <= 0.0 || py <= 0.0 {
610 0.0
611 } else {
612 (pxy / (px * py)).ln()
613 }
614}
615
616/// Log-sum-exp: numerically stable computation of `ln(exp(a_1) + ... + exp(a_n))`.
617///
618/// This is the fundamental primitive for working in log-probability space.
619/// The naive `values.iter().map(|v| v.exp()).sum::<f64>().ln()` overflows
620/// for large values and underflows for small ones; the max-shift trick
621/// avoids both.
622///
623/// Returns `NEG_INFINITY` for an empty slice.
624///
625/// # Examples
626///
627/// ```
628/// # use logp::log_sum_exp;
629/// // ln(e^0 + e^0) = ln(2)
630/// let lse = log_sum_exp(&[0.0, 0.0]);
631/// assert!((lse - 2.0_f64.ln()).abs() < 1e-12);
632///
633/// // Dominated term: ln(e^1000 + e^0) ≈ 1000
634/// let lse = log_sum_exp(&[1000.0, 0.0]);
635/// assert!((lse - 1000.0).abs() < 1e-10);
636///
637/// // Single element: identity.
638/// assert_eq!(log_sum_exp(&[42.0]), 42.0);
639///
640/// // Empty: -inf.
641/// assert_eq!(log_sum_exp(&[]), f64::NEG_INFINITY);
642/// ```
643#[inline]
644pub fn log_sum_exp(values: &[f64]) -> f64 {
645 if values.is_empty() {
646 return f64::NEG_INFINITY;
647 }
648 let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
649 if max.is_infinite() {
650 return max;
651 }
652 let sum: f64 = values.iter().map(|v| (v - max).exp()).sum();
653 max + sum.ln()
654}
655
656/// Log-sum-exp for two values (common special case).
657///
658/// Equivalent to `log_sum_exp(&[a, b])` but avoids the slice overhead.
659///
660/// # Examples
661///
662/// ```
663/// # use logp::log_sum_exp2;
664/// let lse = log_sum_exp2(0.0, 0.0);
665/// assert!((lse - 2.0_f64.ln()).abs() < 1e-12);
666/// ```
667#[inline]
668pub fn log_sum_exp2(a: f64, b: f64) -> f64 {
669 let max = a.max(b);
670 if max.is_infinite() {
671 return max;
672 }
673 max + ((a - max).exp() + (b - max).exp()).ln()
674}
675
676/// Digamma function: the logarithmic derivative of the Gamma function.
677///
678/// \[\psi(x) = \frac{d}{dx} \ln \Gamma(x) = \frac{\Gamma'(x)}{\Gamma(x)}\]
679///
680/// # Key properties
681///
682/// - **Recurrence**: \(\psi(x+1) = \psi(x) + \frac{1}{x}\), which follows from
683/// \(\Gamma(x+1) = x\,\Gamma(x)\).
684/// - **Special value**: \(\psi(1) = -\gamma \approx -0.5772\), where \(\gamma\) is
685/// the Euler--Mascheroni constant.
686/// - **Asymptotic**: \(\psi(x) \sim \ln x - \frac{1}{2x}\) for large \(x\).
687/// - **Why it appears here**: the KSG estimator for mutual information
688/// ([`mutual_information_ksg`]) uses digamma to correct for the bias of
689/// nearest-neighbor density estimates.
690///
691/// # Domain
692///
693/// Defined for \(x > 0\). Returns `NaN` for \(x \le 0\).
694///
695/// # Implementation
696///
697/// Uses the recurrence to shift small \(x\) up to \(x \ge 7\), then applies the
698/// asymptotic expansion with Bernoulli-number correction terms.
699///
700/// # Examples
701///
702/// ```
703/// # use logp::digamma;
704/// // psi(1) = -gamma (Euler-Mascheroni constant).
705/// let psi1 = digamma(1.0);
706/// assert!((psi1 - (-0.5772156649)).abs() < 1e-8);
707///
708/// // Recurrence: psi(x+1) = psi(x) + 1/x.
709/// let x = 3.5;
710/// assert!((digamma(x + 1.0) - digamma(x) - 1.0 / x).abs() < 1e-10);
711///
712/// // Non-positive input returns NaN.
713/// assert!(digamma(0.0).is_nan());
714/// assert!(digamma(-1.0).is_nan());
715/// ```
716pub fn digamma(mut x: f64) -> f64 {
717 if x <= 0.0 {
718 return f64::NAN;
719 }
720 let mut result = 0.0;
721 while x < 7.0 {
722 result -= 1.0 / x;
723 x += 1.0;
724 }
725 let r = 1.0 / x;
726 result += x.ln() - 0.5 * r;
727 let r2 = r * r;
728 result -= r2 * (1.0 / 12.0 - r2 * (1.0 / 120.0 - r2 / 252.0));
729 result
730}
731
732/// Bhattacharyya coefficient: the geometric-mean overlap between two distributions.
733///
734/// \[BC(p, q) = \sum_i \sqrt{p_i \, q_i}\]
735///
736/// # Key properties
737///
738/// - **Geometric mean interpretation**: each term \(\sqrt{p_i q_i}\) is the geometric
739/// mean of the two probabilities at bin \(i\). BC sums these, measuring how much
740/// the distributions overlap.
741/// - **Bounded**: \(BC \in [0, 1]\). Equals 1 iff \(p = q\); equals 0 iff supports
742/// are disjoint.
743/// - **Relationship to Hellinger**: \(H^2(p, q) = 1 - BC(p, q)\), so the squared
744/// Hellinger distance is just one minus the Bhattacharyya coefficient.
745/// - **Relationship to Renyi**: at \(\alpha = \tfrac{1}{2}\), the Renyi divergence
746/// gives \(D_{1/2}^R(p \| q) = -2 \ln BC(p, q)\).
747/// - **Connection to alpha family**: \(BC = \rho_{1/2}(p, q)\), a special case of
748/// [`rho_alpha`].
749///
750/// # Examples
751///
752/// ```
753/// # use logp::bhattacharyya_coeff;
754/// // BC(p, p) = 1.
755/// let p = [0.3, 0.7];
756/// assert!((bhattacharyya_coeff(&p, &p, 1e-9).unwrap() - 1.0).abs() < 1e-12);
757///
758/// // Disjoint supports: BC = 0.
759/// let a = [1.0, 0.0];
760/// let b = [0.0, 1.0];
761/// assert!(bhattacharyya_coeff(&a, &b, 1e-9).unwrap().abs() < 1e-15);
762///
763/// // BC is in [0, 1].
764/// let q = [0.5, 0.5];
765/// let bc = bhattacharyya_coeff(&p, &q, 1e-9).unwrap();
766/// assert!(bc >= 0.0 && bc <= 1.0);
767/// ```
768pub fn bhattacharyya_coeff(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
769 ensure_same_len(p, q)?;
770 validate_simplex(p, tol)?;
771 validate_simplex(q, tol)?;
772 let bc: f64 = p
773 .iter()
774 .zip(q.iter())
775 .map(|(&pi, &qi)| (pi * qi).sqrt())
776 .sum();
777 Ok(bc)
778}
779
780/// Bhattacharyya distance \(D_B(p,q) = -\ln BC(p,q)\).
781///
782/// # Examples
783///
784/// ```
785/// # use logp::bhattacharyya_distance;
786/// // D_B(p, p) = 0.
787/// let p = [0.4, 0.6];
788/// assert!(bhattacharyya_distance(&p, &p, 1e-9).unwrap().abs() < 1e-12);
789///
790/// // Non-negative for distinct distributions.
791/// let q = [0.5, 0.5];
792/// assert!(bhattacharyya_distance(&p, &q, 1e-9).unwrap() >= 0.0);
793/// ```
794pub fn bhattacharyya_distance(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
795 let bc = bhattacharyya_coeff(p, q, tol)?;
796 // When supports are disjoint, bc can be 0 (=> +∞ distance). Keep it explicit.
797 if bc == 0.0 {
798 return Err(Error::Domain("Bhattacharyya distance is infinite (BC=0)"));
799 }
800 Ok(-bc.ln())
801}
802
803/// Squared Hellinger distance: one minus the Bhattacharyya coefficient.
804///
805/// \[H^2(p, q) = 1 - \sum_i \sqrt{p_i \, q_i} = 1 - BC(p, q)\]
806///
807/// Bounded in \([0, 1]\). Equals the Amari \(\alpha\)-divergence at \(\alpha = 0\)
808/// (up to a factor of 2).
809///
810/// # Examples
811///
812/// ```
813/// # use logp::hellinger_squared;
814/// // H^2(p, p) = 0.
815/// let p = [0.25, 0.75];
816/// assert!(hellinger_squared(&p, &p, 1e-9).unwrap().abs() < 1e-15);
817///
818/// // Disjoint supports: H^2 = 1.
819/// let a = [1.0, 0.0];
820/// let b = [0.0, 1.0];
821/// assert!((hellinger_squared(&a, &b, 1e-9).unwrap() - 1.0).abs() < 1e-12);
822/// ```
823pub fn hellinger_squared(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
824 let bc = bhattacharyya_coeff(p, q, tol)?;
825 Ok((1.0 - bc).max(0.0))
826}
827
828/// Hellinger distance: the square root of the squared Hellinger distance.
829///
830/// \[H(p, q) = \sqrt{1 - BC(p, q)}\]
831///
832/// Unlike KL divergence, Hellinger is a **proper metric**: it is symmetric, satisfies
833/// the triangle inequality, and is bounded in \([0, 1]\).
834///
835/// # Examples
836///
837/// ```
838/// # use logp::hellinger;
839/// // H(p, p) = 0.
840/// let p = [0.3, 0.7];
841/// assert!(hellinger(&p, &p, 1e-9).unwrap().abs() < 1e-15);
842///
843/// // Symmetric: H(p, q) = H(q, p).
844/// let q = [0.5, 0.5];
845/// let h_pq = hellinger(&p, &q, 1e-9).unwrap();
846/// let h_qp = hellinger(&q, &p, 1e-9).unwrap();
847/// assert!((h_pq - h_qp).abs() < 1e-15);
848///
849/// // Bounded in [0, 1].
850/// assert!(h_pq >= 0.0 && h_pq <= 1.0);
851/// ```
852pub fn hellinger(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
853 Ok(hellinger_squared(p, q, tol)?.sqrt())
854}
855
856fn pow_nonneg(x: f64, a: f64) -> Result<f64> {
857 if x < 0.0 || !x.is_finite() || !a.is_finite() {
858 return Err(Error::Domain("pow_nonneg: invalid input"));
859 }
860 if x == 0.0 {
861 if a == 0.0 {
862 // By continuity in the divergence formulas, treat 0^0 as 1.
863 return Ok(1.0);
864 }
865 if a > 0.0 {
866 return Ok(0.0);
867 }
868 return Err(Error::Domain("0^a for a<0 is infinite"));
869 }
870 Ok(x.powf(a))
871}
872
873/// Alpha-integral: the workhorse behind the entire alpha-family of divergences.
874///
875/// \[\rho_\alpha(p, q) = \sum_i p_i^\alpha \, q_i^{1-\alpha}\]
876///
877/// # Why this matters
878///
879/// This single quantity generates multiple divergence families via simple transforms:
880///
881/// - **Renyi**: \(D_\alpha^R = \frac{1}{\alpha - 1} \ln \rho_\alpha\)
882/// - **Tsallis**: \(D_\alpha^T = \frac{\rho_\alpha - 1}{\alpha - 1}\)
883/// - **Bhattacharyya coefficient**: \(BC = \rho_{1/2}\)
884/// - **Chernoff information**: \(\min_\alpha (-\ln \rho_\alpha)\)
885///
886/// # Key properties
887///
888/// - \(\rho_\alpha(p, p) = 1\) for all \(\alpha\) (since \(\sum p_i = 1\)).
889/// - By Holder's inequality, \(\rho_\alpha(p, q) \le 1\) for \(\alpha \in [0, 1]\).
890/// - Continuous and log-convex in \(\alpha\).
891///
892/// # Examples
893///
894/// ```
895/// # use logp::rho_alpha;
896/// // rho_alpha(p, p, alpha) = 1 for any alpha (since sum(p) = 1).
897/// let p = [0.2, 0.3, 0.5];
898/// assert!((rho_alpha(&p, &p, 0.5, 1e-9).unwrap() - 1.0).abs() < 1e-12);
899/// assert!((rho_alpha(&p, &p, 2.0, 1e-9).unwrap() - 1.0).abs() < 1e-12);
900///
901/// // At alpha = 0.5, rho equals the Bhattacharyya coefficient.
902/// let q = [0.5, 0.25, 0.25];
903/// let rho = rho_alpha(&p, &q, 0.5, 1e-9).unwrap();
904/// let bc = logp::bhattacharyya_coeff(&p, &q, 1e-9).unwrap();
905/// assert!((rho - bc).abs() < 1e-12);
906/// ```
907pub fn rho_alpha(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
908 ensure_same_len(p, q)?;
909 validate_simplex(p, tol)?;
910 validate_simplex(q, tol)?;
911 if !alpha.is_finite() {
912 return Err(Error::InvalidAlpha {
913 alpha,
914 forbidden: f64::NAN,
915 });
916 }
917 let mut s = 0.0;
918 for (&pi, &qi) in p.iter().zip(q.iter()) {
919 let a = pow_nonneg(pi, alpha)?;
920 let b = pow_nonneg(qi, 1.0 - alpha)?;
921 s += a * b;
922 }
923 Ok(s)
924}
925
926/// Renyi divergence in nats: a one-parameter family that interpolates between
927/// different notions of distributional difference.
928///
929/// \[D_\alpha^R(p \| q) = \frac{1}{\alpha - 1} \ln \rho_\alpha(p, q), \quad \alpha > 0,\; \alpha \ne 1\]
930///
931/// # Key properties
932///
933/// - **Limit to KL**: \(\lim_{\alpha \to 1} D_\alpha^R(p \| q) = D_{KL}(p \| q)\)
934/// by L'Hopital's rule (the logarithm and denominator both vanish).
935/// - **Alpha = 1/2**: \(D_{1/2}^R = -2 \ln BC(p, q)\), twice the negative log
936/// Bhattacharyya coefficient.
937/// - **Alpha = infinity**: \(D_\infty^R = \ln \max_i (p_i / q_i)\), the log of the
938/// maximum likelihood ratio. This bounds all other Renyi orders.
939/// - **Monotone in alpha**: \(D_\alpha^R\) is non-decreasing in \(\alpha\).
940/// - **Non-negative**: \(D_\alpha^R(p \| q) \ge 0\), with equality iff \(p = q\).
941///
942/// # Domain
943///
944/// \(\alpha > 0\), \(\alpha \ne 1\). Both `p` and `q` must be simplex distributions.
945///
946/// # Examples
947///
948/// ```
949/// # use logp::renyi_divergence;
950/// // D_alpha(p || p) = 0 for any valid alpha.
951/// let p = [0.3, 0.7];
952/// assert!(renyi_divergence(&p, &p, 2.0, 1e-9).unwrap().abs() < 1e-12);
953///
954/// // Non-negative.
955/// let q = [0.5, 0.5];
956/// assert!(renyi_divergence(&p, &q, 0.5, 1e-9).unwrap() >= -1e-12);
957///
958/// // alpha = 1.0 is forbidden (use KL instead).
959/// assert!(renyi_divergence(&p, &q, 1.0, 1e-9).is_err());
960/// ```
961pub fn renyi_divergence(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
962 if alpha == 1.0 {
963 return Err(Error::InvalidAlpha {
964 alpha,
965 forbidden: 1.0,
966 });
967 }
968 let rho = rho_alpha(p, q, alpha, tol)?;
969 if rho <= 0.0 {
970 return Err(Error::Domain("rho_alpha <= 0"));
971 }
972 Ok(rho.ln() / (alpha - 1.0))
973}
974
975/// Tsallis divergence: a non-extensive generalization of KL divergence from
976/// statistical mechanics.
977///
978/// \[D_\alpha^T(p \| q) = \frac{\rho_\alpha(p, q) - 1}{\alpha - 1}, \quad \alpha \ne 1\]
979///
980/// # Key properties
981///
982/// - **Limit to KL**: \(\lim_{\alpha \to 1} D_\alpha^T(p \| q) = D_{KL}(p \| q)\),
983/// same limit as Renyi but via a different path.
984/// - **Connection to Renyi via deformed logarithm**: Tsallis uses the q-logarithm
985/// \(\ln_q(x) = \frac{x^{1-q} - 1}{1-q}\) where Renyi uses the ordinary log.
986/// Formally: \(D_\alpha^T = \frac{e^{(\alpha-1) D_\alpha^R} - 1}{\alpha - 1}\).
987/// - **Non-extensive**: for independent systems, Tsallis divergence is **not** additive
988/// (unlike KL and Renyi). This property is intentional and models systems with
989/// long-range correlations in statistical physics.
990/// - **Non-negative**: \(D_\alpha^T(p \| q) \ge 0\), with equality iff \(p = q\).
991///
992/// # Domain
993///
994/// \(\alpha \ne 1\). Both `p` and `q` must be simplex distributions.
995///
996/// # Examples
997///
998/// ```
999/// # use logp::tsallis_divergence;
1000/// // D_alpha^T(p || p) = 0 for any valid alpha.
1001/// let p = [0.4, 0.6];
1002/// assert!(tsallis_divergence(&p, &p, 2.0, 1e-9).unwrap().abs() < 1e-12);
1003///
1004/// // Non-negative.
1005/// let q = [0.5, 0.5];
1006/// assert!(tsallis_divergence(&p, &q, 0.5, 1e-9).unwrap() >= -1e-12);
1007///
1008/// // alpha = 1.0 is forbidden.
1009/// assert!(tsallis_divergence(&p, &q, 1.0, 1e-9).is_err());
1010/// ```
1011pub fn tsallis_divergence(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
1012 if alpha == 1.0 {
1013 return Err(Error::InvalidAlpha {
1014 alpha,
1015 forbidden: 1.0,
1016 });
1017 }
1018 Ok((rho_alpha(p, q, alpha, tol)? - 1.0) / (alpha - 1.0))
1019}
1020
1021/// Amari alpha-divergence: a one-parameter family from information geometry that
1022/// continuously interpolates between forward KL, reverse KL, and squared Hellinger.
1023///
1024/// For \(\alpha \notin \{-1, 1\}\):
1025///
1026/// \[D^\alpha(p : q) = \frac{4}{1 - \alpha^2}\left(1 - \rho_{\frac{1-\alpha}{2}}(p, q)\right)\]
1027///
1028/// # Key properties
1029///
1030/// - **\(\alpha = -1\)**: recovers \(D_{KL}(p \| q)\), the forward KL divergence.
1031/// - **\(\alpha = +1\)**: recovers \(D_{KL}(q \| p)\), the reverse KL divergence.
1032/// - **\(\alpha = 0\)**: gives \(4(1 - BC(p,q)) = 4\,H^2(p,q)\), proportional to
1033/// the squared Hellinger distance.
1034/// - **Duality**: \(D^\alpha(p : q) = D^{-\alpha}(q : p)\). Swapping the sign of
1035/// \(\alpha\) is the same as swapping the arguments.
1036/// - **Non-negative**: \(D^\alpha(p : q) \ge 0\), with equality iff \(p = q\).
1037/// - **Information geometry**: the Amari family parameterizes the \(\alpha\)-connections
1038/// on the statistical manifold (Amari & Nagaoka, 2000).
1039///
1040/// # Examples
1041///
1042/// ```
1043/// # use logp::{amari_alpha_divergence, kl_divergence, hellinger_squared};
1044/// let p = [0.3, 0.7];
1045/// let q = [0.5, 0.5];
1046/// let tol = 1e-9;
1047///
1048/// // alpha = -1 gives forward KL(p || q).
1049/// let amari_neg1 = amari_alpha_divergence(&p, &q, -1.0, tol).unwrap();
1050/// let kl_fwd = kl_divergence(&p, &q, tol).unwrap();
1051/// assert!((amari_neg1 - kl_fwd).abs() < 1e-6);
1052///
1053/// // alpha = 0 gives 4 * H^2(p, q).
1054/// let amari_0 = amari_alpha_divergence(&p, &q, 0.0, tol).unwrap();
1055/// let h2 = hellinger_squared(&p, &q, tol).unwrap();
1056/// assert!((amari_0 - 4.0 * h2).abs() < 1e-10);
1057/// ```
1058pub fn amari_alpha_divergence(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
1059 if !alpha.is_finite() {
1060 return Err(Error::InvalidAlpha {
1061 alpha,
1062 forbidden: f64::NAN,
1063 });
1064 }
1065 // Numerically stable handling near ±1.
1066 let eps = 1e-10;
1067 if (alpha + 1.0).abs() <= eps {
1068 return kl_divergence(p, q, tol);
1069 }
1070 if (alpha - 1.0).abs() <= eps {
1071 return kl_divergence(q, p, tol);
1072 }
1073 let t = (1.0 - alpha) / 2.0;
1074 let rho = rho_alpha(p, q, t, tol)?;
1075 Ok((4.0 / (1.0 - alpha * alpha)) * (1.0 - rho))
1076}
1077
1078/// Csiszar f-divergence: the most general class of divergences that respect
1079/// sufficient statistics (information monotonicity).
1080///
1081/// \[D_f(p \| q) = \sum_i q_i \, f\!\left(\frac{p_i}{q_i}\right)\]
1082///
1083/// where \(f\) is a convex function with \(f(1) = 0\).
1084///
1085/// # Information monotonicity theorem
1086///
1087/// The defining property of f-divergences (Csiszar, 1967): for any Markov kernel
1088/// (stochastic map) \(T\),
1089///
1090/// \[D_f(Tp \| Tq) \le D_f(p \| q)\]
1091///
1092/// Coarse-graining (merging bins) cannot increase the divergence. This is the
1093/// information-theoretic analogue of the data processing inequality.
1094///
1095/// # Common f-generators
1096///
1097/// | Divergence | \(f(t)\) |
1098/// |---|---|
1099/// | KL divergence | \(t \ln t\) |
1100/// | Reverse KL | \(-\ln t\) |
1101/// | Squared Hellinger | \((\sqrt{t} - 1)^2\) |
1102/// | Total variation | \(\tfrac{1}{2} |t - 1|\) |
1103/// | Chi-squared | \((t - 1)^2\) |
1104/// | Jensen-Shannon | \(t \ln t - (1+t) \ln \tfrac{1+t}{2}\) |
1105///
1106/// # Edge cases
1107///
1108/// When `q_i = 0`:
1109/// - if `p_i = 0`, the contribution is treated as 0 (by continuity).
1110/// - if `p_i > 0`, the divergence is infinite; we return an error.
1111///
1112/// # Examples
1113///
1114/// ```
1115/// # use logp::{csiszar_f_divergence, kl_divergence};
1116/// let p = [0.3, 0.7];
1117/// let q = [0.5, 0.5];
1118///
1119/// // f(t) = t*ln(t) recovers KL divergence.
1120/// let cs = csiszar_f_divergence(&p, &q, |t| t * t.ln(), 1e-9).unwrap();
1121/// let kl = kl_divergence(&p, &q, 1e-9).unwrap();
1122/// assert!((cs - kl).abs() < 1e-10);
1123///
1124/// // f(t) = (t - 1)^2 gives chi-squared divergence.
1125/// let chi2 = csiszar_f_divergence(&p, &q, |t| (t - 1.0).powi(2), 1e-9).unwrap();
1126/// assert!(chi2 >= 0.0);
1127/// ```
1128pub fn csiszar_f_divergence(p: &[f64], q: &[f64], f: impl Fn(f64) -> f64, tol: f64) -> Result<f64> {
1129 ensure_same_len(p, q)?;
1130 validate_simplex(p, tol)?;
1131 validate_simplex(q, tol)?;
1132
1133 let mut d = 0.0;
1134 for (&pi, &qi) in p.iter().zip(q.iter()) {
1135 if qi == 0.0 {
1136 if pi == 0.0 {
1137 continue;
1138 }
1139 return Err(Error::Domain("f-divergence undefined: q_i=0 while p_i>0"));
1140 }
1141 d += qi * f(pi / qi);
1142 }
1143 Ok(d)
1144}
1145
1146/// Bregman generator: a convex function \(F\) and its gradient.
1147pub trait BregmanGenerator {
1148 /// Evaluate the potential \(F(x)\).
1149 fn f(&self, x: &[f64]) -> Result<f64>;
1150
1151 /// Write \(\nabla F(x)\) into `out`.
1152 fn grad_into(&self, x: &[f64], out: &mut [f64]) -> Result<()>;
1153}
1154
1155/// Bregman divergence: the gap between a convex function and its tangent approximation.
1156///
1157/// \[B_F(p, q) = F(p) - F(q) - \langle p - q,\, \nabla F(q) \rangle\]
1158///
1159/// # Key properties
1160///
1161/// - **Non-negative**: \(B_F(p, q) \ge 0\) by convexity of \(F\), with equality iff
1162/// \(p = q\).
1163/// - **Not symmetric** in general: \(B_F(p, q) \ne B_F(q, p)\).
1164/// - **Generalized Pythagorean theorem**: for an affine subspace \(S\) and its
1165/// Bregman projection \(q^* = \arg\min_{q \in S} B_F(p, q)\), the three-point
1166/// identity holds: \(B_F(p, q) = B_F(p, q^*) + B_F(q^*, q)\) for all \(q \in S\).
1167/// This is the foundation of dually flat geometry (Amari).
1168/// - **Not an f-divergence**: Bregman divergences are **not** information monotone
1169/// in general. They live in a different branch of the divergence taxonomy.
1170/// - **Examples**: squared Euclidean (\(F = \tfrac{1}{2}\|x\|^2\)) gives
1171/// \(B_F(p,q) = \tfrac{1}{2}\|p - q\|^2\); negative entropy
1172/// (\(F = \sum x_i \ln x_i\)) gives the KL divergence.
1173///
1174/// # Examples
1175///
1176/// ```
1177/// # use logp::{bregman_divergence, SquaredL2};
1178/// // Squared-L2 generator: B_F(p, q) = 0.5 * ||p - q||^2.
1179/// let gen = SquaredL2;
1180/// let p = [1.0, 2.0, 3.0];
1181/// let q = [1.5, 1.5, 2.5];
1182/// let mut grad = [0.0; 3];
1183/// let b = bregman_divergence(&gen, &p, &q, &mut grad).unwrap();
1184/// let expected = 0.5 * ((0.5_f64).powi(2) + (0.5_f64).powi(2) + (0.5_f64).powi(2));
1185/// assert!((b - expected).abs() < 1e-12);
1186///
1187/// // B_F(p, p) = 0.
1188/// let mut grad2 = [0.0; 3];
1189/// assert!(bregman_divergence(&gen, &p, &p, &mut grad2).unwrap().abs() < 1e-15);
1190/// ```
1191pub fn bregman_divergence(
1192 gen: &impl BregmanGenerator,
1193 p: &[f64],
1194 q: &[f64],
1195 grad_q: &mut [f64],
1196) -> Result<f64> {
1197 ensure_nonempty(p)?;
1198 ensure_same_len(p, q)?;
1199 if grad_q.len() != q.len() {
1200 return Err(Error::LengthMismatch(grad_q.len(), q.len()));
1201 }
1202 gen.grad_into(q, grad_q)?;
1203 let fp = gen.f(p)?;
1204 let fq = gen.f(q)?;
1205 let mut inner = 0.0;
1206 for i in 0..p.len() {
1207 inner += (p[i] - q[i]) * grad_q[i];
1208 }
1209 Ok(fp - fq - inner)
1210}
1211
1212/// Total Bregman divergence as shown in Nielsen’s taxonomy diagram:
1213///
1214/// \(tB_F(p,q) = \frac{B_F(p,q)}{\sqrt{1 + \|\nabla F(q)\|^2}}\).
1215///
1216/// # Examples
1217///
1218/// ```
1219/// # use logp::{total_bregman_divergence, bregman_divergence, SquaredL2};
1220/// let gen = SquaredL2;
1221/// let p = [1.0, 2.0];
1222/// let q = [3.0, 4.0];
1223/// let mut grad = [0.0; 2];
1224///
1225/// let tb = total_bregman_divergence(&gen, &p, &q, &mut grad).unwrap();
1226///
1227/// // Total Bregman <= Bregman (normalization divides by >= 1).
1228/// let mut grad2 = [0.0; 2];
1229/// let b = bregman_divergence(&gen, &p, &q, &mut grad2).unwrap();
1230/// assert!(tb <= b + 1e-12);
1231/// ```
1232pub fn total_bregman_divergence(
1233 gen: &impl BregmanGenerator,
1234 p: &[f64],
1235 q: &[f64],
1236 grad_q: &mut [f64],
1237) -> Result<f64> {
1238 let b = bregman_divergence(gen, p, q, grad_q)?;
1239 let grad_norm_sq: f64 = grad_q.iter().map(|&x| x * x).sum();
1240 Ok(b / (1.0 + grad_norm_sq).sqrt())
1241}
1242
1243/// Squared Euclidean Bregman generator: \(F(x)=\tfrac12\|x\|_2^2\), \(\nabla F(x)=x\).
1244#[derive(Debug, Clone, Copy, Default)]
1245pub struct SquaredL2;
1246
1247impl BregmanGenerator for SquaredL2 {
1248 fn f(&self, x: &[f64]) -> Result<f64> {
1249 ensure_nonempty(x)?;
1250 Ok(0.5 * x.iter().map(|&v| v * v).sum::<f64>())
1251 }
1252
1253 fn grad_into(&self, x: &[f64], out: &mut [f64]) -> Result<()> {
1254 ensure_nonempty(x)?;
1255 if out.len() != x.len() {
1256 return Err(Error::LengthMismatch(out.len(), x.len()));
1257 }
1258 out.copy_from_slice(x);
1259 Ok(())
1260 }
1261}
1262
1263#[cfg(test)]
1264mod tests {
1265 use super::*;
1266 use proptest::prelude::*;
1267
1268 const TOL: f64 = 1e-9;
1269
1270 fn simplex_vec(len: usize) -> impl Strategy<Value = Vec<f64>> {
1271 // Draw nonnegative weights then normalize.
1272 prop::collection::vec(0.0f64..10.0, len).prop_map(|mut v| {
1273 let s: f64 = v.iter().sum();
1274 if s == 0.0 {
1275 v[0] = 1.0;
1276 return v;
1277 }
1278 for x in v.iter_mut() {
1279 *x /= s;
1280 }
1281 v
1282 })
1283 }
1284
1285 fn simplex_vec_pos(len: usize, eps: f64) -> impl Strategy<Value = Vec<f64>> {
1286 prop::collection::vec(0.0f64..10.0, len).prop_map(move |mut v| {
1287 // Add a small floor to avoid exact zeros (needed for KL-style domains).
1288 for x in v.iter_mut() {
1289 *x += eps;
1290 }
1291 let s: f64 = v.iter().sum();
1292 for x in v.iter_mut() {
1293 *x /= s;
1294 }
1295 v
1296 })
1297 }
1298
1299 fn random_partition(n: usize) -> impl Strategy<Value = Vec<usize>> {
1300 // Partition indices into k buckets (k chosen implicitly).
1301 // We generate a label in [0, n) for each index and later reindex to compact labels.
1302 prop::collection::vec(0usize..n, n).prop_map(|labels| {
1303 // Compress labels to 0..k-1 while preserving equality pattern.
1304 use std::collections::BTreeMap;
1305 let mut map = BTreeMap::<usize, usize>::new();
1306 let mut next = 0usize;
1307 labels
1308 .into_iter()
1309 .map(|l| {
1310 *map.entry(l).or_insert_with(|| {
1311 let id = next;
1312 next += 1;
1313 id
1314 })
1315 })
1316 .collect::<Vec<_>>()
1317 })
1318 }
1319
1320 fn coarse_grain(p: &[f64], labels: &[usize]) -> Vec<f64> {
1321 let k = labels.iter().copied().max().unwrap_or(0) + 1;
1322 let mut out = vec![0.0; k];
1323 for (i, &lab) in labels.iter().enumerate() {
1324 out[lab] += p[i];
1325 }
1326 out
1327 }
1328
1329 fn l1(p: &[f64], q: &[f64]) -> f64 {
1330 p.iter().zip(q.iter()).map(|(&a, &b)| (a - b).abs()).sum()
1331 }
1332
1333 #[test]
1334 fn test_entropy_unchecked() {
1335 let p = [0.5, 0.5];
1336 let h = entropy_unchecked(&p);
1337 // -0.5*ln(0.5) - 0.5*ln(0.5) = -ln(0.5) = ln(2)
1338 assert!((h - LN_2).abs() < 1e-12);
1339 }
1340
1341 #[test]
1342 fn js_is_bounded_by_ln2() {
1343 let p = [1.0, 0.0];
1344 let q = [0.0, 1.0];
1345 let js = jensen_shannon_divergence(&p, &q, TOL).unwrap();
1346 assert!(js <= LN_2 + 1e-12);
1347 assert!(js >= 0.0);
1348 }
1349
1350 #[test]
1351 fn mutual_information_independent_is_zero() {
1352 // p(x,y) = p(x)p(y) ⇒ I(X;Y)=0
1353 let p_x = [0.5, 0.5];
1354 let p_y = [0.25, 0.75];
1355 // Row-major 2x2:
1356 // [0.125, 0.375,
1357 // 0.125, 0.375]
1358 let p_xy = [
1359 p_x[0] * p_y[0],
1360 p_x[0] * p_y[1],
1361 p_x[1] * p_y[0],
1362 p_x[1] * p_y[1],
1363 ];
1364 let mi = mutual_information(&p_xy, 2, 2, TOL).unwrap();
1365 assert!(mi.abs() < 1e-12, "mi={}", mi);
1366 }
1367
1368 #[test]
1369 fn mutual_information_perfect_correlation_is_ln2() {
1370 // X=Y uniform bit ⇒ I(X;Y)=ln 2 (nats)
1371 let p_xy = [0.5, 0.0, 0.0, 0.5]; // 2x2 diagonal
1372 let mi = mutual_information(&p_xy, 2, 2, TOL).unwrap();
1373 assert!((mi - LN_2).abs() < 1e-12, "mi={}", mi);
1374 }
1375
1376 #[test]
1377 fn bregman_squared_l2_matches_half_l2() {
1378 let gen = SquaredL2;
1379 let p = [1.0, 2.0, 3.0];
1380 let q = [1.5, 1.5, 2.5];
1381 let mut grad = [0.0; 3];
1382 let b = bregman_divergence(&gen, &p, &q, &mut grad).unwrap();
1383 let expected = 0.5
1384 * p.iter()
1385 .zip(q.iter())
1386 .map(|(&a, &b)| (a - b) * (a - b))
1387 .sum::<f64>();
1388 assert!((b - expected).abs() < 1e-12);
1389 }
1390
1391 // --- Entropy tests ---
1392
1393 #[test]
1394 fn entropy_nats_uniform_is_ln_n() {
1395 // Uniform distribution over n items: H = ln(n)
1396 for n in [2, 4, 8, 16] {
1397 let p: Vec<f64> = vec![1.0 / n as f64; n];
1398 let h = entropy_nats(&p, TOL).unwrap();
1399 let expected = (n as f64).ln();
1400 assert!((h - expected).abs() < 1e-12, "n={n}: h={h} expected={expected}");
1401 }
1402 }
1403
1404 #[test]
1405 fn entropy_nats_singleton_is_zero() {
1406 let h = entropy_nats(&[1.0], TOL).unwrap();
1407 assert!(h.abs() < 1e-15);
1408 }
1409
1410 #[test]
1411 fn entropy_bits_converts_correctly() {
1412 let p = [0.25, 0.75];
1413 let nats = entropy_nats(&p, TOL).unwrap();
1414 let bits = entropy_bits(&p, TOL).unwrap();
1415 assert!((bits - nats / LN_2).abs() < 1e-12);
1416 }
1417
1418 // --- Cross-entropy tests ---
1419
1420 #[test]
1421 fn cross_entropy_identity_h_pq_eq_h_p_plus_kl() {
1422 let p = [0.3, 0.7];
1423 let q = [0.5, 0.5];
1424 let h_pq = cross_entropy_nats(&p, &q, TOL).unwrap();
1425 let h_p = entropy_nats(&p, TOL).unwrap();
1426 let kl = kl_divergence(&p, &q, TOL).unwrap();
1427 assert!((h_pq - (h_p + kl)).abs() < 1e-12);
1428 }
1429
1430 #[test]
1431 fn cross_entropy_rejects_zero_q_with_positive_p() {
1432 let p = [0.5, 0.5];
1433 let q = [1.0, 0.0]; // q[1]=0 but p[1]=0.5
1434 assert!(cross_entropy_nats(&p, &q, TOL).is_err());
1435 }
1436
1437 // --- Validate / normalize tests ---
1438
1439 #[test]
1440 fn validate_simplex_accepts_valid() {
1441 assert!(validate_simplex(&[0.3, 0.7], TOL).is_ok());
1442 assert!(validate_simplex(&[1.0], TOL).is_ok());
1443 }
1444
1445 #[test]
1446 fn validate_simplex_rejects_bad_sum() {
1447 assert!(validate_simplex(&[0.3, 0.6], TOL).is_err()); // sum=0.9
1448 }
1449
1450 #[test]
1451 fn validate_simplex_rejects_negative() {
1452 assert!(validate_simplex(&[1.5, -0.5], TOL).is_err());
1453 }
1454
1455 #[test]
1456 fn validate_simplex_rejects_empty() {
1457 assert!(validate_simplex(&[], TOL).is_err());
1458 }
1459
1460 #[test]
1461 fn normalize_in_place_works() {
1462 let mut v = vec![2.0, 3.0];
1463 let s = normalize_in_place(&mut v).unwrap();
1464 assert!((s - 5.0).abs() < 1e-12);
1465 assert!((v[0] - 0.4).abs() < 1e-12);
1466 assert!((v[1] - 0.6).abs() < 1e-12);
1467 }
1468
1469 #[test]
1470 fn normalize_in_place_rejects_zero_sum() {
1471 let mut v = vec![0.0, 0.0];
1472 assert!(normalize_in_place(&mut v).is_err());
1473 }
1474
1475 // --- Hellinger / Bhattacharyya tests ---
1476
1477 #[test]
1478 fn hellinger_identical_is_zero() {
1479 let p = [0.25, 0.75];
1480 let h = hellinger(&p, &p, TOL).unwrap();
1481 assert!(h.abs() < 1e-12);
1482 }
1483
1484 #[test]
1485 fn hellinger_squared_in_unit_interval() {
1486 let p = [0.1, 0.9];
1487 let q = [0.9, 0.1];
1488 let h2 = hellinger_squared(&p, &q, TOL).unwrap();
1489 assert!(h2 >= -1e-12 && h2 <= 1.0 + 1e-12, "h2={h2}");
1490 }
1491
1492 #[test]
1493 fn bhattacharyya_coeff_identical_is_one() {
1494 let p = [0.3, 0.7];
1495 let bc = bhattacharyya_coeff(&p, &p, TOL).unwrap();
1496 assert!((bc - 1.0).abs() < 1e-12);
1497 }
1498
1499 #[test]
1500 fn bhattacharyya_distance_identical_is_zero() {
1501 let p = [0.5, 0.5];
1502 let d = bhattacharyya_distance(&p, &p, TOL).unwrap();
1503 assert!(d.abs() < 1e-12);
1504 }
1505
1506 // --- Renyi / Tsallis tests ---
1507
1508 #[test]
1509 fn renyi_alpha_half_on_simple_case() {
1510 let p = [0.5, 0.5];
1511 let q = [0.25, 0.75];
1512 // alpha=0.5 should be well-defined and non-negative
1513 let r = renyi_divergence(&p, &q, 0.5, TOL).unwrap();
1514 assert!(r >= -1e-12, "renyi={r}");
1515 }
1516
1517 #[test]
1518 fn renyi_identical_is_zero() {
1519 let p = [0.3, 0.7];
1520 let r = renyi_divergence(&p, &p, 2.0, TOL).unwrap();
1521 assert!(r.abs() < 1e-12, "renyi(p,p)={r}");
1522 }
1523
1524 #[test]
1525 fn tsallis_identical_is_zero() {
1526 let p = [0.4, 0.6];
1527 let t = tsallis_divergence(&p, &p, 2.0, TOL).unwrap();
1528 assert!(t.abs() < 1e-12, "tsallis(p,p)={t}");
1529 }
1530
1531 // --- Digamma test ---
1532
1533 #[test]
1534 fn digamma_at_one_is_neg_euler_mascheroni() {
1535 let psi1 = digamma(1.0);
1536 // digamma(1) = -gamma where gamma ~= 0.5772156649
1537 assert!((psi1 - (-0.5772156649)).abs() < 1e-8, "psi(1)={psi1}");
1538 }
1539
1540 #[test]
1541 fn digamma_recurrence_relation() {
1542 // digamma(x+1) = digamma(x) + 1/x
1543 for &x in &[1.0, 2.0, 3.5, 10.0] {
1544 let lhs = digamma(x + 1.0);
1545 let rhs = digamma(x) + 1.0 / x;
1546 assert!((lhs - rhs).abs() < 1e-8, "recurrence at x={x}: {lhs} vs {rhs}");
1547 }
1548 }
1549
1550 #[test]
1551 fn pmi_independent_is_zero() {
1552 // PMI(x,y) = log(p(x,y) / (p(x)*p(y))). If independent: p(x,y) = p(x)*p(y)
1553 let pmi_val = pmi(0.06, 0.3, 0.2); // 0.3 * 0.2 = 0.06
1554 assert!(pmi_val.abs() < 1e-10, "PMI of independent events should be 0: {pmi_val}");
1555 }
1556
1557 #[test]
1558 fn pmi_positive_for_correlated() {
1559 // If p(x,y) > p(x)*p(y), events are positively correlated
1560 let pmi_val = pmi(0.4, 0.5, 0.5); // 0.4 > 0.5*0.5 = 0.25
1561 assert!(pmi_val > 0.0, "correlated events should have positive PMI: {pmi_val}");
1562 }
1563
1564 #[test]
1565 fn renyi_approaches_kl_as_alpha_to_one() {
1566 let p = [0.3, 0.7];
1567 let q = [0.5, 0.5];
1568 let tol = 1e-9;
1569 let kl = kl_divergence(&p, &q, tol).unwrap();
1570 // Renyi(alpha) -> KL as alpha -> 1
1571 let r099 = renyi_divergence(&p, &q, 0.99, tol).unwrap();
1572 let r0999 = renyi_divergence(&p, &q, 0.999, tol).unwrap();
1573 assert!((r099 - kl).abs() < 0.01, "Renyi(0.99)={r099}, KL={kl}");
1574 assert!((r0999 - kl).abs() < 0.001, "Renyi(0.999)={r0999}, KL={kl}");
1575 }
1576
1577 #[test]
1578 fn amari_alpha_neg1_is_kl_forward() {
1579 // Amari alpha=-1 returns KL(p||q) per the implementation
1580 let p = [0.3, 0.7];
1581 let q = [0.5, 0.5];
1582 let tol = 1e-9;
1583 let kl_pq = kl_divergence(&p, &q, tol).unwrap();
1584 let amari = amari_alpha_divergence(&p, &q, -1.0, tol).unwrap();
1585 assert!((amari - kl_pq).abs() < 1e-6, "Amari(-1)={amari}, KL(p||q)={kl_pq}");
1586 }
1587
1588 #[test]
1589 fn amari_alpha_pos1_is_kl_reverse() {
1590 // Amari alpha=+1 returns KL(q||p) per the implementation
1591 let p = [0.3, 0.7];
1592 let q = [0.5, 0.5];
1593 let tol = 1e-9;
1594 let kl_qp = kl_divergence(&q, &p, tol).unwrap();
1595 let amari = amari_alpha_divergence(&p, &q, 1.0, tol).unwrap();
1596 assert!((amari - kl_qp).abs() < 1e-6, "Amari(1)={amari}, KL(q||p)={kl_qp}");
1597 }
1598
1599 #[test]
1600 fn csiszar_with_kl_generator_matches_kl() {
1601 // f(t) = t*ln(t) gives KL divergence
1602 let p = [0.3, 0.7];
1603 let q = [0.5, 0.5];
1604 let tol = 1e-9;
1605 let kl = kl_divergence(&p, &q, tol).unwrap();
1606 let cs = csiszar_f_divergence(&p, &q, |t| t * t.ln(), tol).unwrap();
1607 assert!((cs - kl).abs() < 1e-6, "Csiszar(t*ln(t))={cs}, KL={kl}");
1608 }
1609
1610 #[test]
1611 fn mutual_information_deterministic_equals_entropy() {
1612 // If Y = f(X), MI(X;Y) = H(X)
1613 // Joint: p(x=0,y=0)=0.3, p(x=1,y=1)=0.7
1614 let p_xy = [0.3, 0.0, 0.0, 0.7]; // 2x2 joint
1615 let mi = mutual_information(&p_xy, 2, 2, 1e-9).unwrap();
1616 let h_x = entropy_nats(&[0.3, 0.7], 1e-9).unwrap();
1617 assert!((mi - h_x).abs() < 1e-6, "MI={mi}, H(X)={h_x}");
1618 }
1619
1620 proptest! {
1621 #[test]
1622 fn kl_is_nonnegative(p in simplex_vec_pos(8, 1e-6), q in simplex_vec_pos(8, 1e-6)) {
1623 let d = kl_divergence(&p, &q, 1e-6).unwrap();
1624 prop_assert!(d >= -1e-12);
1625 }
1626
1627 #[test]
1628 fn js_is_bounded(p in simplex_vec(16), q in simplex_vec(16)) {
1629 let js = jensen_shannon_divergence(&p, &q, 1e-6).unwrap();
1630 prop_assert!(js >= -1e-12);
1631 prop_assert!(js <= LN_2 + 1e-9);
1632 }
1633
1634 #[test]
1635 fn prop_kl_gaussians_is_nonnegative(
1636 mu1 in prop::collection::vec(-10.0f64..10.0, 1..16),
1637 std1 in prop::collection::vec(0.1f64..5.0, 1..16),
1638 mu2 in prop::collection::vec(-10.0f64..10.0, 1..16),
1639 std2 in prop::collection::vec(0.1f64..5.0, 1..16),
1640 ) {
1641 let n = mu1.len().min(std1.len()).min(mu2.len()).min(std2.len());
1642 let d = kl_divergence_gaussians(&mu1[..n], &std1[..n], &mu2[..n], &std2[..n]).unwrap();
1643 // KL divergence is always non-negative.
1644 prop_assert!(d >= -1e-12);
1645 }
1646
1647 #[test]
1648 fn prop_kl_gaussians_is_zero_for_identical(
1649 mu in prop::collection::vec(-10.0f64..10.0, 1..16),
1650 std in prop::collection::vec(0.1f64..5.0, 1..16),
1651 ) {
1652 let n = mu.len().min(std.len());
1653 let d = kl_divergence_gaussians(&mu[..n], &std[..n], &mu[..n], &std[..n]).unwrap();
1654 prop_assert!(d.abs() < 1e-12);
1655 }
1656
1657 #[test]
1658 fn f_divergence_monotone_under_coarse_graining(
1659 p in simplex_vec_pos(12, 1e-6),
1660 q in simplex_vec_pos(12, 1e-6),
1661 labels in random_partition(12),
1662 ) {
1663 // Use KL as an f-divergence instance: f(t)=t ln t.
1664 // D_KL(p||q) = Σ q_i f(p_i/q_i).
1665 let f = |t: f64| if t == 0.0 { 0.0 } else { t * t.ln() };
1666 let d_f = csiszar_f_divergence(&p, &q, f, 1e-6).unwrap();
1667
1668 let pc = coarse_grain(&p, &labels);
1669 let qc = coarse_grain(&q, &labels);
1670 let d_fc = csiszar_f_divergence(&pc, &qc, f, 1e-6).unwrap();
1671
1672 // Coarse graining should not increase.
1673 prop_assert!(d_fc <= d_f + 1e-9);
1674 }
1675 }
1676
1677 // Heavier “theorem-ish” checks: keep case count modest so `cargo test` stays fast.
1678 proptest! {
1679 #![proptest_config(ProptestConfig { cases: 64, .. ProptestConfig::default() })]
1680
1681 #[test]
1682 fn pinsker_kl_lower_bounds_l1_squared(
1683 p in simplex_vec_pos(16, 1e-6),
1684 q in simplex_vec_pos(16, 1e-6),
1685 ) {
1686 // Pinsker: TV(p,q)^2 <= (1/2) KL(p||q)
1687 // where TV = (1/2)||p-q||_1. Rearranged: KL(p||q) >= 0.5 * ||p-q||_1^2.
1688 let kl = kl_divergence(&p, &q, 1e-6).unwrap();
1689 let d1 = l1(&p, &q);
1690 prop_assert!(kl + 1e-9 >= 0.5 * d1 * d1, "kl={kl} l1={d1}");
1691 }
1692
1693 #[test]
1694 fn sqrt_js_satisfies_triangle_inequality(
1695 p in simplex_vec(12),
1696 q in simplex_vec(12),
1697 r in simplex_vec(12),
1698 ) {
1699 // Known fact: sqrt(JS) is a metric on the simplex.
1700 let js_pq = jensen_shannon_divergence(&p, &q, 1e-6).unwrap().max(0.0).sqrt();
1701 let js_qr = jensen_shannon_divergence(&q, &r, 1e-6).unwrap().max(0.0).sqrt();
1702 let js_pr = jensen_shannon_divergence(&p, &r, 1e-6).unwrap().max(0.0).sqrt();
1703 prop_assert!(js_pr <= js_pq + js_qr + 1e-7, "js_pr={js_pr} js_pq+js_qr={}", js_pq+js_qr);
1704 }
1705
1706 #[test]
1707 fn mutual_information_equals_kl_to_product(
1708 // Ensure strictly positive so KL domains are satisfied.
1709 p_xy in simplex_vec_pos(16, 1e-6),
1710 nx in 2usize..=4,
1711 ny in 2usize..=4,
1712 ) {
1713 // We need p_xy to have length nx*ny; we will truncate/renormalize a fixed-length draw.
1714 let n = nx * ny;
1715 let mut joint = p_xy;
1716 joint.truncate(n);
1717 // Renormalize after truncation.
1718 let _ = normalize_in_place(&mut joint).unwrap();
1719
1720 // Compute MI via the dedicated function.
1721 let mi = mutual_information(&joint, nx, ny, 1e-6).unwrap();
1722
1723 // Compute product of marginals and KL(joint || product).
1724 let mut p_x = vec![0.0; nx];
1725 let mut p_y = vec![0.0; ny];
1726 for i in 0..nx {
1727 for j in 0..ny {
1728 let p = joint[i * ny + j];
1729 p_x[i] += p;
1730 p_y[j] += p;
1731 }
1732 }
1733 let mut prod = vec![0.0; n];
1734 for i in 0..nx {
1735 for j in 0..ny {
1736 prod[i * ny + j] = p_x[i] * p_y[j];
1737 }
1738 }
1739 let kl = kl_divergence(&joint, &prod, 1e-6).unwrap();
1740
1741 prop_assert!((mi - kl).abs() < 1e-9, "mi={mi} kl={kl}");
1742 }
1743
1744 #[test]
1745 fn hellinger_satisfies_triangle_inequality(
1746 p in simplex_vec(8),
1747 q in simplex_vec(8),
1748 r in simplex_vec(8),
1749 ) {
1750 let h_pq = hellinger(&p, &q, 1e-6).unwrap();
1751 let h_qr = hellinger(&q, &r, 1e-6).unwrap();
1752 let h_pr = hellinger(&p, &r, 1e-6).unwrap();
1753 prop_assert!(h_pr <= h_pq + h_qr + 1e-7, "h_pr={h_pr} h_pq+h_qr={}", h_pq + h_qr);
1754 }
1755 }
1756
1757 // --- total_bregman_divergence ---
1758
1759 #[test]
1760 fn total_bregman_le_bregman() {
1761 // tB_F(p, q) <= B_F(p, q) because the denominator sqrt(1 + ||grad||^2) >= 1.
1762 let gen = SquaredL2;
1763 let p = [1.0, 2.0, 3.0];
1764 let q = [4.0, 5.0, 6.0];
1765 let mut grad1 = [0.0; 3];
1766 let mut grad2 = [0.0; 3];
1767 let b = bregman_divergence(&gen, &p, &q, &mut grad1).unwrap();
1768 let tb = total_bregman_divergence(&gen, &p, &q, &mut grad2).unwrap();
1769 assert!(tb <= b + 1e-12, "total_bregman={tb} > bregman={b}");
1770 assert!(tb >= 0.0);
1771 }
1772
1773 #[test]
1774 fn total_bregman_is_zero_for_identical() {
1775 let gen = SquaredL2;
1776 let p = [1.0, 2.0];
1777 let mut grad = [0.0; 2];
1778 let tb = total_bregman_divergence(&gen, &p, &p, &mut grad).unwrap();
1779 assert!(tb.abs() < 1e-15);
1780 }
1781
1782 // --- rho_alpha ---
1783
1784 #[test]
1785 fn rho_alpha_self_is_one() {
1786 let p = [0.1, 0.2, 0.3, 0.4];
1787 for alpha in [0.0, 0.25, 0.5, 0.75, 1.0, 2.0, -1.0] {
1788 let r = rho_alpha(&p, &p, alpha, TOL).unwrap();
1789 assert!((r - 1.0).abs() < 1e-10, "rho_alpha(p,p,{alpha})={r}");
1790 }
1791 }
1792
1793 // --- digamma negative domain ---
1794
1795 #[test]
1796 fn digamma_nonpositive_is_nan() {
1797 assert!(digamma(0.0).is_nan());
1798 assert!(digamma(-1.0).is_nan());
1799 assert!(digamma(-100.0).is_nan());
1800 }
1801
1802 // --- pmi edge cases ---
1803
1804 #[test]
1805 fn pmi_zero_joint_returns_zero() {
1806 assert_eq!(pmi(0.0, 0.5, 0.5), 0.0);
1807 }
1808
1809 #[test]
1810 fn pmi_zero_marginal_returns_zero() {
1811 // When px or py is zero, return 0 by convention.
1812 assert_eq!(pmi(0.1, 0.0, 0.5), 0.0);
1813 assert_eq!(pmi(0.1, 0.5, 0.0), 0.0);
1814 }
1815
1816 #[test]
1817 fn pmi_all_zero_returns_zero() {
1818 assert_eq!(pmi(0.0, 0.0, 0.0), 0.0);
1819 }
1820}