opensrdk_linear_algebra/matrix/sp_hp/
mod.rs

1use serde::Deserialize;
2use serde::Serialize;
3
4use crate::number::*;
5use crate::Matrix;
6use crate::MatrixError;
7
8pub mod pp;
9
10pub mod trf;
11pub mod tri;
12pub mod trs;
13
14#[derive(Clone, Debug, Default, PartialEq, Hash, Serialize, Deserialize)]
15pub struct SymmetricPackedMatrix<T = f64>
16where
17    T: Number,
18{
19    dim: usize,
20    elems: Vec<T>,
21}
22
23impl<T> SymmetricPackedMatrix<T>
24where
25    T: Number,
26{
27    pub fn new(dim: usize) -> Self {
28        Self {
29            dim,
30            elems: vec![T::default(); dim * (dim + 1) / 2],
31        }
32    }
33
34    /// You can do `unwrap()` if you have a conviction that `elems.len() == dim * (dim + 1) / 2`
35    pub fn from(dim: usize, elems: Vec<T>) -> Result<Self, MatrixError> {
36        if elems.len() != dim * (dim + 1) / 2 {
37            return Err(MatrixError::DimensionMismatch);
38        }
39
40        Ok(Self { dim, elems })
41    }
42
43    pub fn dim(&self) -> usize {
44        self.dim
45    }
46
47    pub fn eject(self) -> Vec<T> {
48        self.elems
49    }
50
51    pub fn elems(&self) -> &[T] {
52        &self.elems
53    }
54
55    pub fn elems_mut(&mut self) -> &mut [T] {
56        &mut self.elems
57    }
58
59    pub fn from_mat(mat: &Matrix<T>) -> Result<Self, MatrixError> {
60        let n = mat.rows();
61        if n != mat.cols() {
62            return Err(MatrixError::DimensionMismatch);
63        }
64
65        let elems = (0..n)
66            .into_iter()
67            .map(|j| (j, &mat[j]))
68            .flat_map(|(j, col)| col[j..n].into_iter())
69            .map(|e| *e)
70            .collect::<Vec<_>>();
71        Self::from(n, elems)
72    }
73
74    pub fn to_mat(&self) -> Matrix<T> {
75        let n = self.dim;
76        let elems = (0..n)
77            .into_iter()
78            .flat_map(|j| {
79                let index = n * (n + 1) / 2 - (n - j) * (n - j + 1) / 2;
80                vec![T::default(); j]
81                    .into_iter()
82                    .chain(self.elems[index..index + (n - j)].iter().map(|e| *e))
83            })
84            .collect::<Vec<_>>();
85
86        Matrix::<T>::from(n, elems).unwrap()
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use crate::*;
93
94    #[test]
95    fn it_works() {
96        let a = mat!(
97           1.0,  0.0,  0.0,  0.0,  0.0,  0.0;
98           2.0,  3.0,  0.0,  0.0,  0.0,  0.0;
99           4.0,  5.0,  6.0,  0.0,  0.0,  0.0;
100           7.0,  8.0,  9.0, 10.0,  0.0,  0.0;
101          11.0, 12.0, 13.0, 14.0, 15.0,  0.0;
102          16.0, 17.0, 18.0, 19.0, 20.0, 21.0
103        );
104
105        let ap = SymmetricPackedMatrix::from_mat(&a).unwrap();
106        let n = ap.dim();
107
108        assert_eq!(ap.elems()[n * (n + 1) / 2 - 1], 21.0);
109
110        let a2 = ap.to_mat();
111
112        assert_eq!(a, a2);
113    }
114}