Skip to main content

oxicuda_solver/dense/
ldlt.rs

1//! Symmetric Indefinite Factorization (LDL^T / Bunch-Kaufman).
2//!
3//! Computes `P * A * P^T = L * D * L^T` where:
4//! - P is a permutation matrix (encoded in pivot_info)
5//! - L is unit lower triangular
6//! - D is block diagonal with 1x1 and 2x2 blocks
7//!
8//! # Algorithm — Bunch-Kaufman Pivoting
9//!
10//! At each step, the algorithm decides whether to use a 1x1 or 2x2 pivot
11//! by examining the magnitudes of diagonal and off-diagonal elements:
12//!
13//! 1. Let `alpha = (1 + sqrt(17)) / 8` (~0.6404).
14//! 2. Find `lambda = max |A[i, k]|` for i != k (largest off-diagonal in column k).
15//! 3. If `|A[k,k]| >= alpha * lambda`: use 1x1 pivot at position k.
16//! 4. Otherwise, find `sigma = max |A[i, r]|` for i != r, where r is the row
17//!    achieving lambda.
18//! 5. If `|A[k,k]| * sigma >= alpha * lambda^2`: use 1x1 pivot.
19//! 6. If `|A[r,r]| >= alpha * sigma`: use 1x1 pivot at position r (with swap).
20//! 7. Otherwise: use 2x2 pivot at positions (k, r).
21//!
22//! This guarantees bounded element growth in L.
23
24#![allow(dead_code)]
25
26use oxicuda_blas::types::{FillMode, GpuFloat};
27use oxicuda_memory::DeviceBuffer;
28
29use crate::error::{SolverError, SolverResult};
30use crate::handle::SolverHandle;
31
32// ---------------------------------------------------------------------------
33// GpuFloat <-> f64 conversion helpers
34// ---------------------------------------------------------------------------
35
36fn to_f64<T: GpuFloat>(val: T) -> f64 {
37    if T::SIZE == 4 {
38        f32::from_bits(val.to_bits_u64() as u32) as f64
39    } else {
40        f64::from_bits(val.to_bits_u64())
41    }
42}
43
44fn from_f64<T: GpuFloat>(val: f64) -> T {
45    if T::SIZE == 4 {
46        T::from_bits_u64(u64::from((val as f32).to_bits()))
47    } else {
48        T::from_bits_u64(val.to_bits())
49    }
50}
51
52// ---------------------------------------------------------------------------
53// Constants
54// ---------------------------------------------------------------------------
55
56/// Bunch-Kaufman pivot threshold: alpha = (1 + sqrt(17)) / 8.
57const BUNCH_KAUFMAN_ALPHA: f64 = 0.6403882032022076;
58
59// ---------------------------------------------------------------------------
60// Result type
61// ---------------------------------------------------------------------------
62
63/// LDL^T factorization result.
64///
65/// Contains pivot information describing the block diagonal structure of D:
66/// - `pivot_info[k] > 0`: 1x1 pivot block at position k; rows/cols were swapped
67///   with row/col `pivot_info[k] - 1`.
68/// - `pivot_info[k] < 0`: start of a 2x2 pivot block at positions (k, k+1);
69///   rows/cols were swapped with row/col `(-pivot_info[k]) - 1`.
70pub struct LdltResult {
71    /// Pivot block sizes: positive for 1x1, negative for start of 2x2.
72    pub pivot_info: DeviceBuffer<i32>,
73}
74
75impl std::fmt::Debug for LdltResult {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        f.debug_struct("LdltResult")
78            .field("pivot_info_len", &self.pivot_info.len())
79            .finish()
80    }
81}
82
83// ---------------------------------------------------------------------------
84// Public API
85// ---------------------------------------------------------------------------
86
87/// Computes the Bunch-Kaufman LDL^T factorization of a symmetric indefinite matrix.
88///
89/// On exit, the specified triangle of `a` is overwritten with L and D:
90/// - The unit lower triangular factor L (with unit diagonal implicit).
91/// - The block diagonal factor D (1x1 and 2x2 blocks on the diagonal/super-diagonal).
92///
93/// # Arguments
94///
95/// * `handle` — solver handle.
96/// * `a` — symmetric matrix (n x n, column-major), overwritten with L and D.
97/// * `n` — matrix dimension.
98/// * `uplo` — which triangle to read/write (Lower or Upper).
99///
100/// # Returns
101///
102/// An [`LdltResult`] containing the pivot information.
103///
104/// # Errors
105///
106/// Returns [`SolverError::DimensionMismatch`] for invalid dimensions.
107/// Returns [`SolverError::SingularMatrix`] if the matrix is singular.
108pub fn ldlt<T: GpuFloat>(
109    handle: &mut SolverHandle,
110    a: &mut DeviceBuffer<T>,
111    n: usize,
112    uplo: FillMode,
113) -> SolverResult<LdltResult> {
114    if n == 0 {
115        let pivot_info = DeviceBuffer::<i32>::zeroed(0)?;
116        return Ok(LdltResult { pivot_info });
117    }
118    if a.len() < n * n {
119        return Err(SolverError::DimensionMismatch(format!(
120            "ldlt: buffer too small ({} < {})",
121            a.len(),
122            n * n
123        )));
124    }
125    if uplo == FillMode::Full {
126        return Err(SolverError::DimensionMismatch(
127            "ldlt: uplo must be Upper or Lower, not Full".into(),
128        ));
129    }
130
131    // Workspace for the host-side factorization.
132    let ws = n * n * std::mem::size_of::<f64>();
133    handle.ensure_workspace(ws)?;
134
135    // Read the matrix into host memory for the factorization.
136    let mut a_host = vec![0.0_f64; n * n];
137    read_device_to_host(a, &mut a_host, n * n)?;
138
139    // Perform the Bunch-Kaufman factorization on the host.
140    let mut ipiv = vec![0_i32; n];
141    bunch_kaufman_factorize(&mut a_host, n, uplo, &mut ipiv)?;
142
143    // Write back to device.
144    let a_device: Vec<T> = a_host.iter().map(|&v| from_f64(v)).collect();
145    write_host_to_device(a, &a_device, n * n)?;
146
147    let mut pivot_info = DeviceBuffer::<i32>::zeroed(n)?;
148    write_host_to_device_i32(&mut pivot_info, &ipiv, n)?;
149
150    Ok(LdltResult { pivot_info })
151}
152
153/// Solves `A * x = b` using the LDL^T factorization.
154///
155/// The LDL^T factors must have been computed by [`ldlt`].
156///
157/// Algorithm:
158/// 1. Apply row permutations P to b.
159/// 2. Forward substitution: solve `L * y = P * b`.
160/// 3. Block diagonal solve: solve `D * z = y`.
161/// 4. Backward substitution: solve `L^T * w = z`.
162/// 5. Apply column permutations P^T to w to get x.
163///
164/// # Arguments
165///
166/// * `handle` — solver handle.
167/// * `a` — LDL^T-factored matrix (output of `ldlt`).
168/// * `pivot_info` — pivot information from `ldlt`.
169/// * `b` — right-hand side matrix (n x nrhs), overwritten with solution.
170/// * `n` — system dimension.
171/// * `nrhs` — number of right-hand side columns.
172/// * `uplo` — which triangle contains the factor.
173///
174/// # Errors
175///
176/// Returns [`SolverError::DimensionMismatch`] for invalid dimensions.
177pub fn ldlt_solve<T: GpuFloat>(
178    handle: &mut SolverHandle,
179    a: &DeviceBuffer<T>,
180    pivot_info: &DeviceBuffer<i32>,
181    b: &mut DeviceBuffer<T>,
182    n: usize,
183    nrhs: usize,
184    uplo: FillMode,
185) -> SolverResult<()> {
186    if n == 0 || nrhs == 0 {
187        return Ok(());
188    }
189    if a.len() < n * n {
190        return Err(SolverError::DimensionMismatch(
191            "ldlt_solve: factor buffer too small".into(),
192        ));
193    }
194    if pivot_info.len() < n {
195        return Err(SolverError::DimensionMismatch(
196            "ldlt_solve: pivot_info buffer too small".into(),
197        ));
198    }
199    if b.len() < n * nrhs {
200        return Err(SolverError::DimensionMismatch(
201            "ldlt_solve: B buffer too small".into(),
202        ));
203    }
204
205    // Workspace.
206    let ws = (n * n + n * nrhs) * std::mem::size_of::<f64>();
207    handle.ensure_workspace(ws)?;
208
209    // Read to host.
210    let mut a_host = vec![0.0_f64; n * n];
211    read_device_to_host(a, &mut a_host, n * n)?;
212
213    let mut ipiv = vec![0_i32; n];
214    read_device_to_host_i32(pivot_info, &mut ipiv, n)?;
215
216    let mut b_host = vec![0.0_f64; n * nrhs];
217    read_device_to_host(b, &mut b_host, n * nrhs)?;
218
219    // Solve on host.
220    bunch_kaufman_solve(&a_host, &ipiv, &mut b_host, n, nrhs, uplo)?;
221
222    // Write back.
223    let b_device: Vec<T> = b_host.iter().map(|&v| from_f64(v)).collect();
224    write_host_to_device(b, &b_device, n * nrhs)?;
225
226    Ok(())
227}
228
229// ---------------------------------------------------------------------------
230// Bunch-Kaufman factorization (host-side)
231// ---------------------------------------------------------------------------
232
233/// Bunch-Kaufman factorization on a host-side column-major matrix.
234fn bunch_kaufman_factorize(
235    a: &mut [f64],
236    n: usize,
237    uplo: FillMode,
238    ipiv: &mut [i32],
239) -> SolverResult<()> {
240    match uplo {
241        FillMode::Lower => bunch_kaufman_lower(a, n, ipiv),
242        FillMode::Upper => {
243            // Normalize to a lower-storage representation and reuse the same
244            // decomposition path to avoid divergent structural implementations.
245            mirror_upper_to_lower(a, n);
246            bunch_kaufman_lower(a, n, ipiv)
247        }
248        FillMode::Full => Err(SolverError::DimensionMismatch(
249            "ldlt: uplo must be Lower or Upper".into(),
250        )),
251    }
252}
253
254/// Lower-triangular Bunch-Kaufman: P*A*P^T = L*D*L^T.
255fn bunch_kaufman_lower(a: &mut [f64], n: usize, ipiv: &mut [i32]) -> SolverResult<()> {
256    let mut k = 0;
257
258    while k < n {
259        // Find the largest off-diagonal in column k (below diagonal).
260        let (lambda, r_idx) = column_max_offdiag(a, n, k, true);
261
262        let akk = a[k * n + k].abs();
263
264        if akk < 1e-300 && lambda < 1e-300 {
265            // Entire column is zero — matrix is singular at this point.
266            return Err(SolverError::SingularMatrix);
267        }
268
269        if akk >= BUNCH_KAUFMAN_ALPHA * lambda {
270            // Case 1: Use 1x1 pivot at position k.
271            perform_1x1_pivot_lower(a, n, k);
272            ipiv[k] = (k + 1) as i32; // 1-based, positive => 1x1
273            k += 1;
274        } else {
275            // Find sigma = max |A[i, r]| for i != r.
276            let (sigma, _) = column_max_offdiag(a, n, r_idx, true);
277
278            if akk * sigma >= BUNCH_KAUFMAN_ALPHA * lambda * lambda {
279                // Case 2: 1x1 pivot at k is fine.
280                perform_1x1_pivot_lower(a, n, k);
281                ipiv[k] = (k + 1) as i32;
282                k += 1;
283            } else if a[r_idx * n + r_idx].abs() >= BUNCH_KAUFMAN_ALPHA * sigma {
284                // Case 3: 1x1 pivot at r (swap k <-> r first).
285                if r_idx != k {
286                    swap_rows_and_cols(a, n, k, r_idx);
287                }
288                perform_1x1_pivot_lower(a, n, k);
289                ipiv[k] = (r_idx + 1) as i32;
290                k += 1;
291            } else {
292                // Case 4: 2x2 pivot at (k, k+1).
293                if k + 1 >= n {
294                    // Edge case: can't form 2x2 block at the last row.
295                    perform_1x1_pivot_lower(a, n, k);
296                    ipiv[k] = (k + 1) as i32;
297                    k += 1;
298                } else {
299                    if r_idx != k + 1 {
300                        swap_rows_and_cols(a, n, k + 1, r_idx);
301                    }
302                    perform_2x2_pivot_lower(a, n, k)?;
303                    ipiv[k] = -((r_idx + 1) as i32); // Negative => start of 2x2
304                    ipiv[k + 1] = ipiv[k];
305                    k += 2;
306                }
307            }
308        }
309    }
310
311    Ok(())
312}
313
314/// Upper-triangular Bunch-Kaufman: P*A*P^T = U^T*D*U.
315fn bunch_kaufman_upper(a: &mut [f64], n: usize, ipiv: &mut [i32]) -> SolverResult<()> {
316    if n == 0 {
317        return Ok(());
318    }
319
320    let mut k = n;
321
322    while k > 0 {
323        let col = k - 1;
324        let (lambda, r_idx) = column_max_offdiag(a, n, col, false);
325        let akk = a[col * n + col].abs();
326
327        if akk < 1e-300 && lambda < 1e-300 {
328            return Err(SolverError::SingularMatrix);
329        }
330
331        if akk >= BUNCH_KAUFMAN_ALPHA * lambda {
332            ipiv[col] = (col + 1) as i32;
333            k -= 1;
334        } else {
335            let (sigma, _) = column_max_offdiag(a, n, r_idx, false);
336
337            if akk * sigma >= BUNCH_KAUFMAN_ALPHA * lambda * lambda {
338                ipiv[col] = (col + 1) as i32;
339                k -= 1;
340            } else if a[r_idx * n + r_idx].abs() >= BUNCH_KAUFMAN_ALPHA * sigma {
341                if r_idx != col {
342                    swap_rows_and_cols(a, n, col, r_idx);
343                }
344                ipiv[col] = (r_idx + 1) as i32;
345                k -= 1;
346            } else {
347                if col == 0 {
348                    ipiv[col] = (col + 1) as i32;
349                    k -= 1;
350                } else {
351                    let col2 = col - 1;
352                    if r_idx != col2 {
353                        swap_rows_and_cols(a, n, col2, r_idx);
354                    }
355                    ipiv[col] = -((r_idx + 1) as i32);
356                    ipiv[col2] = ipiv[col];
357                    k -= 2;
358                }
359            }
360        }
361    }
362
363    Ok(())
364}
365
366// ---------------------------------------------------------------------------
367// Pivot operations
368// ---------------------------------------------------------------------------
369
370/// Finds the maximum absolute off-diagonal element in a column.
371/// Returns (max_value, row_index).
372fn column_max_offdiag(a: &[f64], n: usize, col: usize, lower: bool) -> (f64, usize) {
373    let mut max_val = 0.0_f64;
374    let mut max_idx = col;
375
376    if lower {
377        for i in (col + 1)..n {
378            let val = a[col * n + i].abs();
379            if val > max_val {
380                max_val = val;
381                max_idx = i;
382            }
383        }
384    } else {
385        for i in 0..col {
386            let val = a[col * n + i].abs();
387            if val > max_val {
388                max_val = val;
389                max_idx = i;
390            }
391        }
392    }
393
394    (max_val, max_idx)
395}
396
397/// Swaps rows and columns i and j in a symmetric matrix (column-major).
398fn swap_rows_and_cols(a: &mut [f64], n: usize, i: usize, j: usize) {
399    if i == j {
400        return;
401    }
402    // Swap row i and row j.
403    for col in 0..n {
404        a.swap(col * n + i, col * n + j);
405    }
406    // Swap col i and col j.
407    for row in 0..n {
408        a.swap(i * n + row, j * n + row);
409    }
410}
411
412/// Performs a 1x1 pivot step at position k for lower-triangular factorization.
413fn perform_1x1_pivot_lower(a: &mut [f64], n: usize, k: usize) {
414    let akk = a[k * n + k];
415    if akk.abs() < 1e-300 {
416        return; // Cannot pivot on zero.
417    }
418    let inv_akk = 1.0 / akk;
419
420    // Scale column k below diagonal: L[i, k] = A[i, k] / A[k, k].
421    for i in (k + 1)..n {
422        a[k * n + i] *= inv_akk;
423    }
424
425    // Update trailing submatrix: A[i, j] -= L[i, k] * D[k, k] * L[j, k].
426    for j in (k + 1)..n {
427        let ljk = a[k * n + j];
428        for i in j..n {
429            let lik = a[k * n + i];
430            a[j * n + i] -= lik * akk * ljk;
431        }
432    }
433}
434
435/// Performs a 2x2 pivot step at positions (k, k+1) for lower-triangular factorization.
436fn perform_2x2_pivot_lower(a: &mut [f64], n: usize, k: usize) -> SolverResult<()> {
437    if k + 1 >= n {
438        return Err(SolverError::InternalError(
439            "ldlt: 2x2 pivot at boundary".into(),
440        ));
441    }
442
443    // Extract 2x2 block D.
444    let d11 = a[k * n + k];
445    let d21 = a[k * n + (k + 1)];
446    let d22 = a[(k + 1) * n + (k + 1)];
447
448    // Invert the 2x2 block: D^{-1} = adj(D) / det(D).
449    let det = d11 * d22 - d21 * d21;
450    if det.abs() < 1e-300 {
451        return Err(SolverError::SingularMatrix);
452    }
453    let inv_det = 1.0 / det;
454
455    // Compute L columns below the 2x2 block.
456    // [L[i,k], L[i,k+1]] = [A[i,k], A[i,k+1]] * D^{-1}
457    for i in (k + 2)..n {
458        let aik = a[k * n + i];
459        let aik1 = a[(k + 1) * n + i];
460
461        a[k * n + i] = (d22 * aik - d21 * aik1) * inv_det;
462        a[(k + 1) * n + i] = (-d21 * aik + d11 * aik1) * inv_det;
463    }
464
465    // Update trailing submatrix.
466    for j in (k + 2)..n {
467        let ljk = a[k * n + j];
468        let ljk1 = a[(k + 1) * n + j];
469
470        for i in j..n {
471            let lik = a[k * n + i];
472            let lik1 = a[(k + 1) * n + i];
473
474            // A[i,j] -= L[i,k]*D[k,k]*L[j,k] + L[i,k]*D[k,k+1]*L[j,k+1]
475            //         + L[i,k+1]*D[k+1,k]*L[j,k] + L[i,k+1]*D[k+1,k+1]*L[j,k+1]
476            a[j * n + i] -=
477                lik * d11 * ljk + lik * d21 * ljk1 + lik1 * d21 * ljk + lik1 * d22 * ljk1;
478        }
479    }
480
481    Ok(())
482}
483
484// ---------------------------------------------------------------------------
485// Bunch-Kaufman solve
486// ---------------------------------------------------------------------------
487
488/// Solves the system using the LDL^T factorization.
489fn bunch_kaufman_solve(
490    a: &[f64],
491    ipiv: &[i32],
492    b: &mut [f64],
493    n: usize,
494    nrhs: usize,
495    uplo: FillMode,
496) -> SolverResult<()> {
497    match uplo {
498        FillMode::Lower => bunch_kaufman_solve_lower(a, ipiv, b, n, nrhs),
499        FillMode::Upper => bunch_kaufman_solve_lower(a, ipiv, b, n, nrhs),
500        FillMode::Full => Err(SolverError::DimensionMismatch(
501            "ldlt_solve: uplo must be Lower or Upper".into(),
502        )),
503    }
504}
505
506/// Lower-triangular solve: L*D*L^T * x = b.
507fn bunch_kaufman_solve_lower(
508    a: &[f64],
509    ipiv: &[i32],
510    b: &mut [f64],
511    n: usize,
512    nrhs: usize,
513) -> SolverResult<()> {
514    for rhs in 0..nrhs {
515        let b_col = &mut b[rhs * n..(rhs + 1) * n];
516
517        // Step 1: Apply permutations and forward substitution (L * y = P * b).
518        let mut k = 0;
519        while k < n {
520            if ipiv[k] > 0 {
521                // 1x1 pivot.
522                let p = (ipiv[k] - 1) as usize;
523                if p != k {
524                    b_col.swap(k, p);
525                }
526                // Forward sub: b[i] -= L[i,k] * b[k] for i > k.
527                for i in (k + 1)..n {
528                    b_col[i] -= a[k * n + i] * b_col[k];
529                }
530                k += 1;
531            } else {
532                // 2x2 pivot.
533                let p = ((-ipiv[k]) - 1) as usize;
534                if p != k + 1 {
535                    b_col.swap(k + 1, p);
536                }
537                for i in (k + 2)..n {
538                    b_col[i] -= a[k * n + i] * b_col[k];
539                    b_col[i] -= a[(k + 1) * n + i] * b_col[k + 1];
540                }
541                k += 2;
542            }
543        }
544
545        // Step 2: Solve D * z = y.
546        k = 0;
547        while k < n {
548            if ipiv[k] > 0 {
549                // 1x1 block.
550                let dkk = a[k * n + k];
551                if dkk.abs() < 1e-300 {
552                    return Err(SolverError::SingularMatrix);
553                }
554                b_col[k] /= dkk;
555                k += 1;
556            } else {
557                // 2x2 block.
558                if k + 1 >= n {
559                    return Err(SolverError::InternalError(
560                        "ldlt_solve: invalid 2x2 pivot at boundary".into(),
561                    ));
562                }
563                let d11 = a[k * n + k];
564                let d21 = a[k * n + (k + 1)];
565                let d22 = a[(k + 1) * n + (k + 1)];
566                let det = d11 * d22 - d21 * d21;
567                if det.abs() < 1e-300 {
568                    return Err(SolverError::SingularMatrix);
569                }
570                let inv_det = 1.0 / det;
571                let y1 = b_col[k];
572                let y2 = b_col[k + 1];
573                b_col[k] = (d22 * y1 - d21 * y2) * inv_det;
574                b_col[k + 1] = (-d21 * y1 + d11 * y2) * inv_det;
575                k += 2;
576            }
577        }
578
579        // Step 3: Backward substitution (L^T * w = z) and apply P^T.
580        k = n;
581        while k > 0 {
582            k -= 1;
583            if ipiv[k] > 0 {
584                // 1x1 pivot — backward sub.
585                for i in (k + 1)..n {
586                    b_col[k] -= a[k * n + i] * b_col[i];
587                }
588                let p = (ipiv[k] - 1) as usize;
589                if p != k {
590                    b_col.swap(k, p);
591                }
592            } else if k > 0 && ipiv[k] < 0 && ipiv[k - 1] == ipiv[k] {
593                // 2x2 pivot — process both rows.
594                let k2 = k - 1;
595                for i in (k + 1)..n {
596                    b_col[k] -= a[k * n + i] * b_col[i]; // Note: this is L^T
597                    b_col[k2] -= a[k2 * n + i] * b_col[i];
598                }
599                let p = ((-ipiv[k]) - 1) as usize;
600                if p != k {
601                    b_col.swap(k, p);
602                }
603                k = k2; // Skip past the 2x2 block.
604            }
605        }
606    }
607
608    Ok(())
609}
610
611fn mirror_upper_to_lower(a: &mut [f64], n: usize) {
612    for col in 0..n {
613        for row in 0..col {
614            a[col * n + row] = a[row * n + col];
615        }
616    }
617}
618
619// ---------------------------------------------------------------------------
620// Device buffer read/write helpers (structural)
621// ---------------------------------------------------------------------------
622
623fn read_device_to_host<T: GpuFloat>(
624    buf: &DeviceBuffer<T>,
625    host: &mut [f64],
626    count: usize,
627) -> SolverResult<()> {
628    if host.len() < count {
629        return Err(SolverError::DimensionMismatch(format!(
630            "read_device_to_host: host buffer too small ({} < {})",
631            host.len(),
632            count
633        )));
634    }
635    let mut staged = vec![T::gpu_zero(); count];
636    buf.copy_to_host(&mut staged)?;
637    for (dst, src) in host.iter_mut().zip(staged.iter()) {
638        *dst = to_f64(*src);
639    }
640    Ok(())
641}
642
643fn write_host_to_device<T: GpuFloat>(
644    buf: &mut DeviceBuffer<T>,
645    data: &[T],
646    count: usize,
647) -> SolverResult<()> {
648    if data.len() < count {
649        return Err(SolverError::DimensionMismatch(format!(
650            "write_host_to_device: source buffer too small ({} < {})",
651            data.len(),
652            count
653        )));
654    }
655    buf.copy_from_host(&data[..count])?;
656    Ok(())
657}
658
659fn read_device_to_host_i32(
660    buf: &DeviceBuffer<i32>,
661    host: &mut [i32],
662    count: usize,
663) -> SolverResult<()> {
664    if host.len() < count {
665        return Err(SolverError::DimensionMismatch(format!(
666            "read_device_to_host_i32: host buffer too small ({} < {})",
667            host.len(),
668            count
669        )));
670    }
671    buf.copy_to_host(&mut host[..count])?;
672    Ok(())
673}
674
675fn write_host_to_device_i32(
676    buf: &mut DeviceBuffer<i32>,
677    data: &[i32],
678    count: usize,
679) -> SolverResult<()> {
680    if data.len() < count {
681        return Err(SolverError::DimensionMismatch(format!(
682            "write_host_to_device_i32: source buffer too small ({} < {})",
683            data.len(),
684            count
685        )));
686    }
687    buf.copy_from_host(&data[..count])?;
688    Ok(())
689}
690
691// ---------------------------------------------------------------------------
692// Tests
693// ---------------------------------------------------------------------------
694
695#[cfg(test)]
696mod tests {
697    use super::*;
698
699    #[test]
700    fn bunch_kaufman_alpha_value() {
701        let expected = (1.0_f64 + 17.0_f64.sqrt()) / 8.0;
702        assert!((BUNCH_KAUFMAN_ALPHA - expected).abs() < 1e-10);
703    }
704
705    #[test]
706    fn column_max_offdiag_lower() {
707        // 3x3 matrix (column-major):
708        // [1  0  0]
709        // [5  2  0]
710        // [3  7  4]
711        let a = [1.0, 5.0, 3.0, 0.0, 2.0, 7.0, 0.0, 0.0, 4.0];
712        let (max_val, max_idx) = column_max_offdiag(&a, 3, 0, true);
713        assert!((max_val - 5.0).abs() < 1e-15);
714        assert_eq!(max_idx, 1);
715    }
716
717    #[test]
718    fn column_max_offdiag_upper() {
719        let a = [1.0, 5.0, 3.0, 0.0, 2.0, 7.0, 0.0, 0.0, 4.0];
720        let (max_val, max_idx) = column_max_offdiag(&a, 3, 2, false);
721        // Column 2 entries above diagonal: a[2*3+0]=0.0, a[2*3+1]=0.0
722        assert!(max_val.abs() < 1e-15);
723        assert_eq!(max_idx, 2); // stays at col when nothing found
724    }
725
726    #[test]
727    fn swap_rows_and_cols_identity() {
728        // Swapping same index should be a no-op.
729        let mut a = [1.0, 0.0, 0.0, 1.0];
730        swap_rows_and_cols(&mut a, 2, 0, 0);
731        assert!((a[0] - 1.0).abs() < 1e-15);
732        assert!((a[3] - 1.0).abs() < 1e-15);
733    }
734
735    #[test]
736    fn swap_rows_and_cols_basic() {
737        // 2x2 identity, swap 0 and 1.
738        let mut a = [1.0, 0.0, 0.0, 1.0];
739        swap_rows_and_cols(&mut a, 2, 0, 1);
740        // After swap: [[1, 0], [0, 1]] becomes [[1, 0], [0, 1]] (symmetric swap).
741        // Actually for identity, swapping rows and cols gives identity back.
742        assert!((a[0] - 1.0).abs() < 1e-15);
743        assert!((a[3] - 1.0).abs() < 1e-15);
744    }
745
746    #[test]
747    fn perform_1x1_pivot_lower_basic() {
748        // 2x2 matrix: [[4, 2], [2, 3]] (column-major: [4, 2, 2, 3])
749        let mut a = [4.0, 2.0, 2.0, 3.0];
750        perform_1x1_pivot_lower(&mut a, 2, 0);
751        // L[1,0] = A[1,0] / A[0,0] = 2/4 = 0.5
752        assert!((a[1] - 0.5).abs() < 1e-15);
753        // A[1,1] -= L[1,0] * D[0,0] * L[1,0] = 3 - 0.5*4*0.5 = 3 - 1 = 2
754        assert!((a[3] - 2.0).abs() < 1e-15);
755    }
756
757    #[test]
758    fn bunch_kaufman_identity_3x3() {
759        // Identity matrix should trivially factorize.
760        let mut a = vec![0.0; 9];
761        a[0] = 1.0;
762        a[4] = 1.0;
763        a[8] = 1.0;
764        let mut ipiv = vec![0_i32; 3];
765        let result = bunch_kaufman_lower(&mut a, 3, &mut ipiv);
766        assert!(result.is_ok());
767        // All pivots should be 1x1 (positive).
768        assert!(ipiv[0] > 0);
769        assert!(ipiv[1] > 0);
770        assert!(ipiv[2] > 0);
771    }
772
773    #[test]
774    fn f64_conversion_roundtrip() {
775        let val = std::f64::consts::E;
776        let converted: f64 = from_f64(to_f64(val));
777        assert!((converted - val).abs() < 1e-15);
778    }
779
780    #[test]
781    fn f32_conversion_roundtrip() {
782        let val = std::f32::consts::E;
783        let as_f64 = to_f64(val);
784        let back: f32 = from_f64(as_f64);
785        assert!((back - val).abs() < 1e-5);
786    }
787}