1use oxiblas_core::scalar::{Field, Scalar};
14use std::ops::Index;
15
16#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum CscError {
19 InvalidColPtrs {
21 expected: usize,
23 actual: usize,
25 },
26 LengthMismatch {
28 values_len: usize,
30 row_indices_len: usize,
32 },
33 InvalidRowIndex {
35 index: usize,
37 nrows: usize,
39 },
40 InvalidColPtrOrder,
42 DuplicateEntry {
44 row: usize,
46 col: usize,
48 },
49}
50
51impl core::fmt::Display for CscError {
52 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
53 match self {
54 Self::InvalidColPtrs { expected, actual } => {
55 write!(
56 f,
57 "Invalid col_ptrs length: expected {expected}, got {actual}"
58 )
59 }
60 Self::LengthMismatch {
61 values_len,
62 row_indices_len,
63 } => {
64 write!(
65 f,
66 "Length mismatch: values={values_len}, row_indices={row_indices_len}"
67 )
68 }
69 Self::InvalidRowIndex { index, nrows } => {
70 write!(f, "Row index {index} out of bounds for {nrows} rows")
71 }
72 Self::InvalidColPtrOrder => {
73 write!(f, "Column pointers must be monotonically increasing")
74 }
75 Self::DuplicateEntry { row, col } => {
76 write!(f, "Duplicate entry at ({row}, {col})")
77 }
78 }
79 }
80}
81
82impl std::error::Error for CscError {}
83
84#[derive(Debug, Clone)]
92pub struct CscMatrix<T: Scalar> {
93 nrows: usize,
95 ncols: usize,
97 col_ptrs: Vec<usize>,
99 row_indices: Vec<usize>,
101 values: Vec<T>,
103}
104
105impl<T: Scalar + Clone> CscMatrix<T> {
106 pub fn new(
120 nrows: usize,
121 ncols: usize,
122 col_ptrs: Vec<usize>,
123 row_indices: Vec<usize>,
124 values: Vec<T>,
125 ) -> Result<Self, CscError> {
126 if col_ptrs.len() != ncols + 1 {
128 return Err(CscError::InvalidColPtrs {
129 expected: ncols + 1,
130 actual: col_ptrs.len(),
131 });
132 }
133
134 if values.len() != row_indices.len() {
136 return Err(CscError::LengthMismatch {
137 values_len: values.len(),
138 row_indices_len: row_indices.len(),
139 });
140 }
141
142 for i in 1..col_ptrs.len() {
144 if col_ptrs[i] < col_ptrs[i - 1] {
145 return Err(CscError::InvalidColPtrOrder);
146 }
147 }
148
149 let nnz = values.len();
151 if col_ptrs[ncols] != nnz {
152 return Err(CscError::InvalidColPtrs {
153 expected: nnz,
154 actual: col_ptrs[ncols],
155 });
156 }
157
158 for &row in &row_indices {
160 if row >= nrows {
161 return Err(CscError::InvalidRowIndex { index: row, nrows });
162 }
163 }
164
165 Ok(Self {
166 nrows,
167 ncols,
168 col_ptrs,
169 row_indices,
170 values,
171 })
172 }
173
174 #[inline]
184 pub unsafe fn new_unchecked(
185 nrows: usize,
186 ncols: usize,
187 col_ptrs: Vec<usize>,
188 row_indices: Vec<usize>,
189 values: Vec<T>,
190 ) -> Self {
191 Self {
192 nrows,
193 ncols,
194 col_ptrs,
195 row_indices,
196 values,
197 }
198 }
199
200 pub fn zeros(nrows: usize, ncols: usize) -> Self {
202 Self {
203 nrows,
204 ncols,
205 col_ptrs: vec![0; ncols + 1],
206 row_indices: Vec::new(),
207 values: Vec::new(),
208 }
209 }
210
211 pub fn eye(n: usize) -> Self
213 where
214 T: Field,
215 {
216 let mut col_ptrs = Vec::with_capacity(n + 1);
217 let mut row_indices = Vec::with_capacity(n);
218 let mut values = Vec::with_capacity(n);
219
220 for i in 0..n {
221 col_ptrs.push(i);
222 row_indices.push(i);
223 values.push(T::one());
224 }
225 col_ptrs.push(n);
226
227 Self {
228 nrows: n,
229 ncols: n,
230 col_ptrs,
231 row_indices,
232 values,
233 }
234 }
235
236 #[inline]
238 pub fn nrows(&self) -> usize {
239 self.nrows
240 }
241
242 #[inline]
244 pub fn ncols(&self) -> usize {
245 self.ncols
246 }
247
248 #[inline]
250 pub fn shape(&self) -> (usize, usize) {
251 (self.nrows, self.ncols)
252 }
253
254 #[inline]
256 pub fn nnz(&self) -> usize {
257 self.values.len()
258 }
259
260 #[inline]
262 pub fn density(&self) -> f64 {
263 if self.nrows == 0 || self.ncols == 0 {
264 0.0
265 } else {
266 self.nnz() as f64 / (self.nrows * self.ncols) as f64
267 }
268 }
269
270 #[inline]
272 pub fn col_ptrs(&self) -> &[usize] {
273 &self.col_ptrs
274 }
275
276 #[inline]
278 pub fn row_indices(&self) -> &[usize] {
279 &self.row_indices
280 }
281
282 #[inline]
284 pub fn values(&self) -> &[T] {
285 &self.values
286 }
287
288 #[inline]
290 pub fn values_mut(&mut self) -> &mut [T] {
291 &mut self.values
292 }
293
294 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
296 if row >= self.nrows || col >= self.ncols {
297 return None;
298 }
299
300 let start = self.col_ptrs[col];
301 let end = self.col_ptrs[col + 1];
302
303 for i in start..end {
304 if self.row_indices[i] == row {
305 return Some(&self.values[i]);
306 }
307 }
308
309 None
310 }
311
312 pub fn get_or_zero(&self, row: usize, col: usize) -> T
314 where
315 T: Field,
316 {
317 self.get(row, col).cloned().unwrap_or_else(T::zero)
318 }
319
320 pub fn col_iter(&self, col: usize) -> impl Iterator<Item = (usize, &T)> {
322 let start = self.col_ptrs[col];
323 let end = self.col_ptrs[col + 1];
324
325 self.row_indices[start..end]
326 .iter()
327 .zip(self.values[start..end].iter())
328 .map(|(&row, val)| (row, val))
329 }
330
331 pub fn iter(&self) -> impl Iterator<Item = (usize, usize, &T)> {
333 (0..self.ncols).flat_map(move |col| {
334 let start = self.col_ptrs[col];
335 let end = self.col_ptrs[col + 1];
336
337 self.row_indices[start..end]
338 .iter()
339 .zip(self.values[start..end].iter())
340 .map(move |(&row, val)| (row, col, val))
341 })
342 }
343
344 pub fn to_csr(&self) -> crate::csr::CsrMatrix<T> {
346 crate::convert::csc_to_csr(self)
347 }
348
349 pub fn to_dense(&self) -> oxiblas_matrix::Mat<T>
351 where
352 T: Field + bytemuck::Zeroable,
353 {
354 let mut dense = oxiblas_matrix::Mat::zeros(self.nrows, self.ncols);
355
356 for col in 0..self.ncols {
357 let start = self.col_ptrs[col];
358 let end = self.col_ptrs[col + 1];
359
360 for i in start..end {
361 dense[(self.row_indices[i], col)] = self.values[i].clone();
362 }
363 }
364
365 dense
366 }
367
368 pub fn from_dense(dense: &oxiblas_matrix::MatRef<'_, T>) -> Self
370 where
371 T: Field,
372 {
373 let (nrows, ncols) = dense.shape();
374 let mut col_ptrs = Vec::with_capacity(ncols + 1);
375 let mut row_indices = Vec::new();
376 let mut values = Vec::new();
377
378 let eps = <T as Scalar>::epsilon();
379
380 col_ptrs.push(0);
381
382 for j in 0..ncols {
383 for i in 0..nrows {
384 let val = dense[(i, j)].clone();
385 if Scalar::abs(val.clone()) > eps {
386 row_indices.push(i);
387 values.push(val);
388 }
389 }
390 col_ptrs.push(values.len());
391 }
392
393 Self {
394 nrows,
395 ncols,
396 col_ptrs,
397 row_indices,
398 values,
399 }
400 }
401
402 pub fn transpose(&self) -> Self {
404 let csr = self.to_csr();
407 csr.to_csc()
408 }
409
410 pub fn scale(&mut self, alpha: T) {
412 for val in &mut self.values {
413 *val = val.clone() * alpha.clone();
414 }
415 }
416
417 pub fn scaled(&self, alpha: T) -> Self {
419 let mut result = self.clone();
420 result.scale(alpha);
421 result
422 }
423
424 #[inline]
426 pub fn col_nnz(&self, col: usize) -> usize {
427 self.col_ptrs[col + 1] - self.col_ptrs[col]
428 }
429
430 pub fn is_structurally_symmetric(&self) -> bool {
434 if self.nrows != self.ncols {
435 return false;
436 }
437
438 for col in 0..self.ncols {
439 let start = self.col_ptrs[col];
440 let end = self.col_ptrs[col + 1];
441
442 for i in start..end {
443 let row = self.row_indices[i];
444 if self.get(col, row).is_none() {
445 return false;
446 }
447 }
448 }
449
450 true
451 }
452}
453
454impl<T: Scalar + Clone> Index<(usize, usize)> for CscMatrix<T>
455where
456 T: Field,
457{
458 type Output = T;
459
460 fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
461 self.get(row, col)
462 .expect("Index out of bounds or zero element")
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_csc_new() {
472 let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
479 let row_indices = vec![0, 2, 1, 0, 2];
480 let col_ptrs = vec![0, 2, 3, 5];
481
482 let csc = CscMatrix::new(3, 3, col_ptrs, row_indices, values).unwrap();
483
484 assert_eq!(csc.nrows(), 3);
485 assert_eq!(csc.ncols(), 3);
486 assert_eq!(csc.nnz(), 5);
487 }
488
489 #[test]
490 fn test_csc_get() {
491 let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
492 let row_indices = vec![0, 2, 1, 0, 2];
493 let col_ptrs = vec![0, 2, 3, 5];
494
495 let csc = CscMatrix::new(3, 3, col_ptrs, row_indices, values).unwrap();
496
497 assert_eq!(csc.get(0, 0), Some(&1.0));
498 assert_eq!(csc.get(2, 0), Some(&2.0));
499 assert_eq!(csc.get(1, 1), Some(&3.0));
500 assert_eq!(csc.get(0, 2), Some(&4.0));
501 assert_eq!(csc.get(2, 2), Some(&5.0));
502
503 assert_eq!(csc.get(1, 0), None);
505 assert_eq!(csc.get(0, 1), None);
506 }
507
508 #[test]
509 fn test_csc_zeros() {
510 let csc: CscMatrix<f64> = CscMatrix::zeros(5, 3);
511
512 assert_eq!(csc.nrows(), 5);
513 assert_eq!(csc.ncols(), 3);
514 assert_eq!(csc.nnz(), 0);
515 }
516
517 #[test]
518 fn test_csc_eye() {
519 let csc: CscMatrix<f64> = CscMatrix::eye(4);
520
521 assert_eq!(csc.nrows(), 4);
522 assert_eq!(csc.ncols(), 4);
523 assert_eq!(csc.nnz(), 4);
524
525 for i in 0..4 {
526 assert_eq!(csc.get(i, i), Some(&1.0));
527 }
528 }
529
530 #[test]
531 fn test_csc_density() {
532 let values = vec![1.0f64, 2.0, 3.0];
533 let row_indices = vec![0, 1, 2];
534 let col_ptrs = vec![0, 1, 2, 3];
535
536 let csc = CscMatrix::new(3, 3, col_ptrs, row_indices, values).unwrap();
537
538 let density = csc.density();
539 assert!((density - 3.0 / 9.0).abs() < 1e-10);
540 }
541
542 #[test]
543 fn test_csc_col_iter() {
544 let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
545 let row_indices = vec![0, 2, 1, 0, 2];
546 let col_ptrs = vec![0, 2, 3, 5];
547
548 let csc = CscMatrix::new(3, 3, col_ptrs, row_indices, values).unwrap();
549
550 let col0: Vec<_> = csc.col_iter(0).collect();
551 assert_eq!(col0, vec![(0, &1.0), (2, &2.0)]);
552
553 let col1: Vec<_> = csc.col_iter(1).collect();
554 assert_eq!(col1, vec![(1, &3.0)]);
555 }
556
557 #[test]
558 fn test_csc_scale() {
559 let values = vec![1.0f64, 2.0, 3.0];
560 let row_indices = vec![0, 1, 2];
561 let col_ptrs = vec![0, 1, 2, 3];
562
563 let mut csc = CscMatrix::new(3, 3, col_ptrs, row_indices, values).unwrap();
564 csc.scale(2.0);
565
566 assert_eq!(csc.values(), &[2.0, 4.0, 6.0]);
567 }
568
569 #[test]
570 fn test_csc_invalid_col_ptrs() {
571 let values = vec![1.0f64, 2.0];
572 let row_indices = vec![0, 1];
573 let col_ptrs = vec![0, 1]; let result = CscMatrix::new(2, 2, col_ptrs, row_indices, values);
576 assert!(matches!(result, Err(CscError::InvalidColPtrs { .. })));
577 }
578
579 #[test]
580 fn test_csc_invalid_row_index() {
581 let values = vec![1.0f64];
582 let row_indices = vec![5]; let col_ptrs = vec![0, 1];
584
585 let result = CscMatrix::new(3, 1, col_ptrs, row_indices, values);
586 assert!(matches!(result, Err(CscError::InvalidRowIndex { .. })));
587 }
588
589 #[test]
590 fn test_csc_col_nnz() {
591 let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
592 let row_indices = vec![0, 2, 1, 0, 2];
593 let col_ptrs = vec![0, 2, 3, 5];
594
595 let csc = CscMatrix::new(3, 3, col_ptrs, row_indices, values).unwrap();
596
597 assert_eq!(csc.col_nnz(0), 2);
598 assert_eq!(csc.col_nnz(1), 1);
599 assert_eq!(csc.col_nnz(2), 2);
600 }
601
602 #[test]
603 fn test_csc_structurally_symmetric() {
604 let values = vec![1.0f64, 2.0, 2.0, 3.0];
606 let row_indices = vec![0, 1, 0, 1];
607 let col_ptrs = vec![0, 2, 4];
608
609 let csc = CscMatrix::new(2, 2, col_ptrs, row_indices, values).unwrap();
610 assert!(csc.is_structurally_symmetric());
611 }
612}