Skip to main content

logp/
lib.rs

1//! # logp
2//!
3//! Information theory primitives: entropies and divergences.
4//!
5//! ## Scope
6//!
7//! Scalar information measures that appear across clustering, ranking,
8//! evaluation, and geometry:
9//!
10//! - Shannon entropy and cross-entropy
11//! - KL / Jensen–Shannon divergences
12//! - Csiszár \(f\)-divergences (a.k.a. *information monotone* divergences)
13//! - Bhattacharyya coefficient, Rényi/Tsallis families
14//! - Bregman divergences (convex-analytic, not generally monotone)
15//!
16//! ## Distances vs divergences (terminology that prevents bugs)
17//!
18//! A **divergence** \(D(p:q)\) is usually required to satisfy:
19//!
20//! - \(D(p:q) \ge 0\)
21//! - \(D(p:p) = 0\)
22//!
23//! but it is typically **not** symmetric and **not** a metric (no triangle inequality).
24//! Many failures in downstream code are caused by treating a divergence as a distance metric.
25//!
26//! ## Key invariants (what tests should enforce)
27//!
28//! - **Jensen–Shannon** is bounded on the simplex:
29//!   \(0 \le JS(p,q) \le \ln 2\) (nats).
30//! - **Csiszár \(f\)-divergences** are monotone under coarse-graining (Markov kernels):
31//!   merging bins cannot increase the divergence.
32//!
33//! ## Further reading
34//!
35//! - Frank Nielsen, “Divergences” portal (taxonomy diagrams + references):
36//!   <https://franknielsen.github.io/Divergence/index.html>
37//! - `nocotan/awesome-information-geometry` (curated reading list):
38//!   <https://github.com/nocotan/awesome-information-geometry>
39//! - Csiszár (1967): \(f\)-divergences and information monotonicity.
40//! - Amari & Nagaoka (2000): *Methods of Information Geometry*.
41//!
42//! ## Taxonomy of Divergences (Nielsen)
43//!
44//! | Family | Generator | Key Property |
45//! |---|---|---|
46//! | **f-divergences** | Convex \(f(t)\) with \(f(1)=0\) | Monotone under Markov morphisms (coarse-graining) |
47//! | **Bregman** | Convex \(F(x)\) | Dually flat geometry; generalized Pythagorean theorem |
48//! | **Jensen-Shannon** | \(f\)-div + metric | Symmetric, bounded \(\[0, \ln 2\]\), \(\sqrt{JS}\) is a metric |
49//! | **Alpha** | \(\rho_\alpha = \int p^\alpha q^{1-\alpha}\) | Encodes Rényi, Tsallis, Bhattacharyya, Hellinger |
50//!
51//! ## References
52//!
53//! - Shannon (1948). "A Mathematical Theory of Communication"
54//! - Cover & Thomas (2006). "Elements of Information Theory"
55
56#![forbid(unsafe_code)]
57#![warn(missing_docs)]
58
59use thiserror::Error;
60
61mod ksg;
62pub use ksg::{mutual_information_ksg, KsgVariant};
63
64use core::f64::consts::LN_2;
65
66/// KL Divergence between two diagonal Multivariate Gaussians.
67///
68/// Assumes **diagonal** covariance (independent dimensions): each dimension is
69/// parameterized by a mean and a standard deviation, with no cross-covariance terms.
70/// For full-covariance Gaussians, the formula involves log-determinant ratios and
71/// matrix inverses; this function does not handle that case.
72///
73/// Returns 0.5 * Σ [ (std1/std2)^2 + (mu2-mu1)^2 / std2^2 - 1 + 2*ln(std2/std1) ]
74///
75/// # Examples
76///
77/// ```
78/// # use logp::kl_divergence_gaussians;
79/// // KL(N || N) = 0 (identical Gaussians).
80/// let mu = [0.0, 1.0];
81/// let std = [1.0, 2.0];
82/// let kl = kl_divergence_gaussians(&mu, &std, &mu, &std).unwrap();
83/// assert!(kl.abs() < 1e-12);
84///
85/// // KL is non-negative for distinct Gaussians.
86/// let kl = kl_divergence_gaussians(&[0.0], &[1.0], &[1.0], &[1.0]).unwrap();
87/// assert!(kl >= 0.0);
88/// ```
89pub fn kl_divergence_gaussians(
90    mu1: &[f64],
91    std1: &[f64],
92    mu2: &[f64],
93    std2: &[f64],
94) -> Result<f64> {
95    ensure_same_len(mu1, std1)?;
96    ensure_same_len(mu1, mu2)?;
97    ensure_same_len(mu1, std2)?;
98
99    let mut kl = 0.0;
100    for (((&m1, &s1), &m2), &s2) in mu1.iter().zip(std1).zip(mu2).zip(std2) {
101        if s1 <= 0.0 || s2 <= 0.0 {
102            return Err(Error::Domain("standard deviation must be positive"));
103        }
104        let v1 = s1 * s1;
105        let v2 = s2 * s2;
106        kl += (v1 / v2) + (m2 - m1).powi(2) / v2 - 1.0 + 2.0 * (s2.ln() - s1.ln());
107    }
108    Ok(0.5 * kl)
109}
110
111/// Errors for information-measure computations.
112#[derive(Debug, Error)]
113pub enum Error {
114    /// Two input slices have different lengths (e.g., `p` and `q` in a divergence).
115    #[error("length mismatch: {0} vs {1}")]
116    LengthMismatch(usize, usize),
117
118    /// An input slice is empty where at least one element is required.
119    #[error("empty input")]
120    Empty,
121
122    /// An entry is NaN or infinite where a finite value is required.
123    #[error("non-finite entry at index {idx}: {value}")]
124    NonFinite {
125        /// Position of the offending entry.
126        idx: usize,
127        /// The non-finite value found.
128        value: f64,
129    },
130
131    /// An entry is negative where a nonnegative value is required (probability distributions).
132    #[error("negative entry at index {idx}: {value}")]
133    Negative {
134        /// Position of the offending entry.
135        idx: usize,
136        /// The negative value found.
137        value: f64,
138    },
139
140    /// The distribution does not sum to 1 within the specified tolerance.
141    #[error("not normalized (expected sum≈1): sum={sum}")]
142    NotNormalized {
143        /// Actual sum of the distribution.
144        sum: f64,
145    },
146
147    /// The alpha parameter is outside its valid domain (e.g., negative or non-finite).
148    #[error("invalid alpha: {alpha} (must be finite and not equal to {forbidden})")]
149    InvalidAlpha {
150        /// The invalid alpha value provided.
151        alpha: f64,
152        /// The value alpha must not equal (e.g., 0.0).
153        forbidden: f64,
154    },
155
156    /// Catch-all for domain violations: zero standard deviation, q_i=0 while p_i>0,
157    /// insufficient sample size for KSG, and similar precondition failures.
158    #[error("domain error: {0}")]
159    Domain(&'static str),
160}
161
162/// Convenience alias for `Result<T, logp::Error>`.
163pub type Result<T> = core::result::Result<T, Error>;
164
165fn ensure_nonempty(x: &[f64]) -> Result<()> {
166    if x.is_empty() {
167        return Err(Error::Empty);
168    }
169    Ok(())
170}
171
172fn ensure_same_len(a: &[f64], b: &[f64]) -> Result<()> {
173    if a.len() != b.len() {
174        return Err(Error::LengthMismatch(a.len(), b.len()));
175    }
176    Ok(())
177}
178
179fn ensure_nonnegative(x: &[f64]) -> Result<()> {
180    for (i, &v) in x.iter().enumerate() {
181        if !v.is_finite() {
182            return Err(Error::NonFinite { idx: i, value: v });
183        }
184        if v < 0.0 {
185            return Err(Error::Negative { idx: i, value: v });
186        }
187    }
188    Ok(())
189}
190
191fn sum(x: &[f64]) -> f64 {
192    x.iter().sum()
193}
194
195/// Validate that `p` is a probability distribution on the simplex (within `tol`).
196///
197/// # Examples
198///
199/// ```
200/// # use logp::validate_simplex;
201/// // Valid simplex.
202/// assert!(validate_simplex(&[0.3, 0.7], 1e-9).is_ok());
203/// assert!(validate_simplex(&[1.0], 1e-9).is_ok());
204///
205/// // Rejects bad sum.
206/// assert!(validate_simplex(&[0.3, 0.6], 1e-9).is_err());
207///
208/// // Rejects negative entries.
209/// assert!(validate_simplex(&[1.5, -0.5], 1e-9).is_err());
210///
211/// // Rejects empty input.
212/// assert!(validate_simplex(&[], 1e-9).is_err());
213/// ```
214pub fn validate_simplex(p: &[f64], tol: f64) -> Result<()> {
215    ensure_nonempty(p)?;
216    ensure_nonnegative(p)?;
217    let s = sum(p);
218    if (s - 1.0).abs() > tol {
219        return Err(Error::NotNormalized { sum: s });
220    }
221    Ok(())
222}
223
224/// Normalize a nonnegative vector in-place to sum to 1.
225///
226/// Returns the original sum.
227///
228/// # Examples
229///
230/// ```
231/// # use logp::normalize_in_place;
232/// let mut v = vec![2.0, 3.0, 5.0];
233/// let original_sum = normalize_in_place(&mut v).unwrap();
234/// assert!((original_sum - 10.0).abs() < 1e-12);
235/// assert!((v[0] - 0.2).abs() < 1e-12);
236/// assert!((v[1] - 0.3).abs() < 1e-12);
237/// assert!((v[2] - 0.5).abs() < 1e-12);
238///
239/// // Rejects all-zero input.
240/// assert!(normalize_in_place(&mut vec![0.0, 0.0]).is_err());
241/// ```
242pub fn normalize_in_place(p: &mut [f64]) -> Result<f64> {
243    ensure_nonempty(p)?;
244    ensure_nonnegative(p)?;
245    let s = sum(p);
246    if s <= 0.0 {
247        return Err(Error::Domain("cannot normalize: sum <= 0"));
248    }
249    for v in p.iter_mut() {
250        *v /= s;
251    }
252    Ok(s)
253}
254
255/// Shannon entropy in nats: the expected surprise under distribution \(p\).
256///
257/// \[H(p) = -\sum_i p_i \ln p_i\]
258///
259/// # Key properties
260///
261/// - **Non-negative**: \(H(p) \ge 0\), with equality iff \(p\) is a delta (point mass).
262/// - **Maximized by uniform**: among distributions on \(n\) outcomes,
263///   \(H(p) \le \ln n\), with equality iff \(p_i = 1/n\) for all \(i\).
264/// - **Concavity**: \(H\) is a concave function of \(p\) on the simplex.
265///   Mixing distributions never decreases entropy.
266/// - **Units**: result is in nats (base \(e\)); divide by \(\ln 2\) for bits.
267/// - **Joint distributions**: for joint entropy \(H(X,Y)\), flatten the joint
268///   distribution to a 1D slice and pass it directly. Shannon entropy of the
269///   flattened joint is mathematically identical to joint entropy.
270///
271/// # Domain
272///
273/// Requires `p` to be a valid simplex distribution (within `tol`).
274///
275/// # Examples
276///
277/// ```
278/// # use logp::entropy_nats;
279/// // Uniform distribution over 4 outcomes: H = ln(4).
280/// let p = [0.25, 0.25, 0.25, 0.25];
281/// let h = entropy_nats(&p, 1e-9).unwrap();
282/// assert!((h - 4.0_f64.ln()).abs() < 1e-12);
283///
284/// // Delta (point mass): H = 0.
285/// let delta = [1.0, 0.0, 0.0];
286/// assert!(entropy_nats(&delta, 1e-9).unwrap().abs() < 1e-15);
287/// ```
288pub fn entropy_nats(p: &[f64], tol: f64) -> Result<f64> {
289    validate_simplex(p, tol)?;
290    let mut h = 0.0;
291    for &pi in p {
292        if pi > 0.0 {
293            h -= pi * pi.ln();
294        }
295    }
296    Ok(h)
297}
298
299/// Shannon entropy in bits.
300///
301/// # Examples
302///
303/// ```
304/// # use logp::{entropy_bits, entropy_nats};
305/// // Fair coin: H = 1 bit.
306/// let p = [0.5, 0.5];
307/// let bits = entropy_bits(&p, 1e-9).unwrap();
308/// assert!((bits - 1.0).abs() < 1e-12);
309///
310/// // Consistent with nats / ln(2).
311/// let nats = entropy_nats(&p, 1e-9).unwrap();
312/// assert!((bits - nats / core::f64::consts::LN_2).abs() < 1e-12);
313/// ```
314pub fn entropy_bits(p: &[f64], tol: f64) -> Result<f64> {
315    Ok(entropy_nats(p, tol)? / LN_2)
316}
317
318/// Fast Shannon entropy calculation without simplex validation.
319///
320/// Used in performance-critical loops like Sinkhorn iteration for Optimal Transport.
321///
322/// # Invariant
323/// Assumes `p` is non-negative and normalized.
324///
325/// # Examples
326///
327/// ```
328/// # use logp::entropy_unchecked;
329/// // Fair coin: H = ln(2).
330/// let h = entropy_unchecked(&[0.5, 0.5]);
331/// assert!((h - core::f64::consts::LN_2).abs() < 1e-12);
332///
333/// // Agrees with the checked version on valid input.
334/// let p = [0.3, 0.7];
335/// let h_checked = logp::entropy_nats(&p, 1e-9).unwrap();
336/// assert!((entropy_unchecked(&p) - h_checked).abs() < 1e-15);
337/// ```
338#[inline]
339pub fn entropy_unchecked(p: &[f64]) -> f64 {
340    let mut h = 0.0;
341    for &pi in p {
342        if pi > 0.0 {
343            h -= pi * pi.ln();
344        }
345    }
346    h
347}
348
349/// Renyi entropy in nats: a one-parameter generalization of Shannon entropy.
350///
351/// \[H_\alpha(p) = \frac{1}{1-\alpha} \ln \sum_i p_i^\alpha, \quad \alpha > 0,\;\alpha \ne 1\]
352///
353/// # Key properties
354///
355/// - **Limit to Shannon**: \(\lim_{\alpha \to 1} H_\alpha(p) = H(p)\) (Shannon entropy).
356/// - **Alpha = 0**: \(H_0(p) = \ln |\text{supp}(p)|\), the log of the support size (Hartley entropy).
357/// - **Alpha = 2**: \(H_2(p) = -\ln \sum_i p_i^2\), the collision entropy (negative log of
358///   the probability that two independent draws match).
359/// - **Alpha = infinity**: \(H_\infty(p) = -\ln \max_i p_i\), the min-entropy (worst-case surprise).
360/// - **Monotone in alpha**: \(H_\alpha(p)\) is non-increasing in \(\alpha\).
361/// - **Non-negative**: \(H_\alpha(p) \ge 0\).
362///
363/// # Examples
364///
365/// ```
366/// # use logp::renyi_entropy;
367/// // Uniform over 4: H_alpha = ln(4) for all alpha.
368/// let p = [0.25, 0.25, 0.25, 0.25];
369/// let h = renyi_entropy(&p, 2.0, 1e-9).unwrap();
370/// assert!((h - 4.0_f64.ln()).abs() < 1e-12);
371///
372/// // Collision entropy: H_2 = -ln(sum(p_i^2)).
373/// let q = [0.3, 0.7];
374/// let h2 = renyi_entropy(&q, 2.0, 1e-9).unwrap();
375/// let expected = -(0.3_f64.powi(2) + 0.7_f64.powi(2)).ln();
376/// assert!((h2 - expected).abs() < 1e-12);
377/// ```
378pub fn renyi_entropy(p: &[f64], alpha: f64, tol: f64) -> Result<f64> {
379    validate_simplex(p, tol)?;
380    if !alpha.is_finite() || alpha < 0.0 {
381        return Err(Error::InvalidAlpha {
382            alpha,
383            forbidden: f64::NAN,
384        });
385    }
386    if (alpha - 1.0).abs() < 1e-12 {
387        return entropy_nats(p, tol);
388    }
389    let mut s = 0.0;
390    for &pi in p {
391        if pi > 0.0 {
392            s += pi.powf(alpha);
393        }
394    }
395    if s <= 0.0 {
396        return Err(Error::Domain("renyi_entropy: sum of p_i^alpha <= 0"));
397    }
398    Ok(s.ln() / (1.0 - alpha))
399}
400
401/// Tsallis entropy: a non-extensive generalization of Shannon entropy from
402/// statistical mechanics.
403///
404/// \[S_\alpha(p) = \frac{1}{\alpha - 1}\left(1 - \sum_i p_i^\alpha\right), \quad \alpha \ne 1\]
405///
406/// # Key properties
407///
408/// - **Limit to Shannon**: \(\lim_{\alpha \to 1} S_\alpha(p) = H(p)\).
409/// - **Non-extensive**: for independent systems \(A, B\):
410///   \(S_\alpha(A \otimes B) = S_\alpha(A) + S_\alpha(B) + (1-\alpha)\,S_\alpha(A)\,S_\alpha(B)\).
411/// - **Non-negative**: \(S_\alpha(p) \ge 0\).
412/// - **Connection to Renyi**: related via \(S_\alpha = \frac{e^{(1-\alpha)H_\alpha} - 1}{\alpha - 1}\).
413///
414/// # Examples
415///
416/// ```
417/// # use logp::tsallis_entropy;
418/// // Uniform over 4: S_2 = 1 - 1/4 = 0.75.
419/// let p = [0.25, 0.25, 0.25, 0.25];
420/// let s = tsallis_entropy(&p, 2.0, 1e-9).unwrap();
421/// assert!((s - 0.75).abs() < 1e-12);
422///
423/// // Delta: S_alpha = 0 for any alpha.
424/// let delta = [1.0, 0.0, 0.0];
425/// assert!(tsallis_entropy(&delta, 2.0, 1e-9).unwrap().abs() < 1e-12);
426/// ```
427pub fn tsallis_entropy(p: &[f64], alpha: f64, tol: f64) -> Result<f64> {
428    validate_simplex(p, tol)?;
429    if !alpha.is_finite() || alpha < 0.0 {
430        return Err(Error::InvalidAlpha {
431            alpha,
432            forbidden: f64::NAN,
433        });
434    }
435    if (alpha - 1.0).abs() < 1e-12 {
436        return entropy_nats(p, tol);
437    }
438    let mut s = 0.0;
439    for &pi in p {
440        if pi > 0.0 {
441            s += pi.powf(alpha);
442        }
443    }
444    Ok((1.0 - s) / (alpha - 1.0))
445}
446
447/// Cross-entropy in nats: the expected code length when using model \(q\) to encode
448/// data drawn from true distribution \(p\).
449///
450/// \[H(p, q) = -\sum_i p_i \ln q_i\]
451///
452/// # Key properties
453///
454/// - **Decomposition identity**: cross-entropy splits into entropy plus KL divergence:
455///   \(H(p, q) = H(p) + D_{KL}(p \| q)\).
456///   This means \(H(p, q) \ge H(p)\) with equality iff \(p = q\).
457/// - **Loss function**: minimizing \(H(p, q)\) over \(q\) is equivalent to minimizing
458///   \(D_{KL}(p \| q)\), which is why cross-entropy is the standard classification loss.
459/// - **Not symmetric**: \(H(p, q) \ne H(q, p)\) in general.
460///
461/// # Domain
462///
463/// `p` must be on the simplex; `q` must be nonnegative and normalized; and
464/// whenever `p_i > 0`, we require `q_i > 0` (otherwise cross-entropy is infinite).
465///
466/// # Examples
467///
468/// ```
469/// # use logp::{cross_entropy_nats, entropy_nats, kl_divergence};
470/// let p = [0.3, 0.7];
471/// let q = [0.5, 0.5];
472/// let h_pq = cross_entropy_nats(&p, &q, 1e-9).unwrap();
473///
474/// // Decomposition: H(p,q) = H(p) + KL(p||q).
475/// let h_p = entropy_nats(&p, 1e-9).unwrap();
476/// let kl = kl_divergence(&p, &q, 1e-9).unwrap();
477/// assert!((h_pq - (h_p + kl)).abs() < 1e-12);
478///
479/// // Self-cross-entropy equals entropy: H(p,p) = H(p).
480/// let h_pp = cross_entropy_nats(&p, &p, 1e-9).unwrap();
481/// assert!((h_pp - h_p).abs() < 1e-12);
482/// ```
483pub fn cross_entropy_nats(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
484    validate_simplex(p, tol)?;
485    validate_simplex(q, tol)?;
486    let mut h = 0.0;
487    for (&pi, &qi) in p.iter().zip(q.iter()) {
488        if pi == 0.0 {
489            continue;
490        }
491        if qi <= 0.0 {
492            return Err(Error::Domain("cross-entropy undefined: q_i=0 while p_i>0"));
493        }
494        h -= pi * qi.ln();
495    }
496    Ok(h)
497}
498
499/// Kullback--Leibler divergence in nats: the information lost when \(q\) is used to
500/// approximate \(p\).
501///
502/// \[D_{KL}(p \| q) = \sum_i p_i \ln \frac{p_i}{q_i}\]
503///
504/// # Key properties
505///
506/// - **Gibbs' inequality**: \(D_{KL}(p \| q) \ge 0\), with equality iff \(p = q\).
507///   This follows directly from Jensen's inequality applied to \(-\ln\).
508/// - **Not symmetric**: \(D_{KL}(p \| q) \ne D_{KL}(q \| p)\) in general;
509///   this is why KL is a divergence, not a distance.
510/// - **Not bounded above**: KL can be arbitrarily large when supports differ.
511/// - **Connection to MLE**: minimizing \(D_{KL}(p_{data} \| q_\theta)\) over \(\theta\)
512///   is equivalent to maximum likelihood estimation.
513/// - **Additive for independent distributions**: if \(p = p_1 \otimes p_2\) and
514///   \(q = q_1 \otimes q_2\), then
515///   \(D_{KL}(p \| q) = D_{KL}(p_1 \| q_1) + D_{KL}(p_2 \| q_2)\).
516///
517/// # Domain
518///
519/// `p` and `q` must be valid simplex distributions; and whenever `p_i > 0`,
520/// we require `q_i > 0`.
521///
522/// # Examples
523///
524/// ```
525/// # use logp::kl_divergence;
526/// // KL(p || p) = 0 (Gibbs' inequality, tight case).
527/// let p = [0.2, 0.3, 0.5];
528/// assert!(kl_divergence(&p, &p, 1e-9).unwrap().abs() < 1e-15);
529///
530/// // KL is non-negative.
531/// let q = [0.5, 0.25, 0.25];
532/// assert!(kl_divergence(&p, &q, 1e-9).unwrap() >= 0.0);
533///
534/// // Not symmetric in general.
535/// let kl_pq = kl_divergence(&p, &q, 1e-9).unwrap();
536/// let kl_qp = kl_divergence(&q, &p, 1e-9).unwrap();
537/// assert!((kl_pq - kl_qp).abs() > 1e-6);
538/// ```
539pub fn kl_divergence(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
540    ensure_same_len(p, q)?;
541    validate_simplex(p, tol)?;
542    validate_simplex(q, tol)?;
543    let mut d = 0.0;
544    for (&pi, &qi) in p.iter().zip(q.iter()) {
545        if pi == 0.0 {
546            continue;
547        }
548        if qi <= 0.0 {
549            return Err(Error::Domain("KL undefined: q_i=0 while p_i>0"));
550        }
551        d += pi * (pi / qi).ln();
552    }
553    Ok(d)
554}
555
556/// Jensen--Shannon divergence in nats: a symmetric, bounded smoothing of KL divergence.
557///
558/// \[JS(p, q) = \tfrac{1}{2} D_{KL}(p \| m) + \tfrac{1}{2} D_{KL}(q \| m), \quad m = \tfrac{1}{2}(p + q)\]
559///
560/// # Key properties
561///
562/// - **Symmetric**: \(JS(p, q) = JS(q, p)\), unlike KL.
563/// - **Bounded**: \(0 \le JS(p, q) \le \ln 2\). The upper bound is attained when \(p\)
564///   and \(q\) have disjoint supports.
565/// - **Square root is a metric**: \(\sqrt{JS(p, q)}\) satisfies the triangle inequality
566///   (Endres & Schindelin, 2003), so it can be used as a proper distance function.
567/// - **Connection to mutual information**: \(JS(p, q) = I(X; Z)\) where \(Z\) is a
568///   fair coin selecting between \(p\) and \(q\), and \(X\) is drawn from the selected
569///   distribution.
570/// - **Always finite**: because \(m_i > 0\) whenever \(p_i > 0\) or \(q_i > 0\), the
571///   KL terms are always well-defined (no division by zero).
572///
573/// # Domain
574///
575/// `p`, `q` must be simplex distributions.
576///
577/// # Examples
578///
579/// ```
580/// # use logp::jensen_shannon_divergence;
581/// // JS(p, p) = 0.
582/// let p = [0.3, 0.7];
583/// assert!(jensen_shannon_divergence(&p, &p, 1e-9).unwrap().abs() < 1e-15);
584///
585/// // Disjoint supports: JS = ln(2).
586/// let a = [1.0, 0.0];
587/// let b = [0.0, 1.0];
588/// let js = jensen_shannon_divergence(&a, &b, 1e-9).unwrap();
589/// assert!((js - core::f64::consts::LN_2).abs() < 1e-12);
590///
591/// // Symmetric: JS(p, q) = JS(q, p).
592/// let q = [0.5, 0.5];
593/// let js_pq = jensen_shannon_divergence(&p, &q, 1e-9).unwrap();
594/// let js_qp = jensen_shannon_divergence(&q, &p, 1e-9).unwrap();
595/// assert!((js_pq - js_qp).abs() < 1e-15);
596/// ```
597pub fn jensen_shannon_divergence(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
598    ensure_same_len(p, q)?;
599    validate_simplex(p, tol)?;
600    validate_simplex(q, tol)?;
601
602    let mut m = vec![0.0; p.len()];
603    for i in 0..p.len() {
604        m[i] = 0.5 * (p[i] + q[i]);
605    }
606
607    Ok(0.5 * kl_divergence(p, &m, tol)? + 0.5 * kl_divergence(q, &m, tol)?)
608}
609
610/// Weighted Jensen--Shannon divergence: a generalization that allows unequal mixture
611/// weights.
612///
613/// \[JS_\pi(p, q) = \pi_1\,D_{KL}(p \| m) + \pi_2\,D_{KL}(q \| m), \quad
614///   m = \pi_1 p + \pi_2 q\]
615///
616/// where \(\pi_1 + \pi_2 = 1\). At \(\pi_1 = \pi_2 = 0.5\) this reduces to
617/// [`jensen_shannon_divergence`].
618///
619/// # Key properties
620///
621/// - **Bounded**: \(0 \le JS_\pi \le H(\pi)\) where \(H(\pi) = -\pi_1 \ln \pi_1 - \pi_2 \ln \pi_2\).
622///   The standard bound \(\ln 2\) is the special case \(\pi_1 = \pi_2 = 0.5\).
623/// - **Symmetric in \((p, \pi_1)\) and \((q, \pi_2)\)**: swapping both distributions
624///   and weights gives the same value.
625///
626/// # Examples
627///
628/// ```
629/// # use logp::{jensen_shannon_weighted, jensen_shannon_divergence};
630/// let p = [0.3, 0.7];
631/// let q = [0.5, 0.5];
632///
633/// // Equal weights recovers standard JS.
634/// let jsw = jensen_shannon_weighted(&p, &q, 0.5, 1e-9).unwrap();
635/// let js = jensen_shannon_divergence(&p, &q, 1e-9).unwrap();
636/// assert!((jsw - js).abs() < 1e-12);
637///
638/// // Extreme weight: pi1=1 means m=p, so JS=0.
639/// let js1 = jensen_shannon_weighted(&p, &q, 1.0, 1e-9).unwrap();
640/// assert!(js1.abs() < 1e-12);
641/// ```
642pub fn jensen_shannon_weighted(p: &[f64], q: &[f64], pi1: f64, tol: f64) -> Result<f64> {
643    ensure_same_len(p, q)?;
644    validate_simplex(p, tol)?;
645    validate_simplex(q, tol)?;
646    if !(0.0..=1.0).contains(&pi1) || !pi1.is_finite() {
647        return Err(Error::Domain("pi1 must be in [0, 1]"));
648    }
649    let pi2 = 1.0 - pi1;
650
651    let mut m = vec![0.0; p.len()];
652    for i in 0..p.len() {
653        m[i] = pi1 * p[i] + pi2 * q[i];
654    }
655
656    // When pi1 or pi2 is zero, the corresponding KL term is 0 * D_KL = 0.
657    let kl_p = if pi1 > 0.0 {
658        kl_divergence(p, &m, tol)?
659    } else {
660        0.0
661    };
662    let kl_q = if pi2 > 0.0 {
663        kl_divergence(q, &m, tol)?
664    } else {
665        0.0
666    };
667    Ok(pi1 * kl_p + pi2 * kl_q)
668}
669
670/// Mutual information in nats: how much knowing \(Y\) reduces uncertainty about \(X\).
671///
672/// \[I(X; Y) = \sum_{x,y} p(x,y) \ln \frac{p(x,y)}{p(x)\,p(y)}\]
673///
674/// # Key properties
675///
676/// - **KL form**: \(I(X; Y) = D_{KL}\bigl(p(x,y) \;\|\; p(x)\,p(y)\bigr)\), measuring
677///   how far the joint distribution is from the product of its marginals.
678/// - **Non-negative**: \(I(X; Y) \ge 0\), with equality iff \(X\) and \(Y\) are
679///   independent.
680/// - **Symmetric**: \(I(X; Y) = I(Y; X)\).
681/// - **Bounded by entropy**: \(I(X; Y) \le \min\{H(X),\, H(Y)\}\).
682/// - **Data processing inequality**: for any Markov chain \(X \to Y \to Z\),
683///   \(I(X; Z) \le I(X; Y)\). Processing cannot create information.
684/// - **Entropy decomposition**: \(I(X; Y) = H(X) + H(Y) - H(X, Y) = H(X) - H(X|Y)\).
685///
686/// # Layout
687///
688/// For discrete distributions, given a **row-major** joint distribution table `p_xy`
689/// with shape `(n_x, n_y)`.
690///
691/// Public invariant (this is the important one): this API is **backend-agnostic**.
692/// It does not force `ndarray` into the public surface of a leaf crate.
693///
694/// # Examples
695///
696/// ```
697/// # use logp::{mutual_information, entropy_nats};
698/// // Independent joint: p(x,y) = p(x)*p(y), so I(X;Y) = 0.
699/// let p_xy = [0.15, 0.35, 0.15, 0.35]; // 2x2, marginals [0.5,0.5] x [0.3,0.7]
700/// let mi = mutual_information(&p_xy, 2, 2, 1e-9).unwrap();
701/// assert!(mi.abs() < 1e-12);
702///
703/// // Perfect correlation (Y = X, uniform bit): I(X;Y) = H(X) = ln(2).
704/// let diag = [0.5, 0.0, 0.0, 0.5];
705/// let mi = mutual_information(&diag, 2, 2, 1e-9).unwrap();
706/// assert!((mi - core::f64::consts::LN_2).abs() < 1e-12);
707/// ```
708pub fn mutual_information(p_xy: &[f64], n_x: usize, n_y: usize, tol: f64) -> Result<f64> {
709    if n_x == 0 || n_y == 0 {
710        return Err(Error::Domain(
711            "mutual_information: n_x and n_y must be >= 1",
712        ));
713    }
714    if p_xy.len() != n_x * n_y {
715        return Err(Error::LengthMismatch(p_xy.len(), n_x * n_y));
716    }
717    validate_simplex(p_xy, tol)?;
718
719    let mut p_x = vec![0.0; n_x];
720    let mut p_y = vec![0.0; n_y];
721    for i in 0..n_x {
722        for j in 0..n_y {
723            let p = p_xy[i * n_y + j];
724            p_x[i] += p;
725            p_y[j] += p;
726        }
727    }
728
729    let mut mi = 0.0;
730    for i in 0..n_x {
731        for j in 0..n_y {
732            let pxy = p_xy[i * n_y + j];
733            if pxy > 0.0 {
734                let px = p_x[i];
735                let py = p_y[j];
736                if px <= 0.0 || py <= 0.0 {
737                    return Err(Error::Domain(
738                        "mutual_information: p(x)=0 or p(y)=0 while p(x,y)>0",
739                    ));
740                }
741                mi += pxy * (pxy / (px * py)).ln();
742            }
743        }
744    }
745    Ok(mi)
746}
747
748/// `ndarray` adapter for discrete mutual information.
749///
750/// Requires `logp` feature `ndarray`.
751#[cfg(feature = "ndarray")]
752pub fn mutual_information_ndarray(p_xy: &ndarray::Array2<f64>, tol: f64) -> Result<f64> {
753    let (n_x, n_y) = p_xy.dim();
754    let flat: Vec<f64> = p_xy.iter().copied().collect();
755    mutual_information(&flat, n_x, n_y, tol)
756}
757
758/// Conditional entropy in nats: the remaining uncertainty about \(X\) after observing \(Y\).
759///
760/// \[H(X|Y) = H(X, Y) - H(Y) = H(X) - I(X; Y)\]
761///
762/// # Key properties
763///
764/// - **Non-negative**: \(H(X|Y) \ge 0\).
765/// - **Bounded**: \(H(X|Y) \le H(X)\), with equality iff \(X\) and \(Y\) are independent.
766/// - **Zero iff deterministic**: \(H(X|Y) = 0\) iff \(X\) is a function of \(Y\).
767///
768/// # Layout
769///
770/// Row-major joint distribution `p_xy` with shape `(n_x, n_y)`, same as
771/// [`mutual_information`].
772///
773/// # Examples
774///
775/// ```
776/// # use logp::conditional_entropy;
777/// // Independent: H(X|Y) = H(X).
778/// let p_xy = [0.15, 0.35, 0.15, 0.35]; // 2x2
779/// let h_x_given_y = conditional_entropy(&p_xy, 2, 2, 1e-9).unwrap();
780/// // H(X) for marginal [0.5, 0.5] = ln(2)
781/// assert!((h_x_given_y - 2.0_f64.ln()).abs() < 1e-10);
782///
783/// // Deterministic (Y = X): H(X|Y) = 0.
784/// let diag = [0.5, 0.0, 0.0, 0.5];
785/// assert!(conditional_entropy(&diag, 2, 2, 1e-9).unwrap().abs() < 1e-10);
786/// ```
787pub fn conditional_entropy(p_xy: &[f64], n_x: usize, n_y: usize, tol: f64) -> Result<f64> {
788    let mi = mutual_information(p_xy, n_x, n_y, tol)?;
789    // Compute H(X) from marginal.
790    let mut p_x = vec![0.0; n_x];
791    for i in 0..n_x {
792        for j in 0..n_y {
793            p_x[i] += p_xy[i * n_y + j];
794        }
795    }
796    let h_x = entropy_nats(&p_x, tol)?;
797    Ok(h_x - mi)
798}
799
800/// Normalized mutual information: MI scaled to \([0, 1]\) for comparing clusterings
801/// of different sizes.
802///
803/// \[NMI(X, Y) = \frac{2\,I(X; Y)}{H(X) + H(Y)}\]
804///
805/// # Key properties
806///
807/// - **Bounded**: \(NMI \in [0, 1]\). Equals 0 for independent variables;
808///   equals 1 for perfectly correlated (identical clustering).
809/// - **Symmetric**: \(NMI(X, Y) = NMI(Y, X)\).
810///
811/// Returns 0 when both marginal entropies are zero (trivial single-cluster case).
812///
813/// # Examples
814///
815/// ```
816/// # use logp::normalized_mutual_information;
817/// // Perfect correlation: NMI = 1.
818/// let diag = [0.5, 0.0, 0.0, 0.5];
819/// let nmi = normalized_mutual_information(&diag, 2, 2, 1e-9).unwrap();
820/// assert!((nmi - 1.0).abs() < 1e-10);
821///
822/// // Independent: NMI = 0.
823/// let indep = [0.15, 0.35, 0.15, 0.35];
824/// let nmi = normalized_mutual_information(&indep, 2, 2, 1e-9).unwrap();
825/// assert!(nmi.abs() < 1e-10);
826/// ```
827pub fn normalized_mutual_information(
828    p_xy: &[f64],
829    n_x: usize,
830    n_y: usize,
831    tol: f64,
832) -> Result<f64> {
833    if n_x == 0 || n_y == 0 {
834        return Err(Error::Domain("nmi: n_x and n_y must be >= 1"));
835    }
836    if p_xy.len() != n_x * n_y {
837        return Err(Error::LengthMismatch(p_xy.len(), n_x * n_y));
838    }
839    validate_simplex(p_xy, tol)?;
840
841    let mut p_x = vec![0.0; n_x];
842    let mut p_y = vec![0.0; n_y];
843    for i in 0..n_x {
844        for j in 0..n_y {
845            let p = p_xy[i * n_y + j];
846            p_x[i] += p;
847            p_y[j] += p;
848        }
849    }
850
851    let h_x = entropy_nats(&p_x, tol)?;
852    let h_y = entropy_nats(&p_y, tol)?;
853    let denom = h_x + h_y;
854    if denom <= 0.0 {
855        return Ok(0.0);
856    }
857
858    let mi = mutual_information(p_xy, n_x, n_y, tol)?;
859    Ok(2.0 * mi / denom)
860}
861
862/// Pointwise mutual information: the log-ratio measuring how much more (or less)
863/// likely two specific outcomes co-occur than if they were independent.
864///
865/// \[PMI(x; y) = \ln \frac{p(x, y)}{p(x)\,p(y)}\]
866///
867/// # Key properties
868///
869/// - **Sign**: positive when \(x\) and \(y\) co-occur more than chance; negative when
870///   less; zero when independent.
871/// - **Unbounded**: \(PMI \in (-\infty, -\ln p(x,y)]\). In practice, rare events yield
872///   very large PMI, which is why PPMI (positive PMI, clamped at 0) is common.
873/// - **Connection to mutual information**: \(I(X; Y) = \mathbb{E}_{p(x,y)}[PMI(x; y)]\).
874///   MI is the expected value of PMI over the joint distribution.
875/// - **Connection to word2vec**: Levy & Goldberg (2014) showed that Skip-gram with
876///   negative sampling implicitly factorizes a PMI matrix (shifted by \(\ln k\)).
877///
878/// # Examples
879///
880/// ```
881/// # use logp::pmi;
882/// // Independent events: p(x,y) = p(x)*p(y), so PMI = 0.
883/// let val = pmi(0.06, 0.3, 0.2).unwrap();
884/// assert!(val.abs() < 1e-10);
885///
886/// // Positive correlation: p(x,y) > p(x)*p(y).
887/// assert!(pmi(0.4, 0.5, 0.5).unwrap() > 0.0);
888///
889/// // Negative correlation: p(x,y) < p(x)*p(y).
890/// assert!(pmi(0.1, 0.5, 0.5).unwrap() < 0.0);
891///
892/// // Zero joint probability returns 0 by convention.
893/// assert_eq!(pmi(0.0, 0.5, 0.5).unwrap(), 0.0);
894///
895/// // Impossible: p(x,y) > 0 but p(x) = 0.
896/// assert!(pmi(0.1, 0.0, 0.5).is_err());
897/// ```
898pub fn pmi(pxy: f64, px: f64, py: f64) -> Result<f64> {
899    if pxy > 0.0 && px == 0.0 {
900        return Err(Error::Domain("pmi: p(x,y)>0 but p(x)=0 is impossible"));
901    }
902    if pxy > 0.0 && py == 0.0 {
903        return Err(Error::Domain("pmi: p(x,y)>0 but p(y)=0 is impossible"));
904    }
905    if pxy <= 0.0 || px <= 0.0 || py <= 0.0 {
906        Ok(0.0)
907    } else {
908        Ok((pxy / (px * py)).ln())
909    }
910}
911
912/// Log-sum-exp: numerically stable computation of `ln(exp(a_1) + ... + exp(a_n))`.
913///
914/// This is the fundamental primitive for working in log-probability space.
915/// The naive `values.iter().map(|v| v.exp()).sum::<f64>().ln()` overflows
916/// for large values and underflows for small ones; the max-shift trick
917/// avoids both.
918///
919/// Returns `NEG_INFINITY` for an empty slice.
920///
921/// # Examples
922///
923/// ```
924/// # use logp::log_sum_exp;
925/// // ln(e^0 + e^0) = ln(2)
926/// let lse = log_sum_exp(&[0.0, 0.0]);
927/// assert!((lse - 2.0_f64.ln()).abs() < 1e-12);
928///
929/// // Dominated term: ln(e^1000 + e^0) ≈ 1000
930/// let lse = log_sum_exp(&[1000.0, 0.0]);
931/// assert!((lse - 1000.0).abs() < 1e-10);
932///
933/// // Single element: identity.
934/// assert_eq!(log_sum_exp(&[42.0]), 42.0);
935///
936/// // Empty: -inf.
937/// assert_eq!(log_sum_exp(&[]), f64::NEG_INFINITY);
938/// ```
939#[inline]
940pub fn log_sum_exp(values: &[f64]) -> f64 {
941    if values.is_empty() {
942        return f64::NEG_INFINITY;
943    }
944    let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
945    if max.is_infinite() {
946        return max;
947    }
948    let sum: f64 = values.iter().map(|v| (v - max).exp()).sum();
949    max + sum.ln()
950}
951
952/// Log-sum-exp for two values (common special case).
953///
954/// Equivalent to `log_sum_exp(&[a, b])` but avoids the slice overhead.
955///
956/// # Examples
957///
958/// ```
959/// # use logp::log_sum_exp2;
960/// let lse = log_sum_exp2(0.0, 0.0);
961/// assert!((lse - 2.0_f64.ln()).abs() < 1e-12);
962/// ```
963#[inline]
964pub fn log_sum_exp2(a: f64, b: f64) -> f64 {
965    let max = a.max(b);
966    if max.is_infinite() {
967        return max;
968    }
969    max + ((a - max).exp() + (b - max).exp()).ln()
970}
971
972/// Streaming log-sum-exp: single-pass, O(1) memory computation over an iterator.
973///
974/// Equivalent to `log_sum_exp` but processes elements one at a time without
975/// materializing a slice. Useful for large-scale or iterator-chain workloads.
976///
977/// Uses a running `(max, sum_exp)` pair that rescales when a new maximum arrives.
978///
979/// Returns `NEG_INFINITY` for an empty iterator.
980///
981/// # Examples
982///
983/// ```
984/// # use logp::{log_sum_exp_iter, log_sum_exp};
985/// let values = vec![1.0, 2.0, 3.0];
986/// let lse_iter = log_sum_exp_iter(values.iter().copied());
987/// let lse_slice = log_sum_exp(&values);
988/// assert!((lse_iter - lse_slice).abs() < 1e-12);
989///
990/// // Works with any iterator.
991/// let lse = log_sum_exp_iter((0..5).map(|i| i as f64));
992/// assert!(lse.is_finite());
993///
994/// // Empty iterator: -inf.
995/// assert_eq!(log_sum_exp_iter(std::iter::empty::<f64>()), f64::NEG_INFINITY);
996/// ```
997#[inline]
998pub fn log_sum_exp_iter(iter: impl Iterator<Item = f64>) -> f64 {
999    let mut max = f64::NEG_INFINITY;
1000    let mut sum_exp = 0.0;
1001
1002    for v in iter {
1003        if v > max {
1004            if max.is_finite() {
1005                // Rescale the running sum when a new max arrives.
1006                sum_exp *= (max - v).exp();
1007            }
1008            max = v;
1009        }
1010        // (v - max).exp() is <= 1.0 since v <= max.
1011        sum_exp += (v - max).exp();
1012    }
1013
1014    if max.is_infinite() {
1015        return max; // Handles NEG_INFINITY (empty) and POS_INFINITY.
1016    }
1017    max + sum_exp.ln()
1018}
1019
1020/// Digamma function: the logarithmic derivative of the Gamma function.
1021///
1022/// \[\psi(x) = \frac{d}{dx} \ln \Gamma(x) = \frac{\Gamma'(x)}{\Gamma(x)}\]
1023///
1024/// # Key properties
1025///
1026/// - **Recurrence**: \(\psi(x+1) = \psi(x) + \frac{1}{x}\), which follows from
1027///   \(\Gamma(x+1) = x\,\Gamma(x)\).
1028/// - **Special value**: \(\psi(1) = -\gamma \approx -0.5772\), where \(\gamma\) is
1029///   the Euler--Mascheroni constant.
1030/// - **Asymptotic**: \(\psi(x) \sim \ln x - \frac{1}{2x}\) for large \(x\).
1031/// - **Why it appears here**: the KSG estimator for mutual information
1032///   ([`mutual_information_ksg`]) uses digamma to correct for the bias of
1033///   nearest-neighbor density estimates.
1034///
1035/// # Domain
1036///
1037/// Defined for \(x > 0\). Returns `NaN` for \(x \le 0\).
1038///
1039/// # Implementation
1040///
1041/// Uses the recurrence to shift small \(x\) up to \(x \ge 10\), then applies the
1042/// asymptotic expansion with 4 Bernoulli-number correction terms (~1e-14 accuracy).
1043///
1044/// # Examples
1045///
1046/// ```
1047/// # use logp::digamma;
1048/// // psi(1) = -gamma (Euler-Mascheroni constant).
1049/// let psi1 = digamma(1.0);
1050/// assert!((psi1 - (-0.5772156649)).abs() < 1e-12);
1051///
1052/// // Recurrence: psi(x+1) = psi(x) + 1/x.
1053/// let x = 3.5;
1054/// assert!((digamma(x + 1.0) - digamma(x) - 1.0 / x).abs() < 1e-12);
1055///
1056/// // Non-positive input returns NaN.
1057/// assert!(digamma(0.0).is_nan());
1058/// assert!(digamma(-1.0).is_nan());
1059/// ```
1060pub fn digamma(mut x: f64) -> f64 {
1061    if x <= 0.0 {
1062        return f64::NAN;
1063    }
1064    let mut result = 0.0;
1065    // Shift x upward via the recurrence psi(x+1) = psi(x) + 1/x until x >= 10,
1066    // where the asymptotic expansion converges to ~1e-14 with 4 Bernoulli terms.
1067    // (Previous threshold of 7 with 3 terms gave ~1e-10 accuracy.)
1068    while x < 10.0 {
1069        result -= 1.0 / x;
1070        x += 1.0;
1071    }
1072    let r = 1.0 / x;
1073    result += x.ln() - 0.5 * r;
1074    let r2 = r * r;
1075    // Bernoulli-number correction terms: B_{2k} / (2k * x^{2k}).
1076    // B2/2 = 1/12, B4/4 = 1/120, B6/6 = 1/252, B8/8 = 1/240.
1077    result -= r2 * (1.0 / 12.0 - r2 * (1.0 / 120.0 - r2 * (1.0 / 252.0 - r2 / 240.0)));
1078    result
1079}
1080
1081/// Bhattacharyya coefficient: the geometric-mean overlap between two distributions.
1082///
1083/// \[BC(p, q) = \sum_i \sqrt{p_i \, q_i}\]
1084///
1085/// # Key properties
1086///
1087/// - **Geometric mean interpretation**: each term \(\sqrt{p_i q_i}\) is the geometric
1088///   mean of the two probabilities at bin \(i\). BC sums these, measuring how much
1089///   the distributions overlap.
1090/// - **Bounded**: \(BC \in [0, 1]\). Equals 1 iff \(p = q\); equals 0 iff supports
1091///   are disjoint.
1092/// - **Relationship to Hellinger**: \(H^2(p, q) = 1 - BC(p, q)\), so the squared
1093///   Hellinger distance is just one minus the Bhattacharyya coefficient.
1094/// - **Relationship to Renyi**: at \(\alpha = \tfrac{1}{2}\), the Renyi divergence
1095///   gives \(D_{1/2}^R(p \| q) = -2 \ln BC(p, q)\).
1096/// - **Connection to alpha family**: \(BC = \rho_{1/2}(p, q)\), a special case of
1097///   [`rho_alpha`].
1098///
1099/// # Examples
1100///
1101/// ```
1102/// # use logp::bhattacharyya_coeff;
1103/// // BC(p, p) = 1.
1104/// let p = [0.3, 0.7];
1105/// assert!((bhattacharyya_coeff(&p, &p, 1e-9).unwrap() - 1.0).abs() < 1e-12);
1106///
1107/// // Disjoint supports: BC = 0.
1108/// let a = [1.0, 0.0];
1109/// let b = [0.0, 1.0];
1110/// assert!(bhattacharyya_coeff(&a, &b, 1e-9).unwrap().abs() < 1e-15);
1111///
1112/// // BC is in [0, 1].
1113/// let q = [0.5, 0.5];
1114/// let bc = bhattacharyya_coeff(&p, &q, 1e-9).unwrap();
1115/// assert!((0.0..=1.0).contains(&bc));
1116/// ```
1117pub fn bhattacharyya_coeff(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
1118    ensure_same_len(p, q)?;
1119    validate_simplex(p, tol)?;
1120    validate_simplex(q, tol)?;
1121    let bc: f64 = p
1122        .iter()
1123        .zip(q.iter())
1124        .map(|(&pi, &qi)| pi.sqrt() * qi.sqrt())
1125        .sum();
1126    Ok(bc)
1127}
1128
1129/// Bhattacharyya distance \(D_B(p,q) = -\ln BC(p,q)\).
1130///
1131/// # Examples
1132///
1133/// ```
1134/// # use logp::bhattacharyya_distance;
1135/// // D_B(p, p) = 0.
1136/// let p = [0.4, 0.6];
1137/// assert!(bhattacharyya_distance(&p, &p, 1e-9).unwrap().abs() < 1e-12);
1138///
1139/// // Non-negative for distinct distributions.
1140/// let q = [0.5, 0.5];
1141/// assert!(bhattacharyya_distance(&p, &q, 1e-9).unwrap() >= 0.0);
1142/// ```
1143pub fn bhattacharyya_distance(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
1144    let bc = bhattacharyya_coeff(p, q, tol)?;
1145    // When supports are disjoint, bc can be 0 (=> +∞ distance). Keep it explicit.
1146    if bc == 0.0 {
1147        return Err(Error::Domain("Bhattacharyya distance is infinite (BC=0)"));
1148    }
1149    Ok(-bc.ln())
1150}
1151
1152/// Squared Hellinger distance.
1153///
1154/// \[H^2(p, q) = \frac{1}{2}\sum_i \left(\sqrt{p_i} - \sqrt{q_i}\right)^2\]
1155///
1156/// Equivalent to \(1 - BC(p,q)\) but computed via the sum-of-squared-differences
1157/// form to avoid catastrophic cancellation when \(p \approx q\).
1158///
1159/// Bounded in \([0, 1]\). Equals the Amari \(\alpha\)-divergence at \(\alpha = 0\)
1160/// (up to a factor of 2).
1161///
1162/// # Examples
1163///
1164/// ```
1165/// # use logp::hellinger_squared;
1166/// // H^2(p, p) = 0.
1167/// let p = [0.25, 0.75];
1168/// assert!(hellinger_squared(&p, &p, 1e-9).unwrap().abs() < 1e-15);
1169///
1170/// // Disjoint supports: H^2 = 1.
1171/// let a = [1.0, 0.0];
1172/// let b = [0.0, 1.0];
1173/// assert!((hellinger_squared(&a, &b, 1e-9).unwrap() - 1.0).abs() < 1e-12);
1174/// ```
1175pub fn hellinger_squared(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
1176    ensure_same_len(p, q)?;
1177    validate_simplex(p, tol)?;
1178    validate_simplex(q, tol)?;
1179    let h2: f64 = p
1180        .iter()
1181        .zip(q.iter())
1182        .map(|(&pi, &qi)| {
1183            let diff = pi.sqrt() - qi.sqrt();
1184            diff * diff
1185        })
1186        .sum();
1187    Ok((0.5 * h2).max(0.0))
1188}
1189
1190/// Hellinger distance: the square root of the squared Hellinger distance.
1191///
1192/// \[H(p, q) = \sqrt{1 - BC(p, q)}\]
1193///
1194/// Unlike KL divergence, Hellinger is a **proper metric**: it is symmetric, satisfies
1195/// the triangle inequality, and is bounded in \([0, 1]\).
1196///
1197/// # Examples
1198///
1199/// ```
1200/// # use logp::hellinger;
1201/// // H(p, p) = 0.
1202/// let p = [0.3, 0.7];
1203/// assert!(hellinger(&p, &p, 1e-9).unwrap().abs() < 1e-15);
1204///
1205/// // Symmetric: H(p, q) = H(q, p).
1206/// let q = [0.5, 0.5];
1207/// let h_pq = hellinger(&p, &q, 1e-9).unwrap();
1208/// let h_qp = hellinger(&q, &p, 1e-9).unwrap();
1209/// assert!((h_pq - h_qp).abs() < 1e-15);
1210///
1211/// // Bounded in [0, 1].
1212/// assert!((0.0..=1.0).contains(&h_pq));
1213/// ```
1214pub fn hellinger(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
1215    Ok(hellinger_squared(p, q, tol)?.sqrt())
1216}
1217
1218fn pow_nonneg(x: f64, a: f64) -> Result<f64> {
1219    if x < 0.0 || !x.is_finite() || !a.is_finite() {
1220        return Err(Error::Domain("pow_nonneg: invalid input"));
1221    }
1222    if x == 0.0 {
1223        if a == 0.0 {
1224            // By continuity in the divergence formulas, treat 0^0 as 1.
1225            return Ok(1.0);
1226        }
1227        if a > 0.0 {
1228            return Ok(0.0);
1229        }
1230        return Err(Error::Domain("0^a for a<0 is infinite"));
1231    }
1232    Ok(x.powf(a))
1233}
1234
1235/// Alpha-integral: the workhorse behind the entire alpha-family of divergences.
1236///
1237/// \[\rho_\alpha(p, q) = \sum_i p_i^\alpha \, q_i^{1-\alpha}\]
1238///
1239/// # Why this matters
1240///
1241/// This single quantity generates multiple divergence families via simple transforms:
1242///
1243/// - **Renyi**: \(D_\alpha^R = \frac{1}{\alpha - 1} \ln \rho_\alpha\)
1244/// - **Tsallis**: \(D_\alpha^T = \frac{\rho_\alpha - 1}{\alpha - 1}\)
1245/// - **Bhattacharyya coefficient**: \(BC = \rho_{1/2}\)
1246/// - **Chernoff information**: \(\min_\alpha (-\ln \rho_\alpha)\)
1247///
1248/// # Key properties
1249///
1250/// - \(\rho_\alpha(p, p) = 1\) for all \(\alpha\) (since \(\sum p_i = 1\)).
1251/// - By Holder's inequality, \(\rho_\alpha(p, q) \le 1\) for \(\alpha \in [0, 1]\).
1252/// - Continuous and log-convex in \(\alpha\).
1253///
1254/// # Examples
1255///
1256/// ```
1257/// # use logp::rho_alpha;
1258/// // rho_alpha(p, p, alpha) = 1 for any alpha (since sum(p) = 1).
1259/// let p = [0.2, 0.3, 0.5];
1260/// assert!((rho_alpha(&p, &p, 0.5, 1e-9).unwrap() - 1.0).abs() < 1e-12);
1261/// assert!((rho_alpha(&p, &p, 2.0, 1e-9).unwrap() - 1.0).abs() < 1e-12);
1262///
1263/// // At alpha = 0.5, rho equals the Bhattacharyya coefficient.
1264/// let q = [0.5, 0.25, 0.25];
1265/// let rho = rho_alpha(&p, &q, 0.5, 1e-9).unwrap();
1266/// let bc = logp::bhattacharyya_coeff(&p, &q, 1e-9).unwrap();
1267/// assert!((rho - bc).abs() < 1e-12);
1268/// ```
1269pub fn rho_alpha(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
1270    ensure_same_len(p, q)?;
1271    validate_simplex(p, tol)?;
1272    validate_simplex(q, tol)?;
1273    if !alpha.is_finite() {
1274        return Err(Error::InvalidAlpha {
1275            alpha,
1276            forbidden: f64::NAN,
1277        });
1278    }
1279    let mut s = 0.0;
1280    for (&pi, &qi) in p.iter().zip(q.iter()) {
1281        let a = pow_nonneg(pi, alpha)?;
1282        let b = pow_nonneg(qi, 1.0 - alpha)?;
1283        s += a * b;
1284    }
1285    Ok(s)
1286}
1287
1288/// Renyi divergence in nats: a one-parameter family that interpolates between
1289/// different notions of distributional difference.
1290///
1291/// \[D_\alpha^R(p \| q) = \frac{1}{\alpha - 1} \ln \rho_\alpha(p, q), \quad \alpha > 0,\; \alpha \ne 1\]
1292///
1293/// # Key properties
1294///
1295/// - **Limit to KL**: \(\lim_{\alpha \to 1} D_\alpha^R(p \| q) = D_{KL}(p \| q)\)
1296///   by L'Hopital's rule (the logarithm and denominator both vanish).
1297/// - **Alpha = 1/2**: \(D_{1/2}^R = -2 \ln BC(p, q)\), twice the negative log
1298///   Bhattacharyya coefficient.
1299/// - **Alpha = infinity**: \(D_\infty^R = \ln \max_i (p_i / q_i)\), the log of the
1300///   maximum likelihood ratio. This bounds all other Renyi orders.
1301/// - **Monotone in alpha**: \(D_\alpha^R\) is non-decreasing in \(\alpha\).
1302/// - **Non-negative**: \(D_\alpha^R(p \| q) \ge 0\), with equality iff \(p = q\).
1303///
1304/// # Domain
1305///
1306/// \(\alpha > 0\), \(\alpha \ne 1\). Both `p` and `q` must be simplex distributions.
1307///
1308/// # Examples
1309///
1310/// ```
1311/// # use logp::renyi_divergence;
1312/// // D_alpha(p || p) = 0 for any valid alpha.
1313/// let p = [0.3, 0.7];
1314/// assert!(renyi_divergence(&p, &p, 2.0, 1e-9).unwrap().abs() < 1e-12);
1315///
1316/// // Non-negative.
1317/// let q = [0.5, 0.5];
1318/// assert!(renyi_divergence(&p, &q, 0.5, 1e-9).unwrap() >= -1e-12);
1319///
1320/// // alpha = 1.0 returns KL divergence (Shannon limit).
1321/// let kl = logp::kl_divergence(&p, &q, 1e-9).unwrap();
1322/// let r1 = renyi_divergence(&p, &q, 1.0, 1e-9).unwrap();
1323/// assert!((r1 - kl).abs() < 1e-12);
1324/// ```
1325pub fn renyi_divergence(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
1326    if (alpha - 1.0).abs() < 1e-12 {
1327        return kl_divergence(p, q, tol);
1328    }
1329    let rho = rho_alpha(p, q, alpha, tol)?;
1330    if rho <= 0.0 {
1331        return Err(Error::Domain("rho_alpha <= 0"));
1332    }
1333    Ok(rho.ln() / (alpha - 1.0))
1334}
1335
1336/// Tsallis divergence: a non-extensive generalization of KL divergence from
1337/// statistical mechanics.
1338///
1339/// \[D_\alpha^T(p \| q) = \frac{\rho_\alpha(p, q) - 1}{\alpha - 1}, \quad \alpha \ne 1\]
1340///
1341/// # Key properties
1342///
1343/// - **Limit to KL**: \(\lim_{\alpha \to 1} D_\alpha^T(p \| q) = D_{KL}(p \| q)\),
1344///   same limit as Renyi but via a different path.
1345/// - **Connection to Renyi via deformed logarithm**: Tsallis uses the q-logarithm
1346///   \(\ln_q(x) = \frac{x^{1-q} - 1}{1-q}\) where Renyi uses the ordinary log.
1347///   Formally: \(D_\alpha^T = \frac{e^{(\alpha-1) D_\alpha^R} - 1}{\alpha - 1}\).
1348/// - **Non-extensive**: for independent systems, Tsallis divergence is **not** additive
1349///   (unlike KL and Renyi). This property is intentional and models systems with
1350///   long-range correlations in statistical physics.
1351/// - **Non-negative**: \(D_\alpha^T(p \| q) \ge 0\), with equality iff \(p = q\).
1352///
1353/// # Domain
1354///
1355/// \(\alpha \ne 1\). Both `p` and `q` must be simplex distributions.
1356///
1357/// # Examples
1358///
1359/// ```
1360/// # use logp::tsallis_divergence;
1361/// // D_alpha^T(p || p) = 0 for any valid alpha.
1362/// let p = [0.4, 0.6];
1363/// assert!(tsallis_divergence(&p, &p, 2.0, 1e-9).unwrap().abs() < 1e-12);
1364///
1365/// // Non-negative.
1366/// let q = [0.5, 0.5];
1367/// assert!(tsallis_divergence(&p, &q, 0.5, 1e-9).unwrap() >= -1e-12);
1368///
1369/// // alpha = 1.0 returns KL divergence (Shannon limit).
1370/// let kl = logp::kl_divergence(&p, &q, 1e-9).unwrap();
1371/// let t1 = tsallis_divergence(&p, &q, 1.0, 1e-9).unwrap();
1372/// assert!((t1 - kl).abs() < 1e-12);
1373/// ```
1374pub fn tsallis_divergence(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
1375    if (alpha - 1.0).abs() < 1e-12 {
1376        return kl_divergence(p, q, tol);
1377    }
1378    Ok((rho_alpha(p, q, alpha, tol)? - 1.0) / (alpha - 1.0))
1379}
1380
1381/// Amari alpha-divergence: a one-parameter family from information geometry that
1382/// continuously interpolates between forward KL, reverse KL, and squared Hellinger.
1383///
1384/// For \(\alpha \notin \{-1, 1\}\):
1385///
1386/// \[D^\alpha(p : q) = \frac{4}{1 - \alpha^2}\left(1 - \rho_{\frac{1-\alpha}{2}}(p, q)\right)\]
1387///
1388/// # Key properties
1389///
1390/// - **\(\alpha = -1\)**: recovers \(D_{KL}(p \| q)\), the forward KL divergence.
1391/// - **\(\alpha = +1\)**: recovers \(D_{KL}(q \| p)\), the reverse KL divergence.
1392/// - **\(\alpha = 0\)**: gives \(4(1 - BC(p,q)) = 4\,H^2(p,q)\), proportional to
1393///   the squared Hellinger distance.
1394/// - **Duality**: \(D^\alpha(p : q) = D^{-\alpha}(q : p)\). Swapping the sign of
1395///   \(\alpha\) is the same as swapping the arguments.
1396/// - **Non-negative**: \(D^\alpha(p : q) \ge 0\), with equality iff \(p = q\).
1397/// - **Information geometry**: the Amari family parameterizes the \(\alpha\)-connections
1398///   on the statistical manifold (Amari & Nagaoka, 2000).
1399///
1400/// # Examples
1401///
1402/// ```
1403/// # use logp::{amari_alpha_divergence, kl_divergence, hellinger_squared};
1404/// let p = [0.3, 0.7];
1405/// let q = [0.5, 0.5];
1406/// let tol = 1e-9;
1407///
1408/// // alpha = -1 gives forward KL(p || q).
1409/// let amari_neg1 = amari_alpha_divergence(&p, &q, -1.0, tol).unwrap();
1410/// let kl_fwd = kl_divergence(&p, &q, tol).unwrap();
1411/// assert!((amari_neg1 - kl_fwd).abs() < 1e-6);
1412///
1413/// // alpha = 0 gives 4 * H^2(p, q).
1414/// let amari_0 = amari_alpha_divergence(&p, &q, 0.0, tol).unwrap();
1415/// let h2 = hellinger_squared(&p, &q, tol).unwrap();
1416/// assert!((amari_0 - 4.0 * h2).abs() < 1e-10);
1417/// ```
1418pub fn amari_alpha_divergence(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
1419    if !alpha.is_finite() {
1420        return Err(Error::InvalidAlpha {
1421            alpha,
1422            forbidden: f64::NAN,
1423        });
1424    }
1425    // Numerically stable handling near ±1.
1426    let eps = tol.sqrt();
1427    if (alpha + 1.0).abs() <= eps {
1428        return kl_divergence(p, q, tol);
1429    }
1430    if (alpha - 1.0).abs() <= eps {
1431        return kl_divergence(q, p, tol);
1432    }
1433    let t = (1.0 - alpha) / 2.0;
1434    let rho = rho_alpha(p, q, t, tol)?;
1435    Ok((4.0 / (1.0 - alpha * alpha)) * (1.0 - rho))
1436}
1437
1438/// Csiszar f-divergence: the most general class of divergences that respect
1439/// sufficient statistics (information monotonicity).
1440///
1441/// \[D_f(p \| q) = \sum_i q_i \, f\!\left(\frac{p_i}{q_i}\right)\]
1442///
1443/// where \(f\) is a convex function with \(f(1) = 0\).
1444///
1445/// # Information monotonicity theorem
1446///
1447/// The defining property of f-divergences (Csiszar, 1967): for any Markov kernel
1448/// (stochastic map) \(T\),
1449///
1450/// \[D_f(Tp \| Tq) \le D_f(p \| q)\]
1451///
1452/// Coarse-graining (merging bins) cannot increase the divergence. This is the
1453/// information-theoretic analogue of the data processing inequality.
1454///
1455/// # Common f-generators
1456///
1457/// | Divergence | \(f(t)\) |
1458/// |---|---|
1459/// | KL divergence | \(t \ln t\) |
1460/// | Reverse KL | \(-\ln t\) |
1461/// | Squared Hellinger | \((\sqrt{t} - 1)^2\) |
1462/// | Total variation | \(\tfrac{1}{2} |t - 1|\) |
1463/// | Chi-squared | \((t - 1)^2\) |
1464/// | Jensen-Shannon | \(t \ln t - (1+t) \ln \tfrac{1+t}{2}\) |
1465///
1466/// # Convention
1467///
1468/// This function uses Csiszar's original convention: \(D_f(p \| q) = \sum q_i f(p_i/q_i)\).
1469/// Some textbooks reverse the roles of \(p\) and \(q\), writing
1470/// \(\sum p_i f(q_i/p_i)\), which silently computes the *conjugate* divergence
1471/// \(D_{f^*}\) where \(f^*(t) = t\,f(1/t)\). The generators in the table above
1472/// follow this function's convention.
1473///
1474/// # Edge cases
1475///
1476/// When `q_i = 0`:
1477/// - if `p_i = 0`, the contribution is treated as 0 (by continuity).
1478/// - if `p_i > 0`, the divergence is infinite; we return an error.
1479///
1480/// # Examples
1481///
1482/// ```
1483/// # use logp::{csiszar_f_divergence, kl_divergence};
1484/// let p = [0.3, 0.7];
1485/// let q = [0.5, 0.5];
1486///
1487/// // f(t) = t*ln(t) recovers KL divergence.
1488/// let cs = csiszar_f_divergence(&p, &q, |t| t * t.ln(), 1e-9).unwrap();
1489/// let kl = kl_divergence(&p, &q, 1e-9).unwrap();
1490/// assert!((cs - kl).abs() < 1e-10);
1491///
1492/// // f(t) = (t - 1)^2 gives chi-squared divergence.
1493/// let chi2 = csiszar_f_divergence(&p, &q, |t| (t - 1.0).powi(2), 1e-9).unwrap();
1494/// assert!(chi2 >= 0.0);
1495/// ```
1496pub fn csiszar_f_divergence(p: &[f64], q: &[f64], f: impl Fn(f64) -> f64, tol: f64) -> Result<f64> {
1497    ensure_same_len(p, q)?;
1498    validate_simplex(p, tol)?;
1499    validate_simplex(q, tol)?;
1500
1501    let mut d = 0.0;
1502    for (&pi, &qi) in p.iter().zip(q.iter()) {
1503        if qi == 0.0 {
1504            if pi == 0.0 {
1505                continue;
1506            }
1507            return Err(Error::Domain("f-divergence undefined: q_i=0 while p_i>0"));
1508        }
1509        d += qi * f(pi / qi);
1510    }
1511    Ok(d)
1512}
1513
1514/// Total variation distance: half the L1 norm between two distributions.
1515///
1516/// \[TV(p, q) = \frac{1}{2} \sum_i |p_i - q_i|\]
1517///
1518/// Equivalently, this is the Csiszar f-divergence with generator
1519/// \(f(t) = \frac{1}{2} |t - 1|\).
1520///
1521/// # Key properties
1522///
1523/// - **Metric**: symmetric, satisfies the triangle inequality, and \(TV(p, p) = 0\).
1524/// - **Bounded**: \(TV \in [0, 1]\).
1525/// - **Pinsker's inequality**: \(TV(p, q) \le \sqrt{\frac{1}{2} D_{KL}(p \| q)}\).
1526///
1527/// # Examples
1528///
1529/// ```
1530/// # use logp::total_variation;
1531/// // TV(p, p) = 0.
1532/// let p = [0.3, 0.7];
1533/// assert!(total_variation(&p, &p, 1e-9).unwrap().abs() < 1e-15);
1534///
1535/// // Disjoint supports: TV = 1.
1536/// let a = [1.0, 0.0];
1537/// let b = [0.0, 1.0];
1538/// assert!((total_variation(&a, &b, 1e-9).unwrap() - 1.0).abs() < 1e-12);
1539///
1540/// // Symmetric.
1541/// let q = [0.5, 0.5];
1542/// let tv_pq = total_variation(&p, &q, 1e-9).unwrap();
1543/// let tv_qp = total_variation(&q, &p, 1e-9).unwrap();
1544/// assert!((tv_pq - tv_qp).abs() < 1e-15);
1545/// ```
1546pub fn total_variation(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
1547    ensure_same_len(p, q)?;
1548    validate_simplex(p, tol)?;
1549    validate_simplex(q, tol)?;
1550    let tv: f64 = p
1551        .iter()
1552        .zip(q.iter())
1553        .map(|(&pi, &qi)| (pi - qi).abs())
1554        .sum();
1555    Ok(0.5 * tv)
1556}
1557
1558/// Chi-squared divergence: a member of the Csiszar f-divergence family that is
1559/// particularly sensitive to tail differences.
1560///
1561/// \[\chi^2(p \| q) = \sum_i \frac{(p_i - q_i)^2}{q_i}\]
1562///
1563/// Equivalently, the f-divergence with generator \(f(t) = (t - 1)^2\).
1564///
1565/// # Key properties
1566///
1567/// - **Non-negative**: \(\chi^2 \ge 0\), with equality iff \(p = q\).
1568/// - **Upper bounds KL**: \(D_{KL}(p \| q) \le \ln(1 + \chi^2(p \| q))\).
1569/// - **Sensitivity warning**: unbounded and extremely sensitive to small \(q_i\).
1570///   When \(q_i\) is near zero but \(p_i\) is not, the ratio \((p_i - q_i)^2 / q_i\)
1571///   can be very large even when KL divergence would be moderate.
1572///
1573/// # Domain
1574///
1575/// Requires \(q_i > 0\) whenever \(p_i > 0\) (same as KL).
1576///
1577/// # Examples
1578///
1579/// ```
1580/// # use logp::chi_squared_divergence;
1581/// // chi^2(p, p) = 0.
1582/// let p = [0.3, 0.7];
1583/// assert!(chi_squared_divergence(&p, &p, 1e-9).unwrap().abs() < 1e-15);
1584///
1585/// // Non-negative.
1586/// let q = [0.5, 0.5];
1587/// assert!(chi_squared_divergence(&p, &q, 1e-9).unwrap() >= 0.0);
1588/// ```
1589pub fn chi_squared_divergence(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
1590    ensure_same_len(p, q)?;
1591    validate_simplex(p, tol)?;
1592    validate_simplex(q, tol)?;
1593    let mut d = 0.0;
1594    for (&pi, &qi) in p.iter().zip(q.iter()) {
1595        if qi == 0.0 {
1596            if pi == 0.0 {
1597                continue;
1598            }
1599            return Err(Error::Domain("chi-squared undefined: q_i=0 while p_i>0"));
1600        }
1601        let diff = pi - qi;
1602        d += diff * diff / qi;
1603    }
1604    Ok(d)
1605}
1606
1607/// Bregman generator: a convex function \(F\) and its gradient.
1608pub trait BregmanGenerator {
1609    /// Evaluate the potential \(F(x)\).
1610    fn f(&self, x: &[f64]) -> Result<f64>;
1611
1612    /// Write \(\nabla F(x)\) into `out`.
1613    fn grad_into(&self, x: &[f64], out: &mut [f64]) -> Result<()>;
1614}
1615
1616/// Bregman divergence: the gap between a convex function and its tangent approximation.
1617///
1618/// \[B_F(p, q) = F(p) - F(q) - \langle p - q,\, \nabla F(q) \rangle\]
1619///
1620/// # Key properties
1621///
1622/// - **Non-negative**: \(B_F(p, q) \ge 0\) by convexity of \(F\), with equality iff
1623///   \(p = q\).
1624/// - **Not symmetric** in general: \(B_F(p, q) \ne B_F(q, p)\).
1625/// - **Generalized Pythagorean theorem**: for an affine subspace \(S\) and its
1626///   Bregman projection \(q^* = \arg\min_{q \in S} B_F(p, q)\), the three-point
1627///   identity holds: \(B_F(p, q) = B_F(p, q^*) + B_F(q^*, q)\) for all \(q \in S\).
1628///   This is the foundation of dually flat geometry (Amari).
1629/// - **Not an f-divergence**: Bregman divergences are **not** information monotone
1630///   in general. They live in a different branch of the divergence taxonomy.
1631/// - **Examples**: squared Euclidean (\(F = \tfrac{1}{2}\|x\|^2\)) gives
1632///   \(B_F(p,q) = \tfrac{1}{2}\|p - q\|^2\); negative entropy
1633///   (\(F = \sum x_i \ln x_i\)) gives the KL divergence.
1634///
1635/// # Examples
1636///
1637/// ```
1638/// # use logp::{bregman_divergence, SquaredL2};
1639/// // Squared-L2 generator: B_F(p, q) = 0.5 * ||p - q||^2.
1640/// let gen = SquaredL2;
1641/// let p = [1.0, 2.0, 3.0];
1642/// let q = [1.5, 1.5, 2.5];
1643/// let b = bregman_divergence(&gen, &p, &q).unwrap();
1644/// let expected = 0.5 * ((0.5_f64).powi(2) + (0.5_f64).powi(2) + (0.5_f64).powi(2));
1645/// assert!((b - expected).abs() < 1e-12);
1646///
1647/// // B_F(p, p) = 0.
1648/// assert!(bregman_divergence(&gen, &p, &p).unwrap().abs() < 1e-15);
1649/// ```
1650pub fn bregman_divergence(gen: &impl BregmanGenerator, p: &[f64], q: &[f64]) -> Result<f64> {
1651    ensure_nonempty(p)?;
1652    ensure_same_len(p, q)?;
1653    let mut grad_q = vec![0.0; q.len()];
1654    gen.grad_into(q, &mut grad_q)?;
1655    let fp = gen.f(p)?;
1656    let fq = gen.f(q)?;
1657    let mut inner = 0.0;
1658    for i in 0..p.len() {
1659        inner += (p[i] - q[i]) * grad_q[i];
1660    }
1661    Ok(fp - fq - inner)
1662}
1663
1664/// Total Bregman divergence as shown in Nielsen’s taxonomy diagram:
1665///
1666/// \(tB_F(p,q) = \frac{B_F(p,q)}{\sqrt{1 + \|\nabla F(q)\|^2}}\).
1667///
1668/// # Examples
1669///
1670/// ```
1671/// # use logp::{total_bregman_divergence, bregman_divergence, SquaredL2};
1672/// let gen = SquaredL2;
1673/// let p = [1.0, 2.0];
1674/// let q = [3.0, 4.0];
1675///
1676/// let tb = total_bregman_divergence(&gen, &p, &q).unwrap();
1677///
1678/// // Total Bregman <= Bregman (normalization divides by >= 1).
1679/// let b = bregman_divergence(&gen, &p, &q).unwrap();
1680/// assert!(tb <= b + 1e-12);
1681/// ```
1682pub fn total_bregman_divergence(gen: &impl BregmanGenerator, p: &[f64], q: &[f64]) -> Result<f64> {
1683    ensure_nonempty(p)?;
1684    ensure_same_len(p, q)?;
1685    let mut grad_q = vec![0.0; q.len()];
1686    gen.grad_into(q, &mut grad_q)?;
1687    let fp = gen.f(p)?;
1688    let fq = gen.f(q)?;
1689    let mut inner = 0.0;
1690    for i in 0..p.len() {
1691        inner += (p[i] - q[i]) * grad_q[i];
1692    }
1693    let b = fp - fq - inner;
1694    let grad_norm_sq: f64 = grad_q.iter().map(|&x| x * x).sum();
1695    Ok(b / (1.0 + grad_norm_sq).sqrt())
1696}
1697
1698/// Squared Euclidean Bregman generator: \(F(x)=\tfrac12\|x\|_2^2\), \(\nabla F(x)=x\).
1699#[derive(Debug, Clone, Copy, Default)]
1700pub struct SquaredL2;
1701
1702impl BregmanGenerator for SquaredL2 {
1703    fn f(&self, x: &[f64]) -> Result<f64> {
1704        ensure_nonempty(x)?;
1705        Ok(0.5 * x.iter().map(|&v| v * v).sum::<f64>())
1706    }
1707
1708    fn grad_into(&self, x: &[f64], out: &mut [f64]) -> Result<()> {
1709        ensure_nonempty(x)?;
1710        if out.len() != x.len() {
1711            return Err(Error::LengthMismatch(out.len(), x.len()));
1712        }
1713        out.copy_from_slice(x);
1714        Ok(())
1715    }
1716}
1717
1718/// Negative-entropy Bregman generator: \(F(x) = \sum_i x_i \ln x_i\),
1719/// \(\nabla F(x)_i = 1 + \ln x_i\).
1720///
1721/// The Bregman divergence with this generator is the (unnormalized) KL divergence:
1722/// \(B_F(p, q) = \sum_i p_i \ln(p_i / q_i) - \sum_i (p_i - q_i)\).
1723/// When \(p\) and \(q\) are normalized (simplex), the second sum vanishes and
1724/// \(B_F = D_{KL}(p \| q)\).
1725///
1726/// This connects the information-theoretic (f-divergence) and geometric (Bregman)
1727/// views of KL divergence. The dually-flat structure of the probability simplex
1728/// under this generator is the foundation of information geometry (Amari & Nagaoka, 2000).
1729#[derive(Debug, Clone, Copy, Default)]
1730pub struct NegEntropy;
1731
1732impl BregmanGenerator for NegEntropy {
1733    fn f(&self, x: &[f64]) -> Result<f64> {
1734        ensure_nonempty(x)?;
1735        let mut s = 0.0;
1736        for &xi in x {
1737            if xi < 0.0 {
1738                return Err(Error::Domain("NegEntropy: input must be nonnegative"));
1739            }
1740            if xi > 0.0 {
1741                s += xi * xi.ln();
1742            }
1743        }
1744        Ok(s)
1745    }
1746
1747    fn grad_into(&self, x: &[f64], out: &mut [f64]) -> Result<()> {
1748        ensure_nonempty(x)?;
1749        if out.len() != x.len() {
1750            return Err(Error::LengthMismatch(out.len(), x.len()));
1751        }
1752        for (o, &xi) in out.iter_mut().zip(x.iter()) {
1753            if xi <= 0.0 {
1754                return Err(Error::Domain("NegEntropy grad: input must be positive"));
1755            }
1756            *o = 1.0 + xi.ln();
1757        }
1758        Ok(())
1759    }
1760}
1761
1762#[cfg(test)]
1763mod tests {
1764    use super::*;
1765    use proptest::prelude::*;
1766
1767    const TOL: f64 = 1e-9;
1768
1769    fn simplex_vec(len: usize) -> impl Strategy<Value = Vec<f64>> {
1770        // Draw nonnegative weights then normalize.
1771        prop::collection::vec(0.0f64..10.0, len).prop_map(|mut v| {
1772            let s: f64 = v.iter().sum();
1773            if s == 0.0 {
1774                v[0] = 1.0;
1775                return v;
1776            }
1777            for x in v.iter_mut() {
1778                *x /= s;
1779            }
1780            v
1781        })
1782    }
1783
1784    fn simplex_vec_pos(len: usize, eps: f64) -> impl Strategy<Value = Vec<f64>> {
1785        prop::collection::vec(0.0f64..10.0, len).prop_map(move |mut v| {
1786            // Add a small floor to avoid exact zeros (needed for KL-style domains).
1787            for x in v.iter_mut() {
1788                *x += eps;
1789            }
1790            let s: f64 = v.iter().sum();
1791            for x in v.iter_mut() {
1792                *x /= s;
1793            }
1794            v
1795        })
1796    }
1797
1798    fn random_partition(n: usize) -> impl Strategy<Value = Vec<usize>> {
1799        // Partition indices into k buckets (k chosen implicitly).
1800        // We generate a label in [0, n) for each index and later reindex to compact labels.
1801        prop::collection::vec(0usize..n, n).prop_map(|labels| {
1802            // Compress labels to 0..k-1 while preserving equality pattern.
1803            use std::collections::BTreeMap;
1804            let mut map = BTreeMap::<usize, usize>::new();
1805            let mut next = 0usize;
1806            labels
1807                .into_iter()
1808                .map(|l| {
1809                    *map.entry(l).or_insert_with(|| {
1810                        let id = next;
1811                        next += 1;
1812                        id
1813                    })
1814                })
1815                .collect::<Vec<_>>()
1816        })
1817    }
1818
1819    fn coarse_grain(p: &[f64], labels: &[usize]) -> Vec<f64> {
1820        let k = labels.iter().copied().max().unwrap_or(0) + 1;
1821        let mut out = vec![0.0; k];
1822        for (i, &lab) in labels.iter().enumerate() {
1823            out[lab] += p[i];
1824        }
1825        out
1826    }
1827
1828    fn l1(p: &[f64], q: &[f64]) -> f64 {
1829        p.iter().zip(q.iter()).map(|(&a, &b)| (a - b).abs()).sum()
1830    }
1831
1832    #[test]
1833    fn test_entropy_unchecked() {
1834        let p = [0.5, 0.5];
1835        let h = entropy_unchecked(&p);
1836        // -0.5*ln(0.5) - 0.5*ln(0.5) = -ln(0.5) = ln(2)
1837        assert!((h - core::f64::consts::LN_2).abs() < 1e-12);
1838    }
1839
1840    #[test]
1841    fn js_is_bounded_by_ln2() {
1842        let p = [1.0, 0.0];
1843        let q = [0.0, 1.0];
1844        let js = jensen_shannon_divergence(&p, &q, TOL).unwrap();
1845        assert!(js <= core::f64::consts::LN_2 + 1e-12);
1846        assert!(js >= 0.0);
1847    }
1848
1849    #[test]
1850    fn mutual_information_independent_is_zero() {
1851        // p(x,y) = p(x)p(y) ⇒ I(X;Y)=0
1852        let p_x = [0.5, 0.5];
1853        let p_y = [0.25, 0.75];
1854        // Row-major 2x2:
1855        // [0.125, 0.375,
1856        //  0.125, 0.375]
1857        let p_xy = [
1858            p_x[0] * p_y[0],
1859            p_x[0] * p_y[1],
1860            p_x[1] * p_y[0],
1861            p_x[1] * p_y[1],
1862        ];
1863        let mi = mutual_information(&p_xy, 2, 2, TOL).unwrap();
1864        assert!(mi.abs() < 1e-12, "mi={}", mi);
1865    }
1866
1867    #[test]
1868    fn mutual_information_perfect_correlation_is_ln2() {
1869        // X=Y uniform bit ⇒ I(X;Y)=ln 2 (nats)
1870        let p_xy = [0.5, 0.0, 0.0, 0.5]; // 2x2 diagonal
1871        let mi = mutual_information(&p_xy, 2, 2, TOL).unwrap();
1872        assert!((mi - core::f64::consts::LN_2).abs() < 1e-12, "mi={}", mi);
1873    }
1874
1875    #[test]
1876    fn bregman_squared_l2_matches_half_l2() {
1877        let gen = SquaredL2;
1878        let p = [1.0, 2.0, 3.0];
1879        let q = [1.5, 1.5, 2.5];
1880        let b = bregman_divergence(&gen, &p, &q).unwrap();
1881        let expected = 0.5
1882            * p.iter()
1883                .zip(q.iter())
1884                .map(|(&a, &b)| (a - b) * (a - b))
1885                .sum::<f64>();
1886        assert!((b - expected).abs() < 1e-12);
1887    }
1888
1889    // --- Entropy tests ---
1890
1891    #[test]
1892    fn entropy_nats_uniform_is_ln_n() {
1893        // Uniform distribution over n items: H = ln(n)
1894        for n in [2, 4, 8, 16] {
1895            let p: Vec<f64> = vec![1.0 / n as f64; n];
1896            let h = entropy_nats(&p, TOL).unwrap();
1897            let expected = (n as f64).ln();
1898            assert!(
1899                (h - expected).abs() < 1e-12,
1900                "n={n}: h={h} expected={expected}"
1901            );
1902        }
1903    }
1904
1905    #[test]
1906    fn entropy_nats_singleton_is_zero() {
1907        let h = entropy_nats(&[1.0], TOL).unwrap();
1908        assert!(h.abs() < 1e-15);
1909    }
1910
1911    #[test]
1912    fn entropy_bits_converts_correctly() {
1913        let p = [0.25, 0.75];
1914        let nats = entropy_nats(&p, TOL).unwrap();
1915        let bits = entropy_bits(&p, TOL).unwrap();
1916        assert!((bits - nats / core::f64::consts::LN_2).abs() < 1e-12);
1917    }
1918
1919    // --- Cross-entropy tests ---
1920
1921    #[test]
1922    fn cross_entropy_identity_h_pq_eq_h_p_plus_kl() {
1923        let p = [0.3, 0.7];
1924        let q = [0.5, 0.5];
1925        let h_pq = cross_entropy_nats(&p, &q, TOL).unwrap();
1926        let h_p = entropy_nats(&p, TOL).unwrap();
1927        let kl = kl_divergence(&p, &q, TOL).unwrap();
1928        assert!((h_pq - (h_p + kl)).abs() < 1e-12);
1929    }
1930
1931    #[test]
1932    fn cross_entropy_rejects_zero_q_with_positive_p() {
1933        let p = [0.5, 0.5];
1934        let q = [1.0, 0.0]; // q[1]=0 but p[1]=0.5
1935        assert!(cross_entropy_nats(&p, &q, TOL).is_err());
1936    }
1937
1938    // --- Validate / normalize tests ---
1939
1940    #[test]
1941    fn validate_simplex_accepts_valid() {
1942        assert!(validate_simplex(&[0.3, 0.7], TOL).is_ok());
1943        assert!(validate_simplex(&[1.0], TOL).is_ok());
1944    }
1945
1946    #[test]
1947    fn validate_simplex_rejects_bad_sum() {
1948        assert!(validate_simplex(&[0.3, 0.6], TOL).is_err()); // sum=0.9
1949    }
1950
1951    #[test]
1952    fn validate_simplex_rejects_negative() {
1953        assert!(validate_simplex(&[1.5, -0.5], TOL).is_err());
1954    }
1955
1956    #[test]
1957    fn validate_simplex_rejects_empty() {
1958        assert!(validate_simplex(&[], TOL).is_err());
1959    }
1960
1961    #[test]
1962    fn normalize_in_place_works() {
1963        let mut v = vec![2.0, 3.0];
1964        let s = normalize_in_place(&mut v).unwrap();
1965        assert!((s - 5.0).abs() < 1e-12);
1966        assert!((v[0] - 0.4).abs() < 1e-12);
1967        assert!((v[1] - 0.6).abs() < 1e-12);
1968    }
1969
1970    #[test]
1971    fn normalize_in_place_rejects_zero_sum() {
1972        let mut v = vec![0.0, 0.0];
1973        assert!(normalize_in_place(&mut v).is_err());
1974    }
1975
1976    // --- Hellinger / Bhattacharyya tests ---
1977
1978    #[test]
1979    fn hellinger_identical_is_zero() {
1980        let p = [0.25, 0.75];
1981        let h = hellinger(&p, &p, TOL).unwrap();
1982        assert!(h.abs() < 1e-12);
1983    }
1984
1985    #[test]
1986    fn hellinger_squared_in_unit_interval() {
1987        let p = [0.1, 0.9];
1988        let q = [0.9, 0.1];
1989        let h2 = hellinger_squared(&p, &q, TOL).unwrap();
1990        assert!((-1e-12..=1.0 + 1e-12).contains(&h2), "h2={h2}");
1991    }
1992
1993    #[test]
1994    fn bhattacharyya_coeff_identical_is_one() {
1995        let p = [0.3, 0.7];
1996        let bc = bhattacharyya_coeff(&p, &p, TOL).unwrap();
1997        assert!((bc - 1.0).abs() < 1e-12);
1998    }
1999
2000    #[test]
2001    fn bhattacharyya_distance_identical_is_zero() {
2002        let p = [0.5, 0.5];
2003        let d = bhattacharyya_distance(&p, &p, TOL).unwrap();
2004        assert!(d.abs() < 1e-12);
2005    }
2006
2007    // --- Renyi / Tsallis tests ---
2008
2009    #[test]
2010    fn renyi_alpha_half_on_simple_case() {
2011        let p = [0.5, 0.5];
2012        let q = [0.25, 0.75];
2013        // alpha=0.5 should be well-defined and non-negative
2014        let r = renyi_divergence(&p, &q, 0.5, TOL).unwrap();
2015        assert!(r >= -1e-12, "renyi={r}");
2016    }
2017
2018    #[test]
2019    fn renyi_identical_is_zero() {
2020        let p = [0.3, 0.7];
2021        let r = renyi_divergence(&p, &p, 2.0, TOL).unwrap();
2022        assert!(r.abs() < 1e-12, "renyi(p,p)={r}");
2023    }
2024
2025    #[test]
2026    fn tsallis_identical_is_zero() {
2027        let p = [0.4, 0.6];
2028        let t = tsallis_divergence(&p, &p, 2.0, TOL).unwrap();
2029        assert!(t.abs() < 1e-12, "tsallis(p,p)={t}");
2030    }
2031
2032    // --- Digamma test ---
2033
2034    #[test]
2035    fn digamma_at_one_is_neg_euler_mascheroni() {
2036        let psi1 = digamma(1.0);
2037        // digamma(1) = -gamma where gamma ~= 0.57721566490153286
2038        assert!(
2039            (psi1 - (-0.57721566490153286)).abs() < 1e-12,
2040            "psi(1)={psi1}"
2041        );
2042    }
2043
2044    #[test]
2045    fn digamma_recurrence_relation() {
2046        // digamma(x+1) = digamma(x) + 1/x
2047        for &x in &[1.0, 2.0, 3.5, 10.0] {
2048            let lhs = digamma(x + 1.0);
2049            let rhs = digamma(x) + 1.0 / x;
2050            assert!(
2051                (lhs - rhs).abs() < 1e-12,
2052                "recurrence at x={x}: {lhs} vs {rhs}"
2053            );
2054        }
2055    }
2056
2057    #[test]
2058    fn pmi_independent_is_zero() {
2059        // PMI(x,y) = log(p(x,y) / (p(x)*p(y))). If independent: p(x,y) = p(x)*p(y)
2060        let pmi_val = pmi(0.06, 0.3, 0.2).unwrap(); // 0.3 * 0.2 = 0.06
2061        assert!(
2062            pmi_val.abs() < 1e-10,
2063            "PMI of independent events should be 0: {pmi_val}"
2064        );
2065    }
2066
2067    #[test]
2068    fn pmi_positive_for_correlated() {
2069        // If p(x,y) > p(x)*p(y), events are positively correlated
2070        let pmi_val = pmi(0.4, 0.5, 0.5).unwrap(); // 0.4 > 0.5*0.5 = 0.25
2071        assert!(
2072            pmi_val > 0.0,
2073            "correlated events should have positive PMI: {pmi_val}"
2074        );
2075    }
2076
2077    #[test]
2078    fn renyi_approaches_kl_as_alpha_to_one() {
2079        let p = [0.3, 0.7];
2080        let q = [0.5, 0.5];
2081        let tol = 1e-9;
2082        let kl = kl_divergence(&p, &q, tol).unwrap();
2083        // Renyi(alpha) -> KL as alpha -> 1
2084        let r099 = renyi_divergence(&p, &q, 0.99, tol).unwrap();
2085        let r0999 = renyi_divergence(&p, &q, 0.999, tol).unwrap();
2086        assert!((r099 - kl).abs() < 0.01, "Renyi(0.99)={r099}, KL={kl}");
2087        assert!((r0999 - kl).abs() < 0.001, "Renyi(0.999)={r0999}, KL={kl}");
2088    }
2089
2090    #[test]
2091    fn amari_alpha_neg1_is_kl_forward() {
2092        // Amari alpha=-1 returns KL(p||q) per the implementation
2093        let p = [0.3, 0.7];
2094        let q = [0.5, 0.5];
2095        let tol = 1e-9;
2096        let kl_pq = kl_divergence(&p, &q, tol).unwrap();
2097        let amari = amari_alpha_divergence(&p, &q, -1.0, tol).unwrap();
2098        assert!(
2099            (amari - kl_pq).abs() < 1e-6,
2100            "Amari(-1)={amari}, KL(p||q)={kl_pq}"
2101        );
2102    }
2103
2104    #[test]
2105    fn amari_alpha_pos1_is_kl_reverse() {
2106        // Amari alpha=+1 returns KL(q||p) per the implementation
2107        let p = [0.3, 0.7];
2108        let q = [0.5, 0.5];
2109        let tol = 1e-9;
2110        let kl_qp = kl_divergence(&q, &p, tol).unwrap();
2111        let amari = amari_alpha_divergence(&p, &q, 1.0, tol).unwrap();
2112        assert!(
2113            (amari - kl_qp).abs() < 1e-6,
2114            "Amari(1)={amari}, KL(q||p)={kl_qp}"
2115        );
2116    }
2117
2118    #[test]
2119    fn csiszar_with_kl_generator_matches_kl() {
2120        // f(t) = t*ln(t) gives KL divergence
2121        let p = [0.3, 0.7];
2122        let q = [0.5, 0.5];
2123        let tol = 1e-9;
2124        let kl = kl_divergence(&p, &q, tol).unwrap();
2125        let cs = csiszar_f_divergence(&p, &q, |t| t * t.ln(), tol).unwrap();
2126        assert!((cs - kl).abs() < 1e-6, "Csiszar(t*ln(t))={cs}, KL={kl}");
2127    }
2128
2129    #[test]
2130    fn mutual_information_deterministic_equals_entropy() {
2131        // If Y = f(X), MI(X;Y) = H(X)
2132        // Joint: p(x=0,y=0)=0.3, p(x=1,y=1)=0.7
2133        let p_xy = [0.3, 0.0, 0.0, 0.7]; // 2x2 joint
2134        let mi = mutual_information(&p_xy, 2, 2, 1e-9).unwrap();
2135        let h_x = entropy_nats(&[0.3, 0.7], 1e-9).unwrap();
2136        assert!((mi - h_x).abs() < 1e-6, "MI={mi}, H(X)={h_x}");
2137    }
2138
2139    proptest! {
2140        #[test]
2141        fn kl_is_nonnegative(p in simplex_vec_pos(8, 1e-6), q in simplex_vec_pos(8, 1e-6)) {
2142            let d = kl_divergence(&p, &q, 1e-6).unwrap();
2143            prop_assert!(d >= -1e-12);
2144        }
2145
2146        #[test]
2147        fn js_is_bounded(p in simplex_vec(16), q in simplex_vec(16)) {
2148            let js = jensen_shannon_divergence(&p, &q, 1e-6).unwrap();
2149            prop_assert!(js >= -1e-12);
2150            prop_assert!(js <= core::f64::consts::LN_2 + 1e-9);
2151        }
2152
2153        #[test]
2154        fn prop_kl_gaussians_is_nonnegative(
2155            mu1 in prop::collection::vec(-10.0f64..10.0, 1..16),
2156            std1 in prop::collection::vec(0.1f64..5.0, 1..16),
2157            mu2 in prop::collection::vec(-10.0f64..10.0, 1..16),
2158            std2 in prop::collection::vec(0.1f64..5.0, 1..16),
2159        ) {
2160            let n = mu1.len().min(std1.len()).min(mu2.len()).min(std2.len());
2161            let d = kl_divergence_gaussians(&mu1[..n], &std1[..n], &mu2[..n], &std2[..n]).unwrap();
2162            // KL divergence is always non-negative.
2163            prop_assert!(d >= -1e-12);
2164        }
2165
2166        #[test]
2167        fn prop_kl_gaussians_is_zero_for_identical(
2168            mu in prop::collection::vec(-10.0f64..10.0, 1..16),
2169            std in prop::collection::vec(0.1f64..5.0, 1..16),
2170        ) {
2171            let n = mu.len().min(std.len());
2172            let d = kl_divergence_gaussians(&mu[..n], &std[..n], &mu[..n], &std[..n]).unwrap();
2173            prop_assert!(d.abs() < 1e-12);
2174        }
2175
2176        #[test]
2177        fn f_divergence_monotone_under_coarse_graining(
2178            p in simplex_vec_pos(12, 1e-6),
2179            q in simplex_vec_pos(12, 1e-6),
2180            labels in random_partition(12),
2181        ) {
2182            // Use KL as an f-divergence instance: f(t)=t ln t.
2183            // D_KL(p||q) = Σ q_i f(p_i/q_i).
2184            let f = |t: f64| if t == 0.0 { 0.0 } else { t * t.ln() };
2185            let d_f = csiszar_f_divergence(&p, &q, f, 1e-6).unwrap();
2186
2187            let pc = coarse_grain(&p, &labels);
2188            let qc = coarse_grain(&q, &labels);
2189            let d_fc = csiszar_f_divergence(&pc, &qc, f, 1e-6).unwrap();
2190
2191            // Coarse graining should not increase.
2192            prop_assert!(d_fc <= d_f + 1e-9);
2193        }
2194    }
2195
2196    // Heavier “theorem-ish” checks: keep case count modest so `cargo test` stays fast.
2197    proptest! {
2198        #![proptest_config(ProptestConfig { cases: 64, .. ProptestConfig::default() })]
2199
2200        #[test]
2201        fn pinsker_kl_lower_bounds_l1_squared(
2202            p in simplex_vec_pos(16, 1e-6),
2203            q in simplex_vec_pos(16, 1e-6),
2204        ) {
2205            // Pinsker: TV(p,q)^2 <= (1/2) KL(p||q)
2206            // where TV = (1/2)||p-q||_1. Rearranged: KL(p||q) >= 0.5 * ||p-q||_1^2.
2207            let kl = kl_divergence(&p, &q, 1e-6).unwrap();
2208            let d1 = l1(&p, &q);
2209            prop_assert!(kl + 1e-9 >= 0.5 * d1 * d1, "kl={kl} l1={d1}");
2210        }
2211
2212        #[test]
2213        fn sqrt_js_satisfies_triangle_inequality(
2214            p in simplex_vec(12),
2215            q in simplex_vec(12),
2216            r in simplex_vec(12),
2217        ) {
2218            // Known fact: sqrt(JS) is a metric on the simplex.
2219            let js_pq = jensen_shannon_divergence(&p, &q, 1e-6).unwrap().max(0.0).sqrt();
2220            let js_qr = jensen_shannon_divergence(&q, &r, 1e-6).unwrap().max(0.0).sqrt();
2221            let js_pr = jensen_shannon_divergence(&p, &r, 1e-6).unwrap().max(0.0).sqrt();
2222            prop_assert!(js_pr <= js_pq + js_qr + 1e-7, "js_pr={js_pr} js_pq+js_qr={}", js_pq+js_qr);
2223        }
2224
2225        #[test]
2226        fn mutual_information_equals_kl_to_product(
2227            // Ensure strictly positive so KL domains are satisfied.
2228            p_xy in simplex_vec_pos(16, 1e-6),
2229            nx in 2usize..=4,
2230            ny in 2usize..=4,
2231        ) {
2232            // We need p_xy to have length nx*ny; we will truncate/renormalize a fixed-length draw.
2233            let n = nx * ny;
2234            let mut joint = p_xy;
2235            joint.truncate(n);
2236            // Renormalize after truncation.
2237            let _ = normalize_in_place(&mut joint).unwrap();
2238
2239            // Compute MI via the dedicated function.
2240            let mi = mutual_information(&joint, nx, ny, 1e-6).unwrap();
2241
2242            // Compute product of marginals and KL(joint || product).
2243            let mut p_x = vec![0.0; nx];
2244            let mut p_y = vec![0.0; ny];
2245            for i in 0..nx {
2246                for j in 0..ny {
2247                    let p = joint[i * ny + j];
2248                    p_x[i] += p;
2249                    p_y[j] += p;
2250                }
2251            }
2252            let mut prod = vec![0.0; n];
2253            for i in 0..nx {
2254                for j in 0..ny {
2255                    prod[i * ny + j] = p_x[i] * p_y[j];
2256                }
2257            }
2258            let kl = kl_divergence(&joint, &prod, 1e-6).unwrap();
2259
2260            prop_assert!((mi - kl).abs() < 1e-9, "mi={mi} kl={kl}");
2261        }
2262
2263        #[test]
2264        fn hellinger_satisfies_triangle_inequality(
2265            p in simplex_vec(8),
2266            q in simplex_vec(8),
2267            r in simplex_vec(8),
2268        ) {
2269            let h_pq = hellinger(&p, &q, 1e-6).unwrap();
2270            let h_qr = hellinger(&q, &r, 1e-6).unwrap();
2271            let h_pr = hellinger(&p, &r, 1e-6).unwrap();
2272            prop_assert!(h_pr <= h_pq + h_qr + 1e-7, "h_pr={h_pr} h_pq+h_qr={}", h_pq + h_qr);
2273        }
2274    }
2275
2276    // --- total_bregman_divergence ---
2277
2278    #[test]
2279    fn total_bregman_le_bregman() {
2280        // tB_F(p, q) <= B_F(p, q) because the denominator sqrt(1 + ||grad||^2) >= 1.
2281        let gen = SquaredL2;
2282        let p = [1.0, 2.0, 3.0];
2283        let q = [4.0, 5.0, 6.0];
2284        let b = bregman_divergence(&gen, &p, &q).unwrap();
2285        let tb = total_bregman_divergence(&gen, &p, &q).unwrap();
2286        assert!(tb <= b + 1e-12, "total_bregman={tb} > bregman={b}");
2287        assert!(tb >= 0.0);
2288    }
2289
2290    #[test]
2291    fn total_bregman_is_zero_for_identical() {
2292        let gen = SquaredL2;
2293        let p = [1.0, 2.0];
2294        let tb = total_bregman_divergence(&gen, &p, &p).unwrap();
2295        assert!(tb.abs() < 1e-15);
2296    }
2297
2298    // --- rho_alpha ---
2299
2300    #[test]
2301    fn rho_alpha_self_is_one() {
2302        let p = [0.1, 0.2, 0.3, 0.4];
2303        for alpha in [0.0, 0.25, 0.5, 0.75, 1.0, 2.0, -1.0] {
2304            let r = rho_alpha(&p, &p, alpha, TOL).unwrap();
2305            assert!((r - 1.0).abs() < 1e-10, "rho_alpha(p,p,{alpha})={r}");
2306        }
2307    }
2308
2309    // --- digamma negative domain ---
2310
2311    #[test]
2312    fn digamma_nonpositive_is_nan() {
2313        assert!(digamma(0.0).is_nan());
2314        assert!(digamma(-1.0).is_nan());
2315        assert!(digamma(-100.0).is_nan());
2316    }
2317
2318    // --- pmi edge cases ---
2319
2320    #[test]
2321    fn pmi_zero_joint_returns_zero() {
2322        assert_eq!(pmi(0.0, 0.5, 0.5).unwrap(), 0.0);
2323    }
2324
2325    #[test]
2326    fn pmi_zero_marginal_with_zero_joint_returns_zero() {
2327        // When px or py is zero AND pxy is also zero, return 0 by convention.
2328        assert_eq!(pmi(0.0, 0.0, 0.5).unwrap(), 0.0);
2329        assert_eq!(pmi(0.0, 0.5, 0.0).unwrap(), 0.0);
2330    }
2331
2332    #[test]
2333    fn pmi_all_zero_returns_zero() {
2334        assert_eq!(pmi(0.0, 0.0, 0.0).unwrap(), 0.0);
2335    }
2336
2337    // --- Enrich-motivated tests ---
2338
2339    #[test]
2340    fn digamma_at_dlmf_reference_values() {
2341        // psi(0.5) = -gamma - 2*ln(2)
2342        let gamma = 0.57721566490153286;
2343        let expected_half = -gamma - 2.0 * core::f64::consts::LN_2;
2344        let psi_half = digamma(0.5);
2345        assert!(
2346            (psi_half - expected_half).abs() < 1e-12,
2347            "psi(0.5)={psi_half} expected={expected_half}"
2348        );
2349
2350        // psi(2) = 1 - gamma
2351        let expected_2 = 1.0 - gamma;
2352        let psi_2 = digamma(2.0);
2353        assert!(
2354            (psi_2 - expected_2).abs() < 1e-12,
2355            "psi(2)={psi_2} expected={expected_2}"
2356        );
2357
2358        // psi(3) = 1 + 1/2 - gamma
2359        let expected_3 = 1.5 - gamma;
2360        let psi_3 = digamma(3.0);
2361        assert!(
2362            (psi_3 - expected_3).abs() < 1e-12,
2363            "psi(3)={psi_3} expected={expected_3}"
2364        );
2365
2366        // psi(4) = 1 + 1/2 + 1/3 - gamma
2367        let expected_4 = 1.0 + 0.5 + 1.0 / 3.0 - gamma;
2368        let psi_4 = digamma(4.0);
2369        assert!(
2370            (psi_4 - expected_4).abs() < 1e-12,
2371            "psi(4)={psi_4} expected={expected_4}"
2372        );
2373    }
2374
2375    #[test]
2376    fn tsallis_approaches_kl_as_alpha_to_one() {
2377        let p = [0.3, 0.7];
2378        let q = [0.5, 0.5];
2379        let tol = 1e-9;
2380        let kl = kl_divergence(&p, &q, tol).unwrap();
2381        // Tsallis(alpha) -> KL as alpha -> 1 (from both sides)
2382        let t099 = tsallis_divergence(&p, &q, 0.99, tol).unwrap();
2383        let t0999 = tsallis_divergence(&p, &q, 0.999, tol).unwrap();
2384        let t101 = tsallis_divergence(&p, &q, 1.01, tol).unwrap();
2385        let t1001 = tsallis_divergence(&p, &q, 1.001, tol).unwrap();
2386        assert!((t099 - kl).abs() < 0.01, "Tsallis(0.99)={t099}, KL={kl}");
2387        assert!(
2388            (t0999 - kl).abs() < 0.001,
2389            "Tsallis(0.999)={t0999}, KL={kl}"
2390        );
2391        assert!((t101 - kl).abs() < 0.01, "Tsallis(1.01)={t101}, KL={kl}");
2392        assert!(
2393            (t1001 - kl).abs() < 0.001,
2394            "Tsallis(1.001)={t1001}, KL={kl}"
2395        );
2396    }
2397
2398    #[test]
2399    fn renyi_approaches_kl_from_above() {
2400        let p = [0.3, 0.7];
2401        let q = [0.5, 0.5];
2402        let tol = 1e-9;
2403        let kl = kl_divergence(&p, &q, tol).unwrap();
2404        let r101 = renyi_divergence(&p, &q, 1.01, tol).unwrap();
2405        let r1001 = renyi_divergence(&p, &q, 1.001, tol).unwrap();
2406        assert!((r101 - kl).abs() < 0.01, "Renyi(1.01)={r101}, KL={kl}");
2407        assert!((r1001 - kl).abs() < 0.001, "Renyi(1.001)={r1001}, KL={kl}");
2408    }
2409
2410    #[test]
2411    fn renyi_at_half_equals_neg2_ln_bc() {
2412        // D_{1/2}^R(p || q) = -2 * ln(BC(p, q))
2413        let p = [0.2, 0.3, 0.5];
2414        let q = [0.4, 0.4, 0.2];
2415        let tol = 1e-9;
2416        let renyi_half = renyi_divergence(&p, &q, 0.5, tol).unwrap();
2417        let bc = bhattacharyya_coeff(&p, &q, tol).unwrap();
2418        let expected = -2.0 * bc.ln();
2419        assert!(
2420            (renyi_half - expected).abs() < 1e-10,
2421            "Renyi(0.5)={renyi_half}, -2*ln(BC)={expected}"
2422        );
2423    }
2424
2425    #[test]
2426    fn hellinger_squared_equals_one_minus_bc() {
2427        let p = [0.1, 0.4, 0.5];
2428        let q = [0.3, 0.3, 0.4];
2429        let tol = 1e-9;
2430        let h2 = hellinger_squared(&p, &q, tol).unwrap();
2431        let bc = bhattacharyya_coeff(&p, &q, tol).unwrap();
2432        assert!(
2433            (h2 - (1.0 - bc)).abs() < 1e-12,
2434            "H^2={h2}, 1-BC={}",
2435            1.0 - bc
2436        );
2437    }
2438
2439    #[test]
2440    fn csiszar_hellinger_generator_matches_twice_hellinger_squared() {
2441        // f(t) = (sqrt(t) - 1)^2 gives sum(q_i*(sqrt(p_i/q_i)-1)^2)
2442        // = sum((sqrt(p_i) - sqrt(q_i))^2) = 2*(1 - BC) = 2*H^2.
2443        let p = [0.2, 0.3, 0.5];
2444        let q = [0.4, 0.4, 0.2];
2445        let tol = 1e-9;
2446        let h2 = hellinger_squared(&p, &q, tol).unwrap();
2447        let cs = csiszar_f_divergence(&p, &q, |t| (t.sqrt() - 1.0).powi(2), tol).unwrap();
2448        assert!(
2449            (cs - 2.0 * h2).abs() < 1e-10,
2450            "Csiszar(Hellinger)={cs}, 2*H^2={}",
2451            2.0 * h2
2452        );
2453    }
2454
2455    #[test]
2456    fn csiszar_chi_squared_generator_is_nonneg() {
2457        // f(t) = (t - 1)^2 gives chi-squared divergence.
2458        let p = [0.2, 0.3, 0.5];
2459        let q = [0.4, 0.4, 0.2];
2460        let tol = 1e-9;
2461        let chi2 = csiszar_f_divergence(&p, &q, |t| (t - 1.0).powi(2), tol).unwrap();
2462        assert!(chi2 >= 0.0, "chi2={chi2}");
2463        // Chi-squared(p,p) = 0
2464        let chi2_self = csiszar_f_divergence(&p, &p, |t| (t - 1.0).powi(2), tol).unwrap();
2465        assert!(chi2_self.abs() < 1e-12, "chi2(p,p)={chi2_self}");
2466    }
2467
2468    #[test]
2469    fn near_boundary_inputs_no_nan() {
2470        // Distributions with entries near machine epsilon.
2471        let tiny = 1e-300;
2472        let p = [tiny, 1.0 - tiny];
2473        let q = [tiny * 2.0, 1.0 - tiny * 2.0];
2474        let tol = 1e-6;
2475
2476        let kl = kl_divergence(&p, &q, tol).unwrap();
2477        assert!(kl.is_finite(), "kl={kl}");
2478        // Allow tiny negative from floating-point at extreme values.
2479        assert!(kl >= -1e-12, "kl negative: {kl}");
2480
2481        let js = jensen_shannon_divergence(&p, &q, tol).unwrap();
2482        assert!(js.is_finite(), "js={js}");
2483
2484        let h = hellinger(&p, &q, tol).unwrap();
2485        assert!(h.is_finite(), "hellinger={h}");
2486
2487        let bc = bhattacharyya_coeff(&p, &q, tol).unwrap();
2488        assert!(bc.is_finite(), "bc={bc}");
2489
2490        let ent = entropy_nats(&p, tol).unwrap();
2491        assert!(ent.is_finite(), "entropy={ent}");
2492    }
2493
2494    proptest! {
2495        #![proptest_config(ProptestConfig { cases: 64, .. ProptestConfig::default() })]
2496
2497        #[test]
2498        fn entropy_is_concave(
2499            p in simplex_vec(8),
2500            q in simplex_vec(8),
2501            lambda in 0.0f64..=1.0,
2502        ) {
2503            // H(lambda*p + (1-lambda)*q) >= lambda*H(p) + (1-lambda)*H(q)
2504            let mix: Vec<f64> = p.iter().zip(q.iter())
2505                .map(|(&pi, &qi)| lambda * pi + (1.0 - lambda) * qi)
2506                .collect();
2507            let h_mix = entropy_nats(&mix, 1e-6).unwrap();
2508            let h_p = entropy_nats(&p, 1e-6).unwrap();
2509            let h_q = entropy_nats(&q, 1e-6).unwrap();
2510            let rhs = lambda * h_p + (1.0 - lambda) * h_q;
2511            prop_assert!(h_mix + 1e-10 >= rhs, "h_mix={h_mix} rhs={rhs}");
2512        }
2513
2514        #[test]
2515        fn renyi_monotone_in_alpha(
2516            p in simplex_vec_pos(8, 1e-6),
2517            q in simplex_vec_pos(8, 1e-6),
2518        ) {
2519            // Renyi divergence is non-decreasing in alpha for fixed p, q.
2520            let alphas = [0.1, 0.25, 0.5, 0.75, 0.99];
2521            let vals: Vec<f64> = alphas.iter()
2522                .map(|&a| renyi_divergence(&p, &q, a, 1e-6).unwrap())
2523                .collect();
2524            for i in 1..vals.len() {
2525                prop_assert!(
2526                    vals[i] + 1e-9 >= vals[i - 1],
2527                    "Renyi not monotone: D({})={} < D({})={}",
2528                    alphas[i], vals[i], alphas[i - 1], vals[i - 1]
2529                );
2530            }
2531        }
2532
2533        #[test]
2534        fn cross_entropy_decomposition(
2535            p in simplex_vec_pos(8, 1e-6),
2536            q in simplex_vec_pos(8, 1e-6),
2537        ) {
2538            // H(p, q) = H(p) + KL(p || q)
2539            let h_pq = cross_entropy_nats(&p, &q, 1e-6).unwrap();
2540            let h_p = entropy_nats(&p, 1e-6).unwrap();
2541            let kl = kl_divergence(&p, &q, 1e-6).unwrap();
2542            prop_assert!(
2543                (h_pq - (h_p + kl)).abs() < 1e-9,
2544                "H(p,q)={h_pq} != H(p)+KL={}", h_p + kl
2545            );
2546        }
2547
2548        #[test]
2549        fn bhattacharyya_renyi_consistency(
2550            p in simplex_vec_pos(8, 1e-6),
2551            q in simplex_vec_pos(8, 1e-6),
2552        ) {
2553            // D_{1/2}^R(p || q) = -2 * ln(BC(p, q))
2554            let renyi_half = renyi_divergence(&p, &q, 0.5, 1e-6).unwrap();
2555            let bc = bhattacharyya_coeff(&p, &q, 1e-6).unwrap();
2556            let expected = -2.0 * bc.ln();
2557            prop_assert!(
2558                (renyi_half - expected).abs() < 1e-8,
2559                "Renyi(0.5)={renyi_half}, -2*ln(BC)={expected}"
2560            );
2561        }
2562
2563        #[test]
2564        fn csiszar_hellinger_consistency(
2565            p in simplex_vec_pos(8, 1e-6),
2566            q in simplex_vec_pos(8, 1e-6),
2567        ) {
2568            // Csiszar with f(t)=(sqrt(t)-1)^2 = 2 * squared Hellinger
2569            let h2 = hellinger_squared(&p, &q, 1e-6).unwrap();
2570            let cs = csiszar_f_divergence(&p, &q, |t| (t.sqrt() - 1.0).powi(2), 1e-6).unwrap();
2571            prop_assert!(
2572                (cs - 2.0 * h2).abs() < 1e-8,
2573                "Csiszar(Hellinger)={cs}, 2*H^2={}", 2.0 * h2
2574            );
2575        }
2576
2577        #[test]
2578        fn pinsker_tightness_for_nearby_distributions(
2579            p in simplex_vec_pos(8, 1e-6),
2580        ) {
2581            // For a distribution slightly perturbed from p, check Pinsker is non-vacuous.
2582            // q = 0.99*p + 0.01*uniform (epsilon-perturbation).
2583            let n = p.len();
2584            let q: Vec<f64> = p.iter().map(|&pi| 0.99 * pi + 0.01 / n as f64).collect();
2585            let kl = kl_divergence(&p, &q, 1e-6).unwrap();
2586            let d1: f64 = p.iter().zip(q.iter()).map(|(&a, &b)| (a - b).abs()).sum();
2587            let pinsker_rhs = 0.5 * d1 * d1;
2588            // Pinsker: KL >= 0.5 * L1^2
2589            prop_assert!(kl + 1e-12 >= pinsker_rhs, "kl={kl} pinsker_rhs={pinsker_rhs}");
2590            // Non-vacuity: for nearby distributions, KL / (0.5*L1^2) should be O(1),
2591            // not astronomically large (would indicate the bound is useful).
2592            if pinsker_rhs > 1e-20 {
2593                let ratio = kl / pinsker_rhs;
2594                prop_assert!(ratio < 1000.0, "Pinsker ratio too large: {ratio}");
2595            }
2596        }
2597
2598        #[test]
2599        fn total_variation_satisfies_triangle(
2600            p in simplex_vec(8),
2601            q in simplex_vec(8),
2602            r in simplex_vec(8),
2603        ) {
2604            let tv_pq = total_variation(&p, &q, 1e-6).unwrap();
2605            let tv_qr = total_variation(&q, &r, 1e-6).unwrap();
2606            let tv_pr = total_variation(&p, &r, 1e-6).unwrap();
2607            prop_assert!(tv_pr <= tv_pq + tv_qr + 1e-10);
2608        }
2609
2610        #[test]
2611        fn chi_squared_matches_csiszar(
2612            p in simplex_vec_pos(8, 1e-6),
2613            q in simplex_vec_pos(8, 1e-6),
2614        ) {
2615            let chi2 = chi_squared_divergence(&p, &q, 1e-6).unwrap();
2616            let cs = csiszar_f_divergence(&p, &q, |t| (t - 1.0).powi(2), 1e-6).unwrap();
2617            prop_assert!(
2618                (chi2 - cs).abs() < 1e-8,
2619                "chi2={chi2}, csiszar={cs}"
2620            );
2621        }
2622
2623        #[test]
2624        fn renyi_entropy_monotone_in_alpha(
2625            p in simplex_vec_pos(8, 1e-6),
2626        ) {
2627            // H_alpha is non-increasing in alpha.
2628            let alphas = [0.1, 0.25, 0.5, 0.75, 0.99];
2629            let vals: Vec<f64> = alphas.iter()
2630                .map(|&a| renyi_entropy(&p, a, 1e-6).unwrap())
2631                .collect();
2632            for i in 1..vals.len() {
2633                prop_assert!(
2634                    vals[i] <= vals[i - 1] + 1e-9,
2635                    "H_alpha not monotone: H({})={} > H({})={}",
2636                    alphas[i], vals[i], alphas[i - 1], vals[i - 1]
2637                );
2638            }
2639        }
2640    }
2641
2642    // --- New function unit tests ---
2643
2644    #[test]
2645    fn weighted_js_at_extreme_weights() {
2646        let p = [0.3, 0.7];
2647        let q = [0.5, 0.5];
2648        // pi1 = 0 means m = q, KL(q || q) = 0
2649        let js0 = jensen_shannon_weighted(&p, &q, 0.0, TOL).unwrap();
2650        assert!(js0.abs() < 1e-12, "JS(pi=0)={js0}");
2651    }
2652
2653    #[test]
2654    fn conditional_entropy_chain_rule() {
2655        // H(X|Y) = H(X,Y) - H(Y)
2656        let p_xy = [0.2, 0.1, 0.3, 0.4]; // 2x2
2657        let h_xy = entropy_nats(&p_xy, TOL).unwrap();
2658        let p_y = [p_xy[0] + p_xy[2], p_xy[1] + p_xy[3]];
2659        let h_y = entropy_nats(&p_y, TOL).unwrap();
2660        let h_x_given_y = conditional_entropy(&p_xy, 2, 2, TOL).unwrap();
2661        assert!(
2662            (h_x_given_y - (h_xy - h_y)).abs() < 1e-10,
2663            "H(X|Y)={h_x_given_y}, H(X,Y)-H(Y)={}",
2664            h_xy - h_y
2665        );
2666    }
2667
2668    #[test]
2669    fn conditional_entropy_nonnegative() {
2670        let p_xy = [0.1, 0.2, 0.3, 0.4];
2671        let h = conditional_entropy(&p_xy, 2, 2, TOL).unwrap();
2672        assert!(h >= -1e-12, "H(X|Y) negative: {h}");
2673    }
2674
2675    #[test]
2676    fn nmi_bounds() {
2677        // NMI for a general joint: should be in [0, 1].
2678        let p_xy = [0.1, 0.2, 0.3, 0.4];
2679        let nmi = normalized_mutual_information(&p_xy, 2, 2, TOL).unwrap();
2680        assert!((-1e-12..=1.0 + 1e-12).contains(&nmi), "nmi={nmi}");
2681    }
2682
2683    #[test]
2684    fn total_variation_self_is_zero() {
2685        let p = [0.3, 0.7];
2686        assert!(total_variation(&p, &p, TOL).unwrap().abs() < 1e-15);
2687    }
2688
2689    #[test]
2690    fn total_variation_disjoint_is_one() {
2691        let a = [1.0, 0.0];
2692        let b = [0.0, 1.0];
2693        assert!((total_variation(&a, &b, TOL).unwrap() - 1.0).abs() < 1e-12);
2694    }
2695
2696    #[test]
2697    fn chi_squared_self_is_zero() {
2698        let p = [0.3, 0.7];
2699        assert!(chi_squared_divergence(&p, &p, TOL).unwrap().abs() < 1e-15);
2700    }
2701
2702    #[test]
2703    fn chi_squared_upper_bounds_kl() {
2704        // KL(p||q) <= ln(1 + chi2(p||q))
2705        let p = [0.2, 0.3, 0.5];
2706        let q = [0.4, 0.4, 0.2];
2707        let kl = kl_divergence(&p, &q, TOL).unwrap();
2708        let chi2 = chi_squared_divergence(&p, &q, TOL).unwrap();
2709        assert!(
2710            kl <= (1.0 + chi2).ln() + 1e-10,
2711            "kl={kl} > ln(1+chi2)={}",
2712            (1.0 + chi2).ln()
2713        );
2714    }
2715
2716    #[test]
2717    fn renyi_entropy_uniform_is_ln_n() {
2718        let p = [0.25, 0.25, 0.25, 0.25];
2719        for alpha in [0.5, 2.0, 3.0, 10.0] {
2720            let h = renyi_entropy(&p, alpha, TOL).unwrap();
2721            let expected = 4.0_f64.ln();
2722            assert!(
2723                (h - expected).abs() < 1e-12,
2724                "H_{alpha}(uniform) = {h}, expected {expected}"
2725            );
2726        }
2727    }
2728
2729    #[test]
2730    fn tsallis_entropy_delta_is_zero() {
2731        let delta = [1.0, 0.0, 0.0];
2732        for alpha in [0.5, 2.0, 3.0] {
2733            let s = tsallis_entropy(&delta, alpha, TOL).unwrap();
2734            assert!(s.abs() < 1e-12, "Tsallis({alpha}) of delta = {s}");
2735        }
2736    }
2737
2738    #[test]
2739    fn renyi_entropy_collision() {
2740        // H_2(p) = -ln(sum(p_i^2))
2741        let p = [0.3, 0.7];
2742        let h2 = renyi_entropy(&p, 2.0, TOL).unwrap();
2743        let expected = -(0.3_f64.powi(2) + 0.7_f64.powi(2)).ln();
2744        assert!(
2745            (h2 - expected).abs() < 1e-12,
2746            "H_2={h2} expected={expected}"
2747        );
2748    }
2749
2750    #[test]
2751    fn neg_entropy_bregman_matches_kl_on_simplex() {
2752        // For normalized p, q: B_NegEntropy(p, q) = KL(p || q).
2753        let p = [0.2, 0.3, 0.5];
2754        let q = [0.4, 0.4, 0.2];
2755        let kl = kl_divergence(&p, &q, TOL).unwrap();
2756        let gen = NegEntropy;
2757        let breg = bregman_divergence(&gen, &p, &q).unwrap();
2758        assert!(
2759            (breg - kl).abs() < 1e-10,
2760            "Bregman(NegEntropy)={breg}, KL={kl}"
2761        );
2762    }
2763
2764    #[test]
2765    fn neg_entropy_bregman_self_is_zero() {
2766        let p = [0.3, 0.7];
2767        let gen = NegEntropy;
2768        let breg = bregman_divergence(&gen, &p, &p).unwrap();
2769        assert!(breg.abs() < 1e-14, "Bregman(p,p)={breg}");
2770    }
2771
2772    // --- Streaming log-sum-exp tests ---
2773
2774    #[test]
2775    fn log_sum_exp_iter_matches_slice() {
2776        let values = [1.0, 2.0, 3.0, -1.0, 0.5];
2777        let lse_slice = log_sum_exp(&values);
2778        let lse_iter = log_sum_exp_iter(values.iter().copied());
2779        assert!(
2780            (lse_slice - lse_iter).abs() < 1e-12,
2781            "slice={lse_slice} iter={lse_iter}"
2782        );
2783    }
2784
2785    #[test]
2786    fn log_sum_exp_iter_empty() {
2787        assert_eq!(log_sum_exp_iter(std::iter::empty()), f64::NEG_INFINITY);
2788    }
2789
2790    #[test]
2791    fn log_sum_exp_iter_single() {
2792        assert_eq!(log_sum_exp_iter(std::iter::once(42.0)), 42.0);
2793    }
2794
2795    #[test]
2796    fn log_sum_exp_iter_large_values() {
2797        // Same stability test as log_sum_exp: dominated term.
2798        let lse = log_sum_exp_iter([1000.0, 0.0].iter().copied());
2799        assert!((lse - 1000.0).abs() < 1e-10);
2800    }
2801
2802    // --- Data processing inequality for discrete MI ---
2803
2804    #[test]
2805    fn data_processing_inequality_mi() {
2806        // Apply a deterministic coarse-graining (Markov kernel) to Y.
2807        // MI(X; f(Y)) <= MI(X; Y).
2808        let p_xy = [0.3, 0.1, 0.05, 0.05, 0.1, 0.2, 0.05, 0.15]; // 2x4
2809        let n_x = 2;
2810        let n_y = 4;
2811        let mi_full = mutual_information(&p_xy, n_x, n_y, TOL).unwrap();
2812
2813        // Coarse-grain Y: merge bins 0+1 and 2+3 -> 2x2.
2814        let mut p_coarse = [0.0; 4]; // 2x2
2815        for i in 0..n_x {
2816            p_coarse[i * 2] = p_xy[i * n_y] + p_xy[i * n_y + 1];
2817            p_coarse[i * 2 + 1] = p_xy[i * n_y + 2] + p_xy[i * n_y + 3];
2818        }
2819        let mi_coarse = mutual_information(&p_coarse, n_x, 2, TOL).unwrap();
2820
2821        assert!(
2822            mi_coarse <= mi_full + 1e-10,
2823            "DPI violated: MI(coarse)={mi_coarse} > MI(full)={mi_full}"
2824        );
2825    }
2826
2827    // --- Weighted JS additional tests ---
2828
2829    #[test]
2830    fn weighted_js_bounded_by_entropy_of_weights() {
2831        // JS_pi(p, q) <= H(pi) = -pi1*ln(pi1) - pi2*ln(pi2)
2832        let p = [0.1, 0.9];
2833        let q = [0.9, 0.1];
2834        let pi1 = 0.3;
2835        let jsw = jensen_shannon_weighted(&p, &q, pi1, TOL).unwrap();
2836        let pi2 = 1.0 - pi1;
2837        let h_pi = -(pi1 * pi1.ln() + pi2 * pi2.ln());
2838        assert!(jsw <= h_pi + 1e-10, "JS_pi={jsw} > H(pi)={h_pi}");
2839    }
2840
2841    // --- Renyi/Tsallis entropy approach Shannon ---
2842
2843    #[test]
2844    fn renyi_entropy_approaches_shannon_as_alpha_to_one() {
2845        let p = [0.2, 0.3, 0.5];
2846        let h_shannon = entropy_nats(&p, TOL).unwrap();
2847        let h_099 = renyi_entropy(&p, 0.99, TOL).unwrap();
2848        let h_0999 = renyi_entropy(&p, 0.999, TOL).unwrap();
2849        let h_101 = renyi_entropy(&p, 1.01, TOL).unwrap();
2850        assert!((h_099 - h_shannon).abs() < 0.01);
2851        assert!((h_0999 - h_shannon).abs() < 0.001);
2852        assert!((h_101 - h_shannon).abs() < 0.01);
2853    }
2854
2855    #[test]
2856    fn tsallis_entropy_approaches_shannon_as_alpha_to_one() {
2857        let p = [0.2, 0.3, 0.5];
2858        let h_shannon = entropy_nats(&p, TOL).unwrap();
2859        let s_099 = tsallis_entropy(&p, 0.99, TOL).unwrap();
2860        let s_0999 = tsallis_entropy(&p, 0.999, TOL).unwrap();
2861        assert!((s_099 - h_shannon).abs() < 0.01);
2862        assert!((s_0999 - h_shannon).abs() < 0.001);
2863    }
2864
2865    // --- New tests for 0.2.0 ---
2866
2867    #[test]
2868    fn bhattacharyya_precision_near_identical() {
2869        let p = [0.5 + 1e-15, 0.5 - 1e-15];
2870        let q = [0.5, 0.5];
2871        let bc = bhattacharyya_coeff(&p, &q, 1e-6).unwrap();
2872        assert!(
2873            (bc - 1.0).abs() < 1e-14,
2874            "BC should be very close to 1.0: {bc}"
2875        );
2876        let h2 = hellinger_squared(&p, &q, 1e-6).unwrap();
2877        // h2 should be tiny but not exactly zero (numerics permitting).
2878        assert!(h2 < 1e-14, "h2 should be tiny: {h2}");
2879        assert!(h2.is_finite(), "h2 should be finite");
2880    }
2881
2882    #[test]
2883    fn renyi_alpha_sweep_continuity() {
2884        let p = [0.2, 0.3, 0.5];
2885        let tol = 1e-9;
2886        let mut prev_renyi = renyi_entropy(&p, 0.5, tol).unwrap();
2887        let mut prev_tsallis = tsallis_entropy(&p, 0.5, tol).unwrap();
2888        let mut alpha = 0.6;
2889        while alpha <= 2.0 + 1e-9 {
2890            let r = renyi_entropy(&p, alpha, tol).unwrap();
2891            let t = tsallis_entropy(&p, alpha, tol).unwrap();
2892            let jump_r = (r - prev_renyi).abs();
2893            let jump_t = (t - prev_tsallis).abs();
2894            assert!(
2895                jump_r < 0.5,
2896                "Renyi discontinuity at alpha={alpha}: jump={jump_r}"
2897            );
2898            assert!(
2899                jump_t < 0.5,
2900                "Tsallis discontinuity at alpha={alpha}: jump={jump_t}"
2901            );
2902            prev_renyi = r;
2903            prev_tsallis = t;
2904            alpha += 0.1;
2905        }
2906    }
2907
2908    #[test]
2909    fn ksg_ties_finite() {
2910        // Integer-valued data with exact ties.
2911        let x: Vec<Vec<f64>> = (0..50).map(|i| vec![(i % 5) as f64]).collect();
2912        let y: Vec<Vec<f64>> = (0..50).map(|i| vec![(i % 3) as f64]).collect();
2913        let mi1 = mutual_information_ksg(&x, &y, 3, KsgVariant::Alg1).unwrap();
2914        let mi2 = mutual_information_ksg(&x, &y, 3, KsgVariant::Alg2).unwrap();
2915        assert!(
2916            mi1.is_finite(),
2917            "KSG Alg1 with ties returned NaN/Inf: {mi1}"
2918        );
2919        assert!(
2920            mi2.is_finite(),
2921            "KSG Alg2 with ties returned NaN/Inf: {mi2}"
2922        );
2923    }
2924
2925    proptest! {
2926        #![proptest_config(ProptestConfig { cases: 64, .. ProptestConfig::default() })]
2927
2928        #[test]
2929        fn bregman_nonnegative(
2930            p in simplex_vec_pos(8, 1e-6),
2931            q in simplex_vec_pos(8, 1e-6),
2932        ) {
2933            let gen = NegEntropy;
2934            let b = bregman_divergence(&gen, &p, &q).unwrap();
2935            prop_assert!(b >= -1e-12, "Bregman(NegEntropy) negative: {b}");
2936        }
2937
2938        #[test]
2939        fn renyi_divergence_alpha1_equals_kl(
2940            p in simplex_vec_pos(8, 1e-6),
2941            q in simplex_vec_pos(8, 1e-6),
2942        ) {
2943            let tol = 1e-6;
2944            let kl = kl_divergence(&p, &q, tol).unwrap();
2945            let r1 = renyi_divergence(&p, &q, 1.0, tol).unwrap();
2946            prop_assert!(
2947                (r1 - kl).abs() < 1e-9,
2948                "renyi(alpha=1)={r1} != kl={kl}"
2949            );
2950        }
2951    }
2952
2953    #[test]
2954    fn pmi_impossible_input_errors() {
2955        // p(x,y)>0 but p(x)=0 is impossible.
2956        assert!(pmi(0.1, 0.0, 0.5).is_err());
2957        // p(x,y)>0 but p(y)=0 is impossible.
2958        assert!(pmi(0.1, 0.5, 0.0).is_err());
2959    }
2960}