1use super::compress::build_compressed_format;
2use crate::error::SolverError;
3
4#[derive(Debug, Clone)]
15pub struct CscMatrix {
16 pub(crate) col_ptr: Vec<usize>,
17 pub(crate) row_ind: Vec<usize>,
18 pub(crate) values: Vec<f64>,
19 pub(crate) nrows: usize,
20 pub(crate) ncols: usize,
21}
22
23impl CscMatrix {
24 pub fn new(nrows: usize, ncols: usize) -> Self {
32 Self {
33 col_ptr: vec![0; ncols + 1],
34 row_ind: Vec::new(),
35 values: Vec::new(),
36 nrows,
37 ncols,
38 }
39 }
40
41 pub fn nnz(&self) -> usize {
43 self.values.len()
44 }
45
46 pub fn col_ptr(&self) -> &[usize] {
48 &self.col_ptr
49 }
50
51 pub fn row_ind(&self) -> &[usize] {
53 &self.row_ind
54 }
55
56 pub fn values(&self) -> &[f64] {
58 &self.values
59 }
60
61 pub fn scale_values(&self, factor: f64) -> Self {
63 Self {
64 col_ptr: self.col_ptr.clone(),
65 row_ind: self.row_ind.clone(),
66 values: self.values.iter().map(|&v| v * factor).collect(),
67 nrows: self.nrows,
68 ncols: self.ncols,
69 }
70 }
71
72 pub fn nrows(&self) -> usize {
74 self.nrows
75 }
76
77 pub fn ncols(&self) -> usize {
79 self.ncols
80 }
81
82 pub fn row_infinity_norms(&self) -> Vec<f64> {
87 let mut norms = vec![0.0_f64; self.nrows];
88 for (&val, &row) in self.values.iter().zip(self.row_ind.iter()) {
89 let abs_val = val.abs();
90 if abs_val > norms[row] {
91 norms[row] = abs_val;
92 }
93 }
94 norms
95 }
96
97 pub fn from_triplets(
113 rows: &[usize],
114 cols: &[usize],
115 vals: &[f64],
116 nrows: usize,
117 ncols: usize,
118 ) -> Result<Self, SolverError> {
119 if rows.len() != cols.len() || rows.len() != vals.len() {
120 return Err(SolverError::DimensionMismatch { field: "triplet_arrays", expected: rows.len(), got: vals.len() });
121 }
122 let (col_ptr, row_ind, values) =
124 build_compressed_format(ncols, nrows, cols, rows, vals)?;
125 Ok(Self { col_ptr, row_ind, values, nrows, ncols })
126 }
127
128 pub fn transpose(&self) -> Self {
133 let nnz = self.nnz();
134 let mut row_count = vec![0usize; self.nrows];
137 for &r in &self.row_ind {
138 row_count[r] += 1;
139 }
140
141 let mut col_ptr = vec![0usize; self.nrows + 1];
143 for r in 0..self.nrows {
144 col_ptr[r + 1] = col_ptr[r] + row_count[r];
145 }
146
147 let mut row_ind = vec![0usize; nnz];
153 let mut values = vec![0.0f64; nnz];
154 let mut pos = col_ptr[..self.nrows].to_vec();
155
156 for col in 0..self.ncols {
157 let start = self.col_ptr[col];
158 let end = self.col_ptr[col + 1];
159 for k in start..end {
160 let row = self.row_ind[k];
161 let p = pos[row];
162 row_ind[p] = col;
163 values[p] = self.values[k];
164 pos[row] += 1;
165 }
166 }
167
168 Self {
169 col_ptr,
170 row_ind,
171 values,
172 nrows: self.ncols,
173 ncols: self.nrows,
174 }
175 }
176
177 pub fn mat_vec_mul(&self, x: &[f64]) -> Result<Vec<f64>, SolverError> {
188 if x.len() != self.ncols {
189 return Err(SolverError::DimensionMismatch { field: "vector", expected: self.ncols, got: x.len() });
190 }
191
192 let mut y = vec![0.0; self.nrows];
193 for (col, &x_val) in x.iter().enumerate() {
194 let start = self.col_ptr[col];
195 let end = self.col_ptr[col + 1];
196 for idx in start..end {
197 let row = self.row_ind[idx];
198 let a_val = self.values[idx];
199 y[row] += a_val * x_val;
200 }
201 }
202 Ok(y)
203 }
204
205 pub fn get_column(&self, j: usize) -> Result<(&[usize], &[f64]), SolverError> {
217 if j >= self.ncols {
218 return Err(SolverError::IndexOutOfBounds { context: "column", index: j, bound: self.ncols });
219 }
220 let start = self.col_ptr[j];
221 let end = self.col_ptr[j + 1];
222 Ok((&self.row_ind[start..end], &self.values[start..end]))
223 }
224
225 pub fn identity(n: usize) -> Self {
232 let col_ptr: Vec<usize> = (0..=n).collect();
233 let row_ind: Vec<usize> = (0..n).collect();
234 let values = vec![1.0; n];
235 Self {
236 col_ptr,
237 row_ind,
238 values,
239 nrows: n,
240 ncols: n,
241 }
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_from_triplets_basic() {
251 let rows = vec![0, 2, 1, 0, 2];
256 let cols = vec![0, 0, 1, 2, 2];
257 let vals = vec![1.0, 4.0, 3.0, 2.0, 5.0];
258
259 let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
260
261 assert_eq!(mat.nrows, 3);
262 assert_eq!(mat.ncols, 3);
263 assert_eq!(mat.nnz(), 5);
264
265 let (row_idx, values) = mat.get_column(0).unwrap();
267 assert_eq!(row_idx, &[0, 2]);
268 assert_eq!(values, &[1.0, 4.0]);
269
270 let (row_idx, values) = mat.get_column(1).unwrap();
272 assert_eq!(row_idx, &[1]);
273 assert_eq!(values, &[3.0]);
274
275 let (row_idx, values) = mat.get_column(2).unwrap();
277 assert_eq!(row_idx, &[0, 2]);
278 assert_eq!(values, &[2.0, 5.0]);
279 }
280
281 #[test]
282 fn test_from_triplets_duplicate_entries() {
283 let rows = vec![0, 0, 1];
285 let cols = vec![0, 0, 1];
286 let vals = vec![1.0, 2.0, 3.0];
287
288 let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 2, 2).unwrap();
289
290 let (row_idx, values) = mat.get_column(0).unwrap();
292 assert_eq!(row_idx, &[0]);
293 assert_eq!(values, &[3.0]);
294
295 let (row_idx, values) = mat.get_column(1).unwrap();
297 assert_eq!(row_idx, &[1]);
298 assert_eq!(values, &[3.0]);
299 }
300
301 #[test]
302 fn test_transpose() {
303 let rows = vec![0, 0, 1];
307 let cols = vec![0, 1, 2];
308 let vals = vec![1.0, 2.0, 3.0];
309
310 let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 2, 3).unwrap();
311 let mat_t = mat.transpose();
312
313 assert_eq!(mat_t.nrows, 3);
318 assert_eq!(mat_t.ncols, 2);
319 assert_eq!(mat_t.nnz(), 3);
320
321 let (row_idx, values) = mat_t.get_column(0).unwrap();
323 assert_eq!(row_idx, &[0, 1]);
324 assert_eq!(values, &[1.0, 2.0]);
325
326 let (row_idx, values) = mat_t.get_column(1).unwrap();
328 assert_eq!(row_idx, &[2]);
329 assert_eq!(values, &[3.0]);
330
331 let mat_tt = mat_t.transpose();
333 assert_eq!(mat_tt.nrows, mat.nrows);
334 assert_eq!(mat_tt.ncols, mat.ncols);
335 assert_eq!(mat_tt.row_ind, mat.row_ind);
336 assert_eq!(mat_tt.col_ptr, mat.col_ptr);
337 assert_eq!(mat_tt.values, mat.values);
338 }
339
340 #[test]
341 fn test_mat_vec_mul() {
342 let rows = vec![0, 2, 1, 0, 2];
347 let cols = vec![0, 0, 1, 2, 2];
348 let vals = vec![1.0, 4.0, 3.0, 2.0, 5.0];
349 let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
350
351 let x = vec![1.0, 2.0, 3.0];
352 let y = mat.mat_vec_mul(&x).unwrap();
353
354 assert_eq!(y.len(), 3);
357 assert!((y[0] - 7.0).abs() < 1e-10);
358 assert!((y[1] - 6.0).abs() < 1e-10);
359 assert!((y[2] - 19.0).abs() < 1e-10);
360 }
361
362 #[test]
363 fn test_mat_vec_mul_dimension_mismatch() {
364 let mat = CscMatrix::identity(3);
365 let x = vec![1.0, 2.0]; let result = mat.mat_vec_mul(&x);
367 assert!(result.is_err());
368 }
369
370 #[test]
371 fn test_identity() {
372 let id = CscMatrix::identity(4);
373 assert_eq!(id.nrows, 4);
374 assert_eq!(id.ncols, 4);
375 assert_eq!(id.nnz(), 4);
376
377 for j in 0..4 {
379 let (row_idx, values) = id.get_column(j).unwrap();
380 assert_eq!(row_idx, &[j]);
381 assert_eq!(values, &[1.0]);
382 }
383
384 let x = vec![1.0, 2.0, 3.0, 4.0];
386 let y = id.mat_vec_mul(&x).unwrap();
387 assert_eq!(y, x);
388 }
389
390 #[test]
391 fn test_empty_matrix() {
392 let mat = CscMatrix::from_triplets(&[], &[], &[], 2, 3).unwrap();
393 assert_eq!(mat.nrows, 2);
394 assert_eq!(mat.ncols, 3);
395 assert_eq!(mat.nnz(), 0);
396
397 for j in 0..3 {
399 let (row_idx, values) = mat.get_column(j).unwrap();
400 assert_eq!(row_idx.len(), 0);
401 assert_eq!(values.len(), 0);
402 }
403
404 let y = mat.mat_vec_mul(&[1.0, 2.0, 3.0]).unwrap();
406 assert_eq!(y, vec![0.0, 0.0]);
407 }
408
409 #[test]
410 fn test_get_column_out_of_bounds() {
411 let mat = CscMatrix::identity(3);
412 let result = mat.get_column(3);
413 assert!(result.is_err());
414 }
415
416 #[test]
417 fn test_from_triplets_out_of_bounds() {
418 let result = CscMatrix::from_triplets(&[0, 3], &[0, 0], &[1.0, 2.0], 3, 2);
420 assert!(result.is_err());
421
422 let result = CscMatrix::from_triplets(&[0, 0], &[0, 2], &[1.0, 2.0], 3, 2);
424 assert!(result.is_err());
425 }
426
427 #[test]
428 fn test_from_triplets_mismatched_lengths() {
429 let result = CscMatrix::from_triplets(&[0, 1], &[0], &[1.0, 2.0], 2, 2);
430 assert!(result.is_err());
431 }
432}