logp/lib.rs
1//! # logp
2//!
3//! Information theory primitives: entropies and divergences.
4//!
5//! ## Scope
6//!
7//! Scalar information measures that appear across clustering, ranking,
8//! evaluation, and geometry:
9//!
10//! - Shannon entropy and cross-entropy
11//! - KL / Jensen–Shannon divergences
12//! - Csiszár \(f\)-divergences (a.k.a. *information monotone* divergences)
13//! - Bhattacharyya coefficient, Rényi/Tsallis families
14//! - Bregman divergences (convex-analytic, not generally monotone)
15//!
16//! ## Distances vs divergences (terminology that prevents bugs)
17//!
18//! A **divergence** \(D(p:q)\) is usually required to satisfy:
19//!
20//! - \(D(p:q) \ge 0\)
21//! - \(D(p:p) = 0\)
22//!
23//! but it is typically **not** symmetric and **not** a metric (no triangle inequality).
24//! Many failures in downstream code are caused by treating a divergence as a distance metric.
25//!
26//! ## Key invariants (what tests should enforce)
27//!
28//! - **Jensen–Shannon** is bounded on the simplex:
29//! \(0 \le JS(p,q) \le \ln 2\) (nats).
30//! - **Csiszár \(f\)-divergences** are monotone under coarse-graining (Markov kernels):
31//! merging bins cannot increase the divergence.
32//!
33//! ## Further reading
34//!
35//! - Frank Nielsen, “Divergences” portal (taxonomy diagrams + references):
36//! <https://franknielsen.github.io/Divergence/index.html>
37//! - `nocotan/awesome-information-geometry` (curated reading list):
38//! <https://github.com/nocotan/awesome-information-geometry>
39//! - Csiszár (1967): \(f\)-divergences and information monotonicity.
40//! - Amari & Nagaoka (2000): *Methods of Information Geometry*.
41//!
42//! ## Taxonomy of Divergences (Nielsen)
43//!
44//! | Family | Generator | Key Property |
45//! |---|---|---|
46//! | **f-divergences** | Convex \(f(t)\) with \(f(1)=0\) | Monotone under Markov morphisms (coarse-graining) |
47//! | **Bregman** | Convex \(F(x)\) | Dually flat geometry; generalized Pythagorean theorem |
48//! | **Jensen-Shannon** | \(f\)-div + metric | Symmetric, bounded \(\[0, \ln 2\]\), \(\sqrt{JS}\) is a metric |
49//! | **Alpha** | \(\rho_\alpha = \int p^\alpha q^{1-\alpha}\) | Encodes Rényi, Tsallis, Bhattacharyya, Hellinger |
50//!
51//! ## References
52//!
53//! - Shannon (1948). "A Mathematical Theory of Communication"
54//! - Cover & Thomas (2006). "Elements of Information Theory"
55
56#![forbid(unsafe_code)]
57#![warn(missing_docs)]
58
59use thiserror::Error;
60
61mod ksg;
62pub use ksg::{mutual_information_ksg, KsgVariant};
63
64use core::f64::consts::LN_2;
65
66/// KL Divergence between two diagonal Multivariate Gaussians.
67///
68/// Assumes **diagonal** covariance (independent dimensions): each dimension is
69/// parameterized by a mean and a standard deviation, with no cross-covariance terms.
70/// For full-covariance Gaussians, the formula involves log-determinant ratios and
71/// matrix inverses; this function does not handle that case.
72///
73/// Returns 0.5 * Σ [ (std1/std2)^2 + (mu2-mu1)^2 / std2^2 - 1 + 2*ln(std2/std1) ]
74///
75/// # Examples
76///
77/// ```
78/// # use logp::kl_divergence_gaussians;
79/// // KL(N || N) = 0 (identical Gaussians).
80/// let mu = [0.0, 1.0];
81/// let std = [1.0, 2.0];
82/// let kl = kl_divergence_gaussians(&mu, &std, &mu, &std).unwrap();
83/// assert!(kl.abs() < 1e-12);
84///
85/// // KL is non-negative for distinct Gaussians.
86/// let kl = kl_divergence_gaussians(&[0.0], &[1.0], &[1.0], &[1.0]).unwrap();
87/// assert!(kl >= 0.0);
88/// ```
89pub fn kl_divergence_gaussians(
90 mu1: &[f64],
91 std1: &[f64],
92 mu2: &[f64],
93 std2: &[f64],
94) -> Result<f64> {
95 ensure_same_len(mu1, std1)?;
96 ensure_same_len(mu1, mu2)?;
97 ensure_same_len(mu1, std2)?;
98
99 let mut kl = 0.0;
100 for (((&m1, &s1), &m2), &s2) in mu1.iter().zip(std1).zip(mu2).zip(std2) {
101 if s1 <= 0.0 || s2 <= 0.0 {
102 return Err(Error::Domain("standard deviation must be positive"));
103 }
104 let v1 = s1 * s1;
105 let v2 = s2 * s2;
106 kl += (v1 / v2) + (m2 - m1).powi(2) / v2 - 1.0 + 2.0 * (s2.ln() - s1.ln());
107 }
108 Ok(0.5 * kl)
109}
110
111/// Errors for information-measure computations.
112#[derive(Debug, Error)]
113pub enum Error {
114 /// Two input slices have different lengths (e.g., `p` and `q` in a divergence).
115 #[error("length mismatch: {0} vs {1}")]
116 LengthMismatch(usize, usize),
117
118 /// An input slice is empty where at least one element is required.
119 #[error("empty input")]
120 Empty,
121
122 /// An entry is NaN or infinite where a finite value is required.
123 #[error("non-finite entry at index {idx}: {value}")]
124 NonFinite {
125 /// Position of the offending entry.
126 idx: usize,
127 /// The non-finite value found.
128 value: f64,
129 },
130
131 /// An entry is negative where a nonnegative value is required (probability distributions).
132 #[error("negative entry at index {idx}: {value}")]
133 Negative {
134 /// Position of the offending entry.
135 idx: usize,
136 /// The negative value found.
137 value: f64,
138 },
139
140 /// The distribution does not sum to 1 within the specified tolerance.
141 #[error("not normalized (expected sum≈1): sum={sum}")]
142 NotNormalized {
143 /// Actual sum of the distribution.
144 sum: f64,
145 },
146
147 /// The alpha parameter is outside its valid domain (e.g., negative or non-finite).
148 #[error("invalid alpha: {alpha} (must be finite and not equal to {forbidden})")]
149 InvalidAlpha {
150 /// The invalid alpha value provided.
151 alpha: f64,
152 /// The value alpha must not equal (e.g., 0.0).
153 forbidden: f64,
154 },
155
156 /// Catch-all for domain violations: zero standard deviation, q_i=0 while p_i>0,
157 /// insufficient sample size for KSG, and similar precondition failures.
158 #[error("domain error: {0}")]
159 Domain(&'static str),
160}
161
162/// Convenience alias for `Result<T, logp::Error>`.
163pub type Result<T> = core::result::Result<T, Error>;
164
165fn ensure_nonempty(x: &[f64]) -> Result<()> {
166 if x.is_empty() {
167 return Err(Error::Empty);
168 }
169 Ok(())
170}
171
172fn ensure_same_len(a: &[f64], b: &[f64]) -> Result<()> {
173 if a.len() != b.len() {
174 return Err(Error::LengthMismatch(a.len(), b.len()));
175 }
176 Ok(())
177}
178
179fn ensure_nonnegative(x: &[f64]) -> Result<()> {
180 for (i, &v) in x.iter().enumerate() {
181 if !v.is_finite() {
182 return Err(Error::NonFinite { idx: i, value: v });
183 }
184 if v < 0.0 {
185 return Err(Error::Negative { idx: i, value: v });
186 }
187 }
188 Ok(())
189}
190
191fn sum(x: &[f64]) -> f64 {
192 x.iter().sum()
193}
194
195/// Validate that `p` is a probability distribution on the simplex (within `tol`).
196///
197/// # Examples
198///
199/// ```
200/// # use logp::validate_simplex;
201/// // Valid simplex.
202/// assert!(validate_simplex(&[0.3, 0.7], 1e-9).is_ok());
203/// assert!(validate_simplex(&[1.0], 1e-9).is_ok());
204///
205/// // Rejects bad sum.
206/// assert!(validate_simplex(&[0.3, 0.6], 1e-9).is_err());
207///
208/// // Rejects negative entries.
209/// assert!(validate_simplex(&[1.5, -0.5], 1e-9).is_err());
210///
211/// // Rejects empty input.
212/// assert!(validate_simplex(&[], 1e-9).is_err());
213/// ```
214pub fn validate_simplex(p: &[f64], tol: f64) -> Result<()> {
215 ensure_nonempty(p)?;
216 ensure_nonnegative(p)?;
217 let s = sum(p);
218 if (s - 1.0).abs() > tol {
219 return Err(Error::NotNormalized { sum: s });
220 }
221 Ok(())
222}
223
224/// Normalize a nonnegative vector in-place to sum to 1.
225///
226/// Returns the original sum.
227///
228/// # Examples
229///
230/// ```
231/// # use logp::normalize_in_place;
232/// let mut v = vec![2.0, 3.0, 5.0];
233/// let original_sum = normalize_in_place(&mut v).unwrap();
234/// assert!((original_sum - 10.0).abs() < 1e-12);
235/// assert!((v[0] - 0.2).abs() < 1e-12);
236/// assert!((v[1] - 0.3).abs() < 1e-12);
237/// assert!((v[2] - 0.5).abs() < 1e-12);
238///
239/// // Rejects all-zero input.
240/// assert!(normalize_in_place(&mut vec![0.0, 0.0]).is_err());
241/// ```
242pub fn normalize_in_place(p: &mut [f64]) -> Result<f64> {
243 ensure_nonempty(p)?;
244 ensure_nonnegative(p)?;
245 let s = sum(p);
246 if s <= 0.0 {
247 return Err(Error::Domain("cannot normalize: sum <= 0"));
248 }
249 for v in p.iter_mut() {
250 *v /= s;
251 }
252 Ok(s)
253}
254
255/// Shannon entropy in nats: the expected surprise under distribution \(p\).
256///
257/// \[H(p) = -\sum_i p_i \ln p_i\]
258///
259/// # Key properties
260///
261/// - **Non-negative**: \(H(p) \ge 0\), with equality iff \(p\) is a delta (point mass).
262/// - **Maximized by uniform**: among distributions on \(n\) outcomes,
263/// \(H(p) \le \ln n\), with equality iff \(p_i = 1/n\) for all \(i\).
264/// - **Concavity**: \(H\) is a concave function of \(p\) on the simplex.
265/// Mixing distributions never decreases entropy.
266/// - **Units**: result is in nats (base \(e\)); divide by \(\ln 2\) for bits.
267/// - **Joint distributions**: for joint entropy \(H(X,Y)\), flatten the joint
268/// distribution to a 1D slice and pass it directly. Shannon entropy of the
269/// flattened joint is mathematically identical to joint entropy.
270///
271/// # Domain
272///
273/// Requires `p` to be a valid simplex distribution (within `tol`).
274///
275/// # Examples
276///
277/// ```
278/// # use logp::entropy_nats;
279/// // Uniform distribution over 4 outcomes: H = ln(4).
280/// let p = [0.25, 0.25, 0.25, 0.25];
281/// let h = entropy_nats(&p, 1e-9).unwrap();
282/// assert!((h - 4.0_f64.ln()).abs() < 1e-12);
283///
284/// // Delta (point mass): H = 0.
285/// let delta = [1.0, 0.0, 0.0];
286/// assert!(entropy_nats(&delta, 1e-9).unwrap().abs() < 1e-15);
287/// ```
288pub fn entropy_nats(p: &[f64], tol: f64) -> Result<f64> {
289 validate_simplex(p, tol)?;
290 let mut h = 0.0;
291 for &pi in p {
292 if pi > 0.0 {
293 h -= pi * pi.ln();
294 }
295 }
296 Ok(h)
297}
298
299/// Shannon entropy in bits.
300///
301/// # Examples
302///
303/// ```
304/// # use logp::{entropy_bits, entropy_nats};
305/// // Fair coin: H = 1 bit.
306/// let p = [0.5, 0.5];
307/// let bits = entropy_bits(&p, 1e-9).unwrap();
308/// assert!((bits - 1.0).abs() < 1e-12);
309///
310/// // Consistent with nats / ln(2).
311/// let nats = entropy_nats(&p, 1e-9).unwrap();
312/// assert!((bits - nats / core::f64::consts::LN_2).abs() < 1e-12);
313/// ```
314pub fn entropy_bits(p: &[f64], tol: f64) -> Result<f64> {
315 Ok(entropy_nats(p, tol)? / LN_2)
316}
317
318/// Fast Shannon entropy calculation without simplex validation.
319///
320/// Used in performance-critical loops like Sinkhorn iteration for Optimal Transport.
321///
322/// # Invariant
323/// Assumes `p` is non-negative and normalized.
324///
325/// # Examples
326///
327/// ```
328/// # use logp::entropy_unchecked;
329/// // Fair coin: H = ln(2).
330/// let h = entropy_unchecked(&[0.5, 0.5]);
331/// assert!((h - core::f64::consts::LN_2).abs() < 1e-12);
332///
333/// // Agrees with the checked version on valid input.
334/// let p = [0.3, 0.7];
335/// let h_checked = logp::entropy_nats(&p, 1e-9).unwrap();
336/// assert!((entropy_unchecked(&p) - h_checked).abs() < 1e-15);
337/// ```
338#[inline]
339pub fn entropy_unchecked(p: &[f64]) -> f64 {
340 let mut h = 0.0;
341 for &pi in p {
342 if pi > 0.0 {
343 h -= pi * pi.ln();
344 }
345 }
346 h
347}
348
349/// Renyi entropy in nats: a one-parameter generalization of Shannon entropy.
350///
351/// \[H_\alpha(p) = \frac{1}{1-\alpha} \ln \sum_i p_i^\alpha, \quad \alpha > 0,\;\alpha \ne 1\]
352///
353/// # Key properties
354///
355/// - **Limit to Shannon**: \(\lim_{\alpha \to 1} H_\alpha(p) = H(p)\) (Shannon entropy).
356/// - **Alpha = 0**: \(H_0(p) = \ln |\text{supp}(p)|\), the log of the support size (Hartley entropy).
357/// - **Alpha = 2**: \(H_2(p) = -\ln \sum_i p_i^2\), the collision entropy (negative log of
358/// the probability that two independent draws match).
359/// - **Alpha = infinity**: \(H_\infty(p) = -\ln \max_i p_i\), the min-entropy (worst-case surprise).
360/// - **Monotone in alpha**: \(H_\alpha(p)\) is non-increasing in \(\alpha\).
361/// - **Non-negative**: \(H_\alpha(p) \ge 0\).
362///
363/// # Examples
364///
365/// ```
366/// # use logp::renyi_entropy;
367/// // Uniform over 4: H_alpha = ln(4) for all alpha.
368/// let p = [0.25, 0.25, 0.25, 0.25];
369/// let h = renyi_entropy(&p, 2.0, 1e-9).unwrap();
370/// assert!((h - 4.0_f64.ln()).abs() < 1e-12);
371///
372/// // Collision entropy: H_2 = -ln(sum(p_i^2)).
373/// let q = [0.3, 0.7];
374/// let h2 = renyi_entropy(&q, 2.0, 1e-9).unwrap();
375/// let expected = -(0.3_f64.powi(2) + 0.7_f64.powi(2)).ln();
376/// assert!((h2 - expected).abs() < 1e-12);
377/// ```
378pub fn renyi_entropy(p: &[f64], alpha: f64, tol: f64) -> Result<f64> {
379 validate_simplex(p, tol)?;
380 if !alpha.is_finite() || alpha < 0.0 {
381 return Err(Error::InvalidAlpha {
382 alpha,
383 forbidden: f64::NAN,
384 });
385 }
386 if (alpha - 1.0).abs() < 1e-12 {
387 return entropy_nats(p, tol);
388 }
389 let mut s = 0.0;
390 for &pi in p {
391 if pi > 0.0 {
392 s += pi.powf(alpha);
393 }
394 }
395 if s <= 0.0 {
396 return Err(Error::Domain("renyi_entropy: sum of p_i^alpha <= 0"));
397 }
398 Ok(s.ln() / (1.0 - alpha))
399}
400
401/// Tsallis entropy: a non-extensive generalization of Shannon entropy from
402/// statistical mechanics.
403///
404/// \[S_\alpha(p) = \frac{1}{\alpha - 1}\left(1 - \sum_i p_i^\alpha\right), \quad \alpha \ne 1\]
405///
406/// # Key properties
407///
408/// - **Limit to Shannon**: \(\lim_{\alpha \to 1} S_\alpha(p) = H(p)\).
409/// - **Non-extensive**: for independent systems \(A, B\):
410/// \(S_\alpha(A \otimes B) = S_\alpha(A) + S_\alpha(B) + (1-\alpha)\,S_\alpha(A)\,S_\alpha(B)\).
411/// - **Non-negative**: \(S_\alpha(p) \ge 0\).
412/// - **Connection to Renyi**: related via \(S_\alpha = \frac{e^{(1-\alpha)H_\alpha} - 1}{\alpha - 1}\).
413///
414/// # Examples
415///
416/// ```
417/// # use logp::tsallis_entropy;
418/// // Uniform over 4: S_2 = 1 - 1/4 = 0.75.
419/// let p = [0.25, 0.25, 0.25, 0.25];
420/// let s = tsallis_entropy(&p, 2.0, 1e-9).unwrap();
421/// assert!((s - 0.75).abs() < 1e-12);
422///
423/// // Delta: S_alpha = 0 for any alpha.
424/// let delta = [1.0, 0.0, 0.0];
425/// assert!(tsallis_entropy(&delta, 2.0, 1e-9).unwrap().abs() < 1e-12);
426/// ```
427pub fn tsallis_entropy(p: &[f64], alpha: f64, tol: f64) -> Result<f64> {
428 validate_simplex(p, tol)?;
429 if !alpha.is_finite() || alpha < 0.0 {
430 return Err(Error::InvalidAlpha {
431 alpha,
432 forbidden: f64::NAN,
433 });
434 }
435 if (alpha - 1.0).abs() < 1e-12 {
436 return entropy_nats(p, tol);
437 }
438 let mut s = 0.0;
439 for &pi in p {
440 if pi > 0.0 {
441 s += pi.powf(alpha);
442 }
443 }
444 Ok((1.0 - s) / (alpha - 1.0))
445}
446
447/// Cross-entropy in nats: the expected code length when using model \(q\) to encode
448/// data drawn from true distribution \(p\).
449///
450/// \[H(p, q) = -\sum_i p_i \ln q_i\]
451///
452/// # Key properties
453///
454/// - **Decomposition identity**: cross-entropy splits into entropy plus KL divergence:
455/// \(H(p, q) = H(p) + D_{KL}(p \| q)\).
456/// This means \(H(p, q) \ge H(p)\) with equality iff \(p = q\).
457/// - **Loss function**: minimizing \(H(p, q)\) over \(q\) is equivalent to minimizing
458/// \(D_{KL}(p \| q)\), which is why cross-entropy is the standard classification loss.
459/// - **Not symmetric**: \(H(p, q) \ne H(q, p)\) in general.
460///
461/// # Domain
462///
463/// `p` must be on the simplex; `q` must be nonnegative and normalized; and
464/// whenever `p_i > 0`, we require `q_i > 0` (otherwise cross-entropy is infinite).
465///
466/// # Examples
467///
468/// ```
469/// # use logp::{cross_entropy_nats, entropy_nats, kl_divergence};
470/// let p = [0.3, 0.7];
471/// let q = [0.5, 0.5];
472/// let h_pq = cross_entropy_nats(&p, &q, 1e-9).unwrap();
473///
474/// // Decomposition: H(p,q) = H(p) + KL(p||q).
475/// let h_p = entropy_nats(&p, 1e-9).unwrap();
476/// let kl = kl_divergence(&p, &q, 1e-9).unwrap();
477/// assert!((h_pq - (h_p + kl)).abs() < 1e-12);
478///
479/// // Self-cross-entropy equals entropy: H(p,p) = H(p).
480/// let h_pp = cross_entropy_nats(&p, &p, 1e-9).unwrap();
481/// assert!((h_pp - h_p).abs() < 1e-12);
482/// ```
483pub fn cross_entropy_nats(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
484 validate_simplex(p, tol)?;
485 validate_simplex(q, tol)?;
486 let mut h = 0.0;
487 for (&pi, &qi) in p.iter().zip(q.iter()) {
488 if pi == 0.0 {
489 continue;
490 }
491 if qi <= 0.0 {
492 return Err(Error::Domain("cross-entropy undefined: q_i=0 while p_i>0"));
493 }
494 h -= pi * qi.ln();
495 }
496 Ok(h)
497}
498
499/// Kullback--Leibler divergence in nats: the information lost when \(q\) is used to
500/// approximate \(p\).
501///
502/// \[D_{KL}(p \| q) = \sum_i p_i \ln \frac{p_i}{q_i}\]
503///
504/// # Key properties
505///
506/// - **Gibbs' inequality**: \(D_{KL}(p \| q) \ge 0\), with equality iff \(p = q\).
507/// This follows directly from Jensen's inequality applied to \(-\ln\).
508/// - **Not symmetric**: \(D_{KL}(p \| q) \ne D_{KL}(q \| p)\) in general;
509/// this is why KL is a divergence, not a distance.
510/// - **Not bounded above**: KL can be arbitrarily large when supports differ.
511/// - **Connection to MLE**: minimizing \(D_{KL}(p_{data} \| q_\theta)\) over \(\theta\)
512/// is equivalent to maximum likelihood estimation.
513/// - **Additive for independent distributions**: if \(p = p_1 \otimes p_2\) and
514/// \(q = q_1 \otimes q_2\), then
515/// \(D_{KL}(p \| q) = D_{KL}(p_1 \| q_1) + D_{KL}(p_2 \| q_2)\).
516///
517/// # Domain
518///
519/// `p` and `q` must be valid simplex distributions; and whenever `p_i > 0`,
520/// we require `q_i > 0`.
521///
522/// # Examples
523///
524/// ```
525/// # use logp::kl_divergence;
526/// // KL(p || p) = 0 (Gibbs' inequality, tight case).
527/// let p = [0.2, 0.3, 0.5];
528/// assert!(kl_divergence(&p, &p, 1e-9).unwrap().abs() < 1e-15);
529///
530/// // KL is non-negative.
531/// let q = [0.5, 0.25, 0.25];
532/// assert!(kl_divergence(&p, &q, 1e-9).unwrap() >= 0.0);
533///
534/// // Not symmetric in general.
535/// let kl_pq = kl_divergence(&p, &q, 1e-9).unwrap();
536/// let kl_qp = kl_divergence(&q, &p, 1e-9).unwrap();
537/// assert!((kl_pq - kl_qp).abs() > 1e-6);
538/// ```
539pub fn kl_divergence(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
540 ensure_same_len(p, q)?;
541 validate_simplex(p, tol)?;
542 validate_simplex(q, tol)?;
543 let mut d = 0.0;
544 for (&pi, &qi) in p.iter().zip(q.iter()) {
545 if pi == 0.0 {
546 continue;
547 }
548 if qi <= 0.0 {
549 return Err(Error::Domain("KL undefined: q_i=0 while p_i>0"));
550 }
551 d += pi * (pi / qi).ln();
552 }
553 Ok(d)
554}
555
556/// Jensen--Shannon divergence in nats: a symmetric, bounded smoothing of KL divergence.
557///
558/// \[JS(p, q) = \tfrac{1}{2} D_{KL}(p \| m) + \tfrac{1}{2} D_{KL}(q \| m), \quad m = \tfrac{1}{2}(p + q)\]
559///
560/// # Key properties
561///
562/// - **Symmetric**: \(JS(p, q) = JS(q, p)\), unlike KL.
563/// - **Bounded**: \(0 \le JS(p, q) \le \ln 2\). The upper bound is attained when \(p\)
564/// and \(q\) have disjoint supports.
565/// - **Square root is a metric**: \(\sqrt{JS(p, q)}\) satisfies the triangle inequality
566/// (Endres & Schindelin, 2003), so it can be used as a proper distance function.
567/// - **Connection to mutual information**: \(JS(p, q) = I(X; Z)\) where \(Z\) is a
568/// fair coin selecting between \(p\) and \(q\), and \(X\) is drawn from the selected
569/// distribution.
570/// - **Always finite**: because \(m_i > 0\) whenever \(p_i > 0\) or \(q_i > 0\), the
571/// KL terms are always well-defined (no division by zero).
572///
573/// # Domain
574///
575/// `p`, `q` must be simplex distributions.
576///
577/// # Examples
578///
579/// ```
580/// # use logp::jensen_shannon_divergence;
581/// // JS(p, p) = 0.
582/// let p = [0.3, 0.7];
583/// assert!(jensen_shannon_divergence(&p, &p, 1e-9).unwrap().abs() < 1e-15);
584///
585/// // Disjoint supports: JS = ln(2).
586/// let a = [1.0, 0.0];
587/// let b = [0.0, 1.0];
588/// let js = jensen_shannon_divergence(&a, &b, 1e-9).unwrap();
589/// assert!((js - core::f64::consts::LN_2).abs() < 1e-12);
590///
591/// // Symmetric: JS(p, q) = JS(q, p).
592/// let q = [0.5, 0.5];
593/// let js_pq = jensen_shannon_divergence(&p, &q, 1e-9).unwrap();
594/// let js_qp = jensen_shannon_divergence(&q, &p, 1e-9).unwrap();
595/// assert!((js_pq - js_qp).abs() < 1e-15);
596/// ```
597pub fn jensen_shannon_divergence(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
598 ensure_same_len(p, q)?;
599 validate_simplex(p, tol)?;
600 validate_simplex(q, tol)?;
601
602 let mut m = vec![0.0; p.len()];
603 for i in 0..p.len() {
604 m[i] = 0.5 * (p[i] + q[i]);
605 }
606
607 Ok(0.5 * kl_divergence(p, &m, tol)? + 0.5 * kl_divergence(q, &m, tol)?)
608}
609
610/// Weighted Jensen--Shannon divergence: a generalization that allows unequal mixture
611/// weights.
612///
613/// \[JS_\pi(p, q) = \pi_1\,D_{KL}(p \| m) + \pi_2\,D_{KL}(q \| m), \quad
614/// m = \pi_1 p + \pi_2 q\]
615///
616/// where \(\pi_1 + \pi_2 = 1\). At \(\pi_1 = \pi_2 = 0.5\) this reduces to
617/// [`jensen_shannon_divergence`].
618///
619/// # Key properties
620///
621/// - **Bounded**: \(0 \le JS_\pi \le H(\pi)\) where \(H(\pi) = -\pi_1 \ln \pi_1 - \pi_2 \ln \pi_2\).
622/// The standard bound \(\ln 2\) is the special case \(\pi_1 = \pi_2 = 0.5\).
623/// - **Symmetric in \((p, \pi_1)\) and \((q, \pi_2)\)**: swapping both distributions
624/// and weights gives the same value.
625///
626/// # Examples
627///
628/// ```
629/// # use logp::{jensen_shannon_weighted, jensen_shannon_divergence};
630/// let p = [0.3, 0.7];
631/// let q = [0.5, 0.5];
632///
633/// // Equal weights recovers standard JS.
634/// let jsw = jensen_shannon_weighted(&p, &q, 0.5, 1e-9).unwrap();
635/// let js = jensen_shannon_divergence(&p, &q, 1e-9).unwrap();
636/// assert!((jsw - js).abs() < 1e-12);
637///
638/// // Extreme weight: pi1=1 means m=p, so JS=0.
639/// let js1 = jensen_shannon_weighted(&p, &q, 1.0, 1e-9).unwrap();
640/// assert!(js1.abs() < 1e-12);
641/// ```
642pub fn jensen_shannon_weighted(p: &[f64], q: &[f64], pi1: f64, tol: f64) -> Result<f64> {
643 ensure_same_len(p, q)?;
644 validate_simplex(p, tol)?;
645 validate_simplex(q, tol)?;
646 if !(0.0..=1.0).contains(&pi1) || !pi1.is_finite() {
647 return Err(Error::Domain("pi1 must be in [0, 1]"));
648 }
649 let pi2 = 1.0 - pi1;
650
651 let mut m = vec![0.0; p.len()];
652 for i in 0..p.len() {
653 m[i] = pi1 * p[i] + pi2 * q[i];
654 }
655
656 // When pi1 or pi2 is zero, the corresponding KL term is 0 * D_KL = 0.
657 let kl_p = if pi1 > 0.0 {
658 kl_divergence(p, &m, tol)?
659 } else {
660 0.0
661 };
662 let kl_q = if pi2 > 0.0 {
663 kl_divergence(q, &m, tol)?
664 } else {
665 0.0
666 };
667 Ok(pi1 * kl_p + pi2 * kl_q)
668}
669
670/// Mutual information in nats: how much knowing \(Y\) reduces uncertainty about \(X\).
671///
672/// \[I(X; Y) = \sum_{x,y} p(x,y) \ln \frac{p(x,y)}{p(x)\,p(y)}\]
673///
674/// # Key properties
675///
676/// - **KL form**: \(I(X; Y) = D_{KL}\bigl(p(x,y) \;\|\; p(x)\,p(y)\bigr)\), measuring
677/// how far the joint distribution is from the product of its marginals.
678/// - **Non-negative**: \(I(X; Y) \ge 0\), with equality iff \(X\) and \(Y\) are
679/// independent.
680/// - **Symmetric**: \(I(X; Y) = I(Y; X)\).
681/// - **Bounded by entropy**: \(I(X; Y) \le \min\{H(X),\, H(Y)\}\).
682/// - **Data processing inequality**: for any Markov chain \(X \to Y \to Z\),
683/// \(I(X; Z) \le I(X; Y)\). Processing cannot create information.
684/// - **Entropy decomposition**: \(I(X; Y) = H(X) + H(Y) - H(X, Y) = H(X) - H(X|Y)\).
685///
686/// # Layout
687///
688/// For discrete distributions, given a **row-major** joint distribution table `p_xy`
689/// with shape `(n_x, n_y)`.
690///
691/// Public invariant (this is the important one): this API is **backend-agnostic**.
692/// It does not force `ndarray` into the public surface of a leaf crate.
693///
694/// # Examples
695///
696/// ```
697/// # use logp::{mutual_information, entropy_nats};
698/// // Independent joint: p(x,y) = p(x)*p(y), so I(X;Y) = 0.
699/// let p_xy = [0.15, 0.35, 0.15, 0.35]; // 2x2, marginals [0.5,0.5] x [0.3,0.7]
700/// let mi = mutual_information(&p_xy, 2, 2, 1e-9).unwrap();
701/// assert!(mi.abs() < 1e-12);
702///
703/// // Perfect correlation (Y = X, uniform bit): I(X;Y) = H(X) = ln(2).
704/// let diag = [0.5, 0.0, 0.0, 0.5];
705/// let mi = mutual_information(&diag, 2, 2, 1e-9).unwrap();
706/// assert!((mi - core::f64::consts::LN_2).abs() < 1e-12);
707/// ```
708pub fn mutual_information(p_xy: &[f64], n_x: usize, n_y: usize, tol: f64) -> Result<f64> {
709 if n_x == 0 || n_y == 0 {
710 return Err(Error::Domain(
711 "mutual_information: n_x and n_y must be >= 1",
712 ));
713 }
714 if p_xy.len() != n_x * n_y {
715 return Err(Error::LengthMismatch(p_xy.len(), n_x * n_y));
716 }
717 validate_simplex(p_xy, tol)?;
718
719 let mut p_x = vec![0.0; n_x];
720 let mut p_y = vec![0.0; n_y];
721 for i in 0..n_x {
722 for j in 0..n_y {
723 let p = p_xy[i * n_y + j];
724 p_x[i] += p;
725 p_y[j] += p;
726 }
727 }
728
729 let mut mi = 0.0;
730 for i in 0..n_x {
731 for j in 0..n_y {
732 let pxy = p_xy[i * n_y + j];
733 if pxy > 0.0 {
734 let px = p_x[i];
735 let py = p_y[j];
736 if px <= 0.0 || py <= 0.0 {
737 return Err(Error::Domain(
738 "mutual_information: p(x)=0 or p(y)=0 while p(x,y)>0",
739 ));
740 }
741 mi += pxy * (pxy / (px * py)).ln();
742 }
743 }
744 }
745 Ok(mi)
746}
747
748/// `ndarray` adapter for discrete mutual information.
749///
750/// Requires `logp` feature `ndarray`.
751#[cfg(feature = "ndarray")]
752pub fn mutual_information_ndarray(p_xy: &ndarray::Array2<f64>, tol: f64) -> Result<f64> {
753 let (n_x, n_y) = p_xy.dim();
754 let flat: Vec<f64> = p_xy.iter().copied().collect();
755 mutual_information(&flat, n_x, n_y, tol)
756}
757
758/// Conditional entropy in nats: the remaining uncertainty about \(X\) after observing \(Y\).
759///
760/// \[H(X|Y) = H(X, Y) - H(Y) = H(X) - I(X; Y)\]
761///
762/// # Key properties
763///
764/// - **Non-negative**: \(H(X|Y) \ge 0\).
765/// - **Bounded**: \(H(X|Y) \le H(X)\), with equality iff \(X\) and \(Y\) are independent.
766/// - **Zero iff deterministic**: \(H(X|Y) = 0\) iff \(X\) is a function of \(Y\).
767///
768/// # Layout
769///
770/// Row-major joint distribution `p_xy` with shape `(n_x, n_y)`, same as
771/// [`mutual_information`].
772///
773/// # Examples
774///
775/// ```
776/// # use logp::conditional_entropy;
777/// // Independent: H(X|Y) = H(X).
778/// let p_xy = [0.15, 0.35, 0.15, 0.35]; // 2x2
779/// let h_x_given_y = conditional_entropy(&p_xy, 2, 2, 1e-9).unwrap();
780/// // H(X) for marginal [0.5, 0.5] = ln(2)
781/// assert!((h_x_given_y - 2.0_f64.ln()).abs() < 1e-10);
782///
783/// // Deterministic (Y = X): H(X|Y) = 0.
784/// let diag = [0.5, 0.0, 0.0, 0.5];
785/// assert!(conditional_entropy(&diag, 2, 2, 1e-9).unwrap().abs() < 1e-10);
786/// ```
787pub fn conditional_entropy(p_xy: &[f64], n_x: usize, n_y: usize, tol: f64) -> Result<f64> {
788 let mi = mutual_information(p_xy, n_x, n_y, tol)?;
789 // Compute H(X) from marginal.
790 let mut p_x = vec![0.0; n_x];
791 for i in 0..n_x {
792 for j in 0..n_y {
793 p_x[i] += p_xy[i * n_y + j];
794 }
795 }
796 let h_x = entropy_nats(&p_x, tol)?;
797 Ok(h_x - mi)
798}
799
800/// Normalized mutual information: MI scaled to \([0, 1]\) for comparing clusterings
801/// of different sizes.
802///
803/// \[NMI(X, Y) = \frac{2\,I(X; Y)}{H(X) + H(Y)}\]
804///
805/// # Key properties
806///
807/// - **Bounded**: \(NMI \in [0, 1]\). Equals 0 for independent variables;
808/// equals 1 for perfectly correlated (identical clustering).
809/// - **Symmetric**: \(NMI(X, Y) = NMI(Y, X)\).
810///
811/// Returns 0 when both marginal entropies are zero (trivial single-cluster case).
812///
813/// # Examples
814///
815/// ```
816/// # use logp::normalized_mutual_information;
817/// // Perfect correlation: NMI = 1.
818/// let diag = [0.5, 0.0, 0.0, 0.5];
819/// let nmi = normalized_mutual_information(&diag, 2, 2, 1e-9).unwrap();
820/// assert!((nmi - 1.0).abs() < 1e-10);
821///
822/// // Independent: NMI = 0.
823/// let indep = [0.15, 0.35, 0.15, 0.35];
824/// let nmi = normalized_mutual_information(&indep, 2, 2, 1e-9).unwrap();
825/// assert!(nmi.abs() < 1e-10);
826/// ```
827pub fn normalized_mutual_information(
828 p_xy: &[f64],
829 n_x: usize,
830 n_y: usize,
831 tol: f64,
832) -> Result<f64> {
833 if n_x == 0 || n_y == 0 {
834 return Err(Error::Domain("nmi: n_x and n_y must be >= 1"));
835 }
836 if p_xy.len() != n_x * n_y {
837 return Err(Error::LengthMismatch(p_xy.len(), n_x * n_y));
838 }
839 validate_simplex(p_xy, tol)?;
840
841 let mut p_x = vec![0.0; n_x];
842 let mut p_y = vec![0.0; n_y];
843 for i in 0..n_x {
844 for j in 0..n_y {
845 let p = p_xy[i * n_y + j];
846 p_x[i] += p;
847 p_y[j] += p;
848 }
849 }
850
851 let h_x = entropy_nats(&p_x, tol)?;
852 let h_y = entropy_nats(&p_y, tol)?;
853 let denom = h_x + h_y;
854 if denom <= 0.0 {
855 return Ok(0.0);
856 }
857
858 let mi = mutual_information(p_xy, n_x, n_y, tol)?;
859 Ok(2.0 * mi / denom)
860}
861
862/// Pointwise mutual information: the log-ratio measuring how much more (or less)
863/// likely two specific outcomes co-occur than if they were independent.
864///
865/// \[PMI(x; y) = \ln \frac{p(x, y)}{p(x)\,p(y)}\]
866///
867/// # Key properties
868///
869/// - **Sign**: positive when \(x\) and \(y\) co-occur more than chance; negative when
870/// less; zero when independent.
871/// - **Unbounded**: \(PMI \in (-\infty, -\ln p(x,y)]\). In practice, rare events yield
872/// very large PMI, which is why PPMI (positive PMI, clamped at 0) is common.
873/// - **Connection to mutual information**: \(I(X; Y) = \mathbb{E}_{p(x,y)}[PMI(x; y)]\).
874/// MI is the expected value of PMI over the joint distribution.
875/// - **Connection to word2vec**: Levy & Goldberg (2014) showed that Skip-gram with
876/// negative sampling implicitly factorizes a PMI matrix (shifted by \(\ln k\)).
877///
878/// # Examples
879///
880/// ```
881/// # use logp::pmi;
882/// // Independent events: p(x,y) = p(x)*p(y), so PMI = 0.
883/// let val = pmi(0.06, 0.3, 0.2).unwrap();
884/// assert!(val.abs() < 1e-10);
885///
886/// // Positive correlation: p(x,y) > p(x)*p(y).
887/// assert!(pmi(0.4, 0.5, 0.5).unwrap() > 0.0);
888///
889/// // Negative correlation: p(x,y) < p(x)*p(y).
890/// assert!(pmi(0.1, 0.5, 0.5).unwrap() < 0.0);
891///
892/// // Zero joint probability returns 0 by convention.
893/// assert_eq!(pmi(0.0, 0.5, 0.5).unwrap(), 0.0);
894///
895/// // Impossible: p(x,y) > 0 but p(x) = 0.
896/// assert!(pmi(0.1, 0.0, 0.5).is_err());
897/// ```
898pub fn pmi(pxy: f64, px: f64, py: f64) -> Result<f64> {
899 if pxy > 0.0 && px == 0.0 {
900 return Err(Error::Domain("pmi: p(x,y)>0 but p(x)=0 is impossible"));
901 }
902 if pxy > 0.0 && py == 0.0 {
903 return Err(Error::Domain("pmi: p(x,y)>0 but p(y)=0 is impossible"));
904 }
905 if pxy <= 0.0 || px <= 0.0 || py <= 0.0 {
906 Ok(0.0)
907 } else {
908 Ok((pxy / (px * py)).ln())
909 }
910}
911
912/// Log-sum-exp: numerically stable computation of `ln(exp(a_1) + ... + exp(a_n))`.
913///
914/// This is the fundamental primitive for working in log-probability space.
915/// The naive `values.iter().map(|v| v.exp()).sum::<f64>().ln()` overflows
916/// for large values and underflows for small ones; the max-shift trick
917/// avoids both.
918///
919/// Returns `NEG_INFINITY` for an empty slice.
920///
921/// # Examples
922///
923/// ```
924/// # use logp::log_sum_exp;
925/// // ln(e^0 + e^0) = ln(2)
926/// let lse = log_sum_exp(&[0.0, 0.0]);
927/// assert!((lse - 2.0_f64.ln()).abs() < 1e-12);
928///
929/// // Dominated term: ln(e^1000 + e^0) ≈ 1000
930/// let lse = log_sum_exp(&[1000.0, 0.0]);
931/// assert!((lse - 1000.0).abs() < 1e-10);
932///
933/// // Single element: identity.
934/// assert_eq!(log_sum_exp(&[42.0]), 42.0);
935///
936/// // Empty: -inf.
937/// assert_eq!(log_sum_exp(&[]), f64::NEG_INFINITY);
938/// ```
939#[inline]
940pub fn log_sum_exp(values: &[f64]) -> f64 {
941 if values.is_empty() {
942 return f64::NEG_INFINITY;
943 }
944 let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
945 if max.is_infinite() {
946 return max;
947 }
948 let sum: f64 = values.iter().map(|v| (v - max).exp()).sum();
949 max + sum.ln()
950}
951
952/// Log-sum-exp for two values (common special case).
953///
954/// Equivalent to `log_sum_exp(&[a, b])` but avoids the slice overhead.
955///
956/// # Examples
957///
958/// ```
959/// # use logp::log_sum_exp2;
960/// let lse = log_sum_exp2(0.0, 0.0);
961/// assert!((lse - 2.0_f64.ln()).abs() < 1e-12);
962/// ```
963#[inline]
964pub fn log_sum_exp2(a: f64, b: f64) -> f64 {
965 let max = a.max(b);
966 if max.is_infinite() {
967 return max;
968 }
969 max + ((a - max).exp() + (b - max).exp()).ln()
970}
971
972/// Streaming log-sum-exp: single-pass, O(1) memory computation over an iterator.
973///
974/// Equivalent to `log_sum_exp` but processes elements one at a time without
975/// materializing a slice. Useful for large-scale or iterator-chain workloads.
976///
977/// Uses a running `(max, sum_exp)` pair that rescales when a new maximum arrives.
978///
979/// Returns `NEG_INFINITY` for an empty iterator.
980///
981/// # Examples
982///
983/// ```
984/// # use logp::{log_sum_exp_iter, log_sum_exp};
985/// let values = vec![1.0, 2.0, 3.0];
986/// let lse_iter = log_sum_exp_iter(values.iter().copied());
987/// let lse_slice = log_sum_exp(&values);
988/// assert!((lse_iter - lse_slice).abs() < 1e-12);
989///
990/// // Works with any iterator.
991/// let lse = log_sum_exp_iter((0..5).map(|i| i as f64));
992/// assert!(lse.is_finite());
993///
994/// // Empty iterator: -inf.
995/// assert_eq!(log_sum_exp_iter(std::iter::empty::<f64>()), f64::NEG_INFINITY);
996/// ```
997#[inline]
998pub fn log_sum_exp_iter(iter: impl Iterator<Item = f64>) -> f64 {
999 let mut max = f64::NEG_INFINITY;
1000 let mut sum_exp = 0.0;
1001
1002 for v in iter {
1003 if v > max {
1004 if max.is_finite() {
1005 // Rescale the running sum when a new max arrives.
1006 sum_exp *= (max - v).exp();
1007 }
1008 max = v;
1009 }
1010 // (v - max).exp() is <= 1.0 since v <= max.
1011 sum_exp += (v - max).exp();
1012 }
1013
1014 if max.is_infinite() {
1015 return max; // Handles NEG_INFINITY (empty) and POS_INFINITY.
1016 }
1017 max + sum_exp.ln()
1018}
1019
1020/// Digamma function: the logarithmic derivative of the Gamma function.
1021///
1022/// \[\psi(x) = \frac{d}{dx} \ln \Gamma(x) = \frac{\Gamma'(x)}{\Gamma(x)}\]
1023///
1024/// # Key properties
1025///
1026/// - **Recurrence**: \(\psi(x+1) = \psi(x) + \frac{1}{x}\), which follows from
1027/// \(\Gamma(x+1) = x\,\Gamma(x)\).
1028/// - **Special value**: \(\psi(1) = -\gamma \approx -0.5772\), where \(\gamma\) is
1029/// the Euler--Mascheroni constant.
1030/// - **Asymptotic**: \(\psi(x) \sim \ln x - \frac{1}{2x}\) for large \(x\).
1031/// - **Why it appears here**: the KSG estimator for mutual information
1032/// ([`mutual_information_ksg`]) uses digamma to correct for the bias of
1033/// nearest-neighbor density estimates.
1034///
1035/// # Domain
1036///
1037/// Defined for \(x > 0\). Returns `NaN` for \(x \le 0\).
1038///
1039/// # Implementation
1040///
1041/// Uses the recurrence to shift small \(x\) up to \(x \ge 10\), then applies the
1042/// asymptotic expansion with 4 Bernoulli-number correction terms (~1e-14 accuracy).
1043///
1044/// # Examples
1045///
1046/// ```
1047/// # use logp::digamma;
1048/// // psi(1) = -gamma (Euler-Mascheroni constant).
1049/// let psi1 = digamma(1.0);
1050/// assert!((psi1 - (-0.5772156649)).abs() < 1e-12);
1051///
1052/// // Recurrence: psi(x+1) = psi(x) + 1/x.
1053/// let x = 3.5;
1054/// assert!((digamma(x + 1.0) - digamma(x) - 1.0 / x).abs() < 1e-12);
1055///
1056/// // Non-positive input returns NaN.
1057/// assert!(digamma(0.0).is_nan());
1058/// assert!(digamma(-1.0).is_nan());
1059/// ```
1060pub fn digamma(mut x: f64) -> f64 {
1061 if x <= 0.0 {
1062 return f64::NAN;
1063 }
1064 let mut result = 0.0;
1065 // Shift x upward via the recurrence psi(x+1) = psi(x) + 1/x until x >= 10,
1066 // where the asymptotic expansion converges to ~1e-14 with 4 Bernoulli terms.
1067 // (Previous threshold of 7 with 3 terms gave ~1e-10 accuracy.)
1068 while x < 10.0 {
1069 result -= 1.0 / x;
1070 x += 1.0;
1071 }
1072 let r = 1.0 / x;
1073 result += x.ln() - 0.5 * r;
1074 let r2 = r * r;
1075 // Bernoulli-number correction terms: B_{2k} / (2k * x^{2k}).
1076 // B2/2 = 1/12, B4/4 = 1/120, B6/6 = 1/252, B8/8 = 1/240.
1077 result -= r2 * (1.0 / 12.0 - r2 * (1.0 / 120.0 - r2 * (1.0 / 252.0 - r2 / 240.0)));
1078 result
1079}
1080
1081/// Bhattacharyya coefficient: the geometric-mean overlap between two distributions.
1082///
1083/// \[BC(p, q) = \sum_i \sqrt{p_i \, q_i}\]
1084///
1085/// # Key properties
1086///
1087/// - **Geometric mean interpretation**: each term \(\sqrt{p_i q_i}\) is the geometric
1088/// mean of the two probabilities at bin \(i\). BC sums these, measuring how much
1089/// the distributions overlap.
1090/// - **Bounded**: \(BC \in [0, 1]\). Equals 1 iff \(p = q\); equals 0 iff supports
1091/// are disjoint.
1092/// - **Relationship to Hellinger**: \(H^2(p, q) = 1 - BC(p, q)\), so the squared
1093/// Hellinger distance is just one minus the Bhattacharyya coefficient.
1094/// - **Relationship to Renyi**: at \(\alpha = \tfrac{1}{2}\), the Renyi divergence
1095/// gives \(D_{1/2}^R(p \| q) = -2 \ln BC(p, q)\).
1096/// - **Connection to alpha family**: \(BC = \rho_{1/2}(p, q)\), a special case of
1097/// [`rho_alpha`].
1098///
1099/// # Examples
1100///
1101/// ```
1102/// # use logp::bhattacharyya_coeff;
1103/// // BC(p, p) = 1.
1104/// let p = [0.3, 0.7];
1105/// assert!((bhattacharyya_coeff(&p, &p, 1e-9).unwrap() - 1.0).abs() < 1e-12);
1106///
1107/// // Disjoint supports: BC = 0.
1108/// let a = [1.0, 0.0];
1109/// let b = [0.0, 1.0];
1110/// assert!(bhattacharyya_coeff(&a, &b, 1e-9).unwrap().abs() < 1e-15);
1111///
1112/// // BC is in [0, 1].
1113/// let q = [0.5, 0.5];
1114/// let bc = bhattacharyya_coeff(&p, &q, 1e-9).unwrap();
1115/// assert!((0.0..=1.0).contains(&bc));
1116/// ```
1117pub fn bhattacharyya_coeff(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
1118 ensure_same_len(p, q)?;
1119 validate_simplex(p, tol)?;
1120 validate_simplex(q, tol)?;
1121 let bc: f64 = p
1122 .iter()
1123 .zip(q.iter())
1124 .map(|(&pi, &qi)| pi.sqrt() * qi.sqrt())
1125 .sum();
1126 Ok(bc)
1127}
1128
1129/// Bhattacharyya distance \(D_B(p,q) = -\ln BC(p,q)\).
1130///
1131/// # Examples
1132///
1133/// ```
1134/// # use logp::bhattacharyya_distance;
1135/// // D_B(p, p) = 0.
1136/// let p = [0.4, 0.6];
1137/// assert!(bhattacharyya_distance(&p, &p, 1e-9).unwrap().abs() < 1e-12);
1138///
1139/// // Non-negative for distinct distributions.
1140/// let q = [0.5, 0.5];
1141/// assert!(bhattacharyya_distance(&p, &q, 1e-9).unwrap() >= 0.0);
1142/// ```
1143pub fn bhattacharyya_distance(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
1144 let bc = bhattacharyya_coeff(p, q, tol)?;
1145 // When supports are disjoint, bc can be 0 (=> +∞ distance). Keep it explicit.
1146 if bc == 0.0 {
1147 return Err(Error::Domain("Bhattacharyya distance is infinite (BC=0)"));
1148 }
1149 Ok(-bc.ln())
1150}
1151
1152/// Squared Hellinger distance.
1153///
1154/// \[H^2(p, q) = \frac{1}{2}\sum_i \left(\sqrt{p_i} - \sqrt{q_i}\right)^2\]
1155///
1156/// Equivalent to \(1 - BC(p,q)\) but computed via the sum-of-squared-differences
1157/// form to avoid catastrophic cancellation when \(p \approx q\).
1158///
1159/// Bounded in \([0, 1]\). Equals the Amari \(\alpha\)-divergence at \(\alpha = 0\)
1160/// (up to a factor of 2).
1161///
1162/// # Examples
1163///
1164/// ```
1165/// # use logp::hellinger_squared;
1166/// // H^2(p, p) = 0.
1167/// let p = [0.25, 0.75];
1168/// assert!(hellinger_squared(&p, &p, 1e-9).unwrap().abs() < 1e-15);
1169///
1170/// // Disjoint supports: H^2 = 1.
1171/// let a = [1.0, 0.0];
1172/// let b = [0.0, 1.0];
1173/// assert!((hellinger_squared(&a, &b, 1e-9).unwrap() - 1.0).abs() < 1e-12);
1174/// ```
1175pub fn hellinger_squared(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
1176 ensure_same_len(p, q)?;
1177 validate_simplex(p, tol)?;
1178 validate_simplex(q, tol)?;
1179 let h2: f64 = p
1180 .iter()
1181 .zip(q.iter())
1182 .map(|(&pi, &qi)| {
1183 let diff = pi.sqrt() - qi.sqrt();
1184 diff * diff
1185 })
1186 .sum();
1187 Ok((0.5 * h2).max(0.0))
1188}
1189
1190/// Hellinger distance: the square root of the squared Hellinger distance.
1191///
1192/// \[H(p, q) = \sqrt{1 - BC(p, q)}\]
1193///
1194/// Unlike KL divergence, Hellinger is a **proper metric**: it is symmetric, satisfies
1195/// the triangle inequality, and is bounded in \([0, 1]\).
1196///
1197/// # Examples
1198///
1199/// ```
1200/// # use logp::hellinger;
1201/// // H(p, p) = 0.
1202/// let p = [0.3, 0.7];
1203/// assert!(hellinger(&p, &p, 1e-9).unwrap().abs() < 1e-15);
1204///
1205/// // Symmetric: H(p, q) = H(q, p).
1206/// let q = [0.5, 0.5];
1207/// let h_pq = hellinger(&p, &q, 1e-9).unwrap();
1208/// let h_qp = hellinger(&q, &p, 1e-9).unwrap();
1209/// assert!((h_pq - h_qp).abs() < 1e-15);
1210///
1211/// // Bounded in [0, 1].
1212/// assert!((0.0..=1.0).contains(&h_pq));
1213/// ```
1214pub fn hellinger(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
1215 Ok(hellinger_squared(p, q, tol)?.sqrt())
1216}
1217
1218fn pow_nonneg(x: f64, a: f64) -> Result<f64> {
1219 if x < 0.0 || !x.is_finite() || !a.is_finite() {
1220 return Err(Error::Domain("pow_nonneg: invalid input"));
1221 }
1222 if x == 0.0 {
1223 if a == 0.0 {
1224 // By continuity in the divergence formulas, treat 0^0 as 1.
1225 return Ok(1.0);
1226 }
1227 if a > 0.0 {
1228 return Ok(0.0);
1229 }
1230 return Err(Error::Domain("0^a for a<0 is infinite"));
1231 }
1232 Ok(x.powf(a))
1233}
1234
1235/// Alpha-integral: the workhorse behind the entire alpha-family of divergences.
1236///
1237/// \[\rho_\alpha(p, q) = \sum_i p_i^\alpha \, q_i^{1-\alpha}\]
1238///
1239/// # Why this matters
1240///
1241/// This single quantity generates multiple divergence families via simple transforms:
1242///
1243/// - **Renyi**: \(D_\alpha^R = \frac{1}{\alpha - 1} \ln \rho_\alpha\)
1244/// - **Tsallis**: \(D_\alpha^T = \frac{\rho_\alpha - 1}{\alpha - 1}\)
1245/// - **Bhattacharyya coefficient**: \(BC = \rho_{1/2}\)
1246/// - **Chernoff information**: \(\min_\alpha (-\ln \rho_\alpha)\)
1247///
1248/// # Key properties
1249///
1250/// - \(\rho_\alpha(p, p) = 1\) for all \(\alpha\) (since \(\sum p_i = 1\)).
1251/// - By Holder's inequality, \(\rho_\alpha(p, q) \le 1\) for \(\alpha \in [0, 1]\).
1252/// - Continuous and log-convex in \(\alpha\).
1253///
1254/// # Examples
1255///
1256/// ```
1257/// # use logp::rho_alpha;
1258/// // rho_alpha(p, p, alpha) = 1 for any alpha (since sum(p) = 1).
1259/// let p = [0.2, 0.3, 0.5];
1260/// assert!((rho_alpha(&p, &p, 0.5, 1e-9).unwrap() - 1.0).abs() < 1e-12);
1261/// assert!((rho_alpha(&p, &p, 2.0, 1e-9).unwrap() - 1.0).abs() < 1e-12);
1262///
1263/// // At alpha = 0.5, rho equals the Bhattacharyya coefficient.
1264/// let q = [0.5, 0.25, 0.25];
1265/// let rho = rho_alpha(&p, &q, 0.5, 1e-9).unwrap();
1266/// let bc = logp::bhattacharyya_coeff(&p, &q, 1e-9).unwrap();
1267/// assert!((rho - bc).abs() < 1e-12);
1268/// ```
1269pub fn rho_alpha(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
1270 ensure_same_len(p, q)?;
1271 validate_simplex(p, tol)?;
1272 validate_simplex(q, tol)?;
1273 if !alpha.is_finite() {
1274 return Err(Error::InvalidAlpha {
1275 alpha,
1276 forbidden: f64::NAN,
1277 });
1278 }
1279 let mut s = 0.0;
1280 for (&pi, &qi) in p.iter().zip(q.iter()) {
1281 let a = pow_nonneg(pi, alpha)?;
1282 let b = pow_nonneg(qi, 1.0 - alpha)?;
1283 s += a * b;
1284 }
1285 Ok(s)
1286}
1287
1288/// Renyi divergence in nats: a one-parameter family that interpolates between
1289/// different notions of distributional difference.
1290///
1291/// \[D_\alpha^R(p \| q) = \frac{1}{\alpha - 1} \ln \rho_\alpha(p, q), \quad \alpha > 0,\; \alpha \ne 1\]
1292///
1293/// # Key properties
1294///
1295/// - **Limit to KL**: \(\lim_{\alpha \to 1} D_\alpha^R(p \| q) = D_{KL}(p \| q)\)
1296/// by L'Hopital's rule (the logarithm and denominator both vanish).
1297/// - **Alpha = 1/2**: \(D_{1/2}^R = -2 \ln BC(p, q)\), twice the negative log
1298/// Bhattacharyya coefficient.
1299/// - **Alpha = infinity**: \(D_\infty^R = \ln \max_i (p_i / q_i)\), the log of the
1300/// maximum likelihood ratio. This bounds all other Renyi orders.
1301/// - **Monotone in alpha**: \(D_\alpha^R\) is non-decreasing in \(\alpha\).
1302/// - **Non-negative**: \(D_\alpha^R(p \| q) \ge 0\), with equality iff \(p = q\).
1303///
1304/// # Domain
1305///
1306/// \(\alpha > 0\), \(\alpha \ne 1\). Both `p` and `q` must be simplex distributions.
1307///
1308/// # Examples
1309///
1310/// ```
1311/// # use logp::renyi_divergence;
1312/// // D_alpha(p || p) = 0 for any valid alpha.
1313/// let p = [0.3, 0.7];
1314/// assert!(renyi_divergence(&p, &p, 2.0, 1e-9).unwrap().abs() < 1e-12);
1315///
1316/// // Non-negative.
1317/// let q = [0.5, 0.5];
1318/// assert!(renyi_divergence(&p, &q, 0.5, 1e-9).unwrap() >= -1e-12);
1319///
1320/// // alpha = 1.0 returns KL divergence (Shannon limit).
1321/// let kl = logp::kl_divergence(&p, &q, 1e-9).unwrap();
1322/// let r1 = renyi_divergence(&p, &q, 1.0, 1e-9).unwrap();
1323/// assert!((r1 - kl).abs() < 1e-12);
1324/// ```
1325pub fn renyi_divergence(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
1326 if (alpha - 1.0).abs() < 1e-12 {
1327 return kl_divergence(p, q, tol);
1328 }
1329 let rho = rho_alpha(p, q, alpha, tol)?;
1330 if rho <= 0.0 {
1331 return Err(Error::Domain("rho_alpha <= 0"));
1332 }
1333 Ok(rho.ln() / (alpha - 1.0))
1334}
1335
1336/// Tsallis divergence: a non-extensive generalization of KL divergence from
1337/// statistical mechanics.
1338///
1339/// \[D_\alpha^T(p \| q) = \frac{\rho_\alpha(p, q) - 1}{\alpha - 1}, \quad \alpha \ne 1\]
1340///
1341/// # Key properties
1342///
1343/// - **Limit to KL**: \(\lim_{\alpha \to 1} D_\alpha^T(p \| q) = D_{KL}(p \| q)\),
1344/// same limit as Renyi but via a different path.
1345/// - **Connection to Renyi via deformed logarithm**: Tsallis uses the q-logarithm
1346/// \(\ln_q(x) = \frac{x^{1-q} - 1}{1-q}\) where Renyi uses the ordinary log.
1347/// Formally: \(D_\alpha^T = \frac{e^{(\alpha-1) D_\alpha^R} - 1}{\alpha - 1}\).
1348/// - **Non-extensive**: for independent systems, Tsallis divergence is **not** additive
1349/// (unlike KL and Renyi). This property is intentional and models systems with
1350/// long-range correlations in statistical physics.
1351/// - **Non-negative**: \(D_\alpha^T(p \| q) \ge 0\), with equality iff \(p = q\).
1352///
1353/// # Domain
1354///
1355/// \(\alpha \ne 1\). Both `p` and `q` must be simplex distributions.
1356///
1357/// # Examples
1358///
1359/// ```
1360/// # use logp::tsallis_divergence;
1361/// // D_alpha^T(p || p) = 0 for any valid alpha.
1362/// let p = [0.4, 0.6];
1363/// assert!(tsallis_divergence(&p, &p, 2.0, 1e-9).unwrap().abs() < 1e-12);
1364///
1365/// // Non-negative.
1366/// let q = [0.5, 0.5];
1367/// assert!(tsallis_divergence(&p, &q, 0.5, 1e-9).unwrap() >= -1e-12);
1368///
1369/// // alpha = 1.0 returns KL divergence (Shannon limit).
1370/// let kl = logp::kl_divergence(&p, &q, 1e-9).unwrap();
1371/// let t1 = tsallis_divergence(&p, &q, 1.0, 1e-9).unwrap();
1372/// assert!((t1 - kl).abs() < 1e-12);
1373/// ```
1374pub fn tsallis_divergence(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
1375 if (alpha - 1.0).abs() < 1e-12 {
1376 return kl_divergence(p, q, tol);
1377 }
1378 Ok((rho_alpha(p, q, alpha, tol)? - 1.0) / (alpha - 1.0))
1379}
1380
1381/// Amari alpha-divergence: a one-parameter family from information geometry that
1382/// continuously interpolates between forward KL, reverse KL, and squared Hellinger.
1383///
1384/// For \(\alpha \notin \{-1, 1\}\):
1385///
1386/// \[D^\alpha(p : q) = \frac{4}{1 - \alpha^2}\left(1 - \rho_{\frac{1-\alpha}{2}}(p, q)\right)\]
1387///
1388/// # Key properties
1389///
1390/// - **\(\alpha = -1\)**: recovers \(D_{KL}(p \| q)\), the forward KL divergence.
1391/// - **\(\alpha = +1\)**: recovers \(D_{KL}(q \| p)\), the reverse KL divergence.
1392/// - **\(\alpha = 0\)**: gives \(4(1 - BC(p,q)) = 4\,H^2(p,q)\), proportional to
1393/// the squared Hellinger distance.
1394/// - **Duality**: \(D^\alpha(p : q) = D^{-\alpha}(q : p)\). Swapping the sign of
1395/// \(\alpha\) is the same as swapping the arguments.
1396/// - **Non-negative**: \(D^\alpha(p : q) \ge 0\), with equality iff \(p = q\).
1397/// - **Information geometry**: the Amari family parameterizes the \(\alpha\)-connections
1398/// on the statistical manifold (Amari & Nagaoka, 2000).
1399///
1400/// # Examples
1401///
1402/// ```
1403/// # use logp::{amari_alpha_divergence, kl_divergence, hellinger_squared};
1404/// let p = [0.3, 0.7];
1405/// let q = [0.5, 0.5];
1406/// let tol = 1e-9;
1407///
1408/// // alpha = -1 gives forward KL(p || q).
1409/// let amari_neg1 = amari_alpha_divergence(&p, &q, -1.0, tol).unwrap();
1410/// let kl_fwd = kl_divergence(&p, &q, tol).unwrap();
1411/// assert!((amari_neg1 - kl_fwd).abs() < 1e-6);
1412///
1413/// // alpha = 0 gives 4 * H^2(p, q).
1414/// let amari_0 = amari_alpha_divergence(&p, &q, 0.0, tol).unwrap();
1415/// let h2 = hellinger_squared(&p, &q, tol).unwrap();
1416/// assert!((amari_0 - 4.0 * h2).abs() < 1e-10);
1417/// ```
1418pub fn amari_alpha_divergence(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
1419 if !alpha.is_finite() {
1420 return Err(Error::InvalidAlpha {
1421 alpha,
1422 forbidden: f64::NAN,
1423 });
1424 }
1425 // Numerically stable handling near ±1.
1426 let eps = tol.sqrt();
1427 if (alpha + 1.0).abs() <= eps {
1428 return kl_divergence(p, q, tol);
1429 }
1430 if (alpha - 1.0).abs() <= eps {
1431 return kl_divergence(q, p, tol);
1432 }
1433 let t = (1.0 - alpha) / 2.0;
1434 let rho = rho_alpha(p, q, t, tol)?;
1435 Ok((4.0 / (1.0 - alpha * alpha)) * (1.0 - rho))
1436}
1437
1438/// Csiszar f-divergence: the most general class of divergences that respect
1439/// sufficient statistics (information monotonicity).
1440///
1441/// \[D_f(p \| q) = \sum_i q_i \, f\!\left(\frac{p_i}{q_i}\right)\]
1442///
1443/// where \(f\) is a convex function with \(f(1) = 0\).
1444///
1445/// # Information monotonicity theorem
1446///
1447/// The defining property of f-divergences (Csiszar, 1967): for any Markov kernel
1448/// (stochastic map) \(T\),
1449///
1450/// \[D_f(Tp \| Tq) \le D_f(p \| q)\]
1451///
1452/// Coarse-graining (merging bins) cannot increase the divergence. This is the
1453/// information-theoretic analogue of the data processing inequality.
1454///
1455/// # Common f-generators
1456///
1457/// | Divergence | \(f(t)\) |
1458/// |---|---|
1459/// | KL divergence | \(t \ln t\) |
1460/// | Reverse KL | \(-\ln t\) |
1461/// | Squared Hellinger | \((\sqrt{t} - 1)^2\) |
1462/// | Total variation | \(\tfrac{1}{2} |t - 1|\) |
1463/// | Chi-squared | \((t - 1)^2\) |
1464/// | Jensen-Shannon | \(t \ln t - (1+t) \ln \tfrac{1+t}{2}\) |
1465///
1466/// # Convention
1467///
1468/// This function uses Csiszar's original convention: \(D_f(p \| q) = \sum q_i f(p_i/q_i)\).
1469/// Some textbooks reverse the roles of \(p\) and \(q\), writing
1470/// \(\sum p_i f(q_i/p_i)\), which silently computes the *conjugate* divergence
1471/// \(D_{f^*}\) where \(f^*(t) = t\,f(1/t)\). The generators in the table above
1472/// follow this function's convention.
1473///
1474/// # Edge cases
1475///
1476/// When `q_i = 0`:
1477/// - if `p_i = 0`, the contribution is treated as 0 (by continuity).
1478/// - if `p_i > 0`, the divergence is infinite; we return an error.
1479///
1480/// # Examples
1481///
1482/// ```
1483/// # use logp::{csiszar_f_divergence, kl_divergence};
1484/// let p = [0.3, 0.7];
1485/// let q = [0.5, 0.5];
1486///
1487/// // f(t) = t*ln(t) recovers KL divergence.
1488/// let cs = csiszar_f_divergence(&p, &q, |t| t * t.ln(), 1e-9).unwrap();
1489/// let kl = kl_divergence(&p, &q, 1e-9).unwrap();
1490/// assert!((cs - kl).abs() < 1e-10);
1491///
1492/// // f(t) = (t - 1)^2 gives chi-squared divergence.
1493/// let chi2 = csiszar_f_divergence(&p, &q, |t| (t - 1.0).powi(2), 1e-9).unwrap();
1494/// assert!(chi2 >= 0.0);
1495/// ```
1496pub fn csiszar_f_divergence(p: &[f64], q: &[f64], f: impl Fn(f64) -> f64, tol: f64) -> Result<f64> {
1497 ensure_same_len(p, q)?;
1498 validate_simplex(p, tol)?;
1499 validate_simplex(q, tol)?;
1500
1501 let mut d = 0.0;
1502 for (&pi, &qi) in p.iter().zip(q.iter()) {
1503 if qi == 0.0 {
1504 if pi == 0.0 {
1505 continue;
1506 }
1507 return Err(Error::Domain("f-divergence undefined: q_i=0 while p_i>0"));
1508 }
1509 d += qi * f(pi / qi);
1510 }
1511 Ok(d)
1512}
1513
1514/// Total variation distance: half the L1 norm between two distributions.
1515///
1516/// \[TV(p, q) = \frac{1}{2} \sum_i |p_i - q_i|\]
1517///
1518/// Equivalently, this is the Csiszar f-divergence with generator
1519/// \(f(t) = \frac{1}{2} |t - 1|\).
1520///
1521/// # Key properties
1522///
1523/// - **Metric**: symmetric, satisfies the triangle inequality, and \(TV(p, p) = 0\).
1524/// - **Bounded**: \(TV \in [0, 1]\).
1525/// - **Pinsker's inequality**: \(TV(p, q) \le \sqrt{\frac{1}{2} D_{KL}(p \| q)}\).
1526///
1527/// # Examples
1528///
1529/// ```
1530/// # use logp::total_variation;
1531/// // TV(p, p) = 0.
1532/// let p = [0.3, 0.7];
1533/// assert!(total_variation(&p, &p, 1e-9).unwrap().abs() < 1e-15);
1534///
1535/// // Disjoint supports: TV = 1.
1536/// let a = [1.0, 0.0];
1537/// let b = [0.0, 1.0];
1538/// assert!((total_variation(&a, &b, 1e-9).unwrap() - 1.0).abs() < 1e-12);
1539///
1540/// // Symmetric.
1541/// let q = [0.5, 0.5];
1542/// let tv_pq = total_variation(&p, &q, 1e-9).unwrap();
1543/// let tv_qp = total_variation(&q, &p, 1e-9).unwrap();
1544/// assert!((tv_pq - tv_qp).abs() < 1e-15);
1545/// ```
1546pub fn total_variation(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
1547 ensure_same_len(p, q)?;
1548 validate_simplex(p, tol)?;
1549 validate_simplex(q, tol)?;
1550 let tv: f64 = p
1551 .iter()
1552 .zip(q.iter())
1553 .map(|(&pi, &qi)| (pi - qi).abs())
1554 .sum();
1555 Ok(0.5 * tv)
1556}
1557
1558/// Chi-squared divergence: a member of the Csiszar f-divergence family that is
1559/// particularly sensitive to tail differences.
1560///
1561/// \[\chi^2(p \| q) = \sum_i \frac{(p_i - q_i)^2}{q_i}\]
1562///
1563/// Equivalently, the f-divergence with generator \(f(t) = (t - 1)^2\).
1564///
1565/// # Key properties
1566///
1567/// - **Non-negative**: \(\chi^2 \ge 0\), with equality iff \(p = q\).
1568/// - **Upper bounds KL**: \(D_{KL}(p \| q) \le \ln(1 + \chi^2(p \| q))\).
1569/// - **Sensitivity warning**: unbounded and extremely sensitive to small \(q_i\).
1570/// When \(q_i\) is near zero but \(p_i\) is not, the ratio \((p_i - q_i)^2 / q_i\)
1571/// can be very large even when KL divergence would be moderate.
1572///
1573/// # Domain
1574///
1575/// Requires \(q_i > 0\) whenever \(p_i > 0\) (same as KL).
1576///
1577/// # Examples
1578///
1579/// ```
1580/// # use logp::chi_squared_divergence;
1581/// // chi^2(p, p) = 0.
1582/// let p = [0.3, 0.7];
1583/// assert!(chi_squared_divergence(&p, &p, 1e-9).unwrap().abs() < 1e-15);
1584///
1585/// // Non-negative.
1586/// let q = [0.5, 0.5];
1587/// assert!(chi_squared_divergence(&p, &q, 1e-9).unwrap() >= 0.0);
1588/// ```
1589pub fn chi_squared_divergence(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
1590 ensure_same_len(p, q)?;
1591 validate_simplex(p, tol)?;
1592 validate_simplex(q, tol)?;
1593 let mut d = 0.0;
1594 for (&pi, &qi) in p.iter().zip(q.iter()) {
1595 if qi == 0.0 {
1596 if pi == 0.0 {
1597 continue;
1598 }
1599 return Err(Error::Domain("chi-squared undefined: q_i=0 while p_i>0"));
1600 }
1601 let diff = pi - qi;
1602 d += diff * diff / qi;
1603 }
1604 Ok(d)
1605}
1606
1607/// Bregman generator: a convex function \(F\) and its gradient.
1608pub trait BregmanGenerator {
1609 /// Evaluate the potential \(F(x)\).
1610 fn f(&self, x: &[f64]) -> Result<f64>;
1611
1612 /// Write \(\nabla F(x)\) into `out`.
1613 fn grad_into(&self, x: &[f64], out: &mut [f64]) -> Result<()>;
1614}
1615
1616/// Bregman divergence: the gap between a convex function and its tangent approximation.
1617///
1618/// \[B_F(p, q) = F(p) - F(q) - \langle p - q,\, \nabla F(q) \rangle\]
1619///
1620/// # Key properties
1621///
1622/// - **Non-negative**: \(B_F(p, q) \ge 0\) by convexity of \(F\), with equality iff
1623/// \(p = q\).
1624/// - **Not symmetric** in general: \(B_F(p, q) \ne B_F(q, p)\).
1625/// - **Generalized Pythagorean theorem**: for an affine subspace \(S\) and its
1626/// Bregman projection \(q^* = \arg\min_{q \in S} B_F(p, q)\), the three-point
1627/// identity holds: \(B_F(p, q) = B_F(p, q^*) + B_F(q^*, q)\) for all \(q \in S\).
1628/// This is the foundation of dually flat geometry (Amari).
1629/// - **Not an f-divergence**: Bregman divergences are **not** information monotone
1630/// in general. They live in a different branch of the divergence taxonomy.
1631/// - **Examples**: squared Euclidean (\(F = \tfrac{1}{2}\|x\|^2\)) gives
1632/// \(B_F(p,q) = \tfrac{1}{2}\|p - q\|^2\); negative entropy
1633/// (\(F = \sum x_i \ln x_i\)) gives the KL divergence.
1634///
1635/// # Examples
1636///
1637/// ```
1638/// # use logp::{bregman_divergence, SquaredL2};
1639/// // Squared-L2 generator: B_F(p, q) = 0.5 * ||p - q||^2.
1640/// let gen = SquaredL2;
1641/// let p = [1.0, 2.0, 3.0];
1642/// let q = [1.5, 1.5, 2.5];
1643/// let b = bregman_divergence(&gen, &p, &q).unwrap();
1644/// let expected = 0.5 * ((0.5_f64).powi(2) + (0.5_f64).powi(2) + (0.5_f64).powi(2));
1645/// assert!((b - expected).abs() < 1e-12);
1646///
1647/// // B_F(p, p) = 0.
1648/// assert!(bregman_divergence(&gen, &p, &p).unwrap().abs() < 1e-15);
1649/// ```
1650pub fn bregman_divergence(gen: &impl BregmanGenerator, p: &[f64], q: &[f64]) -> Result<f64> {
1651 ensure_nonempty(p)?;
1652 ensure_same_len(p, q)?;
1653 let mut grad_q = vec![0.0; q.len()];
1654 gen.grad_into(q, &mut grad_q)?;
1655 let fp = gen.f(p)?;
1656 let fq = gen.f(q)?;
1657 let mut inner = 0.0;
1658 for i in 0..p.len() {
1659 inner += (p[i] - q[i]) * grad_q[i];
1660 }
1661 Ok(fp - fq - inner)
1662}
1663
1664/// Total Bregman divergence as shown in Nielsen’s taxonomy diagram:
1665///
1666/// \(tB_F(p,q) = \frac{B_F(p,q)}{\sqrt{1 + \|\nabla F(q)\|^2}}\).
1667///
1668/// # Examples
1669///
1670/// ```
1671/// # use logp::{total_bregman_divergence, bregman_divergence, SquaredL2};
1672/// let gen = SquaredL2;
1673/// let p = [1.0, 2.0];
1674/// let q = [3.0, 4.0];
1675///
1676/// let tb = total_bregman_divergence(&gen, &p, &q).unwrap();
1677///
1678/// // Total Bregman <= Bregman (normalization divides by >= 1).
1679/// let b = bregman_divergence(&gen, &p, &q).unwrap();
1680/// assert!(tb <= b + 1e-12);
1681/// ```
1682pub fn total_bregman_divergence(gen: &impl BregmanGenerator, p: &[f64], q: &[f64]) -> Result<f64> {
1683 ensure_nonempty(p)?;
1684 ensure_same_len(p, q)?;
1685 let mut grad_q = vec![0.0; q.len()];
1686 gen.grad_into(q, &mut grad_q)?;
1687 let fp = gen.f(p)?;
1688 let fq = gen.f(q)?;
1689 let mut inner = 0.0;
1690 for i in 0..p.len() {
1691 inner += (p[i] - q[i]) * grad_q[i];
1692 }
1693 let b = fp - fq - inner;
1694 let grad_norm_sq: f64 = grad_q.iter().map(|&x| x * x).sum();
1695 Ok(b / (1.0 + grad_norm_sq).sqrt())
1696}
1697
1698/// Squared Euclidean Bregman generator: \(F(x)=\tfrac12\|x\|_2^2\), \(\nabla F(x)=x\).
1699#[derive(Debug, Clone, Copy, Default)]
1700pub struct SquaredL2;
1701
1702impl BregmanGenerator for SquaredL2 {
1703 fn f(&self, x: &[f64]) -> Result<f64> {
1704 ensure_nonempty(x)?;
1705 Ok(0.5 * x.iter().map(|&v| v * v).sum::<f64>())
1706 }
1707
1708 fn grad_into(&self, x: &[f64], out: &mut [f64]) -> Result<()> {
1709 ensure_nonempty(x)?;
1710 if out.len() != x.len() {
1711 return Err(Error::LengthMismatch(out.len(), x.len()));
1712 }
1713 out.copy_from_slice(x);
1714 Ok(())
1715 }
1716}
1717
1718/// Negative-entropy Bregman generator: \(F(x) = \sum_i x_i \ln x_i\),
1719/// \(\nabla F(x)_i = 1 + \ln x_i\).
1720///
1721/// The Bregman divergence with this generator is the (unnormalized) KL divergence:
1722/// \(B_F(p, q) = \sum_i p_i \ln(p_i / q_i) - \sum_i (p_i - q_i)\).
1723/// When \(p\) and \(q\) are normalized (simplex), the second sum vanishes and
1724/// \(B_F = D_{KL}(p \| q)\).
1725///
1726/// This connects the information-theoretic (f-divergence) and geometric (Bregman)
1727/// views of KL divergence. The dually-flat structure of the probability simplex
1728/// under this generator is the foundation of information geometry (Amari & Nagaoka, 2000).
1729#[derive(Debug, Clone, Copy, Default)]
1730pub struct NegEntropy;
1731
1732impl BregmanGenerator for NegEntropy {
1733 fn f(&self, x: &[f64]) -> Result<f64> {
1734 ensure_nonempty(x)?;
1735 let mut s = 0.0;
1736 for &xi in x {
1737 if xi < 0.0 {
1738 return Err(Error::Domain("NegEntropy: input must be nonnegative"));
1739 }
1740 if xi > 0.0 {
1741 s += xi * xi.ln();
1742 }
1743 }
1744 Ok(s)
1745 }
1746
1747 fn grad_into(&self, x: &[f64], out: &mut [f64]) -> Result<()> {
1748 ensure_nonempty(x)?;
1749 if out.len() != x.len() {
1750 return Err(Error::LengthMismatch(out.len(), x.len()));
1751 }
1752 for (o, &xi) in out.iter_mut().zip(x.iter()) {
1753 if xi <= 0.0 {
1754 return Err(Error::Domain("NegEntropy grad: input must be positive"));
1755 }
1756 *o = 1.0 + xi.ln();
1757 }
1758 Ok(())
1759 }
1760}
1761
1762#[cfg(test)]
1763mod tests {
1764 use super::*;
1765 use proptest::prelude::*;
1766
1767 const TOL: f64 = 1e-9;
1768
1769 fn simplex_vec(len: usize) -> impl Strategy<Value = Vec<f64>> {
1770 // Draw nonnegative weights then normalize.
1771 prop::collection::vec(0.0f64..10.0, len).prop_map(|mut v| {
1772 let s: f64 = v.iter().sum();
1773 if s == 0.0 {
1774 v[0] = 1.0;
1775 return v;
1776 }
1777 for x in v.iter_mut() {
1778 *x /= s;
1779 }
1780 v
1781 })
1782 }
1783
1784 fn simplex_vec_pos(len: usize, eps: f64) -> impl Strategy<Value = Vec<f64>> {
1785 prop::collection::vec(0.0f64..10.0, len).prop_map(move |mut v| {
1786 // Add a small floor to avoid exact zeros (needed for KL-style domains).
1787 for x in v.iter_mut() {
1788 *x += eps;
1789 }
1790 let s: f64 = v.iter().sum();
1791 for x in v.iter_mut() {
1792 *x /= s;
1793 }
1794 v
1795 })
1796 }
1797
1798 fn random_partition(n: usize) -> impl Strategy<Value = Vec<usize>> {
1799 // Partition indices into k buckets (k chosen implicitly).
1800 // We generate a label in [0, n) for each index and later reindex to compact labels.
1801 prop::collection::vec(0usize..n, n).prop_map(|labels| {
1802 // Compress labels to 0..k-1 while preserving equality pattern.
1803 use std::collections::BTreeMap;
1804 let mut map = BTreeMap::<usize, usize>::new();
1805 let mut next = 0usize;
1806 labels
1807 .into_iter()
1808 .map(|l| {
1809 *map.entry(l).or_insert_with(|| {
1810 let id = next;
1811 next += 1;
1812 id
1813 })
1814 })
1815 .collect::<Vec<_>>()
1816 })
1817 }
1818
1819 fn coarse_grain(p: &[f64], labels: &[usize]) -> Vec<f64> {
1820 let k = labels.iter().copied().max().unwrap_or(0) + 1;
1821 let mut out = vec![0.0; k];
1822 for (i, &lab) in labels.iter().enumerate() {
1823 out[lab] += p[i];
1824 }
1825 out
1826 }
1827
1828 fn l1(p: &[f64], q: &[f64]) -> f64 {
1829 p.iter().zip(q.iter()).map(|(&a, &b)| (a - b).abs()).sum()
1830 }
1831
1832 #[test]
1833 fn test_entropy_unchecked() {
1834 let p = [0.5, 0.5];
1835 let h = entropy_unchecked(&p);
1836 // -0.5*ln(0.5) - 0.5*ln(0.5) = -ln(0.5) = ln(2)
1837 assert!((h - core::f64::consts::LN_2).abs() < 1e-12);
1838 }
1839
1840 #[test]
1841 fn js_is_bounded_by_ln2() {
1842 let p = [1.0, 0.0];
1843 let q = [0.0, 1.0];
1844 let js = jensen_shannon_divergence(&p, &q, TOL).unwrap();
1845 assert!(js <= core::f64::consts::LN_2 + 1e-12);
1846 assert!(js >= 0.0);
1847 }
1848
1849 #[test]
1850 fn mutual_information_independent_is_zero() {
1851 // p(x,y) = p(x)p(y) ⇒ I(X;Y)=0
1852 let p_x = [0.5, 0.5];
1853 let p_y = [0.25, 0.75];
1854 // Row-major 2x2:
1855 // [0.125, 0.375,
1856 // 0.125, 0.375]
1857 let p_xy = [
1858 p_x[0] * p_y[0],
1859 p_x[0] * p_y[1],
1860 p_x[1] * p_y[0],
1861 p_x[1] * p_y[1],
1862 ];
1863 let mi = mutual_information(&p_xy, 2, 2, TOL).unwrap();
1864 assert!(mi.abs() < 1e-12, "mi={}", mi);
1865 }
1866
1867 #[test]
1868 fn mutual_information_perfect_correlation_is_ln2() {
1869 // X=Y uniform bit ⇒ I(X;Y)=ln 2 (nats)
1870 let p_xy = [0.5, 0.0, 0.0, 0.5]; // 2x2 diagonal
1871 let mi = mutual_information(&p_xy, 2, 2, TOL).unwrap();
1872 assert!((mi - core::f64::consts::LN_2).abs() < 1e-12, "mi={}", mi);
1873 }
1874
1875 #[test]
1876 fn bregman_squared_l2_matches_half_l2() {
1877 let gen = SquaredL2;
1878 let p = [1.0, 2.0, 3.0];
1879 let q = [1.5, 1.5, 2.5];
1880 let b = bregman_divergence(&gen, &p, &q).unwrap();
1881 let expected = 0.5
1882 * p.iter()
1883 .zip(q.iter())
1884 .map(|(&a, &b)| (a - b) * (a - b))
1885 .sum::<f64>();
1886 assert!((b - expected).abs() < 1e-12);
1887 }
1888
1889 // --- Entropy tests ---
1890
1891 #[test]
1892 fn entropy_nats_uniform_is_ln_n() {
1893 // Uniform distribution over n items: H = ln(n)
1894 for n in [2, 4, 8, 16] {
1895 let p: Vec<f64> = vec![1.0 / n as f64; n];
1896 let h = entropy_nats(&p, TOL).unwrap();
1897 let expected = (n as f64).ln();
1898 assert!(
1899 (h - expected).abs() < 1e-12,
1900 "n={n}: h={h} expected={expected}"
1901 );
1902 }
1903 }
1904
1905 #[test]
1906 fn entropy_nats_singleton_is_zero() {
1907 let h = entropy_nats(&[1.0], TOL).unwrap();
1908 assert!(h.abs() < 1e-15);
1909 }
1910
1911 #[test]
1912 fn entropy_bits_converts_correctly() {
1913 let p = [0.25, 0.75];
1914 let nats = entropy_nats(&p, TOL).unwrap();
1915 let bits = entropy_bits(&p, TOL).unwrap();
1916 assert!((bits - nats / core::f64::consts::LN_2).abs() < 1e-12);
1917 }
1918
1919 // --- Cross-entropy tests ---
1920
1921 #[test]
1922 fn cross_entropy_identity_h_pq_eq_h_p_plus_kl() {
1923 let p = [0.3, 0.7];
1924 let q = [0.5, 0.5];
1925 let h_pq = cross_entropy_nats(&p, &q, TOL).unwrap();
1926 let h_p = entropy_nats(&p, TOL).unwrap();
1927 let kl = kl_divergence(&p, &q, TOL).unwrap();
1928 assert!((h_pq - (h_p + kl)).abs() < 1e-12);
1929 }
1930
1931 #[test]
1932 fn cross_entropy_rejects_zero_q_with_positive_p() {
1933 let p = [0.5, 0.5];
1934 let q = [1.0, 0.0]; // q[1]=0 but p[1]=0.5
1935 assert!(cross_entropy_nats(&p, &q, TOL).is_err());
1936 }
1937
1938 // --- Validate / normalize tests ---
1939
1940 #[test]
1941 fn validate_simplex_accepts_valid() {
1942 assert!(validate_simplex(&[0.3, 0.7], TOL).is_ok());
1943 assert!(validate_simplex(&[1.0], TOL).is_ok());
1944 }
1945
1946 #[test]
1947 fn validate_simplex_rejects_bad_sum() {
1948 assert!(validate_simplex(&[0.3, 0.6], TOL).is_err()); // sum=0.9
1949 }
1950
1951 #[test]
1952 fn validate_simplex_rejects_negative() {
1953 assert!(validate_simplex(&[1.5, -0.5], TOL).is_err());
1954 }
1955
1956 #[test]
1957 fn validate_simplex_rejects_empty() {
1958 assert!(validate_simplex(&[], TOL).is_err());
1959 }
1960
1961 #[test]
1962 fn normalize_in_place_works() {
1963 let mut v = vec![2.0, 3.0];
1964 let s = normalize_in_place(&mut v).unwrap();
1965 assert!((s - 5.0).abs() < 1e-12);
1966 assert!((v[0] - 0.4).abs() < 1e-12);
1967 assert!((v[1] - 0.6).abs() < 1e-12);
1968 }
1969
1970 #[test]
1971 fn normalize_in_place_rejects_zero_sum() {
1972 let mut v = vec![0.0, 0.0];
1973 assert!(normalize_in_place(&mut v).is_err());
1974 }
1975
1976 // --- Hellinger / Bhattacharyya tests ---
1977
1978 #[test]
1979 fn hellinger_identical_is_zero() {
1980 let p = [0.25, 0.75];
1981 let h = hellinger(&p, &p, TOL).unwrap();
1982 assert!(h.abs() < 1e-12);
1983 }
1984
1985 #[test]
1986 fn hellinger_squared_in_unit_interval() {
1987 let p = [0.1, 0.9];
1988 let q = [0.9, 0.1];
1989 let h2 = hellinger_squared(&p, &q, TOL).unwrap();
1990 assert!((-1e-12..=1.0 + 1e-12).contains(&h2), "h2={h2}");
1991 }
1992
1993 #[test]
1994 fn bhattacharyya_coeff_identical_is_one() {
1995 let p = [0.3, 0.7];
1996 let bc = bhattacharyya_coeff(&p, &p, TOL).unwrap();
1997 assert!((bc - 1.0).abs() < 1e-12);
1998 }
1999
2000 #[test]
2001 fn bhattacharyya_distance_identical_is_zero() {
2002 let p = [0.5, 0.5];
2003 let d = bhattacharyya_distance(&p, &p, TOL).unwrap();
2004 assert!(d.abs() < 1e-12);
2005 }
2006
2007 // --- Renyi / Tsallis tests ---
2008
2009 #[test]
2010 fn renyi_alpha_half_on_simple_case() {
2011 let p = [0.5, 0.5];
2012 let q = [0.25, 0.75];
2013 // alpha=0.5 should be well-defined and non-negative
2014 let r = renyi_divergence(&p, &q, 0.5, TOL).unwrap();
2015 assert!(r >= -1e-12, "renyi={r}");
2016 }
2017
2018 #[test]
2019 fn renyi_identical_is_zero() {
2020 let p = [0.3, 0.7];
2021 let r = renyi_divergence(&p, &p, 2.0, TOL).unwrap();
2022 assert!(r.abs() < 1e-12, "renyi(p,p)={r}");
2023 }
2024
2025 #[test]
2026 fn tsallis_identical_is_zero() {
2027 let p = [0.4, 0.6];
2028 let t = tsallis_divergence(&p, &p, 2.0, TOL).unwrap();
2029 assert!(t.abs() < 1e-12, "tsallis(p,p)={t}");
2030 }
2031
2032 // --- Digamma test ---
2033
2034 #[test]
2035 fn digamma_at_one_is_neg_euler_mascheroni() {
2036 let psi1 = digamma(1.0);
2037 // digamma(1) = -gamma where gamma ~= 0.57721566490153286
2038 assert!(
2039 (psi1 - (-0.57721566490153286)).abs() < 1e-12,
2040 "psi(1)={psi1}"
2041 );
2042 }
2043
2044 #[test]
2045 fn digamma_recurrence_relation() {
2046 // digamma(x+1) = digamma(x) + 1/x
2047 for &x in &[1.0, 2.0, 3.5, 10.0] {
2048 let lhs = digamma(x + 1.0);
2049 let rhs = digamma(x) + 1.0 / x;
2050 assert!(
2051 (lhs - rhs).abs() < 1e-12,
2052 "recurrence at x={x}: {lhs} vs {rhs}"
2053 );
2054 }
2055 }
2056
2057 #[test]
2058 fn pmi_independent_is_zero() {
2059 // PMI(x,y) = log(p(x,y) / (p(x)*p(y))). If independent: p(x,y) = p(x)*p(y)
2060 let pmi_val = pmi(0.06, 0.3, 0.2).unwrap(); // 0.3 * 0.2 = 0.06
2061 assert!(
2062 pmi_val.abs() < 1e-10,
2063 "PMI of independent events should be 0: {pmi_val}"
2064 );
2065 }
2066
2067 #[test]
2068 fn pmi_positive_for_correlated() {
2069 // If p(x,y) > p(x)*p(y), events are positively correlated
2070 let pmi_val = pmi(0.4, 0.5, 0.5).unwrap(); // 0.4 > 0.5*0.5 = 0.25
2071 assert!(
2072 pmi_val > 0.0,
2073 "correlated events should have positive PMI: {pmi_val}"
2074 );
2075 }
2076
2077 #[test]
2078 fn renyi_approaches_kl_as_alpha_to_one() {
2079 let p = [0.3, 0.7];
2080 let q = [0.5, 0.5];
2081 let tol = 1e-9;
2082 let kl = kl_divergence(&p, &q, tol).unwrap();
2083 // Renyi(alpha) -> KL as alpha -> 1
2084 let r099 = renyi_divergence(&p, &q, 0.99, tol).unwrap();
2085 let r0999 = renyi_divergence(&p, &q, 0.999, tol).unwrap();
2086 assert!((r099 - kl).abs() < 0.01, "Renyi(0.99)={r099}, KL={kl}");
2087 assert!((r0999 - kl).abs() < 0.001, "Renyi(0.999)={r0999}, KL={kl}");
2088 }
2089
2090 #[test]
2091 fn amari_alpha_neg1_is_kl_forward() {
2092 // Amari alpha=-1 returns KL(p||q) per the implementation
2093 let p = [0.3, 0.7];
2094 let q = [0.5, 0.5];
2095 let tol = 1e-9;
2096 let kl_pq = kl_divergence(&p, &q, tol).unwrap();
2097 let amari = amari_alpha_divergence(&p, &q, -1.0, tol).unwrap();
2098 assert!(
2099 (amari - kl_pq).abs() < 1e-6,
2100 "Amari(-1)={amari}, KL(p||q)={kl_pq}"
2101 );
2102 }
2103
2104 #[test]
2105 fn amari_alpha_pos1_is_kl_reverse() {
2106 // Amari alpha=+1 returns KL(q||p) per the implementation
2107 let p = [0.3, 0.7];
2108 let q = [0.5, 0.5];
2109 let tol = 1e-9;
2110 let kl_qp = kl_divergence(&q, &p, tol).unwrap();
2111 let amari = amari_alpha_divergence(&p, &q, 1.0, tol).unwrap();
2112 assert!(
2113 (amari - kl_qp).abs() < 1e-6,
2114 "Amari(1)={amari}, KL(q||p)={kl_qp}"
2115 );
2116 }
2117
2118 #[test]
2119 fn csiszar_with_kl_generator_matches_kl() {
2120 // f(t) = t*ln(t) gives KL divergence
2121 let p = [0.3, 0.7];
2122 let q = [0.5, 0.5];
2123 let tol = 1e-9;
2124 let kl = kl_divergence(&p, &q, tol).unwrap();
2125 let cs = csiszar_f_divergence(&p, &q, |t| t * t.ln(), tol).unwrap();
2126 assert!((cs - kl).abs() < 1e-6, "Csiszar(t*ln(t))={cs}, KL={kl}");
2127 }
2128
2129 #[test]
2130 fn mutual_information_deterministic_equals_entropy() {
2131 // If Y = f(X), MI(X;Y) = H(X)
2132 // Joint: p(x=0,y=0)=0.3, p(x=1,y=1)=0.7
2133 let p_xy = [0.3, 0.0, 0.0, 0.7]; // 2x2 joint
2134 let mi = mutual_information(&p_xy, 2, 2, 1e-9).unwrap();
2135 let h_x = entropy_nats(&[0.3, 0.7], 1e-9).unwrap();
2136 assert!((mi - h_x).abs() < 1e-6, "MI={mi}, H(X)={h_x}");
2137 }
2138
2139 proptest! {
2140 #[test]
2141 fn kl_is_nonnegative(p in simplex_vec_pos(8, 1e-6), q in simplex_vec_pos(8, 1e-6)) {
2142 let d = kl_divergence(&p, &q, 1e-6).unwrap();
2143 prop_assert!(d >= -1e-12);
2144 }
2145
2146 #[test]
2147 fn js_is_bounded(p in simplex_vec(16), q in simplex_vec(16)) {
2148 let js = jensen_shannon_divergence(&p, &q, 1e-6).unwrap();
2149 prop_assert!(js >= -1e-12);
2150 prop_assert!(js <= core::f64::consts::LN_2 + 1e-9);
2151 }
2152
2153 #[test]
2154 fn prop_kl_gaussians_is_nonnegative(
2155 mu1 in prop::collection::vec(-10.0f64..10.0, 1..16),
2156 std1 in prop::collection::vec(0.1f64..5.0, 1..16),
2157 mu2 in prop::collection::vec(-10.0f64..10.0, 1..16),
2158 std2 in prop::collection::vec(0.1f64..5.0, 1..16),
2159 ) {
2160 let n = mu1.len().min(std1.len()).min(mu2.len()).min(std2.len());
2161 let d = kl_divergence_gaussians(&mu1[..n], &std1[..n], &mu2[..n], &std2[..n]).unwrap();
2162 // KL divergence is always non-negative.
2163 prop_assert!(d >= -1e-12);
2164 }
2165
2166 #[test]
2167 fn prop_kl_gaussians_is_zero_for_identical(
2168 mu in prop::collection::vec(-10.0f64..10.0, 1..16),
2169 std in prop::collection::vec(0.1f64..5.0, 1..16),
2170 ) {
2171 let n = mu.len().min(std.len());
2172 let d = kl_divergence_gaussians(&mu[..n], &std[..n], &mu[..n], &std[..n]).unwrap();
2173 prop_assert!(d.abs() < 1e-12);
2174 }
2175
2176 #[test]
2177 fn f_divergence_monotone_under_coarse_graining(
2178 p in simplex_vec_pos(12, 1e-6),
2179 q in simplex_vec_pos(12, 1e-6),
2180 labels in random_partition(12),
2181 ) {
2182 // Use KL as an f-divergence instance: f(t)=t ln t.
2183 // D_KL(p||q) = Σ q_i f(p_i/q_i).
2184 let f = |t: f64| if t == 0.0 { 0.0 } else { t * t.ln() };
2185 let d_f = csiszar_f_divergence(&p, &q, f, 1e-6).unwrap();
2186
2187 let pc = coarse_grain(&p, &labels);
2188 let qc = coarse_grain(&q, &labels);
2189 let d_fc = csiszar_f_divergence(&pc, &qc, f, 1e-6).unwrap();
2190
2191 // Coarse graining should not increase.
2192 prop_assert!(d_fc <= d_f + 1e-9);
2193 }
2194 }
2195
2196 // Heavier “theorem-ish” checks: keep case count modest so `cargo test` stays fast.
2197 proptest! {
2198 #![proptest_config(ProptestConfig { cases: 64, .. ProptestConfig::default() })]
2199
2200 #[test]
2201 fn pinsker_kl_lower_bounds_l1_squared(
2202 p in simplex_vec_pos(16, 1e-6),
2203 q in simplex_vec_pos(16, 1e-6),
2204 ) {
2205 // Pinsker: TV(p,q)^2 <= (1/2) KL(p||q)
2206 // where TV = (1/2)||p-q||_1. Rearranged: KL(p||q) >= 0.5 * ||p-q||_1^2.
2207 let kl = kl_divergence(&p, &q, 1e-6).unwrap();
2208 let d1 = l1(&p, &q);
2209 prop_assert!(kl + 1e-9 >= 0.5 * d1 * d1, "kl={kl} l1={d1}");
2210 }
2211
2212 #[test]
2213 fn sqrt_js_satisfies_triangle_inequality(
2214 p in simplex_vec(12),
2215 q in simplex_vec(12),
2216 r in simplex_vec(12),
2217 ) {
2218 // Known fact: sqrt(JS) is a metric on the simplex.
2219 let js_pq = jensen_shannon_divergence(&p, &q, 1e-6).unwrap().max(0.0).sqrt();
2220 let js_qr = jensen_shannon_divergence(&q, &r, 1e-6).unwrap().max(0.0).sqrt();
2221 let js_pr = jensen_shannon_divergence(&p, &r, 1e-6).unwrap().max(0.0).sqrt();
2222 prop_assert!(js_pr <= js_pq + js_qr + 1e-7, "js_pr={js_pr} js_pq+js_qr={}", js_pq+js_qr);
2223 }
2224
2225 #[test]
2226 fn mutual_information_equals_kl_to_product(
2227 // Ensure strictly positive so KL domains are satisfied.
2228 p_xy in simplex_vec_pos(16, 1e-6),
2229 nx in 2usize..=4,
2230 ny in 2usize..=4,
2231 ) {
2232 // We need p_xy to have length nx*ny; we will truncate/renormalize a fixed-length draw.
2233 let n = nx * ny;
2234 let mut joint = p_xy;
2235 joint.truncate(n);
2236 // Renormalize after truncation.
2237 let _ = normalize_in_place(&mut joint).unwrap();
2238
2239 // Compute MI via the dedicated function.
2240 let mi = mutual_information(&joint, nx, ny, 1e-6).unwrap();
2241
2242 // Compute product of marginals and KL(joint || product).
2243 let mut p_x = vec![0.0; nx];
2244 let mut p_y = vec![0.0; ny];
2245 for i in 0..nx {
2246 for j in 0..ny {
2247 let p = joint[i * ny + j];
2248 p_x[i] += p;
2249 p_y[j] += p;
2250 }
2251 }
2252 let mut prod = vec![0.0; n];
2253 for i in 0..nx {
2254 for j in 0..ny {
2255 prod[i * ny + j] = p_x[i] * p_y[j];
2256 }
2257 }
2258 let kl = kl_divergence(&joint, &prod, 1e-6).unwrap();
2259
2260 prop_assert!((mi - kl).abs() < 1e-9, "mi={mi} kl={kl}");
2261 }
2262
2263 #[test]
2264 fn hellinger_satisfies_triangle_inequality(
2265 p in simplex_vec(8),
2266 q in simplex_vec(8),
2267 r in simplex_vec(8),
2268 ) {
2269 let h_pq = hellinger(&p, &q, 1e-6).unwrap();
2270 let h_qr = hellinger(&q, &r, 1e-6).unwrap();
2271 let h_pr = hellinger(&p, &r, 1e-6).unwrap();
2272 prop_assert!(h_pr <= h_pq + h_qr + 1e-7, "h_pr={h_pr} h_pq+h_qr={}", h_pq + h_qr);
2273 }
2274 }
2275
2276 // --- total_bregman_divergence ---
2277
2278 #[test]
2279 fn total_bregman_le_bregman() {
2280 // tB_F(p, q) <= B_F(p, q) because the denominator sqrt(1 + ||grad||^2) >= 1.
2281 let gen = SquaredL2;
2282 let p = [1.0, 2.0, 3.0];
2283 let q = [4.0, 5.0, 6.0];
2284 let b = bregman_divergence(&gen, &p, &q).unwrap();
2285 let tb = total_bregman_divergence(&gen, &p, &q).unwrap();
2286 assert!(tb <= b + 1e-12, "total_bregman={tb} > bregman={b}");
2287 assert!(tb >= 0.0);
2288 }
2289
2290 #[test]
2291 fn total_bregman_is_zero_for_identical() {
2292 let gen = SquaredL2;
2293 let p = [1.0, 2.0];
2294 let tb = total_bregman_divergence(&gen, &p, &p).unwrap();
2295 assert!(tb.abs() < 1e-15);
2296 }
2297
2298 // --- rho_alpha ---
2299
2300 #[test]
2301 fn rho_alpha_self_is_one() {
2302 let p = [0.1, 0.2, 0.3, 0.4];
2303 for alpha in [0.0, 0.25, 0.5, 0.75, 1.0, 2.0, -1.0] {
2304 let r = rho_alpha(&p, &p, alpha, TOL).unwrap();
2305 assert!((r - 1.0).abs() < 1e-10, "rho_alpha(p,p,{alpha})={r}");
2306 }
2307 }
2308
2309 // --- digamma negative domain ---
2310
2311 #[test]
2312 fn digamma_nonpositive_is_nan() {
2313 assert!(digamma(0.0).is_nan());
2314 assert!(digamma(-1.0).is_nan());
2315 assert!(digamma(-100.0).is_nan());
2316 }
2317
2318 // --- pmi edge cases ---
2319
2320 #[test]
2321 fn pmi_zero_joint_returns_zero() {
2322 assert_eq!(pmi(0.0, 0.5, 0.5).unwrap(), 0.0);
2323 }
2324
2325 #[test]
2326 fn pmi_zero_marginal_with_zero_joint_returns_zero() {
2327 // When px or py is zero AND pxy is also zero, return 0 by convention.
2328 assert_eq!(pmi(0.0, 0.0, 0.5).unwrap(), 0.0);
2329 assert_eq!(pmi(0.0, 0.5, 0.0).unwrap(), 0.0);
2330 }
2331
2332 #[test]
2333 fn pmi_all_zero_returns_zero() {
2334 assert_eq!(pmi(0.0, 0.0, 0.0).unwrap(), 0.0);
2335 }
2336
2337 // --- Enrich-motivated tests ---
2338
2339 #[test]
2340 fn digamma_at_dlmf_reference_values() {
2341 // psi(0.5) = -gamma - 2*ln(2)
2342 let gamma = 0.57721566490153286;
2343 let expected_half = -gamma - 2.0 * core::f64::consts::LN_2;
2344 let psi_half = digamma(0.5);
2345 assert!(
2346 (psi_half - expected_half).abs() < 1e-12,
2347 "psi(0.5)={psi_half} expected={expected_half}"
2348 );
2349
2350 // psi(2) = 1 - gamma
2351 let expected_2 = 1.0 - gamma;
2352 let psi_2 = digamma(2.0);
2353 assert!(
2354 (psi_2 - expected_2).abs() < 1e-12,
2355 "psi(2)={psi_2} expected={expected_2}"
2356 );
2357
2358 // psi(3) = 1 + 1/2 - gamma
2359 let expected_3 = 1.5 - gamma;
2360 let psi_3 = digamma(3.0);
2361 assert!(
2362 (psi_3 - expected_3).abs() < 1e-12,
2363 "psi(3)={psi_3} expected={expected_3}"
2364 );
2365
2366 // psi(4) = 1 + 1/2 + 1/3 - gamma
2367 let expected_4 = 1.0 + 0.5 + 1.0 / 3.0 - gamma;
2368 let psi_4 = digamma(4.0);
2369 assert!(
2370 (psi_4 - expected_4).abs() < 1e-12,
2371 "psi(4)={psi_4} expected={expected_4}"
2372 );
2373 }
2374
2375 #[test]
2376 fn tsallis_approaches_kl_as_alpha_to_one() {
2377 let p = [0.3, 0.7];
2378 let q = [0.5, 0.5];
2379 let tol = 1e-9;
2380 let kl = kl_divergence(&p, &q, tol).unwrap();
2381 // Tsallis(alpha) -> KL as alpha -> 1 (from both sides)
2382 let t099 = tsallis_divergence(&p, &q, 0.99, tol).unwrap();
2383 let t0999 = tsallis_divergence(&p, &q, 0.999, tol).unwrap();
2384 let t101 = tsallis_divergence(&p, &q, 1.01, tol).unwrap();
2385 let t1001 = tsallis_divergence(&p, &q, 1.001, tol).unwrap();
2386 assert!((t099 - kl).abs() < 0.01, "Tsallis(0.99)={t099}, KL={kl}");
2387 assert!(
2388 (t0999 - kl).abs() < 0.001,
2389 "Tsallis(0.999)={t0999}, KL={kl}"
2390 );
2391 assert!((t101 - kl).abs() < 0.01, "Tsallis(1.01)={t101}, KL={kl}");
2392 assert!(
2393 (t1001 - kl).abs() < 0.001,
2394 "Tsallis(1.001)={t1001}, KL={kl}"
2395 );
2396 }
2397
2398 #[test]
2399 fn renyi_approaches_kl_from_above() {
2400 let p = [0.3, 0.7];
2401 let q = [0.5, 0.5];
2402 let tol = 1e-9;
2403 let kl = kl_divergence(&p, &q, tol).unwrap();
2404 let r101 = renyi_divergence(&p, &q, 1.01, tol).unwrap();
2405 let r1001 = renyi_divergence(&p, &q, 1.001, tol).unwrap();
2406 assert!((r101 - kl).abs() < 0.01, "Renyi(1.01)={r101}, KL={kl}");
2407 assert!((r1001 - kl).abs() < 0.001, "Renyi(1.001)={r1001}, KL={kl}");
2408 }
2409
2410 #[test]
2411 fn renyi_at_half_equals_neg2_ln_bc() {
2412 // D_{1/2}^R(p || q) = -2 * ln(BC(p, q))
2413 let p = [0.2, 0.3, 0.5];
2414 let q = [0.4, 0.4, 0.2];
2415 let tol = 1e-9;
2416 let renyi_half = renyi_divergence(&p, &q, 0.5, tol).unwrap();
2417 let bc = bhattacharyya_coeff(&p, &q, tol).unwrap();
2418 let expected = -2.0 * bc.ln();
2419 assert!(
2420 (renyi_half - expected).abs() < 1e-10,
2421 "Renyi(0.5)={renyi_half}, -2*ln(BC)={expected}"
2422 );
2423 }
2424
2425 #[test]
2426 fn hellinger_squared_equals_one_minus_bc() {
2427 let p = [0.1, 0.4, 0.5];
2428 let q = [0.3, 0.3, 0.4];
2429 let tol = 1e-9;
2430 let h2 = hellinger_squared(&p, &q, tol).unwrap();
2431 let bc = bhattacharyya_coeff(&p, &q, tol).unwrap();
2432 assert!(
2433 (h2 - (1.0 - bc)).abs() < 1e-12,
2434 "H^2={h2}, 1-BC={}",
2435 1.0 - bc
2436 );
2437 }
2438
2439 #[test]
2440 fn csiszar_hellinger_generator_matches_twice_hellinger_squared() {
2441 // f(t) = (sqrt(t) - 1)^2 gives sum(q_i*(sqrt(p_i/q_i)-1)^2)
2442 // = sum((sqrt(p_i) - sqrt(q_i))^2) = 2*(1 - BC) = 2*H^2.
2443 let p = [0.2, 0.3, 0.5];
2444 let q = [0.4, 0.4, 0.2];
2445 let tol = 1e-9;
2446 let h2 = hellinger_squared(&p, &q, tol).unwrap();
2447 let cs = csiszar_f_divergence(&p, &q, |t| (t.sqrt() - 1.0).powi(2), tol).unwrap();
2448 assert!(
2449 (cs - 2.0 * h2).abs() < 1e-10,
2450 "Csiszar(Hellinger)={cs}, 2*H^2={}",
2451 2.0 * h2
2452 );
2453 }
2454
2455 #[test]
2456 fn csiszar_chi_squared_generator_is_nonneg() {
2457 // f(t) = (t - 1)^2 gives chi-squared divergence.
2458 let p = [0.2, 0.3, 0.5];
2459 let q = [0.4, 0.4, 0.2];
2460 let tol = 1e-9;
2461 let chi2 = csiszar_f_divergence(&p, &q, |t| (t - 1.0).powi(2), tol).unwrap();
2462 assert!(chi2 >= 0.0, "chi2={chi2}");
2463 // Chi-squared(p,p) = 0
2464 let chi2_self = csiszar_f_divergence(&p, &p, |t| (t - 1.0).powi(2), tol).unwrap();
2465 assert!(chi2_self.abs() < 1e-12, "chi2(p,p)={chi2_self}");
2466 }
2467
2468 #[test]
2469 fn near_boundary_inputs_no_nan() {
2470 // Distributions with entries near machine epsilon.
2471 let tiny = 1e-300;
2472 let p = [tiny, 1.0 - tiny];
2473 let q = [tiny * 2.0, 1.0 - tiny * 2.0];
2474 let tol = 1e-6;
2475
2476 let kl = kl_divergence(&p, &q, tol).unwrap();
2477 assert!(kl.is_finite(), "kl={kl}");
2478 // Allow tiny negative from floating-point at extreme values.
2479 assert!(kl >= -1e-12, "kl negative: {kl}");
2480
2481 let js = jensen_shannon_divergence(&p, &q, tol).unwrap();
2482 assert!(js.is_finite(), "js={js}");
2483
2484 let h = hellinger(&p, &q, tol).unwrap();
2485 assert!(h.is_finite(), "hellinger={h}");
2486
2487 let bc = bhattacharyya_coeff(&p, &q, tol).unwrap();
2488 assert!(bc.is_finite(), "bc={bc}");
2489
2490 let ent = entropy_nats(&p, tol).unwrap();
2491 assert!(ent.is_finite(), "entropy={ent}");
2492 }
2493
2494 proptest! {
2495 #![proptest_config(ProptestConfig { cases: 64, .. ProptestConfig::default() })]
2496
2497 #[test]
2498 fn entropy_is_concave(
2499 p in simplex_vec(8),
2500 q in simplex_vec(8),
2501 lambda in 0.0f64..=1.0,
2502 ) {
2503 // H(lambda*p + (1-lambda)*q) >= lambda*H(p) + (1-lambda)*H(q)
2504 let mix: Vec<f64> = p.iter().zip(q.iter())
2505 .map(|(&pi, &qi)| lambda * pi + (1.0 - lambda) * qi)
2506 .collect();
2507 let h_mix = entropy_nats(&mix, 1e-6).unwrap();
2508 let h_p = entropy_nats(&p, 1e-6).unwrap();
2509 let h_q = entropy_nats(&q, 1e-6).unwrap();
2510 let rhs = lambda * h_p + (1.0 - lambda) * h_q;
2511 prop_assert!(h_mix + 1e-10 >= rhs, "h_mix={h_mix} rhs={rhs}");
2512 }
2513
2514 #[test]
2515 fn renyi_monotone_in_alpha(
2516 p in simplex_vec_pos(8, 1e-6),
2517 q in simplex_vec_pos(8, 1e-6),
2518 ) {
2519 // Renyi divergence is non-decreasing in alpha for fixed p, q.
2520 let alphas = [0.1, 0.25, 0.5, 0.75, 0.99];
2521 let vals: Vec<f64> = alphas.iter()
2522 .map(|&a| renyi_divergence(&p, &q, a, 1e-6).unwrap())
2523 .collect();
2524 for i in 1..vals.len() {
2525 prop_assert!(
2526 vals[i] + 1e-9 >= vals[i - 1],
2527 "Renyi not monotone: D({})={} < D({})={}",
2528 alphas[i], vals[i], alphas[i - 1], vals[i - 1]
2529 );
2530 }
2531 }
2532
2533 #[test]
2534 fn cross_entropy_decomposition(
2535 p in simplex_vec_pos(8, 1e-6),
2536 q in simplex_vec_pos(8, 1e-6),
2537 ) {
2538 // H(p, q) = H(p) + KL(p || q)
2539 let h_pq = cross_entropy_nats(&p, &q, 1e-6).unwrap();
2540 let h_p = entropy_nats(&p, 1e-6).unwrap();
2541 let kl = kl_divergence(&p, &q, 1e-6).unwrap();
2542 prop_assert!(
2543 (h_pq - (h_p + kl)).abs() < 1e-9,
2544 "H(p,q)={h_pq} != H(p)+KL={}", h_p + kl
2545 );
2546 }
2547
2548 #[test]
2549 fn bhattacharyya_renyi_consistency(
2550 p in simplex_vec_pos(8, 1e-6),
2551 q in simplex_vec_pos(8, 1e-6),
2552 ) {
2553 // D_{1/2}^R(p || q) = -2 * ln(BC(p, q))
2554 let renyi_half = renyi_divergence(&p, &q, 0.5, 1e-6).unwrap();
2555 let bc = bhattacharyya_coeff(&p, &q, 1e-6).unwrap();
2556 let expected = -2.0 * bc.ln();
2557 prop_assert!(
2558 (renyi_half - expected).abs() < 1e-8,
2559 "Renyi(0.5)={renyi_half}, -2*ln(BC)={expected}"
2560 );
2561 }
2562
2563 #[test]
2564 fn csiszar_hellinger_consistency(
2565 p in simplex_vec_pos(8, 1e-6),
2566 q in simplex_vec_pos(8, 1e-6),
2567 ) {
2568 // Csiszar with f(t)=(sqrt(t)-1)^2 = 2 * squared Hellinger
2569 let h2 = hellinger_squared(&p, &q, 1e-6).unwrap();
2570 let cs = csiszar_f_divergence(&p, &q, |t| (t.sqrt() - 1.0).powi(2), 1e-6).unwrap();
2571 prop_assert!(
2572 (cs - 2.0 * h2).abs() < 1e-8,
2573 "Csiszar(Hellinger)={cs}, 2*H^2={}", 2.0 * h2
2574 );
2575 }
2576
2577 #[test]
2578 fn pinsker_tightness_for_nearby_distributions(
2579 p in simplex_vec_pos(8, 1e-6),
2580 ) {
2581 // For a distribution slightly perturbed from p, check Pinsker is non-vacuous.
2582 // q = 0.99*p + 0.01*uniform (epsilon-perturbation).
2583 let n = p.len();
2584 let q: Vec<f64> = p.iter().map(|&pi| 0.99 * pi + 0.01 / n as f64).collect();
2585 let kl = kl_divergence(&p, &q, 1e-6).unwrap();
2586 let d1: f64 = p.iter().zip(q.iter()).map(|(&a, &b)| (a - b).abs()).sum();
2587 let pinsker_rhs = 0.5 * d1 * d1;
2588 // Pinsker: KL >= 0.5 * L1^2
2589 prop_assert!(kl + 1e-12 >= pinsker_rhs, "kl={kl} pinsker_rhs={pinsker_rhs}");
2590 // Non-vacuity: for nearby distributions, KL / (0.5*L1^2) should be O(1),
2591 // not astronomically large (would indicate the bound is useful).
2592 if pinsker_rhs > 1e-20 {
2593 let ratio = kl / pinsker_rhs;
2594 prop_assert!(ratio < 1000.0, "Pinsker ratio too large: {ratio}");
2595 }
2596 }
2597
2598 #[test]
2599 fn total_variation_satisfies_triangle(
2600 p in simplex_vec(8),
2601 q in simplex_vec(8),
2602 r in simplex_vec(8),
2603 ) {
2604 let tv_pq = total_variation(&p, &q, 1e-6).unwrap();
2605 let tv_qr = total_variation(&q, &r, 1e-6).unwrap();
2606 let tv_pr = total_variation(&p, &r, 1e-6).unwrap();
2607 prop_assert!(tv_pr <= tv_pq + tv_qr + 1e-10);
2608 }
2609
2610 #[test]
2611 fn chi_squared_matches_csiszar(
2612 p in simplex_vec_pos(8, 1e-6),
2613 q in simplex_vec_pos(8, 1e-6),
2614 ) {
2615 let chi2 = chi_squared_divergence(&p, &q, 1e-6).unwrap();
2616 let cs = csiszar_f_divergence(&p, &q, |t| (t - 1.0).powi(2), 1e-6).unwrap();
2617 prop_assert!(
2618 (chi2 - cs).abs() < 1e-8,
2619 "chi2={chi2}, csiszar={cs}"
2620 );
2621 }
2622
2623 #[test]
2624 fn renyi_entropy_monotone_in_alpha(
2625 p in simplex_vec_pos(8, 1e-6),
2626 ) {
2627 // H_alpha is non-increasing in alpha.
2628 let alphas = [0.1, 0.25, 0.5, 0.75, 0.99];
2629 let vals: Vec<f64> = alphas.iter()
2630 .map(|&a| renyi_entropy(&p, a, 1e-6).unwrap())
2631 .collect();
2632 for i in 1..vals.len() {
2633 prop_assert!(
2634 vals[i] <= vals[i - 1] + 1e-9,
2635 "H_alpha not monotone: H({})={} > H({})={}",
2636 alphas[i], vals[i], alphas[i - 1], vals[i - 1]
2637 );
2638 }
2639 }
2640 }
2641
2642 // --- New function unit tests ---
2643
2644 #[test]
2645 fn weighted_js_at_extreme_weights() {
2646 let p = [0.3, 0.7];
2647 let q = [0.5, 0.5];
2648 // pi1 = 0 means m = q, KL(q || q) = 0
2649 let js0 = jensen_shannon_weighted(&p, &q, 0.0, TOL).unwrap();
2650 assert!(js0.abs() < 1e-12, "JS(pi=0)={js0}");
2651 }
2652
2653 #[test]
2654 fn conditional_entropy_chain_rule() {
2655 // H(X|Y) = H(X,Y) - H(Y)
2656 let p_xy = [0.2, 0.1, 0.3, 0.4]; // 2x2
2657 let h_xy = entropy_nats(&p_xy, TOL).unwrap();
2658 let p_y = [p_xy[0] + p_xy[2], p_xy[1] + p_xy[3]];
2659 let h_y = entropy_nats(&p_y, TOL).unwrap();
2660 let h_x_given_y = conditional_entropy(&p_xy, 2, 2, TOL).unwrap();
2661 assert!(
2662 (h_x_given_y - (h_xy - h_y)).abs() < 1e-10,
2663 "H(X|Y)={h_x_given_y}, H(X,Y)-H(Y)={}",
2664 h_xy - h_y
2665 );
2666 }
2667
2668 #[test]
2669 fn conditional_entropy_nonnegative() {
2670 let p_xy = [0.1, 0.2, 0.3, 0.4];
2671 let h = conditional_entropy(&p_xy, 2, 2, TOL).unwrap();
2672 assert!(h >= -1e-12, "H(X|Y) negative: {h}");
2673 }
2674
2675 #[test]
2676 fn nmi_bounds() {
2677 // NMI for a general joint: should be in [0, 1].
2678 let p_xy = [0.1, 0.2, 0.3, 0.4];
2679 let nmi = normalized_mutual_information(&p_xy, 2, 2, TOL).unwrap();
2680 assert!((-1e-12..=1.0 + 1e-12).contains(&nmi), "nmi={nmi}");
2681 }
2682
2683 #[test]
2684 fn total_variation_self_is_zero() {
2685 let p = [0.3, 0.7];
2686 assert!(total_variation(&p, &p, TOL).unwrap().abs() < 1e-15);
2687 }
2688
2689 #[test]
2690 fn total_variation_disjoint_is_one() {
2691 let a = [1.0, 0.0];
2692 let b = [0.0, 1.0];
2693 assert!((total_variation(&a, &b, TOL).unwrap() - 1.0).abs() < 1e-12);
2694 }
2695
2696 #[test]
2697 fn chi_squared_self_is_zero() {
2698 let p = [0.3, 0.7];
2699 assert!(chi_squared_divergence(&p, &p, TOL).unwrap().abs() < 1e-15);
2700 }
2701
2702 #[test]
2703 fn chi_squared_upper_bounds_kl() {
2704 // KL(p||q) <= ln(1 + chi2(p||q))
2705 let p = [0.2, 0.3, 0.5];
2706 let q = [0.4, 0.4, 0.2];
2707 let kl = kl_divergence(&p, &q, TOL).unwrap();
2708 let chi2 = chi_squared_divergence(&p, &q, TOL).unwrap();
2709 assert!(
2710 kl <= (1.0 + chi2).ln() + 1e-10,
2711 "kl={kl} > ln(1+chi2)={}",
2712 (1.0 + chi2).ln()
2713 );
2714 }
2715
2716 #[test]
2717 fn renyi_entropy_uniform_is_ln_n() {
2718 let p = [0.25, 0.25, 0.25, 0.25];
2719 for alpha in [0.5, 2.0, 3.0, 10.0] {
2720 let h = renyi_entropy(&p, alpha, TOL).unwrap();
2721 let expected = 4.0_f64.ln();
2722 assert!(
2723 (h - expected).abs() < 1e-12,
2724 "H_{alpha}(uniform) = {h}, expected {expected}"
2725 );
2726 }
2727 }
2728
2729 #[test]
2730 fn tsallis_entropy_delta_is_zero() {
2731 let delta = [1.0, 0.0, 0.0];
2732 for alpha in [0.5, 2.0, 3.0] {
2733 let s = tsallis_entropy(&delta, alpha, TOL).unwrap();
2734 assert!(s.abs() < 1e-12, "Tsallis({alpha}) of delta = {s}");
2735 }
2736 }
2737
2738 #[test]
2739 fn renyi_entropy_collision() {
2740 // H_2(p) = -ln(sum(p_i^2))
2741 let p = [0.3, 0.7];
2742 let h2 = renyi_entropy(&p, 2.0, TOL).unwrap();
2743 let expected = -(0.3_f64.powi(2) + 0.7_f64.powi(2)).ln();
2744 assert!(
2745 (h2 - expected).abs() < 1e-12,
2746 "H_2={h2} expected={expected}"
2747 );
2748 }
2749
2750 #[test]
2751 fn neg_entropy_bregman_matches_kl_on_simplex() {
2752 // For normalized p, q: B_NegEntropy(p, q) = KL(p || q).
2753 let p = [0.2, 0.3, 0.5];
2754 let q = [0.4, 0.4, 0.2];
2755 let kl = kl_divergence(&p, &q, TOL).unwrap();
2756 let gen = NegEntropy;
2757 let breg = bregman_divergence(&gen, &p, &q).unwrap();
2758 assert!(
2759 (breg - kl).abs() < 1e-10,
2760 "Bregman(NegEntropy)={breg}, KL={kl}"
2761 );
2762 }
2763
2764 #[test]
2765 fn neg_entropy_bregman_self_is_zero() {
2766 let p = [0.3, 0.7];
2767 let gen = NegEntropy;
2768 let breg = bregman_divergence(&gen, &p, &p).unwrap();
2769 assert!(breg.abs() < 1e-14, "Bregman(p,p)={breg}");
2770 }
2771
2772 // --- Streaming log-sum-exp tests ---
2773
2774 #[test]
2775 fn log_sum_exp_iter_matches_slice() {
2776 let values = [1.0, 2.0, 3.0, -1.0, 0.5];
2777 let lse_slice = log_sum_exp(&values);
2778 let lse_iter = log_sum_exp_iter(values.iter().copied());
2779 assert!(
2780 (lse_slice - lse_iter).abs() < 1e-12,
2781 "slice={lse_slice} iter={lse_iter}"
2782 );
2783 }
2784
2785 #[test]
2786 fn log_sum_exp_iter_empty() {
2787 assert_eq!(log_sum_exp_iter(std::iter::empty()), f64::NEG_INFINITY);
2788 }
2789
2790 #[test]
2791 fn log_sum_exp_iter_single() {
2792 assert_eq!(log_sum_exp_iter(std::iter::once(42.0)), 42.0);
2793 }
2794
2795 #[test]
2796 fn log_sum_exp_iter_large_values() {
2797 // Same stability test as log_sum_exp: dominated term.
2798 let lse = log_sum_exp_iter([1000.0, 0.0].iter().copied());
2799 assert!((lse - 1000.0).abs() < 1e-10);
2800 }
2801
2802 // --- Data processing inequality for discrete MI ---
2803
2804 #[test]
2805 fn data_processing_inequality_mi() {
2806 // Apply a deterministic coarse-graining (Markov kernel) to Y.
2807 // MI(X; f(Y)) <= MI(X; Y).
2808 let p_xy = [0.3, 0.1, 0.05, 0.05, 0.1, 0.2, 0.05, 0.15]; // 2x4
2809 let n_x = 2;
2810 let n_y = 4;
2811 let mi_full = mutual_information(&p_xy, n_x, n_y, TOL).unwrap();
2812
2813 // Coarse-grain Y: merge bins 0+1 and 2+3 -> 2x2.
2814 let mut p_coarse = [0.0; 4]; // 2x2
2815 for i in 0..n_x {
2816 p_coarse[i * 2] = p_xy[i * n_y] + p_xy[i * n_y + 1];
2817 p_coarse[i * 2 + 1] = p_xy[i * n_y + 2] + p_xy[i * n_y + 3];
2818 }
2819 let mi_coarse = mutual_information(&p_coarse, n_x, 2, TOL).unwrap();
2820
2821 assert!(
2822 mi_coarse <= mi_full + 1e-10,
2823 "DPI violated: MI(coarse)={mi_coarse} > MI(full)={mi_full}"
2824 );
2825 }
2826
2827 // --- Weighted JS additional tests ---
2828
2829 #[test]
2830 fn weighted_js_bounded_by_entropy_of_weights() {
2831 // JS_pi(p, q) <= H(pi) = -pi1*ln(pi1) - pi2*ln(pi2)
2832 let p = [0.1, 0.9];
2833 let q = [0.9, 0.1];
2834 let pi1 = 0.3;
2835 let jsw = jensen_shannon_weighted(&p, &q, pi1, TOL).unwrap();
2836 let pi2 = 1.0 - pi1;
2837 let h_pi = -(pi1 * pi1.ln() + pi2 * pi2.ln());
2838 assert!(jsw <= h_pi + 1e-10, "JS_pi={jsw} > H(pi)={h_pi}");
2839 }
2840
2841 // --- Renyi/Tsallis entropy approach Shannon ---
2842
2843 #[test]
2844 fn renyi_entropy_approaches_shannon_as_alpha_to_one() {
2845 let p = [0.2, 0.3, 0.5];
2846 let h_shannon = entropy_nats(&p, TOL).unwrap();
2847 let h_099 = renyi_entropy(&p, 0.99, TOL).unwrap();
2848 let h_0999 = renyi_entropy(&p, 0.999, TOL).unwrap();
2849 let h_101 = renyi_entropy(&p, 1.01, TOL).unwrap();
2850 assert!((h_099 - h_shannon).abs() < 0.01);
2851 assert!((h_0999 - h_shannon).abs() < 0.001);
2852 assert!((h_101 - h_shannon).abs() < 0.01);
2853 }
2854
2855 #[test]
2856 fn tsallis_entropy_approaches_shannon_as_alpha_to_one() {
2857 let p = [0.2, 0.3, 0.5];
2858 let h_shannon = entropy_nats(&p, TOL).unwrap();
2859 let s_099 = tsallis_entropy(&p, 0.99, TOL).unwrap();
2860 let s_0999 = tsallis_entropy(&p, 0.999, TOL).unwrap();
2861 assert!((s_099 - h_shannon).abs() < 0.01);
2862 assert!((s_0999 - h_shannon).abs() < 0.001);
2863 }
2864
2865 // --- New tests for 0.2.0 ---
2866
2867 #[test]
2868 fn bhattacharyya_precision_near_identical() {
2869 let p = [0.5 + 1e-15, 0.5 - 1e-15];
2870 let q = [0.5, 0.5];
2871 let bc = bhattacharyya_coeff(&p, &q, 1e-6).unwrap();
2872 assert!(
2873 (bc - 1.0).abs() < 1e-14,
2874 "BC should be very close to 1.0: {bc}"
2875 );
2876 let h2 = hellinger_squared(&p, &q, 1e-6).unwrap();
2877 // h2 should be tiny but not exactly zero (numerics permitting).
2878 assert!(h2 < 1e-14, "h2 should be tiny: {h2}");
2879 assert!(h2.is_finite(), "h2 should be finite");
2880 }
2881
2882 #[test]
2883 fn renyi_alpha_sweep_continuity() {
2884 let p = [0.2, 0.3, 0.5];
2885 let tol = 1e-9;
2886 let mut prev_renyi = renyi_entropy(&p, 0.5, tol).unwrap();
2887 let mut prev_tsallis = tsallis_entropy(&p, 0.5, tol).unwrap();
2888 let mut alpha = 0.6;
2889 while alpha <= 2.0 + 1e-9 {
2890 let r = renyi_entropy(&p, alpha, tol).unwrap();
2891 let t = tsallis_entropy(&p, alpha, tol).unwrap();
2892 let jump_r = (r - prev_renyi).abs();
2893 let jump_t = (t - prev_tsallis).abs();
2894 assert!(
2895 jump_r < 0.5,
2896 "Renyi discontinuity at alpha={alpha}: jump={jump_r}"
2897 );
2898 assert!(
2899 jump_t < 0.5,
2900 "Tsallis discontinuity at alpha={alpha}: jump={jump_t}"
2901 );
2902 prev_renyi = r;
2903 prev_tsallis = t;
2904 alpha += 0.1;
2905 }
2906 }
2907
2908 #[test]
2909 fn ksg_ties_finite() {
2910 // Integer-valued data with exact ties.
2911 let x: Vec<Vec<f64>> = (0..50).map(|i| vec![(i % 5) as f64]).collect();
2912 let y: Vec<Vec<f64>> = (0..50).map(|i| vec![(i % 3) as f64]).collect();
2913 let mi1 = mutual_information_ksg(&x, &y, 3, KsgVariant::Alg1).unwrap();
2914 let mi2 = mutual_information_ksg(&x, &y, 3, KsgVariant::Alg2).unwrap();
2915 assert!(
2916 mi1.is_finite(),
2917 "KSG Alg1 with ties returned NaN/Inf: {mi1}"
2918 );
2919 assert!(
2920 mi2.is_finite(),
2921 "KSG Alg2 with ties returned NaN/Inf: {mi2}"
2922 );
2923 }
2924
2925 proptest! {
2926 #![proptest_config(ProptestConfig { cases: 64, .. ProptestConfig::default() })]
2927
2928 #[test]
2929 fn bregman_nonnegative(
2930 p in simplex_vec_pos(8, 1e-6),
2931 q in simplex_vec_pos(8, 1e-6),
2932 ) {
2933 let gen = NegEntropy;
2934 let b = bregman_divergence(&gen, &p, &q).unwrap();
2935 prop_assert!(b >= -1e-12, "Bregman(NegEntropy) negative: {b}");
2936 }
2937
2938 #[test]
2939 fn renyi_divergence_alpha1_equals_kl(
2940 p in simplex_vec_pos(8, 1e-6),
2941 q in simplex_vec_pos(8, 1e-6),
2942 ) {
2943 let tol = 1e-6;
2944 let kl = kl_divergence(&p, &q, tol).unwrap();
2945 let r1 = renyi_divergence(&p, &q, 1.0, tol).unwrap();
2946 prop_assert!(
2947 (r1 - kl).abs() < 1e-9,
2948 "renyi(alpha=1)={r1} != kl={kl}"
2949 );
2950 }
2951 }
2952
2953 #[test]
2954 fn pmi_impossible_input_errors() {
2955 // p(x,y)>0 but p(x)=0 is impossible.
2956 assert!(pmi(0.1, 0.0, 0.5).is_err());
2957 // p(x,y)>0 but p(y)=0 is impossible.
2958 assert!(pmi(0.1, 0.5, 0.0).is_err());
2959 }
2960}