runmat_runtime/builtins/math/linalg/factor/
lu.rs

1//! MATLAB-compatible `lu` builtin with CPU-backed semantics.
2
3use crate::builtins::common::spec::{
4    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
5    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
6};
7use crate::builtins::common::{gpu_helpers, tensor};
8use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
9use num_complex::Complex64;
10use runmat_accelerate_api::{GpuTensorHandle, ProviderLuResult};
11use runmat_builtins::{ComplexTensor, Tensor, Value};
12use runmat_macros::runtime_builtin;
13
14#[cfg(feature = "doc_export")]
15use crate::register_builtin_doc_text;
16
17#[cfg(feature = "doc_export")]
18pub const DOC_MD: &str = r#"---
19title: "lu"
20category: "math/linalg/factor"
21keywords: ["lu", "factorization", "decomposition", "lower-upper", "permutation"]
22summary: "LU decomposition with partial pivoting, matching MATLAB semantics."
23references: []
24gpu_support:
25  elementwise: false
26  reduction: false
27  precisions: ["f64"]
28  broadcasting: "none"
29  notes: "Uses the provider `lu` hook when available; otherwise gathers to the host fallback implementation."
30fusion:
31  elementwise: false
32  reduction: false
33  max_inputs: 1
34  constants: "inline"
35requires_feature: null
36tested:
37  unit: "builtins::math::linalg::factor::lu::tests"
38  integration: "builtins::math::linalg::factor::lu::tests::lu_three_outputs_matches_factorization"
39---
40
41# What does the `lu` function do in MATLAB / RunMat?
42`lu(A)` computes the LU factorization of a real or complex matrix `A` using partial pivoting. It exposes the same calling forms as MATLAB:
43
44- Single output: `lu(A)` returns a single matrix whose strictly lower-triangular entries encode `L` (with an implicit unit diagonal) and whose upper-triangular part encodes `U`.
45- Two outputs: `[L, U] = lu(A)` returns the explicit unit-lower-triangular factor `L` and the upper-triangular factor `U`.
46- Three outputs: `[L, U, P] = lu(A)` additionally returns a permutation so that `P * A = L * U`. Use the option `'vector'` to receive the permutation as a pivot vector instead of a matrix.
47
48The implementation follows MATLAB’s dense behaviour for full matrices and supports rectangular inputs.
49
50## How does the `lu` function behave in MATLAB / RunMat?
51- Partial pivoting is applied to improve numerical stability. The permutation is encoded either as a dense matrix (`'matrix'`, default) or as a pivot vector (`'vector'`).
52- Rectangular inputs are supported. `L` is always `m × m` (unit lower-triangular), and `U` is `m × n`, where `m` and `n` are the row and column counts of `A`.
53- Singular matrices are permitted. Zero pivots propagate into the `U` factor just as in MATLAB; MATLAB-compatible warnings are not yet emitted.
54- Only the first three outputs are implemented today. Column permutations (`Q`) and scaling (`R`) for the five-output sparse form are not yet available.
55
56## GPU execution in RunMat
57- When an acceleration provider implements the `lu` hook (the WGPU provider does), the factorization executes through that provider and the combined LU factor, `L`, `U`, and permutation outputs all remain on the device. The current WGPU backend performs the decomposition on the host once and immediately reuploads the factors so residency is preserved until dedicated kernels land.
58- The `'vector'` option likewise returns a GPU-resident pivot vector when a provider hook is active.
59- If no provider hook is available, RunMat automatically gathers the input to host memory and falls back to the CPU implementation so behaviour stays MATLAB-compatible.
60
61## Examples of using the `lu` function in MATLAB / RunMat
62
63### Factorizing a square matrix with `lu`
64```matlab
65A = [2 1 1; 4 -6 0; -2 7 2];
66[L, U, P] = lu(A);
67```
68Expected output (up to floating-point roundoff):
69```matlab
70L =
71     1     0     0
72    -1     1     0
73     0    -1     1
74
75U =
76     4    -6     0
77     0     1     1
78     0     0     3
79
80P =
81     0     1     0
82     1     0     0
83     0     0     1
84```
85
86### Obtaining only the combined LU factor
87```matlab
88LU = lu([1 3 5; 2 4 7; 1 1 0]);
89```
90Expected output:
91```matlab
92LU =
93     2     4     7
94   0.5     1   -1.5
95   0.5   -0.5    2
96```
97
98### Requesting the permutation vector with the `'vector'` option
99```matlab
100[L, U, p] = lu([4 3; 6 3], 'vector');
101```
102Expected output:
103```matlab
104p =
105     2
106     1
107```
108
109### LU factorization of a rectangular matrix
110```matlab
111A = [3 1 2; 6 3 4];
112[L, U, P] = lu(A);
113```
114Expected output:
115```matlab
116L =
117     1     0
118     0.5   1
119
120U =
121     6     3     4
122     0    -0.5    0
123
124P =
125     0     1
126     1     0
127```
128
129### Using LU factors to solve a linear system
130```matlab
131A = [3 1 2; 6 3 4];
132b = [1; 2];
133[L, U, P] = lu(A);
134y = L \ (P * b);
135x = U \ y;
136```
137Expected output:
138```matlab
139x =
140    0.0
141    0.5
142   -0.0
143```
144
145### Running `lu` on a `gpuArray`
146```matlab
147G = gpuArray([10 7; 3 2]);
148[L, U, P] = lu(G);
149class(L)
150class(U)
151class(P)
152```
153Expected output:
154```matlab
155ans =
156    'gpuArray'
157
158ans =
159    'gpuArray'
160
161ans =
162    'gpuArray'
163```
164If no acceleration provider exposes `lu`, RunMat gathers the input and returns the factors as host double arrays instead.
165
166## FAQ
167
168### Why does RunMat currently stop at three outputs?
169Column pivoting (`Q`) and scaling (`R`) from MATLAB’s five-output sparse form are planned but not yet implemented. The dense three-output contract mirrors MATLAB’s default dense behaviour.
170
171### Does the permutation vector use MATLAB’s 1-based indexing?
172Yes. When you request `'vector'`, the returned pivot vector contains 1-based row indices so that `A(p, :) = L * U`.
173
174### How are singular matrices handled?
175Partial pivoting proceeds exactly as in MATLAB. If a pivot column is entirely zero, the corresponding diagonal entries in `U` become zero. No warning is emitted yet.
176
177### Are complex matrices supported?
178Yes. Complex inputs produce complex `L`, `U`, and `LU`. The permutation remains real because it only contains zeros and ones.
179
180### Will the factors stay on the GPU when I pass a `gpuArray`?
181Yes. When the active acceleration provider exposes the `lu` hook (WGPU today), the combined factor, `L`, `U`, and the permutation outputs remain `gpuArray` values—the provider currently performs the decomposition on the host once and reuploads the results to preserve residency. Without provider support, RunMat gathers to host memory before returning the factors.
182
183### Can I call `lu` on logical arrays?
184Yes. Logical inputs are promoted to double precision before factorization, matching MATLAB semantics.
185
186### Is pivoting deterministic?
187Yes. Partial pivoting always chooses the first maximal entry in each column, mirroring MATLAB’s behaviour for dense matrices.
188
189### How accurate is the factorization?
190The implementation uses standard double-precision arithmetic (or complex double when needed). Numerical properties therefore match MATLAB’s dense fallback (without iterative refinement).
191
192### What happens if I pass more than one option argument?
193RunMat currently supports at most one option string (`'matrix'` or `'vector'`). Passing additional options raises an error.
194
195### Can I reuse the combined LU factor to solve systems?
196Yes. The combined matrix returned by `lu(A)` stores `L` in the strictly lower-triangular part (with an implicit unit diagonal) and `U` in the upper-triangular part, just like MATLAB. You can use forward/back substitution routines that understand this layout.
197
198## See Also
199[det](../../det), [inv](../../inv), [chol](./chol), [qr](./qr), [solve](../../solve/backslash), [gpuArray](../../../acceleration/gpu/gpuArray)
200
201## Source & Feedback
202- Implementation: `crates/runmat-runtime/src/builtins/math/linalg/factor/lu.rs`
203- Found an issue or missing behaviour? [Open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with details and a minimal reproduction.
204"#;
205
206pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
207    name: "lu",
208    op_kind: GpuOpKind::Custom("lu-factor"),
209    supported_precisions: &[ScalarType::F64],
210    broadcast: BroadcastSemantics::None,
211    provider_hooks: &[ProviderHook::Custom("lu")],
212    constant_strategy: ConstantStrategy::InlineLiteral,
213    residency: ResidencyPolicy::NewHandle,
214    nan_mode: ReductionNaN::Include,
215    two_pass_threshold: None,
216    workgroup_size: None,
217    accepts_nan_mode: false,
218    notes: "Prefers the provider `lu` hook; automatically gathers and falls back to the CPU implementation when no provider support is registered.",
219};
220
221register_builtin_gpu_spec!(GPU_SPEC);
222
223pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
224    name: "lu",
225    shape: ShapeRequirements::Any,
226    constant_strategy: ConstantStrategy::InlineLiteral,
227    elementwise: None,
228    reduction: None,
229    emits_nan: false,
230    notes: "LU decomposition is not part of expression fusion; calls execute eagerly on the CPU.",
231};
232
233register_builtin_fusion_spec!(FUSION_SPEC);
234
235#[cfg(feature = "doc_export")]
236register_builtin_doc_text!("lu", DOC_MD);
237
238#[runtime_builtin(
239    name = "lu",
240    category = "math/linalg/factor",
241    summary = "LU decomposition with partial pivoting.",
242    keywords = "lu,factorization,decomposition,permutation",
243    accel = "sink",
244    sink = true
245)]
246fn lu_builtin(value: Value, rest: Vec<Value>) -> Result<Value, String> {
247    let eval = evaluate(value, &rest)?;
248    Ok(eval.combined())
249}
250
251/// Output form for `lu`, reused by both the builtin wrapper and the VM multi-output path.
252#[derive(Clone)]
253pub struct LuEval {
254    combined: Value,
255    lower: Value,
256    upper: Value,
257    perm_matrix: Value,
258    perm_vector: Value,
259    pivot_mode: PivotMode,
260}
261
262impl LuEval {
263    /// Combined LU factor (single-output form).
264    pub fn combined(&self) -> Value {
265        self.combined.clone()
266    }
267
268    /// Lower-triangular factor.
269    pub fn lower(&self) -> Value {
270        self.lower.clone()
271    }
272
273    /// Upper-triangular factor.
274    pub fn upper(&self) -> Value {
275        self.upper.clone()
276    }
277
278    /// Permutation value respecting the selected pivot mode.
279    pub fn permutation(&self) -> Value {
280        match self.pivot_mode {
281            PivotMode::Matrix => self.perm_matrix.clone(),
282            PivotMode::Vector => self.perm_vector.clone(),
283        }
284    }
285
286    /// Permutation matrix (always available, useful for tests).
287    pub fn permutation_matrix(&self) -> Value {
288        self.perm_matrix.clone()
289    }
290
291    /// Pivot vector (always available, useful for tests).
292    pub fn pivot_vector(&self) -> Value {
293        self.perm_vector.clone()
294    }
295
296    /// The pivot mode that was requested.
297    pub fn pivot_mode(&self) -> PivotMode {
298        self.pivot_mode
299    }
300
301    fn from_components(components: LuComponents, pivot_mode: PivotMode) -> Result<Self, String> {
302        let combined = matrix_to_value(&components.combined)?;
303        let lower = matrix_to_value(&components.lower)?;
304        let upper = matrix_to_value(&components.upper)?;
305        let perm_matrix = matrix_to_value(&components.permutation)?;
306        let perm_vector = pivot_vector_to_value(&components.pivot_vector)?;
307        Ok(Self {
308            combined,
309            lower,
310            upper,
311            perm_matrix,
312            perm_vector,
313            pivot_mode,
314        })
315    }
316
317    fn from_provider(result: ProviderLuResult, pivot_mode: PivotMode) -> Self {
318        Self {
319            combined: Value::GpuTensor(result.combined),
320            lower: Value::GpuTensor(result.lower),
321            upper: Value::GpuTensor(result.upper),
322            perm_matrix: Value::GpuTensor(result.perm_matrix),
323            perm_vector: Value::GpuTensor(result.perm_vector),
324            pivot_mode,
325        }
326    }
327}
328
329/// Permutation output mode.
330#[derive(Clone, Copy, Debug, PartialEq, Eq)]
331pub enum PivotMode {
332    Matrix,
333    Vector,
334}
335
336impl Default for PivotMode {
337    fn default() -> Self {
338        Self::Matrix
339    }
340}
341
342/// Evaluate `lu` while preserving all output forms for later extraction.
343pub fn evaluate(value: Value, args: &[Value]) -> Result<LuEval, String> {
344    let pivot_mode = parse_pivot_mode(args)?;
345    match value {
346        Value::GpuTensor(handle) => {
347            if let Some(eval) = evaluate_gpu(&handle, pivot_mode)? {
348                return Ok(eval);
349            }
350            let tensor = gpu_helpers::gather_tensor(&handle)?;
351            evaluate_host_value(Value::Tensor(tensor), pivot_mode)
352        }
353        other => evaluate_host_value(other, pivot_mode),
354    }
355}
356
357fn evaluate_host_value(value: Value, pivot_mode: PivotMode) -> Result<LuEval, String> {
358    let matrix = extract_matrix(value)?;
359    let components = lu_factor(matrix)?;
360    LuEval::from_components(components, pivot_mode)
361}
362
363fn evaluate_gpu(handle: &GpuTensorHandle, pivot_mode: PivotMode) -> Result<Option<LuEval>, String> {
364    if let Some(provider) = runmat_accelerate_api::provider() {
365        if let Ok(result) = provider.lu(handle) {
366            return Ok(Some(LuEval::from_provider(result, pivot_mode)));
367        }
368    }
369    Ok(None)
370}
371
372fn parse_pivot_mode(args: &[Value]) -> Result<PivotMode, String> {
373    if args.is_empty() {
374        return Ok(PivotMode::Matrix);
375    }
376    if args.len() > 1 {
377        return Err("lu: too many option arguments".to_string());
378    }
379    let Some(option) = tensor::value_to_string(&args[0]) else {
380        return Err("lu: option must be a string or character vector".to_string());
381    };
382    match option.trim().to_ascii_lowercase().as_str() {
383        "matrix" => Ok(PivotMode::Matrix),
384        "vector" => Ok(PivotMode::Vector),
385        other => Err(format!("lu: unknown option '{other}'")),
386    }
387}
388
389fn extract_matrix(value: Value) -> Result<RowMajorMatrix, String> {
390    match value {
391        Value::Tensor(t) => RowMajorMatrix::from_tensor(&t),
392        Value::ComplexTensor(ct) => RowMajorMatrix::from_complex_tensor(&ct),
393        Value::GpuTensor(handle) => {
394            let tensor = gpu_helpers::gather_tensor(&handle)?;
395            RowMajorMatrix::from_tensor(&tensor)
396        }
397        Value::LogicalArray(logical) => {
398            let tensor = tensor::logical_to_tensor(&logical)?;
399            RowMajorMatrix::from_tensor(&tensor)
400        }
401        Value::Num(n) => Ok(RowMajorMatrix::from_scalar(Complex64::new(n, 0.0))),
402        Value::Int(i) => Ok(RowMajorMatrix::from_scalar(Complex64::new(i.to_f64(), 0.0))),
403        Value::Bool(b) => Ok(RowMajorMatrix::from_scalar(Complex64::new(
404            if b { 1.0 } else { 0.0 },
405            0.0,
406        ))),
407        Value::Complex(re, im) => Ok(RowMajorMatrix::from_scalar(Complex64::new(re, im))),
408        Value::CharArray(_) | Value::String(_) | Value::StringArray(_) => {
409            Err("lu: character data is not supported; convert to numeric values first".to_string())
410        }
411        other => Err(format!("lu: unsupported input type {:?}", other)),
412    }
413}
414
415struct LuComponents {
416    combined: RowMajorMatrix,
417    lower: RowMajorMatrix,
418    upper: RowMajorMatrix,
419    permutation: RowMajorMatrix,
420    pivot_vector: Vec<f64>,
421}
422
423fn lu_factor(mut matrix: RowMajorMatrix) -> Result<LuComponents, String> {
424    let rows = matrix.rows;
425    let cols = matrix.cols;
426    let min_dim = rows.min(cols);
427    let mut perm: Vec<usize> = (0..rows).collect();
428
429    for k in 0..min_dim {
430        // Select pivot row with maximal absolute value in column k.
431        let mut pivot_row = k;
432        let mut pivot_abs = 0.0;
433        for r in k..rows {
434            let val = matrix.get(r, k);
435            let abs = val.norm();
436            if abs > pivot_abs {
437                pivot_abs = abs;
438                pivot_row = r;
439            }
440        }
441
442        if pivot_row != k {
443            matrix.swap_rows(pivot_row, k);
444            perm.swap(pivot_row, k);
445        }
446
447        if pivot_abs <= EPS {
448            // Entire column is effectively zero; set multipliers to zero and continue.
449            for r in (k + 1)..rows {
450                matrix.set(r, k, Complex64::new(0.0, 0.0));
451            }
452            continue;
453        }
454
455        let pivot_value = matrix.get(k, k);
456        for r in (k + 1)..rows {
457            let factor = matrix.get(r, k) / pivot_value;
458            matrix.set(r, k, factor);
459            for c in (k + 1)..cols {
460                let updated = matrix.get(r, c) - factor * matrix.get(k, c);
461                matrix.set(r, c, updated);
462            }
463        }
464    }
465
466    let combined = matrix.clone();
467    let lower = build_lower(&matrix);
468    let upper = build_upper(&matrix);
469    let permutation = build_permutation(rows, &perm);
470    let pivot_vector: Vec<f64> = perm.iter().map(|idx| (*idx + 1) as f64).collect();
471
472    Ok(LuComponents {
473        combined,
474        lower,
475        upper,
476        permutation,
477        pivot_vector,
478    })
479}
480
481fn build_lower(matrix: &RowMajorMatrix) -> RowMajorMatrix {
482    let rows = matrix.rows;
483    let cols = matrix.cols;
484    let min_dim = rows.min(cols);
485    let mut lower = RowMajorMatrix::identity(rows);
486    for i in 0..rows {
487        for j in 0..min_dim {
488            if i > j {
489                lower.set(i, j, matrix.get(i, j));
490            }
491        }
492    }
493    lower
494}
495
496fn build_upper(matrix: &RowMajorMatrix) -> RowMajorMatrix {
497    let rows = matrix.rows;
498    let cols = matrix.cols;
499    let mut upper = RowMajorMatrix::zeros(rows, cols);
500    for i in 0..rows {
501        for j in 0..cols {
502            if i <= j {
503                upper.set(i, j, matrix.get(i, j));
504            }
505        }
506    }
507    upper
508}
509
510fn build_permutation(rows: usize, perm: &[usize]) -> RowMajorMatrix {
511    let mut matrix = RowMajorMatrix::zeros(rows, rows);
512    for (i, &col) in perm.iter().enumerate() {
513        if col < rows {
514            matrix.set(i, col, Complex64::new(1.0, 0.0));
515        }
516    }
517    matrix
518}
519
520const EPS: f64 = 1.0e-12;
521
522fn matrix_to_value(matrix: &RowMajorMatrix) -> Result<Value, String> {
523    let mut has_imag = false;
524    for val in &matrix.data {
525        if val.im.abs() > EPS {
526            has_imag = true;
527            break;
528        }
529    }
530    if has_imag {
531        let mut data = Vec::with_capacity(matrix.rows * matrix.cols);
532        for col in 0..matrix.cols {
533            for row in 0..matrix.rows {
534                let idx = row * matrix.cols + col;
535                let v = matrix.data[idx];
536                data.push((v.re, v.im));
537            }
538        }
539        let tensor = ComplexTensor::new(data, vec![matrix.rows, matrix.cols])
540            .map_err(|e| format!("lu: {e}"))?;
541        Ok(Value::ComplexTensor(tensor))
542    } else {
543        let mut data = Vec::with_capacity(matrix.rows * matrix.cols);
544        for col in 0..matrix.cols {
545            for row in 0..matrix.rows {
546                let idx = row * matrix.cols + col;
547                data.push(matrix.data[idx].re);
548            }
549        }
550        let tensor =
551            Tensor::new(data, vec![matrix.rows, matrix.cols]).map_err(|e| format!("lu: {e}"))?;
552        Ok(Value::Tensor(tensor))
553    }
554}
555
556fn pivot_vector_to_value(pivot: &[f64]) -> Result<Value, String> {
557    let rows = pivot.len();
558    let tensor = Tensor::new(pivot.to_vec(), vec![rows, 1]).map_err(|e| format!("lu: {e}"))?;
559    Ok(Value::Tensor(tensor))
560}
561
562#[derive(Clone)]
563struct RowMajorMatrix {
564    rows: usize,
565    cols: usize,
566    data: Vec<Complex64>,
567}
568
569impl RowMajorMatrix {
570    fn zeros(rows: usize, cols: usize) -> Self {
571        Self {
572            rows,
573            cols,
574            data: vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)],
575        }
576    }
577
578    fn identity(size: usize) -> Self {
579        let mut matrix = Self::zeros(size, size);
580        for i in 0..size {
581            matrix.set(i, i, Complex64::new(1.0, 0.0));
582        }
583        matrix
584    }
585
586    fn from_scalar(value: Complex64) -> Self {
587        Self {
588            rows: 1,
589            cols: 1,
590            data: vec![value],
591        }
592    }
593
594    fn from_tensor(tensor: &Tensor) -> Result<Self, String> {
595        if tensor.shape.len() > 2 {
596            return Err("lu: input must be 2-D".to_string());
597        }
598        let rows = tensor.rows();
599        let cols = tensor.cols();
600        let mut data = vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)];
601        for col in 0..cols {
602            for row in 0..rows {
603                let idx_col_major = row + col * rows;
604                let idx_row_major = row * cols + col;
605                data[idx_row_major] = Complex64::new(tensor.data[idx_col_major], 0.0);
606            }
607        }
608        Ok(Self { rows, cols, data })
609    }
610
611    fn from_complex_tensor(tensor: &ComplexTensor) -> Result<Self, String> {
612        if tensor.shape.len() > 2 {
613            return Err("lu: input must be 2-D".to_string());
614        }
615        let rows = tensor.rows;
616        let cols = tensor.cols;
617        let mut data = vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)];
618        for col in 0..cols {
619            for row in 0..rows {
620                let idx_col_major = row + col * rows;
621                let idx_row_major = row * cols + col;
622                let (re, im) = tensor.data[idx_col_major];
623                data[idx_row_major] = Complex64::new(re, im);
624            }
625        }
626        Ok(Self { rows, cols, data })
627    }
628
629    fn get(&self, row: usize, col: usize) -> Complex64 {
630        self.data[row * self.cols + col]
631    }
632
633    fn set(&mut self, row: usize, col: usize, value: Complex64) {
634        self.data[row * self.cols + col] = value;
635    }
636
637    fn swap_rows(&mut self, r1: usize, r2: usize) {
638        if r1 == r2 {
639            return;
640        }
641        for col in 0..self.cols {
642            self.data.swap(r1 * self.cols + col, r2 * self.cols + col);
643        }
644    }
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650    use crate::builtins::common::test_support;
651    use runmat_builtins::{ComplexTensor as CMatrix, Tensor as Matrix};
652
653    fn tensor_from_value(value: Value) -> Matrix {
654        match value {
655            Value::Tensor(t) => t,
656            other => panic!("expected dense tensor, got {other:?}"),
657        }
658    }
659
660    fn row_major_from_value(value: Value) -> RowMajorMatrix {
661        match value {
662            Value::Tensor(t) => RowMajorMatrix::from_tensor(&t).expect("row-major tensor"),
663            Value::ComplexTensor(ct) => {
664                RowMajorMatrix::from_complex_tensor(&ct).expect("row-major complex tensor")
665            }
666            other => panic!("expected tensor value, got {other:?}"),
667        }
668    }
669
670    fn row_major_matmul(a: &RowMajorMatrix, b: &RowMajorMatrix) -> RowMajorMatrix {
671        assert_eq!(a.cols, b.rows, "incompatible shapes for matmul");
672        let mut out = RowMajorMatrix::zeros(a.rows, b.cols);
673        for i in 0..a.rows {
674            for k in 0..a.cols {
675                let aik = a.get(i, k);
676                for j in 0..b.cols {
677                    let acc = out.get(i, j) + aik * b.get(k, j);
678                    out.set(i, j, acc);
679                }
680            }
681        }
682        out
683    }
684
685    fn assert_tensor_close(a: &Matrix, b: &Matrix, tol: f64) {
686        assert_eq!(a.shape, b.shape);
687        for (lhs, rhs) in a.data.iter().zip(&b.data) {
688            assert!(
689                (lhs - rhs).abs() <= tol,
690                "mismatch: lhs={lhs}, rhs={rhs}, tol={tol}"
691            );
692        }
693    }
694
695    fn assert_row_major_close(a: &RowMajorMatrix, b: &RowMajorMatrix, tol: f64) {
696        assert_eq!(a.rows, b.rows, "row mismatch");
697        assert_eq!(a.cols, b.cols, "col mismatch");
698        for row in 0..a.rows {
699            for col in 0..a.cols {
700                let lhs = a.get(row, col);
701                let rhs = b.get(row, col);
702                let diff = (lhs - rhs).norm();
703                assert!(
704                    diff <= tol,
705                    "mismatch at ({row}, {col}): lhs={lhs:?}, rhs={rhs:?}, diff={diff}, tol={tol}"
706                );
707            }
708        }
709    }
710
711    #[test]
712    fn lu_single_output_produces_combined_matrix() {
713        let a = Matrix::new(
714            vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0],
715            vec![3, 3],
716        )
717        .unwrap();
718        let result = lu_builtin(Value::Tensor(a.clone()), Vec::new()).expect("lu");
719        let lu = tensor_from_value(result);
720        let eval = evaluate(Value::Tensor(a), &[]).expect("evaluate");
721        let expected = tensor_from_value(eval.combined());
722        assert_tensor_close(&lu, &expected, 1e-12);
723    }
724
725    #[test]
726    fn lu_three_outputs_matches_factorization() {
727        let data = vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0];
728        let a = Matrix::new(data.clone(), vec![3, 3]).unwrap();
729        let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate");
730        let l = tensor_from_value(eval.lower());
731        let u = tensor_from_value(eval.upper());
732        let p = tensor_from_value(eval.permutation_matrix());
733
734        let pa = crate::matrix::matrix_mul(&p, &a).expect("P*A");
735        let lu_product = crate::matrix::matrix_mul(&l, &u).expect("L*U");
736        assert_tensor_close(&pa, &lu_product, 1e-9);
737    }
738
739    #[test]
740    fn lu_complex_matrix_factorization() {
741        let data = vec![(1.0, 2.0), (3.0, -1.0), (2.0, -1.0), (4.0, 2.0)];
742        let a = CMatrix::new(data.clone(), vec![2, 2]).expect("complex tensor");
743        let eval = evaluate(Value::ComplexTensor(a.clone()), &[]).expect("evaluate complex");
744
745        let l = row_major_from_value(eval.lower());
746        let u = row_major_from_value(eval.upper());
747        let p = row_major_from_value(eval.permutation_matrix());
748        let input = RowMajorMatrix::from_complex_tensor(&a).expect("row-major input");
749
750        let pa = row_major_matmul(&p, &input);
751        let lu = row_major_matmul(&l, &u);
752        assert_row_major_close(&pa, &lu, 1e-9);
753    }
754
755    #[test]
756    fn lu_handles_singular_matrix() {
757        let a = Matrix::new(vec![0.0, 0.0, 0.0, 0.0], vec![2, 2]).unwrap();
758        let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate singular");
759        let l = tensor_from_value(eval.lower());
760        let u = tensor_from_value(eval.upper());
761        let p = tensor_from_value(eval.permutation_matrix());
762
763        assert!(u.data.iter().any(|&v| v.abs() <= 1e-12));
764
765        let pa = crate::matrix::matrix_mul(&p, &a).expect("P*A");
766        let lu_product = crate::matrix::matrix_mul(&l, &u).expect("L*U");
767        assert_tensor_close(&pa, &lu_product, 1e-9);
768    }
769
770    #[test]
771    fn lu_vector_option_returns_pivot_vector() {
772        let a = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
773        let eval =
774            evaluate(Value::Tensor(a), &[Value::from("vector")]).expect("evaluate vector mode");
775        assert_eq!(eval.pivot_mode(), PivotMode::Vector);
776        let pivot = tensor_from_value(eval.pivot_vector());
777        assert_eq!(pivot.shape, vec![2, 1]);
778        assert_eq!(pivot.data, vec![2.0, 1.0]);
779    }
780
781    #[test]
782    fn lu_vector_option_case_insensitive() {
783        let a = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
784        let eval =
785            evaluate(Value::Tensor(a), &[Value::from("VECTOR")]).expect("evaluate vector option");
786        assert_eq!(eval.pivot_mode(), PivotMode::Vector);
787    }
788
789    #[test]
790    fn lu_matrix_option_returns_permutation_matrix() {
791        let a = Matrix::new(vec![2.0, 1.0, 3.0, 4.0], vec![2, 2]).unwrap();
792        let eval =
793            evaluate(Value::Tensor(a), &[Value::from("matrix")]).expect("evaluate matrix option");
794        assert_eq!(eval.pivot_mode(), PivotMode::Matrix);
795        let perm_selected = tensor_from_value(eval.permutation());
796        let perm_matrix = tensor_from_value(eval.permutation_matrix());
797        assert_eq!(perm_selected.shape, perm_matrix.shape);
798        assert_tensor_close(&perm_selected, &perm_matrix, 1e-12);
799    }
800
801    #[test]
802    fn lu_handles_rectangular_matrices() {
803        let a = Matrix::new(vec![3.0, 6.0, 1.0, 3.0, 2.0, 4.0], vec![2, 3]).unwrap();
804        let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate rectangular");
805        let l = tensor_from_value(eval.lower());
806        let u = tensor_from_value(eval.upper());
807        let p = tensor_from_value(eval.permutation_matrix());
808        assert_eq!(l.shape, vec![2, 2]);
809        assert_eq!(u.shape, vec![2, 3]);
810        assert_eq!(p.shape, vec![2, 2]);
811
812        let pa = crate::matrix::matrix_mul(&p, &a).expect("P*A");
813        let lu_product = crate::matrix::matrix_mul(&l, &u).expect("L*U");
814        assert_tensor_close(&pa, &lu_product, 1e-9);
815    }
816
817    #[test]
818    fn lu_rejects_unknown_option() {
819        let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
820        let err = match evaluate(Value::Tensor(a), &[Value::from("invalid")]) {
821            Ok(_) => panic!("expected option parse failure"),
822            Err(err) => err,
823        };
824        assert!(err.contains("unknown option"));
825    }
826
827    #[test]
828    fn lu_rejects_non_string_option() {
829        let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
830        let err = match evaluate(Value::Tensor(a), &[Value::Num(2.0)]) {
831            Ok(_) => panic!("expected option parse failure"),
832            Err(err) => err,
833        };
834        assert!(err.contains("unknown option"));
835    }
836
837    #[test]
838    fn lu_rejects_multiple_options() {
839        let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
840        let err = match evaluate(
841            Value::Tensor(a),
842            &[Value::from("matrix"), Value::from("vector")],
843        ) {
844            Ok(_) => panic!("expected option arity failure"),
845            Err(err) => err,
846        };
847        assert!(err.contains("too many option arguments"));
848    }
849
850    #[test]
851    fn lu_gpu_provider_roundtrip() {
852        test_support::with_test_provider(|provider| {
853            let host = Matrix::new(vec![10.0, 3.0, 7.0, 2.0], vec![2, 2]).unwrap();
854            let view = runmat_accelerate_api::HostTensorView {
855                data: &host.data,
856                shape: &host.shape,
857            };
858            let handle = provider.upload(&view).expect("upload");
859            let eval = evaluate(Value::GpuTensor(handle.clone()), &[]).expect("evaluate gpu input");
860            let lower_val = eval.lower();
861            let upper_val = eval.upper();
862            let perm_val = eval.permutation_matrix();
863            assert!(matches!(lower_val, Value::GpuTensor(_)));
864            assert!(matches!(upper_val, Value::GpuTensor(_)));
865            assert!(matches!(perm_val, Value::GpuTensor(_)));
866            let l = test_support::gather(lower_val).expect("gather lower");
867            let u = test_support::gather(upper_val).expect("gather upper");
868            let p = test_support::gather(perm_val).expect("gather permutation");
869            let pa = crate::matrix::matrix_mul(&p, &host).expect("P*A");
870            let lu_product = crate::matrix::matrix_mul(&l, &u).expect("L*U");
871            assert_tensor_close(&pa, &lu_product, 1e-9);
872        });
873    }
874
875    #[test]
876    fn lu_gpu_vector_option_roundtrip() {
877        test_support::with_test_provider(|provider| {
878            let host = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
879            let view = runmat_accelerate_api::HostTensorView {
880                data: &host.data,
881                shape: &host.shape,
882            };
883            let handle = provider.upload(&view).expect("upload");
884            let eval =
885                evaluate(Value::GpuTensor(handle), &[Value::from("vector")]).expect("gpu vector");
886            let pivot_val = eval.permutation();
887            assert!(matches!(pivot_val, Value::GpuTensor(_)));
888            let pivot = test_support::gather(pivot_val).expect("gather pivot");
889            assert_eq!(pivot.shape, vec![2, 1]);
890            let expected = Matrix::new(vec![2.0, 1.0], vec![2, 1]).unwrap();
891            assert_tensor_close(&pivot, &expected, 1e-12);
892        });
893    }
894
895    #[test]
896    fn lu_accepts_scalar_inputs() {
897        let eval = evaluate(Value::Num(5.0), &[]).expect("evaluate scalar");
898        let l = tensor_from_value(eval.lower());
899        let u = tensor_from_value(eval.upper());
900        let p = tensor_from_value(eval.permutation_matrix());
901        assert_eq!(l.data, vec![1.0]);
902        assert_eq!(u.data, vec![5.0]);
903        assert_eq!(p.data, vec![1.0]);
904    }
905
906    #[test]
907    #[cfg(feature = "wgpu")]
908    fn lu_wgpu_matches_cpu() {
909        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
910            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
911        );
912        let host = Matrix::new(
913            vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0],
914            vec![3, 3],
915        )
916        .unwrap();
917        let cpu_eval = evaluate(Value::Tensor(host.clone()), &[]).expect("cpu evaluate");
918        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
919        let view = runmat_accelerate_api::HostTensorView {
920            data: &host.data,
921            shape: &host.shape,
922        };
923        let handle = provider.upload(&view).expect("upload");
924        let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu evaluate");
925
926        let l_cpu = tensor_from_value(cpu_eval.lower());
927        let u_cpu = tensor_from_value(cpu_eval.upper());
928        let p_cpu = tensor_from_value(cpu_eval.permutation_matrix());
929        let lu_cpu = tensor_from_value(cpu_eval.combined());
930
931        let l_gpu = test_support::gather(gpu_eval.lower()).expect("gather L");
932        let u_gpu = test_support::gather(gpu_eval.upper()).expect("gather U");
933        let p_gpu = test_support::gather(gpu_eval.permutation_matrix()).expect("gather P");
934        let lu_gpu = test_support::gather(gpu_eval.combined()).expect("gather LU");
935
936        assert_tensor_close(&l_cpu, &l_gpu, 1e-12);
937        assert_tensor_close(&u_cpu, &u_gpu, 1e-12);
938        assert_tensor_close(&p_cpu, &p_gpu, 1e-12);
939        assert_tensor_close(&lu_cpu, &lu_gpu, 1e-12);
940
941        let pivot_cpu = tensor_from_value(cpu_eval.pivot_vector());
942        let pivot_gpu = test_support::gather(gpu_eval.pivot_vector()).expect("gather pivot vector");
943        assert_tensor_close(&pivot_cpu, &pivot_gpu, 1e-12);
944
945        let handle_vector = provider.upload(&view).expect("upload vector option");
946        let gpu_vector_eval = evaluate(Value::GpuTensor(handle_vector), &[Value::from("vector")])
947            .expect("gpu vector evaluate");
948        let pivot_vector =
949            test_support::gather(gpu_vector_eval.permutation()).expect("gather vector pivot");
950        assert_tensor_close(&pivot_cpu, &pivot_vector, 1e-12);
951    }
952
953    #[test]
954    #[cfg(feature = "doc_export")]
955    fn doc_examples_present() {
956        let blocks = test_support::doc_examples(DOC_MD);
957        assert!(!blocks.is_empty());
958    }
959}