Skip to main content

runmat_runtime/builtins/math/linalg/factor/
chol.rs

1//! MATLAB-compatible `chol` builtin with upper/lower forms and failure flag.
2//!
3//! This implementation matches MATLAB semantics for dense matrices, including
4//! the two-output form that reports the leading minor index when a matrix fails
5//! the positive-definiteness test. GPU execution is delegated to acceleration
6//! providers when available, with automatic host fallbacks.
7
8use crate::builtins::common::spec::{
9    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
11};
12use crate::builtins::common::{gpu_helpers, random_args, tensor};
13use crate::builtins::math::linalg::type_resolvers::matrix_unary_type;
14use crate::{build_runtime_error, BuiltinResult, RuntimeError};
15use num_complex::Complex64;
16use runmat_accelerate_api::{GpuTensorHandle, ProviderCholResult};
17use runmat_builtins::{ComplexTensor, Tensor, Value};
18use runmat_macros::runtime_builtin;
19
20const BUILTIN_NAME: &str = "chol";
21
22#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::linalg::factor::chol")]
23pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
24    name: "chol",
25    op_kind: GpuOpKind::Custom("chol-factor"),
26    supported_precisions: &[ScalarType::F64],
27    broadcast: BroadcastSemantics::None,
28    provider_hooks: &[ProviderHook::Custom("chol")],
29    constant_strategy: ConstantStrategy::InlineLiteral,
30    residency: ResidencyPolicy::NewHandle,
31    nan_mode: ReductionNaN::Include,
32    two_pass_threshold: None,
33    workgroup_size: None,
34    accepts_nan_mode: false,
35    notes:
36        "Uses the provider 'chol' hook when present; otherwise gathers to the host implementation.",
37};
38
39fn chol_error(message: impl Into<String>) -> RuntimeError {
40    build_runtime_error(message)
41        .with_builtin(BUILTIN_NAME)
42        .build()
43}
44
45fn with_chol_context(mut error: RuntimeError) -> RuntimeError {
46    if error.message() == "interaction pending..." {
47        return build_runtime_error("interaction pending...")
48            .with_builtin(BUILTIN_NAME)
49            .build();
50    }
51    if error.context.builtin.is_none() {
52        error.context = error.context.with_builtin(BUILTIN_NAME);
53    }
54    error
55}
56
57#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::linalg::factor::chol")]
58pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
59    name: "chol",
60    shape: ShapeRequirements::Any,
61    constant_strategy: ConstantStrategy::InlineLiteral,
62    elementwise: None,
63    reduction: None,
64    emits_nan: false,
65    notes: "Factorisation executes eagerly and does not participate in expression fusion.",
66};
67
68#[runtime_builtin(
69    name = "chol",
70    category = "math/linalg/factor",
71    summary = "Cholesky factorization with MATLAB-compatible upper and lower forms.",
72    keywords = "chol,cholesky,factorization,positive-definite",
73    accel = "sink",
74    sink = true,
75    type_resolver(matrix_unary_type),
76    builtin_path = "crate::builtins::math::linalg::factor::chol"
77)]
78async fn chol_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
79    let eval = evaluate(value, &rest).await?;
80    if let Some(out_count) = crate::output_count::current_output_count() {
81        if out_count == 0 {
82            return Ok(Value::OutputList(Vec::new()));
83        }
84        if out_count == 1 {
85            if !eval.is_positive_definite() {
86                return Err(chol_error("Matrix must be positive definite."));
87            }
88            return Ok(Value::OutputList(vec![eval.factor()]));
89        }
90        if out_count == 2 {
91            return Ok(Value::OutputList(vec![eval.factor(), eval.flag()]));
92        }
93        return Err(chol_error("chol currently supports at most two outputs"));
94    }
95    if !eval.is_positive_definite() {
96        return Err(chol_error("Matrix must be positive definite."));
97    }
98    Ok(eval.factor())
99}
100
101/// Evaluate `chol` while keeping both the factor and the failure index available.
102#[derive(Clone)]
103pub struct CholEval {
104    factor: Value,
105    flag: usize,
106    triangle: CholTriangle,
107}
108
109impl CholEval {
110    /// The factor (`R` or `L`) requested by the caller.
111    pub fn factor(&self) -> Value {
112        self.factor.clone()
113    }
114
115    /// MATLAB-compatible failure index (0 indicates success).
116    pub fn flag(&self) -> Value {
117        Value::Num(self.flag as f64)
118    }
119
120    /// Zero-based flag value (0 indicates success).
121    pub fn flag_index(&self) -> usize {
122        self.flag
123    }
124
125    /// The triangle variant that was requested.
126    pub fn triangle(&self) -> CholTriangle {
127        self.triangle
128    }
129
130    /// Returns true when the input matrix was positive definite.
131    pub fn is_positive_definite(&self) -> bool {
132        self.flag == 0
133    }
134
135    fn from_components(components: CholComponents, triangle: CholTriangle) -> BuiltinResult<Self> {
136        let factor_matrix = match triangle {
137            CholTriangle::Upper => components.upper.clone(),
138            CholTriangle::Lower => components.upper.conjugate_transpose(),
139        };
140        let factor = matrix_to_value("chol", &factor_matrix)?;
141        Ok(Self {
142            factor,
143            flag: components.info,
144            triangle,
145        })
146    }
147
148    fn from_provider(result: ProviderCholResult, triangle: CholTriangle) -> Self {
149        Self {
150            factor: Value::GpuTensor(result.factor),
151            flag: result.info as usize,
152            triangle,
153        }
154    }
155}
156
157/// Triangle variant for the Cholesky factor.
158#[derive(Clone, Copy, Debug, PartialEq, Eq)]
159pub enum CholTriangle {
160    Upper,
161    Lower,
162}
163
164/// Compute the Cholesky factorization for the given value and option list.
165pub async fn evaluate(value: Value, args: &[Value]) -> BuiltinResult<CholEval> {
166    let triangle = parse_triangle(args)?;
167    match value {
168        Value::GpuTensor(handle) => {
169            if let Some(eval) = evaluate_gpu(&handle, triangle).await? {
170                return Ok(eval);
171            }
172            let tensor = gpu_helpers::gather_tensor_async(&handle)
173                .await
174                .map_err(with_chol_context)?;
175            evaluate_host_value(Value::Tensor(tensor), triangle).await
176        }
177        other => evaluate_host_value(other, triangle).await,
178    }
179}
180
181async fn evaluate_host_value(value: Value, triangle: CholTriangle) -> BuiltinResult<CholEval> {
182    let matrix = extract_matrix(value).await?;
183    if matrix.rows != matrix.cols {
184        return Err(chol_error("chol: input matrix must be square"));
185    }
186    let components = chol_factor(matrix)?;
187    CholEval::from_components(components, triangle)
188}
189
190async fn evaluate_gpu(
191    handle: &GpuTensorHandle,
192    triangle: CholTriangle,
193) -> BuiltinResult<Option<CholEval>> {
194    if let Some(provider) = runmat_accelerate_api::provider() {
195        let lower = matches!(triangle, CholTriangle::Lower);
196        if let Ok(result) = provider.chol(handle, lower).await {
197            return Ok(Some(CholEval::from_provider(result, triangle)));
198        }
199    }
200    Ok(None)
201}
202
203fn parse_triangle(args: &[Value]) -> BuiltinResult<CholTriangle> {
204    if args.is_empty() {
205        return Ok(CholTriangle::Upper);
206    }
207    if args.len() > 1 {
208        return Err(chol_error("chol: too many option arguments"));
209    }
210    let Some(option) = tensor::value_to_string(&args[0]) else {
211        return Err(chol_error(
212            "chol: option must be a string or character vector",
213        ));
214    };
215    match option.trim().to_ascii_lowercase().as_str() {
216        "upper" => Ok(CholTriangle::Upper),
217        "lower" => Ok(CholTriangle::Lower),
218        other => Err(chol_error(format!("chol: unknown option '{other}'"))),
219    }
220}
221
222const EPS: f64 = 1.0e-12;
223
224#[inline]
225fn hermitian_pair_matches(a: Complex64, b: Complex64) -> bool {
226    let diff = a - b.conj();
227    let scale = a.norm().max(b.norm()).max(1.0);
228    diff.norm() <= EPS * scale
229}
230
231fn chol_factor(matrix: RowMajorMatrix) -> BuiltinResult<CholComponents> {
232    let n = matrix.rows;
233    if n == 0 {
234        return Ok(CholComponents {
235            upper: RowMajorMatrix::zeros(0, 0),
236            info: 0,
237        });
238    }
239    let mut upper = RowMajorMatrix::zeros(n, n);
240    let mut info = 0usize;
241
242    'outer: for j in 0..n {
243        for i in 0..j {
244            if !hermitian_pair_matches(matrix.get(i, j), matrix.get(j, i)) {
245                info = j + 1;
246                break 'outer;
247            }
248        }
249
250        for i in 0..=j {
251            let mut sum = matrix.get(i, j);
252            for k in 0..i {
253                let rik = upper.get(k, i).conj();
254                let rkj = upper.get(k, j);
255                sum -= rik * rkj;
256            }
257            if i == j {
258                let imag_tol = EPS * sum.re.abs().max(1.0);
259                if !sum.re.is_finite()
260                    || !sum.im.is_finite()
261                    || sum.re <= 0.0
262                    || sum.im.abs() > imag_tol
263                {
264                    info = j + 1;
265                    break 'outer;
266                }
267                let diag = sum.re.sqrt();
268                upper.set(i, i, Complex64::new(diag, 0.0));
269            } else {
270                let denom = upper.get(i, i);
271                if denom.norm() <= EPS {
272                    info = i + 1;
273                    break 'outer;
274                }
275                upper.set(i, j, sum / denom);
276            }
277        }
278    }
279
280    if info != 0 {
281        let start = info.saturating_sub(1).min(n);
282        for row in start..n {
283            for col in row..n {
284                upper.set(row, col, Complex64::new(0.0, 0.0));
285            }
286        }
287    }
288
289    Ok(CholComponents { upper, info })
290}
291
292async fn extract_matrix(value: Value) -> BuiltinResult<RowMajorMatrix> {
293    match value {
294        Value::Tensor(tensor) => RowMajorMatrix::from_tensor(&tensor, "chol"),
295        Value::ComplexTensor(ct) => RowMajorMatrix::from_complex_tensor(&ct, "chol"),
296        Value::LogicalArray(logical) => {
297            let tensor = tensor::logical_to_tensor(&logical)
298                .map_err(|err| chol_error(format!("chol: {err}")))?;
299            RowMajorMatrix::from_tensor(&tensor, "chol")
300        }
301        Value::Num(n) => Ok(RowMajorMatrix::from_scalar(Complex64::new(n, 0.0))),
302        Value::Int(i) => Ok(RowMajorMatrix::from_scalar(Complex64::new(i.to_f64(), 0.0))),
303        Value::Bool(b) => Ok(RowMajorMatrix::from_scalar(Complex64::new(
304            if b { 1.0 } else { 0.0 },
305            0.0,
306        ))),
307        Value::Complex(re, im) => Ok(RowMajorMatrix::from_scalar(Complex64::new(re, im))),
308        Value::GpuTensor(handle) => {
309            let tensor = gpu_helpers::gather_tensor_async(&handle)
310                .await
311                .map_err(with_chol_context)?;
312            RowMajorMatrix::from_tensor(&tensor, "chol")
313        }
314        other => Err(chol_error(format!(
315            "chol: unsupported input type {:?}; expected numeric or logical values",
316            other
317        ))),
318    }
319}
320
321fn matrix_to_value(label: &str, matrix: &RowMajorMatrix) -> BuiltinResult<Value> {
322    let mut has_imag = false;
323    for val in &matrix.data {
324        if val.im.abs() > EPS {
325            has_imag = true;
326            break;
327        }
328    }
329    if has_imag {
330        let mut data = Vec::with_capacity(matrix.rows * matrix.cols);
331        for col in 0..matrix.cols {
332            for row in 0..matrix.rows {
333                let idx = row * matrix.cols + col;
334                let v = matrix.data[idx];
335                data.push((v.re, v.im));
336            }
337        }
338        let tensor = ComplexTensor::new(data, vec![matrix.rows, matrix.cols])
339            .map_err(|e| chol_error(format!("{label}: {e}")))?;
340        Ok(random_args::complex_tensor_into_value(tensor))
341    } else {
342        let mut data = Vec::with_capacity(matrix.rows * matrix.cols);
343        for col in 0..matrix.cols {
344            for row in 0..matrix.rows {
345                let idx = row * matrix.cols + col;
346                data.push(matrix.data[idx].re);
347            }
348        }
349        let tensor = Tensor::new(data, vec![matrix.rows, matrix.cols])
350            .map_err(|e| chol_error(format!("{label}: {e}")))?;
351        Ok(tensor::tensor_into_value(tensor))
352    }
353}
354
355struct CholComponents {
356    upper: RowMajorMatrix,
357    info: usize,
358}
359
360#[derive(Clone)]
361struct RowMajorMatrix {
362    rows: usize,
363    cols: usize,
364    data: Vec<Complex64>,
365}
366
367impl RowMajorMatrix {
368    fn zeros(rows: usize, cols: usize) -> Self {
369        Self {
370            rows,
371            cols,
372            data: vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)],
373        }
374    }
375
376    fn from_scalar(value: Complex64) -> Self {
377        Self {
378            rows: 1,
379            cols: 1,
380            data: vec![value],
381        }
382    }
383
384    fn from_tensor(tensor: &Tensor, label: &str) -> BuiltinResult<Self> {
385        if tensor.shape.len() > 2 {
386            return Err(chol_error(format!("{label}: input must be 2-D")));
387        }
388        let rows = tensor.rows();
389        let cols = tensor.cols();
390        let mut data = vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)];
391        for col in 0..cols {
392            for row in 0..rows {
393                let idx_col_major = row + col * rows;
394                let idx_row_major = row * cols + col;
395                data[idx_row_major] = Complex64::new(tensor.data[idx_col_major], 0.0);
396            }
397        }
398        Ok(Self { rows, cols, data })
399    }
400
401    fn from_complex_tensor(tensor: &ComplexTensor, label: &str) -> BuiltinResult<Self> {
402        if tensor.shape.len() > 2 {
403            return Err(chol_error(format!("{label}: input must be 2-D")));
404        }
405        let rows = tensor.rows;
406        let cols = tensor.cols;
407        let mut data = vec![Complex64::new(0.0, 0.0); rows.saturating_mul(cols)];
408        for col in 0..cols {
409            for row in 0..rows {
410                let idx_col_major = row + col * rows;
411                let idx_row_major = row * cols + col;
412                let (re, im) = tensor.data[idx_col_major];
413                data[idx_row_major] = Complex64::new(re, im);
414            }
415        }
416        Ok(Self { rows, cols, data })
417    }
418
419    fn get(&self, row: usize, col: usize) -> Complex64 {
420        self.data[row * self.cols + col]
421    }
422
423    fn set(&mut self, row: usize, col: usize, value: Complex64) {
424        self.data[row * self.cols + col] = value;
425    }
426
427    fn conjugate_transpose(&self) -> Self {
428        let mut out = RowMajorMatrix::zeros(self.cols, self.rows);
429        for row in 0..self.rows {
430            for col in row..self.cols {
431                let value = self.get(row, col);
432                out.set(col, row, value.conj());
433            }
434        }
435        out
436    }
437}
438
439#[cfg(test)]
440pub(crate) mod tests {
441    use super::*;
442    use crate::builtins::common::test_support;
443    use futures::executor::block_on;
444    use runmat_builtins::{LogicalArray, ResolveContext, Tensor as Matrix, Type};
445
446    fn error_message(err: RuntimeError) -> String {
447        err.message().to_string()
448    }
449
450    fn tensor_from_value(value: Value) -> Matrix {
451        match value {
452            Value::Tensor(t) => t,
453            Value::Num(n) => Matrix::new(vec![n], vec![1, 1]).expect("tensor"),
454            other => panic!("expected tensor value, got {other:?}"),
455        }
456    }
457
458    #[test]
459    fn chol_type_preserves_matrix_shape() {
460        let out = matrix_unary_type(
461            &[Type::Tensor {
462                shape: Some(vec![Some(3), Some(3)]),
463            }],
464            &ResolveContext::new(Vec::new()),
465        );
466        assert_eq!(
467            out,
468            Type::Tensor {
469                shape: Some(vec![Some(3), Some(3)])
470            }
471        );
472    }
473
474    fn reconstruct_from_upper(matrix: &Matrix) -> Matrix {
475        let rows = matrix.rows();
476        let cols = matrix.cols();
477        assert_eq!(rows, cols, "expected square matrix");
478        let mut data = vec![0.0; rows * cols];
479        // Compute R' * R for validation (column-major input)
480        for i in 0..rows {
481            for j in 0..rows {
482                let mut sum = 0.0;
483                for k in 0..rows {
484                    let rik = if k <= i {
485                        matrix.data[k + i * rows]
486                    } else {
487                        0.0
488                    };
489                    let rjk = if k <= j {
490                        matrix.data[k + j * rows]
491                    } else {
492                        0.0
493                    };
494                    sum += rik * rjk;
495                }
496                data[i + j * rows] = sum;
497            }
498        }
499        Matrix::new(data, vec![rows, rows]).expect("matrix")
500    }
501
502    fn reconstruct_from_lower(matrix: &Matrix) -> Matrix {
503        let rows = matrix.rows();
504        let cols = matrix.cols();
505        assert_eq!(rows, cols, "expected square matrix");
506        let mut data = vec![0.0; rows * cols];
507        for i in 0..rows {
508            for j in 0..rows {
509                let mut sum = 0.0;
510                for k in 0..rows {
511                    let lik = if i >= k {
512                        matrix.data[i + k * rows]
513                    } else {
514                        0.0
515                    };
516                    let ljk = if j >= k {
517                        matrix.data[j + k * rows]
518                    } else {
519                        0.0
520                    };
521                    sum += lik * ljk;
522                }
523                data[i + j * rows] = sum;
524            }
525        }
526        Matrix::new(data, vec![rows, rows]).expect("matrix")
527    }
528
529    fn tensor_close(lhs: &Matrix, rhs: &Matrix, tol: f64) {
530        assert_eq!(lhs.shape, rhs.shape, "shape mismatch");
531        for (a, b) in lhs.data.iter().zip(rhs.data.iter()) {
532            assert!(
533                (a - b).abs() <= tol,
534                "tensors differ: {a} vs {b} (tol {tol})"
535            );
536        }
537    }
538
539    fn complex_tensor_from_value(value: Value) -> ComplexTensor {
540        match value {
541            Value::ComplexTensor(ct) => ct,
542            Value::Complex(re, im) => {
543                ComplexTensor::new(vec![(re, im)], vec![1, 1]).expect("complex tensor")
544            }
545            Value::Tensor(t) => {
546                let data: Vec<(f64, f64)> = t.data.iter().map(|&v| (v, 0.0)).collect();
547                ComplexTensor::new(data, t.shape.clone()).expect("complex tensor")
548            }
549            Value::Num(n) => {
550                ComplexTensor::new(vec![(n, 0.0)], vec![1, 1]).expect("complex tensor")
551            }
552            other => panic!("expected complex-capable value, got {other:?}"),
553        }
554    }
555
556    fn reconstruct_complex_upper(matrix: &ComplexTensor) -> ComplexTensor {
557        let rows = matrix.rows;
558        let cols = matrix.cols;
559        assert_eq!(rows, cols, "expected square matrix");
560        let mut data = vec![(0.0, 0.0); rows * rows];
561        for i in 0..rows {
562            for j in 0..rows {
563                let mut sum = Complex64::new(0.0, 0.0);
564                for k in 0..rows {
565                    let rik = if k <= i {
566                        let (re, im) = matrix.data[k + i * rows];
567                        Complex64::new(re, im)
568                    } else {
569                        Complex64::new(0.0, 0.0)
570                    };
571                    let rjk = if k <= j {
572                        let (re, im) = matrix.data[k + j * rows];
573                        Complex64::new(re, im)
574                    } else {
575                        Complex64::new(0.0, 0.0)
576                    };
577                    sum += rik.conj() * rjk;
578                }
579                data[i + j * rows] = (sum.re, sum.im);
580            }
581        }
582        ComplexTensor::new(data, vec![rows, rows]).expect("complex tensor")
583    }
584
585    fn reconstruct_complex_lower(matrix: &ComplexTensor) -> ComplexTensor {
586        let rows = matrix.rows;
587        let cols = matrix.cols;
588        assert_eq!(rows, cols, "expected square matrix");
589        let mut data = vec![(0.0, 0.0); rows * rows];
590        for i in 0..rows {
591            for j in 0..rows {
592                let mut sum = Complex64::new(0.0, 0.0);
593                for k in 0..rows {
594                    let lik = if i >= k {
595                        let (re, im) = matrix.data[i + k * rows];
596                        Complex64::new(re, im)
597                    } else {
598                        Complex64::new(0.0, 0.0)
599                    };
600                    let ljk = if j >= k {
601                        let (re, im) = matrix.data[j + k * rows];
602                        Complex64::new(re, im)
603                    } else {
604                        Complex64::new(0.0, 0.0)
605                    };
606                    sum += lik * ljk.conj();
607                }
608                data[i + j * rows] = (sum.re, sum.im);
609            }
610        }
611        ComplexTensor::new(data, vec![rows, rows]).expect("complex tensor")
612    }
613
614    fn complex_tensor_close(lhs: &ComplexTensor, rhs: &ComplexTensor, tol: f64) {
615        assert_eq!(lhs.shape, rhs.shape, "shape mismatch");
616        for ((ar, ai), (br, bi)) in lhs.data.iter().zip(rhs.data.iter()) {
617            let a = Complex64::new(*ar, *ai);
618            let b = Complex64::new(*br, *bi);
619            assert!(
620                (a - b).norm() <= tol,
621                "tensors differ: {a:?} vs {b:?} (tol {tol})"
622            );
623        }
624    }
625
626    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
627    #[test]
628    fn chol_upper_factor_matches_reference() {
629        let a = Matrix::new(
630            vec![
631                4.0, 12.0, -16.0, //
632                12.0, 37.0, -43.0, //
633                -16.0, -43.0, 98.0,
634            ],
635            vec![3, 3],
636        )
637        .unwrap();
638        let r = chol_builtin(Value::Tensor(a.clone()), Vec::new()).expect("chol");
639        let r_tensor = tensor_from_value(r);
640        assert_eq!(r_tensor.shape, vec![3, 3]);
641        for diag in 0..3 {
642            let value = r_tensor.data[diag + diag * 3];
643            assert!(value > 0.0, "Cholesky diagonal must be positive");
644        }
645        let recon = reconstruct_from_upper(&r_tensor);
646        tensor_close(&recon, &a, 1e-10);
647    }
648
649    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
650    #[test]
651    fn chol_upper_option_matches_default() {
652        let a = Matrix::new(
653            vec![
654                7.0, 2.0, 1.0, //
655                2.0, 5.0, 2.0, //
656                1.0, 2.0, 3.0,
657            ],
658            vec![3, 3],
659        )
660        .unwrap();
661        let default = chol_builtin(Value::Tensor(a.clone()), Vec::new()).expect("chol");
662        let explicit =
663            chol_builtin(Value::Tensor(a.clone()), vec![Value::from("upper")]).expect("chol upper");
664        let default_tensor = tensor_from_value(default);
665        let explicit_tensor = tensor_from_value(explicit);
666        tensor_close(&default_tensor, &explicit_tensor, 1e-12);
667    }
668
669    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
670    #[test]
671    fn chol_lower_option_returns_lower_factor() {
672        let a = Matrix::new(
673            vec![
674                25.0, 15.0, -5.0, //
675                15.0, 18.0, 0.0, //
676                -5.0, 0.0, 11.0,
677            ],
678            vec![3, 3],
679        )
680        .unwrap();
681        let result =
682            chol_builtin(Value::Tensor(a.clone()), vec![Value::from("lower")]).expect("chol");
683        let l = tensor_from_value(result);
684        assert_eq!(l.shape, vec![3, 3]);
685        for diag in 0..3 {
686            let value = l.data[diag + diag * 3];
687            assert!(value > 0.0, "Cholesky diagonal must be positive");
688        }
689        let recon = reconstruct_from_lower(&l);
690        tensor_close(&recon, &a, 1e-10);
691    }
692
693    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
694    #[test]
695    fn chol_two_output_lower_variant() {
696        let a = Matrix::new(
697            vec![
698                9.0, 3.0, 3.0, //
699                3.0, 5.0, 1.0, //
700                3.0, 1.0, 7.0,
701            ],
702            vec![3, 3],
703        )
704        .unwrap();
705        let eval = evaluate(Value::Tensor(a.clone()), &[Value::from("lower")]).expect("chol eval");
706        assert_eq!(eval.flag_index(), 0);
707        assert_eq!(eval.triangle(), CholTriangle::Lower);
708        let factor = tensor_from_value(eval.factor());
709        let recon = reconstruct_from_lower(&factor);
710        tensor_close(&recon, &a, 1e-10);
711    }
712
713    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
714    #[test]
715    fn chol_two_output_reports_failure() {
716        let a = Matrix::new(vec![1.0, 2.0, 2.0, 1.0], vec![2, 2]).expect("matrix");
717        let eval = evaluate(Value::Tensor(a), &[]).expect("chol eval");
718        assert_eq!(eval.flag_index(), 2);
719        let factor = tensor_from_value(eval.factor());
720        assert_eq!(factor.shape, vec![2, 2]);
721        assert!((factor.data[0] - 1.0).abs() < 1e-12);
722        assert!((factor.data[1] - 0.0).abs() < 1e-12);
723        assert!((factor.data[2] - 2.0).abs() < 1e-12);
724        assert!((factor.data[3] - 0.0).abs() < 1e-12);
725    }
726
727    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
728    #[test]
729    fn chol_single_output_errors_on_failure() {
730        let a = Matrix::new(vec![1.0, 2.0, 2.0, 1.0], vec![2, 2]).expect("matrix");
731        let err = error_message(chol_builtin(Value::Tensor(a), Vec::new()).unwrap_err());
732        assert!(err.contains("positive definite"));
733    }
734
735    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
736    #[test]
737    fn chol_invalid_option_errors() {
738        let a = Matrix::new(vec![4.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
739        let err = error_message(
740            chol_builtin(Value::Tensor(a), vec![Value::from("diagonal")]).unwrap_err(),
741        );
742        assert!(err.to_ascii_lowercase().contains("unknown option"));
743    }
744
745    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
746    #[test]
747    fn chol_non_square_errors() {
748        let a = Matrix::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
749        let err = error_message(chol_builtin(Value::Tensor(a), Vec::new()).unwrap_err());
750        assert!(err.to_ascii_lowercase().contains("square"));
751    }
752
753    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
754    #[test]
755    fn chol_empty_matrix_returns_empty() {
756        let empty = Matrix::new(Vec::<f64>::new(), vec![0, 0]).unwrap();
757        let eval = evaluate(Value::Tensor(empty.clone()), &[]).expect("chol eval");
758        assert_eq!(eval.flag_index(), 0);
759        let factor = tensor_from_value(eval.factor());
760        assert_eq!(factor.shape, vec![0, 0]);
761        assert!(factor.data.is_empty());
762    }
763
764    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
765    #[test]
766    fn chol_non_hermitian_reports_failure() {
767        let a = Matrix::new(vec![2.0, 1.0, 0.0, 2.0], vec![2, 2]).expect("matrix");
768        let eval = evaluate(Value::Tensor(a), &[]).expect("chol eval");
769        assert_eq!(eval.flag_index(), 2);
770    }
771
772    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
773    #[test]
774    fn chol_logical_input_factorizes() {
775        let logical = LogicalArray::new(vec![1, 0, 0, 1], vec![2, 2]).expect("logical array");
776        let result = chol_builtin(Value::LogicalArray(logical), Vec::new()).expect("chol");
777        let factor = tensor_from_value(result);
778        let recon = reconstruct_from_upper(&factor);
779        let identity = Matrix::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
780        tensor_close(&recon, &identity, 1e-12);
781    }
782
783    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
784    #[test]
785    fn chol_complex_positive_definite() {
786        let complex = ComplexTensor::new(
787            vec![(5.0, 0.0), (1.0, 2.0), (1.0, -2.0), (4.0, 0.0)],
788            vec![2, 2],
789        )
790        .unwrap();
791        let eval = evaluate(Value::ComplexTensor(complex.clone()), &[]).expect("chol eval");
792        assert_eq!(eval.flag_index(), 0);
793        let factor = complex_tensor_from_value(eval.factor());
794        let recon = reconstruct_complex_upper(&factor);
795        complex_tensor_close(&recon, &complex, 1e-10);
796    }
797
798    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
799    #[test]
800    fn chol_complex_lower_variant() {
801        let complex = ComplexTensor::new(
802            vec![(5.0, 0.0), (1.0, 2.0), (1.0, -2.0), (4.0, 0.0)],
803            vec![2, 2],
804        )
805        .unwrap();
806        let eval = evaluate(
807            Value::ComplexTensor(complex.clone()),
808            &[Value::from("lower")],
809        )
810        .expect("chol eval");
811        assert_eq!(eval.flag_index(), 0);
812        let factor = complex_tensor_from_value(eval.factor());
813        let recon = reconstruct_complex_lower(&factor);
814        complex_tensor_close(&recon, &complex, 1e-10);
815    }
816
817    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
818    #[test]
819    fn chol_gpu_provider_roundtrip() {
820        test_support::with_test_provider(|provider| {
821            let a = Matrix::new(vec![6.0, 2.0, 2.0, 5.0], vec![2, 2]).unwrap();
822            let view = runmat_accelerate_api::HostTensorView {
823                data: &a.data,
824                shape: &a.shape,
825            };
826            let handle = provider.upload(&view).expect("upload");
827            let result = chol_builtin(Value::GpuTensor(handle), Vec::new()).expect("chol");
828            let gathered = test_support::gather(result).expect("gather");
829            let recon = reconstruct_from_upper(&gathered);
830            tensor_close(&recon, &a, 1e-10);
831        });
832    }
833
834    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
835    #[test]
836    fn chol_gpu_failure_flag() {
837        test_support::with_test_provider(|provider| {
838            let a = Matrix::new(vec![1.0, 2.0, 2.0, 1.0], vec![2, 2]).unwrap();
839            let view = runmat_accelerate_api::HostTensorView {
840                data: &a.data,
841                shape: &a.shape,
842            };
843            let handle = provider.upload(&view).expect("upload");
844            let eval = evaluate(Value::GpuTensor(handle), &[]).expect("chol eval");
845            assert_eq!(eval.flag_index(), 2);
846            let factor = eval.factor();
847            assert!(matches!(factor, Value::GpuTensor(_)));
848            let gathered = test_support::gather(factor).expect("gather factor");
849            assert!((gathered.data[0] - 1.0).abs() < 1e-12);
850            assert!((gathered.data[1] - 0.0).abs() < 1e-12);
851            assert!((gathered.data[2] - 2.0).abs() < 1e-12);
852            assert!((gathered.data[3] - 0.0).abs() < 1e-12);
853        });
854    }
855
856    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
857    #[test]
858    #[cfg(feature = "wgpu")]
859    fn chol_wgpu_matches_cpu() {
860        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
861            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
862        )
863        .expect("register wgpu provider");
864
865        let tol = match runmat_accelerate_api::provider()
866            .expect("provider")
867            .precision()
868        {
869            runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
870            runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
871        };
872
873        let tensor = Matrix::new(
874            vec![
875                10.0, 2.0, 3.0, //
876                2.0, 9.0, 1.0, //
877                3.0, 1.0, 7.0,
878            ],
879            vec![3, 3],
880        )
881        .unwrap();
882
883        let host_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("host eval");
884        let host_factor = tensor_from_value(host_eval.factor());
885
886        let provider = runmat_accelerate_api::provider().expect("provider");
887        let view = runmat_accelerate_api::HostTensorView {
888            data: &tensor.data,
889            shape: &tensor.shape,
890        };
891        let handle = provider.upload(&view).expect("upload");
892
893        let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu eval");
894        assert_eq!(gpu_eval.flag_index(), 0, "gpu chol should succeed");
895        let gpu_factor = test_support::gather(gpu_eval.factor()).expect("gather factor");
896
897        tensor_close(&gpu_factor, &host_factor, tol);
898    }
899
900    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
901    #[test]
902    fn chol_accepts_scalar() {
903        let result = chol_builtin(Value::Num(9.0), Vec::new()).expect("chol");
904        match result {
905            Value::Num(n) => assert!((n - 3.0).abs() < 1e-12),
906            Value::Tensor(t) => {
907                assert_eq!(t.shape, vec![1, 1]);
908                assert!((t.data[0] - 3.0).abs() < 1e-12);
909            }
910            other => panic!("expected scalar-like, got {other:?}"),
911        }
912    }
913
914    fn chol_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
915        block_on(super::chol_builtin(value, rest))
916    }
917
918    fn evaluate(value: Value, args: &[Value]) -> BuiltinResult<CholEval> {
919        block_on(super::evaluate(value, args))
920    }
921}