opensrdk_linear_algebra/matrix/sp_hp/
mod.rs1use serde::Deserialize;
2use serde::Serialize;
3
4use crate::number::*;
5use crate::Matrix;
6use crate::MatrixError;
7
8pub mod pp;
9
10pub mod trf;
11pub mod tri;
12pub mod trs;
13
14#[derive(Clone, Debug, Default, PartialEq, Hash, Serialize, Deserialize)]
15pub struct SymmetricPackedMatrix<T = f64>
16where
17 T: Number,
18{
19 dim: usize,
20 elems: Vec<T>,
21}
22
23impl<T> SymmetricPackedMatrix<T>
24where
25 T: Number,
26{
27 pub fn new(dim: usize) -> Self {
28 Self {
29 dim,
30 elems: vec![T::default(); dim * (dim + 1) / 2],
31 }
32 }
33
34 pub fn from(dim: usize, elems: Vec<T>) -> Result<Self, MatrixError> {
36 if elems.len() != dim * (dim + 1) / 2 {
37 return Err(MatrixError::DimensionMismatch);
38 }
39
40 Ok(Self { dim, elems })
41 }
42
43 pub fn dim(&self) -> usize {
44 self.dim
45 }
46
47 pub fn eject(self) -> Vec<T> {
48 self.elems
49 }
50
51 pub fn elems(&self) -> &[T] {
52 &self.elems
53 }
54
55 pub fn elems_mut(&mut self) -> &mut [T] {
56 &mut self.elems
57 }
58
59 pub fn from_mat(mat: &Matrix<T>) -> Result<Self, MatrixError> {
60 let n = mat.rows();
61 if n != mat.cols() {
62 return Err(MatrixError::DimensionMismatch);
63 }
64
65 let elems = (0..n)
66 .into_iter()
67 .map(|j| (j, &mat[j]))
68 .flat_map(|(j, col)| col[j..n].into_iter())
69 .map(|e| *e)
70 .collect::<Vec<_>>();
71 Self::from(n, elems)
72 }
73
74 pub fn to_mat(&self) -> Matrix<T> {
75 let n = self.dim;
76 let elems = (0..n)
77 .into_iter()
78 .flat_map(|j| {
79 let index = n * (n + 1) / 2 - (n - j) * (n - j + 1) / 2;
80 vec![T::default(); j]
81 .into_iter()
82 .chain(self.elems[index..index + (n - j)].iter().map(|e| *e))
83 })
84 .collect::<Vec<_>>();
85
86 Matrix::<T>::from(n, elems).unwrap()
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use crate::*;
93
94 #[test]
95 fn it_works() {
96 let a = mat!(
97 1.0, 0.0, 0.0, 0.0, 0.0, 0.0;
98 2.0, 3.0, 0.0, 0.0, 0.0, 0.0;
99 4.0, 5.0, 6.0, 0.0, 0.0, 0.0;
100 7.0, 8.0, 9.0, 10.0, 0.0, 0.0;
101 11.0, 12.0, 13.0, 14.0, 15.0, 0.0;
102 16.0, 17.0, 18.0, 19.0, 20.0, 21.0
103 );
104
105 let ap = SymmetricPackedMatrix::from_mat(&a).unwrap();
106 let n = ap.dim();
107
108 assert_eq!(ap.elems()[n * (n + 1) / 2 - 1], 21.0);
109
110 let a2 = ap.to_mat();
111
112 assert_eq!(a, a2);
113 }
114}