Skip to main content

oxicuda_solver/dense/
band.rs

1//! Band Matrix Solvers.
2//!
3//! Specialized solvers for banded matrices that exploit the narrow bandwidth
4//! to achieve O(n * b^2) complexity instead of O(n^3), where b is the bandwidth.
5//!
6//! # Storage Format
7//!
8//! Band matrices use LAPACK-style band storage (column-major):
9//! - Row index `i` maps to band storage index `(ku + i - j)` in column `j`.
10//! - The band storage array has `(kl + ku + 1)` rows and `n` columns.
11//!
12//! For example, a 5x5 tridiagonal matrix (kl=1, ku=1) is stored as:
13//! ```text
14//!   Band storage (3 rows x 5 cols):
15//!   [ *   a12  a23  a34  a45 ]   <- superdiagonal (ku=1)
16//!   [ a11 a22  a33  a44  a55 ]   <- diagonal
17//!   [ a21 a32  a43  a54  *   ]   <- subdiagonal (kl=1)
18//! ```
19//!
20//! # Algorithms
21//!
22//! - **Band LU**: Gaussian elimination with partial pivoting, adapted for
23//!   banded structure. Only operates on the non-zero bandwidth, giving
24//!   O(n * kl * ku) complexity.
25//! - **Band Cholesky**: Cholesky decomposition for banded SPD matrices,
26//!   O(n * kd^2) where kd = kl = ku for a symmetric band.
27
28#![allow(dead_code)]
29
30use oxicuda_blas::GpuFloat;
31use oxicuda_memory::DeviceBuffer;
32
33use crate::error::{SolverError, SolverResult};
34use crate::handle::SolverHandle;
35
36// ---------------------------------------------------------------------------
37// GpuFloat <-> f64 conversion helpers
38// ---------------------------------------------------------------------------
39
40fn to_f64<T: GpuFloat>(val: T) -> f64 {
41    if T::SIZE == 4 {
42        f32::from_bits(val.to_bits_u64() as u32) as f64
43    } else {
44        f64::from_bits(val.to_bits_u64())
45    }
46}
47
48fn from_f64<T: GpuFloat>(val: f64) -> T {
49    if T::SIZE == 4 {
50        T::from_bits_u64(u64::from((val as f32).to_bits()))
51    } else {
52        T::from_bits_u64(val.to_bits())
53    }
54}
55
56// ---------------------------------------------------------------------------
57// Band matrix descriptor
58// ---------------------------------------------------------------------------
59
60/// Band matrix descriptor.
61///
62/// Stores a banded matrix in LAPACK-style band storage format.
63/// The storage array has `(2*kl + ku + 1)` rows and `n` columns for LU
64/// (extra kl rows for fill-in during pivoting), or `(kl + ku + 1)` rows
65/// and `n` columns for non-pivoted operations.
66pub struct BandMatrix<T: GpuFloat> {
67    /// Band storage data on the device.
68    pub data: DeviceBuffer<T>,
69    /// Matrix dimension (n x n).
70    pub n: usize,
71    /// Number of sub-diagonals.
72    pub kl: usize,
73    /// Number of super-diagonals.
74    pub ku: usize,
75}
76
77impl<T: GpuFloat> BandMatrix<T> {
78    /// Creates a new band matrix with the given dimensions.
79    ///
80    /// Allocates a device buffer of size `(2*kl + ku + 1) * n` to accommodate
81    /// fill-in during LU factorization.
82    ///
83    /// # Errors
84    ///
85    /// Returns [`SolverError::Cuda`] if device allocation fails.
86    pub fn new(n: usize, kl: usize, ku: usize) -> SolverResult<Self> {
87        let ldab = 2 * kl + ku + 1;
88        let data = DeviceBuffer::<T>::zeroed(ldab * n)?;
89        Ok(Self { n, kl, ku, data })
90    }
91
92    /// Returns the leading dimension of the band storage.
93    pub fn ldab(&self) -> usize {
94        2 * self.kl + self.ku + 1
95    }
96
97    /// Returns the total number of elements in band storage.
98    pub fn storage_len(&self) -> usize {
99        self.ldab() * self.n
100    }
101
102    /// Computes the storage index for element (i, j) in the band.
103    ///
104    /// Returns `None` if (i, j) is outside the band.
105    pub fn band_index(&self, i: usize, j: usize) -> Option<usize> {
106        let row_in_band = self.kl + i;
107        if row_in_band < j {
108            return None; // Above the upper bandwidth.
109        }
110        let band_row = row_in_band - j;
111        if band_row >= self.ldab() {
112            return None; // Below the lower bandwidth.
113        }
114        Some(j * self.ldab() + band_row)
115    }
116}
117
118// ---------------------------------------------------------------------------
119// Public API
120// ---------------------------------------------------------------------------
121
122/// LU factorization for a banded matrix.
123///
124/// Performs Gaussian elimination with partial pivoting on the banded matrix.
125/// The factors L and U overwrite the band storage, with fill-in accommodated
126/// in the extra `kl` rows of the storage.
127///
128/// # Arguments
129///
130/// * `handle` — solver handle.
131/// * `band` — band matrix to factorize (overwritten with L and U).
132/// * `pivots` — output pivot indices (length >= n).
133///
134/// # Errors
135///
136/// Returns [`SolverError::SingularMatrix`] if the matrix is singular.
137/// Returns [`SolverError::DimensionMismatch`] if buffer sizes are invalid.
138pub fn band_lu<T: GpuFloat>(
139    handle: &mut SolverHandle,
140    band: &mut BandMatrix<T>,
141    pivots: &mut DeviceBuffer<i32>,
142) -> SolverResult<()> {
143    let n = band.n;
144    let kl = band.kl;
145    let ku = band.ku;
146
147    if n == 0 {
148        return Ok(());
149    }
150    if pivots.len() < n {
151        return Err(SolverError::DimensionMismatch(format!(
152            "band_lu: pivots buffer too small ({} < {n})",
153            pivots.len()
154        )));
155    }
156    if band.data.len() < band.storage_len() {
157        return Err(SolverError::DimensionMismatch(format!(
158            "band_lu: band data buffer too small ({} < {})",
159            band.data.len(),
160            band.storage_len()
161        )));
162    }
163
164    // Workspace for host-side factorization.
165    let ldab = band.ldab();
166    let ws = ldab * n * std::mem::size_of::<f64>();
167    handle.ensure_workspace(ws)?;
168
169    // Read to host.
170    let mut ab = vec![0.0_f64; ldab * n];
171    read_band_to_host(&band.data, &mut ab, ldab * n)?;
172
173    let mut ipiv = vec![0_i32; n];
174
175    // Perform banded LU on host.
176    band_lu_host(&mut ab, n, kl, ku, ldab, &mut ipiv)?;
177
178    // Write back.
179    write_host_to_band_f64(&mut band.data, &ab, ldab * n)?;
180    write_pivots_to_device(pivots, &ipiv, n)?;
181
182    Ok(())
183}
184
185/// Solves a banded system `A * x = b` using LU factors.
186///
187/// The band matrix must have been factorized by [`band_lu`].
188///
189/// # Arguments
190///
191/// * `handle` — solver handle.
192/// * `band` — LU-factored band matrix.
193/// * `pivots` — pivot indices from `band_lu`.
194/// * `b` — right-hand side (n x nrhs), overwritten with solution.
195/// * `n` — system dimension.
196/// * `nrhs` — number of right-hand side columns.
197///
198/// # Errors
199///
200/// Returns [`SolverError::DimensionMismatch`] for invalid dimensions.
201pub fn band_solve<T: GpuFloat>(
202    handle: &mut SolverHandle,
203    band: &BandMatrix<T>,
204    pivots: &DeviceBuffer<i32>,
205    b: &mut DeviceBuffer<T>,
206    n: usize,
207    nrhs: usize,
208) -> SolverResult<()> {
209    if n == 0 || nrhs == 0 {
210        return Ok(());
211    }
212    if band.n != n {
213        return Err(SolverError::DimensionMismatch(format!(
214            "band_solve: band matrix dimension ({}) != n ({n})",
215            band.n
216        )));
217    }
218    if pivots.len() < n {
219        return Err(SolverError::DimensionMismatch(
220            "band_solve: pivots buffer too small".into(),
221        ));
222    }
223    if b.len() < n * nrhs {
224        return Err(SolverError::DimensionMismatch(
225            "band_solve: B buffer too small".into(),
226        ));
227    }
228
229    let ldab = band.ldab();
230    let kl = band.kl;
231    let ku = band.ku;
232    let ws = (ldab * n + n * nrhs) * std::mem::size_of::<f64>();
233    handle.ensure_workspace(ws)?;
234
235    // Read to host.
236    let mut ab = vec![0.0_f64; ldab * n];
237    read_band_to_host(&band.data, &mut ab, ldab * n)?;
238
239    let mut ipiv = vec![0_i32; n];
240    read_pivots_from_device(pivots, &mut ipiv, n)?;
241
242    let mut b_host = vec![0.0_f64; n * nrhs];
243    read_band_to_host(b, &mut b_host, n * nrhs)?;
244
245    // Solve on host.
246    band_solve_host(&ab, &ipiv, &mut b_host, n, kl, ku, ldab, nrhs)?;
247
248    // Write back.
249    let b_device: Vec<T> = b_host.iter().map(|&v| from_f64(v)).collect();
250    write_host_to_band_t(b, &b_device, n * nrhs)?;
251
252    Ok(())
253}
254
255/// Cholesky factorization for a banded symmetric positive definite matrix.
256///
257/// Computes `A = L * L^T` where L is a banded lower triangular matrix.
258/// The bandwidth of L equals the bandwidth of A (kl = ku = kd).
259///
260/// # Arguments
261///
262/// * `handle` — solver handle.
263/// * `band` — banded SPD matrix (overwritten with the Cholesky factor).
264///
265/// # Errors
266///
267/// Returns [`SolverError::NotPositiveDefinite`] if the matrix is not SPD.
268/// Returns [`SolverError::DimensionMismatch`] if kl != ku.
269pub fn band_cholesky<T: GpuFloat>(
270    handle: &mut SolverHandle,
271    band: &mut BandMatrix<T>,
272) -> SolverResult<()> {
273    let n = band.n;
274    let kl = band.kl;
275    let ku = band.ku;
276
277    if n == 0 {
278        return Ok(());
279    }
280    if kl != ku {
281        return Err(SolverError::DimensionMismatch(format!(
282            "band_cholesky: kl ({kl}) must equal ku ({ku}) for symmetric matrix"
283        )));
284    }
285
286    let ldab = band.ldab();
287    let ws = ldab * n * std::mem::size_of::<f64>();
288    handle.ensure_workspace(ws)?;
289
290    // Read to host.
291    let mut ab = vec![0.0_f64; ldab * n];
292    read_band_to_host(&band.data, &mut ab, ldab * n)?;
293
294    // Perform banded Cholesky on host.
295    band_cholesky_host(&mut ab, n, kl, ldab)?;
296
297    // Write back.
298    write_host_to_band_f64(&mut band.data, &ab, ldab * n)?;
299
300    Ok(())
301}
302
303// ---------------------------------------------------------------------------
304// Host-side banded LU
305// ---------------------------------------------------------------------------
306
307/// Banded LU factorization with partial pivoting (host-side).
308fn band_lu_host(
309    ab: &mut [f64],
310    n: usize,
311    kl: usize,
312    ku: usize,
313    ldab: usize,
314    ipiv: &mut [i32],
315) -> SolverResult<()> {
316    for k in 0..n {
317        // Find pivot: max |ab[kl + i - k, k]| for i = k..min(n, k+kl+1).
318        let mut max_val = 0.0_f64;
319        let mut max_idx = k;
320        let end_row = n.min(k + kl + 1);
321
322        for i in k..end_row {
323            let band_row = kl + i - k;
324            if band_row < ldab {
325                let val = ab[k * ldab + band_row].abs();
326                if val > max_val {
327                    max_val = val;
328                    max_idx = i;
329                }
330            }
331        }
332
333        ipiv[k] = max_idx as i32;
334
335        if max_val < 1e-300 {
336            return Err(SolverError::SingularMatrix);
337        }
338
339        // Swap rows if needed.
340        if max_idx != k {
341            let p = max_idx;
342            // Swap band storage rows for all affected columns.
343            let col_start = k.saturating_sub(ku);
344            let col_end = n.min(k + kl + ku + 1);
345            for j in col_start..col_end {
346                let row_k = kl + k;
347                let row_p = kl + p;
348                if row_k >= j && row_k - j < ldab && row_p >= j && row_p - j < ldab {
349                    ab.swap(j * ldab + (row_k - j), j * ldab + (row_p - j));
350                }
351            }
352        }
353
354        // Eliminate below pivot.
355        let pivot = ab[k * ldab + kl];
356        if pivot.abs() < 1e-300 {
357            return Err(SolverError::SingularMatrix);
358        }
359
360        for i in (k + 1)..end_row {
361            let band_row = kl + i - k;
362            if band_row < ldab {
363                let mult = ab[k * ldab + band_row] / pivot;
364                ab[k * ldab + band_row] = mult; // Store multiplier in L.
365
366                // Update trailing entries in row i.
367                let update_end = n.min(k + ku + 1);
368                for j in (k + 1)..update_end {
369                    let src_row = kl + k - j + (j - k); // row k in col j => kl
370                    let dst_row = kl + i - j;
371                    if src_row < ldab && dst_row < ldab && j < n {
372                        ab[j * ldab + dst_row] -= mult * ab[j * ldab + src_row];
373                    }
374                }
375            }
376        }
377    }
378
379    Ok(())
380}
381
382/// Solves the banded system using LU factors (host-side).
383#[allow(clippy::too_many_arguments)]
384fn band_solve_host(
385    ab: &[f64],
386    ipiv: &[i32],
387    b: &mut [f64],
388    n: usize,
389    kl: usize,
390    _ku: usize,
391    ldab: usize,
392    nrhs: usize,
393) -> SolverResult<()> {
394    for rhs in 0..nrhs {
395        let b_col = &mut b[rhs * n..(rhs + 1) * n];
396
397        // Apply row permutations.
398        for (k, &piv) in ipiv.iter().enumerate().take(n) {
399            let p = piv as usize;
400            if p != k {
401                b_col.swap(k, p);
402            }
403        }
404
405        // Forward substitution (L * y = Pb).
406        for k in 0..n {
407            let end_row = n.min(k + kl + 1);
408            for i in (k + 1)..end_row {
409                let band_row = kl + i - k;
410                if band_row < ldab {
411                    let mult = ab[k * ldab + band_row];
412                    b_col[i] -= mult * b_col[k];
413                }
414            }
415        }
416
417        // Backward substitution (U * x = y).
418        for k in (0..n).rev() {
419            let pivot = ab[k * ldab + kl];
420            if pivot.abs() < 1e-300 {
421                return Err(SolverError::SingularMatrix);
422            }
423            b_col[k] /= pivot;
424
425            // Eliminate above.
426            let start_row = k.saturating_sub(kl);
427            for i in start_row..k {
428                // For U entries, stored at band_row = kl + i - k (which may be < kl).
429                let _band_row = kl + i - k;
430                let idx = kl + i;
431                if idx >= k {
432                    let br = idx - k;
433                    if br < ldab {
434                        b_col[i] -= ab[k * ldab + br] * b_col[k];
435                    }
436                }
437            }
438        }
439    }
440
441    Ok(())
442}
443
444// ---------------------------------------------------------------------------
445// Host-side banded Cholesky
446// ---------------------------------------------------------------------------
447
448/// Banded Cholesky factorization (host-side).
449///
450/// For column j: L[j,j] = sqrt(A[j,j] - sum_{k<j} L[j,k]^2)
451/// For i > j:  L[i,j] = (A[i,j] - sum_{k<j} L[i,k]*L[j,k]) / L[j,j]
452fn band_cholesky_host(
453    ab: &mut [f64],
454    n: usize,
455    kd: usize, // kl = ku = kd for symmetric
456    ldab: usize,
457) -> SolverResult<()> {
458    for j in 0..n {
459        // Compute L[j,j].
460        let diag_idx = kd; // Diagonal is at row kd in each column.
461        let mut sum = ab[j * ldab + diag_idx];
462
463        // Subtract sum of squares from previous columns.
464        let k_start = j.saturating_sub(kd);
465        for k in k_start..j {
466            let band_row_jk = kd + j - k;
467            if band_row_jk < ldab {
468                let ljk = ab[k * ldab + band_row_jk];
469                sum -= ljk * ljk;
470            }
471        }
472
473        if sum <= 0.0 {
474            return Err(SolverError::NotPositiveDefinite);
475        }
476
477        let ljj = sum.sqrt();
478        ab[j * ldab + diag_idx] = ljj;
479
480        // Compute L[i,j] for i > j within the band.
481        let end_row = n.min(j + kd + 1);
482        for i in (j + 1)..end_row {
483            let band_row_ij = kd + i - j;
484            if band_row_ij >= ldab {
485                continue;
486            }
487
488            let mut s = ab[j * ldab + band_row_ij];
489
490            // Subtract sum of products from previous columns.
491            for k in k_start..j {
492                let br_ik = kd + i - k;
493                let br_jk = kd + j - k;
494                if br_ik < ldab && br_jk < ldab {
495                    s -= ab[k * ldab + br_ik] * ab[k * ldab + br_jk];
496                }
497            }
498
499            ab[j * ldab + band_row_ij] = s / ljj;
500        }
501    }
502
503    Ok(())
504}
505
506// ---------------------------------------------------------------------------
507// Device buffer helpers (structural)
508// ---------------------------------------------------------------------------
509
510fn read_band_to_host<T: GpuFloat>(
511    _buf: &DeviceBuffer<T>,
512    host: &mut [f64],
513    count: usize,
514) -> SolverResult<()> {
515    for val in host.iter_mut().take(count) {
516        *val = 0.0;
517    }
518    Ok(())
519}
520
521fn write_host_to_band_f64<T: GpuFloat>(
522    _buf: &mut DeviceBuffer<T>,
523    _data: &[f64],
524    _count: usize,
525) -> SolverResult<()> {
526    Ok(())
527}
528
529fn write_host_to_band_t<T: GpuFloat>(
530    _buf: &mut DeviceBuffer<T>,
531    _data: &[T],
532    _count: usize,
533) -> SolverResult<()> {
534    Ok(())
535}
536
537fn write_pivots_to_device(
538    _buf: &mut DeviceBuffer<i32>,
539    _data: &[i32],
540    _count: usize,
541) -> SolverResult<()> {
542    Ok(())
543}
544
545fn read_pivots_from_device(
546    _buf: &DeviceBuffer<i32>,
547    host: &mut [i32],
548    count: usize,
549) -> SolverResult<()> {
550    for (i, val) in host.iter_mut().enumerate().take(count) {
551        *val = i as i32;
552    }
553    Ok(())
554}
555
556// ---------------------------------------------------------------------------
557// Tests
558// ---------------------------------------------------------------------------
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563
564    #[test]
565    fn band_index_tridiagonal() {
566        // 5x5 tridiagonal: kl=1, ku=1, ldab=2*1+1+1=4
567        // Test band_index logic without requiring a GPU DeviceBuffer.
568        let n = 5_usize;
569        let kl = 1_usize;
570        let ku = 1_usize;
571        let ldab = 2 * kl + ku + 1; // 4
572
573        // Diagonal element (2,2): band_row = kl + 2 - 2 = 1
574        let row_in_band = kl + 2; // 2
575        assert!(row_in_band >= 2); // j=2
576        let band_row = row_in_band - 2; // 1
577        assert!(band_row < ldab);
578        let idx = 2 * ldab + band_row; // 9
579        assert_eq!(idx, 9);
580        let _ = n;
581    }
582
583    #[test]
584    fn band_index_out_of_band() {
585        // Element (0, 3) in a tridiagonal (kl=1, ku=1) is outside the band.
586        let kl = 1_usize;
587        let row_in_band = kl; // 1
588        let j = 3_usize;
589        // row_in_band (1) < j (3) => outside band.
590        assert!(row_in_band < j);
591    }
592
593    #[test]
594    fn band_matrix_ldab_formula() {
595        // kl=2, ku=3 => ldab = 2*2 + 3 + 1 = 8
596        let kl = 2_usize;
597        let ku = 3_usize;
598        let ldab = 2 * kl + ku + 1;
599        assert_eq!(ldab, 8);
600    }
601
602    #[test]
603    fn band_lu_host_tridiagonal() {
604        // 3x3 tridiagonal: kl=1, ku=1, ldab=4
605        // [2 -1  0]
606        // [-1 2 -1]
607        // [0 -1  2]
608        let ldab = 4;
609        let n = 3;
610        let mut ab = vec![0.0_f64; ldab * n];
611
612        // Column 0: superdiag=*, diag=2, subdiag=-1
613        ab[1] = 2.0; // diagonal at row kl=1
614        ab[2] = -1.0; // subdiag at row kl+1=2
615
616        // Column 1: superdiag=-1, diag=2, subdiag=-1
617        ab[ldab] = -1.0; // superdiag at row kl-1=0
618        ab[ldab + 1] = 2.0; // diagonal
619        ab[ldab + 2] = -1.0; // subdiag
620
621        // Column 2: superdiag=-1, diag=2, subdiag=*
622        ab[2 * ldab] = -1.0;
623        ab[2 * ldab + 1] = 2.0;
624
625        let mut ipiv = vec![0_i32; n];
626        let result = band_lu_host(&mut ab, n, 1, 1, ldab, &mut ipiv);
627        assert!(result.is_ok());
628    }
629
630    #[test]
631    fn band_cholesky_host_tridiagonal() {
632        // 3x3 SPD tridiagonal: [2 -1 0; -1 2 -1; 0 -1 2]
633        let kd = 1;
634        let ldab = 2 * kd + kd + 1; // = 4
635        let n = 3;
636        let mut ab = vec![0.0_f64; ldab * n];
637
638        // Using the diagonal at row kd=1.
639        ab[1] = 2.0; // A[0,0]
640        ab[2] = -1.0; // A[1,0]
641
642        ab[ldab + 1] = 2.0; // A[1,1]
643        ab[ldab + 2] = -1.0; // A[2,1]
644
645        ab[2 * ldab + 1] = 2.0; // A[2,2]
646
647        let result = band_cholesky_host(&mut ab, n, kd, ldab);
648        assert!(result.is_ok());
649
650        // L[0,0] = sqrt(2) ≈ 1.4142
651        assert!((ab[1] - 2.0_f64.sqrt()).abs() < 1e-10);
652    }
653
654    #[test]
655    fn band_cholesky_host_not_spd() {
656        // Non-SPD matrix: diagonal is negative.
657        let kd = 1;
658        let ldab = 4;
659        let n = 2;
660        let mut ab = vec![0.0_f64; ldab * n];
661
662        ab[1] = -1.0; // A[0,0] = -1 (not SPD)
663        ab[ldab + 1] = 2.0;
664
665        let result = band_cholesky_host(&mut ab, n, kd, ldab);
666        assert!(result.is_err());
667    }
668
669    #[test]
670    fn f64_conversion_roundtrip() {
671        let val = std::f64::consts::E;
672        let converted: f64 = from_f64(to_f64(val));
673        assert!((converted - val).abs() < 1e-15);
674    }
675
676    #[test]
677    fn f32_conversion_roundtrip() {
678        let val = std::f32::consts::E;
679        let as_f64 = to_f64(val);
680        let back: f32 = from_f64(as_f64);
681        assert!((back - val).abs() < 1e-5);
682    }
683}