1use crate::coo::CooMatrix;
27use crate::csr::CsrMatrix;
28use crate::ell::EllMatrix;
29use oxiblas_core::scalar::{Field, Scalar};
30
31#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum HybError {
34 InvalidDimensions {
36 nrows: usize,
38 ncols: usize,
40 },
41 ZeroEllWidth,
43 IncompatibleParts {
45 ell_shape: (usize, usize),
47 coo_shape: (usize, usize),
49 },
50}
51
52impl core::fmt::Display for HybError {
53 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
54 match self {
55 Self::InvalidDimensions { nrows, ncols } => {
56 write!(f, "Invalid dimensions: {nrows}×{ncols}")
57 }
58 Self::ZeroEllWidth => {
59 write!(f, "ELL width cannot be zero")
60 }
61 Self::IncompatibleParts {
62 ell_shape,
63 coo_shape,
64 } => {
65 write!(
66 f,
67 "ELL shape {:?} incompatible with COO shape {:?}",
68 ell_shape, coo_shape
69 )
70 }
71 }
72 }
73}
74
75impl std::error::Error for HybError {}
76
77#[derive(Debug, Clone, Copy, PartialEq)]
79pub enum HybWidthStrategy {
80 Fixed(usize),
82 Mean,
84 MeanPlusStddev(f64),
86 Median,
88 Percentile(f64),
90 Max,
92}
93
94impl Default for HybWidthStrategy {
95 fn default() -> Self {
96 Self::MeanPlusStddev(1.0)
98 }
99}
100
101#[derive(Debug, Clone)]
126pub struct HybMatrix<T: Scalar> {
127 nrows: usize,
129 ncols: usize,
131 ell_width: usize,
133 ell_data: Vec<T>,
135 ell_indices: Vec<usize>,
137 coo_rows: Vec<usize>,
139 coo_cols: Vec<usize>,
141 coo_data: Vec<T>,
143}
144
145impl<T: Scalar + Clone> HybMatrix<T> {
146 #[allow(clippy::too_many_arguments)]
163 pub fn new(
164 nrows: usize,
165 ncols: usize,
166 ell_width: usize,
167 ell_data: Vec<T>,
168 ell_indices: Vec<usize>,
169 coo_rows: Vec<usize>,
170 coo_cols: Vec<usize>,
171 coo_data: Vec<T>,
172 ) -> Result<Self, HybError> {
173 if nrows == 0 || ncols == 0 {
174 return Err(HybError::InvalidDimensions { nrows, ncols });
175 }
176
177 let expected_ell_size = nrows * ell_width;
178 if ell_data.len() != expected_ell_size || ell_indices.len() != expected_ell_size {
179 return Err(HybError::InvalidDimensions { nrows, ncols });
180 }
181
182 if coo_rows.len() != coo_cols.len() || coo_rows.len() != coo_data.len() {
183 return Err(HybError::InvalidDimensions { nrows, ncols });
184 }
185
186 Ok(Self {
187 nrows,
188 ncols,
189 ell_width,
190 ell_data,
191 ell_indices,
192 coo_rows,
193 coo_cols,
194 coo_data,
195 })
196 }
197
198 pub fn zeros(nrows: usize, ncols: usize, ell_width: usize) -> Self
200 where
201 T: Field,
202 {
203 let size = nrows * ell_width;
204 Self {
205 nrows,
206 ncols,
207 ell_width,
208 ell_data: vec![T::zero(); size],
209 ell_indices: vec![0; size],
210 coo_rows: Vec::new(),
211 coo_cols: Vec::new(),
212 coo_data: Vec::new(),
213 }
214 }
215
216 pub fn eye(n: usize) -> Self
218 where
219 T: Field,
220 {
221 let ell_width = 1;
222 let mut ell_data = Vec::with_capacity(n);
223 let mut ell_indices = Vec::with_capacity(n);
224
225 for i in 0..n {
226 ell_data.push(T::one());
227 ell_indices.push(i);
228 }
229
230 Self {
231 nrows: n,
232 ncols: n,
233 ell_width,
234 ell_data,
235 ell_indices,
236 coo_rows: Vec::new(),
237 coo_cols: Vec::new(),
238 coo_data: Vec::new(),
239 }
240 }
241
242 #[inline]
244 pub fn nrows(&self) -> usize {
245 self.nrows
246 }
247
248 #[inline]
250 pub fn ncols(&self) -> usize {
251 self.ncols
252 }
253
254 #[inline]
256 pub fn shape(&self) -> (usize, usize) {
257 (self.nrows, self.ncols)
258 }
259
260 #[inline]
262 pub fn ell_width(&self) -> usize {
263 self.ell_width
264 }
265
266 pub fn ell_nnz(&self) -> usize
268 where
269 T: Field,
270 {
271 let eps = <T as Scalar>::epsilon();
272 self.ell_data
273 .iter()
274 .filter(|v| Scalar::abs((*v).clone()) > eps)
275 .count()
276 }
277
278 #[inline]
280 pub fn coo_nnz(&self) -> usize {
281 self.coo_data.len()
282 }
283
284 pub fn nnz(&self) -> usize
286 where
287 T: Field,
288 {
289 self.ell_nnz() + self.coo_nnz()
290 }
291
292 #[inline]
294 pub fn ell_data(&self) -> &[T] {
295 &self.ell_data
296 }
297
298 #[inline]
300 pub fn ell_indices(&self) -> &[usize] {
301 &self.ell_indices
302 }
303
304 #[inline]
306 pub fn coo_rows(&self) -> &[usize] {
307 &self.coo_rows
308 }
309
310 #[inline]
312 pub fn coo_cols(&self) -> &[usize] {
313 &self.coo_cols
314 }
315
316 #[inline]
318 pub fn coo_data(&self) -> &[T] {
319 &self.coo_data
320 }
321
322 pub fn ell_fraction(&self) -> f64
324 where
325 T: Field,
326 {
327 let total = self.nnz();
328 if total == 0 {
329 return 1.0;
330 }
331 self.ell_nnz() as f64 / total as f64
332 }
333
334 pub fn storage_efficiency(&self) -> f64
336 where
337 T: Field,
338 {
339 let nnz = self.nnz();
340 if nnz == 0 {
341 return 0.0;
342 }
343 let stored = self.nrows * self.ell_width + self.coo_data.len();
344 nnz as f64 / stored as f64
345 }
346
347 pub fn get(&self, row: usize, col: usize) -> Option<T>
349 where
350 T: Field,
351 {
352 if row >= self.nrows || col >= self.ncols {
353 return None;
354 }
355
356 let eps = <T as Scalar>::epsilon();
357
358 let ell_start = row * self.ell_width;
360 for k in 0..self.ell_width {
361 let idx = ell_start + k;
362 if self.ell_indices[idx] == col {
363 let val = self.ell_data[idx].clone();
364 if Scalar::abs(val.clone()) > eps {
365 return Some(val);
366 }
367 }
368 }
369
370 for (i, (&r, &c)) in self.coo_rows.iter().zip(self.coo_cols.iter()).enumerate() {
372 if r == row && c == col {
373 return Some(self.coo_data[i].clone());
374 }
375 }
376
377 None
378 }
379
380 pub fn get_or_zero(&self, row: usize, col: usize) -> T
382 where
383 T: Field,
384 {
385 self.get(row, col).unwrap_or_else(T::zero)
386 }
387
388 pub fn matvec(&self, x: &[T], y: &mut [T])
390 where
391 T: Field,
392 {
393 assert_eq!(x.len(), self.ncols, "x length must equal ncols");
394 assert_eq!(y.len(), self.nrows, "y length must equal nrows");
395
396 let eps = <T as Scalar>::epsilon();
397
398 for yi in y.iter_mut() {
400 *yi = T::zero();
401 }
402
403 for row in 0..self.nrows {
405 let ell_start = row * self.ell_width;
406 for k in 0..self.ell_width {
407 let idx = ell_start + k;
408 let val = &self.ell_data[idx];
409 if Scalar::abs(val.clone()) > eps {
410 let col = self.ell_indices[idx];
411 y[row] = y[row].clone() + val.clone() * x[col].clone();
412 }
413 }
414 }
415
416 for (i, (&row, &col)) in self.coo_rows.iter().zip(self.coo_cols.iter()).enumerate() {
418 y[row] = y[row].clone() + self.coo_data[i].clone() * x[col].clone();
419 }
420 }
421
422 pub fn mul_vec(&self, x: &[T]) -> Vec<T>
424 where
425 T: Field,
426 {
427 let mut y = vec![T::zero(); self.nrows];
428 self.matvec(x, &mut y);
429 y
430 }
431
432 pub fn matvec_transpose(&self, x: &[T], y: &mut [T])
434 where
435 T: Field,
436 {
437 assert_eq!(x.len(), self.nrows, "x length must equal nrows");
438 assert_eq!(y.len(), self.ncols, "y length must equal ncols");
439
440 let eps = <T as Scalar>::epsilon();
441
442 for yi in y.iter_mut() {
444 *yi = T::zero();
445 }
446
447 for row in 0..self.nrows {
449 let ell_start = row * self.ell_width;
450 for k in 0..self.ell_width {
451 let idx = ell_start + k;
452 let val = &self.ell_data[idx];
453 if Scalar::abs(val.clone()) > eps {
454 let col = self.ell_indices[idx];
455 y[col] = y[col].clone() + val.clone() * x[row].clone();
456 }
457 }
458 }
459
460 for (i, (&row, &col)) in self.coo_rows.iter().zip(self.coo_cols.iter()).enumerate() {
462 y[col] = y[col].clone() + self.coo_data[i].clone() * x[row].clone();
463 }
464 }
465
466 pub fn from_csr(csr: &CsrMatrix<T>, strategy: HybWidthStrategy) -> Self
468 where
469 T: Field,
470 {
471 let (nrows, ncols) = csr.shape();
472 let eps = <T as Scalar>::epsilon();
473
474 let mut row_lengths: Vec<usize> = Vec::with_capacity(nrows);
476 for row in 0..nrows {
477 let mut count = 0;
478 for (_, val) in csr.row_iter(row) {
479 if Scalar::abs(val.clone()) > eps {
480 count += 1;
481 }
482 }
483 row_lengths.push(count);
484 }
485
486 let ell_width = compute_ell_width(&row_lengths, strategy);
488 let ell_width = ell_width.max(1); let mut ell_data = vec![T::zero(); nrows * ell_width];
492 let mut ell_indices = vec![0usize; nrows * ell_width];
493 let mut coo_rows = Vec::new();
494 let mut coo_cols = Vec::new();
495 let mut coo_data = Vec::new();
496
497 for row in 0..nrows {
498 let ell_start = row * ell_width;
499 let mut ell_count = 0;
500
501 for (col, val) in csr.row_iter(row) {
502 if Scalar::abs(val.clone()) <= eps {
503 continue;
504 }
505
506 if ell_count < ell_width {
507 ell_data[ell_start + ell_count] = val.clone();
508 ell_indices[ell_start + ell_count] = col;
509 ell_count += 1;
510 } else {
511 coo_rows.push(row);
513 coo_cols.push(col);
514 coo_data.push(val.clone());
515 }
516 }
517 }
518
519 Self {
520 nrows,
521 ncols,
522 ell_width,
523 ell_data,
524 ell_indices,
525 coo_rows,
526 coo_cols,
527 coo_data,
528 }
529 }
530
531 pub fn from_coo(coo: &CooMatrix<T>, strategy: HybWidthStrategy) -> Self
533 where
534 T: Scalar<Real = T> + Field + oxiblas_core::Real,
535 {
536 let csr = coo.to_csr();
537 Self::from_csr(&csr, strategy)
538 }
539
540 pub fn from_ell(ell: &EllMatrix<T>) -> Self
542 where
543 T: Field,
544 {
545 let (nrows, ncols) = ell.shape();
546 let ell_width = ell.width();
547
548 let mut ell_data = Vec::with_capacity(nrows * ell_width);
550 let mut ell_indices = Vec::with_capacity(nrows * ell_width);
551
552 let data = ell.data();
553 let indices = ell.indices();
554
555 for row in 0..nrows {
556 for k in 0..ell_width {
557 ell_data.push(data[row][k].clone());
558 ell_indices.push(indices[row][k]);
559 }
560 }
561
562 Self {
563 nrows,
564 ncols,
565 ell_width,
566 ell_data,
567 ell_indices,
568 coo_rows: Vec::new(),
569 coo_cols: Vec::new(),
570 coo_data: Vec::new(),
571 }
572 }
573
574 pub fn to_csr(&self) -> CsrMatrix<T>
576 where
577 T: Field,
578 {
579 let eps = <T as Scalar>::epsilon();
580
581 let mut entries: Vec<(usize, usize, T)> = Vec::new();
583
584 for row in 0..self.nrows {
586 let ell_start = row * self.ell_width;
587 for k in 0..self.ell_width {
588 let idx = ell_start + k;
589 let val = self.ell_data[idx].clone();
590 if Scalar::abs(val.clone()) > eps {
591 entries.push((row, self.ell_indices[idx], val));
592 }
593 }
594 }
595
596 for (i, (&row, &col)) in self.coo_rows.iter().zip(self.coo_cols.iter()).enumerate() {
598 entries.push((row, col, self.coo_data[i].clone()));
599 }
600
601 entries.sort_by_key(|(r, c, _)| (*r, *c));
603
604 let mut row_ptrs = vec![0usize; self.nrows + 1];
606 let mut col_indices = Vec::with_capacity(entries.len());
607 let mut values = Vec::with_capacity(entries.len());
608
609 for (row, col, val) in entries {
610 col_indices.push(col);
611 values.push(val);
612 row_ptrs[row + 1] += 1;
613 }
614
615 for i in 1..=self.nrows {
617 row_ptrs[i] += row_ptrs[i - 1];
618 }
619
620 unsafe { CsrMatrix::new_unchecked(self.nrows, self.ncols, row_ptrs, col_indices, values) }
622 }
623
624 pub fn to_coo(&self) -> CooMatrix<T>
626 where
627 T: Field,
628 {
629 let eps = <T as Scalar>::epsilon();
630
631 let mut builder = crate::coo::CooMatrixBuilder::new(self.nrows, self.ncols);
632
633 for row in 0..self.nrows {
635 let ell_start = row * self.ell_width;
636 for k in 0..self.ell_width {
637 let idx = ell_start + k;
638 let val = self.ell_data[idx].clone();
639 if Scalar::abs(val.clone()) > eps {
640 builder.add(row, self.ell_indices[idx], val);
641 }
642 }
643 }
644
645 for (i, (&row, &col)) in self.coo_rows.iter().zip(self.coo_cols.iter()).enumerate() {
647 builder.add(row, col, self.coo_data[i].clone());
648 }
649
650 builder.build()
651 }
652
653 pub fn to_ell(&self) -> EllMatrix<T>
657 where
658 T: Field,
659 {
660 let eps = <T as Scalar>::epsilon();
661
662 let mut row_lengths = vec![0usize; self.nrows];
664
665 for row in 0..self.nrows {
667 let ell_start = row * self.ell_width;
668 for k in 0..self.ell_width {
669 let idx = ell_start + k;
670 if Scalar::abs(self.ell_data[idx].clone()) > eps {
671 row_lengths[row] += 1;
672 }
673 }
674 }
675
676 for &row in &self.coo_rows {
678 row_lengths[row] += 1;
679 }
680
681 let max_width = row_lengths.iter().max().copied().unwrap_or(0).max(1);
682
683 let mut data: Vec<Vec<T>> = vec![vec![T::zero(); max_width]; self.nrows];
685 let mut indices: Vec<Vec<usize>> = vec![vec![0usize; max_width]; self.nrows];
686 let mut current_counts = vec![0usize; self.nrows];
687
688 for row in 0..self.nrows {
690 let ell_start = row * self.ell_width;
691 for k in 0..self.ell_width {
692 let idx = ell_start + k;
693 let val = self.ell_data[idx].clone();
694 if Scalar::abs(val.clone()) > eps {
695 let pos = current_counts[row];
696 data[row][pos] = val;
697 indices[row][pos] = self.ell_indices[idx];
698 current_counts[row] += 1;
699 }
700 }
701 }
702
703 for (i, (&row, &col)) in self.coo_rows.iter().zip(self.coo_cols.iter()).enumerate() {
705 let pos = current_counts[row];
706 data[row][pos] = self.coo_data[i].clone();
707 indices[row][pos] = col;
708 current_counts[row] += 1;
709 }
710
711 unsafe { EllMatrix::new_unchecked(self.nrows, self.ncols, max_width, data, indices) }
713 }
714
715 pub fn to_dense(&self) -> oxiblas_matrix::Mat<T>
717 where
718 T: Field + bytemuck::Zeroable,
719 {
720 let eps = <T as Scalar>::epsilon();
721 let mut dense = oxiblas_matrix::Mat::zeros(self.nrows, self.ncols);
722
723 for row in 0..self.nrows {
725 let ell_start = row * self.ell_width;
726 for k in 0..self.ell_width {
727 let idx = ell_start + k;
728 let val = self.ell_data[idx].clone();
729 if Scalar::abs(val.clone()) > eps {
730 let col = self.ell_indices[idx];
731 dense[(row, col)] = val;
732 }
733 }
734 }
735
736 for (i, (&row, &col)) in self.coo_rows.iter().zip(self.coo_cols.iter()).enumerate() {
738 dense[(row, col)] = self.coo_data[i].clone();
739 }
740
741 dense
742 }
743
744 pub fn scale(&mut self, alpha: T) {
746 for val in &mut self.ell_data {
747 *val = val.clone() * alpha.clone();
748 }
749 for val in &mut self.coo_data {
750 *val = val.clone() * alpha.clone();
751 }
752 }
753
754 pub fn scaled(&self, alpha: T) -> Self {
756 let mut result = self.clone();
757 result.scale(alpha);
758 result
759 }
760
761 pub fn iter(&self) -> impl Iterator<Item = (usize, usize, T)> + '_
763 where
764 T: Field,
765 {
766 let eps = <T as Scalar>::epsilon();
767 let ell_width = self.ell_width;
768
769 let ell_iter = (0..self.nrows).flat_map(move |row| {
771 let ell_start = row * ell_width;
772 (0..ell_width).filter_map(move |k| {
773 let idx = ell_start + k;
774 let val = self.ell_data[idx].clone();
775 if Scalar::abs(val.clone()) > eps {
776 Some((row, self.ell_indices[idx], val))
777 } else {
778 None
779 }
780 })
781 });
782
783 let coo_iter = self
785 .coo_rows
786 .iter()
787 .zip(self.coo_cols.iter())
788 .zip(self.coo_data.iter())
789 .map(|((&row, &col), val)| (row, col, val.clone()));
790
791 ell_iter.chain(coo_iter)
792 }
793
794 pub fn rebalance(&mut self, new_width: usize)
798 where
799 T: Field,
800 {
801 if new_width == self.ell_width {
802 return;
803 }
804
805 let eps = <T as Scalar>::epsilon();
806
807 let mut entries: Vec<(usize, usize, T)> = Vec::new();
809
810 for row in 0..self.nrows {
812 let ell_start = row * self.ell_width;
813 for k in 0..self.ell_width {
814 let idx = ell_start + k;
815 let val = self.ell_data[idx].clone();
816 if Scalar::abs(val.clone()) > eps {
817 entries.push((row, self.ell_indices[idx], val));
818 }
819 }
820 }
821
822 for (i, (&row, &col)) in self.coo_rows.iter().zip(self.coo_cols.iter()).enumerate() {
824 entries.push((row, col, self.coo_data[i].clone()));
825 }
826
827 entries.sort_by_key(|(r, c, _)| (*r, *c));
829
830 self.ell_width = new_width;
832 self.ell_data = vec![T::zero(); self.nrows * new_width];
833 self.ell_indices = vec![0usize; self.nrows * new_width];
834 self.coo_rows.clear();
835 self.coo_cols.clear();
836 self.coo_data.clear();
837
838 let mut current_row = 0;
839 let mut count_in_row = 0;
840
841 for (row, col, val) in entries {
842 if row != current_row {
843 current_row = row;
844 count_in_row = 0;
845 }
846
847 if count_in_row < new_width {
848 let idx = row * new_width + count_in_row;
849 self.ell_data[idx] = val;
850 self.ell_indices[idx] = col;
851 count_in_row += 1;
852 } else {
853 self.coo_rows.push(row);
854 self.coo_cols.push(col);
855 self.coo_data.push(val);
856 }
857 }
858 }
859
860 pub fn stats(&self) -> HybStats
862 where
863 T: Field,
864 {
865 let eps = <T as Scalar>::epsilon();
866
867 let mut row_lengths = vec![0usize; self.nrows];
869
870 for row in 0..self.nrows {
871 let ell_start = row * self.ell_width;
872 for k in 0..self.ell_width {
873 let idx = ell_start + k;
874 if Scalar::abs(self.ell_data[idx].clone()) > eps {
875 row_lengths[row] += 1;
876 }
877 }
878 }
879
880 for &row in &self.coo_rows {
881 row_lengths[row] += 1;
882 }
883
884 let ell_nnz = self.ell_nnz();
885 let coo_nnz = self.coo_nnz();
886 let total_nnz = ell_nnz + coo_nnz;
887
888 let max_row_len = row_lengths.iter().max().copied().unwrap_or(0);
889 let min_row_len = row_lengths.iter().min().copied().unwrap_or(0);
890 let avg_row_len = if self.nrows > 0 {
891 total_nnz as f64 / self.nrows as f64
892 } else {
893 0.0
894 };
895
896 HybStats {
897 nrows: self.nrows,
898 ncols: self.ncols,
899 ell_width: self.ell_width,
900 ell_nnz,
901 coo_nnz,
902 total_nnz,
903 ell_fraction: if total_nnz > 0 {
904 ell_nnz as f64 / total_nnz as f64
905 } else {
906 1.0
907 },
908 max_row_length: max_row_len,
909 min_row_length: min_row_len,
910 avg_row_length: avg_row_len,
911 storage_efficiency: self.storage_efficiency(),
912 }
913 }
914}
915
916#[derive(Debug, Clone, Copy)]
918pub struct HybStats {
919 pub nrows: usize,
921 pub ncols: usize,
923 pub ell_width: usize,
925 pub ell_nnz: usize,
927 pub coo_nnz: usize,
929 pub total_nnz: usize,
931 pub ell_fraction: f64,
933 pub max_row_length: usize,
935 pub min_row_length: usize,
937 pub avg_row_length: f64,
939 pub storage_efficiency: f64,
941}
942
943fn compute_ell_width(row_lengths: &[usize], strategy: HybWidthStrategy) -> usize {
945 if row_lengths.is_empty() {
946 return 1;
947 }
948
949 match strategy {
950 HybWidthStrategy::Fixed(k) => k,
951
952 HybWidthStrategy::Mean => {
953 let sum: usize = row_lengths.iter().sum();
954 let mean = sum as f64 / row_lengths.len() as f64;
955 mean.ceil() as usize
956 }
957
958 HybWidthStrategy::MeanPlusStddev(k) => {
959 let n = row_lengths.len() as f64;
960 let sum: usize = row_lengths.iter().sum();
961 let mean = sum as f64 / n;
962
963 let variance: f64 = row_lengths
964 .iter()
965 .map(|&x| {
966 let diff = x as f64 - mean;
967 diff * diff
968 })
969 .sum::<f64>()
970 / n;
971 let stddev = variance.sqrt();
972
973 (mean + k * stddev).ceil() as usize
974 }
975
976 HybWidthStrategy::Median => {
977 let mut sorted = row_lengths.to_vec();
978 sorted.sort_unstable();
979 let mid = sorted.len() / 2;
980 if sorted.len() % 2 == 0 {
981 (sorted[mid - 1] + sorted[mid]).div_ceil(2)
982 } else {
983 sorted[mid]
984 }
985 }
986
987 HybWidthStrategy::Percentile(p) => {
988 let p = p.clamp(0.0, 1.0);
989 let mut sorted = row_lengths.to_vec();
990 sorted.sort_unstable();
991 let idx = ((sorted.len() - 1) as f64 * p) as usize;
992 sorted[idx]
993 }
994
995 HybWidthStrategy::Max => row_lengths.iter().max().copied().unwrap_or(1),
996 }
997}
998
999#[cfg(test)]
1000mod tests {
1001 use super::*;
1002
1003 fn make_test_csr() -> CsrMatrix<f64> {
1004 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
1009 let col_indices = vec![0, 2, 1, 0, 2, 3, 3];
1010 let row_ptrs = vec![0, 2, 3, 6, 7];
1011 CsrMatrix::new(4, 4, row_ptrs, col_indices, values).unwrap()
1012 }
1013
1014 #[test]
1015 fn test_hyb_from_csr_fixed() {
1016 let csr = make_test_csr();
1017 let hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(2));
1018
1019 assert_eq!(hyb.nrows(), 4);
1020 assert_eq!(hyb.ncols(), 4);
1021 assert_eq!(hyb.ell_width(), 2);
1022
1023 assert_eq!(hyb.coo_nnz(), 1);
1025 assert_eq!(hyb.nnz(), 7);
1026 }
1027
1028 #[test]
1029 fn test_hyb_from_csr_max() {
1030 let csr = make_test_csr();
1031 let hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Max);
1032
1033 assert_eq!(hyb.ell_width(), 3); assert_eq!(hyb.coo_nnz(), 0); }
1036
1037 #[test]
1038 fn test_hyb_matvec() {
1039 let csr = make_test_csr();
1040 let hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(2));
1041
1042 let x = vec![1.0, 2.0, 3.0, 4.0];
1047 let y = hyb.mul_vec(&x);
1048
1049 assert!((y[0] - 7.0).abs() < 1e-10);
1050 assert!((y[1] - 6.0).abs() < 1e-10);
1051 assert!((y[2] - 43.0).abs() < 1e-10);
1052 assert!((y[3] - 28.0).abs() < 1e-10);
1053 }
1054
1055 #[test]
1056 fn test_hyb_matvec_transpose() {
1057 let csr = make_test_csr();
1058 let hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(2));
1059
1060 let x = vec![1.0, 1.0, 1.0, 1.0];
1061 let mut y = vec![0.0; 4];
1062 hyb.matvec_transpose(&x, &mut y);
1063
1064 assert!((y[0] - 5.0).abs() < 1e-10); assert!((y[1] - 3.0).abs() < 1e-10); assert!((y[2] - 7.0).abs() < 1e-10); assert!((y[3] - 13.0).abs() < 1e-10); }
1070
1071 #[test]
1072 fn test_hyb_to_csr_roundtrip() {
1073 let csr1 = make_test_csr();
1074 let hyb = HybMatrix::from_csr(&csr1, HybWidthStrategy::Fixed(2));
1075 let csr2 = hyb.to_csr();
1076
1077 assert_eq!(csr1.nnz(), csr2.nnz());
1078
1079 for row in 0..4 {
1081 for col in 0..4 {
1082 let v1 = csr1.get(row, col).cloned().unwrap_or(0.0);
1083 let v2 = csr2.get(row, col).cloned().unwrap_or(0.0);
1084 assert!((v1 - v2).abs() < 1e-10);
1085 }
1086 }
1087 }
1088
1089 #[test]
1090 fn test_hyb_to_dense() {
1091 let csr = make_test_csr();
1092 let hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(2));
1093 let dense = hyb.to_dense();
1094
1095 assert!((dense[(0, 0)] - 1.0).abs() < 1e-10);
1096 assert!((dense[(0, 2)] - 2.0).abs() < 1e-10);
1097 assert!((dense[(1, 1)] - 3.0).abs() < 1e-10);
1098 assert!((dense[(2, 0)] - 4.0).abs() < 1e-10);
1099 assert!((dense[(2, 2)] - 5.0).abs() < 1e-10);
1100 assert!((dense[(2, 3)] - 6.0).abs() < 1e-10);
1101 assert!((dense[(3, 3)] - 7.0).abs() < 1e-10);
1102 }
1103
1104 #[test]
1105 fn test_hyb_get() {
1106 let csr = make_test_csr();
1107 let hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(2));
1108
1109 assert_eq!(hyb.get(0, 0), Some(1.0));
1110 assert_eq!(hyb.get(0, 2), Some(2.0));
1111 assert_eq!(hyb.get(1, 1), Some(3.0));
1112 assert_eq!(hyb.get(2, 0), Some(4.0));
1113 assert_eq!(hyb.get(2, 2), Some(5.0));
1114 assert_eq!(hyb.get(2, 3), Some(6.0)); assert_eq!(hyb.get(3, 3), Some(7.0));
1116
1117 assert_eq!(hyb.get(0, 1), None);
1118 }
1119
1120 #[test]
1121 fn test_hyb_eye() {
1122 let hyb: HybMatrix<f64> = HybMatrix::eye(4);
1123
1124 assert_eq!(hyb.nrows(), 4);
1125 assert_eq!(hyb.ncols(), 4);
1126 assert_eq!(hyb.ell_width(), 1);
1127 assert_eq!(hyb.nnz(), 4);
1128
1129 for i in 0..4 {
1130 assert_eq!(hyb.get(i, i), Some(1.0));
1131 }
1132 }
1133
1134 #[test]
1135 fn test_hyb_scale() {
1136 let csr = make_test_csr();
1137 let mut hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(2));
1138
1139 hyb.scale(2.0);
1140
1141 assert_eq!(hyb.get(0, 0), Some(2.0));
1142 assert_eq!(hyb.get(2, 2), Some(10.0));
1143 }
1144
1145 #[test]
1146 fn test_hyb_rebalance() {
1147 let csr = make_test_csr();
1148 let mut hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(1));
1149
1150 assert!(hyb.coo_nnz() > 0);
1152
1153 hyb.rebalance(3);
1155
1156 assert_eq!(hyb.ell_width(), 3);
1157 assert_eq!(hyb.coo_nnz(), 0); assert_eq!(hyb.nnz(), 7);
1159 }
1160
1161 #[test]
1162 fn test_hyb_stats() {
1163 let csr = make_test_csr();
1164 let hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(2));
1165 let stats = hyb.stats();
1166
1167 assert_eq!(stats.nrows, 4);
1168 assert_eq!(stats.ncols, 4);
1169 assert_eq!(stats.ell_width, 2);
1170 assert_eq!(stats.total_nnz, 7);
1171 assert_eq!(stats.max_row_length, 3);
1172 assert_eq!(stats.min_row_length, 1);
1173 }
1174
1175 #[test]
1176 fn test_hyb_iter() {
1177 let csr = make_test_csr();
1178 let hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(2));
1179
1180 let entries: Vec<_> = hyb.iter().collect();
1181 assert_eq!(entries.len(), 7);
1182 }
1183
1184 #[test]
1185 fn test_compute_ell_width() {
1186 let row_lengths = vec![1, 2, 5, 3, 2];
1187
1188 assert_eq!(
1189 compute_ell_width(&row_lengths, HybWidthStrategy::Fixed(4)),
1190 4
1191 );
1192 assert_eq!(compute_ell_width(&row_lengths, HybWidthStrategy::Max), 5);
1193 assert_eq!(compute_ell_width(&row_lengths, HybWidthStrategy::Median), 2);
1194
1195 assert_eq!(compute_ell_width(&row_lengths, HybWidthStrategy::Mean), 3);
1197 }
1198
1199 #[test]
1200 fn test_hyb_from_ell() {
1201 let csr = make_test_csr();
1203 let ell = crate::ell::EllMatrix::from_csr(&csr, None).unwrap();
1204
1205 let hyb = HybMatrix::from_ell(&ell);
1206
1207 assert_eq!(hyb.coo_nnz(), 0);
1208 assert_eq!(hyb.nnz(), 7);
1209 }
1210}