Skip to main content

runmat_runtime/builtins/math/linalg/factor/
lu.rs

1//! MATLAB-compatible `lu` builtin with CPU-backed semantics.
2
3use crate::builtins::common::spec::{
4    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
5    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
6};
7use crate::builtins::common::{gpu_helpers, tensor};
8use crate::builtins::math::linalg::type_resolvers::matrix_unary_type;
9use crate::{build_runtime_error, BuiltinResult, RuntimeError};
10
11use num_complex::Complex64;
12use runmat_accelerate_api::{GpuTensorHandle, ProviderLuResult};
13use runmat_builtins::{
14    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
15    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
16    ComplexTensor, Tensor, Value,
17};
18use runmat_macros::runtime_builtin;
19
20const BUILTIN_NAME: &str = "lu";
21
22const LU_OUTPUT_COMBINED: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
23    name: "LU",
24    ty: BuiltinParamType::NumericArray,
25    arity: BuiltinParamArity::Required,
26    default: None,
27    description: "Combined LU factors.",
28}];
29
30const LU_OUTPUT_LU: [BuiltinParamDescriptor; 2] = [
31    BuiltinParamDescriptor {
32        name: "L",
33        ty: BuiltinParamType::NumericArray,
34        arity: BuiltinParamArity::Required,
35        default: None,
36        description: "Lower-triangular factor.",
37    },
38    BuiltinParamDescriptor {
39        name: "U",
40        ty: BuiltinParamType::NumericArray,
41        arity: BuiltinParamArity::Required,
42        default: None,
43        description: "Upper-triangular factor.",
44    },
45];
46
47const LU_OUTPUT_LUP: [BuiltinParamDescriptor; 3] = [
48    BuiltinParamDescriptor {
49        name: "L",
50        ty: BuiltinParamType::NumericArray,
51        arity: BuiltinParamArity::Required,
52        default: None,
53        description: "Lower-triangular factor.",
54    },
55    BuiltinParamDescriptor {
56        name: "U",
57        ty: BuiltinParamType::NumericArray,
58        arity: BuiltinParamArity::Required,
59        default: None,
60        description: "Upper-triangular factor.",
61    },
62    BuiltinParamDescriptor {
63        name: "P",
64        ty: BuiltinParamType::NumericArray,
65        arity: BuiltinParamArity::Required,
66        default: None,
67        description: "Permutation matrix or vector based on pivot mode.",
68    },
69];
70
71const LU_INPUTS_A: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
72    name: "A",
73    ty: BuiltinParamType::NumericArray,
74    arity: BuiltinParamArity::Required,
75    default: None,
76    description: "Input matrix to factorize.",
77}];
78
79const LU_INPUTS_A_MODE: [BuiltinParamDescriptor; 2] = [
80    BuiltinParamDescriptor {
81        name: "A",
82        ty: BuiltinParamType::NumericArray,
83        arity: BuiltinParamArity::Required,
84        default: None,
85        description: "Input matrix to factorize.",
86    },
87    BuiltinParamDescriptor {
88        name: "pivotMode",
89        ty: BuiltinParamType::StringScalar,
90        arity: BuiltinParamArity::Required,
91        default: Some("\"matrix\""),
92        description: "Permutation mode (`\"matrix\"` or `\"vector\"`).",
93    },
94];
95
96const LU_SIGNATURES: [BuiltinSignatureDescriptor; 6] = [
97    BuiltinSignatureDescriptor {
98        label: "LU = lu(A)",
99        inputs: &LU_INPUTS_A,
100        outputs: &LU_OUTPUT_COMBINED,
101    },
102    BuiltinSignatureDescriptor {
103        label: "LU = lu(A, pivotMode)",
104        inputs: &LU_INPUTS_A_MODE,
105        outputs: &LU_OUTPUT_COMBINED,
106    },
107    BuiltinSignatureDescriptor {
108        label: "[L, U] = lu(A)",
109        inputs: &LU_INPUTS_A,
110        outputs: &LU_OUTPUT_LU,
111    },
112    BuiltinSignatureDescriptor {
113        label: "[L, U] = lu(A, pivotMode)",
114        inputs: &LU_INPUTS_A_MODE,
115        outputs: &LU_OUTPUT_LU,
116    },
117    BuiltinSignatureDescriptor {
118        label: "[L, U, P] = lu(A)",
119        inputs: &LU_INPUTS_A,
120        outputs: &LU_OUTPUT_LUP,
121    },
122    BuiltinSignatureDescriptor {
123        label: "[L, U, P] = lu(A, pivotMode)",
124        inputs: &LU_INPUTS_A_MODE,
125        outputs: &LU_OUTPUT_LUP,
126    },
127];
128
129const LU_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
130    code: "RM.LU.INVALID_ARGUMENT",
131    identifier: Some("RunMat:lu:InvalidArgument"),
132    when: "Option arguments or requested output count are invalid.",
133    message: "lu currently supports at most three outputs",
134};
135
136const LU_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
137    code: "RM.LU.INVALID_INPUT",
138    identifier: Some("RunMat:lu:InvalidInput"),
139    when: "Input is unsupported for LU factorization.",
140    message: "lu: expected numeric or logical input values",
141};
142
143const LU_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
144    code: "RM.LU.INTERNAL",
145    identifier: Some("RunMat:lu:Internal"),
146    when: "Runtime cannot materialize LU outputs.",
147    message: "lu: internal runtime failure",
148};
149
150const LU_ERRORS: [BuiltinErrorDescriptor; 3] = [
151    LU_ERROR_INVALID_ARGUMENT,
152    LU_ERROR_INVALID_INPUT,
153    LU_ERROR_INTERNAL,
154];
155
156pub const LU_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
157    signatures: &LU_SIGNATURES,
158    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
159    completion_policy: BuiltinCompletionPolicy::Public,
160    errors: &LU_ERRORS,
161};
162
163#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::linalg::factor::lu")]
164pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
165    name: "lu",
166    op_kind: GpuOpKind::Custom("lu-factor"),
167    supported_precisions: &[ScalarType::F64],
168    broadcast: BroadcastSemantics::None,
169    provider_hooks: &[ProviderHook::Custom("lu")],
170    constant_strategy: ConstantStrategy::InlineLiteral,
171    residency: ResidencyPolicy::NewHandle,
172    nan_mode: ReductionNaN::Include,
173    two_pass_threshold: None,
174    workgroup_size: None,
175    accepts_nan_mode: false,
176    notes: "Prefers the provider `lu` hook; automatically gathers and falls back to the CPU implementation when no provider support is registered.",
177};
178
179fn lu_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
180    lu_error_with_message(error.message, error)
181}
182
183fn lu_error_with_message(
184    message: impl Into<String>,
185    error: &'static BuiltinErrorDescriptor,
186) -> RuntimeError {
187    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
188    if let Some(identifier) = error.identifier {
189        builder = builder.with_identifier(identifier);
190    }
191    builder.build()
192}
193
194fn lu_invalid_argument(message: impl Into<String>) -> RuntimeError {
195    lu_error_with_message(message, &LU_ERROR_INVALID_ARGUMENT)
196}
197
198fn lu_invalid_input(message: impl Into<String>) -> RuntimeError {
199    lu_error_with_message(message, &LU_ERROR_INVALID_INPUT)
200}
201
202fn lu_internal_error(message: impl Into<String>) -> RuntimeError {
203    lu_error_with_message(message, &LU_ERROR_INTERNAL)
204}
205
206fn with_lu_context(mut error: RuntimeError) -> RuntimeError {
207    if error.message() == "interaction pending..." {
208        return build_runtime_error("interaction pending...")
209            .with_builtin(BUILTIN_NAME)
210            .build();
211    }
212    if error.context.builtin.is_none() {
213        error.context = error.context.with_builtin(BUILTIN_NAME);
214    }
215    error
216}
217
218#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::linalg::factor::lu")]
219pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
220    name: "lu",
221    shape: ShapeRequirements::Any,
222    constant_strategy: ConstantStrategy::InlineLiteral,
223    elementwise: None,
224    reduction: None,
225    emits_nan: false,
226    notes: "LU decomposition is not part of expression fusion; calls execute eagerly on the CPU.",
227};
228
229#[runtime_builtin(
230    name = "lu",
231    category = "math/linalg/factor",
232    summary = "Compute LU decompositions with partial pivoting.",
233    keywords = "lu,factorization,decomposition,permutation",
234    accel = "sink",
235    sink = true,
236    type_resolver(matrix_unary_type),
237    descriptor(crate::builtins::math::linalg::factor::lu::LU_DESCRIPTOR),
238    builtin_path = "crate::builtins::math::linalg::factor::lu"
239)]
240async fn lu_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
241    let eval = evaluate(value, &rest).await?;
242    if let Some(out_count) = crate::output_count::current_output_count() {
243        if out_count == 0 {
244            return Ok(Value::OutputList(Vec::new()));
245        }
246        if out_count == 1 {
247            return Ok(Value::OutputList(vec![eval.combined()]));
248        }
249        if out_count == 2 {
250            return Ok(Value::OutputList(vec![eval.lower(), eval.upper()]));
251        }
252        if out_count == 3 {
253            return Ok(Value::OutputList(vec![
254                eval.lower(),
255                eval.upper(),
256                eval.permutation(),
257            ]));
258        }
259        return Err(lu_error(&LU_ERROR_INVALID_ARGUMENT));
260    }
261    Ok(eval.combined())
262}
263
264/// Output form for `lu`, reused by both the builtin wrapper and the VM multi-output path.
265#[derive(Clone)]
266pub struct LuEval {
267    combined: Value,
268    lower: Value,
269    upper: Value,
270    perm_matrix: Value,
271    perm_vector: Value,
272    pivot_mode: PivotMode,
273}
274
275impl LuEval {
276    /// Combined LU factor (single-output form).
277    pub fn combined(&self) -> Value {
278        self.combined.clone()
279    }
280
281    /// Lower-triangular factor.
282    pub fn lower(&self) -> Value {
283        self.lower.clone()
284    }
285
286    /// Upper-triangular factor.
287    pub fn upper(&self) -> Value {
288        self.upper.clone()
289    }
290
291    /// Permutation value respecting the selected pivot mode.
292    pub fn permutation(&self) -> Value {
293        match self.pivot_mode {
294            PivotMode::Matrix => self.perm_matrix.clone(),
295            PivotMode::Vector => self.perm_vector.clone(),
296        }
297    }
298
299    /// Permutation matrix (always available, useful for tests).
300    pub fn permutation_matrix(&self) -> Value {
301        self.perm_matrix.clone()
302    }
303
304    /// Pivot vector (always available, useful for tests).
305    pub fn pivot_vector(&self) -> Value {
306        self.perm_vector.clone()
307    }
308
309    /// The pivot mode that was requested.
310    pub fn pivot_mode(&self) -> PivotMode {
311        self.pivot_mode
312    }
313
314    fn from_components(components: LuComponents, pivot_mode: PivotMode) -> BuiltinResult<Self> {
315        let combined = matrix_to_value(&components.combined)?;
316        let lower = matrix_to_value(&components.lower)?;
317        let upper = matrix_to_value(&components.upper)?;
318        let perm_matrix = matrix_to_value(&components.permutation)?;
319        let perm_vector = pivot_vector_to_value(&components.pivot_vector)?;
320        Ok(Self {
321            combined,
322            lower,
323            upper,
324            perm_matrix,
325            perm_vector,
326            pivot_mode,
327        })
328    }
329
330    fn from_provider(result: ProviderLuResult, pivot_mode: PivotMode) -> Self {
331        Self {
332            combined: Value::GpuTensor(result.combined),
333            lower: Value::GpuTensor(result.lower),
334            upper: Value::GpuTensor(result.upper),
335            perm_matrix: Value::GpuTensor(result.perm_matrix),
336            perm_vector: Value::GpuTensor(result.perm_vector),
337            pivot_mode,
338        }
339    }
340}
341
342/// Permutation output mode.
343#[derive(Clone, Copy, Debug, PartialEq, Eq)]
344pub enum PivotMode {
345    Matrix,
346    Vector,
347}
348
349impl Default for PivotMode {
350    fn default() -> Self {
351        Self::Matrix
352    }
353}
354
355/// Evaluate `lu` while preserving all output forms for later extraction.
356pub async fn evaluate(value: Value, args: &[Value]) -> BuiltinResult<LuEval> {
357    let pivot_mode = parse_pivot_mode(args)?;
358    match value {
359        Value::GpuTensor(handle) => {
360            if let Some(eval) = evaluate_gpu(&handle, pivot_mode).await? {
361                return Ok(eval);
362            }
363            let tensor = gpu_helpers::gather_tensor_async(&handle)
364                .await
365                .map_err(with_lu_context)?;
366            evaluate_host_value(Value::Tensor(tensor), pivot_mode).await
367        }
368        other => evaluate_host_value(other, pivot_mode).await,
369    }
370}
371
372async fn evaluate_host_value(value: Value, pivot_mode: PivotMode) -> BuiltinResult<LuEval> {
373    let matrix = extract_matrix(value).await?;
374    let components = lu_factor(matrix)?;
375    LuEval::from_components(components, pivot_mode)
376}
377
378async fn evaluate_gpu(
379    handle: &GpuTensorHandle,
380    pivot_mode: PivotMode,
381) -> BuiltinResult<Option<LuEval>> {
382    if let Some(provider) = runmat_accelerate_api::provider() {
383        if let Ok(result) = provider.lu(handle).await {
384            return Ok(Some(LuEval::from_provider(result, pivot_mode)));
385        }
386    }
387    Ok(None)
388}
389
390fn parse_pivot_mode(args: &[Value]) -> BuiltinResult<PivotMode> {
391    if args.is_empty() {
392        return Ok(PivotMode::Matrix);
393    }
394    if args.len() > 1 {
395        return Err(lu_invalid_argument("lu: too many option arguments"));
396    }
397    let Some(option) = tensor::value_to_string(&args[0]) else {
398        return Err(lu_invalid_argument(
399            "lu: option must be a string or character vector",
400        ));
401    };
402    match option.trim().to_ascii_lowercase().as_str() {
403        "matrix" => Ok(PivotMode::Matrix),
404        "vector" => Ok(PivotMode::Vector),
405        other => Err(lu_invalid_argument(format!("lu: unknown option '{other}'"))),
406    }
407}
408
409async fn extract_matrix(value: Value) -> BuiltinResult<RowMajorMatrix> {
410    match value {
411        Value::Tensor(t) => RowMajorMatrix::from_tensor(&t),
412        Value::ComplexTensor(ct) => RowMajorMatrix::from_complex_tensor(&ct),
413        Value::GpuTensor(handle) => {
414            let tensor = gpu_helpers::gather_tensor_async(&handle)
415                .await
416                .map_err(with_lu_context)?;
417            RowMajorMatrix::from_tensor(&tensor)
418        }
419        Value::LogicalArray(logical) => {
420            let tensor = tensor::logical_to_tensor(&logical)
421                .map_err(|err| lu_invalid_input(format!("lu: {err}")))?;
422            RowMajorMatrix::from_tensor(&tensor)
423        }
424        Value::Num(n) => Ok(RowMajorMatrix::from_scalar(Complex64::new(n, 0.0))),
425        Value::Int(i) => Ok(RowMajorMatrix::from_scalar(Complex64::new(i.to_f64(), 0.0))),
426        Value::Bool(b) => Ok(RowMajorMatrix::from_scalar(Complex64::new(
427            if b { 1.0 } else { 0.0 },
428            0.0,
429        ))),
430        Value::Complex(re, im) => Ok(RowMajorMatrix::from_scalar(Complex64::new(re, im))),
431        Value::CharArray(_) | Value::String(_) | Value::StringArray(_) => Err(lu_invalid_input(
432            "lu: character data is not supported; convert to numeric values first",
433        )),
434        other => Err(lu_invalid_input(format!(
435            "lu: unsupported input type {:?}",
436            other
437        ))),
438    }
439}
440
441struct LuComponents {
442    combined: RowMajorMatrix,
443    lower: RowMajorMatrix,
444    upper: RowMajorMatrix,
445    permutation: RowMajorMatrix,
446    pivot_vector: Vec<f64>,
447}
448
449fn lu_factor(mut matrix: RowMajorMatrix) -> BuiltinResult<LuComponents> {
450    let rows = matrix.rows;
451    let cols = matrix.cols;
452    let min_dim = rows.min(cols);
453    let mut perm: Vec<usize> = (0..rows).collect();
454
455    for k in 0..min_dim {
456        // Select pivot row with maximal absolute value in column k.
457        let mut pivot_row = k;
458        let mut pivot_abs = 0.0;
459        for r in k..rows {
460            let val = matrix.get(r, k);
461            let abs = val.norm();
462            if abs > pivot_abs {
463                pivot_abs = abs;
464                pivot_row = r;
465            }
466        }
467
468        if pivot_row != k {
469            matrix.swap_rows(pivot_row, k);
470            perm.swap(pivot_row, k);
471        }
472
473        if pivot_abs <= EPS {
474            // Entire column is effectively zero; set multipliers to zero and continue.
475            for r in (k + 1)..rows {
476                matrix.set(r, k, Complex64::new(0.0, 0.0));
477            }
478            continue;
479        }
480
481        let pivot_value = matrix.get(k, k);
482        for r in (k + 1)..rows {
483            let factor = matrix.get(r, k) / pivot_value;
484            matrix.set(r, k, factor);
485            for c in (k + 1)..cols {
486                let updated = matrix.get(r, c) - factor * matrix.get(k, c);
487                matrix.set(r, c, updated);
488            }
489        }
490    }
491
492    let combined = matrix.clone();
493    let lower = build_lower(&matrix);
494    let upper = build_upper(&matrix);
495    let permutation = build_permutation(rows, &perm);
496    let pivot_vector: Vec<f64> = perm.iter().map(|idx| (*idx + 1) as f64).collect();
497
498    Ok(LuComponents {
499        combined,
500        lower,
501        upper,
502        permutation,
503        pivot_vector,
504    })
505}
506
507fn build_lower(matrix: &RowMajorMatrix) -> RowMajorMatrix {
508    let rows = matrix.rows;
509    let cols = matrix.cols;
510    let min_dim = rows.min(cols);
511    let mut lower = RowMajorMatrix::identity(rows);
512    for i in 0..rows {
513        for j in 0..min_dim {
514            if i > j {
515                lower.set(i, j, matrix.get(i, j));
516            }
517        }
518    }
519    lower
520}
521
522fn build_upper(matrix: &RowMajorMatrix) -> RowMajorMatrix {
523    let rows = matrix.rows;
524    let cols = matrix.cols;
525    let mut upper = RowMajorMatrix::zeros(rows, cols);
526    for i in 0..rows {
527        for j in 0..cols {
528            if i <= j {
529                upper.set(i, j, matrix.get(i, j));
530            }
531        }
532    }
533    upper
534}
535
536fn build_permutation(rows: usize, perm: &[usize]) -> RowMajorMatrix {
537    let mut matrix = RowMajorMatrix::zeros(rows, rows);
538    for (i, &col) in perm.iter().enumerate() {
539        if col < rows {
540            matrix.set(i, col, Complex64::new(1.0, 0.0));
541        }
542    }
543    matrix
544}
545
546const EPS: f64 = 1.0e-12;
547
548fn matrix_to_value(matrix: &RowMajorMatrix) -> BuiltinResult<Value> {
549    let mut has_imag = false;
550    for val in &matrix.data {
551        if val.im.abs() > EPS {
552            has_imag = true;
553            break;
554        }
555    }
556    if has_imag {
557        let mut data = Vec::with_capacity(matrix.rows * matrix.cols);
558        for col in 0..matrix.cols {
559            for row in 0..matrix.rows {
560                let idx = row * matrix.cols + col;
561                let v = matrix.data[idx];
562                data.push((v.re, v.im));
563            }
564        }
565        let tensor = ComplexTensor::new(data, vec![matrix.rows, matrix.cols])
566            .map_err(|e| lu_internal_error(format!("lu: {e}")))?;
567        Ok(Value::ComplexTensor(tensor))
568    } else {
569        let mut data = Vec::with_capacity(matrix.rows * matrix.cols);
570        for col in 0..matrix.cols {
571            for row in 0..matrix.rows {
572                let idx = row * matrix.cols + col;
573                data.push(matrix.data[idx].re);
574            }
575        }
576        let tensor = Tensor::new(data, vec![matrix.rows, matrix.cols])
577            .map_err(|e| lu_internal_error(format!("lu: {e}")))?;
578        Ok(Value::Tensor(tensor))
579    }
580}
581
582fn pivot_vector_to_value(pivot: &[f64]) -> BuiltinResult<Value> {
583    let rows = pivot.len();
584    let tensor = Tensor::new(pivot.to_vec(), vec![rows, 1])
585        .map_err(|e| lu_internal_error(format!("lu: {e}")))?;
586    Ok(Value::Tensor(tensor))
587}
588
589#[derive(Clone)]
590struct RowMajorMatrix {
591    rows: usize,
592    cols: usize,
593    data: Vec<Complex64>,
594}
595
596impl RowMajorMatrix {
597    fn zeros(rows: usize, cols: usize) -> Self {
598        Self {
599            rows,
600            cols,
601            data: vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)],
602        }
603    }
604
605    fn identity(size: usize) -> Self {
606        let mut matrix = Self::zeros(size, size);
607        for i in 0..size {
608            matrix.set(i, i, Complex64::new(1.0, 0.0));
609        }
610        matrix
611    }
612
613    fn from_scalar(value: Complex64) -> Self {
614        Self {
615            rows: 1,
616            cols: 1,
617            data: vec![value],
618        }
619    }
620
621    fn from_tensor(tensor: &Tensor) -> BuiltinResult<Self> {
622        if tensor.shape.len() > 2 {
623            return Err(lu_invalid_input("lu: input must be 2-D"));
624        }
625        let rows = tensor.rows();
626        let cols = tensor.cols();
627        let mut data = vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)];
628        for col in 0..cols {
629            for row in 0..rows {
630                let idx_col_major = row + col * rows;
631                let idx_row_major = row * cols + col;
632                data[idx_row_major] = Complex64::new(tensor.data[idx_col_major], 0.0);
633            }
634        }
635        Ok(Self { rows, cols, data })
636    }
637
638    fn from_complex_tensor(tensor: &ComplexTensor) -> BuiltinResult<Self> {
639        if tensor.shape.len() > 2 {
640            return Err(lu_invalid_input("lu: input must be 2-D"));
641        }
642        let rows = tensor.rows;
643        let cols = tensor.cols;
644        let mut data = vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)];
645        for col in 0..cols {
646            for row in 0..rows {
647                let idx_col_major = row + col * rows;
648                let idx_row_major = row * cols + col;
649                let (re, im) = tensor.data[idx_col_major];
650                data[idx_row_major] = Complex64::new(re, im);
651            }
652        }
653        Ok(Self { rows, cols, data })
654    }
655
656    fn get(&self, row: usize, col: usize) -> Complex64 {
657        self.data[row * self.cols + col]
658    }
659
660    fn set(&mut self, row: usize, col: usize, value: Complex64) {
661        self.data[row * self.cols + col] = value;
662    }
663
664    fn swap_rows(&mut self, r1: usize, r2: usize) {
665        if r1 == r2 {
666            return;
667        }
668        for col in 0..self.cols {
669            self.data.swap(r1 * self.cols + col, r2 * self.cols + col);
670        }
671    }
672}
673
674#[cfg(test)]
675pub(crate) mod tests {
676    use super::*;
677    use crate::builtins::common::test_support;
678    use futures::executor::block_on;
679    use runmat_builtins::{ComplexTensor as CMatrix, ResolveContext, Tensor as Matrix, Type};
680
681    fn error_message(err: RuntimeError) -> String {
682        err.message().to_string()
683    }
684
685    fn tensor_from_value(value: Value) -> Matrix {
686        match value {
687            Value::Tensor(t) => t,
688            other => panic!("expected dense tensor, got {other:?}"),
689        }
690    }
691
692    fn row_major_from_value(value: Value) -> RowMajorMatrix {
693        match value {
694            Value::Tensor(t) => RowMajorMatrix::from_tensor(&t).expect("row-major tensor"),
695            Value::ComplexTensor(ct) => {
696                RowMajorMatrix::from_complex_tensor(&ct).expect("row-major complex tensor")
697            }
698            other => panic!("expected tensor value, got {other:?}"),
699        }
700    }
701
702    #[test]
703    fn lu_type_preserves_matrix_shape() {
704        let out = matrix_unary_type(
705            &[Type::Tensor {
706                shape: Some(vec![Some(2), Some(3)]),
707            }],
708            &ResolveContext::new(Vec::new()),
709        );
710        assert_eq!(
711            out,
712            Type::Tensor {
713                shape: Some(vec![Some(2), Some(3)])
714            }
715        );
716    }
717
718    #[test]
719    fn lu_descriptor_signatures_cover_core_forms() {
720        let labels: Vec<&str> = LU_DESCRIPTOR
721            .signatures
722            .iter()
723            .map(|signature| signature.label)
724            .collect();
725        assert!(labels.contains(&"LU = lu(A)"));
726        assert!(labels.contains(&"LU = lu(A, pivotMode)"));
727        assert!(labels.contains(&"[L, U] = lu(A)"));
728        assert!(labels.contains(&"[L, U] = lu(A, pivotMode)"));
729        assert!(labels.contains(&"[L, U, P] = lu(A)"));
730        assert!(labels.contains(&"[L, U, P] = lu(A, pivotMode)"));
731    }
732
733    #[test]
734    fn lu_descriptor_errors_have_stable_codes() {
735        let codes: Vec<&str> = LU_DESCRIPTOR.errors.iter().map(|err| err.code).collect();
736        assert!(codes.contains(&"RM.LU.INVALID_ARGUMENT"));
737        assert!(codes.contains(&"RM.LU.INVALID_INPUT"));
738        assert!(codes.contains(&"RM.LU.INTERNAL"));
739    }
740
741    fn row_major_matmul(a: &RowMajorMatrix, b: &RowMajorMatrix) -> RowMajorMatrix {
742        assert_eq!(a.cols, b.rows, "incompatible shapes for matmul");
743        let mut out = RowMajorMatrix::zeros(a.rows, b.cols);
744        for i in 0..a.rows {
745            for k in 0..a.cols {
746                let aik = a.get(i, k);
747                for j in 0..b.cols {
748                    let acc = out.get(i, j) + aik * b.get(k, j);
749                    out.set(i, j, acc);
750                }
751            }
752        }
753        out
754    }
755
756    fn assert_tensor_close(a: &Matrix, b: &Matrix, tol: f64) {
757        assert_eq!(a.shape, b.shape);
758        for (lhs, rhs) in a.data.iter().zip(&b.data) {
759            assert!(
760                (lhs - rhs).abs() <= tol,
761                "mismatch: lhs={lhs}, rhs={rhs}, tol={tol}"
762            );
763        }
764    }
765
766    fn assert_row_major_close(a: &RowMajorMatrix, b: &RowMajorMatrix, tol: f64) {
767        assert_eq!(a.rows, b.rows, "row mismatch");
768        assert_eq!(a.cols, b.cols, "col mismatch");
769        for row in 0..a.rows {
770            for col in 0..a.cols {
771                let lhs = a.get(row, col);
772                let rhs = b.get(row, col);
773                let diff = (lhs - rhs).norm();
774                assert!(
775                    diff <= tol,
776                    "mismatch at ({row}, {col}): lhs={lhs:?}, rhs={rhs:?}, diff={diff}, tol={tol}"
777                );
778            }
779        }
780    }
781
782    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
783    #[test]
784    fn lu_single_output_produces_combined_matrix() {
785        let a = Matrix::new(
786            vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0],
787            vec![3, 3],
788        )
789        .unwrap();
790        let result = lu_builtin(Value::Tensor(a.clone()), Vec::new()).expect("lu");
791        let lu = tensor_from_value(result);
792        let eval = evaluate(Value::Tensor(a), &[]).expect("evaluate");
793        let expected = tensor_from_value(eval.combined());
794        assert_tensor_close(&lu, &expected, 1e-12);
795    }
796
797    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
798    #[test]
799    fn lu_three_outputs_matches_factorization() {
800        let data = vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0];
801        let a = Matrix::new(data.clone(), vec![3, 3]).unwrap();
802        let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate");
803        let l = tensor_from_value(eval.lower());
804        let u = tensor_from_value(eval.upper());
805        let p = tensor_from_value(eval.permutation_matrix());
806
807        let pa = crate::builtins::common::matrix::matrix_mul(&p, &a).expect("P*A");
808        let lu_product = crate::builtins::common::matrix::matrix_mul(&l, &u).expect("L*U");
809        assert_tensor_close(&pa, &lu_product, 1e-9);
810    }
811
812    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
813    #[test]
814    fn lu_complex_matrix_factorization() {
815        let data = vec![(1.0, 2.0), (3.0, -1.0), (2.0, -1.0), (4.0, 2.0)];
816        let a = CMatrix::new(data.clone(), vec![2, 2]).expect("complex tensor");
817        let eval = evaluate(Value::ComplexTensor(a.clone()), &[]).expect("evaluate complex");
818
819        let l = row_major_from_value(eval.lower());
820        let u = row_major_from_value(eval.upper());
821        let p = row_major_from_value(eval.permutation_matrix());
822        let input = RowMajorMatrix::from_complex_tensor(&a).expect("row-major input");
823
824        let pa = row_major_matmul(&p, &input);
825        let lu = row_major_matmul(&l, &u);
826        assert_row_major_close(&pa, &lu, 1e-9);
827    }
828
829    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
830    #[test]
831    fn lu_handles_singular_matrix() {
832        let a = Matrix::new(vec![0.0, 0.0, 0.0, 0.0], vec![2, 2]).unwrap();
833        let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate singular");
834        let l = tensor_from_value(eval.lower());
835        let u = tensor_from_value(eval.upper());
836        let p = tensor_from_value(eval.permutation_matrix());
837
838        assert!(u.data.iter().any(|&v| v.abs() <= 1e-12));
839
840        let pa = crate::builtins::common::matrix::matrix_mul(&p, &a).expect("P*A");
841        let lu_product = crate::builtins::common::matrix::matrix_mul(&l, &u).expect("L*U");
842        assert_tensor_close(&pa, &lu_product, 1e-9);
843    }
844
845    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
846    #[test]
847    fn lu_vector_option_returns_pivot_vector() {
848        let a = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
849        let eval =
850            evaluate(Value::Tensor(a), &[Value::from("vector")]).expect("evaluate vector mode");
851        assert_eq!(eval.pivot_mode(), PivotMode::Vector);
852        let pivot = tensor_from_value(eval.pivot_vector());
853        assert_eq!(pivot.shape, vec![2, 1]);
854        assert_eq!(pivot.data, vec![2.0, 1.0]);
855    }
856
857    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
858    #[test]
859    fn lu_vector_option_case_insensitive() {
860        let a = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
861        let eval =
862            evaluate(Value::Tensor(a), &[Value::from("VECTOR")]).expect("evaluate vector option");
863        assert_eq!(eval.pivot_mode(), PivotMode::Vector);
864    }
865
866    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
867    #[test]
868    fn lu_matrix_option_returns_permutation_matrix() {
869        let a = Matrix::new(vec![2.0, 1.0, 3.0, 4.0], vec![2, 2]).unwrap();
870        let eval =
871            evaluate(Value::Tensor(a), &[Value::from("matrix")]).expect("evaluate matrix option");
872        assert_eq!(eval.pivot_mode(), PivotMode::Matrix);
873        let perm_selected = tensor_from_value(eval.permutation());
874        let perm_matrix = tensor_from_value(eval.permutation_matrix());
875        assert_eq!(perm_selected.shape, perm_matrix.shape);
876        assert_tensor_close(&perm_selected, &perm_matrix, 1e-12);
877    }
878
879    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
880    #[test]
881    fn lu_handles_rectangular_matrices() {
882        let a = Matrix::new(vec![3.0, 6.0, 1.0, 3.0, 2.0, 4.0], vec![2, 3]).unwrap();
883        let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate rectangular");
884        let l = tensor_from_value(eval.lower());
885        let u = tensor_from_value(eval.upper());
886        let p = tensor_from_value(eval.permutation_matrix());
887        assert_eq!(l.shape, vec![2, 2]);
888        assert_eq!(u.shape, vec![2, 3]);
889        assert_eq!(p.shape, vec![2, 2]);
890
891        let pa = crate::builtins::common::matrix::matrix_mul(&p, &a).expect("P*A");
892        let lu_product = crate::builtins::common::matrix::matrix_mul(&l, &u).expect("L*U");
893        assert_tensor_close(&pa, &lu_product, 1e-9);
894    }
895
896    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
897    #[test]
898    fn lu_rejects_unknown_option() {
899        let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
900        let err = match evaluate(Value::Tensor(a), &[Value::from("invalid")]) {
901            Ok(_) => panic!("expected option parse failure"),
902            Err(err) => {
903                assert_eq!(err.identifier(), LU_ERROR_INVALID_ARGUMENT.identifier);
904                error_message(err)
905            }
906        };
907        assert!(err.contains("unknown option"));
908    }
909
910    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
911    #[test]
912    fn lu_rejects_non_string_option() {
913        let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
914        let err = match evaluate(Value::Tensor(a), &[Value::Num(2.0)]) {
915            Ok(_) => panic!("expected option parse failure"),
916            Err(err) => {
917                assert_eq!(err.identifier(), LU_ERROR_INVALID_ARGUMENT.identifier);
918                error_message(err)
919            }
920        };
921        assert!(err.contains("unknown option"));
922    }
923
924    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
925    #[test]
926    fn lu_rejects_multiple_options() {
927        let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
928        let err = match evaluate(
929            Value::Tensor(a),
930            &[Value::from("matrix"), Value::from("vector")],
931        ) {
932            Ok(_) => panic!("expected option arity failure"),
933            Err(err) => {
934                assert_eq!(err.identifier(), LU_ERROR_INVALID_ARGUMENT.identifier);
935                error_message(err)
936            }
937        };
938        assert!(err.contains("too many option arguments"));
939    }
940
941    #[test]
942    fn lu_invalid_input_identifier_is_stable() {
943        let tensor = Matrix::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 2, 2]).expect("tensor");
944        let err = match evaluate(Value::Tensor(tensor), &[]) {
945            Ok(_) => panic!("expected 2-D input failure"),
946            Err(err) => err,
947        };
948        assert_eq!(err.identifier(), LU_ERROR_INVALID_INPUT.identifier);
949    }
950
951    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
952    #[test]
953    fn lu_gpu_provider_roundtrip() {
954        test_support::with_test_provider(|provider| {
955            let host = Matrix::new(vec![10.0, 3.0, 7.0, 2.0], vec![2, 2]).unwrap();
956            let view = runmat_accelerate_api::HostTensorView {
957                data: &host.data,
958                shape: &host.shape,
959            };
960            let handle = provider.upload(&view).expect("upload");
961            let eval = evaluate(Value::GpuTensor(handle.clone()), &[]).expect("evaluate gpu input");
962            let lower_val = eval.lower();
963            let upper_val = eval.upper();
964            let perm_val = eval.permutation_matrix();
965            assert!(matches!(lower_val, Value::GpuTensor(_)));
966            assert!(matches!(upper_val, Value::GpuTensor(_)));
967            assert!(matches!(perm_val, Value::GpuTensor(_)));
968            let l = test_support::gather(lower_val).expect("gather lower");
969            let u = test_support::gather(upper_val).expect("gather upper");
970            let p = test_support::gather(perm_val).expect("gather permutation");
971            let pa = crate::builtins::common::matrix::matrix_mul(&p, &host).expect("P*A");
972            let lu_product = crate::builtins::common::matrix::matrix_mul(&l, &u).expect("L*U");
973            assert_tensor_close(&pa, &lu_product, 1e-9);
974        });
975    }
976
977    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
978    #[test]
979    fn lu_gpu_vector_option_roundtrip() {
980        test_support::with_test_provider(|provider| {
981            let host = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
982            let view = runmat_accelerate_api::HostTensorView {
983                data: &host.data,
984                shape: &host.shape,
985            };
986            let handle = provider.upload(&view).expect("upload");
987            let eval =
988                evaluate(Value::GpuTensor(handle), &[Value::from("vector")]).expect("gpu vector");
989            let pivot_val = eval.permutation();
990            assert!(matches!(pivot_val, Value::GpuTensor(_)));
991            let pivot = test_support::gather(pivot_val).expect("gather pivot");
992            assert_eq!(pivot.shape, vec![2, 1]);
993            let expected = Matrix::new(vec![2.0, 1.0], vec![2, 1]).unwrap();
994            assert_tensor_close(&pivot, &expected, 1e-12);
995        });
996    }
997
998    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
999    #[test]
1000    fn lu_accepts_scalar_inputs() {
1001        let eval = evaluate(Value::Num(5.0), &[]).expect("evaluate scalar");
1002        let l = tensor_from_value(eval.lower());
1003        let u = tensor_from_value(eval.upper());
1004        let p = tensor_from_value(eval.permutation_matrix());
1005        assert_eq!(l.data, vec![1.0]);
1006        assert_eq!(u.data, vec![5.0]);
1007        assert_eq!(p.data, vec![1.0]);
1008    }
1009
1010    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1011    #[test]
1012    #[cfg(feature = "wgpu")]
1013    fn lu_wgpu_matches_cpu() {
1014        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1015            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1016        );
1017        let host = Matrix::new(
1018            vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0],
1019            vec![3, 3],
1020        )
1021        .unwrap();
1022        let cpu_eval = evaluate(Value::Tensor(host.clone()), &[]).expect("cpu evaluate");
1023        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1024        let view = runmat_accelerate_api::HostTensorView {
1025            data: &host.data,
1026            shape: &host.shape,
1027        };
1028        let handle = provider.upload(&view).expect("upload");
1029        let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu evaluate");
1030
1031        let l_cpu = tensor_from_value(cpu_eval.lower());
1032        let u_cpu = tensor_from_value(cpu_eval.upper());
1033        let p_cpu = tensor_from_value(cpu_eval.permutation_matrix());
1034        let lu_cpu = tensor_from_value(cpu_eval.combined());
1035
1036        let l_gpu = test_support::gather(gpu_eval.lower()).expect("gather L");
1037        let u_gpu = test_support::gather(gpu_eval.upper()).expect("gather U");
1038        let p_gpu = test_support::gather(gpu_eval.permutation_matrix()).expect("gather P");
1039        let lu_gpu = test_support::gather(gpu_eval.combined()).expect("gather LU");
1040
1041        assert_tensor_close(&l_cpu, &l_gpu, 1e-12);
1042        assert_tensor_close(&u_cpu, &u_gpu, 1e-12);
1043        assert_tensor_close(&p_cpu, &p_gpu, 1e-12);
1044        assert_tensor_close(&lu_cpu, &lu_gpu, 1e-12);
1045
1046        let pivot_cpu = tensor_from_value(cpu_eval.pivot_vector());
1047        let pivot_gpu = test_support::gather(gpu_eval.pivot_vector()).expect("gather pivot vector");
1048        assert_tensor_close(&pivot_cpu, &pivot_gpu, 1e-12);
1049
1050        let handle_vector = provider.upload(&view).expect("upload vector option");
1051        let gpu_vector_eval = evaluate(Value::GpuTensor(handle_vector), &[Value::from("vector")])
1052            .expect("gpu vector evaluate");
1053        let pivot_vector =
1054            test_support::gather(gpu_vector_eval.permutation()).expect("gather vector pivot");
1055        assert_tensor_close(&pivot_cpu, &pivot_vector, 1e-12);
1056    }
1057
1058    fn lu_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
1059        block_on(super::lu_builtin(value, rest))
1060    }
1061
1062    fn evaluate(value: Value, args: &[Value]) -> BuiltinResult<LuEval> {
1063        block_on(super::evaluate(value, args))
1064    }
1065}