1use serde::{Deserialize, Serialize};
12use sprs::{CsMat, TriMat};
13
14#[derive(Clone, Debug, Serialize, Deserialize)]
20#[serde(transparent)]
21pub struct StoichMatrix(CsMat<f64>);
22
23impl Default for StoichMatrix {
24 fn default() -> Self {
25 Self::zeros(0, 0)
26 }
27}
28
29impl StoichMatrix {
30 pub fn from_triplets(
34 n_rows: usize,
35 n_cols: usize,
36 triplets: impl IntoIterator<Item = (usize, usize, f64)>,
37 ) -> Self {
38 let mut tri: TriMat<f64> = TriMat::new((n_rows, n_cols));
39 for (r, c, v) in triplets {
40 tri.add_triplet(r, c, v);
41 }
42 Self(tri.to_csc())
43 }
44
45 pub fn zeros(n_rows: usize, n_cols: usize) -> Self {
47 let tri: TriMat<f64> = TriMat::new((n_rows, n_cols));
48 Self(tri.to_csc())
49 }
50
51 pub fn rows(&self) -> usize {
52 self.0.rows()
53 }
54 pub fn cols(&self) -> usize {
55 self.0.cols()
56 }
57 pub fn nnz(&self) -> usize {
58 self.0.nnz()
59 }
60
61 pub fn inner(&self) -> &CsMat<f64> {
62 &self.0
63 }
64 pub fn into_inner(self) -> CsMat<f64> {
65 self.0
66 }
67
68 pub fn column(&self, col: usize) -> Vec<(usize, f64)> {
74 let mut out = Vec::new();
75 if self.0.is_csc() {
76 if let Some(view) = self.0.outer_view(col) {
77 for (row, &val) in view.iter() {
78 out.push((row, val));
79 }
80 }
81 } else {
82 for (val, (row, c)) in self.0.iter() {
83 if c == col {
84 out.push((row, *val));
85 }
86 }
87 }
88 out
89 }
90}
91
92impl From<CsMat<f64>> for StoichMatrix {
93 fn from(m: CsMat<f64>) -> Self {
94 Self(m)
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 #[test]
103 fn build_and_query() {
104 let s = StoichMatrix::from_triplets(3, 2, vec![(0, 0, -1.0), (1, 0, 1.0), (2, 1, 2.0)]);
105 assert_eq!(s.rows(), 3);
106 assert_eq!(s.cols(), 2);
107 assert_eq!(s.nnz(), 3);
108
109 let col0 = s.column(0);
110 assert_eq!(col0.len(), 2);
111 assert!(col0.contains(&(0, -1.0)));
112 assert!(col0.contains(&(1, 1.0)));
113 }
114
115 #[test]
116 fn serde_json_roundtrip() {
117 let s = StoichMatrix::from_triplets(2, 2, vec![(0, 0, -1.0), (1, 1, 1.0)]);
118 let j = serde_json::to_string(&s).unwrap();
119 let back: StoichMatrix = serde_json::from_str(&j).unwrap();
120 assert_eq!(back.rows(), 2);
121 assert_eq!(back.cols(), 2);
122 assert_eq!(back.nnz(), 2);
123 }
124}