ndarray_glm/
math.rs

1//! Mathematical helper functions
2use crate::num::Float;
3use ndarray::Array2;
4use ndarray_linalg::QRSquareInto;
5
6/// The product-logarithm function (not the W function) x * log(x). If x == 0, 0 is returned.
7pub fn prod_log<F>(x: F) -> F
8where
9    F: Float,
10{
11    if x == F::zero() {
12        return F::zero();
13    }
14    x * num_traits::Float::ln(x)
15}
16
17/// Returns true iff the matrix is rank deficient with tolerance `eps` using QR
18/// decomposition.
19// NOTE: SVD may be faster
20pub fn is_rank_deficient<F>(matrix: Array2<F>, eps: F) -> ndarray_linalg::error::Result<bool>
21where
22    F: Float,
23{
24    if matrix.ncols() != matrix.nrows() {
25        return Ok(true);
26    }
27    let (_, r) = matrix.qr_square_into()?;
28    let diag = r.into_diag();
29    for e in diag.into_iter() {
30        if num_traits::Float::abs(e) < eps {
31            return Ok(true);
32        }
33    }
34    Ok(false)
35}
36
37#[cfg(test)]
38mod tests {
39    use super::*;
40    use crate::array;
41    use approx::assert_abs_diff_eq;
42
43    #[test]
44    fn test_prod_log() {
45        assert_abs_diff_eq!(0., prod_log(0.));
46        let e: f64 = std::f64::consts::E;
47        assert_abs_diff_eq!(e, prod_log(e));
48    }
49
50    #[test]
51    fn test_rank_def() {
52        assert!(is_rank_deficient(array![[0., 1.]], 0.).unwrap());
53        assert!(!is_rank_deficient(array![[0., 1.], [2., 0.]], f32::EPSILON as f64).unwrap());
54        assert!(is_rank_deficient(array![[0., 1.], [0., 2.342]], f64::EPSILON).unwrap());
55        assert!(is_rank_deficient(
56            array![[1., 1., 0.], [1., 0.5, 0.5], [1., 0.2, 0.8]],
57            f64::EPSILON
58        )
59        .unwrap());
60    }
61}