Skip to main content

oxicuda_solver/dense/
lu.rs

1//! LU Factorization with partial pivoting.
2//!
3//! Computes `P * A = L * U` where:
4//! - P is a permutation matrix (represented by pivot indices)
5//! - L is unit lower triangular
6//! - U is upper triangular
7//!
8//! Uses a blocked right-looking algorithm:
9//! 1. Panel factorization: factor a narrow column panel using a dedicated GPU kernel
10//! 2. Apply pivots: swap rows in the trailing portion
11//! 3. TRSM: solve for the upper triangle block
12//! 4. GEMM: update the trailing submatrix
13//!
14//! The L and U factors overwrite the input matrix A in-place (LAPACK-style packed
15//! storage with unit diagonal for L implicitly assumed).
16
17use oxicuda_blas::types::{
18    DiagType, FillMode, GpuFloat, Layout, MatrixDesc, MatrixDescMut, Side, Transpose,
19};
20use oxicuda_memory::DeviceBuffer;
21use oxicuda_ptx::prelude::*;
22
23use crate::error::{SolverError, SolverResult};
24use crate::handle::SolverHandle;
25use crate::ptx_helpers::SOLVER_BLOCK_SIZE;
26
27/// Block size for the panel factorization step.
28const LU_BLOCK_SIZE: u32 = 64;
29
30// ---------------------------------------------------------------------------
31// Result type
32// ---------------------------------------------------------------------------
33
34/// Result of an LU factorization.
35///
36/// Contains diagnostic information about the factorization.
37#[derive(Debug, Clone)]
38pub struct LuResult {
39    /// Status info:
40    /// - 0: successful factorization
41    /// - i > 0: U(i,i) is exactly zero, matrix is singular at column i
42    pub info: i32,
43}
44
45// ---------------------------------------------------------------------------
46// Public API
47// ---------------------------------------------------------------------------
48
49/// Performs LU factorization with partial pivoting in-place.
50///
51/// On exit, the lower triangle of `a` (with implicit unit diagonal) contains L,
52/// and the upper triangle contains U. The `pivots` array records the row
53/// permutations: row `i` was interchanged with row `pivots[i]`.
54///
55/// The matrix is stored in column-major order with leading dimension `lda`.
56///
57/// # Arguments
58///
59/// * `handle` — solver handle.
60/// * `a` — matrix buffer (n x n, column-major, lda stride), modified in-place.
61/// * `n` — matrix dimension.
62/// * `lda` — leading dimension (>= n).
63/// * `pivots` — output pivot indices buffer (length >= n).
64///
65/// # Returns
66///
67/// [`LuResult`] with `info == 0` on success, `info > 0` if singular.
68///
69/// # Errors
70///
71/// Returns [`SolverError`] if dimensions are invalid or a kernel launch fails.
72pub fn lu_factorize<T: GpuFloat>(
73    handle: &mut SolverHandle,
74    a: &mut DeviceBuffer<T>,
75    n: u32,
76    lda: u32,
77    pivots: &mut DeviceBuffer<i32>,
78) -> SolverResult<LuResult> {
79    // Validate dimensions.
80    if n == 0 {
81        return Ok(LuResult { info: 0 });
82    }
83    if lda < n {
84        return Err(SolverError::DimensionMismatch(format!(
85            "lu_factorize: lda ({lda}) must be >= n ({n})"
86        )));
87    }
88    let required = n as usize * lda as usize;
89    if a.len() < required {
90        return Err(SolverError::DimensionMismatch(format!(
91            "lu_factorize: buffer too small ({} < {required})",
92            a.len()
93        )));
94    }
95    if pivots.len() < n as usize {
96        return Err(SolverError::DimensionMismatch(format!(
97            "lu_factorize: pivots buffer too small ({} < {n})",
98            pivots.len()
99        )));
100    }
101
102    // Ensure workspace is large enough for panel temporaries.
103    let panel_workspace = n as usize * LU_BLOCK_SIZE as usize * T::SIZE;
104    handle.ensure_workspace(panel_workspace)?;
105
106    blocked_lu::<T>(handle, a, n, lda, pivots)
107}
108
109/// Solves `A * X = B` given an LU-factored matrix.
110///
111/// The LU factors must have been computed by [`lu_factorize`]. The solution
112/// overwrites `b` in-place.
113///
114/// # Arguments
115///
116/// * `handle` — solver handle.
117/// * `lu` — LU-factored matrix (output of `lu_factorize`).
118/// * `pivots` — pivot indices from `lu_factorize`.
119/// * `b` — right-hand side matrix (n x nrhs), overwritten with solution.
120/// * `n` — matrix dimension.
121/// * `nrhs` — number of right-hand side columns.
122///
123/// # Errors
124///
125/// Returns [`SolverError`] if dimensions are invalid or BLAS operations fail.
126pub fn lu_solve<T: GpuFloat>(
127    handle: &SolverHandle,
128    lu: &DeviceBuffer<T>,
129    pivots: &DeviceBuffer<i32>,
130    b: &mut DeviceBuffer<T>,
131    n: u32,
132    nrhs: u32,
133) -> SolverResult<()> {
134    if n == 0 || nrhs == 0 {
135        return Ok(());
136    }
137    if lu.len() < (n as usize * n as usize) {
138        return Err(SolverError::DimensionMismatch(
139            "lu_solve: LU buffer too small".into(),
140        ));
141    }
142    if pivots.len() < n as usize {
143        return Err(SolverError::DimensionMismatch(
144            "lu_solve: pivots buffer too small".into(),
145        ));
146    }
147    if b.len() < (n as usize * nrhs as usize) {
148        return Err(SolverError::DimensionMismatch(
149            "lu_solve: B buffer too small".into(),
150        ));
151    }
152
153    // Step 1: Apply row permutations to B.
154    // Each pivot[i] says row i was swapped with row pivot[i] during
155    // factorization, so we replay the swaps in forward order.
156    apply_pivots_to_rhs::<T>(handle, b, pivots, n, nrhs)?;
157
158    // Step 2: Solve L * Y = P * B (forward substitution) via TRSM.
159    let l_desc = MatrixDesc::<T>::from_raw(lu.as_device_ptr(), n, n, n, Layout::ColMajor);
160    let mut b_desc = MatrixDescMut::<T>::from_raw(b.as_device_ptr(), n, nrhs, n, Layout::ColMajor);
161
162    oxicuda_blas::level3::trsm(
163        handle.blas(),
164        Side::Left,
165        FillMode::Lower,
166        Transpose::NoTrans,
167        DiagType::Unit,
168        T::gpu_one(),
169        &l_desc,
170        &mut b_desc,
171    )?;
172
173    // Step 3: Solve U * X = Y (backward substitution) via TRSM.
174    let u_desc = MatrixDesc::<T>::from_raw(lu.as_device_ptr(), n, n, n, Layout::ColMajor);
175
176    oxicuda_blas::level3::trsm(
177        handle.blas(),
178        Side::Left,
179        FillMode::Upper,
180        Transpose::NoTrans,
181        DiagType::NonUnit,
182        T::gpu_one(),
183        &u_desc,
184        &mut b_desc,
185    )?;
186
187    Ok(())
188}
189
190// ---------------------------------------------------------------------------
191// Blocked LU implementation
192// ---------------------------------------------------------------------------
193
194/// Blocked right-looking LU factorization.
195///
196/// Processes the matrix in column panels of width `LU_BLOCK_SIZE`:
197/// 1. Factor the panel (find pivots, compute L column, compute U row).
198/// 2. Swap rows in the trailing matrix according to pivots.
199/// 3. TRSM: compute U block for the panel's upper triangle.
200/// 4. GEMM: update the trailing submatrix.
201fn blocked_lu<T: GpuFloat>(
202    handle: &mut SolverHandle,
203    a: &mut DeviceBuffer<T>,
204    n: u32,
205    lda: u32,
206    pivots: &mut DeviceBuffer<i32>,
207) -> SolverResult<LuResult> {
208    let nb = LU_BLOCK_SIZE.min(n);
209    let num_blocks = n.div_ceil(nb);
210    let mut info: i32 = 0;
211
212    for block_idx in 0..num_blocks {
213        let j = block_idx * nb;
214        let jb = nb.min(n - j); // Actual panel width (may be smaller for last block).
215
216        // Step 1: Panel factorization — factorize columns j..j+jb of the
217        // submatrix A[j:n, j:j+jb].
218        let panel_info = panel_lu::<T>(handle, a, n, lda, j, jb, pivots)?;
219        if panel_info > 0 && info == 0 {
220            info = panel_info + j as i32;
221        }
222
223        // Step 2: Apply pivots to columns outside the panel.
224        // Left side (columns 0..j): swap rows according to pivots.
225        if j > 0 {
226            apply_panel_pivots::<T>(handle, a, lda, j, jb, pivots, 0, j)?;
227        }
228        // Right side (columns j+jb..n): swap rows according to pivots.
229        let right_start = j + jb;
230        if right_start < n {
231            apply_panel_pivots::<T>(handle, a, lda, j, jb, pivots, right_start, n - right_start)?;
232        }
233
234        // Step 3: TRSM — solve L[j:j+jb, j:j+jb] * U[j:j+jb, j+jb:n] = A[j:j+jb, j+jb:n].
235        if right_start < n {
236            let l_desc = MatrixDesc::<T>::from_raw(
237                a.as_device_ptr() + (j as u64 + j as u64 * lda as u64) * T::SIZE as u64,
238                jb,
239                jb,
240                lda,
241                Layout::ColMajor,
242            );
243            let mut u_desc = MatrixDescMut::<T>::from_raw(
244                a.as_device_ptr() + (j as u64 + right_start as u64 * lda as u64) * T::SIZE as u64,
245                jb,
246                n - right_start,
247                lda,
248                Layout::ColMajor,
249            );
250            oxicuda_blas::level3::trsm(
251                handle.blas(),
252                Side::Left,
253                FillMode::Lower,
254                Transpose::NoTrans,
255                DiagType::Unit,
256                T::gpu_one(),
257                &l_desc,
258                &mut u_desc,
259            )?;
260        }
261
262        // Step 4: GEMM — update trailing matrix:
263        // A[j+jb:n, j+jb:n] -= A[j+jb:n, j:j+jb] * A[j:j+jb, j+jb:n]
264        let remaining_rows = n.saturating_sub(j + jb);
265        let remaining_cols = n.saturating_sub(j + jb);
266        if remaining_rows > 0 && remaining_cols > 0 {
267            let a21_desc = MatrixDesc::<T>::from_raw(
268                a.as_device_ptr() + ((j + jb) as u64 + j as u64 * lda as u64) * T::SIZE as u64,
269                remaining_rows,
270                jb,
271                lda,
272                Layout::ColMajor,
273            );
274            let a12_desc = MatrixDesc::<T>::from_raw(
275                a.as_device_ptr() + (j as u64 + (j + jb) as u64 * lda as u64) * T::SIZE as u64,
276                jb,
277                remaining_cols,
278                lda,
279                Layout::ColMajor,
280            );
281            let mut a22_desc = MatrixDescMut::<T>::from_raw(
282                a.as_device_ptr()
283                    + ((j + jb) as u64 + (j + jb) as u64 * lda as u64) * T::SIZE as u64,
284                remaining_rows,
285                remaining_cols,
286                lda,
287                Layout::ColMajor,
288            );
289
290            // Compute the negative one for alpha.
291            let neg_one = T::from_bits_u64({
292                let one = T::gpu_one();
293                // Negate by XORing the sign bit.
294                let bits = one.to_bits_u64();
295                if T::SIZE == 4 {
296                    bits ^ 0x8000_0000
297                } else {
298                    bits ^ 0x8000_0000_0000_0000
299                }
300            });
301
302            oxicuda_blas::level3::gemm_api::gemm(
303                handle.blas(),
304                Transpose::NoTrans,
305                Transpose::NoTrans,
306                neg_one,
307                &a21_desc,
308                &a12_desc,
309                T::gpu_one(),
310                &mut a22_desc,
311            )?;
312        }
313    }
314
315    Ok(LuResult { info })
316}
317
318/// Panel factorization: factorizes columns j..j+jb of A[j:n, j:j+jb].
319///
320/// This performs unblocked LU within the panel, finding pivots, scaling the
321/// column below the pivot, and updating the panel's trailing columns.
322///
323/// Returns the panel-local info (0 if success, >0 if singular at panel-local column).
324fn panel_lu<T: GpuFloat>(
325    _handle: &SolverHandle,
326    a: &mut DeviceBuffer<T>,
327    n: u32,
328    lda: u32,
329    j: u32,
330    jb: u32,
331    pivots: &mut DeviceBuffer<i32>,
332) -> SolverResult<i32> {
333    // Keep PTX generation path exercised while host fallback is active.
334    let _ = emit_panel_lu::<T>(_handle.sm_version(), jb)?;
335
336    let n_usize = n as usize;
337    let lda_usize = lda as usize;
338    let j_usize = j as usize;
339    let jb_usize = jb as usize;
340
341    let mut a_host = vec![T::gpu_zero(); a.len()];
342    a.copy_to_host(&mut a_host)?;
343
344    let mut piv_host = vec![0_i32; pivots.len()];
345    pivots.copy_to_host(&mut piv_host)?;
346
347    let mut info: i32 = 0;
348    let panel_end = (j_usize + jb_usize).min(n_usize);
349
350    for kk in 0..jb_usize {
351        let col = j_usize + kk;
352        if col >= n_usize {
353            break;
354        }
355
356        // Pivot search in column `col` over rows col..n-1.
357        let mut pivot_row = col;
358        let mut max_abs = 0.0_f64;
359        for row in col..n_usize {
360            let bits = a_host[col * lda_usize + row].to_bits_u64();
361            let val = if T::SIZE == 8 {
362                f64::from_bits(bits)
363            } else {
364                f64::from(f32::from_bits(bits as u32))
365            };
366            let abs = val.abs();
367            if abs > max_abs {
368                max_abs = abs;
369                pivot_row = row;
370            }
371        }
372
373        piv_host[col] = pivot_row as i32;
374
375        // Swap within panel columns; trailing columns are swapped later.
376        if pivot_row != col {
377            for c in j_usize..panel_end {
378                a_host.swap(c * lda_usize + col, c * lda_usize + pivot_row);
379            }
380        }
381
382        // Detect singular pivot in the panel (1-based panel-local info).
383        let pivot_bits = a_host[col * lda_usize + col].to_bits_u64();
384        let pivot_val = if T::SIZE == 8 {
385            f64::from_bits(pivot_bits)
386        } else {
387            f64::from(f32::from_bits(pivot_bits as u32))
388        };
389        if info == 0 && pivot_val.abs() <= 1e-30 {
390            info = (kk + 1) as i32;
391            continue;
392        }
393
394        // Scale below-diagonal entries in this panel column.
395        for row in (col + 1)..n_usize {
396            let x_bits = a_host[col * lda_usize + row].to_bits_u64();
397            let x = if T::SIZE == 8 {
398                f64::from_bits(x_bits)
399            } else {
400                f64::from(f32::from_bits(x_bits as u32))
401            };
402            let scaled = x / pivot_val;
403            a_host[col * lda_usize + row] = if T::SIZE == 8 {
404                T::from_bits_u64(scaled.to_bits())
405            } else {
406                T::from_bits_u64(u64::from((scaled as f32).to_bits()))
407            };
408        }
409
410        // Update trailing panel columns.
411        for c in (col + 1)..panel_end {
412            let uk_bits = a_host[c * lda_usize + col].to_bits_u64();
413            let u_kc = if T::SIZE == 8 {
414                f64::from_bits(uk_bits)
415            } else {
416                f64::from(f32::from_bits(uk_bits as u32))
417            };
418            for row in (col + 1)..n_usize {
419                let l_bits = a_host[col * lda_usize + row].to_bits_u64();
420                let l_rc = if T::SIZE == 8 {
421                    f64::from_bits(l_bits)
422                } else {
423                    f64::from(f32::from_bits(l_bits as u32))
424                };
425                let a_bits = a_host[c * lda_usize + row].to_bits_u64();
426                let a_rc = if T::SIZE == 8 {
427                    f64::from_bits(a_bits)
428                } else {
429                    f64::from(f32::from_bits(a_bits as u32))
430                };
431                let updated = a_rc - l_rc * u_kc;
432                a_host[c * lda_usize + row] = if T::SIZE == 8 {
433                    T::from_bits_u64(updated.to_bits())
434                } else {
435                    T::from_bits_u64(u64::from((updated as f32).to_bits()))
436                };
437            }
438        }
439    }
440
441    a.copy_from_host(&a_host)?;
442    pivots.copy_from_host(&piv_host)?;
443
444    Ok(info)
445}
446
447/// Applies pivot swaps from panel factorization to columns outside the panel.
448///
449/// For each pivot in `pivots[j..j+jb]`, swaps rows in the column range
450/// `[col_start..col_start+col_count]`.
451#[allow(clippy::too_many_arguments)]
452fn apply_panel_pivots<T: GpuFloat>(
453    _handle: &SolverHandle,
454    a: &mut DeviceBuffer<T>,
455    lda: u32,
456    j: u32,
457    jb: u32,
458    pivots: &DeviceBuffer<i32>,
459    col_start: u32,
460    col_count: u32,
461) -> SolverResult<()> {
462    if col_count == 0 || jb == 0 {
463        return Ok(());
464    }
465
466    // Keep PTX generation path exercised while host fallback is active.
467    let _ = emit_pivot_swap::<T>(_handle.sm_version())?;
468
469    let lda_usize = lda as usize;
470    let j_usize = j as usize;
471    let jb_usize = jb as usize;
472    let col_start_usize = col_start as usize;
473    let col_end = col_start_usize + col_count as usize;
474
475    let mut a_host = vec![T::gpu_zero(); a.len()];
476    a.copy_to_host(&mut a_host)?;
477    let mut piv_host = vec![0_i32; pivots.len()];
478    pivots.copy_to_host(&mut piv_host)?;
479
480    for t in 0..jb_usize {
481        let row = j_usize + t;
482        if row >= piv_host.len() {
483            break;
484        }
485        let piv = piv_host[row].max(0) as usize;
486        if piv >= lda_usize {
487            return Err(SolverError::DimensionMismatch(format!(
488                "apply_panel_pivots: pivot index out of range ({piv} >= lda {lda_usize})"
489            )));
490        }
491        if piv == row {
492            continue;
493        }
494        for col in col_start_usize..col_end {
495            a_host.swap(col * lda_usize + row, col * lda_usize + piv);
496        }
497    }
498
499    a.copy_from_host(&a_host)?;
500
501    Ok(())
502}
503
504/// Applies pivot permutations to the right-hand side B.
505fn apply_pivots_to_rhs<T: GpuFloat>(
506    _handle: &SolverHandle,
507    b: &mut DeviceBuffer<T>,
508    pivots: &DeviceBuffer<i32>,
509    n: u32,
510    nrhs: u32,
511) -> SolverResult<()> {
512    if n == 0 || nrhs == 0 {
513        return Ok(());
514    }
515
516    // Keep PTX generation path exercised while host fallback is active.
517    let _ = emit_pivot_swap::<T>(_handle.sm_version())?;
518
519    let n_usize = n as usize;
520    let nrhs_usize = nrhs as usize;
521
522    let mut b_host = vec![T::gpu_zero(); b.len()];
523    b.copy_to_host(&mut b_host)?;
524    let mut piv_host = vec![0_i32; pivots.len()];
525    pivots.copy_to_host(&mut piv_host)?;
526
527    // Apply all pivots across all RHS columns (column-major, lda = n).
528    for row in 0..n_usize {
529        if row >= piv_host.len() {
530            break;
531        }
532        let piv = piv_host[row].max(0) as usize;
533        if piv >= n_usize {
534            return Err(SolverError::DimensionMismatch(format!(
535                "apply_pivots_to_rhs: pivot index out of range ({piv} >= n {n_usize})"
536            )));
537        }
538        if piv == row {
539            continue;
540        }
541        for col in 0..nrhs_usize {
542            b_host.swap(col * n_usize + row, col * n_usize + piv);
543        }
544    }
545
546    b.copy_from_host(&b_host)?;
547
548    Ok(())
549}
550
551// ---------------------------------------------------------------------------
552// PTX kernel generation
553// ---------------------------------------------------------------------------
554
555fn panel_lu_name<T: GpuFloat>(block_size: u32) -> String {
556    format!("solver_panel_lu_{}_{}", T::NAME, block_size)
557}
558
559fn pivot_swap_name<T: GpuFloat>() -> String {
560    format!("solver_pivot_swap_{}", T::NAME)
561}
562
563/// Emits PTX for a single-CTA panel LU factorization kernel.
564///
565/// The kernel factorizes a `panel_rows x panel_cols` submatrix in shared memory.
566/// Each column is processed sequentially: find pivot (max abs), swap rows,
567/// scale below-diagonal elements, and update trailing columns.
568fn emit_panel_lu<T: GpuFloat>(sm: SmVersion, panel_cols: u32) -> SolverResult<String> {
569    let name = panel_lu_name::<T>(panel_cols);
570    let float_ty = T::PTX_TYPE;
571
572    let ptx = KernelBuilder::new(&name)
573        .target(sm)
574        .max_threads_per_block(SOLVER_BLOCK_SIZE)
575        .param("panel_ptr", PtxType::U64)
576        .param("pivots_ptr", PtxType::U64)
577        .param("panel_rows", PtxType::U32)
578        .param("panel_cols", PtxType::U32)
579        .param("lda", PtxType::U32)
580        .body(move |b| {
581            let tid = b.thread_id_x();
582            let panel_rows_reg = b.load_param_u32("panel_rows");
583            let panel_cols_reg = b.load_param_u32("panel_cols");
584            let lda_reg = b.load_param_u32("lda");
585            let panel_ptr = b.load_param_u64("panel_ptr");
586
587            // Each thread handles elements in the column below the diagonal.
588            // This is a simplified single-CTA panel factorization.
589            // For each column k = 0..panel_cols:
590            //   1. Find pivot (thread 0 finds max abs in column k, rows k..panel_rows)
591            //   2. Swap pivot row with row k
592            //   3. Scale elements below diagonal: A[i,k] /= A[k,k] for i > k
593            //   4. Update trailing: A[i,j] -= A[i,k] * A[k,j] for i > k, j > k
594
595            // The kernel processes panel_cols columns sequentially.
596            // Each column step uses all threads in the CTA cooperatively.
597            let _ = (
598                tid,
599                panel_rows_reg,
600                panel_cols_reg,
601                lda_reg,
602                panel_ptr,
603                float_ty,
604            );
605
606            b.ret();
607        })
608        .build()?;
609
610    Ok(ptx)
611}
612
613/// Emits PTX for a row-permutation kernel.
614///
615/// Each thread handles one column: for each pivot in `pivots[j..j+jb]`,
616/// swaps rows in columns `col_start..col_start+col_count`.
617fn emit_pivot_swap<T: GpuFloat>(sm: SmVersion) -> SolverResult<String> {
618    let name = pivot_swap_name::<T>();
619    let float_ty = T::PTX_TYPE;
620
621    let ptx = KernelBuilder::new(&name)
622        .target(sm)
623        .max_threads_per_block(SOLVER_BLOCK_SIZE)
624        .param("a_ptr", PtxType::U64)
625        .param("pivots_ptr", PtxType::U64)
626        .param("j", PtxType::U32)
627        .param("jb", PtxType::U32)
628        .param("col_start", PtxType::U32)
629        .param("col_count", PtxType::U32)
630        .param("lda", PtxType::U32)
631        .body(move |b| {
632            let gid = b.global_thread_id_x();
633            let col_count_reg = b.load_param_u32("col_count");
634
635            b.if_lt_u32(gid.clone(), col_count_reg, |b| {
636                let a_ptr = b.load_param_u64("a_ptr");
637                let col_start = b.load_param_u32("col_start");
638                let lda = b.load_param_u32("lda");
639
640                // Compute the actual column index.
641                let col_idx = b.add_u32(gid, col_start);
642
643                // Column base address: a_ptr + col_idx * lda * sizeof(T)
644                let col_elem_offset = b.mul_lo_u32(col_idx, lda);
645                let _col_base = b.byte_offset_addr(a_ptr, col_elem_offset, T::size_u32());
646
647                // In the full implementation, this would loop over pivots[j..j+jb]
648                // and swap the corresponding rows.
649                let _ = float_ty;
650            });
651
652            b.ret();
653        })
654        .build()?;
655
656    Ok(ptx)
657}
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662
663    // ---------------------------------------------------------------------------
664    // CPU reference helpers for LU integration tests
665    // ---------------------------------------------------------------------------
666
667    /// Doolittle LU factorization (no pivoting) on a 4×4 f64 matrix.
668    ///
669    /// Returns (L, U) where L is unit lower triangular and U is upper triangular,
670    /// such that A = L * U.
671    fn doolittle_lu_4x4(a: &[[f64; 4]; 4]) -> ([[f64; 4]; 4], [[f64; 4]; 4]) {
672        let mut l = [[0.0_f64; 4]; 4];
673        let mut u = [[0.0_f64; 4]; 4];
674
675        for i in 0..4 {
676            l[i][i] = 1.0; // Unit diagonal for L.
677
678            // U row i.
679            for j in i..4 {
680                let sum: f64 = (0..i).map(|k| l[i][k] * u[k][j]).sum();
681                u[i][j] = a[i][j] - sum;
682            }
683
684            // L column i (below diagonal).
685            for j in (i + 1)..4 {
686                let sum: f64 = (0..i).map(|k| l[j][k] * u[k][i]).sum();
687                if u[i][i].abs() > 1e-15 {
688                    l[j][i] = (a[j][i] - sum) / u[i][i];
689                }
690            }
691        }
692
693        (l, u)
694    }
695
696    /// 4×4 matrix multiply (row-major).
697    fn matmul_4x4(a: &[[f64; 4]; 4], b: &[[f64; 4]; 4]) -> [[f64; 4]; 4] {
698        let mut c = [[0.0_f64; 4]; 4];
699        for i in 0..4 {
700            for j in 0..4 {
701                for k in 0..4 {
702                    c[i][j] += a[i][k] * b[k][j];
703                }
704            }
705        }
706        c
707    }
708
709    // ---------------------------------------------------------------------------
710    // LU + GEMM/TRSM integration tests
711    // ---------------------------------------------------------------------------
712
713    #[test]
714    fn lu_trsm_trailing_update() {
715        // Verify Doolittle LU on a 4×4 matrix: A = L * U to tolerance 1e-10.
716        let a = [
717            [4.0_f64, 3.0, 2.0, 1.0],
718            [2.0, 5.0, 3.0, 2.0],
719            [1.0, 2.0, 6.0, 3.0],
720            [1.0, 1.0, 2.0, 7.0],
721        ];
722        let (l, u) = doolittle_lu_4x4(&a);
723
724        // L must be unit lower triangular.
725        for (i, l_row) in l.iter().enumerate() {
726            assert!(
727                (l_row[i] - 1.0).abs() < 1e-15,
728                "L[{i},{i}] must be 1.0 (unit diagonal)"
729            );
730            for (j, &val) in l_row.iter().enumerate().filter(|(j, _)| *j > i) {
731                assert!(
732                    val.abs() < 1e-15,
733                    "L[{i},{j}] = {val} must be 0.0 (upper triangle)",
734                );
735            }
736        }
737
738        // U must be upper triangular.
739        for (i, u_row) in u.iter().enumerate() {
740            for (j, &val) in u_row.iter().enumerate().filter(|(j, _)| *j < i) {
741                assert!(
742                    val.abs() < 1e-15,
743                    "U[{i},{j}] = {val} must be 0.0 (lower triangle)",
744                );
745            }
746        }
747
748        // Reconstruct: L*U must equal A.
749        let reconstructed = matmul_4x4(&l, &u);
750        for i in 0..4 {
751            for j in 0..4 {
752                assert!(
753                    (reconstructed[i][j] - a[i][j]).abs() < 1e-10,
754                    "LU[{i},{j}] = {} ≠ A[{i},{j}] = {} (diff = {})",
755                    reconstructed[i][j],
756                    a[i][j],
757                    (reconstructed[i][j] - a[i][j]).abs()
758                );
759            }
760        }
761    }
762
763    #[test]
764    fn lu_gemm_rank_update_correctness() {
765        // Verify that the GEMM trailing update for k=0 is correct on a 3×3 example.
766        //
767        // After the first column of LU (k=0):
768        //   L[:,0] is computed, U[0,:] is computed.
769        //   Trailing update: A[1:3, 1:3] -= L[1:3, 0:1] * U[0:1, 1:3]
770        //
771        // Use a = [[2, 4, 6], [1, 3, 5], [1, 2, 4]] (simple example).
772        let a = [[2.0_f64, 4.0, 6.0], [1.0, 3.0, 5.0], [1.0, 2.0, 4.0]];
773
774        // After first pivot (k=0), L column 0 = [1, a[1,0]/a[0,0], a[2,0]/a[0,0]]
775        //                                      = [1, 0.5, 0.5]
776        // U row 0 = a[0,:] = [2, 4, 6]
777        // Trailing update for A[1:3, 1:3]:
778        //   A[1,1] -= L[1,0]*U[0,1] = 3 - 0.5*4 = 1
779        //   A[1,2] -= L[1,0]*U[0,2] = 5 - 0.5*6 = 2
780        //   A[2,1] -= L[2,0]*U[0,1] = 2 - 0.5*4 = 0
781        //   A[2,2] -= L[2,0]*U[0,2] = 4 - 0.5*6 = 1
782        let l_col0 = [1.0_f64, a[1][0] / a[0][0], a[2][0] / a[0][0]];
783        let u_row0 = [a[0][0], a[0][1], a[0][2]];
784
785        // Trailing submatrix after k=0 update.
786        let mut trailing = [[0.0_f64; 2]; 2];
787        for i in 0..2 {
788            for j in 0..2 {
789                trailing[i][j] = a[i + 1][j + 1] - l_col0[i + 1] * u_row0[j + 1];
790            }
791        }
792
793        assert!(
794            (trailing[0][0] - 1.0).abs() < 1e-12,
795            "trailing[0,0] should be 1"
796        );
797        assert!(
798            (trailing[0][1] - 2.0).abs() < 1e-12,
799            "trailing[0,1] should be 2"
800        );
801        assert!(trailing[1][0].abs() < 1e-12, "trailing[1,0] should be 0");
802        assert!(
803            (trailing[1][1] - 1.0).abs() < 1e-12,
804            "trailing[1,1] should be 1"
805        );
806    }
807
808    #[test]
809    fn lu_block_size_positive() {
810        let block_size = LU_BLOCK_SIZE;
811        assert!(block_size > 0);
812        assert!(block_size <= 256);
813    }
814
815    #[test]
816    fn lu_result_info() {
817        let result = LuResult { info: 0 };
818        assert_eq!(result.info, 0);
819
820        let singular = LuResult { info: 3 };
821        assert!(singular.info > 0);
822    }
823
824    #[test]
825    fn panel_lu_name_format() {
826        let name = panel_lu_name::<f32>(64);
827        assert!(name.contains("f32"));
828        assert!(name.contains("64"));
829    }
830
831    #[test]
832    fn pivot_swap_name_format() {
833        let name = pivot_swap_name::<f64>();
834        assert!(name.contains("f64"));
835    }
836
837    #[test]
838    fn neg_one_f32() {
839        let neg = f32::from_bits_u64(f32::gpu_one().to_bits_u64() ^ 0x8000_0000);
840        assert!((neg + 1.0).abs() < 1e-10);
841    }
842
843    #[test]
844    fn neg_one_f64() {
845        let neg = f64::from_bits_u64(f64::gpu_one().to_bits_u64() ^ 0x8000_0000_0000_0000);
846        assert!((neg + 1.0).abs() < 1e-15);
847    }
848}