1#![allow(dead_code)]
29
30use oxicuda_blas::GpuFloat;
31use oxicuda_memory::DeviceBuffer;
32
33use crate::error::{SolverError, SolverResult};
34use crate::handle::SolverHandle;
35
36fn to_f64<T: GpuFloat>(val: T) -> f64 {
41 if T::SIZE == 4 {
42 f32::from_bits(val.to_bits_u64() as u32) as f64
43 } else {
44 f64::from_bits(val.to_bits_u64())
45 }
46}
47
48fn from_f64<T: GpuFloat>(val: f64) -> T {
49 if T::SIZE == 4 {
50 T::from_bits_u64(u64::from((val as f32).to_bits()))
51 } else {
52 T::from_bits_u64(val.to_bits())
53 }
54}
55
56pub struct BandMatrix<T: GpuFloat> {
67 pub data: DeviceBuffer<T>,
69 pub n: usize,
71 pub kl: usize,
73 pub ku: usize,
75}
76
77impl<T: GpuFloat> BandMatrix<T> {
78 pub fn new(n: usize, kl: usize, ku: usize) -> SolverResult<Self> {
87 let ldab = 2 * kl + ku + 1;
88 let data = DeviceBuffer::<T>::zeroed(ldab * n)?;
89 Ok(Self { n, kl, ku, data })
90 }
91
92 pub fn ldab(&self) -> usize {
94 2 * self.kl + self.ku + 1
95 }
96
97 pub fn storage_len(&self) -> usize {
99 self.ldab() * self.n
100 }
101
102 pub fn band_index(&self, i: usize, j: usize) -> Option<usize> {
106 let row_in_band = self.kl + i;
107 if row_in_band < j {
108 return None; }
110 let band_row = row_in_band - j;
111 if band_row >= self.ldab() {
112 return None; }
114 Some(j * self.ldab() + band_row)
115 }
116}
117
118pub fn band_lu<T: GpuFloat>(
139 handle: &mut SolverHandle,
140 band: &mut BandMatrix<T>,
141 pivots: &mut DeviceBuffer<i32>,
142) -> SolverResult<()> {
143 let n = band.n;
144 let kl = band.kl;
145 let ku = band.ku;
146
147 if n == 0 {
148 return Ok(());
149 }
150 if pivots.len() < n {
151 return Err(SolverError::DimensionMismatch(format!(
152 "band_lu: pivots buffer too small ({} < {n})",
153 pivots.len()
154 )));
155 }
156 if band.data.len() < band.storage_len() {
157 return Err(SolverError::DimensionMismatch(format!(
158 "band_lu: band data buffer too small ({} < {})",
159 band.data.len(),
160 band.storage_len()
161 )));
162 }
163
164 let ldab = band.ldab();
166 let ws = ldab * n * std::mem::size_of::<f64>();
167 handle.ensure_workspace(ws)?;
168
169 let mut ab = vec![0.0_f64; ldab * n];
171 read_band_to_host(&band.data, &mut ab, ldab * n)?;
172
173 let mut ipiv = vec![0_i32; n];
174
175 band_lu_host(&mut ab, n, kl, ku, ldab, &mut ipiv)?;
177
178 write_host_to_band_f64(&mut band.data, &ab, ldab * n)?;
180 write_pivots_to_device(pivots, &ipiv, n)?;
181
182 Ok(())
183}
184
185pub fn band_solve<T: GpuFloat>(
202 handle: &mut SolverHandle,
203 band: &BandMatrix<T>,
204 pivots: &DeviceBuffer<i32>,
205 b: &mut DeviceBuffer<T>,
206 n: usize,
207 nrhs: usize,
208) -> SolverResult<()> {
209 if n == 0 || nrhs == 0 {
210 return Ok(());
211 }
212 if band.n != n {
213 return Err(SolverError::DimensionMismatch(format!(
214 "band_solve: band matrix dimension ({}) != n ({n})",
215 band.n
216 )));
217 }
218 if pivots.len() < n {
219 return Err(SolverError::DimensionMismatch(
220 "band_solve: pivots buffer too small".into(),
221 ));
222 }
223 if b.len() < n * nrhs {
224 return Err(SolverError::DimensionMismatch(
225 "band_solve: B buffer too small".into(),
226 ));
227 }
228
229 let ldab = band.ldab();
230 let kl = band.kl;
231 let ku = band.ku;
232 let ws = (ldab * n + n * nrhs) * std::mem::size_of::<f64>();
233 handle.ensure_workspace(ws)?;
234
235 let mut ab = vec![0.0_f64; ldab * n];
237 read_band_to_host(&band.data, &mut ab, ldab * n)?;
238
239 let mut ipiv = vec![0_i32; n];
240 read_pivots_from_device(pivots, &mut ipiv, n)?;
241
242 let mut b_host = vec![0.0_f64; n * nrhs];
243 read_band_to_host(b, &mut b_host, n * nrhs)?;
244
245 band_solve_host(&ab, &ipiv, &mut b_host, n, kl, ku, ldab, nrhs)?;
247
248 let b_device: Vec<T> = b_host.iter().map(|&v| from_f64(v)).collect();
250 write_host_to_band_t(b, &b_device, n * nrhs)?;
251
252 Ok(())
253}
254
255pub fn band_cholesky<T: GpuFloat>(
270 handle: &mut SolverHandle,
271 band: &mut BandMatrix<T>,
272) -> SolverResult<()> {
273 let n = band.n;
274 let kl = band.kl;
275 let ku = band.ku;
276
277 if n == 0 {
278 return Ok(());
279 }
280 if kl != ku {
281 return Err(SolverError::DimensionMismatch(format!(
282 "band_cholesky: kl ({kl}) must equal ku ({ku}) for symmetric matrix"
283 )));
284 }
285
286 let ldab = band.ldab();
287 let ws = ldab * n * std::mem::size_of::<f64>();
288 handle.ensure_workspace(ws)?;
289
290 let mut ab = vec![0.0_f64; ldab * n];
292 read_band_to_host(&band.data, &mut ab, ldab * n)?;
293
294 band_cholesky_host(&mut ab, n, kl, ldab)?;
296
297 write_host_to_band_f64(&mut band.data, &ab, ldab * n)?;
299
300 Ok(())
301}
302
303fn band_lu_host(
309 ab: &mut [f64],
310 n: usize,
311 kl: usize,
312 ku: usize,
313 ldab: usize,
314 ipiv: &mut [i32],
315) -> SolverResult<()> {
316 for k in 0..n {
317 let mut max_val = 0.0_f64;
319 let mut max_idx = k;
320 let end_row = n.min(k + kl + 1);
321
322 for i in k..end_row {
323 let band_row = kl + i - k;
324 if band_row < ldab {
325 let val = ab[k * ldab + band_row].abs();
326 if val > max_val {
327 max_val = val;
328 max_idx = i;
329 }
330 }
331 }
332
333 ipiv[k] = max_idx as i32;
334
335 if max_val < 1e-300 {
336 return Err(SolverError::SingularMatrix);
337 }
338
339 if max_idx != k {
341 let p = max_idx;
342 let col_start = k.saturating_sub(ku);
344 let col_end = n.min(k + kl + ku + 1);
345 for j in col_start..col_end {
346 let row_k = kl + k;
347 let row_p = kl + p;
348 if row_k >= j && row_k - j < ldab && row_p >= j && row_p - j < ldab {
349 ab.swap(j * ldab + (row_k - j), j * ldab + (row_p - j));
350 }
351 }
352 }
353
354 let pivot = ab[k * ldab + kl];
356 if pivot.abs() < 1e-300 {
357 return Err(SolverError::SingularMatrix);
358 }
359
360 for i in (k + 1)..end_row {
361 let band_row = kl + i - k;
362 if band_row < ldab {
363 let mult = ab[k * ldab + band_row] / pivot;
364 ab[k * ldab + band_row] = mult; let update_end = n.min(k + ku + 1);
368 for j in (k + 1)..update_end {
369 let src_row = kl + k - j + (j - k); let dst_row = kl + i - j;
371 if src_row < ldab && dst_row < ldab && j < n {
372 ab[j * ldab + dst_row] -= mult * ab[j * ldab + src_row];
373 }
374 }
375 }
376 }
377 }
378
379 Ok(())
380}
381
382#[allow(clippy::too_many_arguments)]
384fn band_solve_host(
385 ab: &[f64],
386 ipiv: &[i32],
387 b: &mut [f64],
388 n: usize,
389 kl: usize,
390 _ku: usize,
391 ldab: usize,
392 nrhs: usize,
393) -> SolverResult<()> {
394 for rhs in 0..nrhs {
395 let b_col = &mut b[rhs * n..(rhs + 1) * n];
396
397 for (k, &piv) in ipiv.iter().enumerate().take(n) {
399 let p = piv as usize;
400 if p != k {
401 b_col.swap(k, p);
402 }
403 }
404
405 for k in 0..n {
407 let end_row = n.min(k + kl + 1);
408 for i in (k + 1)..end_row {
409 let band_row = kl + i - k;
410 if band_row < ldab {
411 let mult = ab[k * ldab + band_row];
412 b_col[i] -= mult * b_col[k];
413 }
414 }
415 }
416
417 for k in (0..n).rev() {
419 let pivot = ab[k * ldab + kl];
420 if pivot.abs() < 1e-300 {
421 return Err(SolverError::SingularMatrix);
422 }
423 b_col[k] /= pivot;
424
425 let start_row = k.saturating_sub(kl);
427 for i in start_row..k {
428 let _band_row = kl + i - k;
430 let idx = kl + i;
431 if idx >= k {
432 let br = idx - k;
433 if br < ldab {
434 b_col[i] -= ab[k * ldab + br] * b_col[k];
435 }
436 }
437 }
438 }
439 }
440
441 Ok(())
442}
443
444fn band_cholesky_host(
453 ab: &mut [f64],
454 n: usize,
455 kd: usize, ldab: usize,
457) -> SolverResult<()> {
458 for j in 0..n {
459 let diag_idx = kd; let mut sum = ab[j * ldab + diag_idx];
462
463 let k_start = j.saturating_sub(kd);
465 for k in k_start..j {
466 let band_row_jk = kd + j - k;
467 if band_row_jk < ldab {
468 let ljk = ab[k * ldab + band_row_jk];
469 sum -= ljk * ljk;
470 }
471 }
472
473 if sum <= 0.0 {
474 return Err(SolverError::NotPositiveDefinite);
475 }
476
477 let ljj = sum.sqrt();
478 ab[j * ldab + diag_idx] = ljj;
479
480 let end_row = n.min(j + kd + 1);
482 for i in (j + 1)..end_row {
483 let band_row_ij = kd + i - j;
484 if band_row_ij >= ldab {
485 continue;
486 }
487
488 let mut s = ab[j * ldab + band_row_ij];
489
490 for k in k_start..j {
492 let br_ik = kd + i - k;
493 let br_jk = kd + j - k;
494 if br_ik < ldab && br_jk < ldab {
495 s -= ab[k * ldab + br_ik] * ab[k * ldab + br_jk];
496 }
497 }
498
499 ab[j * ldab + band_row_ij] = s / ljj;
500 }
501 }
502
503 Ok(())
504}
505
506fn read_band_to_host<T: GpuFloat>(
511 _buf: &DeviceBuffer<T>,
512 host: &mut [f64],
513 count: usize,
514) -> SolverResult<()> {
515 for val in host.iter_mut().take(count) {
516 *val = 0.0;
517 }
518 Ok(())
519}
520
521fn write_host_to_band_f64<T: GpuFloat>(
522 _buf: &mut DeviceBuffer<T>,
523 _data: &[f64],
524 _count: usize,
525) -> SolverResult<()> {
526 Ok(())
527}
528
529fn write_host_to_band_t<T: GpuFloat>(
530 _buf: &mut DeviceBuffer<T>,
531 _data: &[T],
532 _count: usize,
533) -> SolverResult<()> {
534 Ok(())
535}
536
537fn write_pivots_to_device(
538 _buf: &mut DeviceBuffer<i32>,
539 _data: &[i32],
540 _count: usize,
541) -> SolverResult<()> {
542 Ok(())
543}
544
545fn read_pivots_from_device(
546 _buf: &DeviceBuffer<i32>,
547 host: &mut [i32],
548 count: usize,
549) -> SolverResult<()> {
550 for (i, val) in host.iter_mut().enumerate().take(count) {
551 *val = i as i32;
552 }
553 Ok(())
554}
555
556#[cfg(test)]
561mod tests {
562 use super::*;
563
564 #[test]
565 fn band_index_tridiagonal() {
566 let n = 5_usize;
569 let kl = 1_usize;
570 let ku = 1_usize;
571 let ldab = 2 * kl + ku + 1; let row_in_band = kl + 2; assert!(row_in_band >= 2); let band_row = row_in_band - 2; assert!(band_row < ldab);
578 let idx = 2 * ldab + band_row; assert_eq!(idx, 9);
580 let _ = n;
581 }
582
583 #[test]
584 fn band_index_out_of_band() {
585 let kl = 1_usize;
587 let row_in_band = kl; let j = 3_usize;
589 assert!(row_in_band < j);
591 }
592
593 #[test]
594 fn band_matrix_ldab_formula() {
595 let kl = 2_usize;
597 let ku = 3_usize;
598 let ldab = 2 * kl + ku + 1;
599 assert_eq!(ldab, 8);
600 }
601
602 #[test]
603 fn band_lu_host_tridiagonal() {
604 let ldab = 4;
609 let n = 3;
610 let mut ab = vec![0.0_f64; ldab * n];
611
612 ab[1] = 2.0; ab[2] = -1.0; ab[ldab] = -1.0; ab[ldab + 1] = 2.0; ab[ldab + 2] = -1.0; ab[2 * ldab] = -1.0;
623 ab[2 * ldab + 1] = 2.0;
624
625 let mut ipiv = vec![0_i32; n];
626 let result = band_lu_host(&mut ab, n, 1, 1, ldab, &mut ipiv);
627 assert!(result.is_ok());
628 }
629
630 #[test]
631 fn band_cholesky_host_tridiagonal() {
632 let kd = 1;
634 let ldab = 2 * kd + kd + 1; let n = 3;
636 let mut ab = vec![0.0_f64; ldab * n];
637
638 ab[1] = 2.0; ab[2] = -1.0; ab[ldab + 1] = 2.0; ab[ldab + 2] = -1.0; ab[2 * ldab + 1] = 2.0; let result = band_cholesky_host(&mut ab, n, kd, ldab);
648 assert!(result.is_ok());
649
650 assert!((ab[1] - 2.0_f64.sqrt()).abs() < 1e-10);
652 }
653
654 #[test]
655 fn band_cholesky_host_not_spd() {
656 let kd = 1;
658 let ldab = 4;
659 let n = 2;
660 let mut ab = vec![0.0_f64; ldab * n];
661
662 ab[1] = -1.0; ab[ldab + 1] = 2.0;
664
665 let result = band_cholesky_host(&mut ab, n, kd, ldab);
666 assert!(result.is_err());
667 }
668
669 #[test]
670 fn f64_conversion_roundtrip() {
671 let val = std::f64::consts::E;
672 let converted: f64 = from_f64(to_f64(val));
673 assert!((converted - val).abs() < 1e-15);
674 }
675
676 #[test]
677 fn f32_conversion_roundtrip() {
678 let val = std::f32::consts::E;
679 let as_f64 = to_f64(val);
680 let back: f32 = from_f64(as_f64);
681 assert!((back - val).abs() < 1e-5);
682 }
683}