1use oxiblas_core::scalar::{Field, Scalar};
23
24#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum EllError {
27 InvalidDataDimensions {
29 expected_rows: usize,
31 actual_rows: usize,
33 expected_width: usize,
35 actual_width: usize,
37 },
38 DimensionMismatch {
40 data_dims: (usize, usize),
42 indices_dims: (usize, usize),
44 },
45 InvalidColumnIndex {
47 row: usize,
49 pos: usize,
51 index: usize,
53 ncols: usize,
55 },
56 TooManyNonZeros {
58 row: usize,
60 nnz: usize,
62 max_nnz: usize,
64 },
65}
66
67impl core::fmt::Display for EllError {
68 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
69 match self {
70 Self::InvalidDataDimensions {
71 expected_rows,
72 actual_rows,
73 expected_width,
74 actual_width,
75 } => {
76 write!(
77 f,
78 "Invalid data dimensions: expected {expected_rows}×{expected_width}, got {actual_rows}×{actual_width}"
79 )
80 }
81 Self::DimensionMismatch {
82 data_dims,
83 indices_dims,
84 } => {
85 write!(
86 f,
87 "Dimension mismatch: data is {}×{}, indices is {}×{}",
88 data_dims.0, data_dims.1, indices_dims.0, indices_dims.1
89 )
90 }
91 Self::InvalidColumnIndex {
92 row,
93 pos,
94 index,
95 ncols,
96 } => {
97 write!(
98 f,
99 "Invalid column index {index} at row {row}, position {pos} (ncols={ncols})"
100 )
101 }
102 Self::TooManyNonZeros { row, nnz, max_nnz } => {
103 write!(f, "Row {row} has {nnz} non-zeros, exceeds max {max_nnz}")
104 }
105 }
106 }
107}
108
109impl std::error::Error for EllError {}
110
111#[derive(Debug, Clone)]
149pub struct EllMatrix<T: Scalar> {
150 nrows: usize,
152 ncols: usize,
154 width: usize,
156 data: Vec<Vec<T>>,
158 indices: Vec<Vec<usize>>,
160}
161
162const INVALID_INDEX: usize = usize::MAX;
164
165impl<T: Scalar + Clone> EllMatrix<T> {
166 pub fn new(
180 nrows: usize,
181 ncols: usize,
182 width: usize,
183 data: Vec<Vec<T>>,
184 indices: Vec<Vec<usize>>,
185 ) -> Result<Self, EllError> {
186 if data.len() != nrows {
188 return Err(EllError::InvalidDataDimensions {
189 expected_rows: nrows,
190 actual_rows: data.len(),
191 expected_width: width,
192 actual_width: if data.is_empty() { 0 } else { data[0].len() },
193 });
194 }
195
196 for (i, row) in data.iter().enumerate() {
197 if row.len() != width {
198 return Err(EllError::InvalidDataDimensions {
199 expected_rows: nrows,
200 actual_rows: data.len(),
201 expected_width: width,
202 actual_width: row.len(),
203 });
204 }
205
206 if i < indices.len() && indices[i].len() != width {
208 return Err(EllError::DimensionMismatch {
209 data_dims: (nrows, width),
210 indices_dims: (indices.len(), indices[i].len()),
211 });
212 }
213 }
214
215 if indices.len() != nrows {
217 return Err(EllError::DimensionMismatch {
218 data_dims: (nrows, width),
219 indices_dims: (
220 indices.len(),
221 if indices.is_empty() {
222 0
223 } else {
224 indices[0].len()
225 },
226 ),
227 });
228 }
229
230 for (row, row_indices) in indices.iter().enumerate() {
232 for (pos, &col) in row_indices.iter().enumerate() {
233 if col != INVALID_INDEX && col >= ncols {
234 return Err(EllError::InvalidColumnIndex {
235 row,
236 pos,
237 index: col,
238 ncols,
239 });
240 }
241 }
242 }
243
244 Ok(Self {
245 nrows,
246 ncols,
247 width,
248 data,
249 indices,
250 })
251 }
252
253 #[inline]
262 pub unsafe fn new_unchecked(
263 nrows: usize,
264 ncols: usize,
265 width: usize,
266 data: Vec<Vec<T>>,
267 indices: Vec<Vec<usize>>,
268 ) -> Self {
269 Self {
270 nrows,
271 ncols,
272 width,
273 data,
274 indices,
275 }
276 }
277
278 pub fn zeros(nrows: usize, ncols: usize) -> Self {
280 Self {
281 nrows,
282 ncols,
283 width: 0,
284 data: vec![Vec::new(); nrows],
285 indices: vec![Vec::new(); nrows],
286 }
287 }
288
289 pub fn eye(n: usize) -> Self
291 where
292 T: Field,
293 {
294 Self {
295 nrows: n,
296 ncols: n,
297 width: 1,
298 data: (0..n).map(|_| vec![T::one()]).collect(),
299 indices: (0..n).map(|i| vec![i]).collect(),
300 }
301 }
302
303 #[inline]
305 pub fn nrows(&self) -> usize {
306 self.nrows
307 }
308
309 #[inline]
311 pub fn ncols(&self) -> usize {
312 self.ncols
313 }
314
315 #[inline]
317 pub fn shape(&self) -> (usize, usize) {
318 (self.nrows, self.ncols)
319 }
320
321 #[inline]
323 pub fn width(&self) -> usize {
324 self.width
325 }
326
327 pub fn nnz(&self) -> usize
331 where
332 T: Field,
333 {
334 let eps = <T as Scalar>::epsilon();
335 let mut count = 0;
336
337 for (row, indices_row) in self.indices.iter().enumerate() {
338 for (k, &col) in indices_row.iter().enumerate() {
339 if col != INVALID_INDEX && Scalar::abs(self.data[row][k].clone()) > eps {
340 count += 1;
341 }
342 }
343 }
344
345 count
346 }
347
348 #[inline]
350 pub fn nstored(&self) -> usize {
351 self.nrows * self.width
352 }
353
354 pub fn efficiency(&self) -> f64
356 where
357 T: Field,
358 {
359 if self.nstored() == 0 {
360 1.0
361 } else {
362 self.nnz() as f64 / self.nstored() as f64
363 }
364 }
365
366 #[inline]
368 pub fn data(&self) -> &[Vec<T>] {
369 &self.data
370 }
371
372 #[inline]
374 pub fn data_mut(&mut self) -> &mut [Vec<T>] {
375 &mut self.data
376 }
377
378 #[inline]
380 pub fn indices(&self) -> &[Vec<usize>] {
381 &self.indices
382 }
383
384 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
386 if row >= self.nrows || col >= self.ncols {
387 return None;
388 }
389
390 for k in 0..self.width {
391 if self.indices[row][k] == col {
392 return Some(&self.data[row][k]);
393 }
394 }
395
396 None
397 }
398
399 pub fn get_or_zero(&self, row: usize, col: usize) -> T
401 where
402 T: Field,
403 {
404 self.get(row, col).cloned().unwrap_or_else(T::zero)
405 }
406
407 pub fn row_iter(&self, row: usize) -> impl Iterator<Item = (usize, &T)> {
409 self.indices[row]
410 .iter()
411 .zip(self.data[row].iter())
412 .filter(|(col, _)| **col != INVALID_INDEX)
413 .map(|(col, val)| (*col, val))
414 }
415
416 pub fn iter(&self) -> impl Iterator<Item = (usize, usize, &T)> + '_ {
418 (0..self.nrows).flat_map(move |row| {
419 self.indices[row]
420 .iter()
421 .zip(self.data[row].iter())
422 .filter(|(col, _)| **col != INVALID_INDEX)
423 .map(move |(col, val)| (row, *col, val))
424 })
425 }
426
427 pub fn matvec(&self, x: &[T], y: &mut [T])
429 where
430 T: Field,
431 {
432 assert_eq!(x.len(), self.ncols, "x length must equal ncols");
433 assert_eq!(y.len(), self.nrows, "y length must equal nrows");
434
435 for row in 0..self.nrows {
436 let mut sum = T::zero();
437 for k in 0..self.width {
438 let col = self.indices[row][k];
439 if col != INVALID_INDEX {
440 sum = sum + self.data[row][k].clone() * x[col].clone();
441 }
442 }
443 y[row] = sum;
444 }
445 }
446
447 pub fn mul_vec(&self, x: &[T]) -> Vec<T>
449 where
450 T: Field,
451 {
452 let mut y = vec![T::zero(); self.nrows];
453 self.matvec(x, &mut y);
454 y
455 }
456
457 pub fn to_csr(&self) -> crate::csr::CsrMatrix<T>
459 where
460 T: Field,
461 {
462 let eps = <T as Scalar>::epsilon();
463
464 let mut row_ptrs = vec![0usize; self.nrows + 1];
465 let mut col_indices = Vec::new();
466 let mut values = Vec::new();
467
468 for row in 0..self.nrows {
469 let mut row_entries: Vec<(usize, T)> = Vec::new();
470
471 for k in 0..self.width {
472 let col = self.indices[row][k];
473 if col != INVALID_INDEX {
474 let val = self.data[row][k].clone();
475 if Scalar::abs(val.clone()) > eps {
476 row_entries.push((col, val));
477 }
478 }
479 }
480
481 row_entries.sort_by_key(|(col, _)| *col);
483
484 for (col, val) in row_entries {
485 col_indices.push(col);
486 values.push(val);
487 }
488 row_ptrs[row + 1] = values.len();
489 }
490
491 unsafe {
493 crate::csr::CsrMatrix::new_unchecked(
494 self.nrows,
495 self.ncols,
496 row_ptrs,
497 col_indices,
498 values,
499 )
500 }
501 }
502
503 pub fn to_dense(&self) -> oxiblas_matrix::Mat<T>
505 where
506 T: Field + bytemuck::Zeroable,
507 {
508 let mut dense = oxiblas_matrix::Mat::zeros(self.nrows, self.ncols);
509
510 for row in 0..self.nrows {
511 for k in 0..self.width {
512 let col = self.indices[row][k];
513 if col != INVALID_INDEX {
514 dense[(row, col)] = self.data[row][k].clone();
515 }
516 }
517 }
518
519 dense
520 }
521
522 pub fn from_dense(dense: &oxiblas_matrix::MatRef<'_, T>, max_width: Option<usize>) -> Self
529 where
530 T: Field,
531 {
532 let (nrows, ncols) = dense.shape();
533 let eps = <T as Scalar>::epsilon();
534
535 let mut row_nnz = vec![0usize; nrows];
537 for i in 0..nrows {
538 for j in 0..ncols {
539 if Scalar::abs(dense[(i, j)].clone()) > eps {
540 row_nnz[i] += 1;
541 }
542 }
543 }
544
545 let width = max_width.unwrap_or_else(|| row_nnz.iter().copied().max().unwrap_or(0));
546
547 let mut data = Vec::with_capacity(nrows);
549 let mut indices = Vec::with_capacity(nrows);
550
551 for i in 0..nrows {
552 let mut row_data = Vec::with_capacity(width);
553 let mut row_indices = Vec::with_capacity(width);
554
555 for j in 0..ncols {
556 if row_data.len() >= width {
557 break;
558 }
559 let val = dense[(i, j)].clone();
560 if Scalar::abs(val.clone()) > eps {
561 row_data.push(val);
562 row_indices.push(j);
563 }
564 }
565
566 while row_data.len() < width {
568 row_data.push(T::zero());
569 row_indices.push(INVALID_INDEX);
570 }
571
572 data.push(row_data);
573 indices.push(row_indices);
574 }
575
576 unsafe { Self::new_unchecked(nrows, ncols, width, data, indices) }
578 }
579
580 pub fn from_csr(
587 csr: &crate::csr::CsrMatrix<T>,
588 max_width: Option<usize>,
589 ) -> Result<Self, EllError>
590 where
591 T: Field,
592 {
593 let (nrows, ncols) = csr.shape();
594 let row_ptrs = csr.row_ptrs();
595 let csr_indices = csr.col_indices();
596 let csr_values = csr.values();
597
598 let actual_max: usize = (0..nrows)
600 .map(|i| row_ptrs[i + 1] - row_ptrs[i])
601 .max()
602 .unwrap_or(0);
603
604 let width = max_width.unwrap_or(actual_max);
605
606 if let Some(max_w) = max_width {
608 for row in 0..nrows {
609 let row_nnz = row_ptrs[row + 1] - row_ptrs[row];
610 if row_nnz > max_w {
611 return Err(EllError::TooManyNonZeros {
612 row,
613 nnz: row_nnz,
614 max_nnz: max_w,
615 });
616 }
617 }
618 }
619
620 let mut data = Vec::with_capacity(nrows);
622 let mut indices = Vec::with_capacity(nrows);
623
624 for row in 0..nrows {
625 let start = row_ptrs[row];
626 let end = row_ptrs[row + 1];
627 let row_nnz = end - start;
628
629 let mut row_data = Vec::with_capacity(width);
630 let mut row_indices = Vec::with_capacity(width);
631
632 for k in 0..row_nnz {
633 row_data.push(csr_values[start + k].clone());
634 row_indices.push(csr_indices[start + k]);
635 }
636
637 while row_data.len() < width {
639 row_data.push(T::zero());
640 row_indices.push(INVALID_INDEX);
641 }
642
643 data.push(row_data);
644 indices.push(row_indices);
645 }
646
647 Ok(Self {
648 nrows,
649 ncols,
650 width,
651 data,
652 indices,
653 })
654 }
655
656 pub fn scale(&mut self, alpha: T) {
658 for row in &mut self.data {
659 for val in row.iter_mut() {
660 *val = val.clone() * alpha.clone();
661 }
662 }
663 }
664
665 pub fn scaled(&self, alpha: T) -> Self {
667 let mut result = self.clone();
668 result.scale(alpha);
669 result
670 }
671
672 pub fn transpose(&self) -> Self
676 where
677 T: Field,
678 {
679 let csr = self.to_csr();
681 let csr_t = csr.transpose();
682 Self::from_csr(&csr_t, Some(self.width)).unwrap_or_else(|_| {
683 Self::from_csr(&csr_t, None).expect("CSR transpose should be valid")
685 })
686 }
687}
688
689#[cfg(test)]
690mod tests {
691 use super::*;
692
693 #[test]
694 fn test_ell_new() {
695 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
699 let indices = vec![vec![0, 1], vec![1, 2], vec![0, 3]];
700
701 let ell = EllMatrix::new(3, 4, 2, data, indices).unwrap();
702
703 assert_eq!(ell.nrows(), 3);
704 assert_eq!(ell.ncols(), 4);
705 assert_eq!(ell.width(), 2);
706 }
707
708 #[test]
709 fn test_ell_get() {
710 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
711 let indices = vec![vec![0, 1], vec![1, 2], vec![0, 3]];
712
713 let ell = EllMatrix::new(3, 4, 2, data, indices).unwrap();
714
715 assert_eq!(ell.get(0, 0), Some(&1.0));
716 assert_eq!(ell.get(0, 1), Some(&2.0));
717 assert_eq!(ell.get(1, 1), Some(&3.0));
718 assert_eq!(ell.get(1, 2), Some(&4.0));
719 assert_eq!(ell.get(2, 0), Some(&5.0));
720 assert_eq!(ell.get(2, 3), Some(&6.0));
721
722 assert_eq!(ell.get(0, 2), None);
724 assert_eq!(ell.get(0, 3), None);
725 }
726
727 #[test]
728 fn test_ell_matvec() {
729 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
734 let indices = vec![vec![0, 1], vec![1, 2], vec![0, 3]];
735
736 let ell = EllMatrix::new(3, 4, 2, data, indices).unwrap();
737 let x = vec![1.0, 1.0, 1.0, 1.0];
738 let y = ell.mul_vec(&x);
739
740 assert!((y[0] - 3.0).abs() < 1e-10);
741 assert!((y[1] - 7.0).abs() < 1e-10);
742 assert!((y[2] - 11.0).abs() < 1e-10);
743 }
744
745 #[test]
746 fn test_ell_with_padding() {
747 let data = vec![
751 vec![1.0, 0.0, 0.0], vec![2.0, 3.0, 4.0], vec![5.0, 0.0, 0.0], ];
755 let indices = vec![
756 vec![0, INVALID_INDEX, INVALID_INDEX],
757 vec![0, 1, 2],
758 vec![1, INVALID_INDEX, INVALID_INDEX],
759 ];
760
761 let ell = EllMatrix::new(3, 3, 3, data, indices).unwrap();
762
763 assert_eq!(ell.nnz(), 5);
764 assert_eq!(ell.nstored(), 9);
765 assert!((ell.efficiency() - 5.0 / 9.0).abs() < 1e-10);
766 }
767
768 #[test]
769 fn test_ell_eye() {
770 let ell: EllMatrix<f64> = EllMatrix::eye(4);
771
772 assert_eq!(ell.nrows(), 4);
773 assert_eq!(ell.ncols(), 4);
774 assert_eq!(ell.width(), 1);
775
776 for i in 0..4 {
777 assert_eq!(ell.get(i, i), Some(&1.0));
778 }
779 }
780
781 #[test]
782 fn test_ell_to_dense() {
783 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
784 let indices = vec![vec![0, 1], vec![1, 2], vec![0, 3]];
785
786 let ell = EllMatrix::new(3, 4, 2, data, indices).unwrap();
787 let dense = ell.to_dense();
788
789 assert!((dense[(0, 0)] - 1.0).abs() < 1e-10);
790 assert!((dense[(0, 1)] - 2.0).abs() < 1e-10);
791 assert!((dense[(0, 2)] - 0.0).abs() < 1e-10);
792 assert!((dense[(1, 1)] - 3.0).abs() < 1e-10);
793 assert!((dense[(1, 2)] - 4.0).abs() < 1e-10);
794 assert!((dense[(2, 0)] - 5.0).abs() < 1e-10);
795 assert!((dense[(2, 3)] - 6.0).abs() < 1e-10);
796 }
797
798 #[test]
799 fn test_ell_to_csr() {
800 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
801 let indices = vec![vec![0, 1], vec![1, 2], vec![0, 3]];
802
803 let ell = EllMatrix::new(3, 4, 2, data, indices).unwrap();
804 let csr = ell.to_csr();
805
806 assert_eq!(csr.nrows(), 3);
807 assert_eq!(csr.ncols(), 4);
808 assert_eq!(csr.nnz(), 6);
809 assert_eq!(csr.get(0, 0), Some(&1.0));
810 assert_eq!(csr.get(2, 3), Some(&6.0));
811 }
812
813 #[test]
814 fn test_ell_from_dense() {
815 use oxiblas_matrix::Mat;
816
817 let dense = Mat::from_rows(&[
818 &[1.0f64, 2.0, 0.0, 0.0],
819 &[0.0, 3.0, 4.0, 0.0],
820 &[5.0, 0.0, 0.0, 6.0],
821 ]);
822
823 let ell = EllMatrix::from_dense(&dense.as_ref(), None);
824
825 assert_eq!(ell.width(), 2);
826 assert_eq!(ell.get(0, 0), Some(&1.0));
827 assert_eq!(ell.get(1, 2), Some(&4.0));
828 }
829
830 #[test]
831 fn test_ell_from_csr() {
832 let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
833 let col_indices = vec![0, 1, 1, 2, 0, 3];
834 let row_ptrs = vec![0, 2, 4, 6];
835
836 let csr = crate::csr::CsrMatrix::new(3, 4, row_ptrs, col_indices, values).unwrap();
837 let ell = EllMatrix::from_csr(&csr, None).unwrap();
838
839 assert_eq!(ell.width(), 2);
840 assert_eq!(ell.get(0, 0), Some(&1.0));
841 assert_eq!(ell.get(2, 3), Some(&6.0));
842 }
843
844 #[test]
845 fn test_ell_scale() {
846 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
847 let indices = vec![vec![0, 1], vec![0, 1]];
848
849 let mut ell = EllMatrix::new(2, 2, 2, data, indices).unwrap();
850 ell.scale(2.0);
851
852 assert_eq!(ell.get(0, 0), Some(&2.0));
853 assert_eq!(ell.get(0, 1), Some(&4.0));
854 }
855
856 #[test]
857 fn test_ell_transpose() {
858 let data = vec![vec![1.0, 2.0], vec![3.0, 0.0]];
859 let indices = vec![vec![0, 1], vec![0, INVALID_INDEX]];
860
861 let ell = EllMatrix::new(2, 2, 2, data, indices).unwrap();
862 let ell_t = ell.transpose();
863
864 let dense = ell.to_dense();
865 let dense_t = ell_t.to_dense();
866
867 for i in 0..2 {
868 for j in 0..2 {
869 assert!((dense[(i, j)] - dense_t[(j, i)]).abs() < 1e-10);
870 }
871 }
872 }
873
874 #[test]
875 fn test_ell_row_iter() {
876 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
877 let indices = vec![vec![0, 2], vec![1, INVALID_INDEX]];
878
879 let ell = EllMatrix::new(2, 3, 2, data, indices).unwrap();
880
881 let row0: Vec<_> = ell.row_iter(0).collect();
882 assert_eq!(row0, vec![(0, &1.0), (2, &2.0)]);
883
884 let row1: Vec<_> = ell.row_iter(1).collect();
885 assert_eq!(row1, vec![(1, &3.0)]);
886 }
887
888 #[test]
889 fn test_ell_invalid_column_index() {
890 let data = vec![vec![1.0]];
891 let indices = vec![vec![10]]; let result = EllMatrix::new(1, 3, 1, data, indices);
894 assert!(matches!(result, Err(EllError::InvalidColumnIndex { .. })));
895 }
896
897 #[test]
898 fn test_ell_zeros() {
899 let ell: EllMatrix<f64> = EllMatrix::zeros(5, 3);
900
901 assert_eq!(ell.nrows(), 5);
902 assert_eq!(ell.ncols(), 3);
903 assert_eq!(ell.width(), 0);
904 assert_eq!(ell.nnz(), 0);
905 }
906}