mathru/algebra/linear/matrix/diagonal/
diagonal.rs

1use crate::algebra::{
2    abstr::{Field, Scalar, Zero},
3    linear::matrix::General,
4};
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7
8/// Diagonal matrix
9#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
10#[derive(Debug, Clone)]
11pub struct Diagonal<T> {
12    pub(crate) matrix: General<T>,
13}
14
15impl<T> Diagonal<T>
16where
17    T: Field + Scalar + Zero,
18{
19    /// Construct a matrix with vec as its diagonal.
20    ///
21    /// # Example
22    ///
23    /// ```
24    /// use mathru::algebra::linear::matrix::{General, Diagonal};
25    /// use mathru::matrix;
26    ///
27    /// let d: Diagonal<f64> = Diagonal::new(&[1.0, 2.0]);
28    ///
29    /// let d_ref: Diagonal<f64> = matrix![1.0, 0.0;
30    ///                                    0.0, 2.0].into();
31    ///
32    /// assert_eq!(d, d_ref);
33    /// ```
34    pub fn new(vec: &[T]) -> Diagonal<T> {
35        let mut g = General::zero(vec.len(), vec.len());
36        for (idx, v) in vec.iter().enumerate() {
37            g[[idx, idx]] = *v;
38        }
39        Diagonal { matrix: g }
40    }
41}
42impl<T> Diagonal<T> {
43    pub fn dim(&self) -> (usize, usize) {
44        self.matrix.dim()
45    }
46}
47
48impl<T> Diagonal<T>
49where
50    T: Clone,
51{
52    /// Applies the function f on every diagonal element in the matrix
53    pub fn apply_mut(mut self: Diagonal<T>, f: &dyn Fn(&T) -> T) -> Diagonal<T> {
54        let (m, n) = self.dim();
55        let k = m.min(n);
56        for i in 0..k {
57            self[[i, i]] = f(&self[[i, i]]);
58        }
59
60        self
61    }
62
63    pub fn apply(self: &Diagonal<T>, f: &dyn Fn(&T) -> T) -> Diagonal<T> {
64        (self.clone()).apply_mut(f)
65    }
66
67    pub fn mut_apply(self: &mut Diagonal<T>, f: &dyn Fn(&mut T) -> T) {
68        let (m, n) = self.dim();
69        let k = m.min(n);
70        for i in 0..k {
71            self[[i, i]] = f(&mut self[[i, i]]);
72        }
73    }
74}