Skip to main content

logp/
lib.rs

1//! # logp
2//!
3//! Information theory primitives: entropies and divergences.
4//!
5//! ## Scope
6//!
7//! This crate is **L1 (Logic)** in the mathematical foundation: it should stay small and reusable.
8//! It provides *scalar* information measures that appear across clustering, ranking,
9//! evaluation, and geometry:
10//!
11//! - Shannon entropy and cross-entropy
12//! - KL / Jensen–Shannon divergences
13//! - Csiszár \(f\)-divergences (a.k.a. *information monotone* divergences)
14//! - Bhattacharyya coefficient, Rényi/Tsallis families
15//! - Bregman divergences (convex-analytic, not generally monotone)
16//!
17//! ## Distances vs divergences (terminology that prevents bugs)
18//!
19//! A **divergence** \(D(p:q)\) is usually required to satisfy:
20//!
21//! - \(D(p:q) \ge 0\)
22//! - \(D(p:p) = 0\)
23//!
24//! but it is typically **not** symmetric and **not** a metric (no triangle inequality).
25//! Many failures in downstream code are caused by treating a divergence as a distance metric.
26//!
27//! ## Key invariants (what tests should enforce)
28//!
29//! - **Jensen–Shannon** is bounded on the simplex:
30//!   \(0 \le JS(p,q) \le \ln 2\) (nats).
31//! - **Csiszár \(f\)-divergences** are monotone under coarse-graining (Markov kernels):
32//!   merging bins cannot increase the divergence.
33//!
34//! ## Further reading
35//!
36//! - Frank Nielsen, “Divergences” portal (taxonomy diagrams + references):
37//!   <https://franknielsen.github.io/Divergence/index.html>
38//! - `nocotan/awesome-information-geometry` (curated reading list):
39//!   <https://github.com/nocotan/awesome-information-geometry>
40//! - Csiszár (1967): \(f\)-divergences and information monotonicity.
41//! - Amari & Nagaoka (2000): *Methods of Information Geometry*.
42//!
43//! ## Taxonomy of Divergences (Nielsen)
44//!
45//! | Family | Generator | Key Property |
46//! |---|---|---|
47//! | **f-divergences** | Convex \(f(t)\) with \(f(1)=0\) | Monotone under Markov morphisms (coarse-graining) |
48//! | **Bregman** | Convex \(F(x)\) | Dually flat geometry; generalized Pythagorean theorem |
49//! | **Jensen-Shannon** | \(f\)-div + metric | Symmetric, bounded \(\[0, \ln 2\]\), \(\sqrt{JS}\) is a metric |
50//! | **Alpha** | \(\rho_\alpha = \int p^\alpha q^{1-\alpha}\) | Encodes Rényi, Tsallis, Bhattacharyya, Hellinger |
51//!
52//! ## Connections
53//!
54//! - [`rkhs`](../rkhs): MMD and KL both measure distribution "distance"
55//! - [`wass`](../wass): Wasserstein vs entropy-based divergences
56//! - [`stratify`](../stratify): NMI for cluster evaluation uses this crate
57//! - [`fynch`](../fynch): Temperature scaling affects entropy calibration
58//!
59//! ## References
60//!
61//! - Shannon (1948). "A Mathematical Theory of Communication"
62//! - Cover & Thomas (2006). "Elements of Information Theory"
63
64#![forbid(unsafe_code)]
65
66use thiserror::Error;
67
68mod ksg;
69pub use ksg::{mutual_information_ksg, KsgVariant};
70
71/// Natural log of 2. Useful when converting nats ↔ bits or bounding Jensen–Shannon.
72pub const LN_2: f64 = core::f64::consts::LN_2;
73
74/// KL Divergence between two diagonal Multivariate Gaussians.
75///
76/// Used for Variational Information Bottleneck (VIB) to regularize latent spaces.
77///
78/// Returns 0.5 * Σ [ (std1/std2)^2 + (mu2-mu1)^2 / std2^2 - 1 + 2*ln(std2/std1) ]
79///
80/// # Examples
81///
82/// ```
83/// # use logp::kl_divergence_gaussians;
84/// // KL(N || N) = 0 (identical Gaussians).
85/// let mu = [0.0, 1.0];
86/// let std = [1.0, 2.0];
87/// let kl = kl_divergence_gaussians(&mu, &std, &mu, &std).unwrap();
88/// assert!(kl.abs() < 1e-12);
89///
90/// // KL is non-negative for distinct Gaussians.
91/// let kl = kl_divergence_gaussians(&[0.0], &[1.0], &[1.0], &[1.0]).unwrap();
92/// assert!(kl >= 0.0);
93/// ```
94pub fn kl_divergence_gaussians(
95    mu1: &[f64],
96    std1: &[f64],
97    mu2: &[f64],
98    std2: &[f64],
99) -> Result<f64> {
100    ensure_same_len(mu1, std1)?;
101    ensure_same_len(mu1, mu2)?;
102    ensure_same_len(mu1, std2)?;
103
104    let mut kl = 0.0;
105    for (((&m1, &s1), &m2), &s2) in mu1.iter().zip(std1).zip(mu2).zip(std2) {
106        if s1 <= 0.0 || s2 <= 0.0 {
107            return Err(Error::Domain("standard deviation must be positive"));
108        }
109        let v1 = s1 * s1;
110        let v2 = s2 * s2;
111        kl += (v1 / v2) + (m2 - m1).powi(2) / v2 - 1.0 + 2.0 * (s2.ln() - s1.ln());
112    }
113    Ok(0.5 * kl)
114}
115
116/// Errors for information-measure computations.
117#[derive(Debug, Error)]
118pub enum Error {
119    #[error("length mismatch: {0} vs {1}")]
120    LengthMismatch(usize, usize),
121
122    #[error("empty input")]
123    Empty,
124
125    #[error("non-finite entry at index {idx}: {value}")]
126    NonFinite { idx: usize, value: f64 },
127
128    #[error("negative entry at index {idx}: {value}")]
129    Negative { idx: usize, value: f64 },
130
131    #[error("not normalized (expected sum≈1): sum={sum}")]
132    NotNormalized { sum: f64 },
133
134    #[error("invalid alpha: {alpha} (must be finite and not equal to {forbidden})")]
135    InvalidAlpha { alpha: f64, forbidden: f64 },
136
137    #[error("domain error: {0}")]
138    Domain(&'static str),
139}
140
141pub type Result<T> = core::result::Result<T, Error>;
142
143fn ensure_nonempty(x: &[f64]) -> Result<()> {
144    if x.is_empty() {
145        return Err(Error::Empty);
146    }
147    Ok(())
148}
149
150fn ensure_same_len(a: &[f64], b: &[f64]) -> Result<()> {
151    if a.len() != b.len() {
152        return Err(Error::LengthMismatch(a.len(), b.len()));
153    }
154    Ok(())
155}
156
157fn ensure_nonnegative(x: &[f64]) -> Result<()> {
158    for (i, &v) in x.iter().enumerate() {
159        if !v.is_finite() {
160            return Err(Error::NonFinite { idx: i, value: v });
161        }
162        if v < 0.0 {
163            return Err(Error::Negative { idx: i, value: v });
164        }
165    }
166    Ok(())
167}
168
169fn sum(x: &[f64]) -> f64 {
170    x.iter().sum()
171}
172
173/// Validate that `p` is a probability distribution on the simplex (within `tol`).
174///
175/// # Examples
176///
177/// ```
178/// # use logp::validate_simplex;
179/// // Valid simplex.
180/// assert!(validate_simplex(&[0.3, 0.7], 1e-9).is_ok());
181/// assert!(validate_simplex(&[1.0], 1e-9).is_ok());
182///
183/// // Rejects bad sum.
184/// assert!(validate_simplex(&[0.3, 0.6], 1e-9).is_err());
185///
186/// // Rejects negative entries.
187/// assert!(validate_simplex(&[1.5, -0.5], 1e-9).is_err());
188///
189/// // Rejects empty input.
190/// assert!(validate_simplex(&[], 1e-9).is_err());
191/// ```
192pub fn validate_simplex(p: &[f64], tol: f64) -> Result<()> {
193    ensure_nonempty(p)?;
194    ensure_nonnegative(p)?;
195    let s = sum(p);
196    if (s - 1.0).abs() > tol {
197        return Err(Error::NotNormalized { sum: s });
198    }
199    Ok(())
200}
201
202/// Normalize a nonnegative vector in-place to sum to 1.
203///
204/// Returns the original sum.
205///
206/// # Examples
207///
208/// ```
209/// # use logp::normalize_in_place;
210/// let mut v = vec![2.0, 3.0, 5.0];
211/// let original_sum = normalize_in_place(&mut v).unwrap();
212/// assert!((original_sum - 10.0).abs() < 1e-12);
213/// assert!((v[0] - 0.2).abs() < 1e-12);
214/// assert!((v[1] - 0.3).abs() < 1e-12);
215/// assert!((v[2] - 0.5).abs() < 1e-12);
216///
217/// // Rejects all-zero input.
218/// assert!(normalize_in_place(&mut vec![0.0, 0.0]).is_err());
219/// ```
220pub fn normalize_in_place(p: &mut [f64]) -> Result<f64> {
221    ensure_nonempty(p)?;
222    ensure_nonnegative(p)?;
223    let s = sum(p);
224    if s <= 0.0 {
225        return Err(Error::Domain("cannot normalize: sum <= 0"));
226    }
227    for v in p.iter_mut() {
228        *v /= s;
229    }
230    Ok(s)
231}
232
233/// Shannon entropy in nats: the expected surprise under distribution \(p\).
234///
235/// \[H(p) = -\sum_i p_i \ln p_i\]
236///
237/// # Key properties
238///
239/// - **Non-negative**: \(H(p) \ge 0\), with equality iff \(p\) is a delta (point mass).
240/// - **Maximized by uniform**: among distributions on \(n\) outcomes,
241///   \(H(p) \le \ln n\), with equality iff \(p_i = 1/n\) for all \(i\).
242/// - **Concavity**: \(H\) is a concave function of \(p\) on the simplex.
243///   Mixing distributions never decreases entropy.
244/// - **Units**: result is in nats (base \(e\)); divide by \(\ln 2\) for bits.
245///
246/// # Domain
247///
248/// Requires `p` to be a valid simplex distribution (within `tol`).
249///
250/// # Examples
251///
252/// ```
253/// # use logp::entropy_nats;
254/// // Uniform distribution over 4 outcomes: H = ln(4).
255/// let p = [0.25, 0.25, 0.25, 0.25];
256/// let h = entropy_nats(&p, 1e-9).unwrap();
257/// assert!((h - 4.0_f64.ln()).abs() < 1e-12);
258///
259/// // Delta (point mass): H = 0.
260/// let delta = [1.0, 0.0, 0.0];
261/// assert!(entropy_nats(&delta, 1e-9).unwrap().abs() < 1e-15);
262/// ```
263pub fn entropy_nats(p: &[f64], tol: f64) -> Result<f64> {
264    validate_simplex(p, tol)?;
265    let mut h = 0.0;
266    for &pi in p {
267        if pi > 0.0 {
268            h -= pi * pi.ln();
269        }
270    }
271    Ok(h)
272}
273
274/// Shannon entropy in bits.
275///
276/// # Examples
277///
278/// ```
279/// # use logp::{entropy_bits, entropy_nats, LN_2};
280/// // Fair coin: H = 1 bit.
281/// let p = [0.5, 0.5];
282/// let bits = entropy_bits(&p, 1e-9).unwrap();
283/// assert!((bits - 1.0).abs() < 1e-12);
284///
285/// // Consistent with nats / ln(2).
286/// let nats = entropy_nats(&p, 1e-9).unwrap();
287/// assert!((bits - nats / LN_2).abs() < 1e-12);
288/// ```
289pub fn entropy_bits(p: &[f64], tol: f64) -> Result<f64> {
290    Ok(entropy_nats(p, tol)? / LN_2)
291}
292
293/// Fast Shannon entropy calculation without simplex validation.
294///
295/// Used in performance-critical loops like Sinkhorn iteration for Optimal Transport.
296///
297/// # Invariant
298/// Assumes `p` is non-negative and normalized.
299///
300/// # Examples
301///
302/// ```
303/// # use logp::{entropy_unchecked, LN_2};
304/// // Fair coin: H = ln(2).
305/// let h = entropy_unchecked(&[0.5, 0.5]);
306/// assert!((h - LN_2).abs() < 1e-12);
307///
308/// // Agrees with the checked version on valid input.
309/// let p = [0.3, 0.7];
310/// let h_checked = logp::entropy_nats(&p, 1e-9).unwrap();
311/// assert!((entropy_unchecked(&p) - h_checked).abs() < 1e-15);
312/// ```
313#[inline]
314pub fn entropy_unchecked(p: &[f64]) -> f64 {
315    let mut h = 0.0;
316    for &pi in p {
317        if pi > 0.0 {
318            h -= pi * pi.ln();
319        }
320    }
321    h
322}
323
324/// Cross-entropy in nats: the expected code length when using model \(q\) to encode
325/// data drawn from true distribution \(p\).
326///
327/// \[H(p, q) = -\sum_i p_i \ln q_i\]
328///
329/// # Key properties
330///
331/// - **Decomposition identity**: cross-entropy splits into entropy plus KL divergence:
332///   \(H(p, q) = H(p) + D_{KL}(p \| q)\).
333///   This means \(H(p, q) \ge H(p)\) with equality iff \(p = q\).
334/// - **Loss function**: minimizing \(H(p, q)\) over \(q\) is equivalent to minimizing
335///   \(D_{KL}(p \| q)\), which is why cross-entropy is the standard classification loss.
336/// - **Not symmetric**: \(H(p, q) \ne H(q, p)\) in general.
337///
338/// # Domain
339///
340/// `p` must be on the simplex; `q` must be nonnegative and normalized; and
341/// whenever `p_i > 0`, we require `q_i > 0` (otherwise cross-entropy is infinite).
342///
343/// # Examples
344///
345/// ```
346/// # use logp::{cross_entropy_nats, entropy_nats, kl_divergence};
347/// let p = [0.3, 0.7];
348/// let q = [0.5, 0.5];
349/// let h_pq = cross_entropy_nats(&p, &q, 1e-9).unwrap();
350///
351/// // Decomposition: H(p,q) = H(p) + KL(p||q).
352/// let h_p = entropy_nats(&p, 1e-9).unwrap();
353/// let kl = kl_divergence(&p, &q, 1e-9).unwrap();
354/// assert!((h_pq - (h_p + kl)).abs() < 1e-12);
355///
356/// // Self-cross-entropy equals entropy: H(p,p) = H(p).
357/// let h_pp = cross_entropy_nats(&p, &p, 1e-9).unwrap();
358/// assert!((h_pp - h_p).abs() < 1e-12);
359/// ```
360pub fn cross_entropy_nats(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
361    validate_simplex(p, tol)?;
362    validate_simplex(q, tol)?;
363    let mut h = 0.0;
364    for (&pi, &qi) in p.iter().zip(q.iter()) {
365        if pi == 0.0 {
366            continue;
367        }
368        if qi <= 0.0 {
369            return Err(Error::Domain("cross-entropy undefined: q_i=0 while p_i>0"));
370        }
371        h -= pi * qi.ln();
372    }
373    Ok(h)
374}
375
376/// Kullback--Leibler divergence in nats: the information lost when \(q\) is used to
377/// approximate \(p\).
378///
379/// \[D_{KL}(p \| q) = \sum_i p_i \ln \frac{p_i}{q_i}\]
380///
381/// # Key properties
382///
383/// - **Gibbs' inequality**: \(D_{KL}(p \| q) \ge 0\), with equality iff \(p = q\).
384///   This follows directly from Jensen's inequality applied to \(-\ln\).
385/// - **Not symmetric**: \(D_{KL}(p \| q) \ne D_{KL}(q \| p)\) in general;
386///   this is why KL is a divergence, not a distance.
387/// - **Not bounded above**: KL can be arbitrarily large when supports differ.
388/// - **Connection to MLE**: minimizing \(D_{KL}(p_{data} \| q_\theta)\) over \(\theta\)
389///   is equivalent to maximum likelihood estimation.
390/// - **Additive for independent distributions**: if \(p = p_1 \otimes p_2\) and
391///   \(q = q_1 \otimes q_2\), then
392///   \(D_{KL}(p \| q) = D_{KL}(p_1 \| q_1) + D_{KL}(p_2 \| q_2)\).
393///
394/// # Domain
395///
396/// `p` and `q` must be valid simplex distributions; and whenever `p_i > 0`,
397/// we require `q_i > 0`.
398///
399/// # Examples
400///
401/// ```
402/// # use logp::kl_divergence;
403/// // KL(p || p) = 0 (Gibbs' inequality, tight case).
404/// let p = [0.2, 0.3, 0.5];
405/// assert!(kl_divergence(&p, &p, 1e-9).unwrap().abs() < 1e-15);
406///
407/// // KL is non-negative.
408/// let q = [0.5, 0.25, 0.25];
409/// assert!(kl_divergence(&p, &q, 1e-9).unwrap() >= 0.0);
410///
411/// // Not symmetric in general.
412/// let kl_pq = kl_divergence(&p, &q, 1e-9).unwrap();
413/// let kl_qp = kl_divergence(&q, &p, 1e-9).unwrap();
414/// assert!((kl_pq - kl_qp).abs() > 1e-6);
415/// ```
416pub fn kl_divergence(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
417    ensure_same_len(p, q)?;
418    validate_simplex(p, tol)?;
419    validate_simplex(q, tol)?;
420    let mut d = 0.0;
421    for (&pi, &qi) in p.iter().zip(q.iter()) {
422        if pi == 0.0 {
423            continue;
424        }
425        if qi <= 0.0 {
426            return Err(Error::Domain("KL undefined: q_i=0 while p_i>0"));
427        }
428        d += pi * (pi / qi).ln();
429    }
430    Ok(d)
431}
432
433/// Jensen--Shannon divergence in nats: a symmetric, bounded smoothing of KL divergence.
434///
435/// \[JS(p, q) = \tfrac{1}{2} D_{KL}(p \| m) + \tfrac{1}{2} D_{KL}(q \| m), \quad m = \tfrac{1}{2}(p + q)\]
436///
437/// # Key properties
438///
439/// - **Symmetric**: \(JS(p, q) = JS(q, p)\), unlike KL.
440/// - **Bounded**: \(0 \le JS(p, q) \le \ln 2\). The upper bound is attained when \(p\)
441///   and \(q\) have disjoint supports.
442/// - **Square root is a metric**: \(\sqrt{JS(p, q)}\) satisfies the triangle inequality
443///   (Endres & Schindelin, 2003), so it can be used as a proper distance function.
444/// - **Connection to mutual information**: \(JS(p, q) = I(X; Z)\) where \(Z\) is a
445///   fair coin selecting between \(p\) and \(q\), and \(X\) is drawn from the selected
446///   distribution.
447/// - **Always finite**: because \(m_i > 0\) whenever \(p_i > 0\) or \(q_i > 0\), the
448///   KL terms are always well-defined (no division by zero).
449///
450/// # Domain
451///
452/// `p`, `q` must be simplex distributions.
453///
454/// # Examples
455///
456/// ```
457/// # use logp::{jensen_shannon_divergence, LN_2};
458/// // JS(p, p) = 0.
459/// let p = [0.3, 0.7];
460/// assert!(jensen_shannon_divergence(&p, &p, 1e-9).unwrap().abs() < 1e-15);
461///
462/// // Disjoint supports: JS = ln(2).
463/// let a = [1.0, 0.0];
464/// let b = [0.0, 1.0];
465/// let js = jensen_shannon_divergence(&a, &b, 1e-9).unwrap();
466/// assert!((js - LN_2).abs() < 1e-12);
467///
468/// // Symmetric: JS(p, q) = JS(q, p).
469/// let q = [0.5, 0.5];
470/// let js_pq = jensen_shannon_divergence(&p, &q, 1e-9).unwrap();
471/// let js_qp = jensen_shannon_divergence(&q, &p, 1e-9).unwrap();
472/// assert!((js_pq - js_qp).abs() < 1e-15);
473/// ```
474pub fn jensen_shannon_divergence(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
475    ensure_same_len(p, q)?;
476    validate_simplex(p, tol)?;
477    validate_simplex(q, tol)?;
478
479    let mut m = vec![0.0; p.len()];
480    for i in 0..p.len() {
481        m[i] = 0.5 * (p[i] + q[i]);
482    }
483
484    Ok(0.5 * kl_divergence(p, &m, tol)? + 0.5 * kl_divergence(q, &m, tol)?)
485}
486
487/// Mutual information in nats: how much knowing \(Y\) reduces uncertainty about \(X\).
488///
489/// \[I(X; Y) = \sum_{x,y} p(x,y) \ln \frac{p(x,y)}{p(x)\,p(y)}\]
490///
491/// # Key properties
492///
493/// - **KL form**: \(I(X; Y) = D_{KL}\bigl(p(x,y) \;\|\; p(x)\,p(y)\bigr)\), measuring
494///   how far the joint distribution is from the product of its marginals.
495/// - **Non-negative**: \(I(X; Y) \ge 0\), with equality iff \(X\) and \(Y\) are
496///   independent.
497/// - **Symmetric**: \(I(X; Y) = I(Y; X)\).
498/// - **Bounded by entropy**: \(I(X; Y) \le \min\{H(X),\, H(Y)\}\).
499/// - **Data processing inequality**: for any Markov chain \(X \to Y \to Z\),
500///   \(I(X; Z) \le I(X; Y)\). Processing cannot create information.
501/// - **Entropy decomposition**: \(I(X; Y) = H(X) + H(Y) - H(X, Y) = H(X) - H(X|Y)\).
502///
503/// # Layout
504///
505/// For discrete distributions, given a **row-major** joint distribution table `p_xy`
506/// with shape `(n_x, n_y)`.
507///
508/// Public invariant (this is the important one): this API is **backend-agnostic**.
509/// It does not force `ndarray` into the public surface of an L1 crate.
510///
511/// # Examples
512///
513/// ```
514/// # use logp::{mutual_information, entropy_nats, LN_2};
515/// // Independent joint: p(x,y) = p(x)*p(y), so I(X;Y) = 0.
516/// let p_xy = [0.15, 0.35, 0.15, 0.35]; // 2x2, marginals [0.5,0.5] x [0.3,0.7]
517/// let mi = mutual_information(&p_xy, 2, 2, 1e-9).unwrap();
518/// assert!(mi.abs() < 1e-12);
519///
520/// // Perfect correlation (Y = X, uniform bit): I(X;Y) = H(X) = ln(2).
521/// let diag = [0.5, 0.0, 0.0, 0.5];
522/// let mi = mutual_information(&diag, 2, 2, 1e-9).unwrap();
523/// assert!((mi - LN_2).abs() < 1e-12);
524/// ```
525pub fn mutual_information(p_xy: &[f64], n_x: usize, n_y: usize, tol: f64) -> Result<f64> {
526    if n_x == 0 || n_y == 0 {
527        return Err(Error::Domain(
528            "mutual_information: n_x and n_y must be >= 1",
529        ));
530    }
531    if p_xy.len() != n_x * n_y {
532        return Err(Error::LengthMismatch(p_xy.len(), n_x * n_y));
533    }
534    validate_simplex(p_xy, tol)?;
535
536    let mut p_x = vec![0.0; n_x];
537    let mut p_y = vec![0.0; n_y];
538    for i in 0..n_x {
539        for j in 0..n_y {
540            let p = p_xy[i * n_y + j];
541            p_x[i] += p;
542            p_y[j] += p;
543        }
544    }
545
546    let mut mi = 0.0;
547    for i in 0..n_x {
548        for j in 0..n_y {
549            let pxy = p_xy[i * n_y + j];
550            if pxy > 0.0 {
551                let px = p_x[i];
552                let py = p_y[j];
553                if px <= 0.0 || py <= 0.0 {
554                    return Err(Error::Domain(
555                        "mutual_information: p(x)=0 or p(y)=0 while p(x,y)>0",
556                    ));
557                }
558                mi += pxy * (pxy / (px * py)).ln();
559            }
560        }
561    }
562    Ok(mi)
563}
564
565/// `ndarray` adapter for discrete mutual information.
566///
567/// Requires `logp` feature `ndarray`.
568#[cfg(feature = "ndarray")]
569pub fn mutual_information_ndarray(p_xy: &ndarray::Array2<f64>, tol: f64) -> Result<f64> {
570    let (n_x, n_y) = p_xy.dim();
571    let flat: Vec<f64> = p_xy.iter().copied().collect();
572    mutual_information(&flat, n_x, n_y, tol)
573}
574
575/// Pointwise mutual information: the log-ratio measuring how much more (or less)
576/// likely two specific outcomes co-occur than if they were independent.
577///
578/// \[PMI(x; y) = \ln \frac{p(x, y)}{p(x)\,p(y)}\]
579///
580/// # Key properties
581///
582/// - **Sign**: positive when \(x\) and \(y\) co-occur more than chance; negative when
583///   less; zero when independent.
584/// - **Unbounded**: \(PMI \in (-\infty, -\ln p(x,y)]\). In practice, rare events yield
585///   very large PMI, which is why PPMI (positive PMI, clamped at 0) is common.
586/// - **Connection to mutual information**: \(I(X; Y) = \mathbb{E}_{p(x,y)}[PMI(x; y)]\).
587///   MI is the expected value of PMI over the joint distribution.
588/// - **Connection to word2vec**: Levy & Goldberg (2014) showed that Skip-gram with
589///   negative sampling implicitly factorizes a PMI matrix (shifted by \(\ln k\)).
590///
591/// # Examples
592///
593/// ```
594/// # use logp::pmi;
595/// // Independent events: p(x,y) = p(x)*p(y), so PMI = 0.
596/// let val = pmi(0.06, 0.3, 0.2);
597/// assert!(val.abs() < 1e-10);
598///
599/// // Positive correlation: p(x,y) > p(x)*p(y).
600/// assert!(pmi(0.4, 0.5, 0.5) > 0.0);
601///
602/// // Negative correlation: p(x,y) < p(x)*p(y).
603/// assert!(pmi(0.1, 0.5, 0.5) < 0.0);
604///
605/// // Zero joint probability returns 0 by convention.
606/// assert_eq!(pmi(0.0, 0.5, 0.5), 0.0);
607/// ```
608pub fn pmi(pxy: f64, px: f64, py: f64) -> f64 {
609    if pxy <= 0.0 || px <= 0.0 || py <= 0.0 {
610        0.0
611    } else {
612        (pxy / (px * py)).ln()
613    }
614}
615
616/// Log-sum-exp: numerically stable computation of `ln(exp(a_1) + ... + exp(a_n))`.
617///
618/// This is the fundamental primitive for working in log-probability space.
619/// The naive `values.iter().map(|v| v.exp()).sum::<f64>().ln()` overflows
620/// for large values and underflows for small ones; the max-shift trick
621/// avoids both.
622///
623/// Returns `NEG_INFINITY` for an empty slice.
624///
625/// # Examples
626///
627/// ```
628/// # use logp::log_sum_exp;
629/// // ln(e^0 + e^0) = ln(2)
630/// let lse = log_sum_exp(&[0.0, 0.0]);
631/// assert!((lse - 2.0_f64.ln()).abs() < 1e-12);
632///
633/// // Dominated term: ln(e^1000 + e^0) ≈ 1000
634/// let lse = log_sum_exp(&[1000.0, 0.0]);
635/// assert!((lse - 1000.0).abs() < 1e-10);
636///
637/// // Single element: identity.
638/// assert_eq!(log_sum_exp(&[42.0]), 42.0);
639///
640/// // Empty: -inf.
641/// assert_eq!(log_sum_exp(&[]), f64::NEG_INFINITY);
642/// ```
643#[inline]
644pub fn log_sum_exp(values: &[f64]) -> f64 {
645    if values.is_empty() {
646        return f64::NEG_INFINITY;
647    }
648    let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
649    if max.is_infinite() {
650        return max;
651    }
652    let sum: f64 = values.iter().map(|v| (v - max).exp()).sum();
653    max + sum.ln()
654}
655
656/// Log-sum-exp for two values (common special case).
657///
658/// Equivalent to `log_sum_exp(&[a, b])` but avoids the slice overhead.
659///
660/// # Examples
661///
662/// ```
663/// # use logp::log_sum_exp2;
664/// let lse = log_sum_exp2(0.0, 0.0);
665/// assert!((lse - 2.0_f64.ln()).abs() < 1e-12);
666/// ```
667#[inline]
668pub fn log_sum_exp2(a: f64, b: f64) -> f64 {
669    let max = a.max(b);
670    if max.is_infinite() {
671        return max;
672    }
673    max + ((a - max).exp() + (b - max).exp()).ln()
674}
675
676/// Digamma function: the logarithmic derivative of the Gamma function.
677///
678/// \[\psi(x) = \frac{d}{dx} \ln \Gamma(x) = \frac{\Gamma'(x)}{\Gamma(x)}\]
679///
680/// # Key properties
681///
682/// - **Recurrence**: \(\psi(x+1) = \psi(x) + \frac{1}{x}\), which follows from
683///   \(\Gamma(x+1) = x\,\Gamma(x)\).
684/// - **Special value**: \(\psi(1) = -\gamma \approx -0.5772\), where \(\gamma\) is
685///   the Euler--Mascheroni constant.
686/// - **Asymptotic**: \(\psi(x) \sim \ln x - \frac{1}{2x}\) for large \(x\).
687/// - **Why it appears here**: the KSG estimator for mutual information
688///   ([`mutual_information_ksg`]) uses digamma to correct for the bias of
689///   nearest-neighbor density estimates.
690///
691/// # Domain
692///
693/// Defined for \(x > 0\). Returns `NaN` for \(x \le 0\).
694///
695/// # Implementation
696///
697/// Uses the recurrence to shift small \(x\) up to \(x \ge 7\), then applies the
698/// asymptotic expansion with Bernoulli-number correction terms.
699///
700/// # Examples
701///
702/// ```
703/// # use logp::digamma;
704/// // psi(1) = -gamma (Euler-Mascheroni constant).
705/// let psi1 = digamma(1.0);
706/// assert!((psi1 - (-0.5772156649)).abs() < 1e-8);
707///
708/// // Recurrence: psi(x+1) = psi(x) + 1/x.
709/// let x = 3.5;
710/// assert!((digamma(x + 1.0) - digamma(x) - 1.0 / x).abs() < 1e-10);
711///
712/// // Non-positive input returns NaN.
713/// assert!(digamma(0.0).is_nan());
714/// assert!(digamma(-1.0).is_nan());
715/// ```
716pub fn digamma(mut x: f64) -> f64 {
717    if x <= 0.0 {
718        return f64::NAN;
719    }
720    let mut result = 0.0;
721    while x < 7.0 {
722        result -= 1.0 / x;
723        x += 1.0;
724    }
725    let r = 1.0 / x;
726    result += x.ln() - 0.5 * r;
727    let r2 = r * r;
728    result -= r2 * (1.0 / 12.0 - r2 * (1.0 / 120.0 - r2 / 252.0));
729    result
730}
731
732/// Bhattacharyya coefficient: the geometric-mean overlap between two distributions.
733///
734/// \[BC(p, q) = \sum_i \sqrt{p_i \, q_i}\]
735///
736/// # Key properties
737///
738/// - **Geometric mean interpretation**: each term \(\sqrt{p_i q_i}\) is the geometric
739///   mean of the two probabilities at bin \(i\). BC sums these, measuring how much
740///   the distributions overlap.
741/// - **Bounded**: \(BC \in [0, 1]\). Equals 1 iff \(p = q\); equals 0 iff supports
742///   are disjoint.
743/// - **Relationship to Hellinger**: \(H^2(p, q) = 1 - BC(p, q)\), so the squared
744///   Hellinger distance is just one minus the Bhattacharyya coefficient.
745/// - **Relationship to Renyi**: at \(\alpha = \tfrac{1}{2}\), the Renyi divergence
746///   gives \(D_{1/2}^R(p \| q) = -2 \ln BC(p, q)\).
747/// - **Connection to alpha family**: \(BC = \rho_{1/2}(p, q)\), a special case of
748///   [`rho_alpha`].
749///
750/// # Examples
751///
752/// ```
753/// # use logp::bhattacharyya_coeff;
754/// // BC(p, p) = 1.
755/// let p = [0.3, 0.7];
756/// assert!((bhattacharyya_coeff(&p, &p, 1e-9).unwrap() - 1.0).abs() < 1e-12);
757///
758/// // Disjoint supports: BC = 0.
759/// let a = [1.0, 0.0];
760/// let b = [0.0, 1.0];
761/// assert!(bhattacharyya_coeff(&a, &b, 1e-9).unwrap().abs() < 1e-15);
762///
763/// // BC is in [0, 1].
764/// let q = [0.5, 0.5];
765/// let bc = bhattacharyya_coeff(&p, &q, 1e-9).unwrap();
766/// assert!(bc >= 0.0 && bc <= 1.0);
767/// ```
768pub fn bhattacharyya_coeff(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
769    ensure_same_len(p, q)?;
770    validate_simplex(p, tol)?;
771    validate_simplex(q, tol)?;
772    let bc: f64 = p
773        .iter()
774        .zip(q.iter())
775        .map(|(&pi, &qi)| (pi * qi).sqrt())
776        .sum();
777    Ok(bc)
778}
779
780/// Bhattacharyya distance \(D_B(p,q) = -\ln BC(p,q)\).
781///
782/// # Examples
783///
784/// ```
785/// # use logp::bhattacharyya_distance;
786/// // D_B(p, p) = 0.
787/// let p = [0.4, 0.6];
788/// assert!(bhattacharyya_distance(&p, &p, 1e-9).unwrap().abs() < 1e-12);
789///
790/// // Non-negative for distinct distributions.
791/// let q = [0.5, 0.5];
792/// assert!(bhattacharyya_distance(&p, &q, 1e-9).unwrap() >= 0.0);
793/// ```
794pub fn bhattacharyya_distance(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
795    let bc = bhattacharyya_coeff(p, q, tol)?;
796    // When supports are disjoint, bc can be 0 (=> +∞ distance). Keep it explicit.
797    if bc == 0.0 {
798        return Err(Error::Domain("Bhattacharyya distance is infinite (BC=0)"));
799    }
800    Ok(-bc.ln())
801}
802
803/// Squared Hellinger distance: one minus the Bhattacharyya coefficient.
804///
805/// \[H^2(p, q) = 1 - \sum_i \sqrt{p_i \, q_i} = 1 - BC(p, q)\]
806///
807/// Bounded in \([0, 1]\). Equals the Amari \(\alpha\)-divergence at \(\alpha = 0\)
808/// (up to a factor of 2).
809///
810/// # Examples
811///
812/// ```
813/// # use logp::hellinger_squared;
814/// // H^2(p, p) = 0.
815/// let p = [0.25, 0.75];
816/// assert!(hellinger_squared(&p, &p, 1e-9).unwrap().abs() < 1e-15);
817///
818/// // Disjoint supports: H^2 = 1.
819/// let a = [1.0, 0.0];
820/// let b = [0.0, 1.0];
821/// assert!((hellinger_squared(&a, &b, 1e-9).unwrap() - 1.0).abs() < 1e-12);
822/// ```
823pub fn hellinger_squared(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
824    let bc = bhattacharyya_coeff(p, q, tol)?;
825    Ok((1.0 - bc).max(0.0))
826}
827
828/// Hellinger distance: the square root of the squared Hellinger distance.
829///
830/// \[H(p, q) = \sqrt{1 - BC(p, q)}\]
831///
832/// Unlike KL divergence, Hellinger is a **proper metric**: it is symmetric, satisfies
833/// the triangle inequality, and is bounded in \([0, 1]\).
834///
835/// # Examples
836///
837/// ```
838/// # use logp::hellinger;
839/// // H(p, p) = 0.
840/// let p = [0.3, 0.7];
841/// assert!(hellinger(&p, &p, 1e-9).unwrap().abs() < 1e-15);
842///
843/// // Symmetric: H(p, q) = H(q, p).
844/// let q = [0.5, 0.5];
845/// let h_pq = hellinger(&p, &q, 1e-9).unwrap();
846/// let h_qp = hellinger(&q, &p, 1e-9).unwrap();
847/// assert!((h_pq - h_qp).abs() < 1e-15);
848///
849/// // Bounded in [0, 1].
850/// assert!(h_pq >= 0.0 && h_pq <= 1.0);
851/// ```
852pub fn hellinger(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
853    Ok(hellinger_squared(p, q, tol)?.sqrt())
854}
855
856fn pow_nonneg(x: f64, a: f64) -> Result<f64> {
857    if x < 0.0 || !x.is_finite() || !a.is_finite() {
858        return Err(Error::Domain("pow_nonneg: invalid input"));
859    }
860    if x == 0.0 {
861        if a == 0.0 {
862            // By continuity in the divergence formulas, treat 0^0 as 1.
863            return Ok(1.0);
864        }
865        if a > 0.0 {
866            return Ok(0.0);
867        }
868        return Err(Error::Domain("0^a for a<0 is infinite"));
869    }
870    Ok(x.powf(a))
871}
872
873/// Alpha-integral: the workhorse behind the entire alpha-family of divergences.
874///
875/// \[\rho_\alpha(p, q) = \sum_i p_i^\alpha \, q_i^{1-\alpha}\]
876///
877/// # Why this matters
878///
879/// This single quantity generates multiple divergence families via simple transforms:
880///
881/// - **Renyi**: \(D_\alpha^R = \frac{1}{\alpha - 1} \ln \rho_\alpha\)
882/// - **Tsallis**: \(D_\alpha^T = \frac{\rho_\alpha - 1}{\alpha - 1}\)
883/// - **Bhattacharyya coefficient**: \(BC = \rho_{1/2}\)
884/// - **Chernoff information**: \(\min_\alpha (-\ln \rho_\alpha)\)
885///
886/// # Key properties
887///
888/// - \(\rho_\alpha(p, p) = 1\) for all \(\alpha\) (since \(\sum p_i = 1\)).
889/// - By Holder's inequality, \(\rho_\alpha(p, q) \le 1\) for \(\alpha \in [0, 1]\).
890/// - Continuous and log-convex in \(\alpha\).
891///
892/// # Examples
893///
894/// ```
895/// # use logp::rho_alpha;
896/// // rho_alpha(p, p, alpha) = 1 for any alpha (since sum(p) = 1).
897/// let p = [0.2, 0.3, 0.5];
898/// assert!((rho_alpha(&p, &p, 0.5, 1e-9).unwrap() - 1.0).abs() < 1e-12);
899/// assert!((rho_alpha(&p, &p, 2.0, 1e-9).unwrap() - 1.0).abs() < 1e-12);
900///
901/// // At alpha = 0.5, rho equals the Bhattacharyya coefficient.
902/// let q = [0.5, 0.25, 0.25];
903/// let rho = rho_alpha(&p, &q, 0.5, 1e-9).unwrap();
904/// let bc = logp::bhattacharyya_coeff(&p, &q, 1e-9).unwrap();
905/// assert!((rho - bc).abs() < 1e-12);
906/// ```
907pub fn rho_alpha(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
908    ensure_same_len(p, q)?;
909    validate_simplex(p, tol)?;
910    validate_simplex(q, tol)?;
911    if !alpha.is_finite() {
912        return Err(Error::InvalidAlpha {
913            alpha,
914            forbidden: f64::NAN,
915        });
916    }
917    let mut s = 0.0;
918    for (&pi, &qi) in p.iter().zip(q.iter()) {
919        let a = pow_nonneg(pi, alpha)?;
920        let b = pow_nonneg(qi, 1.0 - alpha)?;
921        s += a * b;
922    }
923    Ok(s)
924}
925
926/// Renyi divergence in nats: a one-parameter family that interpolates between
927/// different notions of distributional difference.
928///
929/// \[D_\alpha^R(p \| q) = \frac{1}{\alpha - 1} \ln \rho_\alpha(p, q), \quad \alpha > 0,\; \alpha \ne 1\]
930///
931/// # Key properties
932///
933/// - **Limit to KL**: \(\lim_{\alpha \to 1} D_\alpha^R(p \| q) = D_{KL}(p \| q)\)
934///   by L'Hopital's rule (the logarithm and denominator both vanish).
935/// - **Alpha = 1/2**: \(D_{1/2}^R = -2 \ln BC(p, q)\), twice the negative log
936///   Bhattacharyya coefficient.
937/// - **Alpha = infinity**: \(D_\infty^R = \ln \max_i (p_i / q_i)\), the log of the
938///   maximum likelihood ratio. This bounds all other Renyi orders.
939/// - **Monotone in alpha**: \(D_\alpha^R\) is non-decreasing in \(\alpha\).
940/// - **Non-negative**: \(D_\alpha^R(p \| q) \ge 0\), with equality iff \(p = q\).
941///
942/// # Domain
943///
944/// \(\alpha > 0\), \(\alpha \ne 1\). Both `p` and `q` must be simplex distributions.
945///
946/// # Examples
947///
948/// ```
949/// # use logp::renyi_divergence;
950/// // D_alpha(p || p) = 0 for any valid alpha.
951/// let p = [0.3, 0.7];
952/// assert!(renyi_divergence(&p, &p, 2.0, 1e-9).unwrap().abs() < 1e-12);
953///
954/// // Non-negative.
955/// let q = [0.5, 0.5];
956/// assert!(renyi_divergence(&p, &q, 0.5, 1e-9).unwrap() >= -1e-12);
957///
958/// // alpha = 1.0 is forbidden (use KL instead).
959/// assert!(renyi_divergence(&p, &q, 1.0, 1e-9).is_err());
960/// ```
961pub fn renyi_divergence(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
962    if alpha == 1.0 {
963        return Err(Error::InvalidAlpha {
964            alpha,
965            forbidden: 1.0,
966        });
967    }
968    let rho = rho_alpha(p, q, alpha, tol)?;
969    if rho <= 0.0 {
970        return Err(Error::Domain("rho_alpha <= 0"));
971    }
972    Ok(rho.ln() / (alpha - 1.0))
973}
974
975/// Tsallis divergence: a non-extensive generalization of KL divergence from
976/// statistical mechanics.
977///
978/// \[D_\alpha^T(p \| q) = \frac{\rho_\alpha(p, q) - 1}{\alpha - 1}, \quad \alpha \ne 1\]
979///
980/// # Key properties
981///
982/// - **Limit to KL**: \(\lim_{\alpha \to 1} D_\alpha^T(p \| q) = D_{KL}(p \| q)\),
983///   same limit as Renyi but via a different path.
984/// - **Connection to Renyi via deformed logarithm**: Tsallis uses the q-logarithm
985///   \(\ln_q(x) = \frac{x^{1-q} - 1}{1-q}\) where Renyi uses the ordinary log.
986///   Formally: \(D_\alpha^T = \frac{e^{(\alpha-1) D_\alpha^R} - 1}{\alpha - 1}\).
987/// - **Non-extensive**: for independent systems, Tsallis divergence is **not** additive
988///   (unlike KL and Renyi). This property is intentional and models systems with
989///   long-range correlations in statistical physics.
990/// - **Non-negative**: \(D_\alpha^T(p \| q) \ge 0\), with equality iff \(p = q\).
991///
992/// # Domain
993///
994/// \(\alpha \ne 1\). Both `p` and `q` must be simplex distributions.
995///
996/// # Examples
997///
998/// ```
999/// # use logp::tsallis_divergence;
1000/// // D_alpha^T(p || p) = 0 for any valid alpha.
1001/// let p = [0.4, 0.6];
1002/// assert!(tsallis_divergence(&p, &p, 2.0, 1e-9).unwrap().abs() < 1e-12);
1003///
1004/// // Non-negative.
1005/// let q = [0.5, 0.5];
1006/// assert!(tsallis_divergence(&p, &q, 0.5, 1e-9).unwrap() >= -1e-12);
1007///
1008/// // alpha = 1.0 is forbidden.
1009/// assert!(tsallis_divergence(&p, &q, 1.0, 1e-9).is_err());
1010/// ```
1011pub fn tsallis_divergence(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
1012    if alpha == 1.0 {
1013        return Err(Error::InvalidAlpha {
1014            alpha,
1015            forbidden: 1.0,
1016        });
1017    }
1018    Ok((rho_alpha(p, q, alpha, tol)? - 1.0) / (alpha - 1.0))
1019}
1020
1021/// Amari alpha-divergence: a one-parameter family from information geometry that
1022/// continuously interpolates between forward KL, reverse KL, and squared Hellinger.
1023///
1024/// For \(\alpha \notin \{-1, 1\}\):
1025///
1026/// \[D^\alpha(p : q) = \frac{4}{1 - \alpha^2}\left(1 - \rho_{\frac{1-\alpha}{2}}(p, q)\right)\]
1027///
1028/// # Key properties
1029///
1030/// - **\(\alpha = -1\)**: recovers \(D_{KL}(p \| q)\), the forward KL divergence.
1031/// - **\(\alpha = +1\)**: recovers \(D_{KL}(q \| p)\), the reverse KL divergence.
1032/// - **\(\alpha = 0\)**: gives \(4(1 - BC(p,q)) = 4\,H^2(p,q)\), proportional to
1033///   the squared Hellinger distance.
1034/// - **Duality**: \(D^\alpha(p : q) = D^{-\alpha}(q : p)\). Swapping the sign of
1035///   \(\alpha\) is the same as swapping the arguments.
1036/// - **Non-negative**: \(D^\alpha(p : q) \ge 0\), with equality iff \(p = q\).
1037/// - **Information geometry**: the Amari family parameterizes the \(\alpha\)-connections
1038///   on the statistical manifold (Amari & Nagaoka, 2000).
1039///
1040/// # Examples
1041///
1042/// ```
1043/// # use logp::{amari_alpha_divergence, kl_divergence, hellinger_squared};
1044/// let p = [0.3, 0.7];
1045/// let q = [0.5, 0.5];
1046/// let tol = 1e-9;
1047///
1048/// // alpha = -1 gives forward KL(p || q).
1049/// let amari_neg1 = amari_alpha_divergence(&p, &q, -1.0, tol).unwrap();
1050/// let kl_fwd = kl_divergence(&p, &q, tol).unwrap();
1051/// assert!((amari_neg1 - kl_fwd).abs() < 1e-6);
1052///
1053/// // alpha = 0 gives 4 * H^2(p, q).
1054/// let amari_0 = amari_alpha_divergence(&p, &q, 0.0, tol).unwrap();
1055/// let h2 = hellinger_squared(&p, &q, tol).unwrap();
1056/// assert!((amari_0 - 4.0 * h2).abs() < 1e-10);
1057/// ```
1058pub fn amari_alpha_divergence(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
1059    if !alpha.is_finite() {
1060        return Err(Error::InvalidAlpha {
1061            alpha,
1062            forbidden: f64::NAN,
1063        });
1064    }
1065    // Numerically stable handling near ±1.
1066    let eps = 1e-10;
1067    if (alpha + 1.0).abs() <= eps {
1068        return kl_divergence(p, q, tol);
1069    }
1070    if (alpha - 1.0).abs() <= eps {
1071        return kl_divergence(q, p, tol);
1072    }
1073    let t = (1.0 - alpha) / 2.0;
1074    let rho = rho_alpha(p, q, t, tol)?;
1075    Ok((4.0 / (1.0 - alpha * alpha)) * (1.0 - rho))
1076}
1077
1078/// Csiszar f-divergence: the most general class of divergences that respect
1079/// sufficient statistics (information monotonicity).
1080///
1081/// \[D_f(p \| q) = \sum_i q_i \, f\!\left(\frac{p_i}{q_i}\right)\]
1082///
1083/// where \(f\) is a convex function with \(f(1) = 0\).
1084///
1085/// # Information monotonicity theorem
1086///
1087/// The defining property of f-divergences (Csiszar, 1967): for any Markov kernel
1088/// (stochastic map) \(T\),
1089///
1090/// \[D_f(Tp \| Tq) \le D_f(p \| q)\]
1091///
1092/// Coarse-graining (merging bins) cannot increase the divergence. This is the
1093/// information-theoretic analogue of the data processing inequality.
1094///
1095/// # Common f-generators
1096///
1097/// | Divergence | \(f(t)\) |
1098/// |---|---|
1099/// | KL divergence | \(t \ln t\) |
1100/// | Reverse KL | \(-\ln t\) |
1101/// | Squared Hellinger | \((\sqrt{t} - 1)^2\) |
1102/// | Total variation | \(\tfrac{1}{2} |t - 1|\) |
1103/// | Chi-squared | \((t - 1)^2\) |
1104/// | Jensen-Shannon | \(t \ln t - (1+t) \ln \tfrac{1+t}{2}\) |
1105///
1106/// # Edge cases
1107///
1108/// When `q_i = 0`:
1109/// - if `p_i = 0`, the contribution is treated as 0 (by continuity).
1110/// - if `p_i > 0`, the divergence is infinite; we return an error.
1111///
1112/// # Examples
1113///
1114/// ```
1115/// # use logp::{csiszar_f_divergence, kl_divergence};
1116/// let p = [0.3, 0.7];
1117/// let q = [0.5, 0.5];
1118///
1119/// // f(t) = t*ln(t) recovers KL divergence.
1120/// let cs = csiszar_f_divergence(&p, &q, |t| t * t.ln(), 1e-9).unwrap();
1121/// let kl = kl_divergence(&p, &q, 1e-9).unwrap();
1122/// assert!((cs - kl).abs() < 1e-10);
1123///
1124/// // f(t) = (t - 1)^2 gives chi-squared divergence.
1125/// let chi2 = csiszar_f_divergence(&p, &q, |t| (t - 1.0).powi(2), 1e-9).unwrap();
1126/// assert!(chi2 >= 0.0);
1127/// ```
1128pub fn csiszar_f_divergence(p: &[f64], q: &[f64], f: impl Fn(f64) -> f64, tol: f64) -> Result<f64> {
1129    ensure_same_len(p, q)?;
1130    validate_simplex(p, tol)?;
1131    validate_simplex(q, tol)?;
1132
1133    let mut d = 0.0;
1134    for (&pi, &qi) in p.iter().zip(q.iter()) {
1135        if qi == 0.0 {
1136            if pi == 0.0 {
1137                continue;
1138            }
1139            return Err(Error::Domain("f-divergence undefined: q_i=0 while p_i>0"));
1140        }
1141        d += qi * f(pi / qi);
1142    }
1143    Ok(d)
1144}
1145
1146/// Bregman generator: a convex function \(F\) and its gradient.
1147pub trait BregmanGenerator {
1148    /// Evaluate the potential \(F(x)\).
1149    fn f(&self, x: &[f64]) -> Result<f64>;
1150
1151    /// Write \(\nabla F(x)\) into `out`.
1152    fn grad_into(&self, x: &[f64], out: &mut [f64]) -> Result<()>;
1153}
1154
1155/// Bregman divergence: the gap between a convex function and its tangent approximation.
1156///
1157/// \[B_F(p, q) = F(p) - F(q) - \langle p - q,\, \nabla F(q) \rangle\]
1158///
1159/// # Key properties
1160///
1161/// - **Non-negative**: \(B_F(p, q) \ge 0\) by convexity of \(F\), with equality iff
1162///   \(p = q\).
1163/// - **Not symmetric** in general: \(B_F(p, q) \ne B_F(q, p)\).
1164/// - **Generalized Pythagorean theorem**: for an affine subspace \(S\) and its
1165///   Bregman projection \(q^* = \arg\min_{q \in S} B_F(p, q)\), the three-point
1166///   identity holds: \(B_F(p, q) = B_F(p, q^*) + B_F(q^*, q)\) for all \(q \in S\).
1167///   This is the foundation of dually flat geometry (Amari).
1168/// - **Not an f-divergence**: Bregman divergences are **not** information monotone
1169///   in general. They live in a different branch of the divergence taxonomy.
1170/// - **Examples**: squared Euclidean (\(F = \tfrac{1}{2}\|x\|^2\)) gives
1171///   \(B_F(p,q) = \tfrac{1}{2}\|p - q\|^2\); negative entropy
1172///   (\(F = \sum x_i \ln x_i\)) gives the KL divergence.
1173///
1174/// # Examples
1175///
1176/// ```
1177/// # use logp::{bregman_divergence, SquaredL2};
1178/// // Squared-L2 generator: B_F(p, q) = 0.5 * ||p - q||^2.
1179/// let gen = SquaredL2;
1180/// let p = [1.0, 2.0, 3.0];
1181/// let q = [1.5, 1.5, 2.5];
1182/// let mut grad = [0.0; 3];
1183/// let b = bregman_divergence(&gen, &p, &q, &mut grad).unwrap();
1184/// let expected = 0.5 * ((0.5_f64).powi(2) + (0.5_f64).powi(2) + (0.5_f64).powi(2));
1185/// assert!((b - expected).abs() < 1e-12);
1186///
1187/// // B_F(p, p) = 0.
1188/// let mut grad2 = [0.0; 3];
1189/// assert!(bregman_divergence(&gen, &p, &p, &mut grad2).unwrap().abs() < 1e-15);
1190/// ```
1191pub fn bregman_divergence(
1192    gen: &impl BregmanGenerator,
1193    p: &[f64],
1194    q: &[f64],
1195    grad_q: &mut [f64],
1196) -> Result<f64> {
1197    ensure_nonempty(p)?;
1198    ensure_same_len(p, q)?;
1199    if grad_q.len() != q.len() {
1200        return Err(Error::LengthMismatch(grad_q.len(), q.len()));
1201    }
1202    gen.grad_into(q, grad_q)?;
1203    let fp = gen.f(p)?;
1204    let fq = gen.f(q)?;
1205    let mut inner = 0.0;
1206    for i in 0..p.len() {
1207        inner += (p[i] - q[i]) * grad_q[i];
1208    }
1209    Ok(fp - fq - inner)
1210}
1211
1212/// Total Bregman divergence as shown in Nielsen’s taxonomy diagram:
1213///
1214/// \(tB_F(p,q) = \frac{B_F(p,q)}{\sqrt{1 + \|\nabla F(q)\|^2}}\).
1215///
1216/// # Examples
1217///
1218/// ```
1219/// # use logp::{total_bregman_divergence, bregman_divergence, SquaredL2};
1220/// let gen = SquaredL2;
1221/// let p = [1.0, 2.0];
1222/// let q = [3.0, 4.0];
1223/// let mut grad = [0.0; 2];
1224///
1225/// let tb = total_bregman_divergence(&gen, &p, &q, &mut grad).unwrap();
1226///
1227/// // Total Bregman <= Bregman (normalization divides by >= 1).
1228/// let mut grad2 = [0.0; 2];
1229/// let b = bregman_divergence(&gen, &p, &q, &mut grad2).unwrap();
1230/// assert!(tb <= b + 1e-12);
1231/// ```
1232pub fn total_bregman_divergence(
1233    gen: &impl BregmanGenerator,
1234    p: &[f64],
1235    q: &[f64],
1236    grad_q: &mut [f64],
1237) -> Result<f64> {
1238    let b = bregman_divergence(gen, p, q, grad_q)?;
1239    let grad_norm_sq: f64 = grad_q.iter().map(|&x| x * x).sum();
1240    Ok(b / (1.0 + grad_norm_sq).sqrt())
1241}
1242
1243/// Squared Euclidean Bregman generator: \(F(x)=\tfrac12\|x\|_2^2\), \(\nabla F(x)=x\).
1244#[derive(Debug, Clone, Copy, Default)]
1245pub struct SquaredL2;
1246
1247impl BregmanGenerator for SquaredL2 {
1248    fn f(&self, x: &[f64]) -> Result<f64> {
1249        ensure_nonempty(x)?;
1250        Ok(0.5 * x.iter().map(|&v| v * v).sum::<f64>())
1251    }
1252
1253    fn grad_into(&self, x: &[f64], out: &mut [f64]) -> Result<()> {
1254        ensure_nonempty(x)?;
1255        if out.len() != x.len() {
1256            return Err(Error::LengthMismatch(out.len(), x.len()));
1257        }
1258        out.copy_from_slice(x);
1259        Ok(())
1260    }
1261}
1262
1263#[cfg(test)]
1264mod tests {
1265    use super::*;
1266    use proptest::prelude::*;
1267
1268    const TOL: f64 = 1e-9;
1269
1270    fn simplex_vec(len: usize) -> impl Strategy<Value = Vec<f64>> {
1271        // Draw nonnegative weights then normalize.
1272        prop::collection::vec(0.0f64..10.0, len).prop_map(|mut v| {
1273            let s: f64 = v.iter().sum();
1274            if s == 0.0 {
1275                v[0] = 1.0;
1276                return v;
1277            }
1278            for x in v.iter_mut() {
1279                *x /= s;
1280            }
1281            v
1282        })
1283    }
1284
1285    fn simplex_vec_pos(len: usize, eps: f64) -> impl Strategy<Value = Vec<f64>> {
1286        prop::collection::vec(0.0f64..10.0, len).prop_map(move |mut v| {
1287            // Add a small floor to avoid exact zeros (needed for KL-style domains).
1288            for x in v.iter_mut() {
1289                *x += eps;
1290            }
1291            let s: f64 = v.iter().sum();
1292            for x in v.iter_mut() {
1293                *x /= s;
1294            }
1295            v
1296        })
1297    }
1298
1299    fn random_partition(n: usize) -> impl Strategy<Value = Vec<usize>> {
1300        // Partition indices into k buckets (k chosen implicitly).
1301        // We generate a label in [0, n) for each index and later reindex to compact labels.
1302        prop::collection::vec(0usize..n, n).prop_map(|labels| {
1303            // Compress labels to 0..k-1 while preserving equality pattern.
1304            use std::collections::BTreeMap;
1305            let mut map = BTreeMap::<usize, usize>::new();
1306            let mut next = 0usize;
1307            labels
1308                .into_iter()
1309                .map(|l| {
1310                    *map.entry(l).or_insert_with(|| {
1311                        let id = next;
1312                        next += 1;
1313                        id
1314                    })
1315                })
1316                .collect::<Vec<_>>()
1317        })
1318    }
1319
1320    fn coarse_grain(p: &[f64], labels: &[usize]) -> Vec<f64> {
1321        let k = labels.iter().copied().max().unwrap_or(0) + 1;
1322        let mut out = vec![0.0; k];
1323        for (i, &lab) in labels.iter().enumerate() {
1324            out[lab] += p[i];
1325        }
1326        out
1327    }
1328
1329    fn l1(p: &[f64], q: &[f64]) -> f64 {
1330        p.iter().zip(q.iter()).map(|(&a, &b)| (a - b).abs()).sum()
1331    }
1332
1333    #[test]
1334    fn test_entropy_unchecked() {
1335        let p = [0.5, 0.5];
1336        let h = entropy_unchecked(&p);
1337        // -0.5*ln(0.5) - 0.5*ln(0.5) = -ln(0.5) = ln(2)
1338        assert!((h - LN_2).abs() < 1e-12);
1339    }
1340
1341    #[test]
1342    fn js_is_bounded_by_ln2() {
1343        let p = [1.0, 0.0];
1344        let q = [0.0, 1.0];
1345        let js = jensen_shannon_divergence(&p, &q, TOL).unwrap();
1346        assert!(js <= LN_2 + 1e-12);
1347        assert!(js >= 0.0);
1348    }
1349
1350    #[test]
1351    fn mutual_information_independent_is_zero() {
1352        // p(x,y) = p(x)p(y) ⇒ I(X;Y)=0
1353        let p_x = [0.5, 0.5];
1354        let p_y = [0.25, 0.75];
1355        // Row-major 2x2:
1356        // [0.125, 0.375,
1357        //  0.125, 0.375]
1358        let p_xy = [
1359            p_x[0] * p_y[0],
1360            p_x[0] * p_y[1],
1361            p_x[1] * p_y[0],
1362            p_x[1] * p_y[1],
1363        ];
1364        let mi = mutual_information(&p_xy, 2, 2, TOL).unwrap();
1365        assert!(mi.abs() < 1e-12, "mi={}", mi);
1366    }
1367
1368    #[test]
1369    fn mutual_information_perfect_correlation_is_ln2() {
1370        // X=Y uniform bit ⇒ I(X;Y)=ln 2 (nats)
1371        let p_xy = [0.5, 0.0, 0.0, 0.5]; // 2x2 diagonal
1372        let mi = mutual_information(&p_xy, 2, 2, TOL).unwrap();
1373        assert!((mi - LN_2).abs() < 1e-12, "mi={}", mi);
1374    }
1375
1376    #[test]
1377    fn bregman_squared_l2_matches_half_l2() {
1378        let gen = SquaredL2;
1379        let p = [1.0, 2.0, 3.0];
1380        let q = [1.5, 1.5, 2.5];
1381        let mut grad = [0.0; 3];
1382        let b = bregman_divergence(&gen, &p, &q, &mut grad).unwrap();
1383        let expected = 0.5
1384            * p.iter()
1385                .zip(q.iter())
1386                .map(|(&a, &b)| (a - b) * (a - b))
1387                .sum::<f64>();
1388        assert!((b - expected).abs() < 1e-12);
1389    }
1390
1391    // --- Entropy tests ---
1392
1393    #[test]
1394    fn entropy_nats_uniform_is_ln_n() {
1395        // Uniform distribution over n items: H = ln(n)
1396        for n in [2, 4, 8, 16] {
1397            let p: Vec<f64> = vec![1.0 / n as f64; n];
1398            let h = entropy_nats(&p, TOL).unwrap();
1399            let expected = (n as f64).ln();
1400            assert!((h - expected).abs() < 1e-12, "n={n}: h={h} expected={expected}");
1401        }
1402    }
1403
1404    #[test]
1405    fn entropy_nats_singleton_is_zero() {
1406        let h = entropy_nats(&[1.0], TOL).unwrap();
1407        assert!(h.abs() < 1e-15);
1408    }
1409
1410    #[test]
1411    fn entropy_bits_converts_correctly() {
1412        let p = [0.25, 0.75];
1413        let nats = entropy_nats(&p, TOL).unwrap();
1414        let bits = entropy_bits(&p, TOL).unwrap();
1415        assert!((bits - nats / LN_2).abs() < 1e-12);
1416    }
1417
1418    // --- Cross-entropy tests ---
1419
1420    #[test]
1421    fn cross_entropy_identity_h_pq_eq_h_p_plus_kl() {
1422        let p = [0.3, 0.7];
1423        let q = [0.5, 0.5];
1424        let h_pq = cross_entropy_nats(&p, &q, TOL).unwrap();
1425        let h_p = entropy_nats(&p, TOL).unwrap();
1426        let kl = kl_divergence(&p, &q, TOL).unwrap();
1427        assert!((h_pq - (h_p + kl)).abs() < 1e-12);
1428    }
1429
1430    #[test]
1431    fn cross_entropy_rejects_zero_q_with_positive_p() {
1432        let p = [0.5, 0.5];
1433        let q = [1.0, 0.0]; // q[1]=0 but p[1]=0.5
1434        assert!(cross_entropy_nats(&p, &q, TOL).is_err());
1435    }
1436
1437    // --- Validate / normalize tests ---
1438
1439    #[test]
1440    fn validate_simplex_accepts_valid() {
1441        assert!(validate_simplex(&[0.3, 0.7], TOL).is_ok());
1442        assert!(validate_simplex(&[1.0], TOL).is_ok());
1443    }
1444
1445    #[test]
1446    fn validate_simplex_rejects_bad_sum() {
1447        assert!(validate_simplex(&[0.3, 0.6], TOL).is_err()); // sum=0.9
1448    }
1449
1450    #[test]
1451    fn validate_simplex_rejects_negative() {
1452        assert!(validate_simplex(&[1.5, -0.5], TOL).is_err());
1453    }
1454
1455    #[test]
1456    fn validate_simplex_rejects_empty() {
1457        assert!(validate_simplex(&[], TOL).is_err());
1458    }
1459
1460    #[test]
1461    fn normalize_in_place_works() {
1462        let mut v = vec![2.0, 3.0];
1463        let s = normalize_in_place(&mut v).unwrap();
1464        assert!((s - 5.0).abs() < 1e-12);
1465        assert!((v[0] - 0.4).abs() < 1e-12);
1466        assert!((v[1] - 0.6).abs() < 1e-12);
1467    }
1468
1469    #[test]
1470    fn normalize_in_place_rejects_zero_sum() {
1471        let mut v = vec![0.0, 0.0];
1472        assert!(normalize_in_place(&mut v).is_err());
1473    }
1474
1475    // --- Hellinger / Bhattacharyya tests ---
1476
1477    #[test]
1478    fn hellinger_identical_is_zero() {
1479        let p = [0.25, 0.75];
1480        let h = hellinger(&p, &p, TOL).unwrap();
1481        assert!(h.abs() < 1e-12);
1482    }
1483
1484    #[test]
1485    fn hellinger_squared_in_unit_interval() {
1486        let p = [0.1, 0.9];
1487        let q = [0.9, 0.1];
1488        let h2 = hellinger_squared(&p, &q, TOL).unwrap();
1489        assert!(h2 >= -1e-12 && h2 <= 1.0 + 1e-12, "h2={h2}");
1490    }
1491
1492    #[test]
1493    fn bhattacharyya_coeff_identical_is_one() {
1494        let p = [0.3, 0.7];
1495        let bc = bhattacharyya_coeff(&p, &p, TOL).unwrap();
1496        assert!((bc - 1.0).abs() < 1e-12);
1497    }
1498
1499    #[test]
1500    fn bhattacharyya_distance_identical_is_zero() {
1501        let p = [0.5, 0.5];
1502        let d = bhattacharyya_distance(&p, &p, TOL).unwrap();
1503        assert!(d.abs() < 1e-12);
1504    }
1505
1506    // --- Renyi / Tsallis tests ---
1507
1508    #[test]
1509    fn renyi_alpha_half_on_simple_case() {
1510        let p = [0.5, 0.5];
1511        let q = [0.25, 0.75];
1512        // alpha=0.5 should be well-defined and non-negative
1513        let r = renyi_divergence(&p, &q, 0.5, TOL).unwrap();
1514        assert!(r >= -1e-12, "renyi={r}");
1515    }
1516
1517    #[test]
1518    fn renyi_identical_is_zero() {
1519        let p = [0.3, 0.7];
1520        let r = renyi_divergence(&p, &p, 2.0, TOL).unwrap();
1521        assert!(r.abs() < 1e-12, "renyi(p,p)={r}");
1522    }
1523
1524    #[test]
1525    fn tsallis_identical_is_zero() {
1526        let p = [0.4, 0.6];
1527        let t = tsallis_divergence(&p, &p, 2.0, TOL).unwrap();
1528        assert!(t.abs() < 1e-12, "tsallis(p,p)={t}");
1529    }
1530
1531    // --- Digamma test ---
1532
1533    #[test]
1534    fn digamma_at_one_is_neg_euler_mascheroni() {
1535        let psi1 = digamma(1.0);
1536        // digamma(1) = -gamma where gamma ~= 0.5772156649
1537        assert!((psi1 - (-0.5772156649)).abs() < 1e-8, "psi(1)={psi1}");
1538    }
1539
1540    #[test]
1541    fn digamma_recurrence_relation() {
1542        // digamma(x+1) = digamma(x) + 1/x
1543        for &x in &[1.0, 2.0, 3.5, 10.0] {
1544            let lhs = digamma(x + 1.0);
1545            let rhs = digamma(x) + 1.0 / x;
1546            assert!((lhs - rhs).abs() < 1e-8, "recurrence at x={x}: {lhs} vs {rhs}");
1547        }
1548    }
1549
1550    #[test]
1551    fn pmi_independent_is_zero() {
1552        // PMI(x,y) = log(p(x,y) / (p(x)*p(y))). If independent: p(x,y) = p(x)*p(y)
1553        let pmi_val = pmi(0.06, 0.3, 0.2); // 0.3 * 0.2 = 0.06
1554        assert!(pmi_val.abs() < 1e-10, "PMI of independent events should be 0: {pmi_val}");
1555    }
1556
1557    #[test]
1558    fn pmi_positive_for_correlated() {
1559        // If p(x,y) > p(x)*p(y), events are positively correlated
1560        let pmi_val = pmi(0.4, 0.5, 0.5); // 0.4 > 0.5*0.5 = 0.25
1561        assert!(pmi_val > 0.0, "correlated events should have positive PMI: {pmi_val}");
1562    }
1563
1564    #[test]
1565    fn renyi_approaches_kl_as_alpha_to_one() {
1566        let p = [0.3, 0.7];
1567        let q = [0.5, 0.5];
1568        let tol = 1e-9;
1569        let kl = kl_divergence(&p, &q, tol).unwrap();
1570        // Renyi(alpha) -> KL as alpha -> 1
1571        let r099 = renyi_divergence(&p, &q, 0.99, tol).unwrap();
1572        let r0999 = renyi_divergence(&p, &q, 0.999, tol).unwrap();
1573        assert!((r099 - kl).abs() < 0.01, "Renyi(0.99)={r099}, KL={kl}");
1574        assert!((r0999 - kl).abs() < 0.001, "Renyi(0.999)={r0999}, KL={kl}");
1575    }
1576
1577    #[test]
1578    fn amari_alpha_neg1_is_kl_forward() {
1579        // Amari alpha=-1 returns KL(p||q) per the implementation
1580        let p = [0.3, 0.7];
1581        let q = [0.5, 0.5];
1582        let tol = 1e-9;
1583        let kl_pq = kl_divergence(&p, &q, tol).unwrap();
1584        let amari = amari_alpha_divergence(&p, &q, -1.0, tol).unwrap();
1585        assert!((amari - kl_pq).abs() < 1e-6, "Amari(-1)={amari}, KL(p||q)={kl_pq}");
1586    }
1587
1588    #[test]
1589    fn amari_alpha_pos1_is_kl_reverse() {
1590        // Amari alpha=+1 returns KL(q||p) per the implementation
1591        let p = [0.3, 0.7];
1592        let q = [0.5, 0.5];
1593        let tol = 1e-9;
1594        let kl_qp = kl_divergence(&q, &p, tol).unwrap();
1595        let amari = amari_alpha_divergence(&p, &q, 1.0, tol).unwrap();
1596        assert!((amari - kl_qp).abs() < 1e-6, "Amari(1)={amari}, KL(q||p)={kl_qp}");
1597    }
1598
1599    #[test]
1600    fn csiszar_with_kl_generator_matches_kl() {
1601        // f(t) = t*ln(t) gives KL divergence
1602        let p = [0.3, 0.7];
1603        let q = [0.5, 0.5];
1604        let tol = 1e-9;
1605        let kl = kl_divergence(&p, &q, tol).unwrap();
1606        let cs = csiszar_f_divergence(&p, &q, |t| t * t.ln(), tol).unwrap();
1607        assert!((cs - kl).abs() < 1e-6, "Csiszar(t*ln(t))={cs}, KL={kl}");
1608    }
1609
1610    #[test]
1611    fn mutual_information_deterministic_equals_entropy() {
1612        // If Y = f(X), MI(X;Y) = H(X)
1613        // Joint: p(x=0,y=0)=0.3, p(x=1,y=1)=0.7
1614        let p_xy = [0.3, 0.0, 0.0, 0.7]; // 2x2 joint
1615        let mi = mutual_information(&p_xy, 2, 2, 1e-9).unwrap();
1616        let h_x = entropy_nats(&[0.3, 0.7], 1e-9).unwrap();
1617        assert!((mi - h_x).abs() < 1e-6, "MI={mi}, H(X)={h_x}");
1618    }
1619
1620    proptest! {
1621        #[test]
1622        fn kl_is_nonnegative(p in simplex_vec_pos(8, 1e-6), q in simplex_vec_pos(8, 1e-6)) {
1623            let d = kl_divergence(&p, &q, 1e-6).unwrap();
1624            prop_assert!(d >= -1e-12);
1625        }
1626
1627        #[test]
1628        fn js_is_bounded(p in simplex_vec(16), q in simplex_vec(16)) {
1629            let js = jensen_shannon_divergence(&p, &q, 1e-6).unwrap();
1630            prop_assert!(js >= -1e-12);
1631            prop_assert!(js <= LN_2 + 1e-9);
1632        }
1633
1634        #[test]
1635        fn prop_kl_gaussians_is_nonnegative(
1636            mu1 in prop::collection::vec(-10.0f64..10.0, 1..16),
1637            std1 in prop::collection::vec(0.1f64..5.0, 1..16),
1638            mu2 in prop::collection::vec(-10.0f64..10.0, 1..16),
1639            std2 in prop::collection::vec(0.1f64..5.0, 1..16),
1640        ) {
1641            let n = mu1.len().min(std1.len()).min(mu2.len()).min(std2.len());
1642            let d = kl_divergence_gaussians(&mu1[..n], &std1[..n], &mu2[..n], &std2[..n]).unwrap();
1643            // KL divergence is always non-negative.
1644            prop_assert!(d >= -1e-12);
1645        }
1646
1647        #[test]
1648        fn prop_kl_gaussians_is_zero_for_identical(
1649            mu in prop::collection::vec(-10.0f64..10.0, 1..16),
1650            std in prop::collection::vec(0.1f64..5.0, 1..16),
1651        ) {
1652            let n = mu.len().min(std.len());
1653            let d = kl_divergence_gaussians(&mu[..n], &std[..n], &mu[..n], &std[..n]).unwrap();
1654            prop_assert!(d.abs() < 1e-12);
1655        }
1656
1657        #[test]
1658        fn f_divergence_monotone_under_coarse_graining(
1659            p in simplex_vec_pos(12, 1e-6),
1660            q in simplex_vec_pos(12, 1e-6),
1661            labels in random_partition(12),
1662        ) {
1663            // Use KL as an f-divergence instance: f(t)=t ln t.
1664            // D_KL(p||q) = Σ q_i f(p_i/q_i).
1665            let f = |t: f64| if t == 0.0 { 0.0 } else { t * t.ln() };
1666            let d_f = csiszar_f_divergence(&p, &q, f, 1e-6).unwrap();
1667
1668            let pc = coarse_grain(&p, &labels);
1669            let qc = coarse_grain(&q, &labels);
1670            let d_fc = csiszar_f_divergence(&pc, &qc, f, 1e-6).unwrap();
1671
1672            // Coarse graining should not increase.
1673            prop_assert!(d_fc <= d_f + 1e-9);
1674        }
1675    }
1676
1677    // Heavier “theorem-ish” checks: keep case count modest so `cargo test` stays fast.
1678    proptest! {
1679        #![proptest_config(ProptestConfig { cases: 64, .. ProptestConfig::default() })]
1680
1681        #[test]
1682        fn pinsker_kl_lower_bounds_l1_squared(
1683            p in simplex_vec_pos(16, 1e-6),
1684            q in simplex_vec_pos(16, 1e-6),
1685        ) {
1686            // Pinsker: TV(p,q)^2 <= (1/2) KL(p||q)
1687            // where TV = (1/2)||p-q||_1. Rearranged: KL(p||q) >= 0.5 * ||p-q||_1^2.
1688            let kl = kl_divergence(&p, &q, 1e-6).unwrap();
1689            let d1 = l1(&p, &q);
1690            prop_assert!(kl + 1e-9 >= 0.5 * d1 * d1, "kl={kl} l1={d1}");
1691        }
1692
1693        #[test]
1694        fn sqrt_js_satisfies_triangle_inequality(
1695            p in simplex_vec(12),
1696            q in simplex_vec(12),
1697            r in simplex_vec(12),
1698        ) {
1699            // Known fact: sqrt(JS) is a metric on the simplex.
1700            let js_pq = jensen_shannon_divergence(&p, &q, 1e-6).unwrap().max(0.0).sqrt();
1701            let js_qr = jensen_shannon_divergence(&q, &r, 1e-6).unwrap().max(0.0).sqrt();
1702            let js_pr = jensen_shannon_divergence(&p, &r, 1e-6).unwrap().max(0.0).sqrt();
1703            prop_assert!(js_pr <= js_pq + js_qr + 1e-7, "js_pr={js_pr} js_pq+js_qr={}", js_pq+js_qr);
1704        }
1705
1706        #[test]
1707        fn mutual_information_equals_kl_to_product(
1708            // Ensure strictly positive so KL domains are satisfied.
1709            p_xy in simplex_vec_pos(16, 1e-6),
1710            nx in 2usize..=4,
1711            ny in 2usize..=4,
1712        ) {
1713            // We need p_xy to have length nx*ny; we will truncate/renormalize a fixed-length draw.
1714            let n = nx * ny;
1715            let mut joint = p_xy;
1716            joint.truncate(n);
1717            // Renormalize after truncation.
1718            let _ = normalize_in_place(&mut joint).unwrap();
1719
1720            // Compute MI via the dedicated function.
1721            let mi = mutual_information(&joint, nx, ny, 1e-6).unwrap();
1722
1723            // Compute product of marginals and KL(joint || product).
1724            let mut p_x = vec![0.0; nx];
1725            let mut p_y = vec![0.0; ny];
1726            for i in 0..nx {
1727                for j in 0..ny {
1728                    let p = joint[i * ny + j];
1729                    p_x[i] += p;
1730                    p_y[j] += p;
1731                }
1732            }
1733            let mut prod = vec![0.0; n];
1734            for i in 0..nx {
1735                for j in 0..ny {
1736                    prod[i * ny + j] = p_x[i] * p_y[j];
1737                }
1738            }
1739            let kl = kl_divergence(&joint, &prod, 1e-6).unwrap();
1740
1741            prop_assert!((mi - kl).abs() < 1e-9, "mi={mi} kl={kl}");
1742        }
1743
1744        #[test]
1745        fn hellinger_satisfies_triangle_inequality(
1746            p in simplex_vec(8),
1747            q in simplex_vec(8),
1748            r in simplex_vec(8),
1749        ) {
1750            let h_pq = hellinger(&p, &q, 1e-6).unwrap();
1751            let h_qr = hellinger(&q, &r, 1e-6).unwrap();
1752            let h_pr = hellinger(&p, &r, 1e-6).unwrap();
1753            prop_assert!(h_pr <= h_pq + h_qr + 1e-7, "h_pr={h_pr} h_pq+h_qr={}", h_pq + h_qr);
1754        }
1755    }
1756
1757    // --- total_bregman_divergence ---
1758
1759    #[test]
1760    fn total_bregman_le_bregman() {
1761        // tB_F(p, q) <= B_F(p, q) because the denominator sqrt(1 + ||grad||^2) >= 1.
1762        let gen = SquaredL2;
1763        let p = [1.0, 2.0, 3.0];
1764        let q = [4.0, 5.0, 6.0];
1765        let mut grad1 = [0.0; 3];
1766        let mut grad2 = [0.0; 3];
1767        let b = bregman_divergence(&gen, &p, &q, &mut grad1).unwrap();
1768        let tb = total_bregman_divergence(&gen, &p, &q, &mut grad2).unwrap();
1769        assert!(tb <= b + 1e-12, "total_bregman={tb} > bregman={b}");
1770        assert!(tb >= 0.0);
1771    }
1772
1773    #[test]
1774    fn total_bregman_is_zero_for_identical() {
1775        let gen = SquaredL2;
1776        let p = [1.0, 2.0];
1777        let mut grad = [0.0; 2];
1778        let tb = total_bregman_divergence(&gen, &p, &p, &mut grad).unwrap();
1779        assert!(tb.abs() < 1e-15);
1780    }
1781
1782    // --- rho_alpha ---
1783
1784    #[test]
1785    fn rho_alpha_self_is_one() {
1786        let p = [0.1, 0.2, 0.3, 0.4];
1787        for alpha in [0.0, 0.25, 0.5, 0.75, 1.0, 2.0, -1.0] {
1788            let r = rho_alpha(&p, &p, alpha, TOL).unwrap();
1789            assert!((r - 1.0).abs() < 1e-10, "rho_alpha(p,p,{alpha})={r}");
1790        }
1791    }
1792
1793    // --- digamma negative domain ---
1794
1795    #[test]
1796    fn digamma_nonpositive_is_nan() {
1797        assert!(digamma(0.0).is_nan());
1798        assert!(digamma(-1.0).is_nan());
1799        assert!(digamma(-100.0).is_nan());
1800    }
1801
1802    // --- pmi edge cases ---
1803
1804    #[test]
1805    fn pmi_zero_joint_returns_zero() {
1806        assert_eq!(pmi(0.0, 0.5, 0.5), 0.0);
1807    }
1808
1809    #[test]
1810    fn pmi_zero_marginal_returns_zero() {
1811        // When px or py is zero, return 0 by convention.
1812        assert_eq!(pmi(0.1, 0.0, 0.5), 0.0);
1813        assert_eq!(pmi(0.1, 0.5, 0.0), 0.0);
1814    }
1815
1816    #[test]
1817    fn pmi_all_zero_returns_zero() {
1818        assert_eq!(pmi(0.0, 0.0, 0.0), 0.0);
1819    }
1820}