1use super::compress::build_compressed_format;
2use crate::error::SolverError;
3
4#[derive(Debug, Clone)]
15pub struct CscMatrix {
16 pub(crate) col_ptr: Vec<usize>,
17 pub(crate) row_ind: Vec<usize>,
18 pub(crate) values: Vec<f64>,
19 pub(crate) nrows: usize,
20 pub(crate) ncols: usize,
21}
22
23impl CscMatrix {
24 pub fn new(nrows: usize, ncols: usize) -> Self {
32 Self {
33 col_ptr: vec![0; ncols + 1],
34 row_ind: Vec::new(),
35 values: Vec::new(),
36 nrows,
37 ncols,
38 }
39 }
40
41 pub fn nnz(&self) -> usize {
42 self.values.len()
43 }
44
45 pub fn col_ptr(&self) -> &[usize] {
46 &self.col_ptr
47 }
48
49 pub fn row_ind(&self) -> &[usize] {
50 &self.row_ind
51 }
52
53 pub fn values(&self) -> &[f64] {
54 &self.values
55 }
56
57 pub fn scale_values(&self, factor: f64) -> Self {
59 Self {
60 col_ptr: self.col_ptr.clone(),
61 row_ind: self.row_ind.clone(),
62 values: self.values.iter().map(|&v| v * factor).collect(),
63 nrows: self.nrows,
64 ncols: self.ncols,
65 }
66 }
67
68 pub fn nrows(&self) -> usize {
69 self.nrows
70 }
71
72 pub fn ncols(&self) -> usize {
73 self.ncols
74 }
75
76 pub fn row_infinity_norms(&self) -> Vec<f64> {
81 let mut norms = vec![0.0_f64; self.nrows];
82 for (&val, &row) in self.values.iter().zip(self.row_ind.iter()) {
83 let abs_val = val.abs();
84 if abs_val > norms[row] {
85 norms[row] = abs_val;
86 }
87 }
88 norms
89 }
90
91 pub fn from_triplets(
95 rows: &[usize],
96 cols: &[usize],
97 vals: &[f64],
98 nrows: usize,
99 ncols: usize,
100 ) -> Result<Self, SolverError> {
101 if rows.len() != cols.len() || rows.len() != vals.len() {
102 return Err(SolverError::DimensionMismatch {
103 field: "triplet_arrays",
104 expected: rows.len(),
105 got: vals.len(),
106 });
107 }
108 for (i, &v) in vals.iter().enumerate() {
109 if !v.is_finite() {
110 return Err(SolverError::NonFiniteCoefficient {
111 field: "matrix",
112 index: i,
113 });
114 }
115 }
116 let (col_ptr, row_ind, values) = build_compressed_format(ncols, nrows, cols, rows, vals)?;
118 Ok(Self {
119 col_ptr,
120 row_ind,
121 values,
122 nrows,
123 ncols,
124 })
125 }
126
127 pub fn transpose(&self) -> Self {
132 let nnz = self.nnz();
133 let mut row_count = vec![0usize; self.nrows];
136 for &r in &self.row_ind {
137 row_count[r] += 1;
138 }
139
140 let mut col_ptr = vec![0usize; self.nrows + 1];
142 for r in 0..self.nrows {
143 col_ptr[r + 1] = col_ptr[r] + row_count[r];
144 }
145
146 let mut row_ind = vec![0usize; nnz];
152 let mut values = vec![0.0f64; nnz];
153 let mut pos = col_ptr[..self.nrows].to_vec();
154
155 for col in 0..self.ncols {
156 let start = self.col_ptr[col];
157 let end = self.col_ptr[col + 1];
158 for k in start..end {
159 let row = self.row_ind[k];
160 let p = pos[row];
161 row_ind[p] = col;
162 values[p] = self.values[k];
163 pos[row] += 1;
164 }
165 }
166
167 Self {
168 col_ptr,
169 row_ind,
170 values,
171 nrows: self.ncols,
172 ncols: self.nrows,
173 }
174 }
175
176 pub fn mat_vec_mul(&self, x: &[f64]) -> Result<Vec<f64>, SolverError> {
178 if x.len() != self.ncols {
179 return Err(SolverError::DimensionMismatch {
180 field: "vector",
181 expected: self.ncols,
182 got: x.len(),
183 });
184 }
185
186 let mut y = vec![0.0; self.nrows];
187 for (col, &x_val) in x.iter().enumerate() {
188 let start = self.col_ptr[col];
189 let end = self.col_ptr[col + 1];
190 for idx in start..end {
191 let row = self.row_ind[idx];
192 let a_val = self.values[idx];
193 y[row] += a_val * x_val;
194 }
195 }
196 Ok(y)
197 }
198
199 pub fn get_column(&self, j: usize) -> Result<(&[usize], &[f64]), SolverError> {
201 if j >= self.ncols {
202 return Err(SolverError::IndexOutOfBounds {
203 context: "column",
204 index: j,
205 bound: self.ncols,
206 });
207 }
208 let start = self.col_ptr[j];
209 let end = self.col_ptr[j + 1];
210 Ok((&self.row_ind[start..end], &self.values[start..end]))
211 }
212
213 pub fn identity(n: usize) -> Self {
214 let col_ptr: Vec<usize> = (0..=n).collect();
215 let row_ind: Vec<usize> = (0..n).collect();
216 let values = vec![1.0; n];
217 Self {
218 col_ptr,
219 row_ind,
220 values,
221 nrows: n,
222 ncols: n,
223 }
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_from_triplets_basic() {
233 let rows = vec![0, 2, 1, 0, 2];
238 let cols = vec![0, 0, 1, 2, 2];
239 let vals = vec![1.0, 4.0, 3.0, 2.0, 5.0];
240
241 let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
242
243 assert_eq!(mat.nrows, 3);
244 assert_eq!(mat.ncols, 3);
245 assert_eq!(mat.nnz(), 5);
246
247 let (row_idx, values) = mat.get_column(0).unwrap();
249 assert_eq!(row_idx, &[0, 2]);
250 assert_eq!(values, &[1.0, 4.0]);
251
252 let (row_idx, values) = mat.get_column(1).unwrap();
254 assert_eq!(row_idx, &[1]);
255 assert_eq!(values, &[3.0]);
256
257 let (row_idx, values) = mat.get_column(2).unwrap();
259 assert_eq!(row_idx, &[0, 2]);
260 assert_eq!(values, &[2.0, 5.0]);
261 }
262
263 #[test]
264 fn test_from_triplets_duplicate_entries() {
265 let rows = vec![0, 0, 1];
267 let cols = vec![0, 0, 1];
268 let vals = vec![1.0, 2.0, 3.0];
269
270 let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 2, 2).unwrap();
271
272 let (row_idx, values) = mat.get_column(0).unwrap();
274 assert_eq!(row_idx, &[0]);
275 assert_eq!(values, &[3.0]);
276
277 let (row_idx, values) = mat.get_column(1).unwrap();
279 assert_eq!(row_idx, &[1]);
280 assert_eq!(values, &[3.0]);
281 }
282
283 #[test]
284 fn test_transpose() {
285 let rows = vec![0, 0, 1];
289 let cols = vec![0, 1, 2];
290 let vals = vec![1.0, 2.0, 3.0];
291
292 let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 2, 3).unwrap();
293 let mat_t = mat.transpose();
294
295 assert_eq!(mat_t.nrows, 3);
300 assert_eq!(mat_t.ncols, 2);
301 assert_eq!(mat_t.nnz(), 3);
302
303 let (row_idx, values) = mat_t.get_column(0).unwrap();
305 assert_eq!(row_idx, &[0, 1]);
306 assert_eq!(values, &[1.0, 2.0]);
307
308 let (row_idx, values) = mat_t.get_column(1).unwrap();
310 assert_eq!(row_idx, &[2]);
311 assert_eq!(values, &[3.0]);
312
313 let mat_tt = mat_t.transpose();
315 assert_eq!(mat_tt.nrows, mat.nrows);
316 assert_eq!(mat_tt.ncols, mat.ncols);
317 assert_eq!(mat_tt.row_ind, mat.row_ind);
318 assert_eq!(mat_tt.col_ptr, mat.col_ptr);
319 assert_eq!(mat_tt.values, mat.values);
320 }
321
322 #[test]
323 fn test_mat_vec_mul() {
324 let rows = vec![0, 2, 1, 0, 2];
329 let cols = vec![0, 0, 1, 2, 2];
330 let vals = vec![1.0, 4.0, 3.0, 2.0, 5.0];
331 let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
332
333 let x = vec![1.0, 2.0, 3.0];
334 let y = mat.mat_vec_mul(&x).unwrap();
335
336 assert_eq!(y.len(), 3);
339 assert!((y[0] - 7.0).abs() < 1e-10);
340 assert!((y[1] - 6.0).abs() < 1e-10);
341 assert!((y[2] - 19.0).abs() < 1e-10);
342 }
343
344 #[test]
345 fn test_mat_vec_mul_dimension_mismatch() {
346 let mat = CscMatrix::identity(3);
347 let x = vec![1.0, 2.0]; let result = mat.mat_vec_mul(&x);
349 assert!(result.is_err());
350 }
351
352 #[test]
353 fn test_identity() {
354 let id = CscMatrix::identity(4);
355 assert_eq!(id.nrows, 4);
356 assert_eq!(id.ncols, 4);
357 assert_eq!(id.nnz(), 4);
358
359 for j in 0..4 {
361 let (row_idx, values) = id.get_column(j).unwrap();
362 assert_eq!(row_idx, &[j]);
363 assert_eq!(values, &[1.0]);
364 }
365
366 let x = vec![1.0, 2.0, 3.0, 4.0];
368 let y = id.mat_vec_mul(&x).unwrap();
369 assert_eq!(y, x);
370 }
371
372 #[test]
373 fn test_empty_matrix() {
374 let mat = CscMatrix::from_triplets(&[], &[], &[], 2, 3).unwrap();
375 assert_eq!(mat.nrows, 2);
376 assert_eq!(mat.ncols, 3);
377 assert_eq!(mat.nnz(), 0);
378
379 for j in 0..3 {
381 let (row_idx, values) = mat.get_column(j).unwrap();
382 assert_eq!(row_idx.len(), 0);
383 assert_eq!(values.len(), 0);
384 }
385
386 let y = mat.mat_vec_mul(&[1.0, 2.0, 3.0]).unwrap();
388 assert_eq!(y, vec![0.0, 0.0]);
389 }
390
391 #[test]
392 fn test_get_column_out_of_bounds() {
393 let mat = CscMatrix::identity(3);
394 let result = mat.get_column(3);
395 assert!(result.is_err());
396 }
397
398 #[test]
399 fn test_from_triplets_out_of_bounds() {
400 let result = CscMatrix::from_triplets(&[0, 3], &[0, 0], &[1.0, 2.0], 3, 2);
402 assert!(result.is_err());
403
404 let result = CscMatrix::from_triplets(&[0, 0], &[0, 2], &[1.0, 2.0], 3, 2);
406 assert!(result.is_err());
407 }
408
409 #[test]
410 fn test_from_triplets_mismatched_lengths() {
411 let result = CscMatrix::from_triplets(&[0, 1], &[0], &[1.0, 2.0], 2, 2);
412 assert!(result.is_err());
413 }
414
415 #[test]
418 fn test_sentinel_triplet_non_finite_rejected() {
419 let r = CscMatrix::from_triplets(&[0], &[0], &[f64::NAN], 1, 1);
420 assert!(r.is_err(), "NaN in triplet vals must be rejected");
421 let r = CscMatrix::from_triplets(&[0], &[0], &[f64::INFINITY], 1, 1);
422 assert!(r.is_err(), "+Inf in triplet vals must be rejected");
423 let r = CscMatrix::from_triplets(&[0], &[0], &[f64::NEG_INFINITY], 1, 1);
424 assert!(r.is_err(), "-Inf in triplet vals must be rejected");
425 let r = CscMatrix::from_triplets(&[0], &[0], &[1.0], 1, 1);
427 assert!(r.is_ok(), "finite value must still be accepted");
428 }
429}