1use crate::builtins::common::spec::{
4 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
5 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
6};
7use crate::builtins::common::{gpu_helpers, tensor};
8use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
9use num_complex::Complex64;
10use runmat_accelerate_api::{GpuTensorHandle, ProviderLuResult};
11use runmat_builtins::{ComplexTensor, Tensor, Value};
12use runmat_macros::runtime_builtin;
13
14#[cfg(feature = "doc_export")]
15use crate::register_builtin_doc_text;
16
17#[cfg(feature = "doc_export")]
18pub const DOC_MD: &str = r#"---
19title: "lu"
20category: "math/linalg/factor"
21keywords: ["lu", "factorization", "decomposition", "lower-upper", "permutation"]
22summary: "LU decomposition with partial pivoting, matching MATLAB semantics."
23references: []
24gpu_support:
25 elementwise: false
26 reduction: false
27 precisions: ["f64"]
28 broadcasting: "none"
29 notes: "Uses the provider `lu` hook when available; otherwise gathers to the host fallback implementation."
30fusion:
31 elementwise: false
32 reduction: false
33 max_inputs: 1
34 constants: "inline"
35requires_feature: null
36tested:
37 unit: "builtins::math::linalg::factor::lu::tests"
38 integration: "builtins::math::linalg::factor::lu::tests::lu_three_outputs_matches_factorization"
39---
40
41# What does the `lu` function do in MATLAB / RunMat?
42`lu(A)` computes the LU factorization of a real or complex matrix `A` using partial pivoting. It exposes the same calling forms as MATLAB:
43
44- Single output: `lu(A)` returns a single matrix whose strictly lower-triangular entries encode `L` (with an implicit unit diagonal) and whose upper-triangular part encodes `U`.
45- Two outputs: `[L, U] = lu(A)` returns the explicit unit-lower-triangular factor `L` and the upper-triangular factor `U`.
46- Three outputs: `[L, U, P] = lu(A)` additionally returns a permutation so that `P * A = L * U`. Use the option `'vector'` to receive the permutation as a pivot vector instead of a matrix.
47
48The implementation follows MATLAB’s dense behaviour for full matrices and supports rectangular inputs.
49
50## How does the `lu` function behave in MATLAB / RunMat?
51- Partial pivoting is applied to improve numerical stability. The permutation is encoded either as a dense matrix (`'matrix'`, default) or as a pivot vector (`'vector'`).
52- Rectangular inputs are supported. `L` is always `m × m` (unit lower-triangular), and `U` is `m × n`, where `m` and `n` are the row and column counts of `A`.
53- Singular matrices are permitted. Zero pivots propagate into the `U` factor just as in MATLAB; MATLAB-compatible warnings are not yet emitted.
54- Only the first three outputs are implemented today. Column permutations (`Q`) and scaling (`R`) for the five-output sparse form are not yet available.
55
56## GPU execution in RunMat
57- When an acceleration provider implements the `lu` hook (the WGPU provider does), the factorization executes through that provider and the combined LU factor, `L`, `U`, and permutation outputs all remain on the device. The current WGPU backend performs the decomposition on the host once and immediately reuploads the factors so residency is preserved until dedicated kernels land.
58- The `'vector'` option likewise returns a GPU-resident pivot vector when a provider hook is active.
59- If no provider hook is available, RunMat automatically gathers the input to host memory and falls back to the CPU implementation so behaviour stays MATLAB-compatible.
60
61## Examples of using the `lu` function in MATLAB / RunMat
62
63### Factorizing a square matrix with `lu`
64```matlab
65A = [2 1 1; 4 -6 0; -2 7 2];
66[L, U, P] = lu(A);
67```
68Expected output (up to floating-point roundoff):
69```matlab
70L =
71 1 0 0
72 -1 1 0
73 0 -1 1
74
75U =
76 4 -6 0
77 0 1 1
78 0 0 3
79
80P =
81 0 1 0
82 1 0 0
83 0 0 1
84```
85
86### Obtaining only the combined LU factor
87```matlab
88LU = lu([1 3 5; 2 4 7; 1 1 0]);
89```
90Expected output:
91```matlab
92LU =
93 2 4 7
94 0.5 1 -1.5
95 0.5 -0.5 2
96```
97
98### Requesting the permutation vector with the `'vector'` option
99```matlab
100[L, U, p] = lu([4 3; 6 3], 'vector');
101```
102Expected output:
103```matlab
104p =
105 2
106 1
107```
108
109### LU factorization of a rectangular matrix
110```matlab
111A = [3 1 2; 6 3 4];
112[L, U, P] = lu(A);
113```
114Expected output:
115```matlab
116L =
117 1 0
118 0.5 1
119
120U =
121 6 3 4
122 0 -0.5 0
123
124P =
125 0 1
126 1 0
127```
128
129### Using LU factors to solve a linear system
130```matlab
131A = [3 1 2; 6 3 4];
132b = [1; 2];
133[L, U, P] = lu(A);
134y = L \ (P * b);
135x = U \ y;
136```
137Expected output:
138```matlab
139x =
140 0.0
141 0.5
142 -0.0
143```
144
145### Running `lu` on a `gpuArray`
146```matlab
147G = gpuArray([10 7; 3 2]);
148[L, U, P] = lu(G);
149class(L)
150class(U)
151class(P)
152```
153Expected output:
154```matlab
155ans =
156 'gpuArray'
157
158ans =
159 'gpuArray'
160
161ans =
162 'gpuArray'
163```
164If no acceleration provider exposes `lu`, RunMat gathers the input and returns the factors as host double arrays instead.
165
166## FAQ
167
168### Why does RunMat currently stop at three outputs?
169Column pivoting (`Q`) and scaling (`R`) from MATLAB’s five-output sparse form are planned but not yet implemented. The dense three-output contract mirrors MATLAB’s default dense behaviour.
170
171### Does the permutation vector use MATLAB’s 1-based indexing?
172Yes. When you request `'vector'`, the returned pivot vector contains 1-based row indices so that `A(p, :) = L * U`.
173
174### How are singular matrices handled?
175Partial pivoting proceeds exactly as in MATLAB. If a pivot column is entirely zero, the corresponding diagonal entries in `U` become zero. No warning is emitted yet.
176
177### Are complex matrices supported?
178Yes. Complex inputs produce complex `L`, `U`, and `LU`. The permutation remains real because it only contains zeros and ones.
179
180### Will the factors stay on the GPU when I pass a `gpuArray`?
181Yes. When the active acceleration provider exposes the `lu` hook (WGPU today), the combined factor, `L`, `U`, and the permutation outputs remain `gpuArray` values—the provider currently performs the decomposition on the host once and reuploads the results to preserve residency. Without provider support, RunMat gathers to host memory before returning the factors.
182
183### Can I call `lu` on logical arrays?
184Yes. Logical inputs are promoted to double precision before factorization, matching MATLAB semantics.
185
186### Is pivoting deterministic?
187Yes. Partial pivoting always chooses the first maximal entry in each column, mirroring MATLAB’s behaviour for dense matrices.
188
189### How accurate is the factorization?
190The implementation uses standard double-precision arithmetic (or complex double when needed). Numerical properties therefore match MATLAB’s dense fallback (without iterative refinement).
191
192### What happens if I pass more than one option argument?
193RunMat currently supports at most one option string (`'matrix'` or `'vector'`). Passing additional options raises an error.
194
195### Can I reuse the combined LU factor to solve systems?
196Yes. The combined matrix returned by `lu(A)` stores `L` in the strictly lower-triangular part (with an implicit unit diagonal) and `U` in the upper-triangular part, just like MATLAB. You can use forward/back substitution routines that understand this layout.
197
198## See Also
199[det](../../det), [inv](../../inv), [chol](./chol), [qr](./qr), [solve](../../solve/backslash), [gpuArray](../../../acceleration/gpu/gpuArray)
200
201## Source & Feedback
202- Implementation: `crates/runmat-runtime/src/builtins/math/linalg/factor/lu.rs`
203- Found an issue or missing behaviour? [Open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with details and a minimal reproduction.
204"#;
205
206pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
207 name: "lu",
208 op_kind: GpuOpKind::Custom("lu-factor"),
209 supported_precisions: &[ScalarType::F64],
210 broadcast: BroadcastSemantics::None,
211 provider_hooks: &[ProviderHook::Custom("lu")],
212 constant_strategy: ConstantStrategy::InlineLiteral,
213 residency: ResidencyPolicy::NewHandle,
214 nan_mode: ReductionNaN::Include,
215 two_pass_threshold: None,
216 workgroup_size: None,
217 accepts_nan_mode: false,
218 notes: "Prefers the provider `lu` hook; automatically gathers and falls back to the CPU implementation when no provider support is registered.",
219};
220
221register_builtin_gpu_spec!(GPU_SPEC);
222
223pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
224 name: "lu",
225 shape: ShapeRequirements::Any,
226 constant_strategy: ConstantStrategy::InlineLiteral,
227 elementwise: None,
228 reduction: None,
229 emits_nan: false,
230 notes: "LU decomposition is not part of expression fusion; calls execute eagerly on the CPU.",
231};
232
233register_builtin_fusion_spec!(FUSION_SPEC);
234
235#[cfg(feature = "doc_export")]
236register_builtin_doc_text!("lu", DOC_MD);
237
238#[runtime_builtin(
239 name = "lu",
240 category = "math/linalg/factor",
241 summary = "LU decomposition with partial pivoting.",
242 keywords = "lu,factorization,decomposition,permutation",
243 accel = "sink",
244 sink = true
245)]
246fn lu_builtin(value: Value, rest: Vec<Value>) -> Result<Value, String> {
247 let eval = evaluate(value, &rest)?;
248 Ok(eval.combined())
249}
250
251#[derive(Clone)]
253pub struct LuEval {
254 combined: Value,
255 lower: Value,
256 upper: Value,
257 perm_matrix: Value,
258 perm_vector: Value,
259 pivot_mode: PivotMode,
260}
261
262impl LuEval {
263 pub fn combined(&self) -> Value {
265 self.combined.clone()
266 }
267
268 pub fn lower(&self) -> Value {
270 self.lower.clone()
271 }
272
273 pub fn upper(&self) -> Value {
275 self.upper.clone()
276 }
277
278 pub fn permutation(&self) -> Value {
280 match self.pivot_mode {
281 PivotMode::Matrix => self.perm_matrix.clone(),
282 PivotMode::Vector => self.perm_vector.clone(),
283 }
284 }
285
286 pub fn permutation_matrix(&self) -> Value {
288 self.perm_matrix.clone()
289 }
290
291 pub fn pivot_vector(&self) -> Value {
293 self.perm_vector.clone()
294 }
295
296 pub fn pivot_mode(&self) -> PivotMode {
298 self.pivot_mode
299 }
300
301 fn from_components(components: LuComponents, pivot_mode: PivotMode) -> Result<Self, String> {
302 let combined = matrix_to_value(&components.combined)?;
303 let lower = matrix_to_value(&components.lower)?;
304 let upper = matrix_to_value(&components.upper)?;
305 let perm_matrix = matrix_to_value(&components.permutation)?;
306 let perm_vector = pivot_vector_to_value(&components.pivot_vector)?;
307 Ok(Self {
308 combined,
309 lower,
310 upper,
311 perm_matrix,
312 perm_vector,
313 pivot_mode,
314 })
315 }
316
317 fn from_provider(result: ProviderLuResult, pivot_mode: PivotMode) -> Self {
318 Self {
319 combined: Value::GpuTensor(result.combined),
320 lower: Value::GpuTensor(result.lower),
321 upper: Value::GpuTensor(result.upper),
322 perm_matrix: Value::GpuTensor(result.perm_matrix),
323 perm_vector: Value::GpuTensor(result.perm_vector),
324 pivot_mode,
325 }
326 }
327}
328
329#[derive(Clone, Copy, Debug, PartialEq, Eq)]
331pub enum PivotMode {
332 Matrix,
333 Vector,
334}
335
336impl Default for PivotMode {
337 fn default() -> Self {
338 Self::Matrix
339 }
340}
341
342pub fn evaluate(value: Value, args: &[Value]) -> Result<LuEval, String> {
344 let pivot_mode = parse_pivot_mode(args)?;
345 match value {
346 Value::GpuTensor(handle) => {
347 if let Some(eval) = evaluate_gpu(&handle, pivot_mode)? {
348 return Ok(eval);
349 }
350 let tensor = gpu_helpers::gather_tensor(&handle)?;
351 evaluate_host_value(Value::Tensor(tensor), pivot_mode)
352 }
353 other => evaluate_host_value(other, pivot_mode),
354 }
355}
356
357fn evaluate_host_value(value: Value, pivot_mode: PivotMode) -> Result<LuEval, String> {
358 let matrix = extract_matrix(value)?;
359 let components = lu_factor(matrix)?;
360 LuEval::from_components(components, pivot_mode)
361}
362
363fn evaluate_gpu(handle: &GpuTensorHandle, pivot_mode: PivotMode) -> Result<Option<LuEval>, String> {
364 if let Some(provider) = runmat_accelerate_api::provider() {
365 if let Ok(result) = provider.lu(handle) {
366 return Ok(Some(LuEval::from_provider(result, pivot_mode)));
367 }
368 }
369 Ok(None)
370}
371
372fn parse_pivot_mode(args: &[Value]) -> Result<PivotMode, String> {
373 if args.is_empty() {
374 return Ok(PivotMode::Matrix);
375 }
376 if args.len() > 1 {
377 return Err("lu: too many option arguments".to_string());
378 }
379 let Some(option) = tensor::value_to_string(&args[0]) else {
380 return Err("lu: option must be a string or character vector".to_string());
381 };
382 match option.trim().to_ascii_lowercase().as_str() {
383 "matrix" => Ok(PivotMode::Matrix),
384 "vector" => Ok(PivotMode::Vector),
385 other => Err(format!("lu: unknown option '{other}'")),
386 }
387}
388
389fn extract_matrix(value: Value) -> Result<RowMajorMatrix, String> {
390 match value {
391 Value::Tensor(t) => RowMajorMatrix::from_tensor(&t),
392 Value::ComplexTensor(ct) => RowMajorMatrix::from_complex_tensor(&ct),
393 Value::GpuTensor(handle) => {
394 let tensor = gpu_helpers::gather_tensor(&handle)?;
395 RowMajorMatrix::from_tensor(&tensor)
396 }
397 Value::LogicalArray(logical) => {
398 let tensor = tensor::logical_to_tensor(&logical)?;
399 RowMajorMatrix::from_tensor(&tensor)
400 }
401 Value::Num(n) => Ok(RowMajorMatrix::from_scalar(Complex64::new(n, 0.0))),
402 Value::Int(i) => Ok(RowMajorMatrix::from_scalar(Complex64::new(i.to_f64(), 0.0))),
403 Value::Bool(b) => Ok(RowMajorMatrix::from_scalar(Complex64::new(
404 if b { 1.0 } else { 0.0 },
405 0.0,
406 ))),
407 Value::Complex(re, im) => Ok(RowMajorMatrix::from_scalar(Complex64::new(re, im))),
408 Value::CharArray(_) | Value::String(_) | Value::StringArray(_) => {
409 Err("lu: character data is not supported; convert to numeric values first".to_string())
410 }
411 other => Err(format!("lu: unsupported input type {:?}", other)),
412 }
413}
414
415struct LuComponents {
416 combined: RowMajorMatrix,
417 lower: RowMajorMatrix,
418 upper: RowMajorMatrix,
419 permutation: RowMajorMatrix,
420 pivot_vector: Vec<f64>,
421}
422
423fn lu_factor(mut matrix: RowMajorMatrix) -> Result<LuComponents, String> {
424 let rows = matrix.rows;
425 let cols = matrix.cols;
426 let min_dim = rows.min(cols);
427 let mut perm: Vec<usize> = (0..rows).collect();
428
429 for k in 0..min_dim {
430 let mut pivot_row = k;
432 let mut pivot_abs = 0.0;
433 for r in k..rows {
434 let val = matrix.get(r, k);
435 let abs = val.norm();
436 if abs > pivot_abs {
437 pivot_abs = abs;
438 pivot_row = r;
439 }
440 }
441
442 if pivot_row != k {
443 matrix.swap_rows(pivot_row, k);
444 perm.swap(pivot_row, k);
445 }
446
447 if pivot_abs <= EPS {
448 for r in (k + 1)..rows {
450 matrix.set(r, k, Complex64::new(0.0, 0.0));
451 }
452 continue;
453 }
454
455 let pivot_value = matrix.get(k, k);
456 for r in (k + 1)..rows {
457 let factor = matrix.get(r, k) / pivot_value;
458 matrix.set(r, k, factor);
459 for c in (k + 1)..cols {
460 let updated = matrix.get(r, c) - factor * matrix.get(k, c);
461 matrix.set(r, c, updated);
462 }
463 }
464 }
465
466 let combined = matrix.clone();
467 let lower = build_lower(&matrix);
468 let upper = build_upper(&matrix);
469 let permutation = build_permutation(rows, &perm);
470 let pivot_vector: Vec<f64> = perm.iter().map(|idx| (*idx + 1) as f64).collect();
471
472 Ok(LuComponents {
473 combined,
474 lower,
475 upper,
476 permutation,
477 pivot_vector,
478 })
479}
480
481fn build_lower(matrix: &RowMajorMatrix) -> RowMajorMatrix {
482 let rows = matrix.rows;
483 let cols = matrix.cols;
484 let min_dim = rows.min(cols);
485 let mut lower = RowMajorMatrix::identity(rows);
486 for i in 0..rows {
487 for j in 0..min_dim {
488 if i > j {
489 lower.set(i, j, matrix.get(i, j));
490 }
491 }
492 }
493 lower
494}
495
496fn build_upper(matrix: &RowMajorMatrix) -> RowMajorMatrix {
497 let rows = matrix.rows;
498 let cols = matrix.cols;
499 let mut upper = RowMajorMatrix::zeros(rows, cols);
500 for i in 0..rows {
501 for j in 0..cols {
502 if i <= j {
503 upper.set(i, j, matrix.get(i, j));
504 }
505 }
506 }
507 upper
508}
509
510fn build_permutation(rows: usize, perm: &[usize]) -> RowMajorMatrix {
511 let mut matrix = RowMajorMatrix::zeros(rows, rows);
512 for (i, &col) in perm.iter().enumerate() {
513 if col < rows {
514 matrix.set(i, col, Complex64::new(1.0, 0.0));
515 }
516 }
517 matrix
518}
519
520const EPS: f64 = 1.0e-12;
521
522fn matrix_to_value(matrix: &RowMajorMatrix) -> Result<Value, String> {
523 let mut has_imag = false;
524 for val in &matrix.data {
525 if val.im.abs() > EPS {
526 has_imag = true;
527 break;
528 }
529 }
530 if has_imag {
531 let mut data = Vec::with_capacity(matrix.rows * matrix.cols);
532 for col in 0..matrix.cols {
533 for row in 0..matrix.rows {
534 let idx = row * matrix.cols + col;
535 let v = matrix.data[idx];
536 data.push((v.re, v.im));
537 }
538 }
539 let tensor = ComplexTensor::new(data, vec![matrix.rows, matrix.cols])
540 .map_err(|e| format!("lu: {e}"))?;
541 Ok(Value::ComplexTensor(tensor))
542 } else {
543 let mut data = Vec::with_capacity(matrix.rows * matrix.cols);
544 for col in 0..matrix.cols {
545 for row in 0..matrix.rows {
546 let idx = row * matrix.cols + col;
547 data.push(matrix.data[idx].re);
548 }
549 }
550 let tensor =
551 Tensor::new(data, vec![matrix.rows, matrix.cols]).map_err(|e| format!("lu: {e}"))?;
552 Ok(Value::Tensor(tensor))
553 }
554}
555
556fn pivot_vector_to_value(pivot: &[f64]) -> Result<Value, String> {
557 let rows = pivot.len();
558 let tensor = Tensor::new(pivot.to_vec(), vec![rows, 1]).map_err(|e| format!("lu: {e}"))?;
559 Ok(Value::Tensor(tensor))
560}
561
562#[derive(Clone)]
563struct RowMajorMatrix {
564 rows: usize,
565 cols: usize,
566 data: Vec<Complex64>,
567}
568
569impl RowMajorMatrix {
570 fn zeros(rows: usize, cols: usize) -> Self {
571 Self {
572 rows,
573 cols,
574 data: vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)],
575 }
576 }
577
578 fn identity(size: usize) -> Self {
579 let mut matrix = Self::zeros(size, size);
580 for i in 0..size {
581 matrix.set(i, i, Complex64::new(1.0, 0.0));
582 }
583 matrix
584 }
585
586 fn from_scalar(value: Complex64) -> Self {
587 Self {
588 rows: 1,
589 cols: 1,
590 data: vec![value],
591 }
592 }
593
594 fn from_tensor(tensor: &Tensor) -> Result<Self, String> {
595 if tensor.shape.len() > 2 {
596 return Err("lu: input must be 2-D".to_string());
597 }
598 let rows = tensor.rows();
599 let cols = tensor.cols();
600 let mut data = vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)];
601 for col in 0..cols {
602 for row in 0..rows {
603 let idx_col_major = row + col * rows;
604 let idx_row_major = row * cols + col;
605 data[idx_row_major] = Complex64::new(tensor.data[idx_col_major], 0.0);
606 }
607 }
608 Ok(Self { rows, cols, data })
609 }
610
611 fn from_complex_tensor(tensor: &ComplexTensor) -> Result<Self, String> {
612 if tensor.shape.len() > 2 {
613 return Err("lu: input must be 2-D".to_string());
614 }
615 let rows = tensor.rows;
616 let cols = tensor.cols;
617 let mut data = vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)];
618 for col in 0..cols {
619 for row in 0..rows {
620 let idx_col_major = row + col * rows;
621 let idx_row_major = row * cols + col;
622 let (re, im) = tensor.data[idx_col_major];
623 data[idx_row_major] = Complex64::new(re, im);
624 }
625 }
626 Ok(Self { rows, cols, data })
627 }
628
629 fn get(&self, row: usize, col: usize) -> Complex64 {
630 self.data[row * self.cols + col]
631 }
632
633 fn set(&mut self, row: usize, col: usize, value: Complex64) {
634 self.data[row * self.cols + col] = value;
635 }
636
637 fn swap_rows(&mut self, r1: usize, r2: usize) {
638 if r1 == r2 {
639 return;
640 }
641 for col in 0..self.cols {
642 self.data.swap(r1 * self.cols + col, r2 * self.cols + col);
643 }
644 }
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650 use crate::builtins::common::test_support;
651 use runmat_builtins::{ComplexTensor as CMatrix, Tensor as Matrix};
652
653 fn tensor_from_value(value: Value) -> Matrix {
654 match value {
655 Value::Tensor(t) => t,
656 other => panic!("expected dense tensor, got {other:?}"),
657 }
658 }
659
660 fn row_major_from_value(value: Value) -> RowMajorMatrix {
661 match value {
662 Value::Tensor(t) => RowMajorMatrix::from_tensor(&t).expect("row-major tensor"),
663 Value::ComplexTensor(ct) => {
664 RowMajorMatrix::from_complex_tensor(&ct).expect("row-major complex tensor")
665 }
666 other => panic!("expected tensor value, got {other:?}"),
667 }
668 }
669
670 fn row_major_matmul(a: &RowMajorMatrix, b: &RowMajorMatrix) -> RowMajorMatrix {
671 assert_eq!(a.cols, b.rows, "incompatible shapes for matmul");
672 let mut out = RowMajorMatrix::zeros(a.rows, b.cols);
673 for i in 0..a.rows {
674 for k in 0..a.cols {
675 let aik = a.get(i, k);
676 for j in 0..b.cols {
677 let acc = out.get(i, j) + aik * b.get(k, j);
678 out.set(i, j, acc);
679 }
680 }
681 }
682 out
683 }
684
685 fn assert_tensor_close(a: &Matrix, b: &Matrix, tol: f64) {
686 assert_eq!(a.shape, b.shape);
687 for (lhs, rhs) in a.data.iter().zip(&b.data) {
688 assert!(
689 (lhs - rhs).abs() <= tol,
690 "mismatch: lhs={lhs}, rhs={rhs}, tol={tol}"
691 );
692 }
693 }
694
695 fn assert_row_major_close(a: &RowMajorMatrix, b: &RowMajorMatrix, tol: f64) {
696 assert_eq!(a.rows, b.rows, "row mismatch");
697 assert_eq!(a.cols, b.cols, "col mismatch");
698 for row in 0..a.rows {
699 for col in 0..a.cols {
700 let lhs = a.get(row, col);
701 let rhs = b.get(row, col);
702 let diff = (lhs - rhs).norm();
703 assert!(
704 diff <= tol,
705 "mismatch at ({row}, {col}): lhs={lhs:?}, rhs={rhs:?}, diff={diff}, tol={tol}"
706 );
707 }
708 }
709 }
710
711 #[test]
712 fn lu_single_output_produces_combined_matrix() {
713 let a = Matrix::new(
714 vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0],
715 vec![3, 3],
716 )
717 .unwrap();
718 let result = lu_builtin(Value::Tensor(a.clone()), Vec::new()).expect("lu");
719 let lu = tensor_from_value(result);
720 let eval = evaluate(Value::Tensor(a), &[]).expect("evaluate");
721 let expected = tensor_from_value(eval.combined());
722 assert_tensor_close(&lu, &expected, 1e-12);
723 }
724
725 #[test]
726 fn lu_three_outputs_matches_factorization() {
727 let data = vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0];
728 let a = Matrix::new(data.clone(), vec![3, 3]).unwrap();
729 let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate");
730 let l = tensor_from_value(eval.lower());
731 let u = tensor_from_value(eval.upper());
732 let p = tensor_from_value(eval.permutation_matrix());
733
734 let pa = crate::matrix::matrix_mul(&p, &a).expect("P*A");
735 let lu_product = crate::matrix::matrix_mul(&l, &u).expect("L*U");
736 assert_tensor_close(&pa, &lu_product, 1e-9);
737 }
738
739 #[test]
740 fn lu_complex_matrix_factorization() {
741 let data = vec![(1.0, 2.0), (3.0, -1.0), (2.0, -1.0), (4.0, 2.0)];
742 let a = CMatrix::new(data.clone(), vec![2, 2]).expect("complex tensor");
743 let eval = evaluate(Value::ComplexTensor(a.clone()), &[]).expect("evaluate complex");
744
745 let l = row_major_from_value(eval.lower());
746 let u = row_major_from_value(eval.upper());
747 let p = row_major_from_value(eval.permutation_matrix());
748 let input = RowMajorMatrix::from_complex_tensor(&a).expect("row-major input");
749
750 let pa = row_major_matmul(&p, &input);
751 let lu = row_major_matmul(&l, &u);
752 assert_row_major_close(&pa, &lu, 1e-9);
753 }
754
755 #[test]
756 fn lu_handles_singular_matrix() {
757 let a = Matrix::new(vec![0.0, 0.0, 0.0, 0.0], vec![2, 2]).unwrap();
758 let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate singular");
759 let l = tensor_from_value(eval.lower());
760 let u = tensor_from_value(eval.upper());
761 let p = tensor_from_value(eval.permutation_matrix());
762
763 assert!(u.data.iter().any(|&v| v.abs() <= 1e-12));
764
765 let pa = crate::matrix::matrix_mul(&p, &a).expect("P*A");
766 let lu_product = crate::matrix::matrix_mul(&l, &u).expect("L*U");
767 assert_tensor_close(&pa, &lu_product, 1e-9);
768 }
769
770 #[test]
771 fn lu_vector_option_returns_pivot_vector() {
772 let a = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
773 let eval =
774 evaluate(Value::Tensor(a), &[Value::from("vector")]).expect("evaluate vector mode");
775 assert_eq!(eval.pivot_mode(), PivotMode::Vector);
776 let pivot = tensor_from_value(eval.pivot_vector());
777 assert_eq!(pivot.shape, vec![2, 1]);
778 assert_eq!(pivot.data, vec![2.0, 1.0]);
779 }
780
781 #[test]
782 fn lu_vector_option_case_insensitive() {
783 let a = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
784 let eval =
785 evaluate(Value::Tensor(a), &[Value::from("VECTOR")]).expect("evaluate vector option");
786 assert_eq!(eval.pivot_mode(), PivotMode::Vector);
787 }
788
789 #[test]
790 fn lu_matrix_option_returns_permutation_matrix() {
791 let a = Matrix::new(vec![2.0, 1.0, 3.0, 4.0], vec![2, 2]).unwrap();
792 let eval =
793 evaluate(Value::Tensor(a), &[Value::from("matrix")]).expect("evaluate matrix option");
794 assert_eq!(eval.pivot_mode(), PivotMode::Matrix);
795 let perm_selected = tensor_from_value(eval.permutation());
796 let perm_matrix = tensor_from_value(eval.permutation_matrix());
797 assert_eq!(perm_selected.shape, perm_matrix.shape);
798 assert_tensor_close(&perm_selected, &perm_matrix, 1e-12);
799 }
800
801 #[test]
802 fn lu_handles_rectangular_matrices() {
803 let a = Matrix::new(vec![3.0, 6.0, 1.0, 3.0, 2.0, 4.0], vec![2, 3]).unwrap();
804 let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate rectangular");
805 let l = tensor_from_value(eval.lower());
806 let u = tensor_from_value(eval.upper());
807 let p = tensor_from_value(eval.permutation_matrix());
808 assert_eq!(l.shape, vec![2, 2]);
809 assert_eq!(u.shape, vec![2, 3]);
810 assert_eq!(p.shape, vec![2, 2]);
811
812 let pa = crate::matrix::matrix_mul(&p, &a).expect("P*A");
813 let lu_product = crate::matrix::matrix_mul(&l, &u).expect("L*U");
814 assert_tensor_close(&pa, &lu_product, 1e-9);
815 }
816
817 #[test]
818 fn lu_rejects_unknown_option() {
819 let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
820 let err = match evaluate(Value::Tensor(a), &[Value::from("invalid")]) {
821 Ok(_) => panic!("expected option parse failure"),
822 Err(err) => err,
823 };
824 assert!(err.contains("unknown option"));
825 }
826
827 #[test]
828 fn lu_rejects_non_string_option() {
829 let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
830 let err = match evaluate(Value::Tensor(a), &[Value::Num(2.0)]) {
831 Ok(_) => panic!("expected option parse failure"),
832 Err(err) => err,
833 };
834 assert!(err.contains("unknown option"));
835 }
836
837 #[test]
838 fn lu_rejects_multiple_options() {
839 let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
840 let err = match evaluate(
841 Value::Tensor(a),
842 &[Value::from("matrix"), Value::from("vector")],
843 ) {
844 Ok(_) => panic!("expected option arity failure"),
845 Err(err) => err,
846 };
847 assert!(err.contains("too many option arguments"));
848 }
849
850 #[test]
851 fn lu_gpu_provider_roundtrip() {
852 test_support::with_test_provider(|provider| {
853 let host = Matrix::new(vec![10.0, 3.0, 7.0, 2.0], vec![2, 2]).unwrap();
854 let view = runmat_accelerate_api::HostTensorView {
855 data: &host.data,
856 shape: &host.shape,
857 };
858 let handle = provider.upload(&view).expect("upload");
859 let eval = evaluate(Value::GpuTensor(handle.clone()), &[]).expect("evaluate gpu input");
860 let lower_val = eval.lower();
861 let upper_val = eval.upper();
862 let perm_val = eval.permutation_matrix();
863 assert!(matches!(lower_val, Value::GpuTensor(_)));
864 assert!(matches!(upper_val, Value::GpuTensor(_)));
865 assert!(matches!(perm_val, Value::GpuTensor(_)));
866 let l = test_support::gather(lower_val).expect("gather lower");
867 let u = test_support::gather(upper_val).expect("gather upper");
868 let p = test_support::gather(perm_val).expect("gather permutation");
869 let pa = crate::matrix::matrix_mul(&p, &host).expect("P*A");
870 let lu_product = crate::matrix::matrix_mul(&l, &u).expect("L*U");
871 assert_tensor_close(&pa, &lu_product, 1e-9);
872 });
873 }
874
875 #[test]
876 fn lu_gpu_vector_option_roundtrip() {
877 test_support::with_test_provider(|provider| {
878 let host = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
879 let view = runmat_accelerate_api::HostTensorView {
880 data: &host.data,
881 shape: &host.shape,
882 };
883 let handle = provider.upload(&view).expect("upload");
884 let eval =
885 evaluate(Value::GpuTensor(handle), &[Value::from("vector")]).expect("gpu vector");
886 let pivot_val = eval.permutation();
887 assert!(matches!(pivot_val, Value::GpuTensor(_)));
888 let pivot = test_support::gather(pivot_val).expect("gather pivot");
889 assert_eq!(pivot.shape, vec![2, 1]);
890 let expected = Matrix::new(vec![2.0, 1.0], vec![2, 1]).unwrap();
891 assert_tensor_close(&pivot, &expected, 1e-12);
892 });
893 }
894
895 #[test]
896 fn lu_accepts_scalar_inputs() {
897 let eval = evaluate(Value::Num(5.0), &[]).expect("evaluate scalar");
898 let l = tensor_from_value(eval.lower());
899 let u = tensor_from_value(eval.upper());
900 let p = tensor_from_value(eval.permutation_matrix());
901 assert_eq!(l.data, vec![1.0]);
902 assert_eq!(u.data, vec![5.0]);
903 assert_eq!(p.data, vec![1.0]);
904 }
905
906 #[test]
907 #[cfg(feature = "wgpu")]
908 fn lu_wgpu_matches_cpu() {
909 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
910 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
911 );
912 let host = Matrix::new(
913 vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0],
914 vec![3, 3],
915 )
916 .unwrap();
917 let cpu_eval = evaluate(Value::Tensor(host.clone()), &[]).expect("cpu evaluate");
918 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
919 let view = runmat_accelerate_api::HostTensorView {
920 data: &host.data,
921 shape: &host.shape,
922 };
923 let handle = provider.upload(&view).expect("upload");
924 let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu evaluate");
925
926 let l_cpu = tensor_from_value(cpu_eval.lower());
927 let u_cpu = tensor_from_value(cpu_eval.upper());
928 let p_cpu = tensor_from_value(cpu_eval.permutation_matrix());
929 let lu_cpu = tensor_from_value(cpu_eval.combined());
930
931 let l_gpu = test_support::gather(gpu_eval.lower()).expect("gather L");
932 let u_gpu = test_support::gather(gpu_eval.upper()).expect("gather U");
933 let p_gpu = test_support::gather(gpu_eval.permutation_matrix()).expect("gather P");
934 let lu_gpu = test_support::gather(gpu_eval.combined()).expect("gather LU");
935
936 assert_tensor_close(&l_cpu, &l_gpu, 1e-12);
937 assert_tensor_close(&u_cpu, &u_gpu, 1e-12);
938 assert_tensor_close(&p_cpu, &p_gpu, 1e-12);
939 assert_tensor_close(&lu_cpu, &lu_gpu, 1e-12);
940
941 let pivot_cpu = tensor_from_value(cpu_eval.pivot_vector());
942 let pivot_gpu = test_support::gather(gpu_eval.pivot_vector()).expect("gather pivot vector");
943 assert_tensor_close(&pivot_cpu, &pivot_gpu, 1e-12);
944
945 let handle_vector = provider.upload(&view).expect("upload vector option");
946 let gpu_vector_eval = evaluate(Value::GpuTensor(handle_vector), &[Value::from("vector")])
947 .expect("gpu vector evaluate");
948 let pivot_vector =
949 test_support::gather(gpu_vector_eval.permutation()).expect("gather vector pivot");
950 assert_tensor_close(&pivot_cpu, &pivot_vector, 1e-12);
951 }
952
953 #[test]
954 #[cfg(feature = "doc_export")]
955 fn doc_examples_present() {
956 let blocks = test_support::doc_examples(DOC_MD);
957 assert!(!blocks.is_empty());
958 }
959}