1use crate::coo::CooMatrix;
8use crate::coo_array::CooArray;
9use crate::error::{SparseError, SparseResult};
10use crate::sparray::SparseArray;
11use scirs2_core::numeric::{Float, SparseElement};
12use std::fmt::Debug;
13use std::ops::{Add, Div, Mul, Sub};
14
15#[derive(Debug, Clone)]
25pub struct SymCooMatrix<T>
26where
27 T: SparseElement + Float + Sub<Output = T> + PartialOrd + Clone,
28{
29 pub data: Vec<T>,
31
32 pub rows: Vec<usize>,
34
35 pub cols: Vec<usize>,
37
38 pub shape: (usize, usize),
40}
41
42impl<T> SymCooMatrix<T>
43where
44 T: SparseElement + Float + Sub<Output = T> + PartialOrd + Clone,
45{
46 pub fn new(
67 data: Vec<T>,
68 rows: Vec<usize>,
69 cols: Vec<usize>,
70 shape: (usize, usize),
71 ) -> SparseResult<Self> {
72 let (nrows, ncols) = shape;
73
74 if nrows != ncols {
76 return Err(SparseError::ValueError(
77 "Symmetric matrix must be square".to_string(),
78 ));
79 }
80
81 let nnz = data.len();
83 if rows.len() != nnz || cols.len() != nnz {
84 return Err(SparseError::ValueError(format!(
85 "Data ({}), row ({}) and column ({}) arrays must have same length",
86 nnz,
87 rows.len(),
88 cols.len()
89 )));
90 }
91
92 for i in 0..nnz {
94 let row = rows[i];
95 let col = cols[i];
96
97 if row >= nrows {
98 return Err(SparseError::IndexOutOfBounds {
99 index: (row, 0),
100 shape: (nrows, ncols),
101 });
102 }
103
104 if col >= ncols {
105 return Err(SparseError::IndexOutOfBounds {
106 index: (row, col),
107 shape: (nrows, ncols),
108 });
109 }
110
111 if col > row {
113 return Err(SparseError::ValueError(
114 "Symmetric COO should only store the lower triangular part".to_string(),
115 ));
116 }
117 }
118
119 Ok(Self {
120 data,
121 rows,
122 cols,
123 shape,
124 })
125 }
126
127 pub fn from_coo(matrix: &CooMatrix<T>) -> SparseResult<Self> {
140 let (rows, cols) = matrix.shape();
141
142 if rows != cols {
144 return Err(SparseError::ValueError(
145 "Symmetric matrix must be square".to_string(),
146 ));
147 }
148
149 if !Self::is_symmetric(matrix) {
151 return Err(SparseError::ValueError(
152 "Matrix must be symmetric to convert to SymCOO format".to_string(),
153 ));
154 }
155
156 let mut data = Vec::new();
158 let mut row_indices = Vec::new();
159 let mut col_indices = Vec::new();
160
161 let rowsvec = matrix.row_indices();
162 let cols_vec = matrix.col_indices();
163 let data_vec = matrix.data();
164
165 for i in 0..data_vec.len() {
166 let row = rowsvec[i];
167 let col = cols_vec[i];
168
169 if col <= row {
171 data.push(data_vec[i]);
172 row_indices.push(row);
173 col_indices.push(col);
174 }
175 }
176
177 Ok(Self {
178 data,
179 rows: row_indices,
180 cols: col_indices,
181 shape: (rows, cols),
182 })
183 }
184
185 pub fn is_symmetric(matrix: &CooMatrix<T>) -> bool {
195 let (rows, cols) = matrix.shape();
196
197 if rows != cols {
199 return false;
200 }
201
202 let dense = matrix.to_dense();
204
205 for i in 0..rows {
206 for j in 0..i {
207 let diff = (dense[i][j] - dense[j][i]).abs();
210 let epsilon = T::epsilon() * T::from(100.0).unwrap();
211 if diff > epsilon {
212 return false;
213 }
214 }
215 }
216
217 true
218 }
219
220 pub fn shape(&self) -> (usize, usize) {
226 self.shape
227 }
228
229 pub fn nnz_stored(&self) -> usize {
235 self.data.len()
236 }
237
238 pub fn nnz(&self) -> usize {
244 let mut count = 0;
245
246 for i in 0..self.data.len() {
247 let row = self.rows[i];
248 let col = self.cols[i];
249
250 if row == col {
251 count += 1;
253 } else {
254 count += 2;
256 }
257 }
258
259 count
260 }
261
262 pub fn get(&self, row: usize, col: usize) -> T {
273 if row >= self.shape.0 || col >= self.shape.1 {
275 return T::sparse_zero();
276 }
277
278 let (actual_row, actual_col) = if row < col { (col, row) } else { (row, col) };
281
282 for i in 0..self.data.len() {
284 if self.rows[i] == actual_row && self.cols[i] == actual_col {
285 return self.data[i];
286 }
287 }
288
289 T::sparse_zero()
290 }
291
292 pub fn to_coo(&self) -> SparseResult<CooMatrix<T>> {
298 let mut data = Vec::new();
299 let mut rows = Vec::new();
300 let mut cols = Vec::new();
301
302 data.extend_from_slice(&self.data);
304 rows.extend_from_slice(&self.rows);
305 cols.extend_from_slice(&self.cols);
306
307 for i in 0..self.data.len() {
309 let row = self.rows[i];
310 let col = self.cols[i];
311
312 if row != col {
314 data.push(self.data[i]);
316 rows.push(col);
317 cols.push(row);
318 }
319 }
320
321 CooMatrix::new(data, rows, cols, self.shape)
322 }
323
324 pub fn to_dense(&self) -> Vec<Vec<T>> {
330 let n = self.shape.0;
331 let mut dense = vec![vec![T::sparse_zero(); n]; n];
332
333 for i in 0..self.data.len() {
335 let row = self.rows[i];
336 let col = self.cols[i];
337 dense[row][col] = self.data[i];
338
339 if row != col {
341 dense[col][row] = self.data[i];
342 }
343 }
344
345 dense
346 }
347}
348
349#[derive(Clone)]
351pub struct SymCooArray<T>
352where
353 T: SparseElement + Float + Sub<Output = T> + PartialOrd + Clone,
354{
355 inner: SymCooMatrix<T>,
357}
358
359impl<T> SymCooArray<T>
360where
361 T: SparseElement + Float + Sub<Output = T> + PartialOrd + Clone + Div<Output = T> + 'static,
362{
363 pub fn new(matrix: SymCooMatrix<T>) -> Self {
373 Self { inner: matrix }
374 }
375
376 pub fn from_triplets(
390 rows: &[usize],
391 cols: &[usize],
392 data: &[T],
393 shape: (usize, usize),
394 enforce_symmetric: bool,
395 ) -> SparseResult<Self> {
396 if shape.0 != shape.1 {
397 return Err(SparseError::ValueError(
398 "Symmetric matrix must be square".to_string(),
399 ));
400 }
401
402 if !enforce_symmetric {
403 let n = shape.0;
405 let mut dense = vec![vec![T::sparse_zero(); n]; n];
406 let nnz = data.len().min(rows.len().min(cols.len()));
407
408 for i in 0..nnz {
410 let row = rows[i];
411 let col = cols[i];
412
413 if row >= n || col >= n {
414 return Err(SparseError::IndexOutOfBounds {
415 index: (row, col),
416 shape,
417 });
418 }
419
420 dense[row][col] = data[i];
421 }
422
423 for i in 0..n {
425 for j in 0..i {
426 if (dense[i][j] - dense[j][i]).abs() > T::epsilon() {
427 return Err(SparseError::ValueError(
428 "Input is not symmetric. Use enforce_symmetric=true to force symmetry"
429 .to_string(),
430 ));
431 }
432 }
433 }
434
435 let mut sym_data = Vec::new();
437 let mut sym_rows = Vec::new();
438 let mut sym_cols = Vec::new();
439
440 for (i, row) in dense.iter().enumerate().take(n) {
441 for (j, &val) in row.iter().enumerate().take(i + 1) {
442 if val != T::sparse_zero() {
443 sym_data.push(val);
444 sym_rows.push(i);
445 sym_cols.push(j);
446 }
447 }
448 }
449
450 let sym_coo = SymCooMatrix::new(sym_data, sym_rows, sym_cols, shape)?;
452 return Ok(Self { inner: sym_coo });
453 }
454
455 let n = shape.0;
457
458 let mut dense = vec![vec![T::sparse_zero(); n]; n];
460 let nnz = data.len();
461
462 for i in 0..nnz {
464 if i >= rows.len() || i >= cols.len() {
465 return Err(SparseError::ValueError(
466 "Inconsistent input arrays".to_string(),
467 ));
468 }
469
470 let row = rows[i];
471 let col = cols[i];
472
473 if row >= n || col >= n {
474 return Err(SparseError::IndexOutOfBounds {
475 index: (row, col),
476 shape: (n, n),
477 });
478 }
479
480 dense[row][col] = data[i];
481 }
482
483 for i in 0..n {
485 for j in 0..i {
486 let avg = (dense[i][j] + dense[j][i]) / (T::sparse_one() + T::sparse_one());
487 dense[i][j] = avg;
488 dense[j][i] = avg;
489 }
490 }
491
492 let mut sym_data = Vec::new();
494 let mut sym_rows = Vec::new();
495 let mut sym_cols = Vec::new();
496
497 for (i, row) in dense.iter().enumerate().take(n) {
498 for (j, &val) in row.iter().enumerate().take(i + 1) {
499 if val != T::sparse_zero() {
500 sym_data.push(val);
501 sym_rows.push(i);
502 sym_cols.push(j);
503 }
504 }
505 }
506
507 let sym_coo = SymCooMatrix::new(sym_data, sym_rows, sym_cols, shape)?;
508 Ok(Self { inner: sym_coo })
509 }
510
511 pub fn from_coo_array(array: &CooArray<T>) -> SparseResult<Self> {
521 let shape = array.shape();
522 let (rows, cols) = shape;
523
524 if rows != cols {
526 return Err(SparseError::ValueError(
527 "Symmetric matrix must be square".to_string(),
528 ));
529 }
530
531 let coomatrix = CooMatrix::new(
533 array.get_data().to_vec(),
534 array.get_rows().to_vec(),
535 array.get_cols().to_vec(),
536 shape,
537 )?;
538
539 let sym_coo = SymCooMatrix::from_coo(&coomatrix)?;
541
542 Ok(Self { inner: sym_coo })
543 }
544
545 pub fn inner(&self) -> &SymCooMatrix<T> {
551 &self.inner
552 }
553
554 pub fn data(&self) -> &[T] {
560 &self.inner.data
561 }
562
563 pub fn rows(&self) -> &[usize] {
569 &self.inner.rows
570 }
571
572 pub fn cols(&self) -> &[usize] {
578 &self.inner.cols
579 }
580
581 pub fn shape(&self) -> (usize, usize) {
587 self.inner.shape
588 }
589
590 pub fn to_coo_array(&self) -> SparseResult<CooArray<T>> {
596 let coo = self.inner.to_coo()?;
597
598 let rows = coo.row_indices();
600 let cols = coo.col_indices();
601 let data = coo.data();
602
603 CooArray::from_triplets(rows, cols, data, coo.shape(), false)
605 }
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611 use crate::sparray::SparseArray;
612
613 #[test]
614 fn test_sym_coo_creation() {
615 let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
621 let rows = vec![0, 1, 1, 2, 2];
622 let cols = vec![0, 0, 1, 1, 2];
623
624 let sym = SymCooMatrix::new(data, rows, cols, (3, 3)).unwrap();
625
626 assert_eq!(sym.shape(), (3, 3));
627 assert_eq!(sym.nnz_stored(), 5);
628
629 assert_eq!(sym.nnz(), 7);
631
632 assert_eq!(sym.get(0, 0), 2.0);
634 assert_eq!(sym.get(0, 1), 1.0);
635 assert_eq!(sym.get(1, 0), 1.0); assert_eq!(sym.get(1, 1), 2.0);
637 assert_eq!(sym.get(1, 2), 3.0);
638 assert_eq!(sym.get(2, 1), 3.0); assert_eq!(sym.get(2, 2), 1.0);
640 assert_eq!(sym.get(0, 2), 0.0);
641 assert_eq!(sym.get(2, 0), 0.0);
642 }
643
644 #[test]
645 fn test_sym_coo_from_standard() {
646 let data = vec![2.0, 1.0, 1.0, 2.0, 3.0, 3.0, 1.0];
652 let rows = vec![0, 0, 1, 1, 1, 2, 2];
653 let cols = vec![0, 1, 0, 1, 2, 1, 2];
654
655 let coo = CooMatrix::new(data, rows, cols, (3, 3)).unwrap();
656 let sym = SymCooMatrix::from_coo(&coo).unwrap();
657
658 assert_eq!(sym.shape(), (3, 3));
659
660 let coo2 = sym.to_coo().unwrap();
662 let dense = coo2.to_dense();
663
664 assert_eq!(dense[0][0], 2.0);
666 assert_eq!(dense[0][1], 1.0);
667 assert_eq!(dense[0][2], 0.0);
668 assert_eq!(dense[1][0], 1.0);
669 assert_eq!(dense[1][1], 2.0);
670 assert_eq!(dense[1][2], 3.0);
671 assert_eq!(dense[2][0], 0.0);
672 assert_eq!(dense[2][1], 3.0);
673 assert_eq!(dense[2][2], 1.0);
674 }
675
676 #[test]
677 fn test_sym_coo_array() {
678 let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
680 let rows = vec![0, 1, 1, 2, 2];
681 let cols = vec![0, 0, 1, 1, 2];
682
683 let symmatrix = SymCooMatrix::new(data, rows, cols, (3, 3)).unwrap();
684 let sym_array = SymCooArray::new(symmatrix);
685
686 assert_eq!(sym_array.inner().shape(), (3, 3));
687
688 let coo_array = sym_array.to_coo_array().unwrap();
690
691 assert_eq!(coo_array.shape(), (3, 3));
693 assert_eq!(coo_array.get(0, 0), 2.0);
694 assert_eq!(coo_array.get(0, 1), 1.0);
695 assert_eq!(coo_array.get(1, 0), 1.0);
696 assert_eq!(coo_array.get(1, 1), 2.0);
697 assert_eq!(coo_array.get(1, 2), 3.0);
698 assert_eq!(coo_array.get(2, 1), 3.0);
699 assert_eq!(coo_array.get(2, 2), 1.0);
700 assert_eq!(coo_array.get(0, 2), 0.0);
701 assert_eq!(coo_array.get(2, 0), 0.0);
702 }
703
704 #[test]
705 fn test_sym_coo_array_from_triplets() {
706 let rows = vec![0, 1, 1, 2, 1, 0, 2];
709 let cols = vec![0, 1, 2, 2, 0, 1, 1];
710 let data = vec![2.0, 2.0, 3.0, 1.0, 1.0, 1.0, 3.0];
711
712 let sym_array = SymCooArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
713
714 assert_eq!(sym_array.shape(), (3, 3));
715
716 let rows2 = vec![0, 0, 1, 1, 2, 1];
718 let cols2 = vec![0, 1, 1, 2, 2, 0];
719 let data2 = vec![2.0, 1.0, 2.0, 3.0, 1.0, 2.0]; let sym_array2 = SymCooArray::from_triplets(&rows2, &cols2, &data2, (3, 3), true).unwrap();
722
723 assert_eq!(sym_array2.inner().get(1, 0), 1.5);
725 assert_eq!(sym_array2.inner().get(0, 1), 1.5);
726 }
727}