Skip to main content

runmat_runtime/builtins/math/linalg/factor/
lu.rs

1//! MATLAB-compatible `lu` builtin with CPU-backed semantics.
2
3use crate::builtins::common::spec::{
4    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
5    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
6};
7use crate::builtins::common::{gpu_helpers, tensor};
8use crate::builtins::math::linalg::type_resolvers::matrix_unary_type;
9use crate::{build_runtime_error, BuiltinResult, RuntimeError};
10
11use num_complex::Complex64;
12use runmat_accelerate_api::{GpuTensorHandle, ProviderLuResult};
13use runmat_builtins::{ComplexTensor, Tensor, Value};
14use runmat_macros::runtime_builtin;
15
16const BUILTIN_NAME: &str = "lu";
17
18#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::linalg::factor::lu")]
19pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
20    name: "lu",
21    op_kind: GpuOpKind::Custom("lu-factor"),
22    supported_precisions: &[ScalarType::F64],
23    broadcast: BroadcastSemantics::None,
24    provider_hooks: &[ProviderHook::Custom("lu")],
25    constant_strategy: ConstantStrategy::InlineLiteral,
26    residency: ResidencyPolicy::NewHandle,
27    nan_mode: ReductionNaN::Include,
28    two_pass_threshold: None,
29    workgroup_size: None,
30    accepts_nan_mode: false,
31    notes: "Prefers the provider `lu` hook; automatically gathers and falls back to the CPU implementation when no provider support is registered.",
32};
33
34fn lu_error(message: impl Into<String>) -> RuntimeError {
35    build_runtime_error(message)
36        .with_builtin(BUILTIN_NAME)
37        .build()
38}
39
40fn with_lu_context(mut error: RuntimeError) -> RuntimeError {
41    if error.message() == "interaction pending..." {
42        return build_runtime_error("interaction pending...")
43            .with_builtin(BUILTIN_NAME)
44            .build();
45    }
46    if error.context.builtin.is_none() {
47        error.context = error.context.with_builtin(BUILTIN_NAME);
48    }
49    error
50}
51
52#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::linalg::factor::lu")]
53pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
54    name: "lu",
55    shape: ShapeRequirements::Any,
56    constant_strategy: ConstantStrategy::InlineLiteral,
57    elementwise: None,
58    reduction: None,
59    emits_nan: false,
60    notes: "LU decomposition is not part of expression fusion; calls execute eagerly on the CPU.",
61};
62
63#[runtime_builtin(
64    name = "lu",
65    category = "math/linalg/factor",
66    summary = "LU decomposition with partial pivoting.",
67    keywords = "lu,factorization,decomposition,permutation",
68    accel = "sink",
69    sink = true,
70    type_resolver(matrix_unary_type),
71    builtin_path = "crate::builtins::math::linalg::factor::lu"
72)]
73async fn lu_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
74    let eval = evaluate(value, &rest).await?;
75    if let Some(out_count) = crate::output_count::current_output_count() {
76        if out_count == 0 {
77            return Ok(Value::OutputList(Vec::new()));
78        }
79        if out_count == 1 {
80            return Ok(Value::OutputList(vec![eval.combined()]));
81        }
82        if out_count == 2 {
83            return Ok(Value::OutputList(vec![eval.lower(), eval.upper()]));
84        }
85        if out_count == 3 {
86            return Ok(Value::OutputList(vec![
87                eval.lower(),
88                eval.upper(),
89                eval.permutation(),
90            ]));
91        }
92        return Err(lu_error("lu currently supports at most three outputs"));
93    }
94    Ok(eval.combined())
95}
96
97/// Output form for `lu`, reused by both the builtin wrapper and the VM multi-output path.
98#[derive(Clone)]
99pub struct LuEval {
100    combined: Value,
101    lower: Value,
102    upper: Value,
103    perm_matrix: Value,
104    perm_vector: Value,
105    pivot_mode: PivotMode,
106}
107
108impl LuEval {
109    /// Combined LU factor (single-output form).
110    pub fn combined(&self) -> Value {
111        self.combined.clone()
112    }
113
114    /// Lower-triangular factor.
115    pub fn lower(&self) -> Value {
116        self.lower.clone()
117    }
118
119    /// Upper-triangular factor.
120    pub fn upper(&self) -> Value {
121        self.upper.clone()
122    }
123
124    /// Permutation value respecting the selected pivot mode.
125    pub fn permutation(&self) -> Value {
126        match self.pivot_mode {
127            PivotMode::Matrix => self.perm_matrix.clone(),
128            PivotMode::Vector => self.perm_vector.clone(),
129        }
130    }
131
132    /// Permutation matrix (always available, useful for tests).
133    pub fn permutation_matrix(&self) -> Value {
134        self.perm_matrix.clone()
135    }
136
137    /// Pivot vector (always available, useful for tests).
138    pub fn pivot_vector(&self) -> Value {
139        self.perm_vector.clone()
140    }
141
142    /// The pivot mode that was requested.
143    pub fn pivot_mode(&self) -> PivotMode {
144        self.pivot_mode
145    }
146
147    fn from_components(components: LuComponents, pivot_mode: PivotMode) -> BuiltinResult<Self> {
148        let combined = matrix_to_value(&components.combined)?;
149        let lower = matrix_to_value(&components.lower)?;
150        let upper = matrix_to_value(&components.upper)?;
151        let perm_matrix = matrix_to_value(&components.permutation)?;
152        let perm_vector = pivot_vector_to_value(&components.pivot_vector)?;
153        Ok(Self {
154            combined,
155            lower,
156            upper,
157            perm_matrix,
158            perm_vector,
159            pivot_mode,
160        })
161    }
162
163    fn from_provider(result: ProviderLuResult, pivot_mode: PivotMode) -> Self {
164        Self {
165            combined: Value::GpuTensor(result.combined),
166            lower: Value::GpuTensor(result.lower),
167            upper: Value::GpuTensor(result.upper),
168            perm_matrix: Value::GpuTensor(result.perm_matrix),
169            perm_vector: Value::GpuTensor(result.perm_vector),
170            pivot_mode,
171        }
172    }
173}
174
175/// Permutation output mode.
176#[derive(Clone, Copy, Debug, PartialEq, Eq)]
177pub enum PivotMode {
178    Matrix,
179    Vector,
180}
181
182impl Default for PivotMode {
183    fn default() -> Self {
184        Self::Matrix
185    }
186}
187
188/// Evaluate `lu` while preserving all output forms for later extraction.
189pub async fn evaluate(value: Value, args: &[Value]) -> BuiltinResult<LuEval> {
190    let pivot_mode = parse_pivot_mode(args)?;
191    match value {
192        Value::GpuTensor(handle) => {
193            if let Some(eval) = evaluate_gpu(&handle, pivot_mode).await? {
194                return Ok(eval);
195            }
196            let tensor = gpu_helpers::gather_tensor_async(&handle)
197                .await
198                .map_err(with_lu_context)?;
199            evaluate_host_value(Value::Tensor(tensor), pivot_mode).await
200        }
201        other => evaluate_host_value(other, pivot_mode).await,
202    }
203}
204
205async fn evaluate_host_value(value: Value, pivot_mode: PivotMode) -> BuiltinResult<LuEval> {
206    let matrix = extract_matrix(value).await?;
207    let components = lu_factor(matrix)?;
208    LuEval::from_components(components, pivot_mode)
209}
210
211async fn evaluate_gpu(
212    handle: &GpuTensorHandle,
213    pivot_mode: PivotMode,
214) -> BuiltinResult<Option<LuEval>> {
215    if let Some(provider) = runmat_accelerate_api::provider() {
216        if let Ok(result) = provider.lu(handle).await {
217            return Ok(Some(LuEval::from_provider(result, pivot_mode)));
218        }
219    }
220    Ok(None)
221}
222
223fn parse_pivot_mode(args: &[Value]) -> BuiltinResult<PivotMode> {
224    if args.is_empty() {
225        return Ok(PivotMode::Matrix);
226    }
227    if args.len() > 1 {
228        return Err(lu_error("lu: too many option arguments"));
229    }
230    let Some(option) = tensor::value_to_string(&args[0]) else {
231        return Err(lu_error("lu: option must be a string or character vector"));
232    };
233    match option.trim().to_ascii_lowercase().as_str() {
234        "matrix" => Ok(PivotMode::Matrix),
235        "vector" => Ok(PivotMode::Vector),
236        other => Err(lu_error(format!("lu: unknown option '{other}'"))),
237    }
238}
239
240async fn extract_matrix(value: Value) -> BuiltinResult<RowMajorMatrix> {
241    match value {
242        Value::Tensor(t) => RowMajorMatrix::from_tensor(&t),
243        Value::ComplexTensor(ct) => RowMajorMatrix::from_complex_tensor(&ct),
244        Value::GpuTensor(handle) => {
245            let tensor = gpu_helpers::gather_tensor_async(&handle)
246                .await
247                .map_err(with_lu_context)?;
248            RowMajorMatrix::from_tensor(&tensor)
249        }
250        Value::LogicalArray(logical) => {
251            let tensor = tensor::logical_to_tensor(&logical)
252                .map_err(|err| lu_error(format!("lu: {err}")))?;
253            RowMajorMatrix::from_tensor(&tensor)
254        }
255        Value::Num(n) => Ok(RowMajorMatrix::from_scalar(Complex64::new(n, 0.0))),
256        Value::Int(i) => Ok(RowMajorMatrix::from_scalar(Complex64::new(i.to_f64(), 0.0))),
257        Value::Bool(b) => Ok(RowMajorMatrix::from_scalar(Complex64::new(
258            if b { 1.0 } else { 0.0 },
259            0.0,
260        ))),
261        Value::Complex(re, im) => Ok(RowMajorMatrix::from_scalar(Complex64::new(re, im))),
262        Value::CharArray(_) | Value::String(_) | Value::StringArray(_) => Err(lu_error(
263            "lu: character data is not supported; convert to numeric values first",
264        )),
265        other => Err(lu_error(format!("lu: unsupported input type {:?}", other))),
266    }
267}
268
269struct LuComponents {
270    combined: RowMajorMatrix,
271    lower: RowMajorMatrix,
272    upper: RowMajorMatrix,
273    permutation: RowMajorMatrix,
274    pivot_vector: Vec<f64>,
275}
276
277fn lu_factor(mut matrix: RowMajorMatrix) -> BuiltinResult<LuComponents> {
278    let rows = matrix.rows;
279    let cols = matrix.cols;
280    let min_dim = rows.min(cols);
281    let mut perm: Vec<usize> = (0..rows).collect();
282
283    for k in 0..min_dim {
284        // Select pivot row with maximal absolute value in column k.
285        let mut pivot_row = k;
286        let mut pivot_abs = 0.0;
287        for r in k..rows {
288            let val = matrix.get(r, k);
289            let abs = val.norm();
290            if abs > pivot_abs {
291                pivot_abs = abs;
292                pivot_row = r;
293            }
294        }
295
296        if pivot_row != k {
297            matrix.swap_rows(pivot_row, k);
298            perm.swap(pivot_row, k);
299        }
300
301        if pivot_abs <= EPS {
302            // Entire column is effectively zero; set multipliers to zero and continue.
303            for r in (k + 1)..rows {
304                matrix.set(r, k, Complex64::new(0.0, 0.0));
305            }
306            continue;
307        }
308
309        let pivot_value = matrix.get(k, k);
310        for r in (k + 1)..rows {
311            let factor = matrix.get(r, k) / pivot_value;
312            matrix.set(r, k, factor);
313            for c in (k + 1)..cols {
314                let updated = matrix.get(r, c) - factor * matrix.get(k, c);
315                matrix.set(r, c, updated);
316            }
317        }
318    }
319
320    let combined = matrix.clone();
321    let lower = build_lower(&matrix);
322    let upper = build_upper(&matrix);
323    let permutation = build_permutation(rows, &perm);
324    let pivot_vector: Vec<f64> = perm.iter().map(|idx| (*idx + 1) as f64).collect();
325
326    Ok(LuComponents {
327        combined,
328        lower,
329        upper,
330        permutation,
331        pivot_vector,
332    })
333}
334
335fn build_lower(matrix: &RowMajorMatrix) -> RowMajorMatrix {
336    let rows = matrix.rows;
337    let cols = matrix.cols;
338    let min_dim = rows.min(cols);
339    let mut lower = RowMajorMatrix::identity(rows);
340    for i in 0..rows {
341        for j in 0..min_dim {
342            if i > j {
343                lower.set(i, j, matrix.get(i, j));
344            }
345        }
346    }
347    lower
348}
349
350fn build_upper(matrix: &RowMajorMatrix) -> RowMajorMatrix {
351    let rows = matrix.rows;
352    let cols = matrix.cols;
353    let mut upper = RowMajorMatrix::zeros(rows, cols);
354    for i in 0..rows {
355        for j in 0..cols {
356            if i <= j {
357                upper.set(i, j, matrix.get(i, j));
358            }
359        }
360    }
361    upper
362}
363
364fn build_permutation(rows: usize, perm: &[usize]) -> RowMajorMatrix {
365    let mut matrix = RowMajorMatrix::zeros(rows, rows);
366    for (i, &col) in perm.iter().enumerate() {
367        if col < rows {
368            matrix.set(i, col, Complex64::new(1.0, 0.0));
369        }
370    }
371    matrix
372}
373
374const EPS: f64 = 1.0e-12;
375
376fn matrix_to_value(matrix: &RowMajorMatrix) -> BuiltinResult<Value> {
377    let mut has_imag = false;
378    for val in &matrix.data {
379        if val.im.abs() > EPS {
380            has_imag = true;
381            break;
382        }
383    }
384    if has_imag {
385        let mut data = Vec::with_capacity(matrix.rows * matrix.cols);
386        for col in 0..matrix.cols {
387            for row in 0..matrix.rows {
388                let idx = row * matrix.cols + col;
389                let v = matrix.data[idx];
390                data.push((v.re, v.im));
391            }
392        }
393        let tensor = ComplexTensor::new(data, vec![matrix.rows, matrix.cols])
394            .map_err(|e| lu_error(format!("lu: {e}")))?;
395        Ok(Value::ComplexTensor(tensor))
396    } else {
397        let mut data = Vec::with_capacity(matrix.rows * matrix.cols);
398        for col in 0..matrix.cols {
399            for row in 0..matrix.rows {
400                let idx = row * matrix.cols + col;
401                data.push(matrix.data[idx].re);
402            }
403        }
404        let tensor = Tensor::new(data, vec![matrix.rows, matrix.cols])
405            .map_err(|e| lu_error(format!("lu: {e}")))?;
406        Ok(Value::Tensor(tensor))
407    }
408}
409
410fn pivot_vector_to_value(pivot: &[f64]) -> BuiltinResult<Value> {
411    let rows = pivot.len();
412    let tensor =
413        Tensor::new(pivot.to_vec(), vec![rows, 1]).map_err(|e| lu_error(format!("lu: {e}")))?;
414    Ok(Value::Tensor(tensor))
415}
416
417#[derive(Clone)]
418struct RowMajorMatrix {
419    rows: usize,
420    cols: usize,
421    data: Vec<Complex64>,
422}
423
424impl RowMajorMatrix {
425    fn zeros(rows: usize, cols: usize) -> Self {
426        Self {
427            rows,
428            cols,
429            data: vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)],
430        }
431    }
432
433    fn identity(size: usize) -> Self {
434        let mut matrix = Self::zeros(size, size);
435        for i in 0..size {
436            matrix.set(i, i, Complex64::new(1.0, 0.0));
437        }
438        matrix
439    }
440
441    fn from_scalar(value: Complex64) -> Self {
442        Self {
443            rows: 1,
444            cols: 1,
445            data: vec![value],
446        }
447    }
448
449    fn from_tensor(tensor: &Tensor) -> BuiltinResult<Self> {
450        if tensor.shape.len() > 2 {
451            return Err(lu_error("lu: input must be 2-D"));
452        }
453        let rows = tensor.rows();
454        let cols = tensor.cols();
455        let mut data = vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)];
456        for col in 0..cols {
457            for row in 0..rows {
458                let idx_col_major = row + col * rows;
459                let idx_row_major = row * cols + col;
460                data[idx_row_major] = Complex64::new(tensor.data[idx_col_major], 0.0);
461            }
462        }
463        Ok(Self { rows, cols, data })
464    }
465
466    fn from_complex_tensor(tensor: &ComplexTensor) -> BuiltinResult<Self> {
467        if tensor.shape.len() > 2 {
468            return Err(lu_error("lu: input must be 2-D"));
469        }
470        let rows = tensor.rows;
471        let cols = tensor.cols;
472        let mut data = vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)];
473        for col in 0..cols {
474            for row in 0..rows {
475                let idx_col_major = row + col * rows;
476                let idx_row_major = row * cols + col;
477                let (re, im) = tensor.data[idx_col_major];
478                data[idx_row_major] = Complex64::new(re, im);
479            }
480        }
481        Ok(Self { rows, cols, data })
482    }
483
484    fn get(&self, row: usize, col: usize) -> Complex64 {
485        self.data[row * self.cols + col]
486    }
487
488    fn set(&mut self, row: usize, col: usize, value: Complex64) {
489        self.data[row * self.cols + col] = value;
490    }
491
492    fn swap_rows(&mut self, r1: usize, r2: usize) {
493        if r1 == r2 {
494            return;
495        }
496        for col in 0..self.cols {
497            self.data.swap(r1 * self.cols + col, r2 * self.cols + col);
498        }
499    }
500}
501
502#[cfg(test)]
503pub(crate) mod tests {
504    use super::*;
505    use crate::builtins::common::test_support;
506    use futures::executor::block_on;
507    use runmat_builtins::{ComplexTensor as CMatrix, ResolveContext, Tensor as Matrix, Type};
508
509    fn error_message(err: RuntimeError) -> String {
510        err.message().to_string()
511    }
512
513    fn tensor_from_value(value: Value) -> Matrix {
514        match value {
515            Value::Tensor(t) => t,
516            other => panic!("expected dense tensor, got {other:?}"),
517        }
518    }
519
520    fn row_major_from_value(value: Value) -> RowMajorMatrix {
521        match value {
522            Value::Tensor(t) => RowMajorMatrix::from_tensor(&t).expect("row-major tensor"),
523            Value::ComplexTensor(ct) => {
524                RowMajorMatrix::from_complex_tensor(&ct).expect("row-major complex tensor")
525            }
526            other => panic!("expected tensor value, got {other:?}"),
527        }
528    }
529
530    #[test]
531    fn lu_type_preserves_matrix_shape() {
532        let out = matrix_unary_type(
533            &[Type::Tensor {
534                shape: Some(vec![Some(2), Some(3)]),
535            }],
536            &ResolveContext::new(Vec::new()),
537        );
538        assert_eq!(
539            out,
540            Type::Tensor {
541                shape: Some(vec![Some(2), Some(3)])
542            }
543        );
544    }
545
546    fn row_major_matmul(a: &RowMajorMatrix, b: &RowMajorMatrix) -> RowMajorMatrix {
547        assert_eq!(a.cols, b.rows, "incompatible shapes for matmul");
548        let mut out = RowMajorMatrix::zeros(a.rows, b.cols);
549        for i in 0..a.rows {
550            for k in 0..a.cols {
551                let aik = a.get(i, k);
552                for j in 0..b.cols {
553                    let acc = out.get(i, j) + aik * b.get(k, j);
554                    out.set(i, j, acc);
555                }
556            }
557        }
558        out
559    }
560
561    fn assert_tensor_close(a: &Matrix, b: &Matrix, tol: f64) {
562        assert_eq!(a.shape, b.shape);
563        for (lhs, rhs) in a.data.iter().zip(&b.data) {
564            assert!(
565                (lhs - rhs).abs() <= tol,
566                "mismatch: lhs={lhs}, rhs={rhs}, tol={tol}"
567            );
568        }
569    }
570
571    fn assert_row_major_close(a: &RowMajorMatrix, b: &RowMajorMatrix, tol: f64) {
572        assert_eq!(a.rows, b.rows, "row mismatch");
573        assert_eq!(a.cols, b.cols, "col mismatch");
574        for row in 0..a.rows {
575            for col in 0..a.cols {
576                let lhs = a.get(row, col);
577                let rhs = b.get(row, col);
578                let diff = (lhs - rhs).norm();
579                assert!(
580                    diff <= tol,
581                    "mismatch at ({row}, {col}): lhs={lhs:?}, rhs={rhs:?}, diff={diff}, tol={tol}"
582                );
583            }
584        }
585    }
586
587    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
588    #[test]
589    fn lu_single_output_produces_combined_matrix() {
590        let a = Matrix::new(
591            vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0],
592            vec![3, 3],
593        )
594        .unwrap();
595        let result = lu_builtin(Value::Tensor(a.clone()), Vec::new()).expect("lu");
596        let lu = tensor_from_value(result);
597        let eval = evaluate(Value::Tensor(a), &[]).expect("evaluate");
598        let expected = tensor_from_value(eval.combined());
599        assert_tensor_close(&lu, &expected, 1e-12);
600    }
601
602    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
603    #[test]
604    fn lu_three_outputs_matches_factorization() {
605        let data = vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0];
606        let a = Matrix::new(data.clone(), vec![3, 3]).unwrap();
607        let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate");
608        let l = tensor_from_value(eval.lower());
609        let u = tensor_from_value(eval.upper());
610        let p = tensor_from_value(eval.permutation_matrix());
611
612        let pa = crate::matrix::matrix_mul(&p, &a).expect("P*A");
613        let lu_product = crate::matrix::matrix_mul(&l, &u).expect("L*U");
614        assert_tensor_close(&pa, &lu_product, 1e-9);
615    }
616
617    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
618    #[test]
619    fn lu_complex_matrix_factorization() {
620        let data = vec![(1.0, 2.0), (3.0, -1.0), (2.0, -1.0), (4.0, 2.0)];
621        let a = CMatrix::new(data.clone(), vec![2, 2]).expect("complex tensor");
622        let eval = evaluate(Value::ComplexTensor(a.clone()), &[]).expect("evaluate complex");
623
624        let l = row_major_from_value(eval.lower());
625        let u = row_major_from_value(eval.upper());
626        let p = row_major_from_value(eval.permutation_matrix());
627        let input = RowMajorMatrix::from_complex_tensor(&a).expect("row-major input");
628
629        let pa = row_major_matmul(&p, &input);
630        let lu = row_major_matmul(&l, &u);
631        assert_row_major_close(&pa, &lu, 1e-9);
632    }
633
634    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
635    #[test]
636    fn lu_handles_singular_matrix() {
637        let a = Matrix::new(vec![0.0, 0.0, 0.0, 0.0], vec![2, 2]).unwrap();
638        let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate singular");
639        let l = tensor_from_value(eval.lower());
640        let u = tensor_from_value(eval.upper());
641        let p = tensor_from_value(eval.permutation_matrix());
642
643        assert!(u.data.iter().any(|&v| v.abs() <= 1e-12));
644
645        let pa = crate::matrix::matrix_mul(&p, &a).expect("P*A");
646        let lu_product = crate::matrix::matrix_mul(&l, &u).expect("L*U");
647        assert_tensor_close(&pa, &lu_product, 1e-9);
648    }
649
650    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
651    #[test]
652    fn lu_vector_option_returns_pivot_vector() {
653        let a = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
654        let eval =
655            evaluate(Value::Tensor(a), &[Value::from("vector")]).expect("evaluate vector mode");
656        assert_eq!(eval.pivot_mode(), PivotMode::Vector);
657        let pivot = tensor_from_value(eval.pivot_vector());
658        assert_eq!(pivot.shape, vec![2, 1]);
659        assert_eq!(pivot.data, vec![2.0, 1.0]);
660    }
661
662    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
663    #[test]
664    fn lu_vector_option_case_insensitive() {
665        let a = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
666        let eval =
667            evaluate(Value::Tensor(a), &[Value::from("VECTOR")]).expect("evaluate vector option");
668        assert_eq!(eval.pivot_mode(), PivotMode::Vector);
669    }
670
671    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
672    #[test]
673    fn lu_matrix_option_returns_permutation_matrix() {
674        let a = Matrix::new(vec![2.0, 1.0, 3.0, 4.0], vec![2, 2]).unwrap();
675        let eval =
676            evaluate(Value::Tensor(a), &[Value::from("matrix")]).expect("evaluate matrix option");
677        assert_eq!(eval.pivot_mode(), PivotMode::Matrix);
678        let perm_selected = tensor_from_value(eval.permutation());
679        let perm_matrix = tensor_from_value(eval.permutation_matrix());
680        assert_eq!(perm_selected.shape, perm_matrix.shape);
681        assert_tensor_close(&perm_selected, &perm_matrix, 1e-12);
682    }
683
684    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
685    #[test]
686    fn lu_handles_rectangular_matrices() {
687        let a = Matrix::new(vec![3.0, 6.0, 1.0, 3.0, 2.0, 4.0], vec![2, 3]).unwrap();
688        let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate rectangular");
689        let l = tensor_from_value(eval.lower());
690        let u = tensor_from_value(eval.upper());
691        let p = tensor_from_value(eval.permutation_matrix());
692        assert_eq!(l.shape, vec![2, 2]);
693        assert_eq!(u.shape, vec![2, 3]);
694        assert_eq!(p.shape, vec![2, 2]);
695
696        let pa = crate::matrix::matrix_mul(&p, &a).expect("P*A");
697        let lu_product = crate::matrix::matrix_mul(&l, &u).expect("L*U");
698        assert_tensor_close(&pa, &lu_product, 1e-9);
699    }
700
701    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
702    #[test]
703    fn lu_rejects_unknown_option() {
704        let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
705        let err = match evaluate(Value::Tensor(a), &[Value::from("invalid")]) {
706            Ok(_) => panic!("expected option parse failure"),
707            Err(err) => error_message(err),
708        };
709        assert!(err.contains("unknown option"));
710    }
711
712    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
713    #[test]
714    fn lu_rejects_non_string_option() {
715        let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
716        let err = match evaluate(Value::Tensor(a), &[Value::Num(2.0)]) {
717            Ok(_) => panic!("expected option parse failure"),
718            Err(err) => error_message(err),
719        };
720        assert!(err.contains("unknown option"));
721    }
722
723    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
724    #[test]
725    fn lu_rejects_multiple_options() {
726        let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
727        let err = match evaluate(
728            Value::Tensor(a),
729            &[Value::from("matrix"), Value::from("vector")],
730        ) {
731            Ok(_) => panic!("expected option arity failure"),
732            Err(err) => error_message(err),
733        };
734        assert!(err.contains("too many option arguments"));
735    }
736
737    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
738    #[test]
739    fn lu_gpu_provider_roundtrip() {
740        test_support::with_test_provider(|provider| {
741            let host = Matrix::new(vec![10.0, 3.0, 7.0, 2.0], vec![2, 2]).unwrap();
742            let view = runmat_accelerate_api::HostTensorView {
743                data: &host.data,
744                shape: &host.shape,
745            };
746            let handle = provider.upload(&view).expect("upload");
747            let eval = evaluate(Value::GpuTensor(handle.clone()), &[]).expect("evaluate gpu input");
748            let lower_val = eval.lower();
749            let upper_val = eval.upper();
750            let perm_val = eval.permutation_matrix();
751            assert!(matches!(lower_val, Value::GpuTensor(_)));
752            assert!(matches!(upper_val, Value::GpuTensor(_)));
753            assert!(matches!(perm_val, Value::GpuTensor(_)));
754            let l = test_support::gather(lower_val).expect("gather lower");
755            let u = test_support::gather(upper_val).expect("gather upper");
756            let p = test_support::gather(perm_val).expect("gather permutation");
757            let pa = crate::matrix::matrix_mul(&p, &host).expect("P*A");
758            let lu_product = crate::matrix::matrix_mul(&l, &u).expect("L*U");
759            assert_tensor_close(&pa, &lu_product, 1e-9);
760        });
761    }
762
763    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
764    #[test]
765    fn lu_gpu_vector_option_roundtrip() {
766        test_support::with_test_provider(|provider| {
767            let host = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
768            let view = runmat_accelerate_api::HostTensorView {
769                data: &host.data,
770                shape: &host.shape,
771            };
772            let handle = provider.upload(&view).expect("upload");
773            let eval =
774                evaluate(Value::GpuTensor(handle), &[Value::from("vector")]).expect("gpu vector");
775            let pivot_val = eval.permutation();
776            assert!(matches!(pivot_val, Value::GpuTensor(_)));
777            let pivot = test_support::gather(pivot_val).expect("gather pivot");
778            assert_eq!(pivot.shape, vec![2, 1]);
779            let expected = Matrix::new(vec![2.0, 1.0], vec![2, 1]).unwrap();
780            assert_tensor_close(&pivot, &expected, 1e-12);
781        });
782    }
783
784    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
785    #[test]
786    fn lu_accepts_scalar_inputs() {
787        let eval = evaluate(Value::Num(5.0), &[]).expect("evaluate scalar");
788        let l = tensor_from_value(eval.lower());
789        let u = tensor_from_value(eval.upper());
790        let p = tensor_from_value(eval.permutation_matrix());
791        assert_eq!(l.data, vec![1.0]);
792        assert_eq!(u.data, vec![5.0]);
793        assert_eq!(p.data, vec![1.0]);
794    }
795
796    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
797    #[test]
798    #[cfg(feature = "wgpu")]
799    fn lu_wgpu_matches_cpu() {
800        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
801            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
802        );
803        let host = Matrix::new(
804            vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0],
805            vec![3, 3],
806        )
807        .unwrap();
808        let cpu_eval = evaluate(Value::Tensor(host.clone()), &[]).expect("cpu evaluate");
809        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
810        let view = runmat_accelerate_api::HostTensorView {
811            data: &host.data,
812            shape: &host.shape,
813        };
814        let handle = provider.upload(&view).expect("upload");
815        let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu evaluate");
816
817        let l_cpu = tensor_from_value(cpu_eval.lower());
818        let u_cpu = tensor_from_value(cpu_eval.upper());
819        let p_cpu = tensor_from_value(cpu_eval.permutation_matrix());
820        let lu_cpu = tensor_from_value(cpu_eval.combined());
821
822        let l_gpu = test_support::gather(gpu_eval.lower()).expect("gather L");
823        let u_gpu = test_support::gather(gpu_eval.upper()).expect("gather U");
824        let p_gpu = test_support::gather(gpu_eval.permutation_matrix()).expect("gather P");
825        let lu_gpu = test_support::gather(gpu_eval.combined()).expect("gather LU");
826
827        assert_tensor_close(&l_cpu, &l_gpu, 1e-12);
828        assert_tensor_close(&u_cpu, &u_gpu, 1e-12);
829        assert_tensor_close(&p_cpu, &p_gpu, 1e-12);
830        assert_tensor_close(&lu_cpu, &lu_gpu, 1e-12);
831
832        let pivot_cpu = tensor_from_value(cpu_eval.pivot_vector());
833        let pivot_gpu = test_support::gather(gpu_eval.pivot_vector()).expect("gather pivot vector");
834        assert_tensor_close(&pivot_cpu, &pivot_gpu, 1e-12);
835
836        let handle_vector = provider.upload(&view).expect("upload vector option");
837        let gpu_vector_eval = evaluate(Value::GpuTensor(handle_vector), &[Value::from("vector")])
838            .expect("gpu vector evaluate");
839        let pivot_vector =
840            test_support::gather(gpu_vector_eval.permutation()).expect("gather vector pivot");
841        assert_tensor_close(&pivot_cpu, &pivot_vector, 1e-12);
842    }
843
844    fn lu_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
845        block_on(super::lu_builtin(value, rest))
846    }
847
848    fn evaluate(value: Value, args: &[Value]) -> BuiltinResult<LuEval> {
849        block_on(super::evaluate(value, args))
850    }
851}