Skip to main content

runmat_runtime/builtins/stats/summary/
cov.rs

1//! MATLAB-compatible `cov` builtin with GPU-aware semantics for RunMat.
2
3use runmat_accelerate_api::{CovNormalization, CovRows, CovarianceOptions};
4use runmat_builtins::{
5    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7    Tensor, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::gpu_helpers;
12use crate::builtins::common::spec::{
13    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
15};
16use crate::builtins::common::tensor;
17use crate::builtins::stats::type_resolvers::cov_type;
18use crate::{build_runtime_error, BuiltinResult, RuntimeError};
19
20const NAME: &str = "cov";
21const COV_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
22    name: "C",
23    ty: BuiltinParamType::NumericArray,
24    arity: BuiltinParamArity::Required,
25    default: None,
26    description: "Covariance matrix.",
27}];
28
29const COV_INPUTS_X: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
30    name: "X",
31    ty: BuiltinParamType::Any,
32    arity: BuiltinParamArity::Required,
33    default: None,
34    description: "Input observations (rows are observations, columns are variables).",
35}];
36
37const COV_INPUTS_X_Y_OR_W: [BuiltinParamDescriptor; 2] = [
38    BuiltinParamDescriptor {
39        name: "X",
40        ty: BuiltinParamType::Any,
41        arity: BuiltinParamArity::Required,
42        default: None,
43        description: "Input observations (rows are observations, columns are variables).",
44    },
45    BuiltinParamDescriptor {
46        name: "Y_or_w",
47        ty: BuiltinParamType::Any,
48        arity: BuiltinParamArity::Required,
49        default: None,
50        description: "Second dataset (Y) or weight vector (w), depending on shape/position.",
51    },
52];
53
54const COV_INPUTS_X_NORMALIZATION: [BuiltinParamDescriptor; 2] = [
55    BuiltinParamDescriptor {
56        name: "X",
57        ty: BuiltinParamType::Any,
58        arity: BuiltinParamArity::Required,
59        default: None,
60        description: "Input observations (rows are observations, columns are variables).",
61    },
62    BuiltinParamDescriptor {
63        name: "normalization",
64        ty: BuiltinParamType::NumericScalar,
65        arity: BuiltinParamArity::Required,
66        default: Some("0"),
67        description: "Normalization flag: 0 (unbiased) or 1 (biased).",
68    },
69];
70
71const COV_INPUTS_X_ROWS: [BuiltinParamDescriptor; 2] = [
72    BuiltinParamDescriptor {
73        name: "X",
74        ty: BuiltinParamType::Any,
75        arity: BuiltinParamArity::Required,
76        default: None,
77        description: "Input observations (rows are observations, columns are variables).",
78    },
79    BuiltinParamDescriptor {
80        name: "rows_option",
81        ty: BuiltinParamType::StringScalar,
82        arity: BuiltinParamArity::Required,
83        default: Some("\"all\""),
84        description: "Rows handling mode: 'all', 'omitrows', or 'partialrows'.",
85    },
86];
87
88const COV_INPUTS_X_Y_OPT: [BuiltinParamDescriptor; 3] = [
89    BuiltinParamDescriptor {
90        name: "X",
91        ty: BuiltinParamType::Any,
92        arity: BuiltinParamArity::Required,
93        default: None,
94        description: "Input observations (rows are observations, columns are variables).",
95    },
96    BuiltinParamDescriptor {
97        name: "Y",
98        ty: BuiltinParamType::Any,
99        arity: BuiltinParamArity::Required,
100        default: None,
101        description: "Second dataset with matching row count.",
102    },
103    BuiltinParamDescriptor {
104        name: "opt",
105        ty: BuiltinParamType::Any,
106        arity: BuiltinParamArity::Required,
107        default: None,
108        description: "Normalization flag or rows option.",
109    },
110];
111
112const COV_INPUTS_X_Y_W: [BuiltinParamDescriptor; 3] = [
113    BuiltinParamDescriptor {
114        name: "X",
115        ty: BuiltinParamType::Any,
116        arity: BuiltinParamArity::Required,
117        default: None,
118        description: "Input observations (rows are observations, columns are variables).",
119    },
120    BuiltinParamDescriptor {
121        name: "Y",
122        ty: BuiltinParamType::Any,
123        arity: BuiltinParamArity::Required,
124        default: None,
125        description: "Second dataset with matching row count.",
126    },
127    BuiltinParamDescriptor {
128        name: "w",
129        ty: BuiltinParamType::Any,
130        arity: BuiltinParamArity::Required,
131        default: None,
132        description: "Weight vector with one weight per observation row.",
133    },
134];
135
136const COV_INPUTS_X_Y_W_OPT: [BuiltinParamDescriptor; 4] = [
137    BuiltinParamDescriptor {
138        name: "X",
139        ty: BuiltinParamType::Any,
140        arity: BuiltinParamArity::Required,
141        default: None,
142        description: "Input observations (rows are observations, columns are variables).",
143    },
144    BuiltinParamDescriptor {
145        name: "Y",
146        ty: BuiltinParamType::Any,
147        arity: BuiltinParamArity::Required,
148        default: None,
149        description: "Second dataset with matching row count.",
150    },
151    BuiltinParamDescriptor {
152        name: "w",
153        ty: BuiltinParamType::Any,
154        arity: BuiltinParamArity::Required,
155        default: None,
156        description: "Weight vector with one weight per observation row.",
157    },
158    BuiltinParamDescriptor {
159        name: "opt",
160        ty: BuiltinParamType::Any,
161        arity: BuiltinParamArity::Required,
162        default: None,
163        description: "Normalization flag or rows option.",
164    },
165];
166
167const COV_SIGNATURES: [BuiltinSignatureDescriptor; 7] = [
168    BuiltinSignatureDescriptor {
169        label: "C = cov(X)",
170        inputs: &COV_INPUTS_X,
171        outputs: &COV_OUTPUT,
172    },
173    BuiltinSignatureDescriptor {
174        label: "C = cov(X, Y_or_w)",
175        inputs: &COV_INPUTS_X_Y_OR_W,
176        outputs: &COV_OUTPUT,
177    },
178    BuiltinSignatureDescriptor {
179        label: "C = cov(X, normalization)",
180        inputs: &COV_INPUTS_X_NORMALIZATION,
181        outputs: &COV_OUTPUT,
182    },
183    BuiltinSignatureDescriptor {
184        label: "C = cov(X, rows_option)",
185        inputs: &COV_INPUTS_X_ROWS,
186        outputs: &COV_OUTPUT,
187    },
188    BuiltinSignatureDescriptor {
189        label: "C = cov(X, Y, opt)",
190        inputs: &COV_INPUTS_X_Y_OPT,
191        outputs: &COV_OUTPUT,
192    },
193    BuiltinSignatureDescriptor {
194        label: "C = cov(X, Y, w)",
195        inputs: &COV_INPUTS_X_Y_W,
196        outputs: &COV_OUTPUT,
197    },
198    BuiltinSignatureDescriptor {
199        label: "C = cov(X, Y, w, opt)",
200        inputs: &COV_INPUTS_X_Y_W_OPT,
201        outputs: &COV_OUTPUT,
202    },
203];
204
205const COV_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
206    code: "RM.COV.INVALID_ARGUMENT",
207    identifier: Some("RunMat:cov:InvalidArgument"),
208    when: "Arguments are malformed or unsupported for cov.",
209    message: "cov: invalid argument",
210};
211
212const COV_ERROR_COMPLEX_UNSUPPORTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
213    code: "RM.COV.COMPLEX_UNSUPPORTED",
214    identifier: Some("RunMat:cov:ComplexUnsupported"),
215    when: "Any argument is complex-valued.",
216    message: "cov: complex inputs are not supported yet",
217};
218
219const COV_ERROR_ROWS_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
220    code: "RM.COV.ROWS_MISMATCH",
221    identifier: Some("RunMat:cov:RowsMismatch"),
222    when: "Two input datasets do not have the same number of rows.",
223    message: "cov: inputs must have the same number of rows",
224};
225
226const COV_ERROR_NORMALIZATION_INVALID: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
227    code: "RM.COV.NORMALIZATION_INVALID",
228    identifier: Some("RunMat:cov:NormalizationInvalid"),
229    when: "Normalization flag is non-finite, non-integer, or not 0/1.",
230    message: "cov: normalization flag is invalid",
231};
232
233const COV_ERROR_WEIGHT_VECTOR_LENGTH_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
234    code: "RM.COV.WEIGHT_VECTOR_LENGTH_MISMATCH",
235    identifier: Some("RunMat:cov:WeightVectorLengthMismatch"),
236    when: "Weight vector length does not match observation row count.",
237    message: "cov: weight vector length mismatch",
238};
239
240const COV_ERROR_ROWS_OPTION_UNKNOWN: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
241    code: "RM.COV.ROWS_OPTION_UNKNOWN",
242    identifier: Some("RunMat:cov:RowsOptionUnknown"),
243    when: "Rows option is not one of all/omitrows/partialrows.",
244    message: "cov: unknown rows option",
245};
246
247const COV_ERROR_NORMALIZATION_DUPLICATE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
248    code: "RM.COV.NORMALIZATION_DUPLICATE",
249    identifier: Some("RunMat:cov:NormalizationDuplicate"),
250    when: "Normalization flag is provided more than once.",
251    message: "cov: normalization flag specified more than once",
252};
253
254const COV_ERROR_TOO_MANY_ARRAY_ARGUMENTS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
255    code: "RM.COV.TOO_MANY_ARRAY_ARGUMENTS",
256    identifier: Some("RunMat:cov:TooManyArrayArguments"),
257    when: "More than two data arrays (or Y plus weight) are provided.",
258    message: "cov: too many array arguments",
259};
260
261const COV_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
262    code: "RM.COV.INTERNAL",
263    identifier: Some("RunMat:cov:Internal"),
264    when: "Internal tensor conversion/allocation or covariance computation fails.",
265    message: "cov: internal operation failed",
266};
267
268const COV_ERRORS: [BuiltinErrorDescriptor; 9] = [
269    COV_ERROR_INVALID_ARGUMENT,
270    COV_ERROR_COMPLEX_UNSUPPORTED,
271    COV_ERROR_ROWS_MISMATCH,
272    COV_ERROR_NORMALIZATION_INVALID,
273    COV_ERROR_WEIGHT_VECTOR_LENGTH_MISMATCH,
274    COV_ERROR_ROWS_OPTION_UNKNOWN,
275    COV_ERROR_NORMALIZATION_DUPLICATE,
276    COV_ERROR_TOO_MANY_ARRAY_ARGUMENTS,
277    COV_ERROR_INTERNAL,
278];
279
280pub const COV_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
281    signatures: &COV_SIGNATURES,
282    output_mode: BuiltinOutputMode::Fixed,
283    completion_policy: BuiltinCompletionPolicy::Public,
284    errors: &COV_ERRORS,
285};
286
287fn cov_error_with(
288    error: &'static BuiltinErrorDescriptor,
289    message: impl Into<String>,
290) -> RuntimeError {
291    let mut builder = build_runtime_error(message).with_builtin(NAME);
292    if let Some(identifier) = error.identifier {
293        builder = builder.with_identifier(identifier);
294    }
295    builder.build()
296}
297
298fn cov_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
299    cov_error_with(error, error.message)
300}
301
302fn cov_error_with_detail(
303    error: &'static BuiltinErrorDescriptor,
304    detail: impl std::fmt::Display,
305) -> RuntimeError {
306    cov_error_with(error, format!("{}: {detail}", error.message))
307}
308
309fn cov_internal_error(message: impl Into<String>) -> RuntimeError {
310    cov_error_with(&COV_ERROR_INTERNAL, message)
311}
312
313#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::stats::summary::cov")]
314pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
315    name: "cov",
316    op_kind: GpuOpKind::Custom("summary-stats"),
317    supported_precisions: &[ScalarType::F32, ScalarType::F64],
318    broadcast: BroadcastSemantics::None,
319    provider_hooks: &[ProviderHook::Custom("covariance")],
320    constant_strategy: ConstantStrategy::InlineLiteral,
321    residency: ResidencyPolicy::NewHandle,
322    nan_mode: ReductionNaN::Include,
323    two_pass_threshold: None,
324    workgroup_size: None,
325    accepts_nan_mode: false,
326    notes: "GPU execution is available when rows='all' and no weight vector is supplied; other cases fall back to the CPU path.",
327};
328
329#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::stats::summary::cov")]
330pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
331    name: "cov",
332    shape: ShapeRequirements::Any,
333    constant_strategy: ConstantStrategy::InlineLiteral,
334    elementwise: None,
335    reduction: None,
336    emits_nan: true,
337    notes: "The covariance builtin is treated as a fusion boundary and executes via dedicated kernels or the host reference.",
338};
339
340#[runtime_builtin(
341    name = "cov",
342    category = "stats/summary",
343    summary = "Compute covariance matrices.",
344    keywords = "cov,covariance,statistics,weights,gpu",
345    accel = "reduction",
346    type_resolver(cov_type),
347    descriptor(crate::builtins::stats::summary::cov::COV_DESCRIPTOR),
348    builtin_path = "crate::builtins::stats::summary::cov"
349)]
350async fn cov_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
351    let args = CovArgs::parse(value, rest)?;
352    if let Some(result) = cov_try_gpu(&args).await? {
353        return Ok(result);
354    }
355    cov_host(args).await
356}
357
358/// Public entry point for providers that need the reference implementation.
359pub fn cov_from_tensors(
360    left: Tensor,
361    right: Option<Tensor>,
362    rows: CovRows,
363    weight: CovWeightSpec,
364) -> BuiltinResult<Tensor> {
365    let matrix = combine_tensors(left, right)?;
366    if let CovWeightSpec::Vector(ref vec) = weight {
367        if matrix.rows != vec.len() {
368            return Err(cov_error_with_detail(
369                &COV_ERROR_WEIGHT_VECTOR_LENGTH_MISMATCH,
370                format!("expected {} elements", matrix.rows),
371            ));
372        }
373    }
374    match rows {
375        CovRows::All => covariance_dense(&matrix, &weight),
376        CovRows::OmitRows => {
377            let (filtered, filtered_weight) = filter_complete_rows(&matrix, weight);
378            covariance_dense(&filtered, &filtered_weight)
379        }
380        CovRows::PartialRows => covariance_pairwise(&matrix, &weight),
381    }
382}
383
384#[derive(Debug)]
385struct CovArgs {
386    first: Value,
387    second: Option<Value>,
388    normalization: CovNormalization,
389    rows: CovRows,
390    weight_vector: Option<Value>,
391}
392
393impl CovArgs {
394    fn parse(first: Value, rest: Vec<Value>) -> BuiltinResult<Self> {
395        let mut second_candidate: Option<Value> = None;
396        let mut weight_candidate: Option<Value> = None;
397        let mut normalization = CovNormalization::Unbiased;
398        let mut normalization_explicit = false;
399        let mut rows = CovRows::All;
400
401        let iter = rest.into_iter();
402        for arg in iter {
403            match arg {
404                Value::String(_) | Value::StringArray(_) | Value::CharArray(_) => {
405                    let key = tensor::value_to_string(&arg)
406                        .ok_or_else(|| cov_error(&COV_ERROR_INVALID_ARGUMENT))?;
407                    let lowered = key.trim().to_ascii_lowercase();
408                    rows = parse_rows_option(&lowered)?;
409                }
410                Value::Tensor(_) | Value::LogicalArray(_) | Value::GpuTensor(_) => {
411                    if second_candidate.is_none() {
412                        second_candidate = Some(arg);
413                    } else if weight_candidate.is_none() {
414                        weight_candidate = Some(arg);
415                    } else {
416                        return Err(cov_error(&COV_ERROR_TOO_MANY_ARRAY_ARGUMENTS));
417                    }
418                }
419                Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
420                    if normalization_explicit || weight_candidate.is_some() {
421                        return Err(cov_error(&COV_ERROR_NORMALIZATION_DUPLICATE));
422                    }
423                    normalization = parse_normalization(arg)?;
424                    normalization_explicit = true;
425                }
426                Value::ComplexTensor(_) => {
427                    return Err(cov_error(&COV_ERROR_COMPLEX_UNSUPPORTED));
428                }
429                other => {
430                    return Err(cov_error_with_detail(
431                        &COV_ERROR_INVALID_ARGUMENT,
432                        format!("{other:?}"),
433                    ))
434                }
435            }
436        }
437
438        if let Some(weight_array) = weight_candidate {
439            // Explicit weight vector always takes precedence over dataset detection.
440            return Ok(Self {
441                first,
442                second: second_candidate,
443                normalization,
444                rows,
445                weight_vector: Some(weight_array),
446            });
447        }
448
449        let mut second = second_candidate;
450        let mut weight_vector: Option<Value> = None;
451
452        if let Some(candidate) = second.take() {
453            if should_treat_as_weight(&first, &candidate, normalization_explicit, rows)? {
454                weight_vector = Some(candidate);
455            } else {
456                second = Some(candidate);
457            }
458        }
459
460        Ok(Self {
461            first,
462            second,
463            normalization,
464            rows,
465            weight_vector,
466        })
467    }
468}
469
470#[derive(Debug, Clone)]
471pub enum CovWeightSpec {
472    Scalar(CovNormalization),
473    Vector(Vec<f64>),
474}
475
476async fn cov_try_gpu(args: &CovArgs) -> BuiltinResult<Option<Value>> {
477    if args.rows != CovRows::All || args.weight_vector.is_some() {
478        return Ok(None);
479    }
480
481    let provider = match runmat_accelerate_api::provider() {
482        Some(p) => p,
483        None => return Ok(None),
484    };
485
486    let first_handle = match &args.first {
487        Value::GpuTensor(handle) => handle,
488        _ => return Ok(None),
489    };
490
491    let maybe_second_handle = match &args.second {
492        Some(Value::GpuTensor(handle)) => Some(handle),
493        Some(_) => return Ok(None),
494        None => None,
495    };
496
497    let options = CovarianceOptions {
498        normalization: args.normalization,
499        rows: args.rows,
500        has_weight_vector: false,
501    };
502
503    match provider
504        .covariance(first_handle, maybe_second_handle, None, &options)
505        .await
506    {
507        Ok(result) => Ok(Some(Value::GpuTensor(result))),
508        Err(_) => Ok(None),
509    }
510}
511
512async fn cov_host(args: CovArgs) -> BuiltinResult<Value> {
513    let CovArgs {
514        first,
515        second,
516        normalization,
517        rows,
518        weight_vector,
519    } = args;
520
521    let left = value_to_tensor_gather(first).await?;
522    let right = match second {
523        Some(value) => Some(value_to_tensor_gather(value).await?),
524        None => None,
525    };
526
527    let weight_spec = if let Some(weight_value) = weight_vector {
528        let vector = value_to_weight_vector(weight_value, left.rows()).await?;
529        CovWeightSpec::Vector(vector)
530    } else {
531        CovWeightSpec::Scalar(normalization)
532    };
533
534    let tensor = cov_from_tensors(left, right, rows, weight_spec)?;
535    Ok(Value::Tensor(tensor))
536}
537
538async fn value_to_tensor_gather(value: Value) -> BuiltinResult<Tensor> {
539    match value {
540        Value::GpuTensor(handle) => gpu_helpers::gather_tensor_async(&handle).await,
541        Value::LogicalArray(logical) => {
542            tensor::logical_to_tensor(&logical).map_err(cov_internal_error)
543        }
544        other => tensor::value_into_tensor_for("cov", other).map_err(cov_internal_error),
545    }
546}
547
548async fn value_to_weight_vector(value: Value, expected_rows: usize) -> BuiltinResult<Vec<f64>> {
549    let tensor = match value {
550        Value::GpuTensor(handle) => gpu_helpers::gather_tensor_async(&handle).await?,
551        Value::LogicalArray(logical) => {
552            tensor::logical_to_tensor(&logical).map_err(cov_internal_error)?
553        }
554        other => tensor::value_into_tensor_for("cov", other).map_err(cov_internal_error)?,
555    };
556
557    if tensor.shape.len() > 2 {
558        return Err(cov_error_with_detail(
559            &COV_ERROR_INVALID_ARGUMENT,
560            "weight vector must be one-dimensional",
561        ));
562    }
563    if tensor.rows() != expected_rows && tensor.cols() != expected_rows {
564        return Err(cov_error_with_detail(
565            &COV_ERROR_WEIGHT_VECTOR_LENGTH_MISMATCH,
566            format!("expected {expected_rows} elements"),
567        ));
568    }
569    for (idx, weight) in tensor.data.iter().enumerate() {
570        if !weight.is_finite() || *weight < 0.0 {
571            return Err(cov_error_with_detail(
572                &COV_ERROR_INVALID_ARGUMENT,
573                format!("weights must be non-negative finite values (index {idx})"),
574            ));
575        }
576    }
577    if tensor.data.is_empty() {
578        return Err(cov_error_with_detail(
579            &COV_ERROR_INVALID_ARGUMENT,
580            "weight vector cannot be empty",
581        ));
582    }
583    Ok(tensor.data)
584}
585
586fn parse_rows_option(value: &str) -> BuiltinResult<CovRows> {
587    match value {
588        "all" => Ok(CovRows::All),
589        "omitrows" | "omit" => Ok(CovRows::OmitRows),
590        "partialrows" | "partial" | "pairwise" => Ok(CovRows::PartialRows),
591        other => Err(cov_error_with_detail(
592            &COV_ERROR_ROWS_OPTION_UNKNOWN,
593            format!("'{other}'"),
594        )),
595    }
596}
597
598fn parse_normalization(value: Value) -> BuiltinResult<CovNormalization> {
599    match value {
600        Value::Int(i) => match i.to_i64() {
601            0 => Ok(CovNormalization::Unbiased),
602            1 => Ok(CovNormalization::Biased),
603            other => Err(cov_error_with_detail(
604                &COV_ERROR_NORMALIZATION_INVALID,
605                format!("expected 0 or 1, received {other}"),
606            )),
607        },
608        Value::Num(n) => {
609            if !n.is_finite() {
610                return Err(cov_error_with_detail(
611                    &COV_ERROR_NORMALIZATION_INVALID,
612                    "value must be finite",
613                ));
614            }
615            let rounded = n.round();
616            if (rounded - n).abs() > 1.0e-12 {
617                return Err(cov_error_with_detail(
618                    &COV_ERROR_NORMALIZATION_INVALID,
619                    "value must be an integer",
620                ));
621            }
622            match rounded as i64 {
623                0 => Ok(CovNormalization::Unbiased),
624                1 => Ok(CovNormalization::Biased),
625                other => Err(cov_error_with_detail(
626                    &COV_ERROR_NORMALIZATION_INVALID,
627                    format!("expected 0 or 1, received {other}"),
628                )),
629            }
630        }
631        Value::Bool(flag) => Ok(if flag {
632            CovNormalization::Biased
633        } else {
634            CovNormalization::Unbiased
635        }),
636        other => Err(cov_error_with_detail(
637            &COV_ERROR_NORMALIZATION_INVALID,
638            format!("value must be numeric, received {other:?}"),
639        )),
640    }
641}
642
643fn should_treat_as_weight(
644    first: &Value,
645    candidate: &Value,
646    normalization_explicit: bool,
647    rows_option: CovRows,
648) -> BuiltinResult<bool> {
649    let (rows_first, cols_first) = value_rows_cols(first)?;
650    let (rows_candidate, cols_candidate) = value_rows_cols(candidate)?;
651
652    let is_vector = rows_candidate == 1
653        || cols_candidate == 1
654        || rows_candidate * cols_candidate == rows_candidate
655            && (rows_candidate == rows_first || cols_candidate == rows_first);
656
657    if !is_vector {
658        return Ok(false);
659    }
660
661    if rows_candidate != rows_first && cols_candidate != rows_first {
662        // Length mismatch, treat as dataset so the later validation emits the proper error.
663        return Ok(false);
664    }
665
666    if cols_first == 1 && !normalization_explicit && matches!(rows_option, CovRows::All) {
667        // Ambiguous `cov(x, y)` case – prefer dataset semantics for compatibility.
668        return Ok(false);
669    }
670
671    Ok(true)
672}
673
674fn value_rows_cols(value: &Value) -> BuiltinResult<(usize, usize)> {
675    match value {
676        Value::Tensor(tensor) => Ok((tensor.rows(), tensor.cols())),
677        Value::LogicalArray(array) => {
678            if array.shape.len() > 2 {
679                return Err(cov_error_with_detail(
680                    &COV_ERROR_INVALID_ARGUMENT,
681                    "inputs must be 2-D matrices or vectors",
682                ));
683            }
684            let rows = if array.shape.is_empty() {
685                1
686            } else {
687                array.shape[0]
688            };
689            let cols = if array.shape.len() >= 2 {
690                array.shape[1]
691            } else {
692                1
693            };
694            Ok((rows, cols))
695        }
696        Value::GpuTensor(handle) => {
697            if handle.shape.len() > 2 {
698                return Err(cov_error_with_detail(
699                    &COV_ERROR_INVALID_ARGUMENT,
700                    "inputs must be 2-D matrices or vectors",
701                ));
702            }
703            let rows = if handle.shape.is_empty() {
704                1
705            } else {
706                handle.shape[0]
707            };
708            let cols = if handle.shape.len() >= 2 {
709                handle.shape[1]
710            } else {
711                1
712            };
713            Ok((rows, cols))
714        }
715        Value::Num(_) | Value::Int(_) | Value::Bool(_) => Ok((1, 1)),
716        other => Err(cov_error_with_detail(
717            &COV_ERROR_INVALID_ARGUMENT,
718            format!("unsupported input type for shape inspection: {other:?}"),
719        )),
720    }
721}
722
723#[derive(Debug, Clone)]
724struct Matrix {
725    data: Vec<f64>,
726    rows: usize,
727    cols: usize,
728}
729
730impl Matrix {
731    fn from_tensor(name: &str, tensor: Tensor) -> BuiltinResult<Self> {
732        if tensor.shape.len() > 2 {
733            return Err(cov_error_with_detail(
734                &COV_ERROR_INVALID_ARGUMENT,
735                format!("{name}: inputs must be 2-D matrices or vectors"),
736            ));
737        }
738        Ok(Self {
739            rows: tensor.rows(),
740            cols: tensor.cols(),
741            data: tensor.data,
742        })
743    }
744
745    #[inline]
746    fn get(&self, row: usize, col: usize) -> f64 {
747        self.data[row + col * self.rows]
748    }
749
750    #[inline]
751    fn column(&self, col: usize) -> &[f64] {
752        let start = col * self.rows;
753        let end = start + self.rows;
754        &self.data[start..end]
755    }
756}
757
758fn combine_tensors(left: Tensor, right: Option<Tensor>) -> BuiltinResult<Matrix> {
759    let mut matrix = Matrix::from_tensor("cov", left)?;
760    if let Some(second) = right {
761        let right_matrix = Matrix::from_tensor("cov", second)?;
762        if matrix.rows != right_matrix.rows {
763            return Err(cov_error(&COV_ERROR_ROWS_MISMATCH));
764        }
765        matrix.cols += right_matrix.cols;
766        matrix
767            .data
768            .extend_from_slice(&right_matrix.data[..right_matrix.rows * right_matrix.cols]);
769    }
770    Ok(matrix)
771}
772
773fn covariance_dense(matrix: &Matrix, weight: &CovWeightSpec) -> BuiltinResult<Tensor> {
774    let cols = matrix.cols;
775    let rows = matrix.rows;
776
777    if cols == 0 {
778        return Tensor::new(Vec::new(), vec![0, 0]).map_err(cov_internal_error);
779    }
780
781    let mut result = vec![f64::NAN; cols * cols];
782
783    match weight {
784        CovWeightSpec::Scalar(normalization) => {
785            let denom = match normalization {
786                CovNormalization::Unbiased => (rows as f64) - 1.0,
787                CovNormalization::Biased => rows as f64,
788            };
789            if denom <= 0.0 {
790                return Tensor::new(result, vec![cols, cols]).map_err(cov_internal_error);
791            }
792
793            let mut means = vec![0.0; cols];
794            for (col, mean_slot) in means.iter_mut().enumerate() {
795                let column = matrix.column(col);
796                let mut sum = 0.0;
797                let mut valid = true;
798                for &value in column {
799                    if !value.is_finite() {
800                        valid = false;
801                        break;
802                    }
803                    sum += value;
804                }
805                *mean_slot = if valid { sum / (rows as f64) } else { f64::NAN };
806            }
807
808            for i in 0..cols {
809                for j in i..cols {
810                    let value = covariance_unweighted_pair(matrix, i, j, &means, denom);
811                    set_entry(&mut result, cols, i, j, sanitize_covariance(i == j, value));
812                }
813            }
814        }
815        CovWeightSpec::Vector(weights) => {
816            if weights.len() != rows {
817                return Err(cov_error_with_detail(
818                    &COV_ERROR_WEIGHT_VECTOR_LENGTH_MISMATCH,
819                    format!("expected {rows} elements"),
820                ));
821            }
822            let sum_w: f64 = weights.iter().sum();
823            if sum_w <= 0.0 {
824                return Tensor::new(result, vec![cols, cols]).map_err(cov_internal_error);
825            }
826            let denom = sum_w - 1.0;
827            if denom <= 0.0 {
828                return Tensor::new(result, vec![cols, cols]).map_err(cov_internal_error);
829            }
830
831            let mut means = vec![0.0; cols];
832            for (col, mean_slot) in means.iter_mut().enumerate() {
833                let column = matrix.column(col);
834                let mut weighted_sum = 0.0;
835                let mut valid = true;
836                for (row, &value) in column.iter().enumerate() {
837                    if !value.is_finite() {
838                        valid = false;
839                        break;
840                    }
841                    weighted_sum += weights[row] * value;
842                }
843                *mean_slot = if valid {
844                    weighted_sum / sum_w
845                } else {
846                    f64::NAN
847                };
848            }
849
850            for i in 0..cols {
851                for j in i..cols {
852                    let value = covariance_weighted_pair(matrix, i, j, weights, &means, denom);
853                    set_entry(&mut result, cols, i, j, sanitize_covariance(i == j, value));
854                }
855            }
856        }
857    }
858
859    Tensor::new(result, vec![cols, cols]).map_err(cov_internal_error)
860}
861
862fn filter_complete_rows(matrix: &Matrix, weight: CovWeightSpec) -> (Matrix, CovWeightSpec) {
863    if matrix.rows == 0 {
864        return (
865            Matrix {
866                data: Vec::new(),
867                rows: 0,
868                cols: matrix.cols,
869            },
870            weight,
871        );
872    }
873
874    let mut valid_rows = Vec::new();
875    for row in 0..matrix.rows {
876        let mut is_valid = true;
877        for col in 0..matrix.cols {
878            if !matrix.get(row, col).is_finite() {
879                is_valid = false;
880                break;
881            }
882        }
883        if is_valid {
884            valid_rows.push(row);
885        }
886    }
887
888    if valid_rows.len() == matrix.rows {
889        // No filtering required.
890        return (matrix.clone(), weight);
891    }
892
893    let mut data = Vec::with_capacity(valid_rows.len() * matrix.cols);
894    for col in 0..matrix.cols {
895        for &row in &valid_rows {
896            data.push(matrix.get(row, col));
897        }
898    }
899
900    let filtered_matrix = Matrix {
901        data,
902        rows: valid_rows.len(),
903        cols: matrix.cols,
904    };
905
906    let filtered_weight = match weight {
907        CovWeightSpec::Scalar(norm) => CovWeightSpec::Scalar(norm),
908        CovWeightSpec::Vector(vec) => {
909            let mut filtered = Vec::with_capacity(valid_rows.len());
910            for &row in &valid_rows {
911                filtered.push(vec[row]);
912            }
913            CovWeightSpec::Vector(filtered)
914        }
915    };
916
917    (filtered_matrix, filtered_weight)
918}
919
920fn covariance_pairwise(matrix: &Matrix, weight: &CovWeightSpec) -> BuiltinResult<Tensor> {
921    let cols = matrix.cols;
922    if cols == 0 {
923        return Tensor::new(Vec::new(), vec![0, 0]).map_err(cov_internal_error);
924    }
925    let mut result = vec![f64::NAN; cols * cols];
926    for i in 0..cols {
927        let variance = covariance_pair(matrix, i, i, weight);
928        set_entry(&mut result, cols, i, i, sanitize_covariance(true, variance));
929        for j in (i + 1)..cols {
930            let value = covariance_pair(matrix, i, j, weight);
931            set_entry(&mut result, cols, i, j, sanitize_covariance(false, value));
932        }
933    }
934    Tensor::new(result, vec![cols, cols]).map_err(cov_internal_error)
935}
936
937fn covariance_unweighted_pair(
938    matrix: &Matrix,
939    lhs: usize,
940    rhs: usize,
941    means: &[f64],
942    denom: f64,
943) -> f64 {
944    if !means[lhs].is_finite() || !means[rhs].is_finite() {
945        return f64::NAN;
946    }
947    let mut accumulator = 0.0;
948    for row in 0..matrix.rows {
949        let x = matrix.get(row, lhs);
950        let y = matrix.get(row, rhs);
951        if !x.is_finite() || !y.is_finite() {
952            return f64::NAN;
953        }
954        accumulator += (x - means[lhs]) * (y - means[rhs]);
955    }
956    accumulator / denom
957}
958
959fn covariance_weighted_pair(
960    matrix: &Matrix,
961    lhs: usize,
962    rhs: usize,
963    weights: &[f64],
964    means: &[f64],
965    denom: f64,
966) -> f64 {
967    if !means[lhs].is_finite() || !means[rhs].is_finite() {
968        return f64::NAN;
969    }
970    let mut accumulator = 0.0;
971    for (row, &weight) in weights.iter().enumerate().take(matrix.rows) {
972        if weight == 0.0 {
973            continue;
974        }
975        let x = matrix.get(row, lhs);
976        let y = matrix.get(row, rhs);
977        if !x.is_finite() || !y.is_finite() {
978            return f64::NAN;
979        }
980        accumulator += weight * (x - means[lhs]) * (y - means[rhs]);
981    }
982    accumulator / denom
983}
984
985fn covariance_pair(matrix: &Matrix, lhs: usize, rhs: usize, weight: &CovWeightSpec) -> f64 {
986    match weight {
987        CovWeightSpec::Scalar(normalization) => {
988            let mut xs = Vec::new();
989            let mut ys = Vec::new();
990            for row in 0..matrix.rows {
991                let x = matrix.get(row, lhs);
992                let y = matrix.get(row, rhs);
993                if x.is_finite() && y.is_finite() {
994                    xs.push(x);
995                    ys.push(y);
996                }
997            }
998            covariance_unweighted_slice(&xs, &ys, *normalization)
999        }
1000        CovWeightSpec::Vector(weights) => {
1001            let mut xs = Vec::new();
1002            let mut ys = Vec::new();
1003            let mut ws = Vec::new();
1004            for (row, &weight) in weights.iter().enumerate().take(matrix.rows) {
1005                let x = matrix.get(row, lhs);
1006                let y = matrix.get(row, rhs);
1007                if x.is_finite() && y.is_finite() {
1008                    xs.push(x);
1009                    ys.push(y);
1010                    ws.push(weight);
1011                }
1012            }
1013            covariance_weighted_slice(&xs, &ys, &ws)
1014        }
1015    }
1016}
1017
1018fn covariance_unweighted_slice(xs: &[f64], ys: &[f64], normalization: CovNormalization) -> f64 {
1019    if xs.is_empty() || ys.is_empty() {
1020        return f64::NAN;
1021    }
1022    let n = xs.len().min(ys.len());
1023    if n == 0 {
1024        return f64::NAN;
1025    }
1026    let denom = match normalization {
1027        CovNormalization::Unbiased => (n as f64) - 1.0,
1028        CovNormalization::Biased => n as f64,
1029    };
1030    if denom <= 0.0 {
1031        return f64::NAN;
1032    }
1033    let sum_x: f64 = xs.iter().take(n).sum();
1034    let sum_y: f64 = ys.iter().take(n).sum();
1035    let mean_x = sum_x / (n as f64);
1036    let mean_y = sum_y / (n as f64);
1037    let mut accumulator = 0.0;
1038    for idx in 0..n {
1039        accumulator += (xs[idx] - mean_x) * (ys[idx] - mean_y);
1040    }
1041    accumulator / denom
1042}
1043
1044fn covariance_weighted_slice(xs: &[f64], ys: &[f64], weights: &[f64]) -> f64 {
1045    if xs.is_empty() || ys.is_empty() || weights.is_empty() {
1046        return f64::NAN;
1047    }
1048    let n = xs.len().min(ys.len()).min(weights.len());
1049    if n == 0 {
1050        return f64::NAN;
1051    }
1052    let sum_w: f64 = weights.iter().take(n).sum();
1053    if sum_w <= 0.0 {
1054        return f64::NAN;
1055    }
1056    let denom = sum_w - 1.0;
1057    if denom <= 0.0 {
1058        return f64::NAN;
1059    }
1060    let mut mean_x = 0.0;
1061    let mut mean_y = 0.0;
1062    for idx in 0..n {
1063        mean_x += weights[idx] * xs[idx];
1064        mean_y += weights[idx] * ys[idx];
1065    }
1066    mean_x /= sum_w;
1067    mean_y /= sum_w;
1068    let mut accumulator = 0.0;
1069    for idx in 0..n {
1070        accumulator += weights[idx] * (xs[idx] - mean_x) * (ys[idx] - mean_y);
1071    }
1072    accumulator / denom
1073}
1074
1075fn sanitize_covariance(is_diag: bool, value: f64) -> f64 {
1076    if !value.is_finite() {
1077        return value;
1078    }
1079    if is_diag && value < 0.0 && value > -1.0e-12 {
1080        0.0
1081    } else {
1082        value
1083    }
1084}
1085
1086fn set_entry(buffer: &mut [f64], dim: usize, row: usize, col: usize, value: f64) {
1087    let idx = row + col * dim;
1088    buffer[idx] = value;
1089    if row != col {
1090        let symmetrical = col + row * dim;
1091        buffer[symmetrical] = value;
1092    }
1093}
1094
1095#[cfg(test)]
1096pub(crate) mod tests {
1097    use super::*;
1098    use crate::builtins::common::test_support;
1099    use futures::executor::block_on;
1100    use runmat_builtins::{ResolveContext, Tensor, Type};
1101
1102    fn assert_tensor_close(actual: &Tensor, expected: &[f64], tol: f64) {
1103        let dim = (expected.len() as f64).sqrt() as usize;
1104        assert_eq!(actual.shape, vec![dim, dim], "unexpected tensor shape");
1105        for (idx, (&got, &want)) in actual.data.iter().zip(expected.iter()).enumerate() {
1106            if want.is_nan() {
1107                assert!(
1108                    got.is_nan(),
1109                    "expected NaN at linear index {idx}, found {got}"
1110                );
1111            } else {
1112                assert!(
1113                    (got - want).abs() <= tol,
1114                    "mismatch at linear index {idx}: got {got}, expected {want}"
1115                );
1116            }
1117        }
1118    }
1119
1120    #[test]
1121    fn cov_type_preserves_column_count() {
1122        let out = cov_type(
1123            &[Type::Tensor {
1124                shape: Some(vec![Some(5), Some(3)]),
1125            }],
1126            &ResolveContext::new(Vec::new()),
1127        );
1128        assert_eq!(
1129            out,
1130            Type::Tensor {
1131                shape: Some(vec![Some(3), Some(3)])
1132            }
1133        );
1134    }
1135
1136    #[test]
1137    fn cov_type_vector_returns_scalar() {
1138        let out = cov_type(
1139            &[Type::Tensor {
1140                shape: Some(vec![Some(1), Some(4)]),
1141            }],
1142            &ResolveContext::new(Vec::new()),
1143        );
1144        assert_eq!(out, Type::Num);
1145    }
1146
1147    #[test]
1148    fn cov_descriptor_signatures_cover_core_forms() {
1149        let labels: Vec<&str> = COV_DESCRIPTOR
1150            .signatures
1151            .iter()
1152            .map(|sig| sig.label)
1153            .collect();
1154        assert!(labels.contains(&"C = cov(X)"));
1155        assert!(labels.contains(&"C = cov(X, normalization)"));
1156        assert!(labels.contains(&"C = cov(X, Y, w, opt)"));
1157    }
1158
1159    #[cfg(feature = "wgpu")]
1160    fn cov_builtin_sync(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
1161        block_on(super::cov_builtin(value, rest))
1162    }
1163
1164    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1165    #[test]
1166    fn cov_matrix_basic() {
1167        let tensor = Tensor::new(
1168            vec![
1169                4.0, 4.2, 3.9, 4.3, 4.1, //
1170                2.0, 2.1, 2.0, 2.1, 2.2, //
1171                0.60, 0.59, 0.58, 0.62, 0.63,
1172            ],
1173            vec![5, 3],
1174        )
1175        .unwrap();
1176        let result = block_on(cov_builtin(Value::Tensor(tensor), Vec::new())).expect("cov");
1177        let tensor = match result {
1178            Value::Tensor(t) => t,
1179            other => panic!("expected tensor result, got {other:?}"),
1180        };
1181        let expected = [
1182            0.0250, 0.0075, 0.00175, //
1183            0.0075, 0.0070, 0.00135, //
1184            0.00175, 0.00135, 0.00043,
1185        ];
1186        assert_tensor_close(&tensor, &expected, 1.0e-6);
1187    }
1188
1189    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1190    #[test]
1191    fn cov_two_vectors() {
1192        let x = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
1193        let y = Tensor::new(vec![10.0, 11.0, 9.0, 12.0], vec![4, 1]).unwrap();
1194        let result = block_on(cov_builtin(Value::Tensor(x), vec![Value::Tensor(y)])).expect("cov");
1195        let tensor = match result {
1196            Value::Tensor(t) => t,
1197            other => panic!("expected tensor result, got {other:?}"),
1198        };
1199        let expected = [
1200            1.6666666666666667,
1201            0.6666666666666666, //
1202            0.6666666666666666,
1203            1.6666666666666667,
1204        ];
1205        assert_tensor_close(&tensor, &expected, 1.0e-6);
1206    }
1207
1208    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1209    #[test]
1210    fn cov_weighted_vector() {
1211        let tensor = Tensor::new(
1212            vec![
1213                4.0, 4.2, 3.9, 4.3, 4.1, //
1214                2.0, 2.1, 2.0, 2.1, 2.2,
1215            ],
1216            vec![5, 2],
1217        )
1218        .unwrap();
1219        let weights = Tensor::new(vec![1.0, 1.0, 1.0, 2.0, 2.0], vec![5, 1]).unwrap();
1220        let result = block_on(cov_builtin(
1221            Value::Tensor(tensor),
1222            vec![Value::Tensor(weights)],
1223        ))
1224        .expect("cov");
1225        let tensor = match result {
1226            Value::Tensor(t) => t,
1227            other => panic!("expected tensor result, got {other:?}"),
1228        };
1229        let expected = [
1230            0.022380952380952376,
1231            0.004999999999999994, //
1232            0.004999999999999994,
1233            0.006666666666666678,
1234        ];
1235        assert_tensor_close(&tensor, &expected, 1.0e-6);
1236    }
1237
1238    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1239    #[test]
1240    fn cov_omitrows() {
1241        let tensor = Tensor::new(
1242            vec![
1243                1.0,
1244                3.0,
1245                f64::NAN,
1246                8.0, //
1247                f64::NAN,
1248                4.0,
1249                6.0,
1250                9.0, //
1251                2.0,
1252                5.0,
1253                7.0,
1254                10.0,
1255            ],
1256            vec![4, 3],
1257        )
1258        .unwrap();
1259        let result = block_on(cov_builtin(
1260            Value::Tensor(tensor),
1261            vec![Value::from("omitrows")],
1262        ))
1263        .expect("cov");
1264        let tensor = match result {
1265            Value::Tensor(t) => t,
1266            other => panic!("expected tensor result, got {other:?}"),
1267        };
1268        let expected = [
1269            12.5, 12.5, 12.5, //
1270            12.5, 12.5, 12.5, //
1271            12.5, 12.5, 12.5,
1272        ];
1273        assert_tensor_close(&tensor, &expected, 1.0e-6);
1274    }
1275
1276    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1277    #[test]
1278    fn cov_partialrows() {
1279        let tensor = Tensor::new(
1280            vec![
1281                1.0,
1282                4.0,
1283                7.0, //
1284                2.0,
1285                f64::NAN,
1286                8.0, //
1287                f64::NAN,
1288                6.0,
1289                9.0,
1290            ],
1291            vec![3, 3],
1292        )
1293        .unwrap();
1294        let result = block_on(cov_builtin(
1295            Value::Tensor(tensor),
1296            vec![Value::from("partialrows")],
1297        ))
1298        .expect("cov");
1299        let tensor = match result {
1300            Value::Tensor(t) => t,
1301            other => panic!("expected tensor result, got {other:?}"),
1302        };
1303        let expected = [
1304            9.0,
1305            18.0,
1306            4.5, //
1307            18.0,
1308            18.0,
1309            f64::NAN, //
1310            4.5,
1311            f64::NAN,
1312            4.5,
1313        ];
1314        assert_tensor_close(&tensor, &expected, 1.0e-6);
1315    }
1316
1317    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1318    #[test]
1319    fn cov_mismatched_rows_errors() {
1320        let left = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
1321        let right = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1322        let err = block_on(cov_builtin(Value::Tensor(left), vec![Value::Tensor(right)]))
1323            .expect_err("expected mismatch error");
1324        assert_eq!(err.identifier(), COV_ERROR_ROWS_MISMATCH.identifier);
1325    }
1326
1327    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1328    #[test]
1329    fn cov_invalid_flag_errors() {
1330        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1331        let err = block_on(cov_builtin(Value::Tensor(tensor), vec![Value::Num(2.5)]))
1332            .expect_err("expected invalid flag error");
1333        assert_eq!(err.identifier(), COV_ERROR_NORMALIZATION_INVALID.identifier);
1334    }
1335
1336    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1337    #[test]
1338    fn cov_weight_vector_length_mismatch_errors() {
1339        let x = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]).unwrap();
1340        let y = Tensor::new(vec![10.0, 11.0, 12.0], vec![3, 1]).unwrap();
1341        let w = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1342        let err = block_on(cov_builtin(
1343            Value::Tensor(x),
1344            vec![Value::Tensor(y), Value::Tensor(w)],
1345        ))
1346        .expect_err("expected weight length mismatch");
1347        assert_eq!(
1348            err.identifier(),
1349            COV_ERROR_WEIGHT_VECTOR_LENGTH_MISMATCH.identifier
1350        );
1351    }
1352
1353    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1354    #[test]
1355    fn cov_unknown_rows_option_errors() {
1356        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1357        let err = block_on(cov_builtin(
1358            Value::Tensor(tensor),
1359            vec![Value::from("rows"), Value::from("bogus")],
1360        ))
1361        .expect_err("expected unknown rows option error");
1362        assert_eq!(err.identifier(), COV_ERROR_ROWS_OPTION_UNKNOWN.identifier);
1363    }
1364
1365    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1366    #[test]
1367    fn cov_duplicate_normalization_flag_errors() {
1368        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1369        let err = block_on(cov_builtin(
1370            Value::Tensor(tensor),
1371            vec![Value::Num(0.0), Value::Num(1.0)],
1372        ))
1373        .expect_err("expected duplicate normalization flag error");
1374        assert_eq!(
1375            err.identifier(),
1376            COV_ERROR_NORMALIZATION_DUPLICATE.identifier
1377        );
1378    }
1379
1380    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1381    #[test]
1382    fn cov_too_many_array_arguments_errors() {
1383        let x = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1384        let y = Tensor::new(vec![4.0, 5.0, 6.0], vec![3, 1]).unwrap();
1385        let w = Tensor::new(vec![1.0, 1.0, 1.0], vec![3, 1]).unwrap();
1386        let z = Tensor::new(vec![7.0, 8.0, 9.0], vec![3, 1]).unwrap();
1387        let err = block_on(cov_builtin(
1388            Value::Tensor(x),
1389            vec![Value::Tensor(y), Value::Tensor(w), Value::Tensor(z)],
1390        ))
1391        .expect_err("expected too many array arguments error");
1392        assert_eq!(
1393            err.identifier(),
1394            COV_ERROR_TOO_MANY_ARRAY_ARGUMENTS.identifier
1395        );
1396    }
1397
1398    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1399    #[test]
1400    fn cov_gpu_roundtrip() {
1401        test_support::with_test_provider(|provider| {
1402            let tensor = Tensor::new(
1403                vec![
1404                    4.0, 4.2, 3.9, 4.3, 4.1, //
1405                    2.0, 2.1, 2.0, 2.1, 2.2,
1406                ],
1407                vec![5, 2],
1408            )
1409            .unwrap();
1410            let view = runmat_accelerate_api::HostTensorView {
1411                data: &tensor.data,
1412                shape: &tensor.shape,
1413            };
1414            let handle = provider.upload(&view).expect("upload");
1415            let result = block_on(cov_builtin(Value::GpuTensor(handle), Vec::new())).expect("cov");
1416            let gathered = test_support::gather(result).expect("gather");
1417            let expected = [
1418                0.0250, 0.0075, //
1419                0.0075, 0.0070,
1420            ];
1421            assert_tensor_close(&gathered, &expected, 1.0e-6);
1422        });
1423    }
1424
1425    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1426    #[test]
1427    #[cfg(feature = "wgpu")]
1428    fn cov_wgpu_matches_cpu() {
1429        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1430            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1431        );
1432
1433        let tensor = Tensor::new(
1434            vec![
1435                4.0, 4.2, 3.9, 4.3, 4.1, //
1436                2.0, 2.1, 2.0, 2.1, 2.2,
1437            ],
1438            vec![5, 2],
1439        )
1440        .unwrap();
1441
1442        let cpu_result =
1443            block_on(cov_builtin(Value::Tensor(tensor.clone()), Vec::new())).expect("cov");
1444        let cpu_tensor = match cpu_result {
1445            Value::Tensor(t) => t,
1446            other => panic!("expected tensor result, got {other:?}"),
1447        };
1448
1449        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1450        let view = runmat_accelerate_api::HostTensorView {
1451            data: &tensor.data,
1452            shape: &tensor.shape,
1453        };
1454        let handle = provider.upload(&view).expect("upload");
1455
1456        let gpu_value = cov_builtin_sync(Value::GpuTensor(handle), Vec::new()).expect("cov");
1457        let gathered = test_support::gather(gpu_value).expect("gather");
1458
1459        assert_tensor_close(&gathered, &cpu_tensor.data, 1.0e-6);
1460    }
1461}