1use crate::builtins::common::spec::{
9 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
11};
12use crate::builtins::common::{gpu_helpers, random_args, tensor};
13use crate::builtins::math::linalg::type_resolvers::matrix_unary_type;
14use crate::{build_runtime_error, BuiltinResult, RuntimeError};
15use num_complex::Complex64;
16use runmat_accelerate_api::{GpuTensorHandle, ProviderCholResult};
17use runmat_builtins::{ComplexTensor, Tensor, Value};
18use runmat_macros::runtime_builtin;
19
20const BUILTIN_NAME: &str = "chol";
21
22#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::linalg::factor::chol")]
23pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
24 name: "chol",
25 op_kind: GpuOpKind::Custom("chol-factor"),
26 supported_precisions: &[ScalarType::F64],
27 broadcast: BroadcastSemantics::None,
28 provider_hooks: &[ProviderHook::Custom("chol")],
29 constant_strategy: ConstantStrategy::InlineLiteral,
30 residency: ResidencyPolicy::NewHandle,
31 nan_mode: ReductionNaN::Include,
32 two_pass_threshold: None,
33 workgroup_size: None,
34 accepts_nan_mode: false,
35 notes:
36 "Uses the provider 'chol' hook when present; otherwise gathers to the host implementation.",
37};
38
39fn chol_error(message: impl Into<String>) -> RuntimeError {
40 build_runtime_error(message)
41 .with_builtin(BUILTIN_NAME)
42 .build()
43}
44
45fn with_chol_context(mut error: RuntimeError) -> RuntimeError {
46 if error.message() == "interaction pending..." {
47 return build_runtime_error("interaction pending...")
48 .with_builtin(BUILTIN_NAME)
49 .build();
50 }
51 if error.context.builtin.is_none() {
52 error.context = error.context.with_builtin(BUILTIN_NAME);
53 }
54 error
55}
56
57#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::linalg::factor::chol")]
58pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
59 name: "chol",
60 shape: ShapeRequirements::Any,
61 constant_strategy: ConstantStrategy::InlineLiteral,
62 elementwise: None,
63 reduction: None,
64 emits_nan: false,
65 notes: "Factorisation executes eagerly and does not participate in expression fusion.",
66};
67
68#[runtime_builtin(
69 name = "chol",
70 category = "math/linalg/factor",
71 summary = "Cholesky factorization with MATLAB-compatible upper and lower forms.",
72 keywords = "chol,cholesky,factorization,positive-definite",
73 accel = "sink",
74 sink = true,
75 type_resolver(matrix_unary_type),
76 builtin_path = "crate::builtins::math::linalg::factor::chol"
77)]
78async fn chol_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
79 let eval = evaluate(value, &rest).await?;
80 if let Some(out_count) = crate::output_count::current_output_count() {
81 if out_count == 0 {
82 return Ok(Value::OutputList(Vec::new()));
83 }
84 if out_count == 1 {
85 if !eval.is_positive_definite() {
86 return Err(chol_error("Matrix must be positive definite."));
87 }
88 return Ok(Value::OutputList(vec![eval.factor()]));
89 }
90 if out_count == 2 {
91 return Ok(Value::OutputList(vec![eval.factor(), eval.flag()]));
92 }
93 return Err(chol_error("chol currently supports at most two outputs"));
94 }
95 if !eval.is_positive_definite() {
96 return Err(chol_error("Matrix must be positive definite."));
97 }
98 Ok(eval.factor())
99}
100
101#[derive(Clone)]
103pub struct CholEval {
104 factor: Value,
105 flag: usize,
106 triangle: CholTriangle,
107}
108
109impl CholEval {
110 pub fn factor(&self) -> Value {
112 self.factor.clone()
113 }
114
115 pub fn flag(&self) -> Value {
117 Value::Num(self.flag as f64)
118 }
119
120 pub fn flag_index(&self) -> usize {
122 self.flag
123 }
124
125 pub fn triangle(&self) -> CholTriangle {
127 self.triangle
128 }
129
130 pub fn is_positive_definite(&self) -> bool {
132 self.flag == 0
133 }
134
135 fn from_components(components: CholComponents, triangle: CholTriangle) -> BuiltinResult<Self> {
136 let factor_matrix = match triangle {
137 CholTriangle::Upper => components.upper.clone(),
138 CholTriangle::Lower => components.upper.conjugate_transpose(),
139 };
140 let factor = matrix_to_value("chol", &factor_matrix)?;
141 Ok(Self {
142 factor,
143 flag: components.info,
144 triangle,
145 })
146 }
147
148 fn from_provider(result: ProviderCholResult, triangle: CholTriangle) -> Self {
149 Self {
150 factor: Value::GpuTensor(result.factor),
151 flag: result.info as usize,
152 triangle,
153 }
154 }
155}
156
157#[derive(Clone, Copy, Debug, PartialEq, Eq)]
159pub enum CholTriangle {
160 Upper,
161 Lower,
162}
163
164pub async fn evaluate(value: Value, args: &[Value]) -> BuiltinResult<CholEval> {
166 let triangle = parse_triangle(args)?;
167 match value {
168 Value::GpuTensor(handle) => {
169 if let Some(eval) = evaluate_gpu(&handle, triangle).await? {
170 return Ok(eval);
171 }
172 let tensor = gpu_helpers::gather_tensor_async(&handle)
173 .await
174 .map_err(with_chol_context)?;
175 evaluate_host_value(Value::Tensor(tensor), triangle).await
176 }
177 other => evaluate_host_value(other, triangle).await,
178 }
179}
180
181async fn evaluate_host_value(value: Value, triangle: CholTriangle) -> BuiltinResult<CholEval> {
182 let matrix = extract_matrix(value).await?;
183 if matrix.rows != matrix.cols {
184 return Err(chol_error("chol: input matrix must be square"));
185 }
186 let components = chol_factor(matrix)?;
187 CholEval::from_components(components, triangle)
188}
189
190async fn evaluate_gpu(
191 handle: &GpuTensorHandle,
192 triangle: CholTriangle,
193) -> BuiltinResult<Option<CholEval>> {
194 if let Some(provider) = runmat_accelerate_api::provider() {
195 let lower = matches!(triangle, CholTriangle::Lower);
196 if let Ok(result) = provider.chol(handle, lower).await {
197 return Ok(Some(CholEval::from_provider(result, triangle)));
198 }
199 }
200 Ok(None)
201}
202
203fn parse_triangle(args: &[Value]) -> BuiltinResult<CholTriangle> {
204 if args.is_empty() {
205 return Ok(CholTriangle::Upper);
206 }
207 if args.len() > 1 {
208 return Err(chol_error("chol: too many option arguments"));
209 }
210 let Some(option) = tensor::value_to_string(&args[0]) else {
211 return Err(chol_error(
212 "chol: option must be a string or character vector",
213 ));
214 };
215 match option.trim().to_ascii_lowercase().as_str() {
216 "upper" => Ok(CholTriangle::Upper),
217 "lower" => Ok(CholTriangle::Lower),
218 other => Err(chol_error(format!("chol: unknown option '{other}'"))),
219 }
220}
221
222const EPS: f64 = 1.0e-12;
223
224#[inline]
225fn hermitian_pair_matches(a: Complex64, b: Complex64) -> bool {
226 let diff = a - b.conj();
227 let scale = a.norm().max(b.norm()).max(1.0);
228 diff.norm() <= EPS * scale
229}
230
231fn chol_factor(matrix: RowMajorMatrix) -> BuiltinResult<CholComponents> {
232 let n = matrix.rows;
233 if n == 0 {
234 return Ok(CholComponents {
235 upper: RowMajorMatrix::zeros(0, 0),
236 info: 0,
237 });
238 }
239 let mut upper = RowMajorMatrix::zeros(n, n);
240 let mut info = 0usize;
241
242 'outer: for j in 0..n {
243 for i in 0..j {
244 if !hermitian_pair_matches(matrix.get(i, j), matrix.get(j, i)) {
245 info = j + 1;
246 break 'outer;
247 }
248 }
249
250 for i in 0..=j {
251 let mut sum = matrix.get(i, j);
252 for k in 0..i {
253 let rik = upper.get(k, i).conj();
254 let rkj = upper.get(k, j);
255 sum -= rik * rkj;
256 }
257 if i == j {
258 let imag_tol = EPS * sum.re.abs().max(1.0);
259 if !sum.re.is_finite()
260 || !sum.im.is_finite()
261 || sum.re <= 0.0
262 || sum.im.abs() > imag_tol
263 {
264 info = j + 1;
265 break 'outer;
266 }
267 let diag = sum.re.sqrt();
268 upper.set(i, i, Complex64::new(diag, 0.0));
269 } else {
270 let denom = upper.get(i, i);
271 if denom.norm() <= EPS {
272 info = i + 1;
273 break 'outer;
274 }
275 upper.set(i, j, sum / denom);
276 }
277 }
278 }
279
280 if info != 0 {
281 let start = info.saturating_sub(1).min(n);
282 for row in start..n {
283 for col in row..n {
284 upper.set(row, col, Complex64::new(0.0, 0.0));
285 }
286 }
287 }
288
289 Ok(CholComponents { upper, info })
290}
291
292async fn extract_matrix(value: Value) -> BuiltinResult<RowMajorMatrix> {
293 match value {
294 Value::Tensor(tensor) => RowMajorMatrix::from_tensor(&tensor, "chol"),
295 Value::ComplexTensor(ct) => RowMajorMatrix::from_complex_tensor(&ct, "chol"),
296 Value::LogicalArray(logical) => {
297 let tensor = tensor::logical_to_tensor(&logical)
298 .map_err(|err| chol_error(format!("chol: {err}")))?;
299 RowMajorMatrix::from_tensor(&tensor, "chol")
300 }
301 Value::Num(n) => Ok(RowMajorMatrix::from_scalar(Complex64::new(n, 0.0))),
302 Value::Int(i) => Ok(RowMajorMatrix::from_scalar(Complex64::new(i.to_f64(), 0.0))),
303 Value::Bool(b) => Ok(RowMajorMatrix::from_scalar(Complex64::new(
304 if b { 1.0 } else { 0.0 },
305 0.0,
306 ))),
307 Value::Complex(re, im) => Ok(RowMajorMatrix::from_scalar(Complex64::new(re, im))),
308 Value::GpuTensor(handle) => {
309 let tensor = gpu_helpers::gather_tensor_async(&handle)
310 .await
311 .map_err(with_chol_context)?;
312 RowMajorMatrix::from_tensor(&tensor, "chol")
313 }
314 other => Err(chol_error(format!(
315 "chol: unsupported input type {:?}; expected numeric or logical values",
316 other
317 ))),
318 }
319}
320
321fn matrix_to_value(label: &str, matrix: &RowMajorMatrix) -> BuiltinResult<Value> {
322 let mut has_imag = false;
323 for val in &matrix.data {
324 if val.im.abs() > EPS {
325 has_imag = true;
326 break;
327 }
328 }
329 if has_imag {
330 let mut data = Vec::with_capacity(matrix.rows * matrix.cols);
331 for col in 0..matrix.cols {
332 for row in 0..matrix.rows {
333 let idx = row * matrix.cols + col;
334 let v = matrix.data[idx];
335 data.push((v.re, v.im));
336 }
337 }
338 let tensor = ComplexTensor::new(data, vec![matrix.rows, matrix.cols])
339 .map_err(|e| chol_error(format!("{label}: {e}")))?;
340 Ok(random_args::complex_tensor_into_value(tensor))
341 } else {
342 let mut data = Vec::with_capacity(matrix.rows * matrix.cols);
343 for col in 0..matrix.cols {
344 for row in 0..matrix.rows {
345 let idx = row * matrix.cols + col;
346 data.push(matrix.data[idx].re);
347 }
348 }
349 let tensor = Tensor::new(data, vec![matrix.rows, matrix.cols])
350 .map_err(|e| chol_error(format!("{label}: {e}")))?;
351 Ok(tensor::tensor_into_value(tensor))
352 }
353}
354
355struct CholComponents {
356 upper: RowMajorMatrix,
357 info: usize,
358}
359
360#[derive(Clone)]
361struct RowMajorMatrix {
362 rows: usize,
363 cols: usize,
364 data: Vec<Complex64>,
365}
366
367impl RowMajorMatrix {
368 fn zeros(rows: usize, cols: usize) -> Self {
369 Self {
370 rows,
371 cols,
372 data: vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)],
373 }
374 }
375
376 fn from_scalar(value: Complex64) -> Self {
377 Self {
378 rows: 1,
379 cols: 1,
380 data: vec![value],
381 }
382 }
383
384 fn from_tensor(tensor: &Tensor, label: &str) -> BuiltinResult<Self> {
385 if tensor.shape.len() > 2 {
386 return Err(chol_error(format!("{label}: input must be 2-D")));
387 }
388 let rows = tensor.rows();
389 let cols = tensor.cols();
390 let mut data = vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)];
391 for col in 0..cols {
392 for row in 0..rows {
393 let idx_col_major = row + col * rows;
394 let idx_row_major = row * cols + col;
395 data[idx_row_major] = Complex64::new(tensor.data[idx_col_major], 0.0);
396 }
397 }
398 Ok(Self { rows, cols, data })
399 }
400
401 fn from_complex_tensor(tensor: &ComplexTensor, label: &str) -> BuiltinResult<Self> {
402 if tensor.shape.len() > 2 {
403 return Err(chol_error(format!("{label}: input must be 2-D")));
404 }
405 let rows = tensor.rows;
406 let cols = tensor.cols;
407 let mut data = vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)];
408 for col in 0..cols {
409 for row in 0..rows {
410 let idx_col_major = row + col * rows;
411 let idx_row_major = row * cols + col;
412 let (re, im) = tensor.data[idx_col_major];
413 data[idx_row_major] = Complex64::new(re, im);
414 }
415 }
416 Ok(Self { rows, cols, data })
417 }
418
419 fn get(&self, row: usize, col: usize) -> Complex64 {
420 self.data[row * self.cols + col]
421 }
422
423 fn set(&mut self, row: usize, col: usize, value: Complex64) {
424 self.data[row * self.cols + col] = value;
425 }
426
427 fn conjugate_transpose(&self) -> Self {
428 let mut out = RowMajorMatrix::zeros(self.cols, self.rows);
429 for row in 0..self.rows {
430 for col in row..self.cols {
431 let value = self.get(row, col);
432 out.set(col, row, value.conj());
433 }
434 }
435 out
436 }
437}
438
439#[cfg(test)]
440pub(crate) mod tests {
441 use super::*;
442 use crate::builtins::common::test_support;
443 use futures::executor::block_on;
444 use runmat_builtins::{LogicalArray, ResolveContext, Tensor as Matrix, Type};
445
446 fn error_message(err: RuntimeError) -> String {
447 err.message().to_string()
448 }
449
450 fn tensor_from_value(value: Value) -> Matrix {
451 match value {
452 Value::Tensor(t) => t,
453 Value::Num(n) => Matrix::new(vec![n], vec![1, 1]).expect("tensor"),
454 other => panic!("expected tensor value, got {other:?}"),
455 }
456 }
457
458 #[test]
459 fn chol_type_preserves_matrix_shape() {
460 let out = matrix_unary_type(
461 &[Type::Tensor {
462 shape: Some(vec![Some(3), Some(3)]),
463 }],
464 &ResolveContext::new(Vec::new()),
465 );
466 assert_eq!(
467 out,
468 Type::Tensor {
469 shape: Some(vec![Some(3), Some(3)])
470 }
471 );
472 }
473
474 fn reconstruct_from_upper(matrix: &Matrix) -> Matrix {
475 let rows = matrix.rows();
476 let cols = matrix.cols();
477 assert_eq!(rows, cols, "expected square matrix");
478 let mut data = vec![0.0; rows * cols];
479 for i in 0..rows {
481 for j in 0..rows {
482 let mut sum = 0.0;
483 for k in 0..rows {
484 let rik = if k <= i {
485 matrix.data[k + i * rows]
486 } else {
487 0.0
488 };
489 let rjk = if k <= j {
490 matrix.data[k + j * rows]
491 } else {
492 0.0
493 };
494 sum += rik * rjk;
495 }
496 data[i + j * rows] = sum;
497 }
498 }
499 Matrix::new(data, vec![rows, rows]).expect("matrix")
500 }
501
502 fn reconstruct_from_lower(matrix: &Matrix) -> Matrix {
503 let rows = matrix.rows();
504 let cols = matrix.cols();
505 assert_eq!(rows, cols, "expected square matrix");
506 let mut data = vec![0.0; rows * cols];
507 for i in 0..rows {
508 for j in 0..rows {
509 let mut sum = 0.0;
510 for k in 0..rows {
511 let lik = if i >= k {
512 matrix.data[i + k * rows]
513 } else {
514 0.0
515 };
516 let ljk = if j >= k {
517 matrix.data[j + k * rows]
518 } else {
519 0.0
520 };
521 sum += lik * ljk;
522 }
523 data[i + j * rows] = sum;
524 }
525 }
526 Matrix::new(data, vec![rows, rows]).expect("matrix")
527 }
528
529 fn tensor_close(lhs: &Matrix, rhs: &Matrix, tol: f64) {
530 assert_eq!(lhs.shape, rhs.shape, "shape mismatch");
531 for (a, b) in lhs.data.iter().zip(rhs.data.iter()) {
532 assert!(
533 (a - b).abs() <= tol,
534 "tensors differ: {a} vs {b} (tol {tol})"
535 );
536 }
537 }
538
539 fn complex_tensor_from_value(value: Value) -> ComplexTensor {
540 match value {
541 Value::ComplexTensor(ct) => ct,
542 Value::Complex(re, im) => {
543 ComplexTensor::new(vec![(re, im)], vec![1, 1]).expect("complex tensor")
544 }
545 Value::Tensor(t) => {
546 let data: Vec<(f64, f64)> = t.data.iter().map(|&v| (v, 0.0)).collect();
547 ComplexTensor::new(data, t.shape.clone()).expect("complex tensor")
548 }
549 Value::Num(n) => {
550 ComplexTensor::new(vec![(n, 0.0)], vec![1, 1]).expect("complex tensor")
551 }
552 other => panic!("expected complex-capable value, got {other:?}"),
553 }
554 }
555
556 fn reconstruct_complex_upper(matrix: &ComplexTensor) -> ComplexTensor {
557 let rows = matrix.rows;
558 let cols = matrix.cols;
559 assert_eq!(rows, cols, "expected square matrix");
560 let mut data = vec![(0.0, 0.0); rows * rows];
561 for i in 0..rows {
562 for j in 0..rows {
563 let mut sum = Complex64::new(0.0, 0.0);
564 for k in 0..rows {
565 let rik = if k <= i {
566 let (re, im) = matrix.data[k + i * rows];
567 Complex64::new(re, im)
568 } else {
569 Complex64::new(0.0, 0.0)
570 };
571 let rjk = if k <= j {
572 let (re, im) = matrix.data[k + j * rows];
573 Complex64::new(re, im)
574 } else {
575 Complex64::new(0.0, 0.0)
576 };
577 sum += rik.conj() * rjk;
578 }
579 data[i + j * rows] = (sum.re, sum.im);
580 }
581 }
582 ComplexTensor::new(data, vec![rows, rows]).expect("complex tensor")
583 }
584
585 fn reconstruct_complex_lower(matrix: &ComplexTensor) -> ComplexTensor {
586 let rows = matrix.rows;
587 let cols = matrix.cols;
588 assert_eq!(rows, cols, "expected square matrix");
589 let mut data = vec![(0.0, 0.0); rows * rows];
590 for i in 0..rows {
591 for j in 0..rows {
592 let mut sum = Complex64::new(0.0, 0.0);
593 for k in 0..rows {
594 let lik = if i >= k {
595 let (re, im) = matrix.data[i + k * rows];
596 Complex64::new(re, im)
597 } else {
598 Complex64::new(0.0, 0.0)
599 };
600 let ljk = if j >= k {
601 let (re, im) = matrix.data[j + k * rows];
602 Complex64::new(re, im)
603 } else {
604 Complex64::new(0.0, 0.0)
605 };
606 sum += lik * ljk.conj();
607 }
608 data[i + j * rows] = (sum.re, sum.im);
609 }
610 }
611 ComplexTensor::new(data, vec![rows, rows]).expect("complex tensor")
612 }
613
614 fn complex_tensor_close(lhs: &ComplexTensor, rhs: &ComplexTensor, tol: f64) {
615 assert_eq!(lhs.shape, rhs.shape, "shape mismatch");
616 for ((ar, ai), (br, bi)) in lhs.data.iter().zip(rhs.data.iter()) {
617 let a = Complex64::new(*ar, *ai);
618 let b = Complex64::new(*br, *bi);
619 assert!(
620 (a - b).norm() <= tol,
621 "tensors differ: {a:?} vs {b:?} (tol {tol})"
622 );
623 }
624 }
625
626 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
627 #[test]
628 fn chol_upper_factor_matches_reference() {
629 let a = Matrix::new(
630 vec![
631 4.0, 12.0, -16.0, 12.0, 37.0, -43.0, -16.0, -43.0, 98.0,
634 ],
635 vec![3, 3],
636 )
637 .unwrap();
638 let r = chol_builtin(Value::Tensor(a.clone()), Vec::new()).expect("chol");
639 let r_tensor = tensor_from_value(r);
640 assert_eq!(r_tensor.shape, vec![3, 3]);
641 for diag in 0..3 {
642 let value = r_tensor.data[diag + diag * 3];
643 assert!(value > 0.0, "Cholesky diagonal must be positive");
644 }
645 let recon = reconstruct_from_upper(&r_tensor);
646 tensor_close(&recon, &a, 1e-10);
647 }
648
649 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
650 #[test]
651 fn chol_upper_option_matches_default() {
652 let a = Matrix::new(
653 vec![
654 7.0, 2.0, 1.0, 2.0, 5.0, 2.0, 1.0, 2.0, 3.0,
657 ],
658 vec![3, 3],
659 )
660 .unwrap();
661 let default = chol_builtin(Value::Tensor(a.clone()), Vec::new()).expect("chol");
662 let explicit =
663 chol_builtin(Value::Tensor(a.clone()), vec![Value::from("upper")]).expect("chol upper");
664 let default_tensor = tensor_from_value(default);
665 let explicit_tensor = tensor_from_value(explicit);
666 tensor_close(&default_tensor, &explicit_tensor, 1e-12);
667 }
668
669 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
670 #[test]
671 fn chol_lower_option_returns_lower_factor() {
672 let a = Matrix::new(
673 vec![
674 25.0, 15.0, -5.0, 15.0, 18.0, 0.0, -5.0, 0.0, 11.0,
677 ],
678 vec![3, 3],
679 )
680 .unwrap();
681 let result =
682 chol_builtin(Value::Tensor(a.clone()), vec![Value::from("lower")]).expect("chol");
683 let l = tensor_from_value(result);
684 assert_eq!(l.shape, vec![3, 3]);
685 for diag in 0..3 {
686 let value = l.data[diag + diag * 3];
687 assert!(value > 0.0, "Cholesky diagonal must be positive");
688 }
689 let recon = reconstruct_from_lower(&l);
690 tensor_close(&recon, &a, 1e-10);
691 }
692
693 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
694 #[test]
695 fn chol_two_output_lower_variant() {
696 let a = Matrix::new(
697 vec![
698 9.0, 3.0, 3.0, 3.0, 5.0, 1.0, 3.0, 1.0, 7.0,
701 ],
702 vec![3, 3],
703 )
704 .unwrap();
705 let eval = evaluate(Value::Tensor(a.clone()), &[Value::from("lower")]).expect("chol eval");
706 assert_eq!(eval.flag_index(), 0);
707 assert_eq!(eval.triangle(), CholTriangle::Lower);
708 let factor = tensor_from_value(eval.factor());
709 let recon = reconstruct_from_lower(&factor);
710 tensor_close(&recon, &a, 1e-10);
711 }
712
713 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
714 #[test]
715 fn chol_two_output_reports_failure() {
716 let a = Matrix::new(vec![1.0, 2.0, 2.0, 1.0], vec![2, 2]).expect("matrix");
717 let eval = evaluate(Value::Tensor(a), &[]).expect("chol eval");
718 assert_eq!(eval.flag_index(), 2);
719 let factor = tensor_from_value(eval.factor());
720 assert_eq!(factor.shape, vec![2, 2]);
721 assert!((factor.data[0] - 1.0).abs() < 1e-12);
722 assert!((factor.data[1] - 0.0).abs() < 1e-12);
723 assert!((factor.data[2] - 2.0).abs() < 1e-12);
724 assert!((factor.data[3] - 0.0).abs() < 1e-12);
725 }
726
727 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
728 #[test]
729 fn chol_single_output_errors_on_failure() {
730 let a = Matrix::new(vec![1.0, 2.0, 2.0, 1.0], vec![2, 2]).expect("matrix");
731 let err = error_message(chol_builtin(Value::Tensor(a), Vec::new()).unwrap_err());
732 assert!(err.contains("positive definite"));
733 }
734
735 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
736 #[test]
737 fn chol_invalid_option_errors() {
738 let a = Matrix::new(vec![4.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
739 let err = error_message(
740 chol_builtin(Value::Tensor(a), vec![Value::from("diagonal")]).unwrap_err(),
741 );
742 assert!(err.to_ascii_lowercase().contains("unknown option"));
743 }
744
745 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
746 #[test]
747 fn chol_non_square_errors() {
748 let a = Matrix::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
749 let err = error_message(chol_builtin(Value::Tensor(a), Vec::new()).unwrap_err());
750 assert!(err.to_ascii_lowercase().contains("square"));
751 }
752
753 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
754 #[test]
755 fn chol_empty_matrix_returns_empty() {
756 let empty = Matrix::new(Vec::<f64>::new(), vec![0, 0]).unwrap();
757 let eval = evaluate(Value::Tensor(empty.clone()), &[]).expect("chol eval");
758 assert_eq!(eval.flag_index(), 0);
759 let factor = tensor_from_value(eval.factor());
760 assert_eq!(factor.shape, vec![0, 0]);
761 assert!(factor.data.is_empty());
762 }
763
764 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
765 #[test]
766 fn chol_non_hermitian_reports_failure() {
767 let a = Matrix::new(vec![2.0, 1.0, 0.0, 2.0], vec![2, 2]).expect("matrix");
768 let eval = evaluate(Value::Tensor(a), &[]).expect("chol eval");
769 assert_eq!(eval.flag_index(), 2);
770 }
771
772 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
773 #[test]
774 fn chol_logical_input_factorizes() {
775 let logical = LogicalArray::new(vec![1, 0, 0, 1], vec![2, 2]).expect("logical array");
776 let result = chol_builtin(Value::LogicalArray(logical), Vec::new()).expect("chol");
777 let factor = tensor_from_value(result);
778 let recon = reconstruct_from_upper(&factor);
779 let identity = Matrix::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
780 tensor_close(&recon, &identity, 1e-12);
781 }
782
783 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
784 #[test]
785 fn chol_complex_positive_definite() {
786 let complex = ComplexTensor::new(
787 vec![(5.0, 0.0), (1.0, 2.0), (1.0, -2.0), (4.0, 0.0)],
788 vec![2, 2],
789 )
790 .unwrap();
791 let eval = evaluate(Value::ComplexTensor(complex.clone()), &[]).expect("chol eval");
792 assert_eq!(eval.flag_index(), 0);
793 let factor = complex_tensor_from_value(eval.factor());
794 let recon = reconstruct_complex_upper(&factor);
795 complex_tensor_close(&recon, &complex, 1e-10);
796 }
797
798 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
799 #[test]
800 fn chol_complex_lower_variant() {
801 let complex = ComplexTensor::new(
802 vec![(5.0, 0.0), (1.0, 2.0), (1.0, -2.0), (4.0, 0.0)],
803 vec![2, 2],
804 )
805 .unwrap();
806 let eval = evaluate(
807 Value::ComplexTensor(complex.clone()),
808 &[Value::from("lower")],
809 )
810 .expect("chol eval");
811 assert_eq!(eval.flag_index(), 0);
812 let factor = complex_tensor_from_value(eval.factor());
813 let recon = reconstruct_complex_lower(&factor);
814 complex_tensor_close(&recon, &complex, 1e-10);
815 }
816
817 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
818 #[test]
819 fn chol_gpu_provider_roundtrip() {
820 test_support::with_test_provider(|provider| {
821 let a = Matrix::new(vec![6.0, 2.0, 2.0, 5.0], vec![2, 2]).unwrap();
822 let view = runmat_accelerate_api::HostTensorView {
823 data: &a.data,
824 shape: &a.shape,
825 };
826 let handle = provider.upload(&view).expect("upload");
827 let result = chol_builtin(Value::GpuTensor(handle), Vec::new()).expect("chol");
828 let gathered = test_support::gather(result).expect("gather");
829 let recon = reconstruct_from_upper(&gathered);
830 tensor_close(&recon, &a, 1e-10);
831 });
832 }
833
834 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
835 #[test]
836 fn chol_gpu_failure_flag() {
837 test_support::with_test_provider(|provider| {
838 let a = Matrix::new(vec![1.0, 2.0, 2.0, 1.0], vec![2, 2]).unwrap();
839 let view = runmat_accelerate_api::HostTensorView {
840 data: &a.data,
841 shape: &a.shape,
842 };
843 let handle = provider.upload(&view).expect("upload");
844 let eval = evaluate(Value::GpuTensor(handle), &[]).expect("chol eval");
845 assert_eq!(eval.flag_index(), 2);
846 let factor = eval.factor();
847 assert!(matches!(factor, Value::GpuTensor(_)));
848 let gathered = test_support::gather(factor).expect("gather factor");
849 assert!((gathered.data[0] - 1.0).abs() < 1e-12);
850 assert!((gathered.data[1] - 0.0).abs() < 1e-12);
851 assert!((gathered.data[2] - 2.0).abs() < 1e-12);
852 assert!((gathered.data[3] - 0.0).abs() < 1e-12);
853 });
854 }
855
856 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
857 #[test]
858 #[cfg(feature = "wgpu")]
859 fn chol_wgpu_matches_cpu() {
860 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
861 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
862 )
863 .expect("register wgpu provider");
864
865 let tol = match runmat_accelerate_api::provider()
866 .expect("provider")
867 .precision()
868 {
869 runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
870 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
871 };
872
873 let tensor = Matrix::new(
874 vec![
875 10.0, 2.0, 3.0, 2.0, 9.0, 1.0, 3.0, 1.0, 7.0,
878 ],
879 vec![3, 3],
880 )
881 .unwrap();
882
883 let host_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("host eval");
884 let host_factor = tensor_from_value(host_eval.factor());
885
886 let provider = runmat_accelerate_api::provider().expect("provider");
887 let view = runmat_accelerate_api::HostTensorView {
888 data: &tensor.data,
889 shape: &tensor.shape,
890 };
891 let handle = provider.upload(&view).expect("upload");
892
893 let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu eval");
894 assert_eq!(gpu_eval.flag_index(), 0, "gpu chol should succeed");
895 let gpu_factor = test_support::gather(gpu_eval.factor()).expect("gather factor");
896
897 tensor_close(&gpu_factor, &host_factor, tol);
898 }
899
900 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
901 #[test]
902 fn chol_accepts_scalar() {
903 let result = chol_builtin(Value::Num(9.0), Vec::new()).expect("chol");
904 match result {
905 Value::Num(n) => assert!((n - 3.0).abs() < 1e-12),
906 Value::Tensor(t) => {
907 assert_eq!(t.shape, vec![1, 1]);
908 assert!((t.data[0] - 3.0).abs() < 1e-12);
909 }
910 other => panic!("expected scalar-like, got {other:?}"),
911 }
912 }
913
914 fn chol_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
915 block_on(super::chol_builtin(value, rest))
916 }
917
918 fn evaluate(value: Value, args: &[Value]) -> BuiltinResult<CholEval> {
919 block_on(super::evaluate(value, args))
920 }
921}