ndarray_stats/
entropy.rs

1//! Information theory (e.g. entropy, KL divergence, etc.).
2use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch};
3use ndarray::{Array, ArrayRef, Dimension, Zip};
4use num_traits::Float;
5
6/// Extension trait for `ndarray` providing methods
7/// to compute information theory quantities
8/// (e.g. entropy, Kullback–Leibler divergence, etc.).
9pub trait EntropyExt<A, D>
10where
11    D: Dimension,
12{
13    /// Computes the [entropy] *S* of the array values, defined as
14    ///
15    /// ```text
16    ///       n
17    /// S = - ∑ xᵢ ln(xᵢ)
18    ///      i=1
19    /// ```
20    ///
21    /// If the array is empty, `Err(EmptyInput)` is returned.
22    ///
23    /// **Panics** if `ln` of any element in the array panics (which can occur for negative values for some `A`).
24    ///
25    /// ## Remarks
26    ///
27    /// The entropy is a measure used in [Information Theory]
28    /// to describe a probability distribution: it only make sense
29    /// when the array values sum to 1, with each entry between
30    /// 0 and 1 (extremes included).
31    ///
32    /// The array values are **not** normalised by this function before
33    /// computing the entropy to avoid introducing potentially
34    /// unnecessary numerical errors (e.g. if the array were to be already normalised).
35    ///
36    /// By definition, *xᵢ ln(xᵢ)* is set to 0 if *xᵢ* is 0.
37    ///
38    /// [entropy]: https://en.wikipedia.org/wiki/Entropy_(information_theory)
39    /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
40    fn entropy(&self) -> Result<A, EmptyInput>
41    where
42        A: Float;
43
44    /// Computes the [Kullback-Leibler divergence] *Dₖₗ(p,q)* between two arrays,
45    /// where `self`=*p*.
46    ///
47    /// The Kullback-Leibler divergence is defined as:
48    ///
49    /// ```text
50    ///              n
51    /// Dₖₗ(p,q) = - ∑ pᵢ ln(qᵢ/pᵢ)
52    ///             i=1
53    /// ```
54    ///
55    /// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
56    /// If the array shapes are not identical,
57    /// `Err(MultiInputError::ShapeMismatch)` is returned.
58    ///
59    /// **Panics** if, for a pair of elements *(pᵢ, qᵢ)* from *p* and *q*, computing
60    /// *ln(qᵢ/pᵢ)* is a panic cause for `A`.
61    ///
62    /// ## Remarks
63    ///
64    /// The Kullback-Leibler divergence is a measure used in [Information Theory]
65    /// to describe the relationship between two probability distribution: it only make sense
66    /// when each array sums to 1 with entries between 0 and 1 (extremes included).
67    ///
68    /// The array values are **not** normalised by this function before
69    /// computing the entropy to avoid introducing potentially
70    /// unnecessary numerical errors (e.g. if the array were to be already normalised).
71    ///
72    /// By definition, *pᵢ ln(qᵢ/pᵢ)* is set to 0 if *pᵢ* is 0.
73    ///
74    /// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
75    /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
76    fn kl_divergence(&self, q: &ArrayRef<A, D>) -> Result<A, MultiInputError>
77    where
78        A: Float;
79
80    /// Computes the [cross entropy] *H(p,q)* between two arrays,
81    /// where `self`=*p*.
82    ///
83    /// The cross entropy is defined as:
84    ///
85    /// ```text
86    ///            n
87    /// H(p,q) = - ∑ pᵢ ln(qᵢ)
88    ///           i=1
89    /// ```
90    ///
91    /// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
92    /// If the array shapes are not identical,
93    /// `Err(MultiInputError::ShapeMismatch)` is returned.
94    ///
95    /// **Panics** if any element in *q* is negative and taking the logarithm of a negative number
96    /// is a panic cause for `A`.
97    ///
98    /// ## Remarks
99    ///
100    /// The cross entropy is a measure used in [Information Theory]
101    /// to describe the relationship between two probability distributions: it only makes sense
102    /// when each array sums to 1 with entries between 0 and 1 (extremes included).
103    ///
104    /// The array values are **not** normalised by this function before
105    /// computing the entropy to avoid introducing potentially
106    /// unnecessary numerical errors (e.g. if the array were to be already normalised).
107    ///
108    /// The cross entropy is often used as an objective/loss function in
109    /// [optimization problems], including [machine learning].
110    ///
111    /// By definition, *pᵢ ln(qᵢ)* is set to 0 if *pᵢ* is 0.
112    ///
113    /// [cross entropy]: https://en.wikipedia.org/wiki/Cross-entropy
114    /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
115    /// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method
116    /// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression
117    fn cross_entropy(&self, q: &ArrayRef<A, D>) -> Result<A, MultiInputError>
118    where
119        A: Float;
120
121    private_decl! {}
122}
123
124impl<A, D> EntropyExt<A, D> for ArrayRef<A, D>
125where
126    D: Dimension,
127{
128    fn entropy(&self) -> Result<A, EmptyInput>
129    where
130        A: Float,
131    {
132        if self.is_empty() {
133            Err(EmptyInput)
134        } else {
135            let entropy = -self
136                .mapv(|x| {
137                    if x == A::zero() {
138                        A::zero()
139                    } else {
140                        x * x.ln()
141                    }
142                })
143                .sum();
144            Ok(entropy)
145        }
146    }
147
148    fn kl_divergence(&self, q: &ArrayRef<A, D>) -> Result<A, MultiInputError>
149    where
150        A: Float,
151    {
152        if self.is_empty() {
153            return Err(MultiInputError::EmptyInput);
154        }
155        if self.shape() != q.shape() {
156            return Err(ShapeMismatch {
157                first_shape: self.shape().to_vec(),
158                second_shape: q.shape().to_vec(),
159            }
160            .into());
161        }
162
163        let mut temp = Array::zeros(self.raw_dim());
164        Zip::from(&mut temp)
165            .and(self)
166            .and(q)
167            .for_each(|result, &p, &q| {
168                *result = {
169                    if p == A::zero() {
170                        A::zero()
171                    } else {
172                        p * (q / p).ln()
173                    }
174                }
175            });
176        let kl_divergence = -temp.sum();
177        Ok(kl_divergence)
178    }
179
180    fn cross_entropy(&self, q: &ArrayRef<A, D>) -> Result<A, MultiInputError>
181    where
182        A: Float,
183    {
184        if self.is_empty() {
185            return Err(MultiInputError::EmptyInput);
186        }
187        if self.shape() != q.shape() {
188            return Err(ShapeMismatch {
189                first_shape: self.shape().to_vec(),
190                second_shape: q.shape().to_vec(),
191            }
192            .into());
193        }
194
195        let mut temp = Array::zeros(self.raw_dim());
196        Zip::from(&mut temp)
197            .and(self)
198            .and(q)
199            .for_each(|result, &p, &q| {
200                *result = {
201                    if p == A::zero() {
202                        A::zero()
203                    } else {
204                        p * q.ln()
205                    }
206                }
207            });
208        let cross_entropy = -temp.sum();
209        Ok(cross_entropy)
210    }
211
212    private_impl! {}
213}
214
215#[cfg(test)]
216mod tests {
217    use super::EntropyExt;
218    use crate::errors::{EmptyInput, MultiInputError};
219    use approx::assert_abs_diff_eq;
220    use ndarray::{array, Array1};
221    use noisy_float::types::n64;
222    use std::f64;
223
224    #[test]
225    fn test_entropy_with_nan_values() {
226        let a = array![f64::NAN, 1.];
227        assert!(a.entropy().unwrap().is_nan());
228    }
229
230    #[test]
231    fn test_entropy_with_empty_array_of_floats() {
232        let a: Array1<f64> = array![];
233        assert_eq!(a.entropy(), Err(EmptyInput));
234    }
235
236    #[test]
237    fn test_entropy_with_array_of_floats() {
238        // Array of probability values - normalized and positive.
239        let a: Array1<f64> = array![
240            0.03602474, 0.01900344, 0.03510129, 0.03414964, 0.00525311, 0.03368976, 0.00065396,
241            0.02906146, 0.00063687, 0.01597306, 0.00787625, 0.00208243, 0.01450896, 0.01803418,
242            0.02055336, 0.03029759, 0.03323628, 0.01218822, 0.0001873, 0.01734179, 0.03521668,
243            0.02564429, 0.02421992, 0.03540229, 0.03497635, 0.03582331, 0.026558, 0.02460495,
244            0.02437716, 0.01212838, 0.00058464, 0.00335236, 0.02146745, 0.00930306, 0.01821588,
245            0.02381928, 0.02055073, 0.01483779, 0.02284741, 0.02251385, 0.00976694, 0.02864634,
246            0.00802828, 0.03464088, 0.03557152, 0.01398894, 0.01831756, 0.0227171, 0.00736204,
247            0.01866295,
248        ];
249        // Computed using scipy.stats.entropy
250        let expected_entropy = 3.721606155686918;
251
252        assert_abs_diff_eq!(a.entropy().unwrap(), expected_entropy, epsilon = 1e-6);
253    }
254
255    #[test]
256    fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), MultiInputError> {
257        let a = array![f64::NAN, 1.];
258        let b = array![2., 1.];
259        assert!(a.cross_entropy(&b)?.is_nan());
260        assert!(b.cross_entropy(&a)?.is_nan());
261        assert!(a.kl_divergence(&b)?.is_nan());
262        assert!(b.kl_divergence(&a)?.is_nan());
263        Ok(())
264    }
265
266    #[test]
267    fn test_cross_entropy_and_kl_with_same_n_dimension_but_different_n_elements() {
268        let p = array![f64::NAN, 1.];
269        let q = array![2., 1., 5.];
270        assert!(q.cross_entropy(&p).is_err());
271        assert!(p.cross_entropy(&q).is_err());
272        assert!(q.kl_divergence(&p).is_err());
273        assert!(p.kl_divergence(&q).is_err());
274    }
275
276    #[test]
277    fn test_cross_entropy_and_kl_with_different_shape_but_same_n_elements() {
278        // p: 3x2, 6 elements
279        let p = array![[f64::NAN, 1.], [6., 7.], [10., 20.]];
280        // q: 2x3, 6 elements
281        let q = array![[2., 1., 5.], [1., 1., 7.],];
282        assert!(q.cross_entropy(&p).is_err());
283        assert!(p.cross_entropy(&q).is_err());
284        assert!(q.kl_divergence(&p).is_err());
285        assert!(p.kl_divergence(&q).is_err());
286    }
287
288    #[test]
289    fn test_cross_entropy_and_kl_with_empty_array_of_floats() {
290        let p: Array1<f64> = array![];
291        let q: Array1<f64> = array![];
292        assert!(p.cross_entropy(&q).unwrap_err().is_empty_input());
293        assert!(p.kl_divergence(&q).unwrap_err().is_empty_input());
294    }
295
296    #[test]
297    fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), MultiInputError> {
298        let p = array![1.];
299        let q = array![-1.];
300        let cross_entropy: f64 = p.cross_entropy(&q)?;
301        let kl_divergence: f64 = p.kl_divergence(&q)?;
302        assert!(cross_entropy.is_nan());
303        assert!(kl_divergence.is_nan());
304        Ok(())
305    }
306
307    #[test]
308    #[should_panic]
309    fn test_cross_entropy_with_noisy_negative_qs() {
310        let p = array![n64(1.)];
311        let q = array![n64(-1.)];
312        let _ = p.cross_entropy(&q);
313    }
314
315    #[test]
316    #[should_panic]
317    fn test_kl_with_noisy_negative_qs() {
318        let p = array![n64(1.)];
319        let q = array![n64(-1.)];
320        let _ = p.kl_divergence(&q);
321    }
322
323    #[test]
324    fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), MultiInputError> {
325        let p = array![0., 0.];
326        let q = array![0., 0.5];
327        assert_eq!(p.cross_entropy(&q)?, 0.);
328        assert_eq!(p.kl_divergence(&q)?, 0.);
329        Ok(())
330    }
331
332    #[test]
333    fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership(
334    ) -> Result<(), MultiInputError> {
335        let p = array![0.5, 0.5];
336        let mut q = array![0.5, 0.];
337        assert_eq!(p.cross_entropy(&q.view_mut())?, f64::INFINITY);
338        assert_eq!(p.kl_divergence(&q.view_mut())?, f64::INFINITY);
339        Ok(())
340    }
341
342    #[test]
343    fn test_cross_entropy() -> Result<(), MultiInputError> {
344        // Arrays of probability values - normalized and positive.
345        let p: Array1<f64> = array![
346            0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, 0.05782189,
347            0.0471258, 0.05594036, 0.01630048, 0.07085162, 0.05365855, 0.01959158, 0.05020174,
348            0.03801479, 0.00092234, 0.08515856, 0.00580683, 0.0156542, 0.0860375, 0.0724246,
349            0.00727477, 0.01004402, 0.01854399, 0.03504082,
350        ];
351        let q: Array1<f64> = array![
352            0.06622616, 0.0478948, 0.03227816, 0.06460884, 0.05795974, 0.01377489, 0.05604812,
353            0.01202684, 0.01647579, 0.03392697, 0.01656126, 0.00867528, 0.0625685, 0.07381292,
354            0.05489067, 0.01385491, 0.03639174, 0.00511611, 0.05700415, 0.05183825, 0.06703064,
355            0.01813342, 0.0007763, 0.0735472, 0.05857833,
356        ];
357        // Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q)
358        let expected_cross_entropy = 3.385347705020779;
359
360        assert_abs_diff_eq!(p.cross_entropy(&q)?, expected_cross_entropy, epsilon = 1e-6);
361        Ok(())
362    }
363
364    #[test]
365    fn test_kl() -> Result<(), MultiInputError> {
366        // Arrays of probability values - normalized and positive.
367        let p: Array1<f64> = array![
368            0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, 0.02183501, 0.00137516,
369            0.02213802, 0.02745017, 0.02163975, 0.0324602, 0.03622766, 0.00782343, 0.00222498,
370            0.03028156, 0.02346124, 0.00071105, 0.00794496, 0.0127609, 0.02899124, 0.01281487,
371            0.0230803, 0.01531864, 0.00518158, 0.02233383, 0.0220279, 0.03196097, 0.03710063,
372            0.01817856, 0.03524661, 0.02902393, 0.00853364, 0.01255615, 0.03556958, 0.00400151,
373            0.01335932, 0.01864965, 0.02371322, 0.02026543, 0.0035375, 0.01988341, 0.02621831,
374            0.03564644, 0.01389121, 0.03151622, 0.03195532, 0.00717521, 0.03547256, 0.00371394,
375            0.01108706,
376        ];
377        let q: Array1<f64> = array![
378            0.02038386, 0.03143914, 0.02630206, 0.0171595, 0.0067072, 0.00911324, 0.02635717,
379            0.01269113, 0.0302361, 0.02243133, 0.01902902, 0.01297185, 0.02118908, 0.03309548,
380            0.01266687, 0.0184529, 0.01830936, 0.03430437, 0.02898924, 0.02238251, 0.0139771,
381            0.01879774, 0.02396583, 0.03019978, 0.01421278, 0.02078981, 0.03542451, 0.02887438,
382            0.01261783, 0.01014241, 0.03263407, 0.0095969, 0.01923903, 0.0051315, 0.00924686,
383            0.00148845, 0.00341391, 0.01480373, 0.01920798, 0.03519871, 0.03315135, 0.02099325,
384            0.03251755, 0.00337555, 0.03432165, 0.01763753, 0.02038337, 0.01923023, 0.01438769,
385            0.02082707,
386        ];
387        // Computed using scipy.stats.entropy(p, q)
388        let expected_kl = 0.3555862567800096;
389
390        assert_abs_diff_eq!(p.kl_divergence(&q)?, expected_kl, epsilon = 1e-6);
391        Ok(())
392    }
393}