oxiblas_sparse/
ell.rs

1//! ELLPACK (ELL) sparse matrix format.
2//!
3//! ELL stores matrix data using:
4//! - `data`: 2D array of shape (nrows × max_nnz_per_row)
5//! - `indices`: 2D array of column indices, same shape as data
6//!
7//! For an m×n matrix:
8//! - Each row stores exactly `max_nnz_per_row` entries
9//! - Rows with fewer non-zeros are padded with zeros and invalid indices
10//!
11//! # When to Use ELL
12//!
13//! ELL format is optimal for:
14//! - Matrices with roughly uniform number of non-zeros per row
15//! - GPU computation (enables coalesced memory access)
16//! - Vector processors with SIMD operations
17//!
18//! ELL is NOT efficient for:
19//! - Matrices with varying non-zeros per row (wastes memory on padding)
20//! - Power-law graphs (a few rows have many entries)
21
22use oxiblas_core::scalar::{Field, Scalar};
23
24/// Error type for ELL matrix operations.
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum EllError {
27    /// Data array has wrong dimensions.
28    InvalidDataDimensions {
29        /// Expected rows.
30        expected_rows: usize,
31        /// Actual rows.
32        actual_rows: usize,
33        /// Expected columns per row.
34        expected_width: usize,
35        /// Actual columns per row.
36        actual_width: usize,
37    },
38    /// Data and indices have different dimensions.
39    DimensionMismatch {
40        /// Data dimensions (rows, width).
41        data_dims: (usize, usize),
42        /// Indices dimensions (rows, width).
43        indices_dims: (usize, usize),
44    },
45    /// Column index out of bounds.
46    InvalidColumnIndex {
47        /// Row where invalid index found.
48        row: usize,
49        /// Position within row.
50        pos: usize,
51        /// The invalid index.
52        index: usize,
53        /// Number of columns.
54        ncols: usize,
55    },
56    /// Too many non-zeros in a row.
57    TooManyNonZeros {
58        /// Row with too many non-zeros.
59        row: usize,
60        /// Number of non-zeros.
61        nnz: usize,
62        /// Maximum allowed.
63        max_nnz: usize,
64    },
65}
66
67impl core::fmt::Display for EllError {
68    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
69        match self {
70            Self::InvalidDataDimensions {
71                expected_rows,
72                actual_rows,
73                expected_width,
74                actual_width,
75            } => {
76                write!(
77                    f,
78                    "Invalid data dimensions: expected {expected_rows}×{expected_width}, got {actual_rows}×{actual_width}"
79                )
80            }
81            Self::DimensionMismatch {
82                data_dims,
83                indices_dims,
84            } => {
85                write!(
86                    f,
87                    "Dimension mismatch: data is {}×{}, indices is {}×{}",
88                    data_dims.0, data_dims.1, indices_dims.0, indices_dims.1
89                )
90            }
91            Self::InvalidColumnIndex {
92                row,
93                pos,
94                index,
95                ncols,
96            } => {
97                write!(
98                    f,
99                    "Invalid column index {index} at row {row}, position {pos} (ncols={ncols})"
100                )
101            }
102            Self::TooManyNonZeros { row, nnz, max_nnz } => {
103                write!(f, "Row {row} has {nnz} non-zeros, exceeds max {max_nnz}")
104            }
105        }
106    }
107}
108
109impl std::error::Error for EllError {}
110
111/// ELLPACK sparse matrix format.
112///
113/// Efficient for:
114/// - GPU computation
115/// - Vectorized operations
116/// - Matrices with uniform row lengths
117///
118/// # Storage
119///
120/// Each row stores exactly `width` entries. The `width` is typically the maximum
121/// number of non-zeros in any row. Rows with fewer non-zeros are padded with
122/// zeros and a special "invalid" column index (usually ncols or usize::MAX).
123///
124/// # Example
125///
126/// ```
127/// use oxiblas_sparse::EllMatrix;
128///
129/// // Create a sparse matrix:
130/// // [1 2 0 0]
131/// // [0 3 4 0]
132/// // [5 0 0 6]
133/// let width = 2; // max 2 non-zeros per row
134/// let data = vec![
135///     vec![1.0, 2.0],  // row 0
136///     vec![3.0, 4.0],  // row 1
137///     vec![5.0, 6.0],  // row 2
138/// ];
139/// let indices = vec![
140///     vec![0, 1],  // row 0
141///     vec![1, 2],  // row 1
142///     vec![0, 3],  // row 2
143/// ];
144///
145/// let ell = EllMatrix::new(3, 4, width, data, indices).unwrap();
146/// assert_eq!(ell.width(), 2);
147/// ```
148#[derive(Debug, Clone)]
149pub struct EllMatrix<T: Scalar> {
150    /// Number of rows.
151    nrows: usize,
152    /// Number of columns.
153    ncols: usize,
154    /// Maximum number of non-zeros per row (width of data/indices arrays).
155    width: usize,
156    /// Data array: data[row][k] is the k-th non-zero value in row.
157    data: Vec<Vec<T>>,
158    /// Column indices: indices[row][k] is the column of data[row][k].
159    indices: Vec<Vec<usize>>,
160}
161
162/// Sentinel value for invalid/padding column indices.
163const INVALID_INDEX: usize = usize::MAX;
164
165impl<T: Scalar + Clone> EllMatrix<T> {
166    /// Creates a new ELL matrix from raw components.
167    ///
168    /// # Arguments
169    ///
170    /// * `nrows` - Number of rows
171    /// * `ncols` - Number of columns
172    /// * `width` - Maximum non-zeros per row
173    /// * `data` - Data array (nrows × width)
174    /// * `indices` - Column indices array (nrows × width)
175    ///
176    /// # Errors
177    ///
178    /// Returns an error if the input is invalid.
179    pub fn new(
180        nrows: usize,
181        ncols: usize,
182        width: usize,
183        data: Vec<Vec<T>>,
184        indices: Vec<Vec<usize>>,
185    ) -> Result<Self, EllError> {
186        // Validate data dimensions
187        if data.len() != nrows {
188            return Err(EllError::InvalidDataDimensions {
189                expected_rows: nrows,
190                actual_rows: data.len(),
191                expected_width: width,
192                actual_width: if data.is_empty() { 0 } else { data[0].len() },
193            });
194        }
195
196        for (i, row) in data.iter().enumerate() {
197            if row.len() != width {
198                return Err(EllError::InvalidDataDimensions {
199                    expected_rows: nrows,
200                    actual_rows: data.len(),
201                    expected_width: width,
202                    actual_width: row.len(),
203                });
204            }
205
206            // Check corresponding indices row
207            if i < indices.len() && indices[i].len() != width {
208                return Err(EllError::DimensionMismatch {
209                    data_dims: (nrows, width),
210                    indices_dims: (indices.len(), indices[i].len()),
211                });
212            }
213        }
214
215        // Validate indices dimensions
216        if indices.len() != nrows {
217            return Err(EllError::DimensionMismatch {
218                data_dims: (nrows, width),
219                indices_dims: (
220                    indices.len(),
221                    if indices.is_empty() {
222                        0
223                    } else {
224                        indices[0].len()
225                    },
226                ),
227            });
228        }
229
230        // Validate column indices
231        for (row, row_indices) in indices.iter().enumerate() {
232            for (pos, &col) in row_indices.iter().enumerate() {
233                if col != INVALID_INDEX && col >= ncols {
234                    return Err(EllError::InvalidColumnIndex {
235                        row,
236                        pos,
237                        index: col,
238                        ncols,
239                    });
240                }
241            }
242        }
243
244        Ok(Self {
245            nrows,
246            ncols,
247            width,
248            data,
249            indices,
250        })
251    }
252
253    /// Creates an ELL matrix without validation (unsafe but faster).
254    ///
255    /// # Safety
256    ///
257    /// The caller must ensure:
258    /// - `data.len() == nrows` and each row has length `width`
259    /// - `indices.len() == nrows` and each row has length `width`
260    /// - All valid column indices are < ncols
261    #[inline]
262    pub unsafe fn new_unchecked(
263        nrows: usize,
264        ncols: usize,
265        width: usize,
266        data: Vec<Vec<T>>,
267        indices: Vec<Vec<usize>>,
268    ) -> Self {
269        Self {
270            nrows,
271            ncols,
272            width,
273            data,
274            indices,
275        }
276    }
277
278    /// Creates an empty ELL matrix with given dimensions.
279    pub fn zeros(nrows: usize, ncols: usize) -> Self {
280        Self {
281            nrows,
282            ncols,
283            width: 0,
284            data: vec![Vec::new(); nrows],
285            indices: vec![Vec::new(); nrows],
286        }
287    }
288
289    /// Creates an identity matrix in ELL format.
290    pub fn eye(n: usize) -> Self
291    where
292        T: Field,
293    {
294        Self {
295            nrows: n,
296            ncols: n,
297            width: 1,
298            data: (0..n).map(|_| vec![T::one()]).collect(),
299            indices: (0..n).map(|i| vec![i]).collect(),
300        }
301    }
302
303    /// Returns the number of rows.
304    #[inline]
305    pub fn nrows(&self) -> usize {
306        self.nrows
307    }
308
309    /// Returns the number of columns.
310    #[inline]
311    pub fn ncols(&self) -> usize {
312        self.ncols
313    }
314
315    /// Returns the shape (nrows, ncols).
316    #[inline]
317    pub fn shape(&self) -> (usize, usize) {
318        (self.nrows, self.ncols)
319    }
320
321    /// Returns the width (max non-zeros per row).
322    #[inline]
323    pub fn width(&self) -> usize {
324        self.width
325    }
326
327    /// Returns the number of non-zero elements.
328    ///
329    /// Note: This counts actual non-zeros, not stored values.
330    pub fn nnz(&self) -> usize
331    where
332        T: Field,
333    {
334        let eps = <T as Scalar>::epsilon();
335        let mut count = 0;
336
337        for (row, indices_row) in self.indices.iter().enumerate() {
338            for (k, &col) in indices_row.iter().enumerate() {
339                if col != INVALID_INDEX && Scalar::abs(self.data[row][k].clone()) > eps {
340                    count += 1;
341                }
342            }
343        }
344
345        count
346    }
347
348    /// Returns the total stored values (including padding).
349    #[inline]
350    pub fn nstored(&self) -> usize {
351        self.nrows * self.width
352    }
353
354    /// Returns the storage efficiency (nnz / nstored).
355    pub fn efficiency(&self) -> f64
356    where
357        T: Field,
358    {
359        if self.nstored() == 0 {
360            1.0
361        } else {
362            self.nnz() as f64 / self.nstored() as f64
363        }
364    }
365
366    /// Returns a reference to the data array.
367    #[inline]
368    pub fn data(&self) -> &[Vec<T>] {
369        &self.data
370    }
371
372    /// Returns a mutable reference to the data array.
373    #[inline]
374    pub fn data_mut(&mut self) -> &mut [Vec<T>] {
375        &mut self.data
376    }
377
378    /// Returns a reference to the indices array.
379    #[inline]
380    pub fn indices(&self) -> &[Vec<usize>] {
381        &self.indices
382    }
383
384    /// Gets the value at (row, col), returning None if not present.
385    pub fn get(&self, row: usize, col: usize) -> Option<&T> {
386        if row >= self.nrows || col >= self.ncols {
387            return None;
388        }
389
390        for k in 0..self.width {
391            if self.indices[row][k] == col {
392                return Some(&self.data[row][k]);
393            }
394        }
395
396        None
397    }
398
399    /// Gets the value at (row, col), returning zero if not present.
400    pub fn get_or_zero(&self, row: usize, col: usize) -> T
401    where
402        T: Field,
403    {
404        self.get(row, col).cloned().unwrap_or_else(T::zero)
405    }
406
407    /// Returns an iterator over non-zeros in a row as (col, value).
408    pub fn row_iter(&self, row: usize) -> impl Iterator<Item = (usize, &T)> {
409        self.indices[row]
410            .iter()
411            .zip(self.data[row].iter())
412            .filter(|(col, _)| **col != INVALID_INDEX)
413            .map(|(col, val)| (*col, val))
414    }
415
416    /// Returns an iterator over all non-zeros as (row, col, value).
417    pub fn iter(&self) -> impl Iterator<Item = (usize, usize, &T)> + '_ {
418        (0..self.nrows).flat_map(move |row| {
419            self.indices[row]
420                .iter()
421                .zip(self.data[row].iter())
422                .filter(|(col, _)| **col != INVALID_INDEX)
423                .map(move |(col, val)| (row, *col, val))
424        })
425    }
426
427    /// Matrix-vector product: y = A * x.
428    pub fn matvec(&self, x: &[T], y: &mut [T])
429    where
430        T: Field,
431    {
432        assert_eq!(x.len(), self.ncols, "x length must equal ncols");
433        assert_eq!(y.len(), self.nrows, "y length must equal nrows");
434
435        for row in 0..self.nrows {
436            let mut sum = T::zero();
437            for k in 0..self.width {
438                let col = self.indices[row][k];
439                if col != INVALID_INDEX {
440                    sum = sum + self.data[row][k].clone() * x[col].clone();
441                }
442            }
443            y[row] = sum;
444        }
445    }
446
447    /// Matrix-vector product returning a new vector: y = A * x.
448    pub fn mul_vec(&self, x: &[T]) -> Vec<T>
449    where
450        T: Field,
451    {
452        let mut y = vec![T::zero(); self.nrows];
453        self.matvec(x, &mut y);
454        y
455    }
456
457    /// Converts to CSR format.
458    pub fn to_csr(&self) -> crate::csr::CsrMatrix<T>
459    where
460        T: Field,
461    {
462        let eps = <T as Scalar>::epsilon();
463
464        let mut row_ptrs = vec![0usize; self.nrows + 1];
465        let mut col_indices = Vec::new();
466        let mut values = Vec::new();
467
468        for row in 0..self.nrows {
469            let mut row_entries: Vec<(usize, T)> = Vec::new();
470
471            for k in 0..self.width {
472                let col = self.indices[row][k];
473                if col != INVALID_INDEX {
474                    let val = self.data[row][k].clone();
475                    if Scalar::abs(val.clone()) > eps {
476                        row_entries.push((col, val));
477                    }
478                }
479            }
480
481            // Sort by column index
482            row_entries.sort_by_key(|(col, _)| *col);
483
484            for (col, val) in row_entries {
485                col_indices.push(col);
486                values.push(val);
487            }
488            row_ptrs[row + 1] = values.len();
489        }
490
491        // Safety: we constructed valid CSR data
492        unsafe {
493            crate::csr::CsrMatrix::new_unchecked(
494                self.nrows,
495                self.ncols,
496                row_ptrs,
497                col_indices,
498                values,
499            )
500        }
501    }
502
503    /// Converts to dense matrix.
504    pub fn to_dense(&self) -> oxiblas_matrix::Mat<T>
505    where
506        T: Field + bytemuck::Zeroable,
507    {
508        let mut dense = oxiblas_matrix::Mat::zeros(self.nrows, self.ncols);
509
510        for row in 0..self.nrows {
511            for k in 0..self.width {
512                let col = self.indices[row][k];
513                if col != INVALID_INDEX {
514                    dense[(row, col)] = self.data[row][k].clone();
515                }
516            }
517        }
518
519        dense
520    }
521
522    /// Creates an ELL matrix from a dense matrix.
523    ///
524    /// # Arguments
525    ///
526    /// * `dense` - Source dense matrix
527    /// * `max_width` - Maximum width (if None, uses actual max non-zeros per row)
528    pub fn from_dense(dense: &oxiblas_matrix::MatRef<'_, T>, max_width: Option<usize>) -> Self
529    where
530        T: Field,
531    {
532        let (nrows, ncols) = dense.shape();
533        let eps = <T as Scalar>::epsilon();
534
535        // First pass: find max non-zeros per row
536        let mut row_nnz = vec![0usize; nrows];
537        for i in 0..nrows {
538            for j in 0..ncols {
539                if Scalar::abs(dense[(i, j)].clone()) > eps {
540                    row_nnz[i] += 1;
541                }
542            }
543        }
544
545        let width = max_width.unwrap_or_else(|| row_nnz.iter().copied().max().unwrap_or(0));
546
547        // Build data and indices
548        let mut data = Vec::with_capacity(nrows);
549        let mut indices = Vec::with_capacity(nrows);
550
551        for i in 0..nrows {
552            let mut row_data = Vec::with_capacity(width);
553            let mut row_indices = Vec::with_capacity(width);
554
555            for j in 0..ncols {
556                if row_data.len() >= width {
557                    break;
558                }
559                let val = dense[(i, j)].clone();
560                if Scalar::abs(val.clone()) > eps {
561                    row_data.push(val);
562                    row_indices.push(j);
563                }
564            }
565
566            // Pad to width
567            while row_data.len() < width {
568                row_data.push(T::zero());
569                row_indices.push(INVALID_INDEX);
570            }
571
572            data.push(row_data);
573            indices.push(row_indices);
574        }
575
576        // Safety: we constructed valid ELL data
577        unsafe { Self::new_unchecked(nrows, ncols, width, data, indices) }
578    }
579
580    /// Creates an ELL matrix from CSR format.
581    ///
582    /// # Arguments
583    ///
584    /// * `csr` - Source CSR matrix
585    /// * `max_width` - Maximum width (if None, uses actual max non-zeros per row)
586    pub fn from_csr(
587        csr: &crate::csr::CsrMatrix<T>,
588        max_width: Option<usize>,
589    ) -> Result<Self, EllError>
590    where
591        T: Field,
592    {
593        let (nrows, ncols) = csr.shape();
594        let row_ptrs = csr.row_ptrs();
595        let csr_indices = csr.col_indices();
596        let csr_values = csr.values();
597
598        // Find max row width
599        let actual_max: usize = (0..nrows)
600            .map(|i| row_ptrs[i + 1] - row_ptrs[i])
601            .max()
602            .unwrap_or(0);
603
604        let width = max_width.unwrap_or(actual_max);
605
606        // Check if any row exceeds max_width
607        if let Some(max_w) = max_width {
608            for row in 0..nrows {
609                let row_nnz = row_ptrs[row + 1] - row_ptrs[row];
610                if row_nnz > max_w {
611                    return Err(EllError::TooManyNonZeros {
612                        row,
613                        nnz: row_nnz,
614                        max_nnz: max_w,
615                    });
616                }
617            }
618        }
619
620        // Build data and indices
621        let mut data = Vec::with_capacity(nrows);
622        let mut indices = Vec::with_capacity(nrows);
623
624        for row in 0..nrows {
625            let start = row_ptrs[row];
626            let end = row_ptrs[row + 1];
627            let row_nnz = end - start;
628
629            let mut row_data = Vec::with_capacity(width);
630            let mut row_indices = Vec::with_capacity(width);
631
632            for k in 0..row_nnz {
633                row_data.push(csr_values[start + k].clone());
634                row_indices.push(csr_indices[start + k]);
635            }
636
637            // Pad to width
638            while row_data.len() < width {
639                row_data.push(T::zero());
640                row_indices.push(INVALID_INDEX);
641            }
642
643            data.push(row_data);
644            indices.push(row_indices);
645        }
646
647        Ok(Self {
648            nrows,
649            ncols,
650            width,
651            data,
652            indices,
653        })
654    }
655
656    /// Scales all values by a scalar.
657    pub fn scale(&mut self, alpha: T) {
658        for row in &mut self.data {
659            for val in row.iter_mut() {
660                *val = val.clone() * alpha.clone();
661            }
662        }
663    }
664
665    /// Returns a scaled copy of this matrix.
666    pub fn scaled(&self, alpha: T) -> Self {
667        let mut result = self.clone();
668        result.scale(alpha);
669        result
670    }
671
672    /// Returns the transpose of this matrix.
673    ///
674    /// Note: This is less efficient than CSR/CSC transpose.
675    pub fn transpose(&self) -> Self
676    where
677        T: Field,
678    {
679        // Convert to CSR, transpose, convert back
680        let csr = self.to_csr();
681        let csr_t = csr.transpose();
682        Self::from_csr(&csr_t, Some(self.width)).unwrap_or_else(|_| {
683            // If width is insufficient, use actual max
684            Self::from_csr(&csr_t, None).expect("CSR transpose should be valid")
685        })
686    }
687}
688
689#[cfg(test)]
690mod tests {
691    use super::*;
692
693    #[test]
694    fn test_ell_new() {
695        // [1 2 0 0]
696        // [0 3 4 0]
697        // [5 0 0 6]
698        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
699        let indices = vec![vec![0, 1], vec![1, 2], vec![0, 3]];
700
701        let ell = EllMatrix::new(3, 4, 2, data, indices).unwrap();
702
703        assert_eq!(ell.nrows(), 3);
704        assert_eq!(ell.ncols(), 4);
705        assert_eq!(ell.width(), 2);
706    }
707
708    #[test]
709    fn test_ell_get() {
710        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
711        let indices = vec![vec![0, 1], vec![1, 2], vec![0, 3]];
712
713        let ell = EllMatrix::new(3, 4, 2, data, indices).unwrap();
714
715        assert_eq!(ell.get(0, 0), Some(&1.0));
716        assert_eq!(ell.get(0, 1), Some(&2.0));
717        assert_eq!(ell.get(1, 1), Some(&3.0));
718        assert_eq!(ell.get(1, 2), Some(&4.0));
719        assert_eq!(ell.get(2, 0), Some(&5.0));
720        assert_eq!(ell.get(2, 3), Some(&6.0));
721
722        // Zero elements
723        assert_eq!(ell.get(0, 2), None);
724        assert_eq!(ell.get(0, 3), None);
725    }
726
727    #[test]
728    fn test_ell_matvec() {
729        // [1 2 0 0]   [1]   [3]
730        // [0 3 4 0] * [1] = [7]
731        // [5 0 0 6]   [1]   [11]
732        //             [1]
733        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
734        let indices = vec![vec![0, 1], vec![1, 2], vec![0, 3]];
735
736        let ell = EllMatrix::new(3, 4, 2, data, indices).unwrap();
737        let x = vec![1.0, 1.0, 1.0, 1.0];
738        let y = ell.mul_vec(&x);
739
740        assert!((y[0] - 3.0).abs() < 1e-10);
741        assert!((y[1] - 7.0).abs() < 1e-10);
742        assert!((y[2] - 11.0).abs() < 1e-10);
743    }
744
745    #[test]
746    fn test_ell_with_padding() {
747        // [1 0 0]
748        // [2 3 4]
749        // [0 5 0]
750        let data = vec![
751            vec![1.0, 0.0, 0.0], // row 0: 1 value, padded
752            vec![2.0, 3.0, 4.0], // row 1: 3 values
753            vec![5.0, 0.0, 0.0], // row 2: 1 value, padded
754        ];
755        let indices = vec![
756            vec![0, INVALID_INDEX, INVALID_INDEX],
757            vec![0, 1, 2],
758            vec![1, INVALID_INDEX, INVALID_INDEX],
759        ];
760
761        let ell = EllMatrix::new(3, 3, 3, data, indices).unwrap();
762
763        assert_eq!(ell.nnz(), 5);
764        assert_eq!(ell.nstored(), 9);
765        assert!((ell.efficiency() - 5.0 / 9.0).abs() < 1e-10);
766    }
767
768    #[test]
769    fn test_ell_eye() {
770        let ell: EllMatrix<f64> = EllMatrix::eye(4);
771
772        assert_eq!(ell.nrows(), 4);
773        assert_eq!(ell.ncols(), 4);
774        assert_eq!(ell.width(), 1);
775
776        for i in 0..4 {
777            assert_eq!(ell.get(i, i), Some(&1.0));
778        }
779    }
780
781    #[test]
782    fn test_ell_to_dense() {
783        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
784        let indices = vec![vec![0, 1], vec![1, 2], vec![0, 3]];
785
786        let ell = EllMatrix::new(3, 4, 2, data, indices).unwrap();
787        let dense = ell.to_dense();
788
789        assert!((dense[(0, 0)] - 1.0).abs() < 1e-10);
790        assert!((dense[(0, 1)] - 2.0).abs() < 1e-10);
791        assert!((dense[(0, 2)] - 0.0).abs() < 1e-10);
792        assert!((dense[(1, 1)] - 3.0).abs() < 1e-10);
793        assert!((dense[(1, 2)] - 4.0).abs() < 1e-10);
794        assert!((dense[(2, 0)] - 5.0).abs() < 1e-10);
795        assert!((dense[(2, 3)] - 6.0).abs() < 1e-10);
796    }
797
798    #[test]
799    fn test_ell_to_csr() {
800        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
801        let indices = vec![vec![0, 1], vec![1, 2], vec![0, 3]];
802
803        let ell = EllMatrix::new(3, 4, 2, data, indices).unwrap();
804        let csr = ell.to_csr();
805
806        assert_eq!(csr.nrows(), 3);
807        assert_eq!(csr.ncols(), 4);
808        assert_eq!(csr.nnz(), 6);
809        assert_eq!(csr.get(0, 0), Some(&1.0));
810        assert_eq!(csr.get(2, 3), Some(&6.0));
811    }
812
813    #[test]
814    fn test_ell_from_dense() {
815        use oxiblas_matrix::Mat;
816
817        let dense = Mat::from_rows(&[
818            &[1.0f64, 2.0, 0.0, 0.0],
819            &[0.0, 3.0, 4.0, 0.0],
820            &[5.0, 0.0, 0.0, 6.0],
821        ]);
822
823        let ell = EllMatrix::from_dense(&dense.as_ref(), None);
824
825        assert_eq!(ell.width(), 2);
826        assert_eq!(ell.get(0, 0), Some(&1.0));
827        assert_eq!(ell.get(1, 2), Some(&4.0));
828    }
829
830    #[test]
831    fn test_ell_from_csr() {
832        let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
833        let col_indices = vec![0, 1, 1, 2, 0, 3];
834        let row_ptrs = vec![0, 2, 4, 6];
835
836        let csr = crate::csr::CsrMatrix::new(3, 4, row_ptrs, col_indices, values).unwrap();
837        let ell = EllMatrix::from_csr(&csr, None).unwrap();
838
839        assert_eq!(ell.width(), 2);
840        assert_eq!(ell.get(0, 0), Some(&1.0));
841        assert_eq!(ell.get(2, 3), Some(&6.0));
842    }
843
844    #[test]
845    fn test_ell_scale() {
846        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
847        let indices = vec![vec![0, 1], vec![0, 1]];
848
849        let mut ell = EllMatrix::new(2, 2, 2, data, indices).unwrap();
850        ell.scale(2.0);
851
852        assert_eq!(ell.get(0, 0), Some(&2.0));
853        assert_eq!(ell.get(0, 1), Some(&4.0));
854    }
855
856    #[test]
857    fn test_ell_transpose() {
858        let data = vec![vec![1.0, 2.0], vec![3.0, 0.0]];
859        let indices = vec![vec![0, 1], vec![0, INVALID_INDEX]];
860
861        let ell = EllMatrix::new(2, 2, 2, data, indices).unwrap();
862        let ell_t = ell.transpose();
863
864        let dense = ell.to_dense();
865        let dense_t = ell_t.to_dense();
866
867        for i in 0..2 {
868            for j in 0..2 {
869                assert!((dense[(i, j)] - dense_t[(j, i)]).abs() < 1e-10);
870            }
871        }
872    }
873
874    #[test]
875    fn test_ell_row_iter() {
876        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
877        let indices = vec![vec![0, 2], vec![1, INVALID_INDEX]];
878
879        let ell = EllMatrix::new(2, 3, 2, data, indices).unwrap();
880
881        let row0: Vec<_> = ell.row_iter(0).collect();
882        assert_eq!(row0, vec![(0, &1.0), (2, &2.0)]);
883
884        let row1: Vec<_> = ell.row_iter(1).collect();
885        assert_eq!(row1, vec![(1, &3.0)]);
886    }
887
888    #[test]
889    fn test_ell_invalid_column_index() {
890        let data = vec![vec![1.0]];
891        let indices = vec![vec![10]]; // Out of bounds
892
893        let result = EllMatrix::new(1, 3, 1, data, indices);
894        assert!(matches!(result, Err(EllError::InvalidColumnIndex { .. })));
895    }
896
897    #[test]
898    fn test_ell_zeros() {
899        let ell: EllMatrix<f64> = EllMatrix::zeros(5, 3);
900
901        assert_eq!(ell.nrows(), 5);
902        assert_eq!(ell.ncols(), 3);
903        assert_eq!(ell.width(), 0);
904        assert_eq!(ell.nnz(), 0);
905    }
906}