1use crate::csr_array::CsrArray;
17use crate::error::{SparseError, SparseResult};
18use crate::sparray::SparseArray;
19use scirs2_core::numeric::{Float, SparseElement};
20use std::fmt::Debug;
21use std::ops::Div;
22
23pub fn sparse_eye<T>(n: usize) -> SparseResult<CsrArray<T>>
40where
41 T: Float + SparseElement + Div<Output = T> + 'static,
42{
43 if n == 0 {
44 return Err(SparseError::ValueError(
45 "Matrix dimension must be positive".to_string(),
46 ));
47 }
48
49 let rows: Vec<usize> = (0..n).collect();
50 let cols: Vec<usize> = (0..n).collect();
51 let data: Vec<T> = vec![T::sparse_one(); n];
52
53 CsrArray::from_triplets(&rows, &cols, &data, (n, n), true)
54}
55
56pub fn sparse_eye_rect<T>(m: usize, n: usize) -> SparseResult<CsrArray<T>>
64where
65 T: Float + SparseElement + Div<Output = T> + 'static,
66{
67 if m == 0 || n == 0 {
68 return Err(SparseError::ValueError(
69 "Matrix dimensions must be positive".to_string(),
70 ));
71 }
72
73 let diag_len = m.min(n);
74 let rows: Vec<usize> = (0..diag_len).collect();
75 let cols: Vec<usize> = (0..diag_len).collect();
76 let data: Vec<T> = vec![T::sparse_one(); diag_len];
77
78 CsrArray::from_triplets(&rows, &cols, &data, (m, n), true)
79}
80
81pub fn sparse_random(
92 m: usize,
93 n: usize,
94 density: f64,
95 seed: Option<u64>,
96) -> SparseResult<CsrArray<f64>> {
97 if m == 0 || n == 0 {
98 return Err(SparseError::ValueError(
99 "Matrix dimensions must be positive".to_string(),
100 ));
101 }
102 if !(0.0..=1.0).contains(&density) {
103 return Err(SparseError::ValueError(
104 "Density must be between 0.0 and 1.0".to_string(),
105 ));
106 }
107
108 let total_elements = m * n;
109 let nnz_target = (density * total_elements as f64).round() as usize;
110
111 if nnz_target == 0 {
112 let rows: Vec<usize> = Vec::new();
114 let cols: Vec<usize> = Vec::new();
115 let data: Vec<f64> = Vec::new();
116 return CsrArray::from_triplets(&rows, &cols, &data, (m, n), false);
117 }
118
119 use scirs2_core::random::{Rng, SeedableRng};
120 let mut rng = match seed {
121 Some(s) => scirs2_core::random::StdRng::seed_from_u64(s),
122 None => scirs2_core::random::StdRng::seed_from_u64(42),
123 };
124
125 let mut positions: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
127
128 if density < 0.5 {
131 while positions.len() < nnz_target {
132 let r = rng.random_range(0..m);
133 let c = rng.random_range(0..n);
134 positions.insert((r, c));
135 }
136 } else {
137 let mut all_positions: Vec<(usize, usize)> = Vec::with_capacity(total_elements);
139 for r in 0..m {
140 for c in 0..n {
141 all_positions.push((r, c));
142 }
143 }
144 for i in 0..nnz_target.min(all_positions.len()) {
146 let j = rng.random_range(i..all_positions.len());
147 all_positions.swap(i, j);
148 positions.insert(all_positions[i]);
149 }
150 }
151
152 let mut rows: Vec<usize> = Vec::with_capacity(nnz_target);
153 let mut cols: Vec<usize> = Vec::with_capacity(nnz_target);
154 let mut data: Vec<f64> = Vec::with_capacity(nnz_target);
155
156 for (r, c) in positions {
157 rows.push(r);
158 cols.push(c);
159 data.push(rng.random::<f64>());
160 }
161
162 CsrArray::from_triplets(&rows, &cols, &data, (m, n), false)
163}
164
165pub fn sparse_kron<T>(a: &CsrArray<T>, b: &CsrArray<T>) -> SparseResult<CsrArray<T>>
184where
185 T: Float + SparseElement + Div<Output = T> + Debug + Copy + 'static,
186{
187 let (m, n) = a.shape();
188 let (p, q) = b.shape();
189 let result_rows = m * p;
190 let result_cols = n * q;
191
192 let (a_rows, a_cols, a_vals) = a.find();
193 let (b_rows, b_cols, b_vals) = b.find();
194
195 let a_nnz = a_vals.len();
196 let b_nnz = b_vals.len();
197
198 let mut rows = Vec::with_capacity(a_nnz * b_nnz);
199 let mut cols = Vec::with_capacity(a_nnz * b_nnz);
200 let mut data = Vec::with_capacity(a_nnz * b_nnz);
201
202 for i in 0..a_nnz {
203 let ar = a_rows[i];
204 let ac = a_cols[i];
205 let av = a_vals[i];
206
207 for j in 0..b_nnz {
208 let br = b_rows[j];
209 let bc = b_cols[j];
210 let bv = b_vals[j];
211
212 rows.push(ar * p + br);
213 cols.push(ac * q + bc);
214 data.push(av * bv);
215 }
216 }
217
218 CsrArray::from_triplets(&rows, &cols, &data, (result_rows, result_cols), false)
219}
220
221pub fn sparse_hstack<T>(arrays: &[&CsrArray<T>]) -> SparseResult<CsrArray<T>>
228where
229 T: Float + SparseElement + Div<Output = T> + Debug + Copy + 'static,
230{
231 if arrays.is_empty() {
232 return Err(SparseError::ValueError(
233 "Cannot stack empty list of arrays".to_string(),
234 ));
235 }
236
237 let m = arrays[0].shape().0;
238 for (idx, &arr) in arrays.iter().enumerate().skip(1) {
239 if arr.shape().0 != m {
240 return Err(SparseError::DimensionMismatch {
241 expected: m,
242 found: arr.shape().0,
243 });
244 }
245 }
246
247 let total_cols: usize = arrays.iter().map(|a| a.shape().1).sum();
248
249 let mut rows = Vec::new();
250 let mut cols = Vec::new();
251 let mut data = Vec::new();
252
253 let mut col_offset = 0usize;
254 for &arr in arrays {
255 let (ar, ac, av) = arr.find();
256 for i in 0..av.len() {
257 rows.push(ar[i]);
258 cols.push(ac[i] + col_offset);
259 data.push(av[i]);
260 }
261 col_offset += arr.shape().1;
262 }
263
264 CsrArray::from_triplets(&rows, &cols, &data, (m, total_cols), false)
265}
266
267pub fn sparse_vstack<T>(arrays: &[&CsrArray<T>]) -> SparseResult<CsrArray<T>>
274where
275 T: Float + SparseElement + Div<Output = T> + Debug + Copy + 'static,
276{
277 if arrays.is_empty() {
278 return Err(SparseError::ValueError(
279 "Cannot stack empty list of arrays".to_string(),
280 ));
281 }
282
283 let n = arrays[0].shape().1;
284 for (idx, &arr) in arrays.iter().enumerate().skip(1) {
285 if arr.shape().1 != n {
286 return Err(SparseError::DimensionMismatch {
287 expected: n,
288 found: arr.shape().1,
289 });
290 }
291 }
292
293 let total_rows: usize = arrays.iter().map(|a| a.shape().0).sum();
294
295 let mut rows = Vec::new();
296 let mut cols = Vec::new();
297 let mut data = Vec::new();
298
299 let mut row_offset = 0usize;
300 for &arr in arrays {
301 let (ar, ac, av) = arr.find();
302 for i in 0..av.len() {
303 rows.push(ar[i] + row_offset);
304 cols.push(ac[i]);
305 data.push(av[i]);
306 }
307 row_offset += arr.shape().0;
308 }
309
310 CsrArray::from_triplets(&rows, &cols, &data, (total_rows, n), false)
311}
312
313pub fn sparse_block_diag<T>(arrays: &[&CsrArray<T>]) -> SparseResult<CsrArray<T>>
333where
334 T: Float + SparseElement + Div<Output = T> + Debug + Copy + 'static,
335{
336 if arrays.is_empty() {
337 return Err(SparseError::ValueError(
338 "Cannot create block diagonal from empty list".to_string(),
339 ));
340 }
341
342 let total_rows: usize = arrays.iter().map(|a| a.shape().0).sum();
343 let total_cols: usize = arrays.iter().map(|a| a.shape().1).sum();
344
345 let mut rows = Vec::new();
346 let mut cols = Vec::new();
347 let mut data = Vec::new();
348
349 let mut row_offset = 0usize;
350 let mut col_offset = 0usize;
351
352 for &arr in arrays {
353 let (ar, ac, av) = arr.find();
354 for i in 0..av.len() {
355 rows.push(ar[i] + row_offset);
356 cols.push(ac[i] + col_offset);
357 data.push(av[i]);
358 }
359 row_offset += arr.shape().0;
360 col_offset += arr.shape().1;
361 }
362
363 CsrArray::from_triplets(&rows, &cols, &data, (total_rows, total_cols), false)
364}
365
366pub fn sparse_diags<T>(
388 diags: &[&[T]],
389 offsets: &[isize],
390 shape: (usize, usize),
391) -> SparseResult<CsrArray<T>>
392where
393 T: Float + SparseElement + Div<Output = T> + Debug + Copy + 'static,
394{
395 if diags.len() != offsets.len() {
396 return Err(SparseError::DimensionMismatch {
397 expected: offsets.len(),
398 found: diags.len(),
399 });
400 }
401
402 let (nrows, ncols) = shape;
403 if nrows == 0 || ncols == 0 {
404 return Err(SparseError::ValueError(
405 "Matrix dimensions must be positive".to_string(),
406 ));
407 }
408
409 let mut rows = Vec::new();
410 let mut cols = Vec::new();
411 let mut data = Vec::new();
412
413 for (d, &offset) in offsets.iter().enumerate() {
414 let diag = diags[d];
415 if offset >= 0 {
416 let off = offset as usize;
417 let diag_len = nrows.min(ncols.saturating_sub(off));
418 if diag.len() < diag_len {
419 return Err(SparseError::DimensionMismatch {
420 expected: diag_len,
421 found: diag.len(),
422 });
423 }
424 for k in 0..diag_len {
425 let v = diag[k];
426 if !SparseElement::is_zero(&v) {
427 rows.push(k);
428 cols.push(k + off);
429 data.push(v);
430 }
431 }
432 } else {
433 let off = (-offset) as usize;
434 let diag_len = ncols.min(nrows.saturating_sub(off));
435 if diag.len() < diag_len {
436 return Err(SparseError::DimensionMismatch {
437 expected: diag_len,
438 found: diag.len(),
439 });
440 }
441 for k in 0..diag_len {
442 let v = diag[k];
443 if !SparseElement::is_zero(&v) {
444 rows.push(k + off);
445 cols.push(k);
446 data.push(v);
447 }
448 }
449 }
450 }
451
452 CsrArray::from_triplets(&rows, &cols, &data, shape, false)
453}
454
455pub fn sparse_diag_matrix<T>(
462 diag: &[T],
463 offset: isize,
464 shape: Option<(usize, usize)>,
465) -> SparseResult<CsrArray<T>>
466where
467 T: Float + SparseElement + Div<Output = T> + Debug + Copy + 'static,
468{
469 let n = diag.len();
470 let (nrows, ncols) = shape.unwrap_or_else(|| {
471 if offset >= 0 {
472 (n, n + offset as usize)
473 } else {
474 (n + (-offset) as usize, n)
475 }
476 });
477
478 sparse_diags(&[diag], &[offset], (nrows, ncols))
479}
480
481pub fn sparse_kronsum<T>(a: &CsrArray<T>, b: &CsrArray<T>) -> SparseResult<CsrArray<T>>
489where
490 T: Float + SparseElement + Div<Output = T> + Debug + Copy + 'static,
491{
492 let (p, pa) = a.shape();
493 let (q, qb) = b.shape();
494
495 if p != pa {
496 return Err(SparseError::ValueError(
497 "First matrix must be square for Kronecker sum".to_string(),
498 ));
499 }
500 if q != qb {
501 return Err(SparseError::ValueError(
502 "Second matrix must be square for Kronecker sum".to_string(),
503 ));
504 }
505
506 let iq = sparse_eye::<T>(q)?;
507 let ip = sparse_eye::<T>(p)?;
508
509 let a_kron_iq = sparse_kron(a, &iq)?;
510 let ip_kron_b = sparse_kron(&ip, b)?;
511
512 let result = a_kron_iq.add(&ip_kron_b)?;
514
515 let (rr, rc, rv) = result.find();
517 let rows_vec: Vec<usize> = rr.to_vec();
518 let cols_vec: Vec<usize> = rc.to_vec();
519 let vals_vec: Vec<T> = rv.to_vec();
520 let shape = result.shape();
521
522 CsrArray::from_triplets(&rows_vec, &cols_vec, &vals_vec, shape, false)
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528 use approx::assert_relative_eq;
529
530 #[test]
531 fn test_sparse_eye() {
532 let eye = sparse_eye::<f64>(4).expect("eye");
533 assert_eq!(eye.shape(), (4, 4));
534 assert_eq!(eye.nnz(), 4);
535 for i in 0..4 {
536 assert_relative_eq!(eye.get(i, i), 1.0);
537 if i > 0 {
538 assert_relative_eq!(eye.get(i, i - 1), 0.0);
539 }
540 }
541 }
542
543 #[test]
544 fn test_sparse_eye_rect() {
545 let eye = sparse_eye_rect::<f64>(3, 5).expect("eye_rect");
546 assert_eq!(eye.shape(), (3, 5));
547 assert_eq!(eye.nnz(), 3);
548 for i in 0..3 {
549 assert_relative_eq!(eye.get(i, i), 1.0);
550 }
551 assert_relative_eq!(eye.get(0, 3), 0.0);
552 }
553
554 #[test]
555 fn test_sparse_random() {
556 let mat = sparse_random(10, 10, 0.3, Some(42)).expect("random");
557 assert_eq!(mat.shape(), (10, 10));
558 let nnz = mat.nnz();
559 assert!(nnz > 10 && nnz < 50);
561 }
562
563 #[test]
564 fn test_sparse_random_empty() {
565 let mat = sparse_random(5, 5, 0.0, Some(1)).expect("random empty");
566 assert_eq!(mat.nnz(), 0);
567 }
568
569 #[test]
570 fn test_sparse_random_full() {
571 let mat = sparse_random(3, 3, 1.0, Some(1)).expect("random full");
572 assert_eq!(mat.shape(), (3, 3));
573 assert_eq!(mat.nnz(), 9);
574 }
575
576 #[test]
577 fn test_sparse_kron_identity() {
578 let i2 = sparse_eye::<f64>(2).expect("eye");
579 let result = sparse_kron(&i2, &i2).expect("kron");
580 assert_eq!(result.shape(), (4, 4));
581 assert_eq!(result.nnz(), 4);
582
583 for i in 0..4 {
585 assert_relative_eq!(result.get(i, i), 1.0);
586 for j in 0..4 {
587 if i != j {
588 assert_relative_eq!(result.get(i, j), 0.0);
589 }
590 }
591 }
592 }
593
594 #[test]
595 fn test_sparse_kron_general() {
596 let a = CsrArray::from_triplets(
598 &[0, 0, 1, 1],
599 &[0, 1, 0, 1],
600 &[1.0, 2.0, 3.0, 4.0],
601 (2, 2),
602 false,
603 )
604 .expect("a");
605
606 let b = CsrArray::from_triplets(&[0, 1, 1], &[1, 0, 1], &[5.0, 6.0, 7.0], (2, 2), false)
607 .expect("b");
608
609 let result = sparse_kron(&a, &b).expect("kron");
610 assert_eq!(result.shape(), (4, 4));
611
612 assert_relative_eq!(result.get(0, 0), 0.0);
621 assert_relative_eq!(result.get(0, 1), 5.0);
622 assert_relative_eq!(result.get(0, 2), 0.0);
623 assert_relative_eq!(result.get(0, 3), 10.0);
624 assert_relative_eq!(result.get(1, 0), 6.0);
625 assert_relative_eq!(result.get(3, 3), 28.0);
626 }
627
628 #[test]
629 fn test_sparse_hstack() {
630 let a =
631 CsrArray::from_triplets(&[0, 1], &[0, 1], &[1.0f64, 2.0], (2, 2), false).expect("a");
632
633 let b =
634 CsrArray::from_triplets(&[0, 1], &[0, 0], &[3.0f64, 4.0], (2, 1), false).expect("b");
635
636 let result = sparse_hstack(&[&a, &b]).expect("hstack");
637 assert_eq!(result.shape(), (2, 3));
638 assert_relative_eq!(result.get(0, 0), 1.0);
639 assert_relative_eq!(result.get(1, 1), 2.0);
640 assert_relative_eq!(result.get(0, 2), 3.0);
641 assert_relative_eq!(result.get(1, 2), 4.0);
642 }
643
644 #[test]
645 fn test_sparse_vstack() {
646 let a =
647 CsrArray::from_triplets(&[0, 0], &[0, 1], &[1.0f64, 2.0], (1, 3), false).expect("a");
648
649 let b =
650 CsrArray::from_triplets(&[0, 1], &[1, 2], &[3.0f64, 4.0], (2, 3), false).expect("b");
651
652 let result = sparse_vstack(&[&a, &b]).expect("vstack");
653 assert_eq!(result.shape(), (3, 3));
654 assert_relative_eq!(result.get(0, 0), 1.0);
655 assert_relative_eq!(result.get(0, 1), 2.0);
656 assert_relative_eq!(result.get(1, 1), 3.0);
657 assert_relative_eq!(result.get(2, 2), 4.0);
658 }
659
660 #[test]
661 fn test_sparse_block_diag() {
662 let a = sparse_eye::<f64>(2).expect("eye");
663 let b = CsrArray::from_triplets(
664 &[0, 0, 1, 1, 2, 2],
665 &[0, 1, 0, 1, 0, 1],
666 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
667 (3, 2),
668 false,
669 )
670 .expect("b");
671
672 let result = sparse_block_diag(&[&a, &b]).expect("block_diag");
673 assert_eq!(result.shape(), (5, 4));
674
675 assert_relative_eq!(result.get(0, 0), 1.0);
677 assert_relative_eq!(result.get(1, 1), 1.0);
678 assert_relative_eq!(result.get(0, 1), 0.0);
679 assert_relative_eq!(result.get(1, 0), 0.0);
680
681 assert_relative_eq!(result.get(2, 2), 1.0);
683 assert_relative_eq!(result.get(2, 3), 2.0);
684 assert_relative_eq!(result.get(4, 3), 6.0);
685
686 assert_relative_eq!(result.get(0, 2), 0.0);
688 assert_relative_eq!(result.get(2, 0), 0.0);
689 }
690
691 #[test]
692 fn test_sparse_diags() {
693 let main = [2.0f64, 2.0, 2.0];
694 let upper = [-1.0f64, -1.0];
695 let lower = [-1.0f64, -1.0];
696
697 let a =
698 sparse_diags(&[&lower[..], &main[..], &upper[..]], &[-1, 0, 1], (3, 3)).expect("diags");
699
700 assert_eq!(a.shape(), (3, 3));
701 assert_relative_eq!(a.get(0, 0), 2.0);
702 assert_relative_eq!(a.get(0, 1), -1.0);
703 assert_relative_eq!(a.get(1, 0), -1.0);
704 assert_relative_eq!(a.get(1, 1), 2.0);
705 assert_relative_eq!(a.get(1, 2), -1.0);
706 assert_relative_eq!(a.get(2, 1), -1.0);
707 assert_relative_eq!(a.get(2, 2), 2.0);
708 assert_relative_eq!(a.get(0, 2), 0.0);
709 }
710
711 #[test]
712 fn test_sparse_diag_matrix() {
713 let diag = vec![3.0f64, 5.0, 7.0];
714 let m = sparse_diag_matrix(&diag, 0, None).expect("diag_matrix");
715 assert_eq!(m.shape(), (3, 3));
716 assert_relative_eq!(m.get(0, 0), 3.0);
717 assert_relative_eq!(m.get(1, 1), 5.0);
718 assert_relative_eq!(m.get(2, 2), 7.0);
719
720 let sd = vec![1.0f64, 2.0];
722 let m2 = sparse_diag_matrix(&sd, 1, None).expect("super_diag");
723 assert_eq!(m2.shape(), (2, 3));
724 assert_relative_eq!(m2.get(0, 1), 1.0);
725 assert_relative_eq!(m2.get(1, 2), 2.0);
726 }
727
728 #[test]
729 fn test_sparse_kronsum() {
730 let a =
732 CsrArray::from_triplets(&[0, 1], &[0, 1], &[1.0f64, 2.0], (2, 2), false).expect("a");
733
734 let b =
735 CsrArray::from_triplets(&[0, 1], &[0, 1], &[3.0f64, 4.0], (2, 2), false).expect("b");
736
737 let result = sparse_kronsum(&a, &b).expect("kronsum");
738 assert_eq!(result.shape(), (4, 4));
739
740 assert_relative_eq!(result.get(0, 0), 4.0);
745 assert_relative_eq!(result.get(1, 1), 5.0);
746 assert_relative_eq!(result.get(2, 2), 5.0);
747 assert_relative_eq!(result.get(3, 3), 6.0);
748 }
749}