1use super::compress::build_compressed_format;
2use crate::error::SolverError;
3
4#[derive(Debug, Clone)]
15pub struct CscMatrix {
16 pub(crate) col_ptr: Vec<usize>,
17 pub(crate) row_ind: Vec<usize>,
18 pub(crate) values: Vec<f64>,
19 pub(crate) nrows: usize,
20 pub(crate) ncols: usize,
21}
22
23impl CscMatrix {
24 pub fn new(nrows: usize, ncols: usize) -> Self {
32 Self {
33 col_ptr: vec![0; ncols + 1],
34 row_ind: Vec::new(),
35 values: Vec::new(),
36 nrows,
37 ncols,
38 }
39 }
40
41 pub fn nnz(&self) -> usize {
42 self.values.len()
43 }
44
45 pub fn col_ptr(&self) -> &[usize] {
46 &self.col_ptr
47 }
48
49 pub fn row_ind(&self) -> &[usize] {
50 &self.row_ind
51 }
52
53 pub fn values(&self) -> &[f64] {
54 &self.values
55 }
56
57 pub fn scale_values(&self, factor: f64) -> Self {
59 Self {
60 col_ptr: self.col_ptr.clone(),
61 row_ind: self.row_ind.clone(),
62 values: self.values.iter().map(|&v| v * factor).collect(),
63 nrows: self.nrows,
64 ncols: self.ncols,
65 }
66 }
67
68 pub fn nrows(&self) -> usize {
69 self.nrows
70 }
71
72 pub fn ncols(&self) -> usize {
73 self.ncols
74 }
75
76 pub fn row_infinity_norms(&self) -> Vec<f64> {
81 let mut norms = vec![0.0_f64; self.nrows];
82 for (&val, &row) in self.values.iter().zip(self.row_ind.iter()) {
83 let abs_val = val.abs();
84 if abs_val > norms[row] {
85 norms[row] = abs_val;
86 }
87 }
88 norms
89 }
90
91 pub fn from_triplets(
95 rows: &[usize],
96 cols: &[usize],
97 vals: &[f64],
98 nrows: usize,
99 ncols: usize,
100 ) -> Result<Self, SolverError> {
101 if rows.len() != cols.len() || rows.len() != vals.len() {
102 return Err(SolverError::DimensionMismatch {
103 field: "triplet_arrays",
104 expected: rows.len(),
105 got: vals.len(),
106 });
107 }
108 let (col_ptr, row_ind, values) = build_compressed_format(ncols, nrows, cols, rows, vals)?;
110 Ok(Self {
111 col_ptr,
112 row_ind,
113 values,
114 nrows,
115 ncols,
116 })
117 }
118
119 pub fn transpose(&self) -> Self {
124 let nnz = self.nnz();
125 let mut row_count = vec![0usize; self.nrows];
128 for &r in &self.row_ind {
129 row_count[r] += 1;
130 }
131
132 let mut col_ptr = vec![0usize; self.nrows + 1];
134 for r in 0..self.nrows {
135 col_ptr[r + 1] = col_ptr[r] + row_count[r];
136 }
137
138 let mut row_ind = vec![0usize; nnz];
144 let mut values = vec![0.0f64; nnz];
145 let mut pos = col_ptr[..self.nrows].to_vec();
146
147 for col in 0..self.ncols {
148 let start = self.col_ptr[col];
149 let end = self.col_ptr[col + 1];
150 for k in start..end {
151 let row = self.row_ind[k];
152 let p = pos[row];
153 row_ind[p] = col;
154 values[p] = self.values[k];
155 pos[row] += 1;
156 }
157 }
158
159 Self {
160 col_ptr,
161 row_ind,
162 values,
163 nrows: self.ncols,
164 ncols: self.nrows,
165 }
166 }
167
168 pub fn mat_vec_mul(&self, x: &[f64]) -> Result<Vec<f64>, SolverError> {
170 if x.len() != self.ncols {
171 return Err(SolverError::DimensionMismatch {
172 field: "vector",
173 expected: self.ncols,
174 got: x.len(),
175 });
176 }
177
178 let mut y = vec![0.0; self.nrows];
179 for (col, &x_val) in x.iter().enumerate() {
180 let start = self.col_ptr[col];
181 let end = self.col_ptr[col + 1];
182 for idx in start..end {
183 let row = self.row_ind[idx];
184 let a_val = self.values[idx];
185 y[row] += a_val * x_val;
186 }
187 }
188 Ok(y)
189 }
190
191 pub fn get_column(&self, j: usize) -> Result<(&[usize], &[f64]), SolverError> {
193 if j >= self.ncols {
194 return Err(SolverError::IndexOutOfBounds {
195 context: "column",
196 index: j,
197 bound: self.ncols,
198 });
199 }
200 let start = self.col_ptr[j];
201 let end = self.col_ptr[j + 1];
202 Ok((&self.row_ind[start..end], &self.values[start..end]))
203 }
204
205 pub fn identity(n: usize) -> Self {
206 let col_ptr: Vec<usize> = (0..=n).collect();
207 let row_ind: Vec<usize> = (0..n).collect();
208 let values = vec![1.0; n];
209 Self {
210 col_ptr,
211 row_ind,
212 values,
213 nrows: n,
214 ncols: n,
215 }
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn test_from_triplets_basic() {
225 let rows = vec![0, 2, 1, 0, 2];
230 let cols = vec![0, 0, 1, 2, 2];
231 let vals = vec![1.0, 4.0, 3.0, 2.0, 5.0];
232
233 let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
234
235 assert_eq!(mat.nrows, 3);
236 assert_eq!(mat.ncols, 3);
237 assert_eq!(mat.nnz(), 5);
238
239 let (row_idx, values) = mat.get_column(0).unwrap();
241 assert_eq!(row_idx, &[0, 2]);
242 assert_eq!(values, &[1.0, 4.0]);
243
244 let (row_idx, values) = mat.get_column(1).unwrap();
246 assert_eq!(row_idx, &[1]);
247 assert_eq!(values, &[3.0]);
248
249 let (row_idx, values) = mat.get_column(2).unwrap();
251 assert_eq!(row_idx, &[0, 2]);
252 assert_eq!(values, &[2.0, 5.0]);
253 }
254
255 #[test]
256 fn test_from_triplets_duplicate_entries() {
257 let rows = vec![0, 0, 1];
259 let cols = vec![0, 0, 1];
260 let vals = vec![1.0, 2.0, 3.0];
261
262 let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 2, 2).unwrap();
263
264 let (row_idx, values) = mat.get_column(0).unwrap();
266 assert_eq!(row_idx, &[0]);
267 assert_eq!(values, &[3.0]);
268
269 let (row_idx, values) = mat.get_column(1).unwrap();
271 assert_eq!(row_idx, &[1]);
272 assert_eq!(values, &[3.0]);
273 }
274
275 #[test]
276 fn test_transpose() {
277 let rows = vec![0, 0, 1];
281 let cols = vec![0, 1, 2];
282 let vals = vec![1.0, 2.0, 3.0];
283
284 let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 2, 3).unwrap();
285 let mat_t = mat.transpose();
286
287 assert_eq!(mat_t.nrows, 3);
292 assert_eq!(mat_t.ncols, 2);
293 assert_eq!(mat_t.nnz(), 3);
294
295 let (row_idx, values) = mat_t.get_column(0).unwrap();
297 assert_eq!(row_idx, &[0, 1]);
298 assert_eq!(values, &[1.0, 2.0]);
299
300 let (row_idx, values) = mat_t.get_column(1).unwrap();
302 assert_eq!(row_idx, &[2]);
303 assert_eq!(values, &[3.0]);
304
305 let mat_tt = mat_t.transpose();
307 assert_eq!(mat_tt.nrows, mat.nrows);
308 assert_eq!(mat_tt.ncols, mat.ncols);
309 assert_eq!(mat_tt.row_ind, mat.row_ind);
310 assert_eq!(mat_tt.col_ptr, mat.col_ptr);
311 assert_eq!(mat_tt.values, mat.values);
312 }
313
314 #[test]
315 fn test_mat_vec_mul() {
316 let rows = vec![0, 2, 1, 0, 2];
321 let cols = vec![0, 0, 1, 2, 2];
322 let vals = vec![1.0, 4.0, 3.0, 2.0, 5.0];
323 let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
324
325 let x = vec![1.0, 2.0, 3.0];
326 let y = mat.mat_vec_mul(&x).unwrap();
327
328 assert_eq!(y.len(), 3);
331 assert!((y[0] - 7.0).abs() < 1e-10);
332 assert!((y[1] - 6.0).abs() < 1e-10);
333 assert!((y[2] - 19.0).abs() < 1e-10);
334 }
335
336 #[test]
337 fn test_mat_vec_mul_dimension_mismatch() {
338 let mat = CscMatrix::identity(3);
339 let x = vec![1.0, 2.0]; let result = mat.mat_vec_mul(&x);
341 assert!(result.is_err());
342 }
343
344 #[test]
345 fn test_identity() {
346 let id = CscMatrix::identity(4);
347 assert_eq!(id.nrows, 4);
348 assert_eq!(id.ncols, 4);
349 assert_eq!(id.nnz(), 4);
350
351 for j in 0..4 {
353 let (row_idx, values) = id.get_column(j).unwrap();
354 assert_eq!(row_idx, &[j]);
355 assert_eq!(values, &[1.0]);
356 }
357
358 let x = vec![1.0, 2.0, 3.0, 4.0];
360 let y = id.mat_vec_mul(&x).unwrap();
361 assert_eq!(y, x);
362 }
363
364 #[test]
365 fn test_empty_matrix() {
366 let mat = CscMatrix::from_triplets(&[], &[], &[], 2, 3).unwrap();
367 assert_eq!(mat.nrows, 2);
368 assert_eq!(mat.ncols, 3);
369 assert_eq!(mat.nnz(), 0);
370
371 for j in 0..3 {
373 let (row_idx, values) = mat.get_column(j).unwrap();
374 assert_eq!(row_idx.len(), 0);
375 assert_eq!(values.len(), 0);
376 }
377
378 let y = mat.mat_vec_mul(&[1.0, 2.0, 3.0]).unwrap();
380 assert_eq!(y, vec![0.0, 0.0]);
381 }
382
383 #[test]
384 fn test_get_column_out_of_bounds() {
385 let mat = CscMatrix::identity(3);
386 let result = mat.get_column(3);
387 assert!(result.is_err());
388 }
389
390 #[test]
391 fn test_from_triplets_out_of_bounds() {
392 let result = CscMatrix::from_triplets(&[0, 3], &[0, 0], &[1.0, 2.0], 3, 2);
394 assert!(result.is_err());
395
396 let result = CscMatrix::from_triplets(&[0, 0], &[0, 2], &[1.0, 2.0], 3, 2);
398 assert!(result.is_err());
399 }
400
401 #[test]
402 fn test_from_triplets_mismatched_lengths() {
403 let result = CscMatrix::from_triplets(&[0, 1], &[0], &[1.0, 2.0], 2, 2);
404 assert!(result.is_err());
405 }
406}