1use runmat_accelerate_api::{CovNormalization, CovRows, CovarianceOptions};
4use runmat_builtins::{
5 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7 Tensor, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::gpu_helpers;
12use crate::builtins::common::spec::{
13 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
15};
16use crate::builtins::common::tensor;
17use crate::builtins::stats::type_resolvers::cov_type;
18use crate::{build_runtime_error, BuiltinResult, RuntimeError};
19
20const NAME: &str = "cov";
21const COV_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
22 name: "C",
23 ty: BuiltinParamType::NumericArray,
24 arity: BuiltinParamArity::Required,
25 default: None,
26 description: "Covariance matrix.",
27}];
28
29const COV_INPUTS_X: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
30 name: "X",
31 ty: BuiltinParamType::Any,
32 arity: BuiltinParamArity::Required,
33 default: None,
34 description: "Input observations (rows are observations, columns are variables).",
35}];
36
37const COV_INPUTS_X_Y_OR_W: [BuiltinParamDescriptor; 2] = [
38 BuiltinParamDescriptor {
39 name: "X",
40 ty: BuiltinParamType::Any,
41 arity: BuiltinParamArity::Required,
42 default: None,
43 description: "Input observations (rows are observations, columns are variables).",
44 },
45 BuiltinParamDescriptor {
46 name: "Y_or_w",
47 ty: BuiltinParamType::Any,
48 arity: BuiltinParamArity::Required,
49 default: None,
50 description: "Second dataset (Y) or weight vector (w), depending on shape/position.",
51 },
52];
53
54const COV_INPUTS_X_NORMALIZATION: [BuiltinParamDescriptor; 2] = [
55 BuiltinParamDescriptor {
56 name: "X",
57 ty: BuiltinParamType::Any,
58 arity: BuiltinParamArity::Required,
59 default: None,
60 description: "Input observations (rows are observations, columns are variables).",
61 },
62 BuiltinParamDescriptor {
63 name: "normalization",
64 ty: BuiltinParamType::NumericScalar,
65 arity: BuiltinParamArity::Required,
66 default: Some("0"),
67 description: "Normalization flag: 0 (unbiased) or 1 (biased).",
68 },
69];
70
71const COV_INPUTS_X_ROWS: [BuiltinParamDescriptor; 2] = [
72 BuiltinParamDescriptor {
73 name: "X",
74 ty: BuiltinParamType::Any,
75 arity: BuiltinParamArity::Required,
76 default: None,
77 description: "Input observations (rows are observations, columns are variables).",
78 },
79 BuiltinParamDescriptor {
80 name: "rows_option",
81 ty: BuiltinParamType::StringScalar,
82 arity: BuiltinParamArity::Required,
83 default: Some("\"all\""),
84 description: "Rows handling mode: 'all', 'omitrows', or 'partialrows'.",
85 },
86];
87
88const COV_INPUTS_X_Y_OPT: [BuiltinParamDescriptor; 3] = [
89 BuiltinParamDescriptor {
90 name: "X",
91 ty: BuiltinParamType::Any,
92 arity: BuiltinParamArity::Required,
93 default: None,
94 description: "Input observations (rows are observations, columns are variables).",
95 },
96 BuiltinParamDescriptor {
97 name: "Y",
98 ty: BuiltinParamType::Any,
99 arity: BuiltinParamArity::Required,
100 default: None,
101 description: "Second dataset with matching row count.",
102 },
103 BuiltinParamDescriptor {
104 name: "opt",
105 ty: BuiltinParamType::Any,
106 arity: BuiltinParamArity::Required,
107 default: None,
108 description: "Normalization flag or rows option.",
109 },
110];
111
112const COV_INPUTS_X_Y_W: [BuiltinParamDescriptor; 3] = [
113 BuiltinParamDescriptor {
114 name: "X",
115 ty: BuiltinParamType::Any,
116 arity: BuiltinParamArity::Required,
117 default: None,
118 description: "Input observations (rows are observations, columns are variables).",
119 },
120 BuiltinParamDescriptor {
121 name: "Y",
122 ty: BuiltinParamType::Any,
123 arity: BuiltinParamArity::Required,
124 default: None,
125 description: "Second dataset with matching row count.",
126 },
127 BuiltinParamDescriptor {
128 name: "w",
129 ty: BuiltinParamType::Any,
130 arity: BuiltinParamArity::Required,
131 default: None,
132 description: "Weight vector with one weight per observation row.",
133 },
134];
135
136const COV_INPUTS_X_Y_W_OPT: [BuiltinParamDescriptor; 4] = [
137 BuiltinParamDescriptor {
138 name: "X",
139 ty: BuiltinParamType::Any,
140 arity: BuiltinParamArity::Required,
141 default: None,
142 description: "Input observations (rows are observations, columns are variables).",
143 },
144 BuiltinParamDescriptor {
145 name: "Y",
146 ty: BuiltinParamType::Any,
147 arity: BuiltinParamArity::Required,
148 default: None,
149 description: "Second dataset with matching row count.",
150 },
151 BuiltinParamDescriptor {
152 name: "w",
153 ty: BuiltinParamType::Any,
154 arity: BuiltinParamArity::Required,
155 default: None,
156 description: "Weight vector with one weight per observation row.",
157 },
158 BuiltinParamDescriptor {
159 name: "opt",
160 ty: BuiltinParamType::Any,
161 arity: BuiltinParamArity::Required,
162 default: None,
163 description: "Normalization flag or rows option.",
164 },
165];
166
167const COV_SIGNATURES: [BuiltinSignatureDescriptor; 7] = [
168 BuiltinSignatureDescriptor {
169 label: "C = cov(X)",
170 inputs: &COV_INPUTS_X,
171 outputs: &COV_OUTPUT,
172 },
173 BuiltinSignatureDescriptor {
174 label: "C = cov(X, Y_or_w)",
175 inputs: &COV_INPUTS_X_Y_OR_W,
176 outputs: &COV_OUTPUT,
177 },
178 BuiltinSignatureDescriptor {
179 label: "C = cov(X, normalization)",
180 inputs: &COV_INPUTS_X_NORMALIZATION,
181 outputs: &COV_OUTPUT,
182 },
183 BuiltinSignatureDescriptor {
184 label: "C = cov(X, rows_option)",
185 inputs: &COV_INPUTS_X_ROWS,
186 outputs: &COV_OUTPUT,
187 },
188 BuiltinSignatureDescriptor {
189 label: "C = cov(X, Y, opt)",
190 inputs: &COV_INPUTS_X_Y_OPT,
191 outputs: &COV_OUTPUT,
192 },
193 BuiltinSignatureDescriptor {
194 label: "C = cov(X, Y, w)",
195 inputs: &COV_INPUTS_X_Y_W,
196 outputs: &COV_OUTPUT,
197 },
198 BuiltinSignatureDescriptor {
199 label: "C = cov(X, Y, w, opt)",
200 inputs: &COV_INPUTS_X_Y_W_OPT,
201 outputs: &COV_OUTPUT,
202 },
203];
204
205const COV_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
206 code: "RM.COV.INVALID_ARGUMENT",
207 identifier: Some("RunMat:cov:InvalidArgument"),
208 when: "Arguments are malformed or unsupported for cov.",
209 message: "cov: invalid argument",
210};
211
212const COV_ERROR_COMPLEX_UNSUPPORTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
213 code: "RM.COV.COMPLEX_UNSUPPORTED",
214 identifier: Some("RunMat:cov:ComplexUnsupported"),
215 when: "Any argument is complex-valued.",
216 message: "cov: complex inputs are not supported yet",
217};
218
219const COV_ERROR_ROWS_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
220 code: "RM.COV.ROWS_MISMATCH",
221 identifier: Some("RunMat:cov:RowsMismatch"),
222 when: "Two input datasets do not have the same number of rows.",
223 message: "cov: inputs must have the same number of rows",
224};
225
226const COV_ERROR_NORMALIZATION_INVALID: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
227 code: "RM.COV.NORMALIZATION_INVALID",
228 identifier: Some("RunMat:cov:NormalizationInvalid"),
229 when: "Normalization flag is non-finite, non-integer, or not 0/1.",
230 message: "cov: normalization flag is invalid",
231};
232
233const COV_ERROR_WEIGHT_VECTOR_LENGTH_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
234 code: "RM.COV.WEIGHT_VECTOR_LENGTH_MISMATCH",
235 identifier: Some("RunMat:cov:WeightVectorLengthMismatch"),
236 when: "Weight vector length does not match observation row count.",
237 message: "cov: weight vector length mismatch",
238};
239
240const COV_ERROR_ROWS_OPTION_UNKNOWN: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
241 code: "RM.COV.ROWS_OPTION_UNKNOWN",
242 identifier: Some("RunMat:cov:RowsOptionUnknown"),
243 when: "Rows option is not one of all/omitrows/partialrows.",
244 message: "cov: unknown rows option",
245};
246
247const COV_ERROR_NORMALIZATION_DUPLICATE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
248 code: "RM.COV.NORMALIZATION_DUPLICATE",
249 identifier: Some("RunMat:cov:NormalizationDuplicate"),
250 when: "Normalization flag is provided more than once.",
251 message: "cov: normalization flag specified more than once",
252};
253
254const COV_ERROR_TOO_MANY_ARRAY_ARGUMENTS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
255 code: "RM.COV.TOO_MANY_ARRAY_ARGUMENTS",
256 identifier: Some("RunMat:cov:TooManyArrayArguments"),
257 when: "More than two data arrays (or Y plus weight) are provided.",
258 message: "cov: too many array arguments",
259};
260
261const COV_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
262 code: "RM.COV.INTERNAL",
263 identifier: Some("RunMat:cov:Internal"),
264 when: "Internal tensor conversion/allocation or covariance computation fails.",
265 message: "cov: internal operation failed",
266};
267
268const COV_ERRORS: [BuiltinErrorDescriptor; 9] = [
269 COV_ERROR_INVALID_ARGUMENT,
270 COV_ERROR_COMPLEX_UNSUPPORTED,
271 COV_ERROR_ROWS_MISMATCH,
272 COV_ERROR_NORMALIZATION_INVALID,
273 COV_ERROR_WEIGHT_VECTOR_LENGTH_MISMATCH,
274 COV_ERROR_ROWS_OPTION_UNKNOWN,
275 COV_ERROR_NORMALIZATION_DUPLICATE,
276 COV_ERROR_TOO_MANY_ARRAY_ARGUMENTS,
277 COV_ERROR_INTERNAL,
278];
279
280pub const COV_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
281 signatures: &COV_SIGNATURES,
282 output_mode: BuiltinOutputMode::Fixed,
283 completion_policy: BuiltinCompletionPolicy::Public,
284 errors: &COV_ERRORS,
285};
286
287fn cov_error_with(
288 error: &'static BuiltinErrorDescriptor,
289 message: impl Into<String>,
290) -> RuntimeError {
291 let mut builder = build_runtime_error(message).with_builtin(NAME);
292 if let Some(identifier) = error.identifier {
293 builder = builder.with_identifier(identifier);
294 }
295 builder.build()
296}
297
298fn cov_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
299 cov_error_with(error, error.message)
300}
301
302fn cov_error_with_detail(
303 error: &'static BuiltinErrorDescriptor,
304 detail: impl std::fmt::Display,
305) -> RuntimeError {
306 cov_error_with(error, format!("{}: {detail}", error.message))
307}
308
309fn cov_internal_error(message: impl Into<String>) -> RuntimeError {
310 cov_error_with(&COV_ERROR_INTERNAL, message)
311}
312
313#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::stats::summary::cov")]
314pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
315 name: "cov",
316 op_kind: GpuOpKind::Custom("summary-stats"),
317 supported_precisions: &[ScalarType::F32, ScalarType::F64],
318 broadcast: BroadcastSemantics::None,
319 provider_hooks: &[ProviderHook::Custom("covariance")],
320 constant_strategy: ConstantStrategy::InlineLiteral,
321 residency: ResidencyPolicy::NewHandle,
322 nan_mode: ReductionNaN::Include,
323 two_pass_threshold: None,
324 workgroup_size: None,
325 accepts_nan_mode: false,
326 notes: "GPU execution is available when rows='all' and no weight vector is supplied; other cases fall back to the CPU path.",
327};
328
329#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::stats::summary::cov")]
330pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
331 name: "cov",
332 shape: ShapeRequirements::Any,
333 constant_strategy: ConstantStrategy::InlineLiteral,
334 elementwise: None,
335 reduction: None,
336 emits_nan: true,
337 notes: "The covariance builtin is treated as a fusion boundary and executes via dedicated kernels or the host reference.",
338};
339
340#[runtime_builtin(
341 name = "cov",
342 category = "stats/summary",
343 summary = "Compute covariance matrices.",
344 keywords = "cov,covariance,statistics,weights,gpu",
345 accel = "reduction",
346 type_resolver(cov_type),
347 descriptor(crate::builtins::stats::summary::cov::COV_DESCRIPTOR),
348 builtin_path = "crate::builtins::stats::summary::cov"
349)]
350async fn cov_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
351 let args = CovArgs::parse(value, rest)?;
352 if let Some(result) = cov_try_gpu(&args).await? {
353 return Ok(result);
354 }
355 cov_host(args).await
356}
357
358pub fn cov_from_tensors(
360 left: Tensor,
361 right: Option<Tensor>,
362 rows: CovRows,
363 weight: CovWeightSpec,
364) -> BuiltinResult<Tensor> {
365 let matrix = combine_tensors(left, right)?;
366 if let CovWeightSpec::Vector(ref vec) = weight {
367 if matrix.rows != vec.len() {
368 return Err(cov_error_with_detail(
369 &COV_ERROR_WEIGHT_VECTOR_LENGTH_MISMATCH,
370 format!("expected {} elements", matrix.rows),
371 ));
372 }
373 }
374 match rows {
375 CovRows::All => covariance_dense(&matrix, &weight),
376 CovRows::OmitRows => {
377 let (filtered, filtered_weight) = filter_complete_rows(&matrix, weight);
378 covariance_dense(&filtered, &filtered_weight)
379 }
380 CovRows::PartialRows => covariance_pairwise(&matrix, &weight),
381 }
382}
383
384#[derive(Debug)]
385struct CovArgs {
386 first: Value,
387 second: Option<Value>,
388 normalization: CovNormalization,
389 rows: CovRows,
390 weight_vector: Option<Value>,
391}
392
393impl CovArgs {
394 fn parse(first: Value, rest: Vec<Value>) -> BuiltinResult<Self> {
395 let mut second_candidate: Option<Value> = None;
396 let mut weight_candidate: Option<Value> = None;
397 let mut normalization = CovNormalization::Unbiased;
398 let mut normalization_explicit = false;
399 let mut rows = CovRows::All;
400
401 let iter = rest.into_iter();
402 for arg in iter {
403 match arg {
404 Value::String(_) | Value::StringArray(_) | Value::CharArray(_) => {
405 let key = tensor::value_to_string(&arg)
406 .ok_or_else(|| cov_error(&COV_ERROR_INVALID_ARGUMENT))?;
407 let lowered = key.trim().to_ascii_lowercase();
408 rows = parse_rows_option(&lowered)?;
409 }
410 Value::Tensor(_) | Value::LogicalArray(_) | Value::GpuTensor(_) => {
411 if second_candidate.is_none() {
412 second_candidate = Some(arg);
413 } else if weight_candidate.is_none() {
414 weight_candidate = Some(arg);
415 } else {
416 return Err(cov_error(&COV_ERROR_TOO_MANY_ARRAY_ARGUMENTS));
417 }
418 }
419 Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
420 if normalization_explicit || weight_candidate.is_some() {
421 return Err(cov_error(&COV_ERROR_NORMALIZATION_DUPLICATE));
422 }
423 normalization = parse_normalization(arg)?;
424 normalization_explicit = true;
425 }
426 Value::ComplexTensor(_) => {
427 return Err(cov_error(&COV_ERROR_COMPLEX_UNSUPPORTED));
428 }
429 other => {
430 return Err(cov_error_with_detail(
431 &COV_ERROR_INVALID_ARGUMENT,
432 format!("{other:?}"),
433 ))
434 }
435 }
436 }
437
438 if let Some(weight_array) = weight_candidate {
439 return Ok(Self {
441 first,
442 second: second_candidate,
443 normalization,
444 rows,
445 weight_vector: Some(weight_array),
446 });
447 }
448
449 let mut second = second_candidate;
450 let mut weight_vector: Option<Value> = None;
451
452 if let Some(candidate) = second.take() {
453 if should_treat_as_weight(&first, &candidate, normalization_explicit, rows)? {
454 weight_vector = Some(candidate);
455 } else {
456 second = Some(candidate);
457 }
458 }
459
460 Ok(Self {
461 first,
462 second,
463 normalization,
464 rows,
465 weight_vector,
466 })
467 }
468}
469
470#[derive(Debug, Clone)]
471pub enum CovWeightSpec {
472 Scalar(CovNormalization),
473 Vector(Vec<f64>),
474}
475
476async fn cov_try_gpu(args: &CovArgs) -> BuiltinResult<Option<Value>> {
477 if args.rows != CovRows::All || args.weight_vector.is_some() {
478 return Ok(None);
479 }
480
481 let provider = match runmat_accelerate_api::provider() {
482 Some(p) => p,
483 None => return Ok(None),
484 };
485
486 let first_handle = match &args.first {
487 Value::GpuTensor(handle) => handle,
488 _ => return Ok(None),
489 };
490
491 let maybe_second_handle = match &args.second {
492 Some(Value::GpuTensor(handle)) => Some(handle),
493 Some(_) => return Ok(None),
494 None => None,
495 };
496
497 let options = CovarianceOptions {
498 normalization: args.normalization,
499 rows: args.rows,
500 has_weight_vector: false,
501 };
502
503 match provider
504 .covariance(first_handle, maybe_second_handle, None, &options)
505 .await
506 {
507 Ok(result) => Ok(Some(Value::GpuTensor(result))),
508 Err(_) => Ok(None),
509 }
510}
511
512async fn cov_host(args: CovArgs) -> BuiltinResult<Value> {
513 let CovArgs {
514 first,
515 second,
516 normalization,
517 rows,
518 weight_vector,
519 } = args;
520
521 let left = value_to_tensor_gather(first).await?;
522 let right = match second {
523 Some(value) => Some(value_to_tensor_gather(value).await?),
524 None => None,
525 };
526
527 let weight_spec = if let Some(weight_value) = weight_vector {
528 let vector = value_to_weight_vector(weight_value, left.rows()).await?;
529 CovWeightSpec::Vector(vector)
530 } else {
531 CovWeightSpec::Scalar(normalization)
532 };
533
534 let tensor = cov_from_tensors(left, right, rows, weight_spec)?;
535 Ok(Value::Tensor(tensor))
536}
537
538async fn value_to_tensor_gather(value: Value) -> BuiltinResult<Tensor> {
539 match value {
540 Value::GpuTensor(handle) => gpu_helpers::gather_tensor_async(&handle).await,
541 Value::LogicalArray(logical) => {
542 tensor::logical_to_tensor(&logical).map_err(cov_internal_error)
543 }
544 other => tensor::value_into_tensor_for("cov", other).map_err(cov_internal_error),
545 }
546}
547
548async fn value_to_weight_vector(value: Value, expected_rows: usize) -> BuiltinResult<Vec<f64>> {
549 let tensor = match value {
550 Value::GpuTensor(handle) => gpu_helpers::gather_tensor_async(&handle).await?,
551 Value::LogicalArray(logical) => {
552 tensor::logical_to_tensor(&logical).map_err(cov_internal_error)?
553 }
554 other => tensor::value_into_tensor_for("cov", other).map_err(cov_internal_error)?,
555 };
556
557 if tensor.shape.len() > 2 {
558 return Err(cov_error_with_detail(
559 &COV_ERROR_INVALID_ARGUMENT,
560 "weight vector must be one-dimensional",
561 ));
562 }
563 if tensor.rows() != expected_rows && tensor.cols() != expected_rows {
564 return Err(cov_error_with_detail(
565 &COV_ERROR_WEIGHT_VECTOR_LENGTH_MISMATCH,
566 format!("expected {expected_rows} elements"),
567 ));
568 }
569 for (idx, weight) in tensor.data.iter().enumerate() {
570 if !weight.is_finite() || *weight < 0.0 {
571 return Err(cov_error_with_detail(
572 &COV_ERROR_INVALID_ARGUMENT,
573 format!("weights must be non-negative finite values (index {idx})"),
574 ));
575 }
576 }
577 if tensor.data.is_empty() {
578 return Err(cov_error_with_detail(
579 &COV_ERROR_INVALID_ARGUMENT,
580 "weight vector cannot be empty",
581 ));
582 }
583 Ok(tensor.data)
584}
585
586fn parse_rows_option(value: &str) -> BuiltinResult<CovRows> {
587 match value {
588 "all" => Ok(CovRows::All),
589 "omitrows" | "omit" => Ok(CovRows::OmitRows),
590 "partialrows" | "partial" | "pairwise" => Ok(CovRows::PartialRows),
591 other => Err(cov_error_with_detail(
592 &COV_ERROR_ROWS_OPTION_UNKNOWN,
593 format!("'{other}'"),
594 )),
595 }
596}
597
598fn parse_normalization(value: Value) -> BuiltinResult<CovNormalization> {
599 match value {
600 Value::Int(i) => match i.to_i64() {
601 0 => Ok(CovNormalization::Unbiased),
602 1 => Ok(CovNormalization::Biased),
603 other => Err(cov_error_with_detail(
604 &COV_ERROR_NORMALIZATION_INVALID,
605 format!("expected 0 or 1, received {other}"),
606 )),
607 },
608 Value::Num(n) => {
609 if !n.is_finite() {
610 return Err(cov_error_with_detail(
611 &COV_ERROR_NORMALIZATION_INVALID,
612 "value must be finite",
613 ));
614 }
615 let rounded = n.round();
616 if (rounded - n).abs() > 1.0e-12 {
617 return Err(cov_error_with_detail(
618 &COV_ERROR_NORMALIZATION_INVALID,
619 "value must be an integer",
620 ));
621 }
622 match rounded as i64 {
623 0 => Ok(CovNormalization::Unbiased),
624 1 => Ok(CovNormalization::Biased),
625 other => Err(cov_error_with_detail(
626 &COV_ERROR_NORMALIZATION_INVALID,
627 format!("expected 0 or 1, received {other}"),
628 )),
629 }
630 }
631 Value::Bool(flag) => Ok(if flag {
632 CovNormalization::Biased
633 } else {
634 CovNormalization::Unbiased
635 }),
636 other => Err(cov_error_with_detail(
637 &COV_ERROR_NORMALIZATION_INVALID,
638 format!("value must be numeric, received {other:?}"),
639 )),
640 }
641}
642
643fn should_treat_as_weight(
644 first: &Value,
645 candidate: &Value,
646 normalization_explicit: bool,
647 rows_option: CovRows,
648) -> BuiltinResult<bool> {
649 let (rows_first, cols_first) = value_rows_cols(first)?;
650 let (rows_candidate, cols_candidate) = value_rows_cols(candidate)?;
651
652 let is_vector = rows_candidate == 1
653 || cols_candidate == 1
654 || rows_candidate * cols_candidate == rows_candidate
655 && (rows_candidate == rows_first || cols_candidate == rows_first);
656
657 if !is_vector {
658 return Ok(false);
659 }
660
661 if rows_candidate != rows_first && cols_candidate != rows_first {
662 return Ok(false);
664 }
665
666 if cols_first == 1 && !normalization_explicit && matches!(rows_option, CovRows::All) {
667 return Ok(false);
669 }
670
671 Ok(true)
672}
673
674fn value_rows_cols(value: &Value) -> BuiltinResult<(usize, usize)> {
675 match value {
676 Value::Tensor(tensor) => Ok((tensor.rows(), tensor.cols())),
677 Value::LogicalArray(array) => {
678 if array.shape.len() > 2 {
679 return Err(cov_error_with_detail(
680 &COV_ERROR_INVALID_ARGUMENT,
681 "inputs must be 2-D matrices or vectors",
682 ));
683 }
684 let rows = if array.shape.is_empty() {
685 1
686 } else {
687 array.shape[0]
688 };
689 let cols = if array.shape.len() >= 2 {
690 array.shape[1]
691 } else {
692 1
693 };
694 Ok((rows, cols))
695 }
696 Value::GpuTensor(handle) => {
697 if handle.shape.len() > 2 {
698 return Err(cov_error_with_detail(
699 &COV_ERROR_INVALID_ARGUMENT,
700 "inputs must be 2-D matrices or vectors",
701 ));
702 }
703 let rows = if handle.shape.is_empty() {
704 1
705 } else {
706 handle.shape[0]
707 };
708 let cols = if handle.shape.len() >= 2 {
709 handle.shape[1]
710 } else {
711 1
712 };
713 Ok((rows, cols))
714 }
715 Value::Num(_) | Value::Int(_) | Value::Bool(_) => Ok((1, 1)),
716 other => Err(cov_error_with_detail(
717 &COV_ERROR_INVALID_ARGUMENT,
718 format!("unsupported input type for shape inspection: {other:?}"),
719 )),
720 }
721}
722
723#[derive(Debug, Clone)]
724struct Matrix {
725 data: Vec<f64>,
726 rows: usize,
727 cols: usize,
728}
729
730impl Matrix {
731 fn from_tensor(name: &str, tensor: Tensor) -> BuiltinResult<Self> {
732 if tensor.shape.len() > 2 {
733 return Err(cov_error_with_detail(
734 &COV_ERROR_INVALID_ARGUMENT,
735 format!("{name}: inputs must be 2-D matrices or vectors"),
736 ));
737 }
738 Ok(Self {
739 rows: tensor.rows(),
740 cols: tensor.cols(),
741 data: tensor.data,
742 })
743 }
744
745 #[inline]
746 fn get(&self, row: usize, col: usize) -> f64 {
747 self.data[row + col * self.rows]
748 }
749
750 #[inline]
751 fn column(&self, col: usize) -> &[f64] {
752 let start = col * self.rows;
753 let end = start + self.rows;
754 &self.data[start..end]
755 }
756}
757
758fn combine_tensors(left: Tensor, right: Option<Tensor>) -> BuiltinResult<Matrix> {
759 let mut matrix = Matrix::from_tensor("cov", left)?;
760 if let Some(second) = right {
761 let right_matrix = Matrix::from_tensor("cov", second)?;
762 if matrix.rows != right_matrix.rows {
763 return Err(cov_error(&COV_ERROR_ROWS_MISMATCH));
764 }
765 matrix.cols += right_matrix.cols;
766 matrix
767 .data
768 .extend_from_slice(&right_matrix.data[..right_matrix.rows * right_matrix.cols]);
769 }
770 Ok(matrix)
771}
772
773fn covariance_dense(matrix: &Matrix, weight: &CovWeightSpec) -> BuiltinResult<Tensor> {
774 let cols = matrix.cols;
775 let rows = matrix.rows;
776
777 if cols == 0 {
778 return Tensor::new(Vec::new(), vec![0, 0]).map_err(cov_internal_error);
779 }
780
781 let mut result = vec![f64::NAN; cols * cols];
782
783 match weight {
784 CovWeightSpec::Scalar(normalization) => {
785 let denom = match normalization {
786 CovNormalization::Unbiased => (rows as f64) - 1.0,
787 CovNormalization::Biased => rows as f64,
788 };
789 if denom <= 0.0 {
790 return Tensor::new(result, vec![cols, cols]).map_err(cov_internal_error);
791 }
792
793 let mut means = vec![0.0; cols];
794 for (col, mean_slot) in means.iter_mut().enumerate() {
795 let column = matrix.column(col);
796 let mut sum = 0.0;
797 let mut valid = true;
798 for &value in column {
799 if !value.is_finite() {
800 valid = false;
801 break;
802 }
803 sum += value;
804 }
805 *mean_slot = if valid { sum / (rows as f64) } else { f64::NAN };
806 }
807
808 for i in 0..cols {
809 for j in i..cols {
810 let value = covariance_unweighted_pair(matrix, i, j, &means, denom);
811 set_entry(&mut result, cols, i, j, sanitize_covariance(i == j, value));
812 }
813 }
814 }
815 CovWeightSpec::Vector(weights) => {
816 if weights.len() != rows {
817 return Err(cov_error_with_detail(
818 &COV_ERROR_WEIGHT_VECTOR_LENGTH_MISMATCH,
819 format!("expected {rows} elements"),
820 ));
821 }
822 let sum_w: f64 = weights.iter().sum();
823 if sum_w <= 0.0 {
824 return Tensor::new(result, vec![cols, cols]).map_err(cov_internal_error);
825 }
826 let denom = sum_w - 1.0;
827 if denom <= 0.0 {
828 return Tensor::new(result, vec![cols, cols]).map_err(cov_internal_error);
829 }
830
831 let mut means = vec![0.0; cols];
832 for (col, mean_slot) in means.iter_mut().enumerate() {
833 let column = matrix.column(col);
834 let mut weighted_sum = 0.0;
835 let mut valid = true;
836 for (row, &value) in column.iter().enumerate() {
837 if !value.is_finite() {
838 valid = false;
839 break;
840 }
841 weighted_sum += weights[row] * value;
842 }
843 *mean_slot = if valid {
844 weighted_sum / sum_w
845 } else {
846 f64::NAN
847 };
848 }
849
850 for i in 0..cols {
851 for j in i..cols {
852 let value = covariance_weighted_pair(matrix, i, j, weights, &means, denom);
853 set_entry(&mut result, cols, i, j, sanitize_covariance(i == j, value));
854 }
855 }
856 }
857 }
858
859 Tensor::new(result, vec![cols, cols]).map_err(cov_internal_error)
860}
861
862fn filter_complete_rows(matrix: &Matrix, weight: CovWeightSpec) -> (Matrix, CovWeightSpec) {
863 if matrix.rows == 0 {
864 return (
865 Matrix {
866 data: Vec::new(),
867 rows: 0,
868 cols: matrix.cols,
869 },
870 weight,
871 );
872 }
873
874 let mut valid_rows = Vec::new();
875 for row in 0..matrix.rows {
876 let mut is_valid = true;
877 for col in 0..matrix.cols {
878 if !matrix.get(row, col).is_finite() {
879 is_valid = false;
880 break;
881 }
882 }
883 if is_valid {
884 valid_rows.push(row);
885 }
886 }
887
888 if valid_rows.len() == matrix.rows {
889 return (matrix.clone(), weight);
891 }
892
893 let mut data = Vec::with_capacity(valid_rows.len() * matrix.cols);
894 for col in 0..matrix.cols {
895 for &row in &valid_rows {
896 data.push(matrix.get(row, col));
897 }
898 }
899
900 let filtered_matrix = Matrix {
901 data,
902 rows: valid_rows.len(),
903 cols: matrix.cols,
904 };
905
906 let filtered_weight = match weight {
907 CovWeightSpec::Scalar(norm) => CovWeightSpec::Scalar(norm),
908 CovWeightSpec::Vector(vec) => {
909 let mut filtered = Vec::with_capacity(valid_rows.len());
910 for &row in &valid_rows {
911 filtered.push(vec[row]);
912 }
913 CovWeightSpec::Vector(filtered)
914 }
915 };
916
917 (filtered_matrix, filtered_weight)
918}
919
920fn covariance_pairwise(matrix: &Matrix, weight: &CovWeightSpec) -> BuiltinResult<Tensor> {
921 let cols = matrix.cols;
922 if cols == 0 {
923 return Tensor::new(Vec::new(), vec![0, 0]).map_err(cov_internal_error);
924 }
925 let mut result = vec![f64::NAN; cols * cols];
926 for i in 0..cols {
927 let variance = covariance_pair(matrix, i, i, weight);
928 set_entry(&mut result, cols, i, i, sanitize_covariance(true, variance));
929 for j in (i + 1)..cols {
930 let value = covariance_pair(matrix, i, j, weight);
931 set_entry(&mut result, cols, i, j, sanitize_covariance(false, value));
932 }
933 }
934 Tensor::new(result, vec![cols, cols]).map_err(cov_internal_error)
935}
936
937fn covariance_unweighted_pair(
938 matrix: &Matrix,
939 lhs: usize,
940 rhs: usize,
941 means: &[f64],
942 denom: f64,
943) -> f64 {
944 if !means[lhs].is_finite() || !means[rhs].is_finite() {
945 return f64::NAN;
946 }
947 let mut accumulator = 0.0;
948 for row in 0..matrix.rows {
949 let x = matrix.get(row, lhs);
950 let y = matrix.get(row, rhs);
951 if !x.is_finite() || !y.is_finite() {
952 return f64::NAN;
953 }
954 accumulator += (x - means[lhs]) * (y - means[rhs]);
955 }
956 accumulator / denom
957}
958
959fn covariance_weighted_pair(
960 matrix: &Matrix,
961 lhs: usize,
962 rhs: usize,
963 weights: &[f64],
964 means: &[f64],
965 denom: f64,
966) -> f64 {
967 if !means[lhs].is_finite() || !means[rhs].is_finite() {
968 return f64::NAN;
969 }
970 let mut accumulator = 0.0;
971 for (row, &weight) in weights.iter().enumerate().take(matrix.rows) {
972 if weight == 0.0 {
973 continue;
974 }
975 let x = matrix.get(row, lhs);
976 let y = matrix.get(row, rhs);
977 if !x.is_finite() || !y.is_finite() {
978 return f64::NAN;
979 }
980 accumulator += weight * (x - means[lhs]) * (y - means[rhs]);
981 }
982 accumulator / denom
983}
984
985fn covariance_pair(matrix: &Matrix, lhs: usize, rhs: usize, weight: &CovWeightSpec) -> f64 {
986 match weight {
987 CovWeightSpec::Scalar(normalization) => {
988 let mut xs = Vec::new();
989 let mut ys = Vec::new();
990 for row in 0..matrix.rows {
991 let x = matrix.get(row, lhs);
992 let y = matrix.get(row, rhs);
993 if x.is_finite() && y.is_finite() {
994 xs.push(x);
995 ys.push(y);
996 }
997 }
998 covariance_unweighted_slice(&xs, &ys, *normalization)
999 }
1000 CovWeightSpec::Vector(weights) => {
1001 let mut xs = Vec::new();
1002 let mut ys = Vec::new();
1003 let mut ws = Vec::new();
1004 for (row, &weight) in weights.iter().enumerate().take(matrix.rows) {
1005 let x = matrix.get(row, lhs);
1006 let y = matrix.get(row, rhs);
1007 if x.is_finite() && y.is_finite() {
1008 xs.push(x);
1009 ys.push(y);
1010 ws.push(weight);
1011 }
1012 }
1013 covariance_weighted_slice(&xs, &ys, &ws)
1014 }
1015 }
1016}
1017
1018fn covariance_unweighted_slice(xs: &[f64], ys: &[f64], normalization: CovNormalization) -> f64 {
1019 if xs.is_empty() || ys.is_empty() {
1020 return f64::NAN;
1021 }
1022 let n = xs.len().min(ys.len());
1023 if n == 0 {
1024 return f64::NAN;
1025 }
1026 let denom = match normalization {
1027 CovNormalization::Unbiased => (n as f64) - 1.0,
1028 CovNormalization::Biased => n as f64,
1029 };
1030 if denom <= 0.0 {
1031 return f64::NAN;
1032 }
1033 let sum_x: f64 = xs.iter().take(n).sum();
1034 let sum_y: f64 = ys.iter().take(n).sum();
1035 let mean_x = sum_x / (n as f64);
1036 let mean_y = sum_y / (n as f64);
1037 let mut accumulator = 0.0;
1038 for idx in 0..n {
1039 accumulator += (xs[idx] - mean_x) * (ys[idx] - mean_y);
1040 }
1041 accumulator / denom
1042}
1043
1044fn covariance_weighted_slice(xs: &[f64], ys: &[f64], weights: &[f64]) -> f64 {
1045 if xs.is_empty() || ys.is_empty() || weights.is_empty() {
1046 return f64::NAN;
1047 }
1048 let n = xs.len().min(ys.len()).min(weights.len());
1049 if n == 0 {
1050 return f64::NAN;
1051 }
1052 let sum_w: f64 = weights.iter().take(n).sum();
1053 if sum_w <= 0.0 {
1054 return f64::NAN;
1055 }
1056 let denom = sum_w - 1.0;
1057 if denom <= 0.0 {
1058 return f64::NAN;
1059 }
1060 let mut mean_x = 0.0;
1061 let mut mean_y = 0.0;
1062 for idx in 0..n {
1063 mean_x += weights[idx] * xs[idx];
1064 mean_y += weights[idx] * ys[idx];
1065 }
1066 mean_x /= sum_w;
1067 mean_y /= sum_w;
1068 let mut accumulator = 0.0;
1069 for idx in 0..n {
1070 accumulator += weights[idx] * (xs[idx] - mean_x) * (ys[idx] - mean_y);
1071 }
1072 accumulator / denom
1073}
1074
1075fn sanitize_covariance(is_diag: bool, value: f64) -> f64 {
1076 if !value.is_finite() {
1077 return value;
1078 }
1079 if is_diag && value < 0.0 && value > -1.0e-12 {
1080 0.0
1081 } else {
1082 value
1083 }
1084}
1085
1086fn set_entry(buffer: &mut [f64], dim: usize, row: usize, col: usize, value: f64) {
1087 let idx = row + col * dim;
1088 buffer[idx] = value;
1089 if row != col {
1090 let symmetrical = col + row * dim;
1091 buffer[symmetrical] = value;
1092 }
1093}
1094
1095#[cfg(test)]
1096pub(crate) mod tests {
1097 use super::*;
1098 use crate::builtins::common::test_support;
1099 use futures::executor::block_on;
1100 use runmat_builtins::{ResolveContext, Tensor, Type};
1101
1102 fn assert_tensor_close(actual: &Tensor, expected: &[f64], tol: f64) {
1103 let dim = (expected.len() as f64).sqrt() as usize;
1104 assert_eq!(actual.shape, vec![dim, dim], "unexpected tensor shape");
1105 for (idx, (&got, &want)) in actual.data.iter().zip(expected.iter()).enumerate() {
1106 if want.is_nan() {
1107 assert!(
1108 got.is_nan(),
1109 "expected NaN at linear index {idx}, found {got}"
1110 );
1111 } else {
1112 assert!(
1113 (got - want).abs() <= tol,
1114 "mismatch at linear index {idx}: got {got}, expected {want}"
1115 );
1116 }
1117 }
1118 }
1119
1120 #[test]
1121 fn cov_type_preserves_column_count() {
1122 let out = cov_type(
1123 &[Type::Tensor {
1124 shape: Some(vec![Some(5), Some(3)]),
1125 }],
1126 &ResolveContext::new(Vec::new()),
1127 );
1128 assert_eq!(
1129 out,
1130 Type::Tensor {
1131 shape: Some(vec![Some(3), Some(3)])
1132 }
1133 );
1134 }
1135
1136 #[test]
1137 fn cov_type_vector_returns_scalar() {
1138 let out = cov_type(
1139 &[Type::Tensor {
1140 shape: Some(vec![Some(1), Some(4)]),
1141 }],
1142 &ResolveContext::new(Vec::new()),
1143 );
1144 assert_eq!(out, Type::Num);
1145 }
1146
1147 #[test]
1148 fn cov_descriptor_signatures_cover_core_forms() {
1149 let labels: Vec<&str> = COV_DESCRIPTOR
1150 .signatures
1151 .iter()
1152 .map(|sig| sig.label)
1153 .collect();
1154 assert!(labels.contains(&"C = cov(X)"));
1155 assert!(labels.contains(&"C = cov(X, normalization)"));
1156 assert!(labels.contains(&"C = cov(X, Y, w, opt)"));
1157 }
1158
1159 #[cfg(feature = "wgpu")]
1160 fn cov_builtin_sync(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
1161 block_on(super::cov_builtin(value, rest))
1162 }
1163
1164 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1165 #[test]
1166 fn cov_matrix_basic() {
1167 let tensor = Tensor::new(
1168 vec![
1169 4.0, 4.2, 3.9, 4.3, 4.1, 2.0, 2.1, 2.0, 2.1, 2.2, 0.60, 0.59, 0.58, 0.62, 0.63,
1172 ],
1173 vec![5, 3],
1174 )
1175 .unwrap();
1176 let result = block_on(cov_builtin(Value::Tensor(tensor), Vec::new())).expect("cov");
1177 let tensor = match result {
1178 Value::Tensor(t) => t,
1179 other => panic!("expected tensor result, got {other:?}"),
1180 };
1181 let expected = [
1182 0.0250, 0.0075, 0.00175, 0.0075, 0.0070, 0.00135, 0.00175, 0.00135, 0.00043,
1185 ];
1186 assert_tensor_close(&tensor, &expected, 1.0e-6);
1187 }
1188
1189 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1190 #[test]
1191 fn cov_two_vectors() {
1192 let x = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
1193 let y = Tensor::new(vec![10.0, 11.0, 9.0, 12.0], vec![4, 1]).unwrap();
1194 let result = block_on(cov_builtin(Value::Tensor(x), vec![Value::Tensor(y)])).expect("cov");
1195 let tensor = match result {
1196 Value::Tensor(t) => t,
1197 other => panic!("expected tensor result, got {other:?}"),
1198 };
1199 let expected = [
1200 1.6666666666666667,
1201 0.6666666666666666, 0.6666666666666666,
1203 1.6666666666666667,
1204 ];
1205 assert_tensor_close(&tensor, &expected, 1.0e-6);
1206 }
1207
1208 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1209 #[test]
1210 fn cov_weighted_vector() {
1211 let tensor = Tensor::new(
1212 vec![
1213 4.0, 4.2, 3.9, 4.3, 4.1, 2.0, 2.1, 2.0, 2.1, 2.2,
1215 ],
1216 vec![5, 2],
1217 )
1218 .unwrap();
1219 let weights = Tensor::new(vec![1.0, 1.0, 1.0, 2.0, 2.0], vec![5, 1]).unwrap();
1220 let result = block_on(cov_builtin(
1221 Value::Tensor(tensor),
1222 vec![Value::Tensor(weights)],
1223 ))
1224 .expect("cov");
1225 let tensor = match result {
1226 Value::Tensor(t) => t,
1227 other => panic!("expected tensor result, got {other:?}"),
1228 };
1229 let expected = [
1230 0.022380952380952376,
1231 0.004999999999999994, 0.004999999999999994,
1233 0.006666666666666678,
1234 ];
1235 assert_tensor_close(&tensor, &expected, 1.0e-6);
1236 }
1237
1238 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1239 #[test]
1240 fn cov_omitrows() {
1241 let tensor = Tensor::new(
1242 vec![
1243 1.0,
1244 3.0,
1245 f64::NAN,
1246 8.0, f64::NAN,
1248 4.0,
1249 6.0,
1250 9.0, 2.0,
1252 5.0,
1253 7.0,
1254 10.0,
1255 ],
1256 vec![4, 3],
1257 )
1258 .unwrap();
1259 let result = block_on(cov_builtin(
1260 Value::Tensor(tensor),
1261 vec![Value::from("omitrows")],
1262 ))
1263 .expect("cov");
1264 let tensor = match result {
1265 Value::Tensor(t) => t,
1266 other => panic!("expected tensor result, got {other:?}"),
1267 };
1268 let expected = [
1269 12.5, 12.5, 12.5, 12.5, 12.5, 12.5, 12.5, 12.5, 12.5,
1272 ];
1273 assert_tensor_close(&tensor, &expected, 1.0e-6);
1274 }
1275
1276 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1277 #[test]
1278 fn cov_partialrows() {
1279 let tensor = Tensor::new(
1280 vec![
1281 1.0,
1282 4.0,
1283 7.0, 2.0,
1285 f64::NAN,
1286 8.0, f64::NAN,
1288 6.0,
1289 9.0,
1290 ],
1291 vec![3, 3],
1292 )
1293 .unwrap();
1294 let result = block_on(cov_builtin(
1295 Value::Tensor(tensor),
1296 vec![Value::from("partialrows")],
1297 ))
1298 .expect("cov");
1299 let tensor = match result {
1300 Value::Tensor(t) => t,
1301 other => panic!("expected tensor result, got {other:?}"),
1302 };
1303 let expected = [
1304 9.0,
1305 18.0,
1306 4.5, 18.0,
1308 18.0,
1309 f64::NAN, 4.5,
1311 f64::NAN,
1312 4.5,
1313 ];
1314 assert_tensor_close(&tensor, &expected, 1.0e-6);
1315 }
1316
1317 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1318 #[test]
1319 fn cov_mismatched_rows_errors() {
1320 let left = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
1321 let right = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1322 let err = block_on(cov_builtin(Value::Tensor(left), vec![Value::Tensor(right)]))
1323 .expect_err("expected mismatch error");
1324 assert_eq!(err.identifier(), COV_ERROR_ROWS_MISMATCH.identifier);
1325 }
1326
1327 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1328 #[test]
1329 fn cov_invalid_flag_errors() {
1330 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1331 let err = block_on(cov_builtin(Value::Tensor(tensor), vec![Value::Num(2.5)]))
1332 .expect_err("expected invalid flag error");
1333 assert_eq!(err.identifier(), COV_ERROR_NORMALIZATION_INVALID.identifier);
1334 }
1335
1336 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1337 #[test]
1338 fn cov_weight_vector_length_mismatch_errors() {
1339 let x = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]).unwrap();
1340 let y = Tensor::new(vec![10.0, 11.0, 12.0], vec![3, 1]).unwrap();
1341 let w = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1342 let err = block_on(cov_builtin(
1343 Value::Tensor(x),
1344 vec![Value::Tensor(y), Value::Tensor(w)],
1345 ))
1346 .expect_err("expected weight length mismatch");
1347 assert_eq!(
1348 err.identifier(),
1349 COV_ERROR_WEIGHT_VECTOR_LENGTH_MISMATCH.identifier
1350 );
1351 }
1352
1353 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1354 #[test]
1355 fn cov_unknown_rows_option_errors() {
1356 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1357 let err = block_on(cov_builtin(
1358 Value::Tensor(tensor),
1359 vec![Value::from("rows"), Value::from("bogus")],
1360 ))
1361 .expect_err("expected unknown rows option error");
1362 assert_eq!(err.identifier(), COV_ERROR_ROWS_OPTION_UNKNOWN.identifier);
1363 }
1364
1365 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1366 #[test]
1367 fn cov_duplicate_normalization_flag_errors() {
1368 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1369 let err = block_on(cov_builtin(
1370 Value::Tensor(tensor),
1371 vec![Value::Num(0.0), Value::Num(1.0)],
1372 ))
1373 .expect_err("expected duplicate normalization flag error");
1374 assert_eq!(
1375 err.identifier(),
1376 COV_ERROR_NORMALIZATION_DUPLICATE.identifier
1377 );
1378 }
1379
1380 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1381 #[test]
1382 fn cov_too_many_array_arguments_errors() {
1383 let x = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1384 let y = Tensor::new(vec![4.0, 5.0, 6.0], vec![3, 1]).unwrap();
1385 let w = Tensor::new(vec![1.0, 1.0, 1.0], vec![3, 1]).unwrap();
1386 let z = Tensor::new(vec![7.0, 8.0, 9.0], vec![3, 1]).unwrap();
1387 let err = block_on(cov_builtin(
1388 Value::Tensor(x),
1389 vec![Value::Tensor(y), Value::Tensor(w), Value::Tensor(z)],
1390 ))
1391 .expect_err("expected too many array arguments error");
1392 assert_eq!(
1393 err.identifier(),
1394 COV_ERROR_TOO_MANY_ARRAY_ARGUMENTS.identifier
1395 );
1396 }
1397
1398 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1399 #[test]
1400 fn cov_gpu_roundtrip() {
1401 test_support::with_test_provider(|provider| {
1402 let tensor = Tensor::new(
1403 vec![
1404 4.0, 4.2, 3.9, 4.3, 4.1, 2.0, 2.1, 2.0, 2.1, 2.2,
1406 ],
1407 vec![5, 2],
1408 )
1409 .unwrap();
1410 let view = runmat_accelerate_api::HostTensorView {
1411 data: &tensor.data,
1412 shape: &tensor.shape,
1413 };
1414 let handle = provider.upload(&view).expect("upload");
1415 let result = block_on(cov_builtin(Value::GpuTensor(handle), Vec::new())).expect("cov");
1416 let gathered = test_support::gather(result).expect("gather");
1417 let expected = [
1418 0.0250, 0.0075, 0.0075, 0.0070,
1420 ];
1421 assert_tensor_close(&gathered, &expected, 1.0e-6);
1422 });
1423 }
1424
1425 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1426 #[test]
1427 #[cfg(feature = "wgpu")]
1428 fn cov_wgpu_matches_cpu() {
1429 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1430 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1431 );
1432
1433 let tensor = Tensor::new(
1434 vec![
1435 4.0, 4.2, 3.9, 4.3, 4.1, 2.0, 2.1, 2.0, 2.1, 2.2,
1437 ],
1438 vec![5, 2],
1439 )
1440 .unwrap();
1441
1442 let cpu_result =
1443 block_on(cov_builtin(Value::Tensor(tensor.clone()), Vec::new())).expect("cov");
1444 let cpu_tensor = match cpu_result {
1445 Value::Tensor(t) => t,
1446 other => panic!("expected tensor result, got {other:?}"),
1447 };
1448
1449 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1450 let view = runmat_accelerate_api::HostTensorView {
1451 data: &tensor.data,
1452 shape: &tensor.shape,
1453 };
1454 let handle = provider.upload(&view).expect("upload");
1455
1456 let gpu_value = cov_builtin_sync(Value::GpuTensor(handle), Vec::new()).expect("cov");
1457 let gathered = test_support::gather(gpu_value).expect("gather");
1458
1459 assert_tensor_close(&gathered, &cpu_tensor.data, 1.0e-6);
1460 }
1461}