1use crate::builtins::common::spec::{
4 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
5 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
6};
7use crate::builtins::common::{gpu_helpers, tensor};
8use crate::builtins::math::linalg::type_resolvers::matrix_unary_type;
9use crate::{build_runtime_error, BuiltinResult, RuntimeError};
10
11use num_complex::Complex64;
12use runmat_accelerate_api::{GpuTensorHandle, ProviderLuResult};
13use runmat_builtins::{ComplexTensor, Tensor, Value};
14use runmat_macros::runtime_builtin;
15
16const BUILTIN_NAME: &str = "lu";
17
18#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::linalg::factor::lu")]
19pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
20 name: "lu",
21 op_kind: GpuOpKind::Custom("lu-factor"),
22 supported_precisions: &[ScalarType::F64],
23 broadcast: BroadcastSemantics::None,
24 provider_hooks: &[ProviderHook::Custom("lu")],
25 constant_strategy: ConstantStrategy::InlineLiteral,
26 residency: ResidencyPolicy::NewHandle,
27 nan_mode: ReductionNaN::Include,
28 two_pass_threshold: None,
29 workgroup_size: None,
30 accepts_nan_mode: false,
31 notes: "Prefers the provider `lu` hook; automatically gathers and falls back to the CPU implementation when no provider support is registered.",
32};
33
34fn lu_error(message: impl Into<String>) -> RuntimeError {
35 build_runtime_error(message)
36 .with_builtin(BUILTIN_NAME)
37 .build()
38}
39
40fn with_lu_context(mut error: RuntimeError) -> RuntimeError {
41 if error.message() == "interaction pending..." {
42 return build_runtime_error("interaction pending...")
43 .with_builtin(BUILTIN_NAME)
44 .build();
45 }
46 if error.context.builtin.is_none() {
47 error.context = error.context.with_builtin(BUILTIN_NAME);
48 }
49 error
50}
51
52#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::linalg::factor::lu")]
53pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
54 name: "lu",
55 shape: ShapeRequirements::Any,
56 constant_strategy: ConstantStrategy::InlineLiteral,
57 elementwise: None,
58 reduction: None,
59 emits_nan: false,
60 notes: "LU decomposition is not part of expression fusion; calls execute eagerly on the CPU.",
61};
62
63#[runtime_builtin(
64 name = "lu",
65 category = "math/linalg/factor",
66 summary = "LU decomposition with partial pivoting.",
67 keywords = "lu,factorization,decomposition,permutation",
68 accel = "sink",
69 sink = true,
70 type_resolver(matrix_unary_type),
71 builtin_path = "crate::builtins::math::linalg::factor::lu"
72)]
73async fn lu_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
74 let eval = evaluate(value, &rest).await?;
75 if let Some(out_count) = crate::output_count::current_output_count() {
76 if out_count == 0 {
77 return Ok(Value::OutputList(Vec::new()));
78 }
79 if out_count == 1 {
80 return Ok(Value::OutputList(vec![eval.combined()]));
81 }
82 if out_count == 2 {
83 return Ok(Value::OutputList(vec![eval.lower(), eval.upper()]));
84 }
85 if out_count == 3 {
86 return Ok(Value::OutputList(vec![
87 eval.lower(),
88 eval.upper(),
89 eval.permutation(),
90 ]));
91 }
92 return Err(lu_error("lu currently supports at most three outputs"));
93 }
94 Ok(eval.combined())
95}
96
97#[derive(Clone)]
99pub struct LuEval {
100 combined: Value,
101 lower: Value,
102 upper: Value,
103 perm_matrix: Value,
104 perm_vector: Value,
105 pivot_mode: PivotMode,
106}
107
108impl LuEval {
109 pub fn combined(&self) -> Value {
111 self.combined.clone()
112 }
113
114 pub fn lower(&self) -> Value {
116 self.lower.clone()
117 }
118
119 pub fn upper(&self) -> Value {
121 self.upper.clone()
122 }
123
124 pub fn permutation(&self) -> Value {
126 match self.pivot_mode {
127 PivotMode::Matrix => self.perm_matrix.clone(),
128 PivotMode::Vector => self.perm_vector.clone(),
129 }
130 }
131
132 pub fn permutation_matrix(&self) -> Value {
134 self.perm_matrix.clone()
135 }
136
137 pub fn pivot_vector(&self) -> Value {
139 self.perm_vector.clone()
140 }
141
142 pub fn pivot_mode(&self) -> PivotMode {
144 self.pivot_mode
145 }
146
147 fn from_components(components: LuComponents, pivot_mode: PivotMode) -> BuiltinResult<Self> {
148 let combined = matrix_to_value(&components.combined)?;
149 let lower = matrix_to_value(&components.lower)?;
150 let upper = matrix_to_value(&components.upper)?;
151 let perm_matrix = matrix_to_value(&components.permutation)?;
152 let perm_vector = pivot_vector_to_value(&components.pivot_vector)?;
153 Ok(Self {
154 combined,
155 lower,
156 upper,
157 perm_matrix,
158 perm_vector,
159 pivot_mode,
160 })
161 }
162
163 fn from_provider(result: ProviderLuResult, pivot_mode: PivotMode) -> Self {
164 Self {
165 combined: Value::GpuTensor(result.combined),
166 lower: Value::GpuTensor(result.lower),
167 upper: Value::GpuTensor(result.upper),
168 perm_matrix: Value::GpuTensor(result.perm_matrix),
169 perm_vector: Value::GpuTensor(result.perm_vector),
170 pivot_mode,
171 }
172 }
173}
174
175#[derive(Clone, Copy, Debug, PartialEq, Eq)]
177pub enum PivotMode {
178 Matrix,
179 Vector,
180}
181
182impl Default for PivotMode {
183 fn default() -> Self {
184 Self::Matrix
185 }
186}
187
188pub async fn evaluate(value: Value, args: &[Value]) -> BuiltinResult<LuEval> {
190 let pivot_mode = parse_pivot_mode(args)?;
191 match value {
192 Value::GpuTensor(handle) => {
193 if let Some(eval) = evaluate_gpu(&handle, pivot_mode).await? {
194 return Ok(eval);
195 }
196 let tensor = gpu_helpers::gather_tensor_async(&handle)
197 .await
198 .map_err(with_lu_context)?;
199 evaluate_host_value(Value::Tensor(tensor), pivot_mode).await
200 }
201 other => evaluate_host_value(other, pivot_mode).await,
202 }
203}
204
205async fn evaluate_host_value(value: Value, pivot_mode: PivotMode) -> BuiltinResult<LuEval> {
206 let matrix = extract_matrix(value).await?;
207 let components = lu_factor(matrix)?;
208 LuEval::from_components(components, pivot_mode)
209}
210
211async fn evaluate_gpu(
212 handle: &GpuTensorHandle,
213 pivot_mode: PivotMode,
214) -> BuiltinResult<Option<LuEval>> {
215 if let Some(provider) = runmat_accelerate_api::provider() {
216 if let Ok(result) = provider.lu(handle).await {
217 return Ok(Some(LuEval::from_provider(result, pivot_mode)));
218 }
219 }
220 Ok(None)
221}
222
223fn parse_pivot_mode(args: &[Value]) -> BuiltinResult<PivotMode> {
224 if args.is_empty() {
225 return Ok(PivotMode::Matrix);
226 }
227 if args.len() > 1 {
228 return Err(lu_error("lu: too many option arguments"));
229 }
230 let Some(option) = tensor::value_to_string(&args[0]) else {
231 return Err(lu_error("lu: option must be a string or character vector"));
232 };
233 match option.trim().to_ascii_lowercase().as_str() {
234 "matrix" => Ok(PivotMode::Matrix),
235 "vector" => Ok(PivotMode::Vector),
236 other => Err(lu_error(format!("lu: unknown option '{other}'"))),
237 }
238}
239
240async fn extract_matrix(value: Value) -> BuiltinResult<RowMajorMatrix> {
241 match value {
242 Value::Tensor(t) => RowMajorMatrix::from_tensor(&t),
243 Value::ComplexTensor(ct) => RowMajorMatrix::from_complex_tensor(&ct),
244 Value::GpuTensor(handle) => {
245 let tensor = gpu_helpers::gather_tensor_async(&handle)
246 .await
247 .map_err(with_lu_context)?;
248 RowMajorMatrix::from_tensor(&tensor)
249 }
250 Value::LogicalArray(logical) => {
251 let tensor = tensor::logical_to_tensor(&logical)
252 .map_err(|err| lu_error(format!("lu: {err}")))?;
253 RowMajorMatrix::from_tensor(&tensor)
254 }
255 Value::Num(n) => Ok(RowMajorMatrix::from_scalar(Complex64::new(n, 0.0))),
256 Value::Int(i) => Ok(RowMajorMatrix::from_scalar(Complex64::new(i.to_f64(), 0.0))),
257 Value::Bool(b) => Ok(RowMajorMatrix::from_scalar(Complex64::new(
258 if b { 1.0 } else { 0.0 },
259 0.0,
260 ))),
261 Value::Complex(re, im) => Ok(RowMajorMatrix::from_scalar(Complex64::new(re, im))),
262 Value::CharArray(_) | Value::String(_) | Value::StringArray(_) => Err(lu_error(
263 "lu: character data is not supported; convert to numeric values first",
264 )),
265 other => Err(lu_error(format!("lu: unsupported input type {:?}", other))),
266 }
267}
268
269struct LuComponents {
270 combined: RowMajorMatrix,
271 lower: RowMajorMatrix,
272 upper: RowMajorMatrix,
273 permutation: RowMajorMatrix,
274 pivot_vector: Vec<f64>,
275}
276
277fn lu_factor(mut matrix: RowMajorMatrix) -> BuiltinResult<LuComponents> {
278 let rows = matrix.rows;
279 let cols = matrix.cols;
280 let min_dim = rows.min(cols);
281 let mut perm: Vec<usize> = (0..rows).collect();
282
283 for k in 0..min_dim {
284 let mut pivot_row = k;
286 let mut pivot_abs = 0.0;
287 for r in k..rows {
288 let val = matrix.get(r, k);
289 let abs = val.norm();
290 if abs > pivot_abs {
291 pivot_abs = abs;
292 pivot_row = r;
293 }
294 }
295
296 if pivot_row != k {
297 matrix.swap_rows(pivot_row, k);
298 perm.swap(pivot_row, k);
299 }
300
301 if pivot_abs <= EPS {
302 for r in (k + 1)..rows {
304 matrix.set(r, k, Complex64::new(0.0, 0.0));
305 }
306 continue;
307 }
308
309 let pivot_value = matrix.get(k, k);
310 for r in (k + 1)..rows {
311 let factor = matrix.get(r, k) / pivot_value;
312 matrix.set(r, k, factor);
313 for c in (k + 1)..cols {
314 let updated = matrix.get(r, c) - factor * matrix.get(k, c);
315 matrix.set(r, c, updated);
316 }
317 }
318 }
319
320 let combined = matrix.clone();
321 let lower = build_lower(&matrix);
322 let upper = build_upper(&matrix);
323 let permutation = build_permutation(rows, &perm);
324 let pivot_vector: Vec<f64> = perm.iter().map(|idx| (*idx + 1) as f64).collect();
325
326 Ok(LuComponents {
327 combined,
328 lower,
329 upper,
330 permutation,
331 pivot_vector,
332 })
333}
334
335fn build_lower(matrix: &RowMajorMatrix) -> RowMajorMatrix {
336 let rows = matrix.rows;
337 let cols = matrix.cols;
338 let min_dim = rows.min(cols);
339 let mut lower = RowMajorMatrix::identity(rows);
340 for i in 0..rows {
341 for j in 0..min_dim {
342 if i > j {
343 lower.set(i, j, matrix.get(i, j));
344 }
345 }
346 }
347 lower
348}
349
350fn build_upper(matrix: &RowMajorMatrix) -> RowMajorMatrix {
351 let rows = matrix.rows;
352 let cols = matrix.cols;
353 let mut upper = RowMajorMatrix::zeros(rows, cols);
354 for i in 0..rows {
355 for j in 0..cols {
356 if i <= j {
357 upper.set(i, j, matrix.get(i, j));
358 }
359 }
360 }
361 upper
362}
363
364fn build_permutation(rows: usize, perm: &[usize]) -> RowMajorMatrix {
365 let mut matrix = RowMajorMatrix::zeros(rows, rows);
366 for (i, &col) in perm.iter().enumerate() {
367 if col < rows {
368 matrix.set(i, col, Complex64::new(1.0, 0.0));
369 }
370 }
371 matrix
372}
373
374const EPS: f64 = 1.0e-12;
375
376fn matrix_to_value(matrix: &RowMajorMatrix) -> BuiltinResult<Value> {
377 let mut has_imag = false;
378 for val in &matrix.data {
379 if val.im.abs() > EPS {
380 has_imag = true;
381 break;
382 }
383 }
384 if has_imag {
385 let mut data = Vec::with_capacity(matrix.rows * matrix.cols);
386 for col in 0..matrix.cols {
387 for row in 0..matrix.rows {
388 let idx = row * matrix.cols + col;
389 let v = matrix.data[idx];
390 data.push((v.re, v.im));
391 }
392 }
393 let tensor = ComplexTensor::new(data, vec![matrix.rows, matrix.cols])
394 .map_err(|e| lu_error(format!("lu: {e}")))?;
395 Ok(Value::ComplexTensor(tensor))
396 } else {
397 let mut data = Vec::with_capacity(matrix.rows * matrix.cols);
398 for col in 0..matrix.cols {
399 for row in 0..matrix.rows {
400 let idx = row * matrix.cols + col;
401 data.push(matrix.data[idx].re);
402 }
403 }
404 let tensor = Tensor::new(data, vec![matrix.rows, matrix.cols])
405 .map_err(|e| lu_error(format!("lu: {e}")))?;
406 Ok(Value::Tensor(tensor))
407 }
408}
409
410fn pivot_vector_to_value(pivot: &[f64]) -> BuiltinResult<Value> {
411 let rows = pivot.len();
412 let tensor =
413 Tensor::new(pivot.to_vec(), vec![rows, 1]).map_err(|e| lu_error(format!("lu: {e}")))?;
414 Ok(Value::Tensor(tensor))
415}
416
417#[derive(Clone)]
418struct RowMajorMatrix {
419 rows: usize,
420 cols: usize,
421 data: Vec<Complex64>,
422}
423
424impl RowMajorMatrix {
425 fn zeros(rows: usize, cols: usize) -> Self {
426 Self {
427 rows,
428 cols,
429 data: vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)],
430 }
431 }
432
433 fn identity(size: usize) -> Self {
434 let mut matrix = Self::zeros(size, size);
435 for i in 0..size {
436 matrix.set(i, i, Complex64::new(1.0, 0.0));
437 }
438 matrix
439 }
440
441 fn from_scalar(value: Complex64) -> Self {
442 Self {
443 rows: 1,
444 cols: 1,
445 data: vec![value],
446 }
447 }
448
449 fn from_tensor(tensor: &Tensor) -> BuiltinResult<Self> {
450 if tensor.shape.len() > 2 {
451 return Err(lu_error("lu: input must be 2-D"));
452 }
453 let rows = tensor.rows();
454 let cols = tensor.cols();
455 let mut data = vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)];
456 for col in 0..cols {
457 for row in 0..rows {
458 let idx_col_major = row + col * rows;
459 let idx_row_major = row * cols + col;
460 data[idx_row_major] = Complex64::new(tensor.data[idx_col_major], 0.0);
461 }
462 }
463 Ok(Self { rows, cols, data })
464 }
465
466 fn from_complex_tensor(tensor: &ComplexTensor) -> BuiltinResult<Self> {
467 if tensor.shape.len() > 2 {
468 return Err(lu_error("lu: input must be 2-D"));
469 }
470 let rows = tensor.rows;
471 let cols = tensor.cols;
472 let mut data = vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)];
473 for col in 0..cols {
474 for row in 0..rows {
475 let idx_col_major = row + col * rows;
476 let idx_row_major = row * cols + col;
477 let (re, im) = tensor.data[idx_col_major];
478 data[idx_row_major] = Complex64::new(re, im);
479 }
480 }
481 Ok(Self { rows, cols, data })
482 }
483
484 fn get(&self, row: usize, col: usize) -> Complex64 {
485 self.data[row * self.cols + col]
486 }
487
488 fn set(&mut self, row: usize, col: usize, value: Complex64) {
489 self.data[row * self.cols + col] = value;
490 }
491
492 fn swap_rows(&mut self, r1: usize, r2: usize) {
493 if r1 == r2 {
494 return;
495 }
496 for col in 0..self.cols {
497 self.data.swap(r1 * self.cols + col, r2 * self.cols + col);
498 }
499 }
500}
501
502#[cfg(test)]
503pub(crate) mod tests {
504 use super::*;
505 use crate::builtins::common::test_support;
506 use futures::executor::block_on;
507 use runmat_builtins::{ComplexTensor as CMatrix, ResolveContext, Tensor as Matrix, Type};
508
509 fn error_message(err: RuntimeError) -> String {
510 err.message().to_string()
511 }
512
513 fn tensor_from_value(value: Value) -> Matrix {
514 match value {
515 Value::Tensor(t) => t,
516 other => panic!("expected dense tensor, got {other:?}"),
517 }
518 }
519
520 fn row_major_from_value(value: Value) -> RowMajorMatrix {
521 match value {
522 Value::Tensor(t) => RowMajorMatrix::from_tensor(&t).expect("row-major tensor"),
523 Value::ComplexTensor(ct) => {
524 RowMajorMatrix::from_complex_tensor(&ct).expect("row-major complex tensor")
525 }
526 other => panic!("expected tensor value, got {other:?}"),
527 }
528 }
529
530 #[test]
531 fn lu_type_preserves_matrix_shape() {
532 let out = matrix_unary_type(
533 &[Type::Tensor {
534 shape: Some(vec![Some(2), Some(3)]),
535 }],
536 &ResolveContext::new(Vec::new()),
537 );
538 assert_eq!(
539 out,
540 Type::Tensor {
541 shape: Some(vec![Some(2), Some(3)])
542 }
543 );
544 }
545
546 fn row_major_matmul(a: &RowMajorMatrix, b: &RowMajorMatrix) -> RowMajorMatrix {
547 assert_eq!(a.cols, b.rows, "incompatible shapes for matmul");
548 let mut out = RowMajorMatrix::zeros(a.rows, b.cols);
549 for i in 0..a.rows {
550 for k in 0..a.cols {
551 let aik = a.get(i, k);
552 for j in 0..b.cols {
553 let acc = out.get(i, j) + aik * b.get(k, j);
554 out.set(i, j, acc);
555 }
556 }
557 }
558 out
559 }
560
561 fn assert_tensor_close(a: &Matrix, b: &Matrix, tol: f64) {
562 assert_eq!(a.shape, b.shape);
563 for (lhs, rhs) in a.data.iter().zip(&b.data) {
564 assert!(
565 (lhs - rhs).abs() <= tol,
566 "mismatch: lhs={lhs}, rhs={rhs}, tol={tol}"
567 );
568 }
569 }
570
571 fn assert_row_major_close(a: &RowMajorMatrix, b: &RowMajorMatrix, tol: f64) {
572 assert_eq!(a.rows, b.rows, "row mismatch");
573 assert_eq!(a.cols, b.cols, "col mismatch");
574 for row in 0..a.rows {
575 for col in 0..a.cols {
576 let lhs = a.get(row, col);
577 let rhs = b.get(row, col);
578 let diff = (lhs - rhs).norm();
579 assert!(
580 diff <= tol,
581 "mismatch at ({row}, {col}): lhs={lhs:?}, rhs={rhs:?}, diff={diff}, tol={tol}"
582 );
583 }
584 }
585 }
586
587 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
588 #[test]
589 fn lu_single_output_produces_combined_matrix() {
590 let a = Matrix::new(
591 vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0],
592 vec![3, 3],
593 )
594 .unwrap();
595 let result = lu_builtin(Value::Tensor(a.clone()), Vec::new()).expect("lu");
596 let lu = tensor_from_value(result);
597 let eval = evaluate(Value::Tensor(a), &[]).expect("evaluate");
598 let expected = tensor_from_value(eval.combined());
599 assert_tensor_close(&lu, &expected, 1e-12);
600 }
601
602 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
603 #[test]
604 fn lu_three_outputs_matches_factorization() {
605 let data = vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0];
606 let a = Matrix::new(data.clone(), vec![3, 3]).unwrap();
607 let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate");
608 let l = tensor_from_value(eval.lower());
609 let u = tensor_from_value(eval.upper());
610 let p = tensor_from_value(eval.permutation_matrix());
611
612 let pa = crate::matrix::matrix_mul(&p, &a).expect("P*A");
613 let lu_product = crate::matrix::matrix_mul(&l, &u).expect("L*U");
614 assert_tensor_close(&pa, &lu_product, 1e-9);
615 }
616
617 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
618 #[test]
619 fn lu_complex_matrix_factorization() {
620 let data = vec![(1.0, 2.0), (3.0, -1.0), (2.0, -1.0), (4.0, 2.0)];
621 let a = CMatrix::new(data.clone(), vec![2, 2]).expect("complex tensor");
622 let eval = evaluate(Value::ComplexTensor(a.clone()), &[]).expect("evaluate complex");
623
624 let l = row_major_from_value(eval.lower());
625 let u = row_major_from_value(eval.upper());
626 let p = row_major_from_value(eval.permutation_matrix());
627 let input = RowMajorMatrix::from_complex_tensor(&a).expect("row-major input");
628
629 let pa = row_major_matmul(&p, &input);
630 let lu = row_major_matmul(&l, &u);
631 assert_row_major_close(&pa, &lu, 1e-9);
632 }
633
634 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
635 #[test]
636 fn lu_handles_singular_matrix() {
637 let a = Matrix::new(vec![0.0, 0.0, 0.0, 0.0], vec![2, 2]).unwrap();
638 let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate singular");
639 let l = tensor_from_value(eval.lower());
640 let u = tensor_from_value(eval.upper());
641 let p = tensor_from_value(eval.permutation_matrix());
642
643 assert!(u.data.iter().any(|&v| v.abs() <= 1e-12));
644
645 let pa = crate::matrix::matrix_mul(&p, &a).expect("P*A");
646 let lu_product = crate::matrix::matrix_mul(&l, &u).expect("L*U");
647 assert_tensor_close(&pa, &lu_product, 1e-9);
648 }
649
650 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
651 #[test]
652 fn lu_vector_option_returns_pivot_vector() {
653 let a = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
654 let eval =
655 evaluate(Value::Tensor(a), &[Value::from("vector")]).expect("evaluate vector mode");
656 assert_eq!(eval.pivot_mode(), PivotMode::Vector);
657 let pivot = tensor_from_value(eval.pivot_vector());
658 assert_eq!(pivot.shape, vec![2, 1]);
659 assert_eq!(pivot.data, vec![2.0, 1.0]);
660 }
661
662 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
663 #[test]
664 fn lu_vector_option_case_insensitive() {
665 let a = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
666 let eval =
667 evaluate(Value::Tensor(a), &[Value::from("VECTOR")]).expect("evaluate vector option");
668 assert_eq!(eval.pivot_mode(), PivotMode::Vector);
669 }
670
671 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
672 #[test]
673 fn lu_matrix_option_returns_permutation_matrix() {
674 let a = Matrix::new(vec![2.0, 1.0, 3.0, 4.0], vec![2, 2]).unwrap();
675 let eval =
676 evaluate(Value::Tensor(a), &[Value::from("matrix")]).expect("evaluate matrix option");
677 assert_eq!(eval.pivot_mode(), PivotMode::Matrix);
678 let perm_selected = tensor_from_value(eval.permutation());
679 let perm_matrix = tensor_from_value(eval.permutation_matrix());
680 assert_eq!(perm_selected.shape, perm_matrix.shape);
681 assert_tensor_close(&perm_selected, &perm_matrix, 1e-12);
682 }
683
684 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
685 #[test]
686 fn lu_handles_rectangular_matrices() {
687 let a = Matrix::new(vec![3.0, 6.0, 1.0, 3.0, 2.0, 4.0], vec![2, 3]).unwrap();
688 let eval = evaluate(Value::Tensor(a.clone()), &[]).expect("evaluate rectangular");
689 let l = tensor_from_value(eval.lower());
690 let u = tensor_from_value(eval.upper());
691 let p = tensor_from_value(eval.permutation_matrix());
692 assert_eq!(l.shape, vec![2, 2]);
693 assert_eq!(u.shape, vec![2, 3]);
694 assert_eq!(p.shape, vec![2, 2]);
695
696 let pa = crate::matrix::matrix_mul(&p, &a).expect("P*A");
697 let lu_product = crate::matrix::matrix_mul(&l, &u).expect("L*U");
698 assert_tensor_close(&pa, &lu_product, 1e-9);
699 }
700
701 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
702 #[test]
703 fn lu_rejects_unknown_option() {
704 let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
705 let err = match evaluate(Value::Tensor(a), &[Value::from("invalid")]) {
706 Ok(_) => panic!("expected option parse failure"),
707 Err(err) => error_message(err),
708 };
709 assert!(err.contains("unknown option"));
710 }
711
712 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
713 #[test]
714 fn lu_rejects_non_string_option() {
715 let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
716 let err = match evaluate(Value::Tensor(a), &[Value::Num(2.0)]) {
717 Ok(_) => panic!("expected option parse failure"),
718 Err(err) => error_message(err),
719 };
720 assert!(err.contains("unknown option"));
721 }
722
723 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
724 #[test]
725 fn lu_rejects_multiple_options() {
726 let a = Matrix::new(vec![1.0], vec![1, 1]).unwrap();
727 let err = match evaluate(
728 Value::Tensor(a),
729 &[Value::from("matrix"), Value::from("vector")],
730 ) {
731 Ok(_) => panic!("expected option arity failure"),
732 Err(err) => error_message(err),
733 };
734 assert!(err.contains("too many option arguments"));
735 }
736
737 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
738 #[test]
739 fn lu_gpu_provider_roundtrip() {
740 test_support::with_test_provider(|provider| {
741 let host = Matrix::new(vec![10.0, 3.0, 7.0, 2.0], vec![2, 2]).unwrap();
742 let view = runmat_accelerate_api::HostTensorView {
743 data: &host.data,
744 shape: &host.shape,
745 };
746 let handle = provider.upload(&view).expect("upload");
747 let eval = evaluate(Value::GpuTensor(handle.clone()), &[]).expect("evaluate gpu input");
748 let lower_val = eval.lower();
749 let upper_val = eval.upper();
750 let perm_val = eval.permutation_matrix();
751 assert!(matches!(lower_val, Value::GpuTensor(_)));
752 assert!(matches!(upper_val, Value::GpuTensor(_)));
753 assert!(matches!(perm_val, Value::GpuTensor(_)));
754 let l = test_support::gather(lower_val).expect("gather lower");
755 let u = test_support::gather(upper_val).expect("gather upper");
756 let p = test_support::gather(perm_val).expect("gather permutation");
757 let pa = crate::matrix::matrix_mul(&p, &host).expect("P*A");
758 let lu_product = crate::matrix::matrix_mul(&l, &u).expect("L*U");
759 assert_tensor_close(&pa, &lu_product, 1e-9);
760 });
761 }
762
763 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
764 #[test]
765 fn lu_gpu_vector_option_roundtrip() {
766 test_support::with_test_provider(|provider| {
767 let host = Matrix::new(vec![4.0, 6.0, 3.0, 3.0], vec![2, 2]).unwrap();
768 let view = runmat_accelerate_api::HostTensorView {
769 data: &host.data,
770 shape: &host.shape,
771 };
772 let handle = provider.upload(&view).expect("upload");
773 let eval =
774 evaluate(Value::GpuTensor(handle), &[Value::from("vector")]).expect("gpu vector");
775 let pivot_val = eval.permutation();
776 assert!(matches!(pivot_val, Value::GpuTensor(_)));
777 let pivot = test_support::gather(pivot_val).expect("gather pivot");
778 assert_eq!(pivot.shape, vec![2, 1]);
779 let expected = Matrix::new(vec![2.0, 1.0], vec![2, 1]).unwrap();
780 assert_tensor_close(&pivot, &expected, 1e-12);
781 });
782 }
783
784 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
785 #[test]
786 fn lu_accepts_scalar_inputs() {
787 let eval = evaluate(Value::Num(5.0), &[]).expect("evaluate scalar");
788 let l = tensor_from_value(eval.lower());
789 let u = tensor_from_value(eval.upper());
790 let p = tensor_from_value(eval.permutation_matrix());
791 assert_eq!(l.data, vec![1.0]);
792 assert_eq!(u.data, vec![5.0]);
793 assert_eq!(p.data, vec![1.0]);
794 }
795
796 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
797 #[test]
798 #[cfg(feature = "wgpu")]
799 fn lu_wgpu_matches_cpu() {
800 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
801 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
802 );
803 let host = Matrix::new(
804 vec![2.0, 4.0, -2.0, 1.0, -6.0, 7.0, 1.0, 0.0, 2.0],
805 vec![3, 3],
806 )
807 .unwrap();
808 let cpu_eval = evaluate(Value::Tensor(host.clone()), &[]).expect("cpu evaluate");
809 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
810 let view = runmat_accelerate_api::HostTensorView {
811 data: &host.data,
812 shape: &host.shape,
813 };
814 let handle = provider.upload(&view).expect("upload");
815 let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu evaluate");
816
817 let l_cpu = tensor_from_value(cpu_eval.lower());
818 let u_cpu = tensor_from_value(cpu_eval.upper());
819 let p_cpu = tensor_from_value(cpu_eval.permutation_matrix());
820 let lu_cpu = tensor_from_value(cpu_eval.combined());
821
822 let l_gpu = test_support::gather(gpu_eval.lower()).expect("gather L");
823 let u_gpu = test_support::gather(gpu_eval.upper()).expect("gather U");
824 let p_gpu = test_support::gather(gpu_eval.permutation_matrix()).expect("gather P");
825 let lu_gpu = test_support::gather(gpu_eval.combined()).expect("gather LU");
826
827 assert_tensor_close(&l_cpu, &l_gpu, 1e-12);
828 assert_tensor_close(&u_cpu, &u_gpu, 1e-12);
829 assert_tensor_close(&p_cpu, &p_gpu, 1e-12);
830 assert_tensor_close(&lu_cpu, &lu_gpu, 1e-12);
831
832 let pivot_cpu = tensor_from_value(cpu_eval.pivot_vector());
833 let pivot_gpu = test_support::gather(gpu_eval.pivot_vector()).expect("gather pivot vector");
834 assert_tensor_close(&pivot_cpu, &pivot_gpu, 1e-12);
835
836 let handle_vector = provider.upload(&view).expect("upload vector option");
837 let gpu_vector_eval = evaluate(Value::GpuTensor(handle_vector), &[Value::from("vector")])
838 .expect("gpu vector evaluate");
839 let pivot_vector =
840 test_support::gather(gpu_vector_eval.permutation()).expect("gather vector pivot");
841 assert_tensor_close(&pivot_cpu, &pivot_vector, 1e-12);
842 }
843
844 fn lu_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
845 block_on(super::lu_builtin(value, rest))
846 }
847
848 fn evaluate(value: Value, args: &[Value]) -> BuiltinResult<LuEval> {
849 block_on(super::evaluate(value, args))
850 }
851}