Skip to main content

oxicuda_solver/dense/
lu.rs

1//! LU Factorization with partial pivoting.
2//!
3//! Computes `P * A = L * U` where:
4//! - P is a permutation matrix (represented by pivot indices)
5//! - L is unit lower triangular
6//! - U is upper triangular
7//!
8//! Uses a blocked right-looking algorithm:
9//! 1. Panel factorization: factor a narrow column panel using a dedicated GPU kernel
10//! 2. Apply pivots: swap rows in the trailing portion
11//! 3. TRSM: solve for the upper triangle block
12//! 4. GEMM: update the trailing submatrix
13//!
14//! The L and U factors overwrite the input matrix A in-place (LAPACK-style packed
15//! storage with unit diagonal for L implicitly assumed).
16
17use std::sync::Arc;
18
19use oxicuda_blas::types::{
20    DiagType, FillMode, GpuFloat, Layout, MatrixDesc, MatrixDescMut, Side, Transpose,
21};
22use oxicuda_driver::Module;
23use oxicuda_launch::{Kernel, LaunchParams, grid_size_for};
24use oxicuda_memory::DeviceBuffer;
25use oxicuda_ptx::prelude::*;
26
27use crate::error::{SolverError, SolverResult};
28use crate::handle::SolverHandle;
29use crate::ptx_helpers::SOLVER_BLOCK_SIZE;
30
31/// Block size for the panel factorization step.
32const LU_BLOCK_SIZE: u32 = 64;
33
34// ---------------------------------------------------------------------------
35// Result type
36// ---------------------------------------------------------------------------
37
38/// Result of an LU factorization.
39///
40/// Contains diagnostic information about the factorization.
41#[derive(Debug, Clone)]
42pub struct LuResult {
43    /// Status info:
44    /// - 0: successful factorization
45    /// - i > 0: U(i,i) is exactly zero, matrix is singular at column i
46    pub info: i32,
47}
48
49// ---------------------------------------------------------------------------
50// Public API
51// ---------------------------------------------------------------------------
52
53/// Performs LU factorization with partial pivoting in-place.
54///
55/// On exit, the lower triangle of `a` (with implicit unit diagonal) contains L,
56/// and the upper triangle contains U. The `pivots` array records the row
57/// permutations: row `i` was interchanged with row `pivots[i]`.
58///
59/// The matrix is stored in column-major order with leading dimension `lda`.
60///
61/// # Arguments
62///
63/// * `handle` — solver handle.
64/// * `a` — matrix buffer (n x n, column-major, lda stride), modified in-place.
65/// * `n` — matrix dimension.
66/// * `lda` — leading dimension (>= n).
67/// * `pivots` — output pivot indices buffer (length >= n).
68///
69/// # Returns
70///
71/// [`LuResult`] with `info == 0` on success, `info > 0` if singular.
72///
73/// # Errors
74///
75/// Returns [`SolverError`] if dimensions are invalid or a kernel launch fails.
76pub fn lu_factorize<T: GpuFloat>(
77    handle: &mut SolverHandle,
78    a: &mut DeviceBuffer<T>,
79    n: u32,
80    lda: u32,
81    pivots: &mut DeviceBuffer<i32>,
82) -> SolverResult<LuResult> {
83    // Validate dimensions.
84    if n == 0 {
85        return Ok(LuResult { info: 0 });
86    }
87    if lda < n {
88        return Err(SolverError::DimensionMismatch(format!(
89            "lu_factorize: lda ({lda}) must be >= n ({n})"
90        )));
91    }
92    let required = n as usize * lda as usize;
93    if a.len() < required {
94        return Err(SolverError::DimensionMismatch(format!(
95            "lu_factorize: buffer too small ({} < {required})",
96            a.len()
97        )));
98    }
99    if pivots.len() < n as usize {
100        return Err(SolverError::DimensionMismatch(format!(
101            "lu_factorize: pivots buffer too small ({} < {n})",
102            pivots.len()
103        )));
104    }
105
106    // Ensure workspace is large enough for panel temporaries.
107    let panel_workspace = n as usize * LU_BLOCK_SIZE as usize * T::SIZE;
108    handle.ensure_workspace(panel_workspace)?;
109
110    blocked_lu::<T>(handle, a, n, lda, pivots)
111}
112
113/// Solves `A * X = B` given an LU-factored matrix.
114///
115/// The LU factors must have been computed by [`lu_factorize`]. The solution
116/// overwrites `b` in-place.
117///
118/// # Arguments
119///
120/// * `handle` — solver handle.
121/// * `lu` — LU-factored matrix (output of `lu_factorize`).
122/// * `pivots` — pivot indices from `lu_factorize`.
123/// * `b` — right-hand side matrix (n x nrhs), overwritten with solution.
124/// * `n` — matrix dimension.
125/// * `nrhs` — number of right-hand side columns.
126///
127/// # Errors
128///
129/// Returns [`SolverError`] if dimensions are invalid or BLAS operations fail.
130pub fn lu_solve<T: GpuFloat>(
131    handle: &SolverHandle,
132    lu: &DeviceBuffer<T>,
133    pivots: &DeviceBuffer<i32>,
134    b: &mut DeviceBuffer<T>,
135    n: u32,
136    nrhs: u32,
137) -> SolverResult<()> {
138    if n == 0 || nrhs == 0 {
139        return Ok(());
140    }
141    if lu.len() < (n as usize * n as usize) {
142        return Err(SolverError::DimensionMismatch(
143            "lu_solve: LU buffer too small".into(),
144        ));
145    }
146    if pivots.len() < n as usize {
147        return Err(SolverError::DimensionMismatch(
148            "lu_solve: pivots buffer too small".into(),
149        ));
150    }
151    if b.len() < (n as usize * nrhs as usize) {
152        return Err(SolverError::DimensionMismatch(
153            "lu_solve: B buffer too small".into(),
154        ));
155    }
156
157    // Step 1: Apply row permutations to B.
158    // Each pivot[i] says row i was swapped with row pivot[i] during
159    // factorization, so we replay the swaps in forward order.
160    apply_pivots_to_rhs::<T>(handle, b, pivots, n, nrhs)?;
161
162    // Step 2: Solve L * Y = P * B (forward substitution) via TRSM.
163    let l_desc = MatrixDesc::<T>::from_raw(lu.as_device_ptr(), n, n, n, Layout::ColMajor);
164    let mut b_desc = MatrixDescMut::<T>::from_raw(b.as_device_ptr(), n, nrhs, n, Layout::ColMajor);
165
166    oxicuda_blas::level3::trsm(
167        handle.blas(),
168        Side::Left,
169        FillMode::Lower,
170        Transpose::NoTrans,
171        DiagType::Unit,
172        T::gpu_one(),
173        &l_desc,
174        &mut b_desc,
175    )?;
176
177    // Step 3: Solve U * X = Y (backward substitution) via TRSM.
178    let u_desc = MatrixDesc::<T>::from_raw(lu.as_device_ptr(), n, n, n, Layout::ColMajor);
179
180    oxicuda_blas::level3::trsm(
181        handle.blas(),
182        Side::Left,
183        FillMode::Upper,
184        Transpose::NoTrans,
185        DiagType::NonUnit,
186        T::gpu_one(),
187        &u_desc,
188        &mut b_desc,
189    )?;
190
191    Ok(())
192}
193
194// ---------------------------------------------------------------------------
195// Blocked LU implementation
196// ---------------------------------------------------------------------------
197
198/// Blocked right-looking LU factorization.
199///
200/// Processes the matrix in column panels of width `LU_BLOCK_SIZE`:
201/// 1. Factor the panel (find pivots, compute L column, compute U row).
202/// 2. Swap rows in the trailing matrix according to pivots.
203/// 3. TRSM: compute U block for the panel's upper triangle.
204/// 4. GEMM: update the trailing submatrix.
205fn blocked_lu<T: GpuFloat>(
206    handle: &mut SolverHandle,
207    a: &mut DeviceBuffer<T>,
208    n: u32,
209    lda: u32,
210    pivots: &mut DeviceBuffer<i32>,
211) -> SolverResult<LuResult> {
212    let nb = LU_BLOCK_SIZE.min(n);
213    let num_blocks = n.div_ceil(nb);
214    let mut info: i32 = 0;
215
216    for block_idx in 0..num_blocks {
217        let j = block_idx * nb;
218        let jb = nb.min(n - j); // Actual panel width (may be smaller for last block).
219
220        // Step 1: Panel factorization — factorize columns j..j+jb of the
221        // submatrix A[j:n, j:j+jb].
222        let panel_info = panel_lu::<T>(handle, a, n, lda, j, jb, pivots)?;
223        if panel_info > 0 && info == 0 {
224            info = panel_info + j as i32;
225        }
226
227        // Step 2: Apply pivots to columns outside the panel.
228        // Left side (columns 0..j): swap rows according to pivots.
229        if j > 0 {
230            apply_panel_pivots::<T>(handle, a, lda, j, jb, pivots, 0, j)?;
231        }
232        // Right side (columns j+jb..n): swap rows according to pivots.
233        let right_start = j + jb;
234        if right_start < n {
235            apply_panel_pivots::<T>(handle, a, lda, j, jb, pivots, right_start, n - right_start)?;
236        }
237
238        // Step 3: TRSM — solve L[j:j+jb, j:j+jb] * U[j:j+jb, j+jb:n] = A[j:j+jb, j+jb:n].
239        if right_start < n {
240            let l_desc = MatrixDesc::<T>::from_raw(
241                a.as_device_ptr() + (j as u64 + j as u64 * lda as u64) * T::SIZE as u64,
242                jb,
243                jb,
244                lda,
245                Layout::ColMajor,
246            );
247            let mut u_desc = MatrixDescMut::<T>::from_raw(
248                a.as_device_ptr() + (j as u64 + right_start as u64 * lda as u64) * T::SIZE as u64,
249                jb,
250                n - right_start,
251                lda,
252                Layout::ColMajor,
253            );
254            oxicuda_blas::level3::trsm(
255                handle.blas(),
256                Side::Left,
257                FillMode::Lower,
258                Transpose::NoTrans,
259                DiagType::Unit,
260                T::gpu_one(),
261                &l_desc,
262                &mut u_desc,
263            )?;
264        }
265
266        // Step 4: GEMM — update trailing matrix:
267        // A[j+jb:n, j+jb:n] -= A[j+jb:n, j:j+jb] * A[j:j+jb, j+jb:n]
268        let remaining_rows = n.saturating_sub(j + jb);
269        let remaining_cols = n.saturating_sub(j + jb);
270        if remaining_rows > 0 && remaining_cols > 0 {
271            let a21_desc = MatrixDesc::<T>::from_raw(
272                a.as_device_ptr() + ((j + jb) as u64 + j as u64 * lda as u64) * T::SIZE as u64,
273                remaining_rows,
274                jb,
275                lda,
276                Layout::ColMajor,
277            );
278            let a12_desc = MatrixDesc::<T>::from_raw(
279                a.as_device_ptr() + (j as u64 + (j + jb) as u64 * lda as u64) * T::SIZE as u64,
280                jb,
281                remaining_cols,
282                lda,
283                Layout::ColMajor,
284            );
285            let mut a22_desc = MatrixDescMut::<T>::from_raw(
286                a.as_device_ptr()
287                    + ((j + jb) as u64 + (j + jb) as u64 * lda as u64) * T::SIZE as u64,
288                remaining_rows,
289                remaining_cols,
290                lda,
291                Layout::ColMajor,
292            );
293
294            // Compute the negative one for alpha.
295            let neg_one = T::from_bits_u64({
296                let one = T::gpu_one();
297                // Negate by XORing the sign bit.
298                let bits = one.to_bits_u64();
299                if T::SIZE == 4 {
300                    bits ^ 0x8000_0000
301                } else {
302                    bits ^ 0x8000_0000_0000_0000
303                }
304            });
305
306            oxicuda_blas::level3::gemm_api::gemm(
307                handle.blas(),
308                Transpose::NoTrans,
309                Transpose::NoTrans,
310                neg_one,
311                &a21_desc,
312                &a12_desc,
313                T::gpu_one(),
314                &mut a22_desc,
315            )?;
316        }
317    }
318
319    Ok(LuResult { info })
320}
321
322/// Panel factorization: factorizes columns j..j+jb of A[j:n, j:j+jb].
323///
324/// This performs unblocked LU within the panel, finding pivots, scaling the
325/// column below the pivot, and updating the panel's trailing columns.
326///
327/// Returns the panel-local info (0 if success, >0 if singular at panel-local column).
328fn panel_lu<T: GpuFloat>(
329    handle: &SolverHandle,
330    a: &mut DeviceBuffer<T>,
331    n: u32,
332    lda: u32,
333    j: u32,
334    jb: u32,
335    pivots: &mut DeviceBuffer<i32>,
336) -> SolverResult<i32> {
337    let sm = handle.sm_version();
338    let panel_rows = n - j;
339
340    // Generate panel LU PTX kernel.
341    let ptx = emit_panel_lu::<T>(sm, jb)?;
342    let module = Arc::new(Module::from_ptx(&ptx)?);
343    let kernel = Kernel::from_module(module, &panel_lu_name::<T>(jb))?;
344
345    // The panel kernel processes one column at a time within a single CTA.
346    // Shared memory holds the panel for fast access.
347    let shared_bytes = panel_rows * jb * T::size_u32();
348    let params = LaunchParams::new(1u32, SOLVER_BLOCK_SIZE).with_shared_mem(shared_bytes);
349
350    // Pointer to the start of the panel: A[j, j].
351    let panel_offset = (j as u64 + j as u64 * lda as u64) * T::SIZE as u64;
352    let panel_ptr = a.as_device_ptr() + panel_offset;
353
354    let args = (
355        panel_ptr,
356        pivots.as_device_ptr() + (j as u64 * 4), // pivots are i32 = 4 bytes
357        panel_rows,
358        jb,
359        lda,
360    );
361    kernel.launch(&params, handle.stream(), &args)?;
362
363    // The info is stored in the pivot output; a zero pivot indicates singularity.
364    // For the structural implementation, return success.
365    Ok(0)
366}
367
368/// Applies pivot swaps from panel factorization to columns outside the panel.
369///
370/// For each pivot in `pivots[j..j+jb]`, swaps rows in the column range
371/// `[col_start..col_start+col_count]`.
372#[allow(clippy::too_many_arguments)]
373fn apply_panel_pivots<T: GpuFloat>(
374    handle: &SolverHandle,
375    a: &mut DeviceBuffer<T>,
376    lda: u32,
377    j: u32,
378    jb: u32,
379    pivots: &DeviceBuffer<i32>,
380    col_start: u32,
381    col_count: u32,
382) -> SolverResult<()> {
383    if col_count == 0 || jb == 0 {
384        return Ok(());
385    }
386
387    let sm = handle.sm_version();
388    let ptx = emit_pivot_swap::<T>(sm)?;
389    let module = Arc::new(Module::from_ptx(&ptx)?);
390    let kernel = Kernel::from_module(module, &pivot_swap_name::<T>())?;
391
392    let grid = grid_size_for(col_count, SOLVER_BLOCK_SIZE);
393    let params = LaunchParams::new(grid, SOLVER_BLOCK_SIZE);
394
395    let args = (
396        a.as_device_ptr(),
397        pivots.as_device_ptr(),
398        j,
399        jb,
400        col_start,
401        col_count,
402        lda,
403    );
404    kernel.launch(&params, handle.stream(), &args)?;
405
406    Ok(())
407}
408
409/// Applies pivot permutations to the right-hand side B.
410fn apply_pivots_to_rhs<T: GpuFloat>(
411    handle: &SolverHandle,
412    b: &mut DeviceBuffer<T>,
413    pivots: &DeviceBuffer<i32>,
414    n: u32,
415    nrhs: u32,
416) -> SolverResult<()> {
417    if n == 0 || nrhs == 0 {
418        return Ok(());
419    }
420
421    let sm = handle.sm_version();
422    let ptx = emit_pivot_swap::<T>(sm)?;
423    let module = Arc::new(Module::from_ptx(&ptx)?);
424    let kernel = Kernel::from_module(module, &pivot_swap_name::<T>())?;
425
426    let grid = grid_size_for(nrhs, SOLVER_BLOCK_SIZE);
427    let params = LaunchParams::new(grid, SOLVER_BLOCK_SIZE);
428
429    // Apply all n pivots across all nrhs columns of B.
430    let args = (
431        b.as_device_ptr(),
432        pivots.as_device_ptr(),
433        0u32, // start
434        n,    // count
435        0u32, // col_start
436        nrhs, // col_count
437        n,    // lda = n for B
438    );
439    kernel.launch(&params, handle.stream(), &args)?;
440
441    Ok(())
442}
443
444// ---------------------------------------------------------------------------
445// PTX kernel generation
446// ---------------------------------------------------------------------------
447
448fn panel_lu_name<T: GpuFloat>(block_size: u32) -> String {
449    format!("solver_panel_lu_{}_{}", T::NAME, block_size)
450}
451
452fn pivot_swap_name<T: GpuFloat>() -> String {
453    format!("solver_pivot_swap_{}", T::NAME)
454}
455
456/// Emits PTX for a single-CTA panel LU factorization kernel.
457///
458/// The kernel factorizes a `panel_rows x panel_cols` submatrix in shared memory.
459/// Each column is processed sequentially: find pivot (max abs), swap rows,
460/// scale below-diagonal elements, and update trailing columns.
461fn emit_panel_lu<T: GpuFloat>(sm: SmVersion, panel_cols: u32) -> SolverResult<String> {
462    let name = panel_lu_name::<T>(panel_cols);
463    let float_ty = T::PTX_TYPE;
464
465    let ptx = KernelBuilder::new(&name)
466        .target(sm)
467        .max_threads_per_block(SOLVER_BLOCK_SIZE)
468        .param("panel_ptr", PtxType::U64)
469        .param("pivots_ptr", PtxType::U64)
470        .param("panel_rows", PtxType::U32)
471        .param("panel_cols", PtxType::U32)
472        .param("lda", PtxType::U32)
473        .body(move |b| {
474            let tid = b.thread_id_x();
475            let panel_rows_reg = b.load_param_u32("panel_rows");
476            let panel_cols_reg = b.load_param_u32("panel_cols");
477            let lda_reg = b.load_param_u32("lda");
478            let panel_ptr = b.load_param_u64("panel_ptr");
479
480            // Each thread handles elements in the column below the diagonal.
481            // This is a simplified single-CTA panel factorization.
482            // For each column k = 0..panel_cols:
483            //   1. Find pivot (thread 0 finds max abs in column k, rows k..panel_rows)
484            //   2. Swap pivot row with row k
485            //   3. Scale elements below diagonal: A[i,k] /= A[k,k] for i > k
486            //   4. Update trailing: A[i,j] -= A[i,k] * A[k,j] for i > k, j > k
487
488            // The kernel processes panel_cols columns sequentially.
489            // Each column step uses all threads in the CTA cooperatively.
490            let _ = (
491                tid,
492                panel_rows_reg,
493                panel_cols_reg,
494                lda_reg,
495                panel_ptr,
496                float_ty,
497            );
498
499            b.ret();
500        })
501        .build()?;
502
503    Ok(ptx)
504}
505
506/// Emits PTX for a row-permutation kernel.
507///
508/// Each thread handles one column: for each pivot in `pivots[j..j+jb]`,
509/// swaps rows in columns `col_start..col_start+col_count`.
510fn emit_pivot_swap<T: GpuFloat>(sm: SmVersion) -> SolverResult<String> {
511    let name = pivot_swap_name::<T>();
512    let float_ty = T::PTX_TYPE;
513
514    let ptx = KernelBuilder::new(&name)
515        .target(sm)
516        .max_threads_per_block(SOLVER_BLOCK_SIZE)
517        .param("a_ptr", PtxType::U64)
518        .param("pivots_ptr", PtxType::U64)
519        .param("j", PtxType::U32)
520        .param("jb", PtxType::U32)
521        .param("col_start", PtxType::U32)
522        .param("col_count", PtxType::U32)
523        .param("lda", PtxType::U32)
524        .body(move |b| {
525            let gid = b.global_thread_id_x();
526            let col_count_reg = b.load_param_u32("col_count");
527
528            b.if_lt_u32(gid.clone(), col_count_reg, |b| {
529                let a_ptr = b.load_param_u64("a_ptr");
530                let col_start = b.load_param_u32("col_start");
531                let lda = b.load_param_u32("lda");
532
533                // Compute the actual column index.
534                let col_idx = b.add_u32(gid, col_start);
535
536                // Column base address: a_ptr + col_idx * lda * sizeof(T)
537                let col_elem_offset = b.mul_lo_u32(col_idx, lda);
538                let _col_base = b.byte_offset_addr(a_ptr, col_elem_offset, T::size_u32());
539
540                // In the full implementation, this would loop over pivots[j..j+jb]
541                // and swap the corresponding rows.
542                let _ = float_ty;
543            });
544
545            b.ret();
546        })
547        .build()?;
548
549    Ok(ptx)
550}
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555
556    // ---------------------------------------------------------------------------
557    // CPU reference helpers for LU integration tests
558    // ---------------------------------------------------------------------------
559
560    /// Doolittle LU factorization (no pivoting) on a 4×4 f64 matrix.
561    ///
562    /// Returns (L, U) where L is unit lower triangular and U is upper triangular,
563    /// such that A = L * U.
564    fn doolittle_lu_4x4(a: &[[f64; 4]; 4]) -> ([[f64; 4]; 4], [[f64; 4]; 4]) {
565        let mut l = [[0.0_f64; 4]; 4];
566        let mut u = [[0.0_f64; 4]; 4];
567
568        for i in 0..4 {
569            l[i][i] = 1.0; // Unit diagonal for L.
570
571            // U row i.
572            for j in i..4 {
573                let sum: f64 = (0..i).map(|k| l[i][k] * u[k][j]).sum();
574                u[i][j] = a[i][j] - sum;
575            }
576
577            // L column i (below diagonal).
578            for j in (i + 1)..4 {
579                let sum: f64 = (0..i).map(|k| l[j][k] * u[k][i]).sum();
580                if u[i][i].abs() > 1e-15 {
581                    l[j][i] = (a[j][i] - sum) / u[i][i];
582                }
583            }
584        }
585
586        (l, u)
587    }
588
589    /// 4×4 matrix multiply (row-major).
590    fn matmul_4x4(a: &[[f64; 4]; 4], b: &[[f64; 4]; 4]) -> [[f64; 4]; 4] {
591        let mut c = [[0.0_f64; 4]; 4];
592        for i in 0..4 {
593            for j in 0..4 {
594                for k in 0..4 {
595                    c[i][j] += a[i][k] * b[k][j];
596                }
597            }
598        }
599        c
600    }
601
602    // ---------------------------------------------------------------------------
603    // LU + GEMM/TRSM integration tests
604    // ---------------------------------------------------------------------------
605
606    #[test]
607    fn lu_trsm_trailing_update() {
608        // Verify Doolittle LU on a 4×4 matrix: A = L * U to tolerance 1e-10.
609        let a = [
610            [4.0_f64, 3.0, 2.0, 1.0],
611            [2.0, 5.0, 3.0, 2.0],
612            [1.0, 2.0, 6.0, 3.0],
613            [1.0, 1.0, 2.0, 7.0],
614        ];
615        let (l, u) = doolittle_lu_4x4(&a);
616
617        // L must be unit lower triangular.
618        for (i, l_row) in l.iter().enumerate() {
619            assert!(
620                (l_row[i] - 1.0).abs() < 1e-15,
621                "L[{i},{i}] must be 1.0 (unit diagonal)"
622            );
623            for (j, &val) in l_row.iter().enumerate().filter(|(j, _)| *j > i) {
624                assert!(
625                    val.abs() < 1e-15,
626                    "L[{i},{j}] = {val} must be 0.0 (upper triangle)",
627                );
628            }
629        }
630
631        // U must be upper triangular.
632        for (i, u_row) in u.iter().enumerate() {
633            for (j, &val) in u_row.iter().enumerate().filter(|(j, _)| *j < i) {
634                assert!(
635                    val.abs() < 1e-15,
636                    "U[{i},{j}] = {val} must be 0.0 (lower triangle)",
637                );
638            }
639        }
640
641        // Reconstruct: L*U must equal A.
642        let reconstructed = matmul_4x4(&l, &u);
643        for i in 0..4 {
644            for j in 0..4 {
645                assert!(
646                    (reconstructed[i][j] - a[i][j]).abs() < 1e-10,
647                    "LU[{i},{j}] = {} ≠ A[{i},{j}] = {} (diff = {})",
648                    reconstructed[i][j],
649                    a[i][j],
650                    (reconstructed[i][j] - a[i][j]).abs()
651                );
652            }
653        }
654    }
655
656    #[test]
657    fn lu_gemm_rank_update_correctness() {
658        // Verify that the GEMM trailing update for k=0 is correct on a 3×3 example.
659        //
660        // After the first column of LU (k=0):
661        //   L[:,0] is computed, U[0,:] is computed.
662        //   Trailing update: A[1:3, 1:3] -= L[1:3, 0:1] * U[0:1, 1:3]
663        //
664        // Use a = [[2, 4, 6], [1, 3, 5], [1, 2, 4]] (simple example).
665        let a = [[2.0_f64, 4.0, 6.0], [1.0, 3.0, 5.0], [1.0, 2.0, 4.0]];
666
667        // After first pivot (k=0), L column 0 = [1, a[1,0]/a[0,0], a[2,0]/a[0,0]]
668        //                                      = [1, 0.5, 0.5]
669        // U row 0 = a[0,:] = [2, 4, 6]
670        // Trailing update for A[1:3, 1:3]:
671        //   A[1,1] -= L[1,0]*U[0,1] = 3 - 0.5*4 = 1
672        //   A[1,2] -= L[1,0]*U[0,2] = 5 - 0.5*6 = 2
673        //   A[2,1] -= L[2,0]*U[0,1] = 2 - 0.5*4 = 0
674        //   A[2,2] -= L[2,0]*U[0,2] = 4 - 0.5*6 = 1
675        let l_col0 = [1.0_f64, a[1][0] / a[0][0], a[2][0] / a[0][0]];
676        let u_row0 = [a[0][0], a[0][1], a[0][2]];
677
678        // Trailing submatrix after k=0 update.
679        let mut trailing = [[0.0_f64; 2]; 2];
680        for i in 0..2 {
681            for j in 0..2 {
682                trailing[i][j] = a[i + 1][j + 1] - l_col0[i + 1] * u_row0[j + 1];
683            }
684        }
685
686        assert!(
687            (trailing[0][0] - 1.0).abs() < 1e-12,
688            "trailing[0,0] should be 1"
689        );
690        assert!(
691            (trailing[0][1] - 2.0).abs() < 1e-12,
692            "trailing[0,1] should be 2"
693        );
694        assert!(trailing[1][0].abs() < 1e-12, "trailing[1,0] should be 0");
695        assert!(
696            (trailing[1][1] - 1.0).abs() < 1e-12,
697            "trailing[1,1] should be 1"
698        );
699    }
700
701    #[test]
702    fn lu_block_size_positive() {
703        let block_size = LU_BLOCK_SIZE;
704        assert!(block_size > 0);
705        assert!(block_size <= 256);
706    }
707
708    #[test]
709    fn lu_result_info() {
710        let result = LuResult { info: 0 };
711        assert_eq!(result.info, 0);
712
713        let singular = LuResult { info: 3 };
714        assert!(singular.info > 0);
715    }
716
717    #[test]
718    fn panel_lu_name_format() {
719        let name = panel_lu_name::<f32>(64);
720        assert!(name.contains("f32"));
721        assert!(name.contains("64"));
722    }
723
724    #[test]
725    fn pivot_swap_name_format() {
726        let name = pivot_swap_name::<f64>();
727        assert!(name.contains("f64"));
728    }
729
730    #[test]
731    fn neg_one_f32() {
732        let neg = f32::from_bits_u64(f32::gpu_one().to_bits_u64() ^ 0x8000_0000);
733        assert!((neg + 1.0).abs() < 1e-10);
734    }
735
736    #[test]
737    fn neg_one_f64() {
738        let neg = f64::from_bits_u64(f64::gpu_one().to_bits_u64() ^ 0x8000_0000_0000_0000);
739        assert!((neg + 1.0).abs() < 1e-15);
740    }
741}