1use crate::builtins::common::spec::{
4 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
5 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
6};
7use crate::builtins::common::{gpu_helpers, tensor};
8use crate::builtins::math::linalg::type_resolvers::matrix_unary_type;
9use crate::{build_runtime_error, BuiltinResult, RuntimeError};
10
11use num_complex::Complex64;
12use runmat_accelerate_api::{GpuTensorHandle, ProviderLuResult};
13use runmat_builtins::{
14 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
15 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
16 ComplexTensor, Tensor, Value,
17};
18use runmat_macros::runtime_builtin;
19
20const BUILTIN_NAME: &str = "lu";
21
22const LU_OUTPUT_COMBINED: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
23 name: "LU",
24 ty: BuiltinParamType::NumericArray,
25 arity: BuiltinParamArity::Required,
26 default: None,
27 description: "Combined LU factors.",
28}];
29
30const LU_OUTPUT_LU: [BuiltinParamDescriptor; 2] = [
31 BuiltinParamDescriptor {
32 name: "L",
33 ty: BuiltinParamType::NumericArray,
34 arity: BuiltinParamArity::Required,
35 default: None,
36 description: "Lower-triangular factor.",
37 },
38 BuiltinParamDescriptor {
39 name: "U",
40 ty: BuiltinParamType::NumericArray,
41 arity: BuiltinParamArity::Required,
42 default: None,
43 description: "Upper-triangular factor.",
44 },
45];
46
47const LU_OUTPUT_LUP: [BuiltinParamDescriptor; 3] = [
48 BuiltinParamDescriptor {
49 name: "L",
50 ty: BuiltinParamType::NumericArray,
51 arity: BuiltinParamArity::Required,
52 default: None,
53 description: "Lower-triangular factor.",
54 },
55 BuiltinParamDescriptor {
56 name: "U",
57 ty: BuiltinParamType::NumericArray,
58 arity: BuiltinParamArity::Required,
59 default: None,
60 description: "Upper-triangular factor.",
61 },
62 BuiltinParamDescriptor {
63 name: "P",
64 ty: BuiltinParamType::NumericArray,
65 arity: BuiltinParamArity::Required,
66 default: None,
67 description: "Permutation matrix or vector based on pivot mode.",
68 },
69];
70
71const LU_INPUTS_A: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
72 name: "A",
73 ty: BuiltinParamType::NumericArray,
74 arity: BuiltinParamArity::Required,
75 default: None,
76 description: "Input matrix to factorize.",
77}];
78
79const LU_INPUTS_A_MODE: [BuiltinParamDescriptor; 2] = [
80 BuiltinParamDescriptor {
81 name: "A",
82 ty: BuiltinParamType::NumericArray,
83 arity: BuiltinParamArity::Required,
84 default: None,
85 description: "Input matrix to factorize.",
86 },
87 BuiltinParamDescriptor {
88 name: "pivotMode",
89 ty: BuiltinParamType::StringScalar,
90 arity: BuiltinParamArity::Required,
91 default: Some("\"matrix\""),
92 description: "Permutation mode (`\"matrix\"` or `\"vector\"`).",
93 },
94];
95
96const LU_SIGNATURES: [BuiltinSignatureDescriptor; 6] = [
97 BuiltinSignatureDescriptor {
98 label: "LU = lu(A)",
99 inputs: &LU_INPUTS_A,
100 outputs: &LU_OUTPUT_COMBINED,
101 },
102 BuiltinSignatureDescriptor {
103 label: "LU = lu(A, pivotMode)",
104 inputs: &LU_INPUTS_A_MODE,
105 outputs: &LU_OUTPUT_COMBINED,
106 },
107 BuiltinSignatureDescriptor {
108 label: "[L, U] = lu(A)",
109 inputs: &LU_INPUTS_A,
110 outputs: &LU_OUTPUT_LU,
111 },
112 BuiltinSignatureDescriptor {
113 label: "[L, U] = lu(A, pivotMode)",
114 inputs: &LU_INPUTS_A_MODE,
115 outputs: &LU_OUTPUT_LU,
116 },
117 BuiltinSignatureDescriptor {
118 label: "[L, U, P] = lu(A)",
119 inputs: &LU_INPUTS_A,
120 outputs: &LU_OUTPUT_LUP,
121 },
122 BuiltinSignatureDescriptor {
123 label: "[L, U, P] = lu(A, pivotMode)",
124 inputs: &LU_INPUTS_A_MODE,
125 outputs: &LU_OUTPUT_LUP,
126 },
127];
128
129const LU_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
130 code: "RM.LU.INVALID_ARGUMENT",
131 identifier: Some("RunMat:lu:InvalidArgument"),
132 when: "Option arguments or requested output count are invalid.",
133 message: "lu currently supports at most three outputs",
134};
135
136const LU_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
137 code: "RM.LU.INVALID_INPUT",
138 identifier: Some("RunMat:lu:InvalidInput"),
139 when: "Input is unsupported for LU factorization.",
140 message: "lu: expected numeric or logical input values",
141};
142
143const LU_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
144 code: "RM.LU.INTERNAL",
145 identifier: Some("RunMat:lu:Internal"),
146 when: "Runtime cannot materialize LU outputs.",
147 message: "lu: internal runtime failure",
148};
149
150const LU_ERRORS: [BuiltinErrorDescriptor; 3] = [
151 LU_ERROR_INVALID_ARGUMENT,
152 LU_ERROR_INVALID_INPUT,
153 LU_ERROR_INTERNAL,
154];
155
156pub const LU_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
157 signatures: &LU_SIGNATURES,
158 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
159 completion_policy: BuiltinCompletionPolicy::Public,
160 errors: &LU_ERRORS,
161};
162
163#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::linalg::factor::lu")]
164pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
165 name: "lu",
166 op_kind: GpuOpKind::Custom("lu-factor"),
167 supported_precisions: &[ScalarType::F64],
168 broadcast: BroadcastSemantics::None,
169 provider_hooks: &[ProviderHook::Custom("lu")],
170 constant_strategy: ConstantStrategy::InlineLiteral,
171 residency: ResidencyPolicy::NewHandle,
172 nan_mode: ReductionNaN::Include,
173 two_pass_threshold: None,
174 workgroup_size: None,
175 accepts_nan_mode: false,
176 notes: "Prefers the provider `lu` hook; automatically gathers and falls back to the CPU implementation when no provider support is registered.",
177};
178
179fn lu_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
180 lu_error_with_message(error.message, error)
181}
182
183fn lu_error_with_message(
184 message: impl Into<String>,
185 error: &'static BuiltinErrorDescriptor,
186) -> RuntimeError {
187 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
188 if let Some(identifier) = error.identifier {
189 builder = builder.with_identifier(identifier);
190 }
191 builder.build()
192}
193
194fn lu_invalid_argument(message: impl Into<String>) -> RuntimeError {
195 lu_error_with_message(message, &LU_ERROR_INVALID_ARGUMENT)
196}
197
198fn lu_invalid_input(message: impl Into<String>) -> RuntimeError {
199 lu_error_with_message(message, &LU_ERROR_INVALID_INPUT)
200}
201
202fn lu_internal_error(message: impl Into<String>) -> RuntimeError {
203 lu_error_with_message(message, &LU_ERROR_INTERNAL)
204}
205
206fn with_lu_context(mut error: RuntimeError) -> RuntimeError {
207 if error.message() == "interaction pending..." {
208 return build_runtime_error("interaction pending...")
209 .with_builtin(BUILTIN_NAME)
210 .build();
211 }
212 if error.context.builtin.is_none() {
213 error.context = error.context.with_builtin(BUILTIN_NAME);
214 }
215 error
216}
217
218#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::linalg::factor::lu")]
219pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
220 name: "lu",
221 shape: ShapeRequirements::Any,
222 constant_strategy: ConstantStrategy::InlineLiteral,
223 elementwise: None,
224 reduction: None,
225 emits_nan: false,
226 notes: "LU decomposition is not part of expression fusion; calls execute eagerly on the CPU.",
227};
228
229#[runtime_builtin(
230 name = "lu",
231 category = "math/linalg/factor",
232 summary = "Compute LU decompositions with partial pivoting.",
233 keywords = "lu,factorization,decomposition,permutation",
234 accel = "sink",
235 sink = true,
236 type_resolver(matrix_unary_type),
237 descriptor(crate::builtins::math::linalg::factor::lu::LU_DESCRIPTOR),
238 builtin_path = "crate::builtins::math::linalg::factor::lu"
239)]
240async fn lu_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
241 let eval = evaluate(value, &rest).await?;
242 if let Some(out_count) = crate::output_count::current_output_count() {
243 if out_count == 0 {
244 return Ok(Value::OutputList(Vec::new()));
245 }
246 if out_count == 1 {
247 return Ok(Value::OutputList(vec![eval.combined()]));
248 }
249 if out_count == 2 {
250 return Ok(Value::OutputList(vec![eval.lower(), eval.upper()]));
251 }
252 if out_count == 3 {
253 return Ok(Value::OutputList(vec![
254 eval.lower(),
255 eval.upper(),
256 eval.permutation(),
257 ]));
258 }
259 return Err(lu_error(&LU_ERROR_INVALID_ARGUMENT));
260 }
261 Ok(eval.combined())
262}
263
264#[derive(Clone)]
266pub struct LuEval {
267 combined: Value,
268 lower: Value,
269 upper: Value,
270 perm_matrix: Value,
271 perm_vector: Value,
272 pivot_mode: PivotMode,
273}
274
275impl LuEval {
276 pub fn combined(&self) -> Value {
278 self.combined.clone()
279 }
280
281 pub fn lower(&self) -> Value {
283 self.lower.clone()
284 }
285
286 pub fn upper(&self) -> Value {
288 self.upper.clone()
289 }
290
291 pub fn permutation(&self) -> Value {
293 match self.pivot_mode {
294 PivotMode::Matrix => self.perm_matrix.clone(),
295 PivotMode::Vector => self.perm_vector.clone(),
296 }
297 }
298
299 pub fn permutation_matrix(&self) -> Value {
301 self.perm_matrix.clone()
302 }
303
304 pub fn pivot_vector(&self) -> Value {
306 self.perm_vector.clone()
307 }
308
309 pub fn pivot_mode(&self) -> PivotMode {
311 self.pivot_mode
312 }
313
314 fn from_components(components: LuComponents, pivot_mode: PivotMode) -> BuiltinResult<Self> {
315 let combined = matrix_to_value(&components.combined)?;
316 let lower = matrix_to_value(&components.lower)?;
317 let upper = matrix_to_value(&components.upper)?;
318 let perm_matrix = matrix_to_value(&components.permutation)?;
319 let perm_vector = pivot_vector_to_value(&components.pivot_vector)?;
320 Ok(Self {
321 combined,
322 lower,
323 upper,
324 perm_matrix,
325 perm_vector,
326 pivot_mode,
327 })
328 }
329
330 fn from_provider(result: ProviderLuResult, pivot_mode: PivotMode) -> Self {
331 Self {
332 combined: Value::GpuTensor(result.combined),
333 lower: Value::GpuTensor(result.lower),
334 upper: Value::GpuTensor(result.upper),
335 perm_matrix: Value::GpuTensor(result.perm_matrix),
336 perm_vector: Value::GpuTensor(result.perm_vector),
337 pivot_mode,
338 }
339 }
340}
341
342#[derive(Clone, Copy, Debug, PartialEq, Eq)]
344pub enum PivotMode {
345 Matrix,
346 Vector,
347}
348
349impl Default for PivotMode {
350 fn default() -> Self {
351 Self::Matrix
352 }
353}
354
355pub async fn evaluate(value: Value, args: &[Value]) -> BuiltinResult<LuEval> {
357 let pivot_mode = parse_pivot_mode(args)?;
358 match value {
359 Value::GpuTensor(handle) => {
360 if let Some(eval) = evaluate_gpu(&handle, pivot_mode).await? {
361 return Ok(eval);
362 }
363 let tensor = gpu_helpers::gather_tensor_async(&handle)
364 .await
365 .map_err(with_lu_context)?;
366 evaluate_host_value(Value::Tensor(tensor), pivot_mode).await
367 }
368 other => evaluate_host_value(other, pivot_mode).await,
369 }
370}
371
372async fn evaluate_host_value(value: Value, pivot_mode: PivotMode) -> BuiltinResult<LuEval> {
373 let matrix = extract_matrix(value).await?;
374 let components = lu_factor(matrix)?;
375 LuEval::from_components(components, pivot_mode)
376}
377
378async fn evaluate_gpu(
379 handle: &GpuTensorHandle,
380 pivot_mode: PivotMode,
381) -> BuiltinResult<Option<LuEval>> {
382 if let Some(provider) = runmat_accelerate_api::provider() {
383 if let Ok(result) = provider.lu(handle).await {
384 return Ok(Some(LuEval::from_provider(result, pivot_mode)));
385 }
386 }
387 Ok(None)
388}
389
390fn parse_pivot_mode(args: &[Value]) -> BuiltinResult<PivotMode> {
391 if args.is_empty() {
392 return Ok(PivotMode::Matrix);
393 }
394 if args.len() > 1 {
395 return Err(lu_invalid_argument("lu: too many option arguments"));
396 }
397 let Some(option) = tensor::value_to_string(&args[0]) else {
398 return Err(lu_invalid_argument(
399 "lu: option must be a string or character vector",
400 ));
401 };
402 match option.trim().to_ascii_lowercase().as_str() {
403 "matrix" => Ok(PivotMode::Matrix),
404 "vector" => Ok(PivotMode::Vector),
405 other => Err(lu_invalid_argument(format!("lu: unknown option '{other}'"))),
406 }
407}
408
409async fn extract_matrix(value: Value) -> BuiltinResult<RowMajorMatrix> {
410 match value {
411 Value::Tensor(t) => RowMajorMatrix::from_tensor(&t),
412 Value::ComplexTensor(ct) => RowMajorMatrix::from_complex_tensor(&ct),
413 Value::GpuTensor(handle) => {
414 let tensor = gpu_helpers::gather_tensor_async(&handle)
415 .await
416 .map_err(with_lu_context)?;
417 RowMajorMatrix::from_tensor(&tensor)
418 }
419 Value::LogicalArray(logical) => {
420 let tensor = tensor::logical_to_tensor(&logical)
421 .map_err(|err| lu_invalid_input(format!("lu: {err}")))?;
422 RowMajorMatrix::from_tensor(&tensor)
423 }
424 Value::Num(n) => Ok(RowMajorMatrix::from_scalar(Complex64::new(n, 0.0))),
425 Value::Int(i) => Ok(RowMajorMatrix::from_scalar(Complex64::new(i.to_f64(), 0.0))),
426 Value::Bool(b) => Ok(RowMajorMatrix::from_scalar(Complex64::new(
427 if b { 1.0 } else { 0.0 },
428 0.0,
429 ))),
430 Value::Complex(re, im) => Ok(RowMajorMatrix::from_scalar(Complex64::new(re, im))),
431 Value::CharArray(_) | Value::String(_) | Value::StringArray(_) => Err(lu_invalid_input(
432 "lu: character data is not supported; convert to numeric values first",
433 )),
434 other => Err(lu_invalid_input(format!(
435 "lu: unsupported input type {:?}",
436 other
437 ))),
438 }
439}
440
441struct LuComponents {
442 combined: RowMajorMatrix,
443 lower: RowMajorMatrix,
444 upper: RowMajorMatrix,
445 permutation: RowMajorMatrix,
446 pivot_vector: Vec<f64>,
447}
448
449fn lu_factor(mut matrix: RowMajorMatrix) -> BuiltinResult<LuComponents> {
450 let rows = matrix.rows;
451 let cols = matrix.cols;
452 let min_dim = rows.min(cols);
453 let mut perm: Vec<usize> = (0..rows).collect();
454
455 for k in 0..min_dim {
456 let mut pivot_row = k;
458 let mut pivot_abs = 0.0;
459 for r in k..rows {
460 let val = matrix.get(r, k);
461 let abs = val.norm();
462 if abs > pivot_abs {
463 pivot_abs = abs;
464 pivot_row = r;
465 }
466 }
467
468 if pivot_row != k {
469 matrix.swap_rows(pivot_row, k);
470 perm.swap(pivot_row, k);
471 }
472
473 if pivot_abs <= EPS {
474 for r in (k + 1)..rows {
476 matrix.set(r, k, Complex64::new(0.0, 0.0));
477 }
478 continue;
479 }
480
481 let pivot_value = matrix.get(k, k);
482 for r in (k + 1)..rows {
483 let factor = matrix.get(r, k) / pivot_value;
484 matrix.set(r, k, factor);
485 for c in (k + 1)..cols {
486 let updated = matrix.get(r, c) - factor * matrix.get(k, c);
487 matrix.set(r, c, updated);
488 }
489 }
490 }
491
492 let combined = matrix.clone();
493 let lower = build_lower(&matrix);
494 let upper = build_upper(&matrix);
495 let permutation = build_permutation(rows, &perm);
496 let pivot_vector: Vec<f64> = perm.iter().map(|idx| (*idx + 1) as f64).collect();
497
498 Ok(LuComponents {
499 combined,
500 lower,
501 upper,
502 permutation,
503 pivot_vector,
504 })
505}
506
507fn build_lower(matrix: &RowMajorMatrix) -> RowMajorMatrix {
508 let rows = matrix.rows;
509 let cols = matrix.cols;
510 let min_dim = rows.min(cols);
511 let mut lower = RowMajorMatrix::identity(rows);
512 for i in 0..rows {
513 for j in 0..min_dim {
514 if i > j {
515 lower.set(i, j, matrix.get(i, j));
516 }
517 }
518 }
519 lower
520}
521
522fn build_upper(matrix: &RowMajorMatrix) -> RowMajorMatrix {
523 let rows = matrix.rows;
524 let cols = matrix.cols;
525 let mut upper = RowMajorMatrix::zeros(rows, cols);
526 for i in 0..rows {
527 for j in 0..cols {
528 if i <= j {
529 upper.set(i, j, matrix.get(i, j));
530 }
531 }
532 }
533 upper
534}
535
536fn build_permutation(rows: usize, perm: &[usize]) -> RowMajorMatrix {
537 let mut matrix = RowMajorMatrix::zeros(rows, rows);
538 for (i, &col) in perm.iter().enumerate() {
539 if col < rows {
540 matrix.set(i, col, Complex64::new(1.0, 0.0));
541 }
542 }
543 matrix
544}
545
546const EPS: f64 = 1.0e-12;
547
548fn matrix_to_value(matrix: &RowMajorMatrix) -> BuiltinResult<Value> {
549 let mut has_imag = false;
550 for val in &matrix.data {
551 if val.im.abs() > EPS {
552 has_imag = true;
553 break;
554 }
555 }
556 if has_imag {
557 let mut data = Vec::with_capacity(matrix.rows * matrix.cols);
558 for col in 0..matrix.cols {
559 for row in 0..matrix.rows {
560 let idx = row * matrix.cols + col;
561 let v = matrix.data[idx];
562 data.push((v.re, v.im));
563 }
564 }
565 let tensor = ComplexTensor::new(data, vec![matrix.rows, matrix.cols])
566 .map_err(|e| lu_internal_error(format!("lu: {e}")))?;
567 Ok(Value::ComplexTensor(tensor))
568 } else {
569 let mut data = Vec::with_capacity(matrix.rows * matrix.cols);
570 for col in 0..matrix.cols {
571 for row in 0..matrix.rows {
572 let idx = row * matrix.cols + col;
573 data.push(matrix.data[idx].re);
574 }
575 }
576 let tensor = Tensor::new(data, vec![matrix.rows, matrix.cols])
577 .map_err(|e| lu_internal_error(format!("lu: {e}")))?;
578 Ok(Value::Tensor(tensor))
579 }
580}
581
582fn pivot_vector_to_value(pivot: &[f64]) -> BuiltinResult<Value> {
583 let rows = pivot.len();
584 let tensor = Tensor::new(pivot.to_vec(), vec![rows, 1])
585 .map_err(|e| lu_internal_error(format!("lu: {e}")))?;
586 Ok(Value::Tensor(tensor))
587}
588
589#[derive(Clone)]
590struct RowMajorMatrix {
591 rows: usize,
592 cols: usize,
593 data: Vec<Complex64>,
594}
595
596impl RowMajorMatrix {
597 fn zeros(rows: usize, cols: usize) -> Self {
598 Self {
599 rows,
600 cols,
601 data: vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)],
602 }
603 }
604
605 fn identity(size: usize) -> Self {
606 let mut matrix = Self::zeros(size, size);
607 for i in 0..size {
608 matrix.set(i, i, Complex64::new(1.0, 0.0));
609 }
610 matrix
611 }
612
613 fn from_scalar(value: Complex64) -> Self {
614 Self {
615 rows: 1,
616 cols: 1,
617 data: vec![value],
618 }
619 }
620
621 fn from_tensor(tensor: &Tensor) -> BuiltinResult<Self> {
622 if tensor.shape.len() > 2 {
623 return Err(lu_invalid_input("lu: input must be 2-D"));
624 }
625 let rows = tensor.rows();
626 let cols = tensor.cols();
627 let mut data = vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)];
628 for col in 0..cols {
629 for row in 0..rows {
630 let idx_col_major = row + col * rows;
631 let idx_row_major = row * cols + col;
632 data[idx_row_major] = Complex64::new(tensor.data[idx_col_major], 0.0);
633 }
634 }
635 Ok(Self { rows, cols, data })
636 }
637
638 fn from_complex_tensor(tensor: &ComplexTensor) -> BuiltinResult<Self> {
639 if tensor.shape.len() > 2 {
640 return Err(lu_invalid_input("lu: input must be 2-D"));
641 }
642 let rows = tensor.rows;
643 let cols = tensor.cols;
644 let mut data = vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)];
645 for col in 0..cols {
646 for row in 0..rows {
647 let idx_col_major = row + col * rows;
648 let idx_row_major = row * cols + col;
649 let (re, im) = tensor.data[idx_col_major];
650 data[idx_row_major] = Complex64::new(re, im);
651 }
652 }
653 Ok(Self { rows, cols, data })
654 }
655
656 fn get(&self, row: usize, col: usize) -> Complex64 {
657 self.data[row * self.cols + col]
658 }
659
660 fn set(&mut self, row: usize, col: usize, value: Complex64) {
661 self.data[row * self.cols + col] = value;
662 }
663
664 fn swap_rows(&mut self, r1: usize, r2: usize) {
665 if r1 == r2 {
666 return;
667 }
668 for col in 0..self.cols {
669 self.data.swap(r1 * self.cols + col, r2 * self.cols + col);
670 }
671 }
672}
673
674#[cfg(test)]
675pub(crate) mod tests {
676 use super::*;
677 use crate::builtins::common::test_support;
678 use futures::executor::block_on;
679 use runmat_builtins::{ComplexTensor as CMatrix, ResolveContext, Tensor as Matrix, Type};
680
681 fn error_message(err: RuntimeError) -> String {
682 err.message().to_string()
683 }
684
685 fn tensor_from_value(value: Value) -> Matrix {
686 match value {
687 Value::Tensor(t) => t,
688 other => panic!("expected dense tensor, got {other:?}"),
689 }
690 }
691
692 fn row_major_from_value(value: Value) -> RowMajorMatrix {
693 match value {
694 Value::Tensor(t) => RowMajorMatrix::from_tensor(&t).expect("row-major tensor"),
695 Value::ComplexTensor(ct) => {
696 RowMajorMatrix::from_complex_tensor(&ct).expect("row-major complex tensor")
697 }
698 other => panic!("expected tensor value, got {other:?}"),
699 }
700 }
701
702 #[test]
703 fn lu_type_preserves_matrix_shape() {
704 let out = matrix_unary_type(
705 &[Type::Tensor {
706 shape: Some(vec![Some(2), Some(3)]),
707 }],
708 &ResolveContext::new(Vec::new()),
709 );
710 assert_eq!(
711 out,
712 Type::Tensor {
713 shape: Some(vec![Some(2), Some(3)])
714 }
715 );
716 }
717
718 #[test]
719 fn lu_descriptor_signatures_cover_core_forms() {
720 let labels: Vec<&str> = LU_DESCRIPTOR
721 .signatures
722 .iter()
723 .map(|signature| signature.label)
724 .collect();
725 assert!(labels.contains(&"LU = lu(A)"));
726 assert!(labels.contains(&"LU = lu(A, pivotMode)"));
727 assert!(labels.contains(&"[L, U] = lu(A)"));
728 assert!(labels.contains(&"[L, U] = lu(A, pivotMode)"));
729 assert!(labels.contains(&"[L, U, P] = lu(A)"));
730 assert!(labels.contains(&"[L, U, P] = lu(A, pivotMode)"));
731 }
732
733 #[test]
734 fn lu_descriptor_errors_have_stable_codes() {
735 let codes: Vec<&str> = LU_DESCRIPTOR.errors.iter().map(|err| err.code).collect();
736 assert!(codes.contains(&"RM.LU.INVALID_ARGUMENT"));
737 assert!(codes.contains(&"RM.LU.INVALID_INPUT"));
738 assert!(codes.contains(&"RM.LU.INTERNAL"));
739 }
740
741 fn row_major_matmul(a: &RowMajorMatrix, b: &RowMajorMatrix) -> RowMajorMatrix {
742 assert_eq!(a.cols, b.rows, "incompatible shapes for matmul");
743 let mut out = RowMajorMatrix::zeros(a.rows, b.cols);
744 for i in 0..a.rows {
745 for k in 0..a.cols {
746 let aik = a.get(i, k);
747 for j in 0..b.cols {
748 let acc = out.get(i, j) + aik * b.get(k, j);
749 out.set(i, j, acc);
750 }
751 }
752 }
753 out
754 }
755
756 fn assert_tensor_close(a: &Matrix, b: &Matrix, tol: f64) {
757 assert_eq!(a.shape, b.shape);
758 for (lhs, rhs) in a.data.iter().zip(&b.data) {
759 assert!(
760 (lhs - rhs).abs() <= tol,
761 "mismatch: lhs={lhs}, rhs={rhs}, tol={tol}"
762 );
763 }
764 }
765
766 fn assert_row_major_close(a: &RowMajorMatrix, b: &RowMajorMatrix, tol: f64) {
767 assert_eq!(a.rows, b.rows, "row mismatch");
768 assert_eq!(a.cols, b.cols, "col mismatch");
769 for row in 0..a.rows {
770 for col in 0..a.cols {
771 let lhs = a.get(row, col);
772 let rhs = b.get(row, col);
773 let diff = (lhs - rhs).norm();
774 assert!(
775 diff <= tol,
776 "mismatch at ({row}, {col}): lhs={lhs:?}, rhs={rhs:?}, diff={diff}, tol={tol}"
777 );
778 }
779 }
780 }
781
782 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
783 #[test]
784 fn lu_single_output_produces_combined_matrix() {
785 let a = Matrix::new(
786 vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0],
787 vec![3, 3],
788 )
789 .unwrap();
790 let result = lu_builtin(Value::Tensor(a.clone()), Vec::new()).expect("lu");
791 let lu = tensor_from_value(result);
792 let eval = evaluate(Value::Tensor(a), &[]).expect("evaluate");
793 let expected = tensor_from_value(eval.combined());
794 assert_tensor_close(&lu, &expected, 1e-12);
795 }
796
797 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
798 #[test]
799 fn lu_three_outputs_matches_factorization() {
800 let data = vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0];
801 let a = Matrix::new(data.clone(), vec![3, 3]).unwrap();
802 let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate");
803 let l = tensor_from_value(eval.lower());
804 let u = tensor_from_value(eval.upper());
805 let p = tensor_from_value(eval.permutation_matrix());
806
807 let pa = crate::builtins::common::matrix::matrix_mul(&p, &a).expect("P*A");
808 let lu_product = crate::builtins::common::matrix::matrix_mul(&l, &u).expect("L*U");
809 assert_tensor_close(&pa, &lu_product, 1e-9);
810 }
811
812 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
813 #[test]
814 fn lu_complex_matrix_factorization() {
815 let data = vec![(1.0, 2.0), (3.0, -1.0), (2.0, -1.0), (4.0, 2.0)];
816 let a = CMatrix::new(data.clone(), vec![2, 2]).expect("complex tensor");
817 let eval = evaluate(Value::ComplexTensor(a.clone()), &[]).expect("evaluate complex");
818
819 let l = row_major_from_value(eval.lower());
820 let u = row_major_from_value(eval.upper());
821 let p = row_major_from_value(eval.permutation_matrix());
822 let input = RowMajorMatrix::from_complex_tensor(&a).expect("row-major input");
823
824 let pa = row_major_matmul(&p, &input);
825 let lu = row_major_matmul(&l, &u);
826 assert_row_major_close(&pa, &lu, 1e-9);
827 }
828
829 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
830 #[test]
831 fn lu_handles_singular_matrix() {
832 let a = Matrix::new(vec![0.0, 0.0, 0.0, 0.0], vec![2, 2]).unwrap();
833 let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate singular");
834 let l = tensor_from_value(eval.lower());
835 let u = tensor_from_value(eval.upper());
836 let p = tensor_from_value(eval.permutation_matrix());
837
838 assert!(u.data.iter().any(|&v| v.abs() <= 1e-12));
839
840 let pa = crate::builtins::common::matrix::matrix_mul(&p, &a).expect("P*A");
841 let lu_product = crate::builtins::common::matrix::matrix_mul(&l, &u).expect("L*U");
842 assert_tensor_close(&pa, &lu_product, 1e-9);
843 }
844
845 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
846 #[test]
847 fn lu_vector_option_returns_pivot_vector() {
848 let a = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
849 let eval =
850 evaluate(Value::Tensor(a), &[Value::from("vector")]).expect("evaluate vector mode");
851 assert_eq!(eval.pivot_mode(), PivotMode::Vector);
852 let pivot = tensor_from_value(eval.pivot_vector());
853 assert_eq!(pivot.shape, vec![2, 1]);
854 assert_eq!(pivot.data, vec![2.0, 1.0]);
855 }
856
857 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
858 #[test]
859 fn lu_vector_option_case_insensitive() {
860 let a = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
861 let eval =
862 evaluate(Value::Tensor(a), &[Value::from("VECTOR")]).expect("evaluate vector option");
863 assert_eq!(eval.pivot_mode(), PivotMode::Vector);
864 }
865
866 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
867 #[test]
868 fn lu_matrix_option_returns_permutation_matrix() {
869 let a = Matrix::new(vec![2.0, 1.0, 3.0, 4.0], vec![2, 2]).unwrap();
870 let eval =
871 evaluate(Value::Tensor(a), &[Value::from("matrix")]).expect("evaluate matrix option");
872 assert_eq!(eval.pivot_mode(), PivotMode::Matrix);
873 let perm_selected = tensor_from_value(eval.permutation());
874 let perm_matrix = tensor_from_value(eval.permutation_matrix());
875 assert_eq!(perm_selected.shape, perm_matrix.shape);
876 assert_tensor_close(&perm_selected, &perm_matrix, 1e-12);
877 }
878
879 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
880 #[test]
881 fn lu_handles_rectangular_matrices() {
882 let a = Matrix::new(vec![3.0, 6.0, 1.0, 3.0, 2.0, 4.0], vec![2, 3]).unwrap();
883 let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate rectangular");
884 let l = tensor_from_value(eval.lower());
885 let u = tensor_from_value(eval.upper());
886 let p = tensor_from_value(eval.permutation_matrix());
887 assert_eq!(l.shape, vec![2, 2]);
888 assert_eq!(u.shape, vec![2, 3]);
889 assert_eq!(p.shape, vec![2, 2]);
890
891 let pa = crate::builtins::common::matrix::matrix_mul(&p, &a).expect("P*A");
892 let lu_product = crate::builtins::common::matrix::matrix_mul(&l, &u).expect("L*U");
893 assert_tensor_close(&pa, &lu_product, 1e-9);
894 }
895
896 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
897 #[test]
898 fn lu_rejects_unknown_option() {
899 let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
900 let err = match evaluate(Value::Tensor(a), &[Value::from("invalid")]) {
901 Ok(_) => panic!("expected option parse failure"),
902 Err(err) => {
903 assert_eq!(err.identifier(), LU_ERROR_INVALID_ARGUMENT.identifier);
904 error_message(err)
905 }
906 };
907 assert!(err.contains("unknown option"));
908 }
909
910 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
911 #[test]
912 fn lu_rejects_non_string_option() {
913 let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
914 let err = match evaluate(Value::Tensor(a), &[Value::Num(2.0)]) {
915 Ok(_) => panic!("expected option parse failure"),
916 Err(err) => {
917 assert_eq!(err.identifier(), LU_ERROR_INVALID_ARGUMENT.identifier);
918 error_message(err)
919 }
920 };
921 assert!(err.contains("unknown option"));
922 }
923
924 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
925 #[test]
926 fn lu_rejects_multiple_options() {
927 let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
928 let err = match evaluate(
929 Value::Tensor(a),
930 &[Value::from("matrix"), Value::from("vector")],
931 ) {
932 Ok(_) => panic!("expected option arity failure"),
933 Err(err) => {
934 assert_eq!(err.identifier(), LU_ERROR_INVALID_ARGUMENT.identifier);
935 error_message(err)
936 }
937 };
938 assert!(err.contains("too many option arguments"));
939 }
940
941 #[test]
942 fn lu_invalid_input_identifier_is_stable() {
943 let tensor = Matrix::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 2, 2]).expect("tensor");
944 let err = match evaluate(Value::Tensor(tensor), &[]) {
945 Ok(_) => panic!("expected 2-D input failure"),
946 Err(err) => err,
947 };
948 assert_eq!(err.identifier(), LU_ERROR_INVALID_INPUT.identifier);
949 }
950
951 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
952 #[test]
953 fn lu_gpu_provider_roundtrip() {
954 test_support::with_test_provider(|provider| {
955 let host = Matrix::new(vec![10.0, 3.0, 7.0, 2.0], vec![2, 2]).unwrap();
956 let view = runmat_accelerate_api::HostTensorView {
957 data: &host.data,
958 shape: &host.shape,
959 };
960 let handle = provider.upload(&view).expect("upload");
961 let eval = evaluate(Value::GpuTensor(handle.clone()), &[]).expect("evaluate gpu input");
962 let lower_val = eval.lower();
963 let upper_val = eval.upper();
964 let perm_val = eval.permutation_matrix();
965 assert!(matches!(lower_val, Value::GpuTensor(_)));
966 assert!(matches!(upper_val, Value::GpuTensor(_)));
967 assert!(matches!(perm_val, Value::GpuTensor(_)));
968 let l = test_support::gather(lower_val).expect("gather lower");
969 let u = test_support::gather(upper_val).expect("gather upper");
970 let p = test_support::gather(perm_val).expect("gather permutation");
971 let pa = crate::builtins::common::matrix::matrix_mul(&p, &host).expect("P*A");
972 let lu_product = crate::builtins::common::matrix::matrix_mul(&l, &u).expect("L*U");
973 assert_tensor_close(&pa, &lu_product, 1e-9);
974 });
975 }
976
977 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
978 #[test]
979 fn lu_gpu_vector_option_roundtrip() {
980 test_support::with_test_provider(|provider| {
981 let host = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
982 let view = runmat_accelerate_api::HostTensorView {
983 data: &host.data,
984 shape: &host.shape,
985 };
986 let handle = provider.upload(&view).expect("upload");
987 let eval =
988 evaluate(Value::GpuTensor(handle), &[Value::from("vector")]).expect("gpu vector");
989 let pivot_val = eval.permutation();
990 assert!(matches!(pivot_val, Value::GpuTensor(_)));
991 let pivot = test_support::gather(pivot_val).expect("gather pivot");
992 assert_eq!(pivot.shape, vec![2, 1]);
993 let expected = Matrix::new(vec![2.0, 1.0], vec![2, 1]).unwrap();
994 assert_tensor_close(&pivot, &expected, 1e-12);
995 });
996 }
997
998 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
999 #[test]
1000 fn lu_accepts_scalar_inputs() {
1001 let eval = evaluate(Value::Num(5.0), &[]).expect("evaluate scalar");
1002 let l = tensor_from_value(eval.lower());
1003 let u = tensor_from_value(eval.upper());
1004 let p = tensor_from_value(eval.permutation_matrix());
1005 assert_eq!(l.data, vec![1.0]);
1006 assert_eq!(u.data, vec![5.0]);
1007 assert_eq!(p.data, vec![1.0]);
1008 }
1009
1010 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1011 #[test]
1012 #[cfg(feature = "wgpu")]
1013 fn lu_wgpu_matches_cpu() {
1014 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1015 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1016 );
1017 let host = Matrix::new(
1018 vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0],
1019 vec![3, 3],
1020 )
1021 .unwrap();
1022 let cpu_eval = evaluate(Value::Tensor(host.clone()), &[]).expect("cpu evaluate");
1023 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1024 let view = runmat_accelerate_api::HostTensorView {
1025 data: &host.data,
1026 shape: &host.shape,
1027 };
1028 let handle = provider.upload(&view).expect("upload");
1029 let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu evaluate");
1030
1031 let l_cpu = tensor_from_value(cpu_eval.lower());
1032 let u_cpu = tensor_from_value(cpu_eval.upper());
1033 let p_cpu = tensor_from_value(cpu_eval.permutation_matrix());
1034 let lu_cpu = tensor_from_value(cpu_eval.combined());
1035
1036 let l_gpu = test_support::gather(gpu_eval.lower()).expect("gather L");
1037 let u_gpu = test_support::gather(gpu_eval.upper()).expect("gather U");
1038 let p_gpu = test_support::gather(gpu_eval.permutation_matrix()).expect("gather P");
1039 let lu_gpu = test_support::gather(gpu_eval.combined()).expect("gather LU");
1040
1041 assert_tensor_close(&l_cpu, &l_gpu, 1e-12);
1042 assert_tensor_close(&u_cpu, &u_gpu, 1e-12);
1043 assert_tensor_close(&p_cpu, &p_gpu, 1e-12);
1044 assert_tensor_close(&lu_cpu, &lu_gpu, 1e-12);
1045
1046 let pivot_cpu = tensor_from_value(cpu_eval.pivot_vector());
1047 let pivot_gpu = test_support::gather(gpu_eval.pivot_vector()).expect("gather pivot vector");
1048 assert_tensor_close(&pivot_cpu, &pivot_gpu, 1e-12);
1049
1050 let handle_vector = provider.upload(&view).expect("upload vector option");
1051 let gpu_vector_eval = evaluate(Value::GpuTensor(handle_vector), &[Value::from("vector")])
1052 .expect("gpu vector evaluate");
1053 let pivot_vector =
1054 test_support::gather(gpu_vector_eval.permutation()).expect("gather vector pivot");
1055 assert_tensor_close(&pivot_cpu, &pivot_vector, 1e-12);
1056 }
1057
1058 fn lu_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
1059 block_on(super::lu_builtin(value, rest))
1060 }
1061
1062 fn evaluate(value: Value, args: &[Value]) -> BuiltinResult<LuEval> {
1063 block_on(super::evaluate(value, args))
1064 }
1065}