1use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch};
3use ndarray::{Array, ArrayRef, Dimension, Zip};
4use num_traits::Float;
5
6pub trait EntropyExt<A, D>
10where
11 D: Dimension,
12{
13 fn entropy(&self) -> Result<A, EmptyInput>
41 where
42 A: Float;
43
44 fn kl_divergence(&self, q: &ArrayRef<A, D>) -> Result<A, MultiInputError>
77 where
78 A: Float;
79
80 fn cross_entropy(&self, q: &ArrayRef<A, D>) -> Result<A, MultiInputError>
118 where
119 A: Float;
120
121 private_decl! {}
122}
123
124impl<A, D> EntropyExt<A, D> for ArrayRef<A, D>
125where
126 D: Dimension,
127{
128 fn entropy(&self) -> Result<A, EmptyInput>
129 where
130 A: Float,
131 {
132 if self.is_empty() {
133 Err(EmptyInput)
134 } else {
135 let entropy = -self
136 .mapv(|x| {
137 if x == A::zero() {
138 A::zero()
139 } else {
140 x * x.ln()
141 }
142 })
143 .sum();
144 Ok(entropy)
145 }
146 }
147
148 fn kl_divergence(&self, q: &ArrayRef<A, D>) -> Result<A, MultiInputError>
149 where
150 A: Float,
151 {
152 if self.is_empty() {
153 return Err(MultiInputError::EmptyInput);
154 }
155 if self.shape() != q.shape() {
156 return Err(ShapeMismatch {
157 first_shape: self.shape().to_vec(),
158 second_shape: q.shape().to_vec(),
159 }
160 .into());
161 }
162
163 let mut temp = Array::zeros(self.raw_dim());
164 Zip::from(&mut temp)
165 .and(self)
166 .and(q)
167 .for_each(|result, &p, &q| {
168 *result = {
169 if p == A::zero() {
170 A::zero()
171 } else {
172 p * (q / p).ln()
173 }
174 }
175 });
176 let kl_divergence = -temp.sum();
177 Ok(kl_divergence)
178 }
179
180 fn cross_entropy(&self, q: &ArrayRef<A, D>) -> Result<A, MultiInputError>
181 where
182 A: Float,
183 {
184 if self.is_empty() {
185 return Err(MultiInputError::EmptyInput);
186 }
187 if self.shape() != q.shape() {
188 return Err(ShapeMismatch {
189 first_shape: self.shape().to_vec(),
190 second_shape: q.shape().to_vec(),
191 }
192 .into());
193 }
194
195 let mut temp = Array::zeros(self.raw_dim());
196 Zip::from(&mut temp)
197 .and(self)
198 .and(q)
199 .for_each(|result, &p, &q| {
200 *result = {
201 if p == A::zero() {
202 A::zero()
203 } else {
204 p * q.ln()
205 }
206 }
207 });
208 let cross_entropy = -temp.sum();
209 Ok(cross_entropy)
210 }
211
212 private_impl! {}
213}
214
215#[cfg(test)]
216mod tests {
217 use super::EntropyExt;
218 use crate::errors::{EmptyInput, MultiInputError};
219 use approx::assert_abs_diff_eq;
220 use ndarray::{array, Array1};
221 use noisy_float::types::n64;
222 use std::f64;
223
224 #[test]
225 fn test_entropy_with_nan_values() {
226 let a = array![f64::NAN, 1.];
227 assert!(a.entropy().unwrap().is_nan());
228 }
229
230 #[test]
231 fn test_entropy_with_empty_array_of_floats() {
232 let a: Array1<f64> = array![];
233 assert_eq!(a.entropy(), Err(EmptyInput));
234 }
235
236 #[test]
237 fn test_entropy_with_array_of_floats() {
238 let a: Array1<f64> = array![
240 0.03602474, 0.01900344, 0.03510129, 0.03414964, 0.00525311, 0.03368976, 0.00065396,
241 0.02906146, 0.00063687, 0.01597306, 0.00787625, 0.00208243, 0.01450896, 0.01803418,
242 0.02055336, 0.03029759, 0.03323628, 0.01218822, 0.0001873, 0.01734179, 0.03521668,
243 0.02564429, 0.02421992, 0.03540229, 0.03497635, 0.03582331, 0.026558, 0.02460495,
244 0.02437716, 0.01212838, 0.00058464, 0.00335236, 0.02146745, 0.00930306, 0.01821588,
245 0.02381928, 0.02055073, 0.01483779, 0.02284741, 0.02251385, 0.00976694, 0.02864634,
246 0.00802828, 0.03464088, 0.03557152, 0.01398894, 0.01831756, 0.0227171, 0.00736204,
247 0.01866295,
248 ];
249 let expected_entropy = 3.721606155686918;
251
252 assert_abs_diff_eq!(a.entropy().unwrap(), expected_entropy, epsilon = 1e-6);
253 }
254
255 #[test]
256 fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), MultiInputError> {
257 let a = array![f64::NAN, 1.];
258 let b = array![2., 1.];
259 assert!(a.cross_entropy(&b)?.is_nan());
260 assert!(b.cross_entropy(&a)?.is_nan());
261 assert!(a.kl_divergence(&b)?.is_nan());
262 assert!(b.kl_divergence(&a)?.is_nan());
263 Ok(())
264 }
265
266 #[test]
267 fn test_cross_entropy_and_kl_with_same_n_dimension_but_different_n_elements() {
268 let p = array![f64::NAN, 1.];
269 let q = array![2., 1., 5.];
270 assert!(q.cross_entropy(&p).is_err());
271 assert!(p.cross_entropy(&q).is_err());
272 assert!(q.kl_divergence(&p).is_err());
273 assert!(p.kl_divergence(&q).is_err());
274 }
275
276 #[test]
277 fn test_cross_entropy_and_kl_with_different_shape_but_same_n_elements() {
278 let p = array![[f64::NAN, 1.], [6., 7.], [10., 20.]];
280 let q = array![[2., 1., 5.], [1., 1., 7.],];
282 assert!(q.cross_entropy(&p).is_err());
283 assert!(p.cross_entropy(&q).is_err());
284 assert!(q.kl_divergence(&p).is_err());
285 assert!(p.kl_divergence(&q).is_err());
286 }
287
288 #[test]
289 fn test_cross_entropy_and_kl_with_empty_array_of_floats() {
290 let p: Array1<f64> = array![];
291 let q: Array1<f64> = array![];
292 assert!(p.cross_entropy(&q).unwrap_err().is_empty_input());
293 assert!(p.kl_divergence(&q).unwrap_err().is_empty_input());
294 }
295
296 #[test]
297 fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), MultiInputError> {
298 let p = array![1.];
299 let q = array![-1.];
300 let cross_entropy: f64 = p.cross_entropy(&q)?;
301 let kl_divergence: f64 = p.kl_divergence(&q)?;
302 assert!(cross_entropy.is_nan());
303 assert!(kl_divergence.is_nan());
304 Ok(())
305 }
306
307 #[test]
308 #[should_panic]
309 fn test_cross_entropy_with_noisy_negative_qs() {
310 let p = array![n64(1.)];
311 let q = array![n64(-1.)];
312 let _ = p.cross_entropy(&q);
313 }
314
315 #[test]
316 #[should_panic]
317 fn test_kl_with_noisy_negative_qs() {
318 let p = array![n64(1.)];
319 let q = array![n64(-1.)];
320 let _ = p.kl_divergence(&q);
321 }
322
323 #[test]
324 fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), MultiInputError> {
325 let p = array![0., 0.];
326 let q = array![0., 0.5];
327 assert_eq!(p.cross_entropy(&q)?, 0.);
328 assert_eq!(p.kl_divergence(&q)?, 0.);
329 Ok(())
330 }
331
332 #[test]
333 fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership(
334 ) -> Result<(), MultiInputError> {
335 let p = array![0.5, 0.5];
336 let mut q = array![0.5, 0.];
337 assert_eq!(p.cross_entropy(&q.view_mut())?, f64::INFINITY);
338 assert_eq!(p.kl_divergence(&q.view_mut())?, f64::INFINITY);
339 Ok(())
340 }
341
342 #[test]
343 fn test_cross_entropy() -> Result<(), MultiInputError> {
344 let p: Array1<f64> = array![
346 0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, 0.05782189,
347 0.0471258, 0.05594036, 0.01630048, 0.07085162, 0.05365855, 0.01959158, 0.05020174,
348 0.03801479, 0.00092234, 0.08515856, 0.00580683, 0.0156542, 0.0860375, 0.0724246,
349 0.00727477, 0.01004402, 0.01854399, 0.03504082,
350 ];
351 let q: Array1<f64> = array![
352 0.06622616, 0.0478948, 0.03227816, 0.06460884, 0.05795974, 0.01377489, 0.05604812,
353 0.01202684, 0.01647579, 0.03392697, 0.01656126, 0.00867528, 0.0625685, 0.07381292,
354 0.05489067, 0.01385491, 0.03639174, 0.00511611, 0.05700415, 0.05183825, 0.06703064,
355 0.01813342, 0.0007763, 0.0735472, 0.05857833,
356 ];
357 let expected_cross_entropy = 3.385347705020779;
359
360 assert_abs_diff_eq!(p.cross_entropy(&q)?, expected_cross_entropy, epsilon = 1e-6);
361 Ok(())
362 }
363
364 #[test]
365 fn test_kl() -> Result<(), MultiInputError> {
366 let p: Array1<f64> = array![
368 0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, 0.02183501, 0.00137516,
369 0.02213802, 0.02745017, 0.02163975, 0.0324602, 0.03622766, 0.00782343, 0.00222498,
370 0.03028156, 0.02346124, 0.00071105, 0.00794496, 0.0127609, 0.02899124, 0.01281487,
371 0.0230803, 0.01531864, 0.00518158, 0.02233383, 0.0220279, 0.03196097, 0.03710063,
372 0.01817856, 0.03524661, 0.02902393, 0.00853364, 0.01255615, 0.03556958, 0.00400151,
373 0.01335932, 0.01864965, 0.02371322, 0.02026543, 0.0035375, 0.01988341, 0.02621831,
374 0.03564644, 0.01389121, 0.03151622, 0.03195532, 0.00717521, 0.03547256, 0.00371394,
375 0.01108706,
376 ];
377 let q: Array1<f64> = array![
378 0.02038386, 0.03143914, 0.02630206, 0.0171595, 0.0067072, 0.00911324, 0.02635717,
379 0.01269113, 0.0302361, 0.02243133, 0.01902902, 0.01297185, 0.02118908, 0.03309548,
380 0.01266687, 0.0184529, 0.01830936, 0.03430437, 0.02898924, 0.02238251, 0.0139771,
381 0.01879774, 0.02396583, 0.03019978, 0.01421278, 0.02078981, 0.03542451, 0.02887438,
382 0.01261783, 0.01014241, 0.03263407, 0.0095969, 0.01923903, 0.0051315, 0.00924686,
383 0.00148845, 0.00341391, 0.01480373, 0.01920798, 0.03519871, 0.03315135, 0.02099325,
384 0.03251755, 0.00337555, 0.03432165, 0.01763753, 0.02038337, 0.01923023, 0.01438769,
385 0.02082707,
386 ];
387 let expected_kl = 0.3555862567800096;
389
390 assert_abs_diff_eq!(p.kl_divergence(&q)?, expected_kl, epsilon = 1e-6);
391 Ok(())
392 }
393}