scirs2_sparse/
banded_array.rs

1//! Banded matrix format for sparse matrices
2//!
3//! Banded matrices are matrices where all non-zero elements are within a band
4//! around the main diagonal. This format is highly efficient for matrices with
5//! this structure, especially for solving linear systems.
6
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
10use scirs2_core::numeric::{Float, One, Zero};
11use scirs2_core::SparseElement;
12use std::fmt::{Debug, Display};
13
14/// Banded array format for sparse matrices
15///
16/// The BandedArray format stores only the non-zero bands of a matrix.
17/// The data is stored in a 2D array where each row represents a diagonal
18/// and each column represents the matrix row.
19///
20/// For a matrix with lower bandwidth `kl` and upper bandwidth `ku`,
21/// the data array has shape `(kl + ku + 1, n)` where `n` is the number
22/// of matrix rows.
23#[derive(Debug, Clone)]
24pub struct BandedArray<T>
25where
26    T: std::ops::AddAssign + std::fmt::Display,
27{
28    /// Band data stored as (kl + ku + 1, n) array
29    data: Array2<T>,
30    /// Lower bandwidth (number of subdiagonals)
31    kl: usize,
32    /// Upper bandwidth (number of superdiagonals)
33    ku: usize,
34    /// Matrix shape
35    shape: (usize, usize),
36}
37
38impl<T> BandedArray<T>
39where
40    T: Float
41        + SparseElement
42        + Debug
43        + Display
44        + Copy
45        + Zero
46        + One
47        + Send
48        + Sync
49        + 'static
50        + std::ops::AddAssign,
51{
52    /// Create a new banded array
53    pub fn new(data: Array2<T>, kl: usize, ku: usize, shape: (usize, usize)) -> SparseResult<Self> {
54        let expected_bands = kl + ku + 1;
55        let (bands, cols) = data.dim();
56
57        if bands != expected_bands {
58            return Err(SparseError::ValueError(format!(
59                "Data array should have {expected_bands} bands, got {bands}"
60            )));
61        }
62
63        if cols != shape.0 {
64            return Err(SparseError::ValueError(format!(
65                "Data array columns {} should match matrix rows {}",
66                cols, shape.0
67            )));
68        }
69
70        Ok(Self {
71            data,
72            kl,
73            ku,
74            shape,
75        })
76    }
77
78    /// Create a new zero banded array
79    pub fn zeros(shape: (usize, usize), kl: usize, ku: usize) -> Self {
80        let bands = kl + ku + 1;
81        let data = Array2::zeros((bands, shape.0));
82
83        Self {
84            data,
85            kl,
86            ku,
87            shape,
88        }
89    }
90
91    /// Create a new identity banded array
92    pub fn eye(n: usize, kl: usize, ku: usize) -> Self {
93        let mut result = Self::zeros((n, n), kl, ku);
94
95        // Set main diagonal to 1
96        for i in 0..n {
97            result.set_unchecked(i, i, T::sparse_one());
98        }
99
100        result
101    }
102
103    /// Create from triplet format (row, col, data)
104    pub fn from_triplets(
105        rows: &[usize],
106        cols: &[usize],
107        data: &[T],
108        shape: (usize, usize),
109        kl: usize,
110        ku: usize,
111    ) -> SparseResult<Self> {
112        let mut result = Self::zeros(shape, kl, ku);
113
114        for (&row, (&col, &value)) in rows.iter().zip(cols.iter().zip(data.iter())) {
115            if row >= shape.0 || col >= shape.1 {
116                return Err(SparseError::ValueError("Index out of bounds".to_string()));
117            }
118
119            if result.is_in_band(row, col) {
120                result.set_unchecked(row, col, value);
121            } else if !SparseElement::is_zero(&value) {
122                return Err(SparseError::ValueError(format!(
123                    "Non-zero element at ({row}, {col}) is outside band structure"
124                )));
125            }
126        }
127
128        Ok(result)
129    }
130
131    /// Create tridiagonal matrix
132    pub fn tridiagonal(diag: &[T], lower: &[T], upper: &[T]) -> SparseResult<Self> {
133        let n = diag.len();
134
135        if lower.len() != n - 1 || upper.len() != n - 1 {
136            return Err(SparseError::ValueError(
137                "Off-diagonal arrays must have length n-1".to_string(),
138            ));
139        }
140
141        let mut result = Self::zeros((n, n), 1, 1);
142
143        // Main diagonal
144        for (i, &val) in diag.iter().enumerate() {
145            result.set_unchecked(i, i, val);
146        }
147
148        // Lower diagonal
149        for (i, &val) in lower.iter().enumerate() {
150            result.set_unchecked(i + 1, i, val);
151        }
152
153        // Upper diagonal
154        for (i, &val) in upper.iter().enumerate() {
155            result.set_unchecked(i, i + 1, val);
156        }
157
158        Ok(result)
159    }
160
161    /// Check if an element is within the band structure
162    pub fn is_in_band(&self, row: usize, col: usize) -> bool {
163        if row >= self.shape.0 || col >= self.shape.1 {
164            return false;
165        }
166
167        let diff = col as isize - row as isize;
168        diff >= -(self.kl as isize) && diff <= self.ku as isize
169    }
170
171    /// Set an element (unchecked for performance)
172    pub fn set_unchecked(&mut self, row: usize, col: usize, value: T) {
173        if let Some(band_idx) = self
174            .ku
175            .checked_add(row)
176            .and_then(|sum| sum.checked_sub(col))
177        {
178            if band_idx < self.data.nrows() {
179                self.data[[band_idx, col]] = value;
180            }
181        }
182    }
183
184    /// Set an element with bounds and band checking
185    pub fn set_direct(&mut self, row: usize, col: usize, value: T) -> SparseResult<()> {
186        if row >= self.shape.0 || col >= self.shape.1 {
187            return Err(SparseError::ValueError(format!(
188                "Index ({}, {}) out of bounds for shape {:?}",
189                row, col, self.shape
190            )));
191        }
192
193        if !self.is_in_band(row, col) {
194            if !SparseElement::is_zero(&value) {
195                return Err(SparseError::ValueError(format!(
196                    "Cannot set non-zero value {value} at ({row}, {col}) - outside band structure"
197                )));
198            }
199            // For zero values outside the band, just ignore (they're implicitly zero)
200            return Ok(());
201        }
202
203        self.set_unchecked(row, col, value);
204        Ok(())
205    }
206
207    /// Get the raw band data
208    pub fn data(&self) -> &Array2<T> {
209        &self.data
210    }
211
212    /// Get mutable reference to the raw band data
213    pub fn data_mut(&mut self) -> &mut Array2<T> {
214        &mut self.data
215    }
216
217    /// Get lower bandwidth
218    pub fn kl(&self) -> usize {
219        self.kl
220    }
221
222    /// Get upper bandwidth
223    pub fn ku(&self) -> usize {
224        self.ku
225    }
226
227    /// Solve a banded linear system using LU decomposition
228    pub fn solve(&self, b: &ArrayView1<T>) -> SparseResult<Array1<T>> {
229        if self.shape.0 != self.shape.1 {
230            return Err(SparseError::ValueError(
231                "Matrix must be square for solving".to_string(),
232            ));
233        }
234
235        if b.len() != self.shape.0 {
236            return Err(SparseError::DimensionMismatch {
237                expected: self.shape.0,
238                found: b.len(),
239            });
240        }
241
242        // Perform banded LU decomposition
243        let (l, u, p) = self.lu_decomposition()?;
244
245        // Solve L * U * x = P * b
246        let pb = apply_permutation(&p, b);
247        let y = l.forward_substitution(&pb.view())?;
248        let x = u.back_substitution(&y.view())?;
249
250        Ok(x)
251    }
252
253    /// LU decomposition for banded matrices
254    pub fn lu_decomposition(&self) -> SparseResult<(BandedArray<T>, BandedArray<T>, Vec<usize>)> {
255        let n = self.shape.0;
256        let mut l = BandedArray::zeros((n, n), self.kl, 0); // Lower triangular
257        let mut u = self.clone(); // Will become upper triangular
258        let mut p: Vec<usize> = (0..n).collect(); // Permutation vector
259
260        // Gaussian elimination with partial pivoting
261        for k in 0..(n - 1) {
262            // Find pivot within the band
263            let mut pivot_row = k;
264            let mut max_val = u.get(k, k).abs();
265
266            for i in (k + 1)..(k + 1 + self.kl).min(n) {
267                let val = u.get(i, k).abs();
268                if val > max_val {
269                    max_val = val;
270                    pivot_row = i;
271                }
272            }
273
274            // Swap rows if needed
275            if pivot_row != k {
276                u.swap_rows(k, pivot_row);
277                l.swap_rows(k, pivot_row);
278                p.swap(k, pivot_row);
279            }
280
281            let pivot = u.get(k, k);
282            if SparseElement::is_zero(&pivot) {
283                return Err(SparseError::ValueError("Matrix is singular".to_string()));
284            }
285
286            // Eliminate column
287            for i in (k + 1)..(k + 1 + self.kl).min(n) {
288                let factor = u.get(i, k) / pivot;
289                l.set_unchecked(i, k, factor);
290
291                for j in k..(k + 1 + self.ku).min(n) {
292                    let val = u.get(i, j) - factor * u.get(k, j);
293                    if u.is_in_band(i, j) {
294                        u.set_unchecked(i, j, val);
295                    }
296                }
297            }
298        }
299
300        // Set L diagonal to 1
301        for i in 0..n {
302            l.set_unchecked(i, i, T::sparse_one());
303        }
304
305        Ok((l, u, p))
306    }
307
308    /// Forward substitution for lower triangular banded matrix
309    pub fn forward_substitution(&self, b: &ArrayView1<T>) -> SparseResult<Array1<T>> {
310        let n = self.shape.0;
311        let mut x = Array1::zeros(n);
312
313        for i in 0..n {
314            let mut sum = T::sparse_zero();
315            let start = i.saturating_sub(self.kl);
316
317            for j in start..i {
318                sum += self.get(i, j) * x[j];
319            }
320
321            x[i] = (b[i] - sum) / self.get(i, i);
322        }
323
324        Ok(x)
325    }
326
327    /// Back substitution for upper triangular banded matrix
328    pub fn back_substitution(&self, b: &ArrayView1<T>) -> SparseResult<Array1<T>> {
329        let n = self.shape.0;
330        let mut x = Array1::zeros(n);
331
332        for i in (0..n).rev() {
333            let mut sum = T::sparse_zero();
334            let end = (i + self.ku + 1).min(n);
335
336            for j in (i + 1)..end {
337                sum += self.get(i, j) * x[j];
338            }
339
340            x[i] = (b[i] - sum) / self.get(i, i);
341        }
342
343        Ok(x)
344    }
345
346    /// Swap two rows in the banded matrix
347    fn swap_rows(&mut self, i: usize, j: usize) {
348        if i == j {
349            return;
350        }
351
352        // Determine the range of columns to swap
353        let min_col = i.saturating_sub(self.kl).max(j.saturating_sub(self.kl));
354        let max_col = (i + self.ku).min(j + self.ku).min(self.shape.1 - 1);
355
356        for col in min_col..=max_col {
357            if self.is_in_band(i, col) && self.is_in_band(j, col) {
358                let temp = self.get(i, col);
359                self.set_unchecked(i, col, self.get(j, col));
360                self.set_unchecked(j, col, temp);
361            }
362        }
363    }
364
365    /// Matrix-vector multiplication optimized for banded structure
366    pub fn matvec(&self, x: &ArrayView1<T>) -> SparseResult<Array1<T>> {
367        if x.len() != self.shape.1 {
368            return Err(SparseError::DimensionMismatch {
369                expected: self.shape.1,
370                found: x.len(),
371            });
372        }
373
374        let mut y = Array1::zeros(self.shape.0);
375
376        for i in 0..self.shape.0 {
377            let start_col = i.saturating_sub(self.kl);
378            let end_col = (i + self.ku + 1).min(self.shape.1);
379
380            for j in start_col..end_col {
381                y[i] += self.get(i, j) * x[j];
382            }
383        }
384
385        Ok(y)
386    }
387}
388
389impl<T> SparseArray<T> for BandedArray<T>
390where
391    T: Float
392        + SparseElement
393        + Debug
394        + Display
395        + Copy
396        + Zero
397        + One
398        + SparseElement
399        + Send
400        + Sync
401        + 'static
402        + std::ops::AddAssign,
403{
404    fn shape(&self) -> (usize, usize) {
405        self.shape
406    }
407
408    fn nnz(&self) -> usize {
409        let mut count = 0;
410        for band in 0..(self.kl + self.ku + 1) {
411            for col in 0..self.shape.0 {
412                if !SparseElement::is_zero(&self.data[[band, col]]) {
413                    count += 1;
414                }
415            }
416        }
417        count
418    }
419
420    fn get(&self, row: usize, col: usize) -> T {
421        if !self.is_in_band(row, col) {
422            return T::sparse_zero();
423        }
424
425        if let Some(band_idx) = self
426            .ku
427            .checked_add(row)
428            .and_then(|sum| sum.checked_sub(col))
429        {
430            if band_idx < self.kl + self.ku + 1 && col < self.shape.1 {
431                self.data[[band_idx, col]]
432            } else {
433                T::sparse_zero()
434            }
435        } else {
436            T::sparse_zero()
437        }
438    }
439
440    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
441        let mut rows = Vec::new();
442        let mut cols = Vec::new();
443        let mut data = Vec::new();
444
445        for i in 0..self.shape.0 {
446            let start_col = i.saturating_sub(self.kl);
447            let end_col = (i + self.ku + 1).min(self.shape.1);
448
449            for j in start_col..end_col {
450                let val = self.get(i, j);
451                if !SparseElement::is_zero(&val) {
452                    rows.push(i);
453                    cols.push(j);
454                    data.push(val);
455                }
456            }
457        }
458
459        (
460            Array1::from_vec(rows),
461            Array1::from_vec(cols),
462            Array1::from_vec(data),
463        )
464    }
465
466    fn to_array(&self) -> Array2<T> {
467        let mut result = Array2::zeros(self.shape);
468
469        for i in 0..self.shape.0 {
470            let start_col = i.saturating_sub(self.kl);
471            let end_col = (i + self.ku + 1).min(self.shape.1);
472
473            for j in start_col..end_col {
474                result[[i, j]] = self.get(i, j);
475            }
476        }
477
478        result
479    }
480
481    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
482        // For now, convert to dense and multiply
483        let a_dense = self.to_array();
484        let b_dense = other.to_array();
485
486        if a_dense.ncols() != b_dense.nrows() {
487            return Err(SparseError::DimensionMismatch {
488                expected: a_dense.ncols(),
489                found: b_dense.nrows(),
490            });
491        }
492
493        let result = a_dense.dot(&b_dense);
494
495        // Try to convert back to banded format if possible
496        // For simplicity, convert to CSR for now
497        let (rows, cols, data) = array_to_triplets(&result);
498        let csr =
499            crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
500
501        Ok(Box::new(csr))
502    }
503
504    fn dtype(&self) -> &str {
505        std::any::type_name::<T>()
506    }
507
508    fn toarray(&self) -> Array2<T> {
509        self.to_array()
510    }
511
512    fn as_any(&self) -> &dyn std::any::Any {
513        self
514    }
515
516    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
517        let (rows, cols, data) = self.find();
518        let coo = crate::coo_array::CooArray::from_triplets(
519            rows.as_slice().unwrap(),
520            cols.as_slice().unwrap(),
521            data.as_slice().unwrap(),
522            self.shape,
523            false,
524        )?;
525        Ok(Box::new(coo))
526    }
527
528    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
529        let (rows, cols, data) = self.find();
530        let csr = crate::csr_array::CsrArray::from_triplets(
531            rows.as_slice().unwrap(),
532            cols.as_slice().unwrap(),
533            data.as_slice().unwrap(),
534            self.shape,
535            false,
536        )?;
537        Ok(Box::new(csr))
538    }
539
540    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
541        let (rows, cols, data) = self.find();
542        let csc = crate::csc_array::CscArray::from_triplets(
543            rows.as_slice().unwrap(),
544            cols.as_slice().unwrap(),
545            data.as_slice().unwrap(),
546            self.shape,
547            false,
548        )?;
549        Ok(Box::new(csc))
550    }
551
552    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
553        let (rows, cols, data) = self.find();
554        let mut dok = crate::dok_array::DokArray::new(self.shape);
555        for ((row, col), &val) in rows.iter().zip(cols.iter()).zip(data.iter()) {
556            dok.set(*row, *col, val)?;
557        }
558        Ok(Box::new(dok))
559    }
560
561    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
562        let mut lil = crate::lil_array::LilArray::new(self.shape);
563        for i in 0..self.shape.0 {
564            let start_col = i.saturating_sub(self.kl);
565            let end_col = (i + self.ku + 1).min(self.shape.1);
566
567            for j in start_col..end_col {
568                let val = self.get(i, j);
569                if !SparseElement::is_zero(&val) {
570                    lil.set(i, j, val)?;
571                }
572            }
573        }
574        Ok(Box::new(lil))
575    }
576
577    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
578        // Convert banded to diagonal format
579        let mut diagonals = Vec::new();
580        let mut offsets = Vec::new();
581
582        for band in 0..(self.kl + self.ku + 1) {
583            let offset = (band as isize) - (self.ku as isize);
584            let mut diagonal = Vec::new();
585
586            for row in 0..self.shape.0 {
587                if row < self.shape.0 && band < self.data.dim().0 {
588                    diagonal.push(self.data[[band, row]]);
589                }
590            }
591
592            if diagonal.iter().any(|&x| !SparseElement::is_zero(&x)) {
593                diagonals.push(Array1::from_vec(diagonal));
594                offsets.push(offset);
595            }
596        }
597
598        let dia = crate::dia_array::DiaArray::new(diagonals, offsets, self.shape)?;
599        Ok(Box::new(dia))
600    }
601
602    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
603        // Convert to CSR first, then to BSR
604        let csr = self.to_csr()?;
605        csr.to_bsr()
606    }
607
608    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
609        if self.shape != other.shape() {
610            return Err(SparseError::DimensionMismatch {
611                expected: self.shape.0 * self.shape.1,
612                found: other.shape().0 * other.shape().1,
613            });
614        }
615
616        let a_dense = self.to_array();
617        let b_dense = other.to_array();
618        let result = a_dense + b_dense;
619
620        let (rows, cols, data) = array_to_triplets(&result);
621        let csr =
622            crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
623        Ok(Box::new(csr))
624    }
625
626    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
627        if self.shape != other.shape() {
628            return Err(SparseError::DimensionMismatch {
629                expected: self.shape.0 * self.shape.1,
630                found: other.shape().0 * other.shape().1,
631            });
632        }
633
634        let a_dense = self.to_array();
635        let b_dense = other.to_array();
636        let result = a_dense - b_dense;
637
638        let (rows, cols, data) = array_to_triplets(&result);
639        let csr =
640            crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
641        Ok(Box::new(csr))
642    }
643
644    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
645        if self.shape != other.shape() {
646            return Err(SparseError::DimensionMismatch {
647                expected: self.shape.0 * self.shape.1,
648                found: other.shape().0 * other.shape().1,
649            });
650        }
651
652        let a_dense = self.to_array();
653        let b_dense = other.to_array();
654        let result = a_dense * b_dense;
655
656        let (rows, cols, data) = array_to_triplets(&result);
657        let csr =
658            crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
659        Ok(Box::new(csr))
660    }
661
662    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
663        if self.shape != other.shape() {
664            return Err(SparseError::DimensionMismatch {
665                expected: self.shape.0 * self.shape.1,
666                found: other.shape().0 * other.shape().1,
667            });
668        }
669
670        let a_dense = self.to_array();
671        let b_dense = other.to_array();
672        let result = a_dense / b_dense;
673
674        let (rows, cols, data) = array_to_triplets(&result);
675        let csr =
676            crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
677        Ok(Box::new(csr))
678    }
679
680    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
681        if self.shape.1 != other.len() {
682            return Err(SparseError::DimensionMismatch {
683                expected: self.shape.1,
684                found: other.len(),
685            });
686        }
687
688        let mut result = Array1::zeros(self.shape.0);
689
690        for i in 0..self.shape.0 {
691            let start_col = i.saturating_sub(self.kl);
692            let end_col = (i + self.ku + 1).min(self.shape.1);
693
694            for j in start_col..end_col {
695                let val = self.get(i, j);
696                if !SparseElement::is_zero(&val) {
697                    result[i] += val * other[j];
698                }
699            }
700        }
701
702        Ok(result)
703    }
704
705    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
706        let mut transposed = BandedArray::zeros((self.shape.1, self.shape.0), self.ku, self.kl);
707
708        for i in 0..self.shape.0 {
709            let start_col = i.saturating_sub(self.kl);
710            let end_col = (i + self.ku + 1).min(self.shape.1);
711
712            for j in start_col..end_col {
713                let val = self.get(i, j);
714                if !SparseElement::is_zero(&val) {
715                    transposed.set_direct(j, i, val)?;
716                }
717            }
718        }
719
720        Ok(Box::new(transposed))
721    }
722
723    fn copy(&self) -> Box<dyn SparseArray<T>> {
724        Box::new(self.clone())
725    }
726
727    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
728        self.set_direct(i, j, value)
729    }
730
731    fn eliminate_zeros(&mut self) {
732        // For banded arrays, we typically don't eliminate structural zeros
733        // as they maintain the band structure
734    }
735
736    fn sort_indices(&mut self) {
737        // Banded arrays maintain sorted indices by structure
738    }
739
740    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
741        self.copy()
742    }
743
744    fn has_sorted_indices(&self) -> bool {
745        true // Banded arrays always have sorted indices by structure
746    }
747
748    fn sum(&self, axis: Option<usize>) -> SparseResult<crate::sparray::SparseSum<T>> {
749        match axis {
750            None => {
751                // Sum all elements
752                let total = self.data.iter().fold(T::sparse_zero(), |acc, &x| acc + x);
753                Ok(crate::sparray::SparseSum::Scalar(total))
754            }
755            Some(0) => {
756                // Sum along rows (result is column vector)
757                let mut result: Array1<T> = Array1::zeros(self.shape.1);
758                for i in 0..self.shape.0 {
759                    let start_col = i.saturating_sub(self.kl);
760                    let end_col = (i + self.ku + 1).min(self.shape.1);
761
762                    for j in start_col..end_col {
763                        let val = self.get(i, j);
764                        result[j] += val;
765                    }
766                }
767                // Convert to CSR format
768                let mut data = Vec::new();
769                let mut indices = Vec::new();
770                let mut indptr = vec![0];
771
772                for (col, &val) in result.iter().enumerate() {
773                    if !SparseElement::is_zero(&val) {
774                        data.push(val);
775                        indices.push(col);
776                    }
777                }
778                indptr.push(data.len());
779
780                let result_array = crate::csr_array::CsrArray::new(
781                    Array1::from_vec(data),
782                    Array1::from_vec(indices),
783                    Array1::from_vec(indptr),
784                    (1, self.shape.1),
785                )?;
786
787                Ok(crate::sparray::SparseSum::SparseArray(Box::new(
788                    result_array,
789                )))
790            }
791            Some(1) => {
792                // Sum along columns (result is column vector)
793                let mut result: Array1<T> = Array1::zeros(self.shape.0);
794                for i in 0..self.shape.0 {
795                    let start_col = i.saturating_sub(self.kl);
796                    let end_col = (i + self.ku + 1).min(self.shape.1);
797
798                    for j in start_col..end_col {
799                        let val = self.get(i, j);
800                        result[i] += val;
801                    }
802                }
803                // Convert to CSR format (column vector)
804                let mut data = Vec::new();
805                let mut indices = Vec::new();
806                let mut indptr = vec![0];
807
808                for &val in result.iter() {
809                    if !SparseElement::is_zero(&val) {
810                        data.push(val);
811                        indices.push(0); // All values are in column 0
812                    }
813                    indptr.push(data.len());
814                }
815
816                let result_array = crate::csr_array::CsrArray::new(
817                    Array1::from_vec(data),
818                    Array1::from_vec(indices),
819                    Array1::from_vec(indptr),
820                    (self.shape.0, 1),
821                )?;
822
823                Ok(crate::sparray::SparseSum::SparseArray(Box::new(
824                    result_array,
825                )))
826            }
827            Some(_) => Err(SparseError::ValueError("Invalid axis".to_string())),
828        }
829    }
830
831    fn max(&self) -> T {
832        self.data
833            .iter()
834            .fold(T::neg_infinity(), |a, &b| if a > b { a } else { b })
835    }
836
837    fn min(&self) -> T {
838        self.data
839            .iter()
840            .fold(T::infinity(), |a, &b| if a < b { a } else { b })
841    }
842
843    fn slice(
844        &self,
845        row_range: (usize, usize),
846        col_range: (usize, usize),
847    ) -> SparseResult<Box<dyn SparseArray<T>>> {
848        let (start_row, end_row) = row_range;
849        let (start_col, end_col) = col_range;
850
851        if end_row > self.shape.0 || end_col > self.shape.1 {
852            return Err(SparseError::ValueError(
853                "Slice bounds exceed matrix dimensions".to_string(),
854            ));
855        }
856
857        let mut rows = Vec::new();
858        let mut cols = Vec::new();
859        let mut data = Vec::new();
860
861        for i in start_row..end_row {
862            let band_start_col = i.saturating_sub(self.kl).max(start_col);
863            let band_end_col = (i + self.ku + 1).min(self.shape.1).min(end_col);
864
865            for j in band_start_col..band_end_col {
866                let val = self.get(i, j);
867                if !SparseElement::is_zero(&val) {
868                    rows.push(i - start_row);
869                    cols.push(j - start_col);
870                    data.push(val);
871                }
872            }
873        }
874
875        let shape = (end_row - start_row, end_col - start_col);
876        let csr = crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, shape, false)?;
877        Ok(Box::new(csr))
878    }
879}
880
881/// Apply permutation to a vector
882#[allow(dead_code)]
883fn apply_permutation<T: Copy + Zero>(p: &[usize], v: &ArrayView1<T>) -> Array1<T> {
884    let mut result = Array1::zeros(v.len());
885    for (i, &pi) in p.iter().enumerate() {
886        result[i] = v[pi];
887    }
888    result
889}
890
891/// Convert dense array to triplet format
892#[allow(dead_code)]
893fn array_to_triplets<T: Float + SparseElement + Debug + Copy + Zero>(
894    array: &Array2<T>,
895) -> (Vec<usize>, Vec<usize>, Vec<T>) {
896    let mut rows = Vec::new();
897    let mut cols = Vec::new();
898    let mut data = Vec::new();
899
900    for ((i, j), &val) in array.indexed_iter() {
901        if !SparseElement::is_zero(&val) {
902            rows.push(i);
903            cols.push(j);
904            data.push(val);
905        }
906    }
907
908    (rows, cols, data)
909}
910
911#[cfg(test)]
912mod tests {
913    use super::*;
914    use approx::assert_relative_eq;
915
916    #[test]
917    fn test_banded_array_creation() {
918        let data = Array2::from_shape_vec(
919            (3, 4),
920            vec![
921                0.0, 1.0, 2.0, 3.0, // Upper diagonal
922                4.0, 5.0, 6.0, 7.0, // Main diagonal
923                8.0, 9.0, 10.0, 0.0, // Lower diagonal
924            ],
925        )
926        .unwrap();
927
928        let banded = BandedArray::new(data, 1, 1, (4, 4)).unwrap();
929
930        assert_eq!(banded.shape(), (4, 4));
931        assert_eq!(banded.kl(), 1);
932        assert_eq!(banded.ku(), 1);
933
934        // Check main diagonal
935        assert_eq!(banded.get(0, 0), 4.0);
936        assert_eq!(banded.get(1, 1), 5.0);
937        assert_eq!(banded.get(2, 2), 6.0);
938        assert_eq!(banded.get(3, 3), 7.0);
939
940        // Check upper diagonal
941        assert_eq!(banded.get(0, 1), 1.0);
942        assert_eq!(banded.get(1, 2), 2.0);
943        assert_eq!(banded.get(2, 3), 3.0);
944
945        // Check lower diagonal
946        assert_eq!(banded.get(1, 0), 8.0);
947        assert_eq!(banded.get(2, 1), 9.0);
948        assert_eq!(banded.get(3, 2), 10.0);
949
950        // Check out-of-band elements
951        assert_eq!(banded.get(0, 2), 0.0);
952        assert_eq!(banded.get(2, 0), 0.0);
953    }
954
955    #[test]
956    fn test_tridiagonal_matrix() {
957        let diag = vec![2.0, 3.0, 4.0];
958        let lower = vec![1.0, 1.0];
959        let upper = vec![5.0, 6.0];
960
961        let banded = BandedArray::tridiagonal(&diag, &lower, &upper).unwrap();
962
963        assert_eq!(banded.shape(), (3, 3));
964        assert_eq!(banded.get(0, 0), 2.0);
965        assert_eq!(banded.get(1, 1), 3.0);
966        assert_eq!(banded.get(2, 2), 4.0);
967        assert_eq!(banded.get(1, 0), 1.0);
968        assert_eq!(banded.get(2, 1), 1.0);
969        assert_eq!(banded.get(0, 1), 5.0);
970        assert_eq!(banded.get(1, 2), 6.0);
971    }
972
973    #[test]
974    fn test_banded_matvec() {
975        let diag = vec![2.0, 3.0, 4.0];
976        let lower = vec![1.0, 1.0];
977        let upper = vec![5.0, 6.0];
978
979        let banded = BandedArray::tridiagonal(&diag, &lower, &upper).unwrap();
980        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
981
982        let y = banded.matvec(&x.view()).unwrap();
983
984        // Manual calculation:
985        // [2 5 0] [1]   [2*1 + 5*2 + 0*3] = [12]
986        // [1 3 6] [2] = [1*1 + 3*2 + 6*3] = [25]
987        // [0 1 4] [3]   [0*1 + 1*2 + 4*3] = [14]
988
989        assert_relative_eq!(y[0], 12.0);
990        assert_relative_eq!(y[1], 25.0);
991        assert_relative_eq!(y[2], 14.0);
992    }
993
994    #[test]
995    fn test_banded_solve() {
996        // Create a simple tridiagonal system
997        let diag = vec![2.0, 2.0, 2.0];
998        let lower = vec![-1.0, -1.0];
999        let upper = vec![-1.0, -1.0];
1000
1001        let banded = BandedArray::tridiagonal(&diag, &lower, &upper).unwrap();
1002        let b = Array1::from_vec(vec![1.0, 0.0, 1.0]);
1003
1004        let x = banded.solve(&b.view()).unwrap();
1005
1006        // Verify solution by computing A*x
1007        let ax = banded.matvec(&x.view()).unwrap();
1008
1009        for i in 0..3 {
1010            assert_relative_eq!(ax[i], b[i], epsilon = 1e-10);
1011        }
1012    }
1013
1014    #[test]
1015    fn test_is_in_band() {
1016        let banded = BandedArray::<f64>::zeros((5, 5), 2, 1);
1017
1018        // Main diagonal should be in band
1019        assert!(banded.is_in_band(2, 2));
1020
1021        // One position above main diagonal
1022        assert!(banded.is_in_band(2, 3));
1023
1024        // Two positions below main diagonal
1025        assert!(banded.is_in_band(2, 0));
1026
1027        // Outside band
1028        assert!(!banded.is_in_band(0, 2));
1029        assert!(!banded.is_in_band(4, 0));
1030    }
1031
1032    #[test]
1033    fn test_eye_matrix() {
1034        let eye = BandedArray::<f64>::eye(3, 1, 1);
1035
1036        assert_eq!(eye.get(0, 0), 1.0);
1037        assert_eq!(eye.get(1, 1), 1.0);
1038        assert_eq!(eye.get(2, 2), 1.0);
1039        assert_eq!(eye.get(0, 1), 0.0);
1040        assert_eq!(eye.get(1, 0), 0.0);
1041
1042        assert_eq!(eye.nnz(), 3);
1043    }
1044}