1#![allow(dead_code)]
23
24use oxicuda_blas::GpuFloat;
25
26use crate::error::{SolverError, SolverResult};
27use crate::handle::SolverHandle;
28
29fn to_f64<T: GpuFloat>(val: T) -> f64 {
34 if T::SIZE == 4 {
35 f32::from_bits(val.to_bits_u64() as u32) as f64
36 } else {
37 f64::from_bits(val.to_bits_u64())
38 }
39}
40
41fn from_f64<T: GpuFloat>(val: f64) -> T {
42 if T::SIZE == 4 {
43 T::from_bits_u64(u64::from((val as f32).to_bits()))
44 } else {
45 T::from_bits_u64(val.to_bits())
46 }
47}
48
49pub fn tridiagonal_solve<T: GpuFloat>(
74 _handle: &SolverHandle,
75 lower: &[T],
76 diag: &[T],
77 upper: &[T],
78 rhs: &mut [T],
79 n: usize,
80) -> SolverResult<()> {
81 validate_tridiagonal_dims(lower, diag, upper, rhs, n)?;
82
83 if n == 0 {
84 return Ok(());
85 }
86 if n == 1 {
87 return solve_1x1(diag, rhs);
88 }
89
90 if n <= 64 {
94 thomas_solve(lower, diag, upper, rhs, n)
95 } else {
96 cyclic_reduction_solve(lower, diag, upper, rhs, n)
97 }
98}
99
100pub fn batched_tridiagonal_solve<T: GpuFloat>(
120 _handle: &SolverHandle,
121 lower: &[T],
122 diag: &[T],
123 upper: &[T],
124 rhs: &mut [T],
125 n: usize,
126 batch_count: usize,
127) -> SolverResult<()> {
128 if batch_count == 0 || n == 0 {
129 return Ok(());
130 }
131
132 let sub_len = n.saturating_sub(1);
133 let expected_lower = batch_count * sub_len;
134 let expected_diag = batch_count * n;
135 let expected_upper = batch_count * sub_len;
136 let expected_rhs = batch_count * n;
137
138 if lower.len() < expected_lower {
139 return Err(SolverError::DimensionMismatch(format!(
140 "batched_tridiagonal_solve: lower length ({}) < expected ({})",
141 lower.len(),
142 expected_lower
143 )));
144 }
145 if diag.len() < expected_diag {
146 return Err(SolverError::DimensionMismatch(format!(
147 "batched_tridiagonal_solve: diag length ({}) < expected ({})",
148 diag.len(),
149 expected_diag
150 )));
151 }
152 if upper.len() < expected_upper {
153 return Err(SolverError::DimensionMismatch(format!(
154 "batched_tridiagonal_solve: upper length ({}) < expected ({})",
155 upper.len(),
156 expected_upper
157 )));
158 }
159 if rhs.len() < expected_rhs {
160 return Err(SolverError::DimensionMismatch(format!(
161 "batched_tridiagonal_solve: rhs length ({}) < expected ({})",
162 rhs.len(),
163 expected_rhs
164 )));
165 }
166
167 for k in 0..batch_count {
170 let l_start = k * sub_len;
171 let d_start = k * n;
172 let u_start = k * sub_len;
173 let r_start = k * n;
174
175 let l_slice = &lower[l_start..l_start + sub_len];
176 let d_slice = &diag[d_start..d_start + n];
177 let u_slice = &upper[u_start..u_start + sub_len];
178 let r_slice = &mut rhs[r_start..r_start + n];
179
180 if n == 1 {
181 solve_1x1(d_slice, r_slice)?;
182 } else {
183 thomas_solve(l_slice, d_slice, u_slice, r_slice, n)?;
184 }
185 }
186
187 Ok(())
188}
189
190fn thomas_solve<T: GpuFloat>(
199 lower: &[T],
200 diag: &[T],
201 upper: &[T],
202 rhs: &mut [T],
203 n: usize,
204) -> SolverResult<()> {
205 let mut c_prime = vec![0.0_f64; n];
207 let mut d_prime = vec![0.0_f64; n];
208
209 let d0 = to_f64(diag[0]);
210 if d0.abs() < 1e-300 {
211 return Err(SolverError::SingularMatrix);
212 }
213
214 c_prime[0] = to_f64(upper[0]) / d0;
215 d_prime[0] = to_f64(rhs[0]) / d0;
216
217 for i in 1..n {
219 let a_i = to_f64(lower[i - 1]);
220 let b_i = to_f64(diag[i]);
221 let d_i = to_f64(rhs[i]);
222
223 let denom = b_i - a_i * c_prime[i - 1];
224 if denom.abs() < 1e-300 {
225 return Err(SolverError::SingularMatrix);
226 }
227
228 if i < n - 1 {
229 c_prime[i] = to_f64(upper[i]) / denom;
230 }
231 d_prime[i] = (d_i - a_i * d_prime[i - 1]) / denom;
232 }
233
234 rhs[n - 1] = from_f64(d_prime[n - 1]);
236 for i in (0..n - 1).rev() {
237 d_prime[i] -= c_prime[i] * to_f64(rhs[i + 1]);
238 rhs[i] = from_f64(d_prime[i]);
239 }
240
241 Ok(())
242}
243
244fn cyclic_reduction_solve<T: GpuFloat>(
256 lower: &[T],
257 diag: &[T],
258 upper: &[T],
259 rhs: &mut [T],
260 n: usize,
261) -> SolverResult<()> {
262 let mut a = vec![0.0_f64; n]; let mut b = vec![0.0_f64; n]; let mut c = vec![0.0_f64; n]; let mut d = vec![0.0_f64; n]; b[0] = to_f64(diag[0]);
270 d[0] = to_f64(rhs[0]);
271 if n > 1 {
272 c[0] = to_f64(upper[0]);
273 }
274
275 for i in 1..n {
276 a[i] = to_f64(lower[i - 1]);
277 b[i] = to_f64(diag[i]);
278 d[i] = to_f64(rhs[i]);
279 if i < n - 1 {
280 c[i] = to_f64(upper[i]);
281 }
282 }
283
284 let mut stride = 1_usize;
286 let mut active_n = n;
287
288 type ReductionLevel = (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>);
291 let mut levels: Vec<ReductionLevel> = Vec::new();
292
293 while active_n > 2 {
294 levels.push((a.clone(), b.clone(), c.clone(), d.clone()));
295
296 let mut a_new = vec![0.0_f64; n];
297 let mut b_new = vec![0.0_f64; n];
298 let mut c_new = vec![0.0_f64; n];
299 let mut d_new = vec![0.0_f64; n];
300
301 let mut count = 0;
304 let mut i = stride;
305 while i < n {
306 let left = i.saturating_sub(stride);
307 let right = if i + stride < n { i + stride } else { n - 1 };
308
309 let bi = b[i];
310 if bi.abs() < 1e-300 {
311 return Err(SolverError::SingularMatrix);
312 }
313
314 let alpha = if left < i && b[left].abs() > 1e-300 {
316 -a[i] / b[left]
317 } else {
318 0.0
319 };
320 let gamma = if right > i && b[right].abs() > 1e-300 {
321 -c[i] / b[right]
322 } else {
323 0.0
324 };
325
326 b_new[i] = bi + alpha * c[left] + gamma * a[right];
328 d_new[i] = d[i] + alpha * d[left] + gamma * d[right];
329 a_new[i] = alpha * a[left];
330 c_new[i] = gamma * c[right];
331
332 count += 1;
333 i += 2 * stride;
334 }
335
336 a = a_new;
337 b = b_new;
338 c = c_new;
339 d = d_new;
340
341 stride *= 2;
342 active_n = count;
343
344 if active_n <= 1 {
345 break;
346 }
347 }
348
349 let mut active_indices = Vec::new();
352 let mut idx = stride - 1;
353 if idx >= n {
354 return thomas_solve(lower, diag, upper, rhs, n);
356 }
357 while idx < n {
358 active_indices.push(idx);
359 idx += stride;
360 }
361
362 match active_indices.len() {
364 0 => {}
365 1 => {
366 let i = active_indices[0];
367 if b[i].abs() < 1e-300 {
368 return Err(SolverError::SingularMatrix);
369 }
370 d[i] /= b[i];
371 }
372 2 => {
373 let i0 = active_indices[0];
374 let i1 = active_indices[1];
375 let det = b[i0] * b[i1] - c[i0] * a[i1];
378 if det.abs() < 1e-300 {
379 return Err(SolverError::SingularMatrix);
380 }
381 let x0 = (d[i0] * b[i1] - c[i0] * d[i1]) / det;
382 let x1 = (b[i0] * d[i1] - a[i1] * d[i0]) / det;
383 d[i0] = x0;
384 d[i1] = x1;
385 }
386 _ => {
387 let k = active_indices.len();
389 let mut sub_a = vec![0.0_f64; k];
390 let mut sub_b = vec![0.0_f64; k];
391 let mut sub_c = vec![0.0_f64; k];
392 let mut sub_d = vec![0.0_f64; k];
393 for (j, &ai) in active_indices.iter().enumerate() {
394 sub_a[j] = a[ai];
395 sub_b[j] = b[ai];
396 sub_c[j] = c[ai];
397 sub_d[j] = d[ai];
398 }
399 if sub_b[0].abs() < 1e-300 {
401 return Err(SolverError::SingularMatrix);
402 }
403 let mut cp = vec![0.0_f64; k];
404 let mut dp = vec![0.0_f64; k];
405 cp[0] = sub_c[0] / sub_b[0];
406 dp[0] = sub_d[0] / sub_b[0];
407 for j in 1..k {
408 let denom = sub_b[j] - sub_a[j] * cp[j - 1];
409 if denom.abs() < 1e-300 {
410 return Err(SolverError::SingularMatrix);
411 }
412 cp[j] = sub_c[j] / denom;
413 dp[j] = (sub_d[j] - sub_a[j] * dp[j - 1]) / denom;
414 }
415 sub_d[k - 1] = dp[k - 1];
416 for j in (0..k - 1).rev() {
417 sub_d[j] = dp[j] - cp[j] * sub_d[j + 1];
418 }
419 for (j, &ai) in active_indices.iter().enumerate() {
420 d[ai] = sub_d[j];
421 }
422 }
423 }
424
425 for level_data in levels.iter().rev() {
427 let (ref la, ref lb, ref _lc, ref ld) = *level_data;
428 let half_stride = stride / 2;
430 let mut i = half_stride.saturating_sub(1);
431 while i < n {
432 let is_solved = if stride > 0 {
434 (i + 1) % stride == 0
435 } else {
436 false
437 };
438
439 if !is_solved {
440 let bi = lb[i];
442 if bi.abs() < 1e-300 {
443 return Err(SolverError::SingularMatrix);
444 }
445 let left_val = if i >= half_stride {
446 d[i - half_stride]
447 } else {
448 0.0
449 };
450 let right_val = if i + half_stride < n {
451 d[i + half_stride]
452 } else {
453 0.0
454 };
455 d[i] = (ld[i] - la[i] * left_val - level_data.2[i] * right_val) / bi;
456 }
457 i += half_stride;
458 }
459 stride /= 2;
460 }
461
462 for i in 0..n {
464 rhs[i] = from_f64(d[i]);
465 }
466
467 Ok(())
468}
469
470fn validate_tridiagonal_dims<T: GpuFloat>(
476 lower: &[T],
477 diag: &[T],
478 upper: &[T],
479 rhs: &[T],
480 n: usize,
481) -> SolverResult<()> {
482 if diag.len() < n {
483 return Err(SolverError::DimensionMismatch(format!(
484 "tridiagonal_solve: diag length ({}) < n ({n})",
485 diag.len()
486 )));
487 }
488 if rhs.len() < n {
489 return Err(SolverError::DimensionMismatch(format!(
490 "tridiagonal_solve: rhs length ({}) < n ({n})",
491 rhs.len()
492 )));
493 }
494 if n > 1 {
495 if lower.len() < n - 1 {
496 return Err(SolverError::DimensionMismatch(format!(
497 "tridiagonal_solve: lower length ({}) < n-1 ({})",
498 lower.len(),
499 n - 1
500 )));
501 }
502 if upper.len() < n - 1 {
503 return Err(SolverError::DimensionMismatch(format!(
504 "tridiagonal_solve: upper length ({}) < n-1 ({})",
505 upper.len(),
506 n - 1
507 )));
508 }
509 }
510 Ok(())
511}
512
513fn solve_1x1<T: GpuFloat>(diag: &[T], rhs: &mut [T]) -> SolverResult<()> {
515 let d = to_f64(diag[0]);
516 if d.abs() < 1e-300 {
517 return Err(SolverError::SingularMatrix);
518 }
519 rhs[0] = from_f64(to_f64(rhs[0]) / d);
520 Ok(())
521}
522
523#[cfg(test)]
528mod tests {
529 use super::*;
530
531 #[test]
534 fn validate_dims_ok() {
535 let lower = [1.0_f64; 2];
536 let diag = [2.0_f64; 3];
537 let upper = [1.0_f64; 2];
538 let rhs = [1.0_f64; 3];
539 let result = validate_tridiagonal_dims(&lower, &diag, &upper, &rhs, 3);
540 assert!(result.is_ok());
541 }
542
543 #[test]
544 fn validate_dims_diag_too_short() {
545 let lower = [1.0_f64; 2];
546 let diag = [2.0_f64; 2];
547 let upper = [1.0_f64; 2];
548 let rhs = [1.0_f64; 3];
549 let result = validate_tridiagonal_dims(&lower, &diag, &upper, &rhs, 3);
550 assert!(result.is_err());
551 }
552
553 #[test]
554 fn validate_dims_lower_too_short() {
555 let lower = [1.0_f64; 1];
556 let diag = [2.0_f64; 3];
557 let upper = [1.0_f64; 2];
558 let rhs = [1.0_f64; 3];
559 let result = validate_tridiagonal_dims(&lower, &diag, &upper, &rhs, 3);
560 assert!(result.is_err());
561 }
562
563 #[test]
566 fn thomas_solve_2x2() {
567 let lower = [1.0_f64];
571 let diag = [2.0_f64, 3.0];
572 let upper = [1.0_f64];
573 let mut rhs = [5.0_f64, 7.0];
574
575 let result = thomas_solve(&lower, &diag, &upper, &mut rhs, 2);
576 assert!(result.is_ok());
577 assert!((rhs[0] - 1.6).abs() < 1e-10);
578 assert!((rhs[1] - 1.8).abs() < 1e-10);
579 }
580
581 #[test]
582 fn thomas_solve_3x3() {
583 let lower = [1.0_f64, 1.0];
588 let diag = [4.0_f64, 4.0, 4.0];
589 let upper = [1.0_f64, 1.0];
590 let mut rhs = [5.0_f64, 6.0, 5.0];
591
592 let result = thomas_solve(&lower, &diag, &upper, &mut rhs, 3);
593 assert!(result.is_ok());
594 assert!((rhs[0] - 1.0).abs() < 1e-10);
595 assert!((rhs[1] - 1.0).abs() < 1e-10);
596 assert!((rhs[2] - 1.0).abs() < 1e-10);
597 }
598
599 #[test]
600 fn thomas_solve_singular() {
601 let lower = [1.0_f64];
602 let diag = [0.0_f64, 1.0]; let upper = [1.0_f64];
604 let mut rhs = [1.0_f64, 1.0];
605
606 let result = thomas_solve(&lower, &diag, &upper, &mut rhs, 2);
607 assert!(result.is_err());
608 }
609
610 #[test]
613 fn cyclic_reduction_3x3() {
614 let lower = [1.0_f64, 1.0];
615 let diag = [4.0_f64, 4.0, 4.0];
616 let upper = [1.0_f64, 1.0];
617 let mut rhs = [5.0_f64, 6.0, 5.0];
618
619 let result = cyclic_reduction_solve(&lower, &diag, &upper, &mut rhs, 3);
620 assert!(result.is_ok());
621 assert!((rhs[0] - 1.0).abs() < 1e-8);
622 assert!((rhs[1] - 1.0).abs() < 1e-8);
623 assert!((rhs[2] - 1.0).abs() < 1e-8);
624 }
625
626 #[test]
627 fn cyclic_reduction_4x4() {
628 let lower = [-1.0_f64, -1.0, -1.0];
634 let diag = [2.0_f64, 2.0, 2.0, 2.0];
635 let upper = [-1.0_f64, -1.0, -1.0];
636 let mut rhs = [1.0_f64, 0.0, 0.0, 1.0];
637
638 let result = cyclic_reduction_solve(&lower, &diag, &upper, &mut rhs, 4);
639 assert!(result.is_ok());
640 for (i, &val) in rhs.iter().enumerate() {
641 assert!((val - 1.0).abs() < 1e-8, "x[{i}] = {val} (expected 1.0)",);
642 }
643 }
644
645 #[test]
648 fn solve_1x1_basic() {
649 let diag = [5.0_f64];
650 let mut rhs = [10.0_f64];
651 let result = solve_1x1(&diag, &mut rhs);
652 assert!(result.is_ok());
653 assert!((rhs[0] - 2.0).abs() < 1e-15);
654 }
655
656 #[test]
657 fn solve_1x1_zero_diag() {
658 let diag = [0.0_f64];
659 let mut rhs = [10.0_f64];
660 let result = solve_1x1(&diag, &mut rhs);
661 assert!(result.is_err());
662 }
663
664 #[test]
667 fn thomas_solve_f32() {
668 let lower = [1.0_f32];
669 let diag = [2.0_f32, 3.0];
670 let upper = [1.0_f32];
671 let mut rhs = [5.0_f32, 7.0];
672
673 let result = thomas_solve(&lower, &diag, &upper, &mut rhs, 2);
674 assert!(result.is_ok());
675 assert!((rhs[0] - 1.6_f32).abs() < 1e-5);
676 assert!((rhs[1] - 1.8_f32).abs() < 1e-5);
677 }
678
679 #[test]
682 fn f64_roundtrip() {
683 let val = std::f64::consts::PI;
684 let back: f64 = from_f64(to_f64(val));
685 assert!((back - val).abs() < 1e-15);
686 }
687
688 #[test]
689 fn f32_roundtrip() {
690 let val = 3.15_f32;
691 let back: f32 = from_f64(to_f64(val));
692 assert!((back - val).abs() < 1e-6);
693 }
694}