oxiblas_sparse/
hyb.rs

1//! Hybrid ELL+COO (HYB) matrix format.
2//!
3//! HYB stores sparse matrices using a combination of:
4//! - **ELL part**: For rows with up to K non-zeros (K is the ELL width)
5//! - **COO part**: For overflow entries beyond K per row
6//!
7//! This hybrid approach provides:
8//! - Efficient vectorized operations for uniform parts (ELL)
9//! - Flexibility for rows with many entries (COO overflow)
10//! - Better GPU performance than pure ELL for irregular matrices
11//!
12//! # When to Use HYB
13//!
14//! HYB format is optimal for:
15//! - Matrices with mostly uniform row lengths but some outliers
16//! - GPU computation where ELL alone wastes too much memory
17//! - Power-law degree distributions (social networks, web graphs)
18//!
19//! # K Selection
20//!
21//! The ELL width K can be:
22//! - Automatic: Based on mean + stddev of row lengths
23//! - Manual: User-specified threshold
24//! - Median-based: Use median row length
25
26use crate::coo::CooMatrix;
27use crate::csr::CsrMatrix;
28use crate::ell::EllMatrix;
29use oxiblas_core::scalar::{Field, Scalar};
30
31/// Error type for HYB matrix operations.
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum HybError {
34    /// Invalid dimensions.
35    InvalidDimensions {
36        /// Number of rows.
37        nrows: usize,
38        /// Number of columns.
39        ncols: usize,
40    },
41    /// ELL width is zero.
42    ZeroEllWidth,
43    /// Incompatible ELL and COO dimensions.
44    IncompatibleParts {
45        /// ELL dimensions.
46        ell_shape: (usize, usize),
47        /// COO dimensions.
48        coo_shape: (usize, usize),
49    },
50}
51
52impl core::fmt::Display for HybError {
53    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
54        match self {
55            Self::InvalidDimensions { nrows, ncols } => {
56                write!(f, "Invalid dimensions: {nrows}×{ncols}")
57            }
58            Self::ZeroEllWidth => {
59                write!(f, "ELL width cannot be zero")
60            }
61            Self::IncompatibleParts {
62                ell_shape,
63                coo_shape,
64            } => {
65                write!(
66                    f,
67                    "ELL shape {:?} incompatible with COO shape {:?}",
68                    ell_shape, coo_shape
69                )
70            }
71        }
72    }
73}
74
75impl std::error::Error for HybError {}
76
77/// Strategy for determining ELL width in HYB format.
78#[derive(Debug, Clone, Copy, PartialEq)]
79pub enum HybWidthStrategy {
80    /// Fixed ELL width.
81    Fixed(usize),
82    /// Use mean row length.
83    Mean,
84    /// Use mean + k*stddev row length.
85    MeanPlusStddev(f64),
86    /// Use median row length.
87    Median,
88    /// Use a specific percentile of row lengths (0.0 to 1.0).
89    Percentile(f64),
90    /// Maximum row length (equivalent to pure ELL).
91    Max,
92}
93
94impl Default for HybWidthStrategy {
95    fn default() -> Self {
96        // Default: mean + 1 stddev catches ~84% of rows in ELL
97        Self::MeanPlusStddev(1.0)
98    }
99}
100
101/// Hybrid ELL+COO matrix format.
102///
103/// Combines ELLPACK for regular entries with COO for overflow.
104///
105/// # Storage
106///
107/// - ELL part: Fixed-width storage for up to K entries per row
108/// - COO part: Overflow entries beyond K per row
109///
110/// # Example
111///
112/// ```
113/// use oxiblas_sparse::{CsrMatrix, HybMatrix, HybWidthStrategy};
114///
115/// // Create a sparse matrix
116/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
117/// let col_indices = vec![0, 2, 1, 0, 2];
118/// let row_ptrs = vec![0, 2, 3, 5];
119/// let csr = CsrMatrix::new(3, 3, row_ptrs, col_indices, values).unwrap();
120///
121/// // Convert to HYB with automatic width selection
122/// let hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::default());
123/// assert_eq!(hyb.shape(), (3, 3));
124/// ```
125#[derive(Debug, Clone)]
126pub struct HybMatrix<T: Scalar> {
127    /// Number of rows.
128    nrows: usize,
129    /// Number of columns.
130    ncols: usize,
131    /// ELL width (max entries per row in ELL part).
132    ell_width: usize,
133    /// ELL data: nrows × ell_width values (row-major).
134    ell_data: Vec<T>,
135    /// ELL column indices: nrows × ell_width indices.
136    ell_indices: Vec<usize>,
137    /// COO row indices for overflow.
138    coo_rows: Vec<usize>,
139    /// COO column indices for overflow.
140    coo_cols: Vec<usize>,
141    /// COO values for overflow.
142    coo_data: Vec<T>,
143}
144
145impl<T: Scalar + Clone> HybMatrix<T> {
146    /// Creates a new HYB matrix from raw components.
147    ///
148    /// # Arguments
149    ///
150    /// * `nrows` - Number of rows
151    /// * `ncols` - Number of columns
152    /// * `ell_width` - Maximum entries per row in ELL part
153    /// * `ell_data` - ELL values (nrows × ell_width)
154    /// * `ell_indices` - ELL column indices (nrows × ell_width)
155    /// * `coo_rows` - COO row indices
156    /// * `coo_cols` - COO column indices
157    /// * `coo_data` - COO values
158    ///
159    /// # Errors
160    ///
161    /// Returns an error if the input is invalid.
162    #[allow(clippy::too_many_arguments)]
163    pub fn new(
164        nrows: usize,
165        ncols: usize,
166        ell_width: usize,
167        ell_data: Vec<T>,
168        ell_indices: Vec<usize>,
169        coo_rows: Vec<usize>,
170        coo_cols: Vec<usize>,
171        coo_data: Vec<T>,
172    ) -> Result<Self, HybError> {
173        if nrows == 0 || ncols == 0 {
174            return Err(HybError::InvalidDimensions { nrows, ncols });
175        }
176
177        let expected_ell_size = nrows * ell_width;
178        if ell_data.len() != expected_ell_size || ell_indices.len() != expected_ell_size {
179            return Err(HybError::InvalidDimensions { nrows, ncols });
180        }
181
182        if coo_rows.len() != coo_cols.len() || coo_rows.len() != coo_data.len() {
183            return Err(HybError::InvalidDimensions { nrows, ncols });
184        }
185
186        Ok(Self {
187            nrows,
188            ncols,
189            ell_width,
190            ell_data,
191            ell_indices,
192            coo_rows,
193            coo_cols,
194            coo_data,
195        })
196    }
197
198    /// Creates an empty HYB matrix.
199    pub fn zeros(nrows: usize, ncols: usize, ell_width: usize) -> Self
200    where
201        T: Field,
202    {
203        let size = nrows * ell_width;
204        Self {
205            nrows,
206            ncols,
207            ell_width,
208            ell_data: vec![T::zero(); size],
209            ell_indices: vec![0; size],
210            coo_rows: Vec::new(),
211            coo_cols: Vec::new(),
212            coo_data: Vec::new(),
213        }
214    }
215
216    /// Creates an identity matrix in HYB format.
217    pub fn eye(n: usize) -> Self
218    where
219        T: Field,
220    {
221        let ell_width = 1;
222        let mut ell_data = Vec::with_capacity(n);
223        let mut ell_indices = Vec::with_capacity(n);
224
225        for i in 0..n {
226            ell_data.push(T::one());
227            ell_indices.push(i);
228        }
229
230        Self {
231            nrows: n,
232            ncols: n,
233            ell_width,
234            ell_data,
235            ell_indices,
236            coo_rows: Vec::new(),
237            coo_cols: Vec::new(),
238            coo_data: Vec::new(),
239        }
240    }
241
242    /// Returns the number of rows.
243    #[inline]
244    pub fn nrows(&self) -> usize {
245        self.nrows
246    }
247
248    /// Returns the number of columns.
249    #[inline]
250    pub fn ncols(&self) -> usize {
251        self.ncols
252    }
253
254    /// Returns the shape (nrows, ncols).
255    #[inline]
256    pub fn shape(&self) -> (usize, usize) {
257        (self.nrows, self.ncols)
258    }
259
260    /// Returns the ELL width.
261    #[inline]
262    pub fn ell_width(&self) -> usize {
263        self.ell_width
264    }
265
266    /// Returns the number of non-zeros in the ELL part.
267    pub fn ell_nnz(&self) -> usize
268    where
269        T: Field,
270    {
271        let eps = <T as Scalar>::epsilon();
272        self.ell_data
273            .iter()
274            .filter(|v| Scalar::abs((*v).clone()) > eps)
275            .count()
276    }
277
278    /// Returns the number of non-zeros in the COO part.
279    #[inline]
280    pub fn coo_nnz(&self) -> usize {
281        self.coo_data.len()
282    }
283
284    /// Returns the total number of non-zeros.
285    pub fn nnz(&self) -> usize
286    where
287        T: Field,
288    {
289        self.ell_nnz() + self.coo_nnz()
290    }
291
292    /// Returns the ELL data.
293    #[inline]
294    pub fn ell_data(&self) -> &[T] {
295        &self.ell_data
296    }
297
298    /// Returns the ELL column indices.
299    #[inline]
300    pub fn ell_indices(&self) -> &[usize] {
301        &self.ell_indices
302    }
303
304    /// Returns the COO row indices.
305    #[inline]
306    pub fn coo_rows(&self) -> &[usize] {
307        &self.coo_rows
308    }
309
310    /// Returns the COO column indices.
311    #[inline]
312    pub fn coo_cols(&self) -> &[usize] {
313        &self.coo_cols
314    }
315
316    /// Returns the COO data.
317    #[inline]
318    pub fn coo_data(&self) -> &[T] {
319        &self.coo_data
320    }
321
322    /// Returns the fraction of entries stored in ELL.
323    pub fn ell_fraction(&self) -> f64
324    where
325        T: Field,
326    {
327        let total = self.nnz();
328        if total == 0 {
329            return 1.0;
330        }
331        self.ell_nnz() as f64 / total as f64
332    }
333
334    /// Returns the storage efficiency (actual nnz / stored values).
335    pub fn storage_efficiency(&self) -> f64
336    where
337        T: Field,
338    {
339        let nnz = self.nnz();
340        if nnz == 0 {
341            return 0.0;
342        }
343        let stored = self.nrows * self.ell_width + self.coo_data.len();
344        nnz as f64 / stored as f64
345    }
346
347    /// Gets the value at (row, col).
348    pub fn get(&self, row: usize, col: usize) -> Option<T>
349    where
350        T: Field,
351    {
352        if row >= self.nrows || col >= self.ncols {
353            return None;
354        }
355
356        let eps = <T as Scalar>::epsilon();
357
358        // Check ELL part
359        let ell_start = row * self.ell_width;
360        for k in 0..self.ell_width {
361            let idx = ell_start + k;
362            if self.ell_indices[idx] == col {
363                let val = self.ell_data[idx].clone();
364                if Scalar::abs(val.clone()) > eps {
365                    return Some(val);
366                }
367            }
368        }
369
370        // Check COO part
371        for (i, (&r, &c)) in self.coo_rows.iter().zip(self.coo_cols.iter()).enumerate() {
372            if r == row && c == col {
373                return Some(self.coo_data[i].clone());
374            }
375        }
376
377        None
378    }
379
380    /// Gets the value at (row, col), returning zero if not present.
381    pub fn get_or_zero(&self, row: usize, col: usize) -> T
382    where
383        T: Field,
384    {
385        self.get(row, col).unwrap_or_else(T::zero)
386    }
387
388    /// Matrix-vector product: y = A * x.
389    pub fn matvec(&self, x: &[T], y: &mut [T])
390    where
391        T: Field,
392    {
393        assert_eq!(x.len(), self.ncols, "x length must equal ncols");
394        assert_eq!(y.len(), self.nrows, "y length must equal nrows");
395
396        let eps = <T as Scalar>::epsilon();
397
398        // Initialize y to zero
399        for yi in y.iter_mut() {
400            *yi = T::zero();
401        }
402
403        // ELL part
404        for row in 0..self.nrows {
405            let ell_start = row * self.ell_width;
406            for k in 0..self.ell_width {
407                let idx = ell_start + k;
408                let val = &self.ell_data[idx];
409                if Scalar::abs(val.clone()) > eps {
410                    let col = self.ell_indices[idx];
411                    y[row] = y[row].clone() + val.clone() * x[col].clone();
412                }
413            }
414        }
415
416        // COO part
417        for (i, (&row, &col)) in self.coo_rows.iter().zip(self.coo_cols.iter()).enumerate() {
418            y[row] = y[row].clone() + self.coo_data[i].clone() * x[col].clone();
419        }
420    }
421
422    /// Matrix-vector product returning a new vector.
423    pub fn mul_vec(&self, x: &[T]) -> Vec<T>
424    where
425        T: Field,
426    {
427        let mut y = vec![T::zero(); self.nrows];
428        self.matvec(x, &mut y);
429        y
430    }
431
432    /// Transposed matrix-vector product: y = A^T * x.
433    pub fn matvec_transpose(&self, x: &[T], y: &mut [T])
434    where
435        T: Field,
436    {
437        assert_eq!(x.len(), self.nrows, "x length must equal nrows");
438        assert_eq!(y.len(), self.ncols, "y length must equal ncols");
439
440        let eps = <T as Scalar>::epsilon();
441
442        // Initialize y to zero
443        for yi in y.iter_mut() {
444            *yi = T::zero();
445        }
446
447        // ELL part
448        for row in 0..self.nrows {
449            let ell_start = row * self.ell_width;
450            for k in 0..self.ell_width {
451                let idx = ell_start + k;
452                let val = &self.ell_data[idx];
453                if Scalar::abs(val.clone()) > eps {
454                    let col = self.ell_indices[idx];
455                    y[col] = y[col].clone() + val.clone() * x[row].clone();
456                }
457            }
458        }
459
460        // COO part
461        for (i, (&row, &col)) in self.coo_rows.iter().zip(self.coo_cols.iter()).enumerate() {
462            y[col] = y[col].clone() + self.coo_data[i].clone() * x[row].clone();
463        }
464    }
465
466    /// Creates HYB matrix from CSR with automatic width selection.
467    pub fn from_csr(csr: &CsrMatrix<T>, strategy: HybWidthStrategy) -> Self
468    where
469        T: Field,
470    {
471        let (nrows, ncols) = csr.shape();
472        let eps = <T as Scalar>::epsilon();
473
474        // Compute row lengths (actual non-zeros)
475        let mut row_lengths: Vec<usize> = Vec::with_capacity(nrows);
476        for row in 0..nrows {
477            let mut count = 0;
478            for (_, val) in csr.row_iter(row) {
479                if Scalar::abs(val.clone()) > eps {
480                    count += 1;
481                }
482            }
483            row_lengths.push(count);
484        }
485
486        // Determine ELL width based on strategy
487        let ell_width = compute_ell_width(&row_lengths, strategy);
488        let ell_width = ell_width.max(1); // At least 1
489
490        // Build HYB from CSR
491        let mut ell_data = vec![T::zero(); nrows * ell_width];
492        let mut ell_indices = vec![0usize; nrows * ell_width];
493        let mut coo_rows = Vec::new();
494        let mut coo_cols = Vec::new();
495        let mut coo_data = Vec::new();
496
497        for row in 0..nrows {
498            let ell_start = row * ell_width;
499            let mut ell_count = 0;
500
501            for (col, val) in csr.row_iter(row) {
502                if Scalar::abs(val.clone()) <= eps {
503                    continue;
504                }
505
506                if ell_count < ell_width {
507                    ell_data[ell_start + ell_count] = val.clone();
508                    ell_indices[ell_start + ell_count] = col;
509                    ell_count += 1;
510                } else {
511                    // Overflow to COO
512                    coo_rows.push(row);
513                    coo_cols.push(col);
514                    coo_data.push(val.clone());
515                }
516            }
517        }
518
519        Self {
520            nrows,
521            ncols,
522            ell_width,
523            ell_data,
524            ell_indices,
525            coo_rows,
526            coo_cols,
527            coo_data,
528        }
529    }
530
531    /// Creates HYB matrix from COO format.
532    pub fn from_coo(coo: &CooMatrix<T>, strategy: HybWidthStrategy) -> Self
533    where
534        T: Scalar<Real = T> + Field + oxiblas_core::Real,
535    {
536        let csr = coo.to_csr();
537        Self::from_csr(&csr, strategy)
538    }
539
540    /// Creates HYB matrix from ELL format (no COO overflow).
541    pub fn from_ell(ell: &EllMatrix<T>) -> Self
542    where
543        T: Field,
544    {
545        let (nrows, ncols) = ell.shape();
546        let ell_width = ell.width();
547
548        // Copy ELL data - convert from Vec<Vec<T>> to flat layout
549        let mut ell_data = Vec::with_capacity(nrows * ell_width);
550        let mut ell_indices = Vec::with_capacity(nrows * ell_width);
551
552        let data = ell.data();
553        let indices = ell.indices();
554
555        for row in 0..nrows {
556            for k in 0..ell_width {
557                ell_data.push(data[row][k].clone());
558                ell_indices.push(indices[row][k]);
559            }
560        }
561
562        Self {
563            nrows,
564            ncols,
565            ell_width,
566            ell_data,
567            ell_indices,
568            coo_rows: Vec::new(),
569            coo_cols: Vec::new(),
570            coo_data: Vec::new(),
571        }
572    }
573
574    /// Converts to CSR format.
575    pub fn to_csr(&self) -> CsrMatrix<T>
576    where
577        T: Field,
578    {
579        let eps = <T as Scalar>::epsilon();
580
581        // Collect all entries
582        let mut entries: Vec<(usize, usize, T)> = Vec::new();
583
584        // ELL entries
585        for row in 0..self.nrows {
586            let ell_start = row * self.ell_width;
587            for k in 0..self.ell_width {
588                let idx = ell_start + k;
589                let val = self.ell_data[idx].clone();
590                if Scalar::abs(val.clone()) > eps {
591                    entries.push((row, self.ell_indices[idx], val));
592                }
593            }
594        }
595
596        // COO entries
597        for (i, (&row, &col)) in self.coo_rows.iter().zip(self.coo_cols.iter()).enumerate() {
598            entries.push((row, col, self.coo_data[i].clone()));
599        }
600
601        // Sort by row, then column
602        entries.sort_by_key(|(r, c, _)| (*r, *c));
603
604        // Build CSR
605        let mut row_ptrs = vec![0usize; self.nrows + 1];
606        let mut col_indices = Vec::with_capacity(entries.len());
607        let mut values = Vec::with_capacity(entries.len());
608
609        for (row, col, val) in entries {
610            col_indices.push(col);
611            values.push(val);
612            row_ptrs[row + 1] += 1;
613        }
614
615        // Cumulative sum
616        for i in 1..=self.nrows {
617            row_ptrs[i] += row_ptrs[i - 1];
618        }
619
620        // Safety: we constructed valid CSR data
621        unsafe { CsrMatrix::new_unchecked(self.nrows, self.ncols, row_ptrs, col_indices, values) }
622    }
623
624    /// Converts to COO format.
625    pub fn to_coo(&self) -> CooMatrix<T>
626    where
627        T: Field,
628    {
629        let eps = <T as Scalar>::epsilon();
630
631        let mut builder = crate::coo::CooMatrixBuilder::new(self.nrows, self.ncols);
632
633        // ELL entries
634        for row in 0..self.nrows {
635            let ell_start = row * self.ell_width;
636            for k in 0..self.ell_width {
637                let idx = ell_start + k;
638                let val = self.ell_data[idx].clone();
639                if Scalar::abs(val.clone()) > eps {
640                    builder.add(row, self.ell_indices[idx], val);
641                }
642            }
643        }
644
645        // COO entries
646        for (i, (&row, &col)) in self.coo_rows.iter().zip(self.coo_cols.iter()).enumerate() {
647            builder.add(row, col, self.coo_data[i].clone());
648        }
649
650        builder.build()
651    }
652
653    /// Converts to ELL format.
654    ///
655    /// Note: This may require increasing the ELL width to accommodate all entries.
656    pub fn to_ell(&self) -> EllMatrix<T>
657    where
658        T: Field,
659    {
660        let eps = <T as Scalar>::epsilon();
661
662        // Compute actual max row length including COO
663        let mut row_lengths = vec![0usize; self.nrows];
664
665        // Count ELL entries
666        for row in 0..self.nrows {
667            let ell_start = row * self.ell_width;
668            for k in 0..self.ell_width {
669                let idx = ell_start + k;
670                if Scalar::abs(self.ell_data[idx].clone()) > eps {
671                    row_lengths[row] += 1;
672                }
673            }
674        }
675
676        // Count COO entries
677        for &row in &self.coo_rows {
678            row_lengths[row] += 1;
679        }
680
681        let max_width = row_lengths.iter().max().copied().unwrap_or(0).max(1);
682
683        // Build ELL with Vec<Vec<T>> format
684        let mut data: Vec<Vec<T>> = vec![vec![T::zero(); max_width]; self.nrows];
685        let mut indices: Vec<Vec<usize>> = vec![vec![0usize; max_width]; self.nrows];
686        let mut current_counts = vec![0usize; self.nrows];
687
688        // Add ELL entries
689        for row in 0..self.nrows {
690            let ell_start = row * self.ell_width;
691            for k in 0..self.ell_width {
692                let idx = ell_start + k;
693                let val = self.ell_data[idx].clone();
694                if Scalar::abs(val.clone()) > eps {
695                    let pos = current_counts[row];
696                    data[row][pos] = val;
697                    indices[row][pos] = self.ell_indices[idx];
698                    current_counts[row] += 1;
699                }
700            }
701        }
702
703        // Add COO entries
704        for (i, (&row, &col)) in self.coo_rows.iter().zip(self.coo_cols.iter()).enumerate() {
705            let pos = current_counts[row];
706            data[row][pos] = self.coo_data[i].clone();
707            indices[row][pos] = col;
708            current_counts[row] += 1;
709        }
710
711        // Safety: we constructed valid ELL data
712        unsafe { EllMatrix::new_unchecked(self.nrows, self.ncols, max_width, data, indices) }
713    }
714
715    /// Converts to dense matrix.
716    pub fn to_dense(&self) -> oxiblas_matrix::Mat<T>
717    where
718        T: Field + bytemuck::Zeroable,
719    {
720        let eps = <T as Scalar>::epsilon();
721        let mut dense = oxiblas_matrix::Mat::zeros(self.nrows, self.ncols);
722
723        // ELL entries
724        for row in 0..self.nrows {
725            let ell_start = row * self.ell_width;
726            for k in 0..self.ell_width {
727                let idx = ell_start + k;
728                let val = self.ell_data[idx].clone();
729                if Scalar::abs(val.clone()) > eps {
730                    let col = self.ell_indices[idx];
731                    dense[(row, col)] = val;
732                }
733            }
734        }
735
736        // COO entries
737        for (i, (&row, &col)) in self.coo_rows.iter().zip(self.coo_cols.iter()).enumerate() {
738            dense[(row, col)] = self.coo_data[i].clone();
739        }
740
741        dense
742    }
743
744    /// Scales all values by a scalar.
745    pub fn scale(&mut self, alpha: T) {
746        for val in &mut self.ell_data {
747            *val = val.clone() * alpha.clone();
748        }
749        for val in &mut self.coo_data {
750            *val = val.clone() * alpha.clone();
751        }
752    }
753
754    /// Returns a scaled copy of this matrix.
755    pub fn scaled(&self, alpha: T) -> Self {
756        let mut result = self.clone();
757        result.scale(alpha);
758        result
759    }
760
761    /// Returns an iterator over non-zero entries as (row, col, value).
762    pub fn iter(&self) -> impl Iterator<Item = (usize, usize, T)> + '_
763    where
764        T: Field,
765    {
766        let eps = <T as Scalar>::epsilon();
767        let ell_width = self.ell_width;
768
769        // ELL entries
770        let ell_iter = (0..self.nrows).flat_map(move |row| {
771            let ell_start = row * ell_width;
772            (0..ell_width).filter_map(move |k| {
773                let idx = ell_start + k;
774                let val = self.ell_data[idx].clone();
775                if Scalar::abs(val.clone()) > eps {
776                    Some((row, self.ell_indices[idx], val))
777                } else {
778                    None
779                }
780            })
781        });
782
783        // COO entries
784        let coo_iter = self
785            .coo_rows
786            .iter()
787            .zip(self.coo_cols.iter())
788            .zip(self.coo_data.iter())
789            .map(|((&row, &col), val)| (row, col, val.clone()));
790
791        ell_iter.chain(coo_iter)
792    }
793
794    /// Rebalances the HYB matrix with a new ELL width.
795    ///
796    /// This redistributes entries between ELL and COO based on the new width.
797    pub fn rebalance(&mut self, new_width: usize)
798    where
799        T: Field,
800    {
801        if new_width == self.ell_width {
802            return;
803        }
804
805        let eps = <T as Scalar>::epsilon();
806
807        // Collect all entries
808        let mut entries: Vec<(usize, usize, T)> = Vec::new();
809
810        // ELL entries
811        for row in 0..self.nrows {
812            let ell_start = row * self.ell_width;
813            for k in 0..self.ell_width {
814                let idx = ell_start + k;
815                let val = self.ell_data[idx].clone();
816                if Scalar::abs(val.clone()) > eps {
817                    entries.push((row, self.ell_indices[idx], val));
818                }
819            }
820        }
821
822        // COO entries
823        for (i, (&row, &col)) in self.coo_rows.iter().zip(self.coo_cols.iter()).enumerate() {
824            entries.push((row, col, self.coo_data[i].clone()));
825        }
826
827        // Sort by row, then column
828        entries.sort_by_key(|(r, c, _)| (*r, *c));
829
830        // Rebuild with new width
831        self.ell_width = new_width;
832        self.ell_data = vec![T::zero(); self.nrows * new_width];
833        self.ell_indices = vec![0usize; self.nrows * new_width];
834        self.coo_rows.clear();
835        self.coo_cols.clear();
836        self.coo_data.clear();
837
838        let mut current_row = 0;
839        let mut count_in_row = 0;
840
841        for (row, col, val) in entries {
842            if row != current_row {
843                current_row = row;
844                count_in_row = 0;
845            }
846
847            if count_in_row < new_width {
848                let idx = row * new_width + count_in_row;
849                self.ell_data[idx] = val;
850                self.ell_indices[idx] = col;
851                count_in_row += 1;
852            } else {
853                self.coo_rows.push(row);
854                self.coo_cols.push(col);
855                self.coo_data.push(val);
856            }
857        }
858    }
859
860    /// Returns statistics about the HYB matrix.
861    pub fn stats(&self) -> HybStats
862    where
863        T: Field,
864    {
865        let eps = <T as Scalar>::epsilon();
866
867        // Compute actual row lengths
868        let mut row_lengths = vec![0usize; self.nrows];
869
870        for row in 0..self.nrows {
871            let ell_start = row * self.ell_width;
872            for k in 0..self.ell_width {
873                let idx = ell_start + k;
874                if Scalar::abs(self.ell_data[idx].clone()) > eps {
875                    row_lengths[row] += 1;
876                }
877            }
878        }
879
880        for &row in &self.coo_rows {
881            row_lengths[row] += 1;
882        }
883
884        let ell_nnz = self.ell_nnz();
885        let coo_nnz = self.coo_nnz();
886        let total_nnz = ell_nnz + coo_nnz;
887
888        let max_row_len = row_lengths.iter().max().copied().unwrap_or(0);
889        let min_row_len = row_lengths.iter().min().copied().unwrap_or(0);
890        let avg_row_len = if self.nrows > 0 {
891            total_nnz as f64 / self.nrows as f64
892        } else {
893            0.0
894        };
895
896        HybStats {
897            nrows: self.nrows,
898            ncols: self.ncols,
899            ell_width: self.ell_width,
900            ell_nnz,
901            coo_nnz,
902            total_nnz,
903            ell_fraction: if total_nnz > 0 {
904                ell_nnz as f64 / total_nnz as f64
905            } else {
906                1.0
907            },
908            max_row_length: max_row_len,
909            min_row_length: min_row_len,
910            avg_row_length: avg_row_len,
911            storage_efficiency: self.storage_efficiency(),
912        }
913    }
914}
915
916/// Statistics about a HYB matrix.
917#[derive(Debug, Clone, Copy)]
918pub struct HybStats {
919    /// Number of rows.
920    pub nrows: usize,
921    /// Number of columns.
922    pub ncols: usize,
923    /// ELL width.
924    pub ell_width: usize,
925    /// Number of non-zeros in ELL part.
926    pub ell_nnz: usize,
927    /// Number of non-zeros in COO part.
928    pub coo_nnz: usize,
929    /// Total non-zeros.
930    pub total_nnz: usize,
931    /// Fraction of entries in ELL (0.0 to 1.0).
932    pub ell_fraction: f64,
933    /// Maximum row length.
934    pub max_row_length: usize,
935    /// Minimum row length.
936    pub min_row_length: usize,
937    /// Average row length.
938    pub avg_row_length: f64,
939    /// Storage efficiency (nnz / stored).
940    pub storage_efficiency: f64,
941}
942
943/// Computes ELL width based on row length statistics.
944fn compute_ell_width(row_lengths: &[usize], strategy: HybWidthStrategy) -> usize {
945    if row_lengths.is_empty() {
946        return 1;
947    }
948
949    match strategy {
950        HybWidthStrategy::Fixed(k) => k,
951
952        HybWidthStrategy::Mean => {
953            let sum: usize = row_lengths.iter().sum();
954            let mean = sum as f64 / row_lengths.len() as f64;
955            mean.ceil() as usize
956        }
957
958        HybWidthStrategy::MeanPlusStddev(k) => {
959            let n = row_lengths.len() as f64;
960            let sum: usize = row_lengths.iter().sum();
961            let mean = sum as f64 / n;
962
963            let variance: f64 = row_lengths
964                .iter()
965                .map(|&x| {
966                    let diff = x as f64 - mean;
967                    diff * diff
968                })
969                .sum::<f64>()
970                / n;
971            let stddev = variance.sqrt();
972
973            (mean + k * stddev).ceil() as usize
974        }
975
976        HybWidthStrategy::Median => {
977            let mut sorted = row_lengths.to_vec();
978            sorted.sort_unstable();
979            let mid = sorted.len() / 2;
980            if sorted.len() % 2 == 0 {
981                (sorted[mid - 1] + sorted[mid]).div_ceil(2)
982            } else {
983                sorted[mid]
984            }
985        }
986
987        HybWidthStrategy::Percentile(p) => {
988            let p = p.clamp(0.0, 1.0);
989            let mut sorted = row_lengths.to_vec();
990            sorted.sort_unstable();
991            let idx = ((sorted.len() - 1) as f64 * p) as usize;
992            sorted[idx]
993        }
994
995        HybWidthStrategy::Max => row_lengths.iter().max().copied().unwrap_or(1),
996    }
997}
998
999#[cfg(test)]
1000mod tests {
1001    use super::*;
1002
1003    fn make_test_csr() -> CsrMatrix<f64> {
1004        // [1 0 2 0]
1005        // [0 3 0 0]
1006        // [4 0 5 6]
1007        // [0 0 0 7]
1008        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
1009        let col_indices = vec![0, 2, 1, 0, 2, 3, 3];
1010        let row_ptrs = vec![0, 2, 3, 6, 7];
1011        CsrMatrix::new(4, 4, row_ptrs, col_indices, values).unwrap()
1012    }
1013
1014    #[test]
1015    fn test_hyb_from_csr_fixed() {
1016        let csr = make_test_csr();
1017        let hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(2));
1018
1019        assert_eq!(hyb.nrows(), 4);
1020        assert_eq!(hyb.ncols(), 4);
1021        assert_eq!(hyb.ell_width(), 2);
1022
1023        // Row 2 has 3 entries, so 1 should overflow to COO
1024        assert_eq!(hyb.coo_nnz(), 1);
1025        assert_eq!(hyb.nnz(), 7);
1026    }
1027
1028    #[test]
1029    fn test_hyb_from_csr_max() {
1030        let csr = make_test_csr();
1031        let hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Max);
1032
1033        assert_eq!(hyb.ell_width(), 3); // Max row has 3 entries
1034        assert_eq!(hyb.coo_nnz(), 0); // No overflow
1035    }
1036
1037    #[test]
1038    fn test_hyb_matvec() {
1039        let csr = make_test_csr();
1040        let hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(2));
1041
1042        // [1 0 2 0]   [1]   [7]
1043        // [0 3 0 0] * [2] = [6]
1044        // [4 0 5 6]   [3]   [43]
1045        // [0 0 0 7]   [4]   [28]
1046        let x = vec![1.0, 2.0, 3.0, 4.0];
1047        let y = hyb.mul_vec(&x);
1048
1049        assert!((y[0] - 7.0).abs() < 1e-10);
1050        assert!((y[1] - 6.0).abs() < 1e-10);
1051        assert!((y[2] - 43.0).abs() < 1e-10);
1052        assert!((y[3] - 28.0).abs() < 1e-10);
1053    }
1054
1055    #[test]
1056    fn test_hyb_matvec_transpose() {
1057        let csr = make_test_csr();
1058        let hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(2));
1059
1060        let x = vec![1.0, 1.0, 1.0, 1.0];
1061        let mut y = vec![0.0; 4];
1062        hyb.matvec_transpose(&x, &mut y);
1063
1064        // A^T * [1,1,1,1] = column sums
1065        assert!((y[0] - 5.0).abs() < 1e-10); // 1 + 4
1066        assert!((y[1] - 3.0).abs() < 1e-10); // 3
1067        assert!((y[2] - 7.0).abs() < 1e-10); // 2 + 5
1068        assert!((y[3] - 13.0).abs() < 1e-10); // 6 + 7
1069    }
1070
1071    #[test]
1072    fn test_hyb_to_csr_roundtrip() {
1073        let csr1 = make_test_csr();
1074        let hyb = HybMatrix::from_csr(&csr1, HybWidthStrategy::Fixed(2));
1075        let csr2 = hyb.to_csr();
1076
1077        assert_eq!(csr1.nnz(), csr2.nnz());
1078
1079        // Check all values match
1080        for row in 0..4 {
1081            for col in 0..4 {
1082                let v1 = csr1.get(row, col).cloned().unwrap_or(0.0);
1083                let v2 = csr2.get(row, col).cloned().unwrap_or(0.0);
1084                assert!((v1 - v2).abs() < 1e-10);
1085            }
1086        }
1087    }
1088
1089    #[test]
1090    fn test_hyb_to_dense() {
1091        let csr = make_test_csr();
1092        let hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(2));
1093        let dense = hyb.to_dense();
1094
1095        assert!((dense[(0, 0)] - 1.0).abs() < 1e-10);
1096        assert!((dense[(0, 2)] - 2.0).abs() < 1e-10);
1097        assert!((dense[(1, 1)] - 3.0).abs() < 1e-10);
1098        assert!((dense[(2, 0)] - 4.0).abs() < 1e-10);
1099        assert!((dense[(2, 2)] - 5.0).abs() < 1e-10);
1100        assert!((dense[(2, 3)] - 6.0).abs() < 1e-10);
1101        assert!((dense[(3, 3)] - 7.0).abs() < 1e-10);
1102    }
1103
1104    #[test]
1105    fn test_hyb_get() {
1106        let csr = make_test_csr();
1107        let hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(2));
1108
1109        assert_eq!(hyb.get(0, 0), Some(1.0));
1110        assert_eq!(hyb.get(0, 2), Some(2.0));
1111        assert_eq!(hyb.get(1, 1), Some(3.0));
1112        assert_eq!(hyb.get(2, 0), Some(4.0));
1113        assert_eq!(hyb.get(2, 2), Some(5.0));
1114        assert_eq!(hyb.get(2, 3), Some(6.0)); // This might be in COO
1115        assert_eq!(hyb.get(3, 3), Some(7.0));
1116
1117        assert_eq!(hyb.get(0, 1), None);
1118    }
1119
1120    #[test]
1121    fn test_hyb_eye() {
1122        let hyb: HybMatrix<f64> = HybMatrix::eye(4);
1123
1124        assert_eq!(hyb.nrows(), 4);
1125        assert_eq!(hyb.ncols(), 4);
1126        assert_eq!(hyb.ell_width(), 1);
1127        assert_eq!(hyb.nnz(), 4);
1128
1129        for i in 0..4 {
1130            assert_eq!(hyb.get(i, i), Some(1.0));
1131        }
1132    }
1133
1134    #[test]
1135    fn test_hyb_scale() {
1136        let csr = make_test_csr();
1137        let mut hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(2));
1138
1139        hyb.scale(2.0);
1140
1141        assert_eq!(hyb.get(0, 0), Some(2.0));
1142        assert_eq!(hyb.get(2, 2), Some(10.0));
1143    }
1144
1145    #[test]
1146    fn test_hyb_rebalance() {
1147        let csr = make_test_csr();
1148        let mut hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(1));
1149
1150        // Initially most entries in COO
1151        assert!(hyb.coo_nnz() > 0);
1152
1153        // Rebalance to width 3
1154        hyb.rebalance(3);
1155
1156        assert_eq!(hyb.ell_width(), 3);
1157        assert_eq!(hyb.coo_nnz(), 0); // All fit in ELL now
1158        assert_eq!(hyb.nnz(), 7);
1159    }
1160
1161    #[test]
1162    fn test_hyb_stats() {
1163        let csr = make_test_csr();
1164        let hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(2));
1165        let stats = hyb.stats();
1166
1167        assert_eq!(stats.nrows, 4);
1168        assert_eq!(stats.ncols, 4);
1169        assert_eq!(stats.ell_width, 2);
1170        assert_eq!(stats.total_nnz, 7);
1171        assert_eq!(stats.max_row_length, 3);
1172        assert_eq!(stats.min_row_length, 1);
1173    }
1174
1175    #[test]
1176    fn test_hyb_iter() {
1177        let csr = make_test_csr();
1178        let hyb = HybMatrix::from_csr(&csr, HybWidthStrategy::Fixed(2));
1179
1180        let entries: Vec<_> = hyb.iter().collect();
1181        assert_eq!(entries.len(), 7);
1182    }
1183
1184    #[test]
1185    fn test_compute_ell_width() {
1186        let row_lengths = vec![1, 2, 5, 3, 2];
1187
1188        assert_eq!(
1189            compute_ell_width(&row_lengths, HybWidthStrategy::Fixed(4)),
1190            4
1191        );
1192        assert_eq!(compute_ell_width(&row_lengths, HybWidthStrategy::Max), 5);
1193        assert_eq!(compute_ell_width(&row_lengths, HybWidthStrategy::Median), 2);
1194
1195        // Mean = 2.6, so ceil = 3
1196        assert_eq!(compute_ell_width(&row_lengths, HybWidthStrategy::Mean), 3);
1197    }
1198
1199    #[test]
1200    fn test_hyb_from_ell() {
1201        // Create a simple ELL from CSR
1202        let csr = make_test_csr();
1203        let ell = crate::ell::EllMatrix::from_csr(&csr, None).unwrap();
1204
1205        let hyb = HybMatrix::from_ell(&ell);
1206
1207        assert_eq!(hyb.coo_nnz(), 0);
1208        assert_eq!(hyb.nnz(), 7);
1209    }
1210}