1use super::compress::build_compressed_format;
2use super::csc::CscMatrix;
3use crate::error::SolverError;
4
5#[derive(Debug, Clone)]
16pub struct CsrMatrix {
17 pub row_ptr: Vec<usize>,
20 pub col_ind: Vec<usize>,
22 pub values: Vec<f64>,
24 pub nrows: usize,
26 pub ncols: usize,
28}
29
30impl CsrMatrix {
31 pub fn nnz(&self) -> usize {
33 self.values.len()
34 }
35
36 pub fn from_triplets(
52 rows: &[usize],
53 cols: &[usize],
54 vals: &[f64],
55 nrows: usize,
56 ncols: usize,
57 ) -> Result<Self, SolverError> {
58 if rows.len() != cols.len() || rows.len() != vals.len() {
59 return Err(SolverError::DimensionMismatch { field: "triplet_arrays", expected: rows.len(), got: vals.len() });
60 }
61 let (row_ptr, col_ind, values) =
63 build_compressed_format(nrows, ncols, rows, cols, vals)?;
64 Ok(Self { row_ptr, col_ind, values, nrows, ncols })
65 }
66
67 pub fn get_row(&self, i: usize) -> Result<(&[usize], &[f64]), SolverError> {
79 if i >= self.nrows {
80 return Err(SolverError::IndexOutOfBounds { context: "row", index: i, bound: self.nrows });
81 }
82 let start = self.row_ptr[i];
83 let end = self.row_ptr[i + 1];
84 Ok((&self.col_ind[start..end], &self.values[start..end]))
85 }
86
87 pub fn from_csc(csc: &CscMatrix) -> Self {
98 let nnz = csc.nnz();
99 let nrows = csc.nrows;
100 let ncols = csc.ncols;
101
102 let mut row_ptr = vec![0usize; nrows + 1];
104 for &r in &csc.row_ind {
105 row_ptr[r + 1] += 1;
106 }
107 for i in 0..nrows {
108 row_ptr[i + 1] += row_ptr[i];
109 }
110
111 let mut col_ind = vec![0usize; nnz];
114 let mut values = vec![0.0f64; nnz];
115 let mut cur = row_ptr[..nrows].to_vec();
116
117 for j in 0..ncols {
118 let start = csc.col_ptr[j];
119 let end = csc.col_ptr[j + 1];
120 for k in start..end {
121 let r = csc.row_ind[k];
122 let pos = cur[r];
123 col_ind[pos] = j;
124 values[pos] = csc.values[k];
125 cur[r] += 1;
126 }
127 }
128
129 Self {
130 row_ptr,
131 col_ind,
132 values,
133 nrows,
134 ncols,
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn test_csr_from_triplets() {
145 let rows = vec![0, 0, 1, 2, 2];
146 let cols = vec![0, 2, 1, 0, 2];
147 let vals = vec![1.0, 2.0, 3.0, 4.0, 5.0];
148 let mat = CsrMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
149 assert_eq!(mat.nrows, 3);
150 assert_eq!(mat.ncols, 3);
151 assert_eq!(mat.nnz(), 5);
152
153 let (ci, v) = mat.get_row(0).unwrap();
154 assert_eq!(ci, &[0, 2]);
155 assert_eq!(v, &[1.0, 2.0]);
156
157 let (ci, v) = mat.get_row(1).unwrap();
158 assert_eq!(ci, &[1]);
159 assert_eq!(v, &[3.0]);
160
161 let (ci, v) = mat.get_row(2).unwrap();
162 assert_eq!(ci, &[0, 2]);
163 assert_eq!(v, &[4.0, 5.0]);
164 }
165
166 #[test]
167 fn test_csr_from_csc() {
168 let rows = vec![0, 2, 1, 0, 2];
169 let cols = vec![0, 0, 1, 2, 2];
170 let vals = vec![1.0, 4.0, 3.0, 2.0, 5.0];
171 let csc = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
172 let csr = CsrMatrix::from_csc(&csc);
173
174 assert_eq!(csr.nrows, 3);
175 assert_eq!(csr.ncols, 3);
176 assert_eq!(csr.nnz(), 5);
177
178 let (ci, v) = csr.get_row(0).unwrap();
179 assert_eq!(ci, &[0, 2]);
180 assert_eq!(v, &[1.0, 2.0]);
181
182 let (ci, v) = csr.get_row(1).unwrap();
183 assert_eq!(ci, &[1]);
184 assert_eq!(v, &[3.0]);
185
186 let (ci, v) = csr.get_row(2).unwrap();
187 assert_eq!(ci, &[0, 2]);
188 assert_eq!(v, &[4.0, 5.0]);
189 }
190}