1use log::debug;
4use num_complex::Complex64;
5use runmat_accelerate_api::{HostTensorView, ProviderPolyvalMu, ProviderPolyvalOptions};
6use runmat_builtins::{
7 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
8 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
9 ComplexTensor, LogicalArray, Tensor, Value,
10};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::spec::{
14 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
16};
17use crate::builtins::common::{gpu_helpers, tensor};
18use crate::builtins::math::poly::type_resolvers::polyval_type;
19use crate::{build_runtime_error, dispatcher::download_handle_async, BuiltinResult, RuntimeError};
20
21const EPS: f64 = 1.0e-12;
22const BUILTIN_NAME: &str = "polyval";
23
24const POLYVAL_OUTPUT_Y: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
25 name: "y",
26 ty: BuiltinParamType::Any,
27 arity: BuiltinParamArity::Required,
28 default: None,
29 description: "Evaluated polynomial values at x.",
30}];
31
32const POLYVAL_OUTPUT_Y_DELTA: [BuiltinParamDescriptor; 2] = [
33 BuiltinParamDescriptor {
34 name: "y",
35 ty: BuiltinParamType::Any,
36 arity: BuiltinParamArity::Required,
37 default: None,
38 description: "Evaluated polynomial values at x.",
39 },
40 BuiltinParamDescriptor {
41 name: "delta",
42 ty: BuiltinParamType::Any,
43 arity: BuiltinParamArity::Required,
44 default: None,
45 description: "Prediction interval values when S is supplied.",
46 },
47];
48
49const POLYVAL_INPUTS: [BuiltinParamDescriptor; 2] = [
50 BuiltinParamDescriptor {
51 name: "p",
52 ty: BuiltinParamType::Any,
53 arity: BuiltinParamArity::Required,
54 default: None,
55 description: "Polynomial coefficient vector.",
56 },
57 BuiltinParamDescriptor {
58 name: "x",
59 ty: BuiltinParamType::Any,
60 arity: BuiltinParamArity::Required,
61 default: None,
62 description: "Evaluation points.",
63 },
64];
65
66const POLYVAL_INPUTS_WITH_S: [BuiltinParamDescriptor; 3] = [
67 BuiltinParamDescriptor {
68 name: "p",
69 ty: BuiltinParamType::Any,
70 arity: BuiltinParamArity::Required,
71 default: None,
72 description: "Polynomial coefficient vector.",
73 },
74 BuiltinParamDescriptor {
75 name: "x",
76 ty: BuiltinParamType::Any,
77 arity: BuiltinParamArity::Required,
78 default: None,
79 description: "Evaluation points.",
80 },
81 BuiltinParamDescriptor {
82 name: "S",
83 ty: BuiltinParamType::Any,
84 arity: BuiltinParamArity::Optional,
85 default: None,
86 description: "Optional polyfit statistics structure.",
87 },
88];
89
90const POLYVAL_INPUTS_WITH_S_MU: [BuiltinParamDescriptor; 4] = [
91 BuiltinParamDescriptor {
92 name: "p",
93 ty: BuiltinParamType::Any,
94 arity: BuiltinParamArity::Required,
95 default: None,
96 description: "Polynomial coefficient vector.",
97 },
98 BuiltinParamDescriptor {
99 name: "x",
100 ty: BuiltinParamType::Any,
101 arity: BuiltinParamArity::Required,
102 default: None,
103 description: "Evaluation points.",
104 },
105 BuiltinParamDescriptor {
106 name: "S",
107 ty: BuiltinParamType::Any,
108 arity: BuiltinParamArity::Optional,
109 default: None,
110 description: "Optional polyfit statistics structure (or []).",
111 },
112 BuiltinParamDescriptor {
113 name: "mu",
114 ty: BuiltinParamType::Any,
115 arity: BuiltinParamArity::Optional,
116 default: None,
117 description: "Optional centering/scaling vector [mean, std].",
118 },
119];
120
121const POLYVAL_SIGNATURES: [BuiltinSignatureDescriptor; 6] = [
122 BuiltinSignatureDescriptor {
123 label: "y = polyval(p, x)",
124 inputs: &POLYVAL_INPUTS,
125 outputs: &POLYVAL_OUTPUT_Y,
126 },
127 BuiltinSignatureDescriptor {
128 label: "y = polyval(p, x, S)",
129 inputs: &POLYVAL_INPUTS_WITH_S,
130 outputs: &POLYVAL_OUTPUT_Y,
131 },
132 BuiltinSignatureDescriptor {
133 label: "y = polyval(p, x, S, mu)",
134 inputs: &POLYVAL_INPUTS_WITH_S_MU,
135 outputs: &POLYVAL_OUTPUT_Y,
136 },
137 BuiltinSignatureDescriptor {
138 label: "[y, delta] = polyval(p, x)",
139 inputs: &POLYVAL_INPUTS,
140 outputs: &POLYVAL_OUTPUT_Y_DELTA,
141 },
142 BuiltinSignatureDescriptor {
143 label: "[y, delta] = polyval(p, x, S)",
144 inputs: &POLYVAL_INPUTS_WITH_S,
145 outputs: &POLYVAL_OUTPUT_Y_DELTA,
146 },
147 BuiltinSignatureDescriptor {
148 label: "[y, delta] = polyval(p, x, S, mu)",
149 inputs: &POLYVAL_INPUTS_WITH_S_MU,
150 outputs: &POLYVAL_OUTPUT_Y_DELTA,
151 },
152];
153
154const POLYVAL_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
155 code: "RM.POLYVAL.INVALID_ARGUMENT",
156 identifier: Some("RunMat:polyval:InvalidArgument"),
157 when: "Option arguments (S/mu/output arity) are malformed or unsupported.",
158 message: "polyval: invalid argument",
159};
160
161const POLYVAL_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
162 code: "RM.POLYVAL.INVALID_INPUT",
163 identifier: Some("RunMat:polyval:InvalidInput"),
164 when: "Polynomial coefficients or evaluation points cannot be interpreted as numeric inputs.",
165 message: "polyval: invalid input",
166};
167
168const POLYVAL_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
169 code: "RM.POLYVAL.INTERNAL",
170 identifier: Some("RunMat:polyval:Internal"),
171 when: "Runtime fails while building output tensors, deltas, or provider fallbacks.",
172 message: "polyval: internal runtime failure",
173};
174
175const POLYVAL_ERRORS: [BuiltinErrorDescriptor; 3] = [
176 POLYVAL_ERROR_INVALID_ARGUMENT,
177 POLYVAL_ERROR_INVALID_INPUT,
178 POLYVAL_ERROR_INTERNAL,
179];
180
181pub const POLYVAL_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
182 signatures: &POLYVAL_SIGNATURES,
183 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
184 completion_policy: BuiltinCompletionPolicy::Public,
185 errors: &POLYVAL_ERRORS,
186};
187
188#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::poly::polyval")]
189pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
190 name: "polyval",
191 op_kind: GpuOpKind::Custom("polyval"),
192 supported_precisions: &[ScalarType::F32, ScalarType::F64],
193 broadcast: BroadcastSemantics::Matlab,
194 provider_hooks: &[ProviderHook::Custom("polyval")],
195 constant_strategy: ConstantStrategy::UniformBuffer,
196 residency: ResidencyPolicy::NewHandle,
197 nan_mode: ReductionNaN::Include,
198 two_pass_threshold: None,
199 workgroup_size: None,
200 accepts_nan_mode: false,
201 notes:
202 "Uses provider-level Horner kernels for real coefficients/inputs; falls back to host evaluation (with upload) for complex or prediction-interval paths.",
203};
204
205fn polyval_error(message: impl Into<String>) -> RuntimeError {
206 polyval_error_with(message, &POLYVAL_ERROR_INVALID_INPUT)
207}
208
209fn polyval_argument_error(message: impl Into<String>) -> RuntimeError {
210 polyval_error_with(message, &POLYVAL_ERROR_INVALID_ARGUMENT)
211}
212
213fn polyval_error_with(
214 message: impl Into<String>,
215 error: &'static BuiltinErrorDescriptor,
216) -> RuntimeError {
217 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
218 if let Some(identifier) = error.identifier {
219 builder = builder.with_identifier(identifier);
220 }
221 builder.build()
222}
223
224#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::poly::polyval")]
225pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
226 name: "polyval",
227 shape: ShapeRequirements::Any,
228 constant_strategy: ConstantStrategy::UniformBuffer,
229 elementwise: None,
230 reduction: None,
231 emits_nan: true,
232 notes: "Acts as a fusion sink; real-valued workloads stay on device, while complex/delta paths gather to the host.",
233};
234
235#[runtime_builtin(
236 name = "polyval",
237 category = "math/poly",
238 summary = "Evaluate polynomials at specified points.",
239 keywords = "polyval,polynomial,polyfit,delta,gpu",
240 accel = "sink",
241 sink = true,
242 type_resolver(polyval_type),
243 descriptor(crate::builtins::math::poly::polyval::POLYVAL_DESCRIPTOR),
244 builtin_path = "crate::builtins::math::poly::polyval"
245)]
246async fn polyval_builtin(p: Value, x: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
247 if let Some(out_count) = crate::output_count::current_output_count() {
248 let eval = evaluate(p, x, &rest, out_count >= 2).await?;
249 if out_count == 0 {
250 return Ok(Value::OutputList(Vec::new()));
251 }
252 let mut outputs = vec![eval.value()];
253 if out_count >= 2 {
254 outputs.push(eval.delta()?);
255 }
256 return Ok(crate::output_count::output_list_with_padding(
257 out_count, outputs,
258 ));
259 }
260 let eval = evaluate(p, x, &rest, false).await?;
261 Ok(eval.value())
262}
263
264pub async fn evaluate(
266 coefficients: Value,
267 points: Value,
268 rest: &[Value],
269 want_delta: bool,
270) -> BuiltinResult<PolyvalEval> {
271 let options = parse_option_values(rest).await?;
272
273 let coeff_clone = coefficients.clone();
274 let points_clone = points.clone();
275
276 let coeff_was_gpu = matches!(coefficients, Value::GpuTensor(_));
277 let (coeffs, coeff_real) = convert_coefficients(coeff_clone).await?;
278
279 let (mut inputs, prefer_gpu_points) = convert_points(points_clone).await?;
280 let prefer_gpu_output = prefer_gpu_points || coeff_was_gpu;
281
282 let mu = match options.mu.clone() {
283 Some(mu_value) => Some(parse_mu(mu_value).await?),
284 None => None,
285 };
286
287 if prefer_gpu_output && !want_delta && options.s.is_none() {
288 if let Some(value) =
289 try_gpu_polyval(&coeffs, coeff_real, &inputs, mu, prefer_gpu_output).await?
290 {
291 return Ok(PolyvalEval::new(value, None));
292 }
293 }
294
295 if let Some(mu_val) = mu {
296 apply_mu(&mut inputs.data, mu_val)?;
297 }
298
299 let stats = if let Some(s_value) = options.s {
300 parse_stats(s_value, coeffs.len()).await?
301 } else {
302 None
303 };
304
305 if want_delta && stats.is_none() {
306 return Err(polyval_argument_error(
307 "polyval: S input (structure returned by polyfit) is required for delta output",
308 ));
309 }
310
311 if inputs.data.is_empty() {
312 let y = zeros_like(&inputs.shape, prefer_gpu_output)?;
313 let delta = if want_delta {
314 Some(zeros_like(&inputs.shape, prefer_gpu_output)?)
315 } else {
316 None
317 };
318 return Ok(PolyvalEval::new(y, delta));
319 }
320
321 if coeffs.is_empty() {
322 let zeros = zeros_like(&inputs.shape, prefer_gpu_output)?;
323 let delta = if want_delta {
324 Some(zeros_like(&inputs.shape, prefer_gpu_output)?)
325 } else {
326 None
327 };
328 return Ok(PolyvalEval::new(zeros, delta));
329 }
330
331 let output_real = coeff_real && inputs.all_real;
332 let values = evaluate_polynomial(&coeffs, &inputs.data);
333 let result_value = finalize_values(
334 &values,
335 &inputs.shape,
336 prefer_gpu_output,
337 output_real && values_are_real(&values),
338 )?;
339
340 let delta_value = if want_delta {
341 let stats = stats.expect("delta requires stats");
342 let delta = compute_prediction_interval(&coeffs, &inputs.data, &stats)?;
343 let prefer = prefer_gpu_output && stats.is_real;
344 Some(finalize_delta(delta, &inputs.shape, prefer)?)
345 } else {
346 None
347 };
348
349 Ok(PolyvalEval::new(result_value, delta_value))
350}
351
352async fn try_gpu_polyval(
353 coeffs: &[Complex64],
354 coeff_real: bool,
355 inputs: &NumericArray,
356 mu: Option<Mu>,
357 prefer_gpu_output: bool,
358) -> BuiltinResult<Option<Value>> {
359 if !coeff_real || !inputs.all_real {
360 return Ok(None);
361 }
362 if coeffs.is_empty() || inputs.data.is_empty() {
363 return Ok(None);
364 }
365 let Some(provider) = runmat_accelerate_api::provider() else {
366 return Ok(None);
367 };
368
369 let coeff_data: Vec<f64> = coeffs.iter().map(|c| c.re).collect();
370 let coeff_shape = vec![1usize, coeffs.len()];
371 let coeff_view = HostTensorView {
372 data: &coeff_data,
373 shape: &coeff_shape,
374 };
375 let coeff_handle = match provider.upload(&coeff_view) {
376 Ok(handle) => handle,
377 Err(err) => {
378 debug!("polyval: GPU upload of coefficients failed, falling back: {err}");
379 return Ok(None);
380 }
381 };
382
383 let input_data: Vec<f64> = inputs.data.iter().map(|c| c.re).collect();
384 let input_shape = inputs.shape.clone();
385 let input_view = HostTensorView {
386 data: &input_data,
387 shape: &input_shape,
388 };
389 let input_handle = match provider.upload(&input_view) {
390 Ok(handle) => handle,
391 Err(err) => {
392 debug!("polyval: GPU upload of evaluation points failed, falling back: {err}");
393 let _ = provider.free(&coeff_handle);
394 return Ok(None);
395 }
396 };
397
398 let options = ProviderPolyvalOptions {
399 mu: mu.map(|m| ProviderPolyvalMu {
400 mean: m.mean,
401 scale: m.scale,
402 }),
403 };
404
405 let result_handle = match provider.polyval(&coeff_handle, &input_handle, &options) {
406 Ok(handle) => handle,
407 Err(err) => {
408 debug!("polyval: GPU kernel execution failed, falling back: {err}");
409 let _ = provider.free(&coeff_handle);
410 let _ = provider.free(&input_handle);
411 return Ok(None);
412 }
413 };
414
415 let _ = provider.free(&coeff_handle);
416 let _ = provider.free(&input_handle);
417
418 if prefer_gpu_output {
419 return Ok(Some(Value::GpuTensor(result_handle)));
420 }
421
422 let host = match download_handle_async(provider, &result_handle).await {
423 Ok(host) => host,
424 Err(err) => {
425 debug!("polyval: GPU download failed, falling back: {err}");
426 let _ = provider.free(&result_handle);
427 return Ok(None);
428 }
429 };
430 let _ = provider.free(&result_handle);
431
432 let tensor =
433 Tensor::new(host.data, host.shape).map_err(|e| polyval_error(format!("polyval: {e}")))?;
434 Ok(Some(tensor::tensor_into_value(tensor)))
435}
436
437#[derive(Debug)]
439pub struct PolyvalEval {
440 value: Value,
441 delta: Option<Value>,
442}
443
444impl PolyvalEval {
445 fn new(value: Value, delta: Option<Value>) -> Self {
446 Self { value, delta }
447 }
448
449 pub fn value(&self) -> Value {
451 self.value.clone()
452 }
453
454 pub fn delta(&self) -> BuiltinResult<Value> {
456 self.delta
457 .clone()
458 .ok_or_else(|| polyval_argument_error("polyval: delta output not computed"))
459 }
460
461 pub fn into_value(self) -> Value {
463 self.value
464 }
465
466 pub fn into_pair(self) -> BuiltinResult<(Value, Value)> {
468 match self.delta {
469 Some(delta) => Ok((self.value, delta)),
470 None => Err(polyval_argument_error("polyval: delta output not computed")),
471 }
472 }
473}
474
475#[derive(Clone, Copy)]
476struct Mu {
477 mean: f64,
478 scale: f64,
479}
480
481impl Mu {
482 fn new(mean: f64, scale: f64) -> BuiltinResult<Self> {
483 if !mean.is_finite() || !scale.is_finite() {
484 return Err(polyval_error("polyval: mu values must be finite"));
485 }
486 if scale.abs() <= EPS {
487 return Err(polyval_error("polyval: mu(2) must be non-zero"));
488 }
489 Ok(Self { mean, scale })
490 }
491}
492
493#[derive(Clone)]
494struct NumericArray {
495 data: Vec<Complex64>,
496 shape: Vec<usize>,
497 all_real: bool,
498}
499
500#[derive(Clone)]
501struct PolyfitStats {
502 r: Matrix,
503 df: f64,
504 normr: f64,
505 is_real: bool,
506}
507
508impl PolyfitStats {
509 fn is_effective(&self) -> bool {
510 self.r.len() > 0 && self.df > 0.0 && self.normr.is_finite()
511 }
512}
513
514#[derive(Clone)]
515struct Matrix {
516 rows: usize,
517 cols: usize,
518 data: Vec<Complex64>,
519}
520
521impl Matrix {
522 fn get(&self, row: usize, col: usize) -> Complex64 {
523 self.data[row + col * self.rows]
524 }
525
526 fn len(&self) -> usize {
527 self.rows * self.cols
528 }
529}
530
531struct ParsedOptions {
532 s: Option<Value>,
533 mu: Option<Value>,
534}
535
536async fn parse_option_values(rest: &[Value]) -> BuiltinResult<ParsedOptions> {
537 match rest.len() {
538 0 => Ok(ParsedOptions { s: None, mu: None }),
539 1 => Ok(ParsedOptions {
540 s: if is_empty_value(&rest[0]).await? {
541 None
542 } else {
543 Some(rest[0].clone())
544 },
545 mu: None,
546 }),
547 2 => Ok(ParsedOptions {
548 s: if is_empty_value(&rest[0]).await? {
549 None
550 } else {
551 Some(rest[0].clone())
552 },
553 mu: Some(rest[1].clone()),
554 }),
555 _ => Err(polyval_argument_error("polyval: too many input arguments")),
556 }
557}
558
559#[async_recursion::async_recursion(?Send)]
560async fn convert_coefficients(value: Value) -> BuiltinResult<(Vec<Complex64>, bool)> {
561 match value {
562 Value::GpuTensor(handle) => {
563 let gathered =
564 gpu_helpers::gather_value_async(&Value::GpuTensor(handle.clone())).await?;
565 convert_coefficients(gathered).await
566 }
567 Value::Tensor(mut tensor) => {
568 ensure_vector_shape("polyval", &tensor.shape)?;
569 let data = tensor
570 .data
571 .drain(..)
572 .map(|re| Complex64::new(re, 0.0))
573 .collect();
574 Ok((data, true))
575 }
576 Value::ComplexTensor(mut tensor) => {
577 ensure_vector_shape("polyval", &tensor.shape)?;
578 let all_real = tensor.data.iter().all(|&(_, im)| im.abs() <= EPS);
579 let data = tensor
580 .data
581 .drain(..)
582 .map(|(re, im)| Complex64::new(re, im))
583 .collect();
584 Ok((data, all_real))
585 }
586 Value::LogicalArray(mut array) => {
587 ensure_vector_data_shape("polyval", &array.shape)?;
588 let data = array
589 .data
590 .drain(..)
591 .map(|bit| Complex64::new(if bit != 0 { 1.0 } else { 0.0 }, 0.0))
592 .collect();
593 Ok((data, true))
594 }
595 Value::Num(n) => Ok((vec![Complex64::new(n, 0.0)], true)),
596 Value::Int(i) => Ok((vec![Complex64::new(i.to_f64(), 0.0)], true)),
597 Value::Bool(flag) => Ok((
598 vec![Complex64::new(if flag { 1.0 } else { 0.0 }, 0.0)],
599 true,
600 )),
601 Value::Complex(re, im) => Ok((vec![Complex64::new(re, im)], im.abs() <= EPS)),
602 other => Err(polyval_error(format!(
603 "polyval: coefficients must be numeric, got {other:?}"
604 ))),
605 }
606}
607
608async fn convert_points(value: Value) -> BuiltinResult<(NumericArray, bool)> {
609 match value {
610 Value::GpuTensor(handle) => {
611 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
612 let array = NumericArray {
613 data: tensor
614 .data
615 .iter()
616 .map(|&re| Complex64::new(re, 0.0))
617 .collect(),
618 shape: tensor.shape.clone(),
619 all_real: true,
620 };
621 Ok((array, true))
622 }
623 Value::Tensor(tensor) => Ok((
624 NumericArray {
625 data: tensor
626 .data
627 .iter()
628 .map(|&re| Complex64::new(re, 0.0))
629 .collect(),
630 shape: tensor.shape.clone(),
631 all_real: true,
632 },
633 false,
634 )),
635 Value::ComplexTensor(tensor) => Ok((
636 NumericArray {
637 data: tensor
638 .data
639 .iter()
640 .map(|&(re, im)| Complex64::new(re, im))
641 .collect(),
642 shape: tensor.shape.clone(),
643 all_real: tensor.data.iter().all(|&(_, im)| im.abs() <= EPS),
644 },
645 false,
646 )),
647 Value::LogicalArray(array) => Ok((
648 NumericArray {
649 data: array
650 .data
651 .iter()
652 .map(|&bit| Complex64::new(if bit != 0 { 1.0 } else { 0.0 }, 0.0))
653 .collect(),
654 shape: array.shape.clone(),
655 all_real: true,
656 },
657 false,
658 )),
659 Value::Num(n) => Ok((
660 NumericArray {
661 data: vec![Complex64::new(n, 0.0)],
662 shape: vec![1, 1],
663 all_real: true,
664 },
665 false,
666 )),
667 Value::Int(i) => Ok((
668 NumericArray {
669 data: vec![Complex64::new(i.to_f64(), 0.0)],
670 shape: vec![1, 1],
671 all_real: true,
672 },
673 false,
674 )),
675 Value::Bool(flag) => Ok((
676 NumericArray {
677 data: vec![Complex64::new(if flag { 1.0 } else { 0.0 }, 0.0)],
678 shape: vec![1, 1],
679 all_real: true,
680 },
681 false,
682 )),
683 Value::Complex(re, im) => Ok((
684 NumericArray {
685 data: vec![Complex64::new(re, im)],
686 shape: vec![1, 1],
687 all_real: im.abs() <= EPS,
688 },
689 false,
690 )),
691 other => Err(polyval_error(format!(
692 "polyval: X must be numeric, got {other:?}"
693 ))),
694 }
695}
696
697#[async_recursion::async_recursion(?Send)]
698async fn parse_mu(value: Value) -> BuiltinResult<Mu> {
699 match value {
700 Value::GpuTensor(handle) => {
701 let gathered = gpu_helpers::gather_tensor_async(&handle).await?;
702 parse_mu(Value::Tensor(gathered)).await
703 }
704 Value::Tensor(tensor) => {
705 if tensor.data.len() < 2 {
706 return Err(polyval_error(
707 "polyval: mu must contain at least two elements",
708 ));
709 }
710 Mu::new(tensor.data[0], tensor.data[1])
711 }
712 Value::LogicalArray(array) => {
713 if array.data.len() < 2 {
714 return Err(polyval_error(
715 "polyval: mu must contain at least two elements",
716 ));
717 }
718 let mean = if array.data[0] != 0 { 1.0 } else { 0.0 };
719 let scale = if array.data[1] != 0 { 1.0 } else { 0.0 };
720 Mu::new(mean, scale)
721 }
722 Value::Num(_) | Value::Int(_) | Value::Bool(_) | Value::Complex(_, _) => Err(
723 polyval_error("polyval: mu must be a numeric vector with at least two values"),
724 ),
725 Value::ComplexTensor(tensor) => {
726 if tensor.data.len() < 2 {
727 return Err(polyval_error(
728 "polyval: mu must contain at least two elements",
729 ));
730 }
731 let (mean_re, mean_im) = tensor.data[0];
732 let (scale_re, scale_im) = tensor.data[1];
733 if mean_im.abs() > EPS || scale_im.abs() > EPS {
734 return Err(polyval_error("polyval: mu values must be real"));
735 }
736 Mu::new(mean_re, scale_re)
737 }
738 _ => Err(polyval_error(
739 "polyval: mu must be a numeric vector with at least two values",
740 )),
741 }
742}
743
744#[async_recursion::async_recursion(?Send)]
745async fn parse_stats(value: Value, coeff_len: usize) -> BuiltinResult<Option<PolyfitStats>> {
746 if is_empty_value(&value).await? {
747 return Ok(None);
748 }
749 let struct_value = match value {
750 Value::Struct(s) => s,
751 Value::GpuTensor(handle) => {
752 let gathered = gpu_helpers::gather_value_async(&Value::GpuTensor(handle)).await?;
753 return parse_stats(gathered, coeff_len).await;
754 }
755 other => {
756 return Err(polyval_error(format!(
757 "polyval: S input must be the structure returned by polyfit, got {other:?}"
758 )))
759 }
760 };
761 let r_value = struct_value
762 .fields
763 .get("R")
764 .cloned()
765 .ok_or_else(|| polyval_error("polyval: S input is missing the field 'R'"))?;
766 let df_value = struct_value
767 .fields
768 .get("df")
769 .cloned()
770 .ok_or_else(|| polyval_error("polyval: S input is missing the field 'df'"))?;
771 let normr_value = struct_value
772 .fields
773 .get("normr")
774 .cloned()
775 .ok_or_else(|| polyval_error("polyval: S input is missing the field 'normr'"))?;
776
777 let (matrix, is_real) = convert_matrix(r_value, coeff_len).await?;
778 let df = scalar_to_f64(df_value, "polyval: S.df").await?;
779 let normr = scalar_to_f64(normr_value, "polyval: S.normr").await?;
780
781 Ok(Some(PolyfitStats {
782 r: matrix,
783 df,
784 normr,
785 is_real,
786 }))
787}
788
789#[async_recursion::async_recursion(?Send)]
790async fn convert_matrix(value: Value, coeff_len: usize) -> BuiltinResult<(Matrix, bool)> {
791 match value {
792 Value::GpuTensor(handle) => {
793 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
794 convert_matrix(Value::Tensor(tensor), coeff_len).await
795 }
796 Value::Tensor(tensor) => {
797 let Tensor {
798 data, rows, cols, ..
799 } = tensor;
800 if rows != coeff_len || cols != coeff_len {
801 return Err(polyval_error("polyval: size of S.R must match the coefficient vector"));
802 }
803 let data = data.into_iter().map(|re| Complex64::new(re, 0.0)).collect();
804 Ok((Matrix { rows, cols, data }, true))
805 }
806 Value::ComplexTensor(tensor) => {
807 let ComplexTensor {
808 data, rows, cols, ..
809 } = tensor;
810 if rows != coeff_len || cols != coeff_len {
811 return Err(polyval_error("polyval: size of S.R must match the coefficient vector"));
812 }
813 let imag_small = data.iter().all(|&(_, im)| im.abs() <= EPS);
814 let data = data
815 .into_iter()
816 .map(|(re, im)| Complex64::new(re, im))
817 .collect();
818 Ok((Matrix { rows, cols, data }, imag_small))
819 }
820 Value::LogicalArray(array) => {
821 let LogicalArray { data, shape } = array;
822 let rows = shape.first().copied().unwrap_or(0);
823 let cols = shape.get(1).copied().unwrap_or(0);
824 if rows != coeff_len || cols != coeff_len {
825 return Err(polyval_error("polyval: size of S.R must match the coefficient vector"));
826 }
827 let data = data
828 .into_iter()
829 .map(|bit| Complex64::new(if bit != 0 { 1.0 } else { 0.0 }, 0.0))
830 .collect();
831 Ok((Matrix { rows, cols, data }, true))
832 }
833 Value::Num(_) | Value::Int(_) | Value::Bool(_) | Value::Complex(_, _) => Err(
834 polyval_error(
835 "polyval: S.R must be a square numeric matrix matching the coefficient vector length",
836 ),
837 ),
838 Value::Struct(_)
839 | Value::Cell(_)
840 | Value::String(_)
841 | Value::StringArray(_)
842 | Value::CharArray(_) => Err(
843 polyval_error(
844 "polyval: S.R must be a square numeric matrix matching the coefficient vector length",
845 ),
846 ),
847 _ => Err(
848 polyval_error(
849 "polyval: S.R must be a square numeric matrix matching the coefficient vector length",
850 ),
851 ),
852 }
853}
854
855#[async_recursion::async_recursion(?Send)]
856async fn scalar_to_f64(value: Value, context: &str) -> BuiltinResult<f64> {
857 match value {
858 Value::Num(n) => Ok(n),
859 Value::Int(i) => Ok(i.to_f64()),
860 Value::Bool(flag) => Ok(if flag { 1.0 } else { 0.0 }),
861 Value::Tensor(tensor) => {
862 if tensor.data.len() != 1 {
863 return Err(polyval_error(format!("{context} must be a scalar")));
864 }
865 Ok(tensor.data[0])
866 }
867 Value::LogicalArray(array) => {
868 if array.data.len() != 1 {
869 return Err(polyval_error(format!("{context} must be a scalar")));
870 }
871 Ok(if array.data[0] != 0 { 1.0 } else { 0.0 })
872 }
873 Value::GpuTensor(handle) => {
874 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
875 scalar_to_f64(Value::Tensor(tensor), context).await
876 }
877 Value::Complex(_, _) | Value::ComplexTensor(_) => {
878 Err(polyval_error(format!("{context} must be real-valued")))
879 }
880 other => Err(polyval_error(format!(
881 "{context} must be a scalar, got {other:?}"
882 ))),
883 }
884}
885
886fn apply_mu(values: &mut [Complex64], mu: Mu) -> BuiltinResult<()> {
887 let mean = Complex64::new(mu.mean, 0.0);
888 let scale = Complex64::new(mu.scale, 0.0);
889 for v in values.iter_mut() {
890 *v = (*v - mean) / scale;
891 }
892 Ok(())
893}
894
895fn evaluate_polynomial(coeffs: &[Complex64], inputs: &[Complex64]) -> Vec<Complex64> {
896 let mut outputs = Vec::with_capacity(inputs.len());
897 for &x in inputs {
898 let mut acc = Complex64::new(0.0, 0.0);
899 for &c in coeffs {
900 acc = acc * x + c;
901 }
902 outputs.push(acc);
903 }
904 outputs
905}
906
907fn compute_prediction_interval(
908 coeffs: &[Complex64],
909 inputs: &[Complex64],
910 stats: &PolyfitStats,
911) -> BuiltinResult<Vec<f64>> {
912 if !stats.is_effective() {
913 return Ok(vec![0.0; inputs.len()]);
914 }
915 let n = coeffs.len();
916 let mut delta = Vec::with_capacity(inputs.len());
917 for &x in inputs {
918 let row = vandermonde_row(x, n);
919 let solved = solve_row_against_upper(&row, &stats.r)?;
920 let sum_sq: f64 = solved.iter().map(|c| c.norm_sqr()).sum();
921 let interval = (1.0 + sum_sq).sqrt() * (stats.normr / stats.df.sqrt());
922 delta.push(interval);
923 }
924 Ok(delta)
925}
926
927fn vandermonde_row(x: Complex64, len: usize) -> Vec<Complex64> {
928 if len == 0 {
929 return vec![Complex64::new(1.0, 0.0)];
930 }
931 let degree = len - 1;
932 let mut powers = vec![Complex64::new(1.0, 0.0); degree + 1];
933 for idx in 1..=degree {
934 powers[idx] = powers[idx - 1] * x;
935 }
936 let mut row = vec![Complex64::new(0.0, 0.0); degree + 1];
937 for (i, value) in powers.into_iter().enumerate() {
938 row[degree - i] = value;
939 }
940 row
941}
942
943fn solve_row_against_upper(row: &[Complex64], matrix: &Matrix) -> BuiltinResult<Vec<Complex64>> {
944 let n = row.len();
945 if matrix.rows != n || matrix.cols != n {
946 return Err(polyval_error(
947 "polyval: size of S.R must match the coefficient vector",
948 ));
949 }
950 let mut result = vec![Complex64::new(0.0, 0.0); n];
951 for j in (0..n).rev() {
952 let mut acc = row[j];
953 for (k, value) in result.iter().enumerate().skip(j + 1) {
954 acc -= *value * matrix.get(k, j);
955 }
956 let diag = matrix.get(j, j);
957 if diag.norm() <= EPS {
958 return Err(polyval_error("polyval: S.R is singular"));
959 }
960 result[j] = acc / diag;
961 }
962 Ok(result)
963}
964
965fn finalize_values(
966 data: &[Complex64],
967 shape: &[usize],
968 prefer_gpu: bool,
969 real_only: bool,
970) -> BuiltinResult<Value> {
971 if real_only {
972 let real_data: Vec<f64> = data.iter().map(|c| c.re).collect();
973 finalize_real(real_data, shape, prefer_gpu)
974 } else if data.len() == 1 {
975 let value = data[0];
976 Ok(Value::Complex(value.re, value.im))
977 } else {
978 let complex_data: Vec<(f64, f64)> = data.iter().map(|c| (c.re, c.im)).collect();
979 let tensor = ComplexTensor::new(complex_data, shape.to_vec())
980 .map_err(|e| polyval_error(format!("polyval: failed to build complex tensor: {e}")))?;
981 Ok(Value::ComplexTensor(tensor))
982 }
983}
984
985fn finalize_delta(data: Vec<f64>, shape: &[usize], prefer_gpu: bool) -> BuiltinResult<Value> {
986 finalize_real(data, shape, prefer_gpu)
987}
988
989fn finalize_real(data: Vec<f64>, shape: &[usize], prefer_gpu: bool) -> BuiltinResult<Value> {
990 let tensor = Tensor::new(data, shape.to_vec())
991 .map_err(|e| polyval_error(format!("polyval: failed to build tensor: {e}")))?;
992 if prefer_gpu {
993 if let Some(provider) = runmat_accelerate_api::provider() {
994 let view = HostTensorView {
995 data: &tensor.data,
996 shape: &tensor.shape,
997 };
998 if let Ok(handle) = provider.upload(&view) {
999 return Ok(Value::GpuTensor(handle));
1000 }
1001 }
1002 }
1003 Ok(tensor::tensor_into_value(tensor))
1004}
1005
1006fn zeros_like(shape: &[usize], prefer_gpu: bool) -> BuiltinResult<Value> {
1007 let len = shape.iter().product();
1008 finalize_real(vec![0.0; len], shape, prefer_gpu)
1009}
1010
1011fn ensure_vector_shape(name: &str, shape: &[usize]) -> BuiltinResult<()> {
1012 if !is_vector_shape(shape) {
1013 Err(polyval_error(format!(
1014 "{name}: coefficients must be a scalar, row vector, or column vector"
1015 )))
1016 } else {
1017 Ok(())
1018 }
1019}
1020
1021fn ensure_vector_data_shape(name: &str, shape: &[usize]) -> BuiltinResult<()> {
1022 if !is_vector_shape(shape) {
1023 Err(polyval_error(format!(
1024 "{name}: inputs must be vectors or scalars"
1025 )))
1026 } else {
1027 Ok(())
1028 }
1029}
1030
1031fn is_vector_shape(shape: &[usize]) -> bool {
1032 shape.iter().filter(|&&dim| dim > 1).count() <= 1
1033}
1034
1035#[async_recursion::async_recursion(?Send)]
1036async fn is_empty_value(value: &Value) -> BuiltinResult<bool> {
1037 match value {
1038 Value::Tensor(t) => Ok(t.data.is_empty()),
1039 Value::LogicalArray(l) => Ok(l.data.is_empty()),
1040 Value::Cell(ca) => Ok(ca.data.is_empty()),
1041 Value::GpuTensor(handle) => {
1042 let gathered =
1043 gpu_helpers::gather_value_async(&Value::GpuTensor(handle.clone())).await?;
1044 is_empty_value(&gathered).await
1045 }
1046 _ => Ok(false),
1047 }
1048}
1049
1050fn values_are_real(values: &[Complex64]) -> bool {
1051 values.iter().all(|c| c.im.abs() <= EPS)
1052}
1053
1054#[cfg(test)]
1055pub(crate) mod tests {
1056 use super::*;
1057 use crate::builtins::common::test_support;
1058 use futures::executor::block_on;
1059 use runmat_builtins::StructValue;
1060
1061 fn assert_error_contains(err: crate::RuntimeError, needle: &str) {
1062 assert!(
1063 err.message().contains(needle),
1064 "expected error containing '{needle}', got '{}'",
1065 err.message()
1066 );
1067 }
1068
1069 #[test]
1070 fn polyval_descriptor_signatures_cover_core_forms() {
1071 let labels: Vec<&str> = POLYVAL_DESCRIPTOR
1072 .signatures
1073 .iter()
1074 .map(|signature| signature.label)
1075 .collect();
1076 assert!(labels.contains(&"y = polyval(p, x)"));
1077 assert!(labels.contains(&"y = polyval(p, x, S)"));
1078 assert!(labels.contains(&"y = polyval(p, x, S, mu)"));
1079 assert!(labels.contains(&"[y, delta] = polyval(p, x, S)"));
1080 }
1081
1082 #[test]
1083 fn polyval_descriptor_errors_have_stable_codes() {
1084 let codes: Vec<&str> = POLYVAL_DESCRIPTOR
1085 .errors
1086 .iter()
1087 .map(|error| error.code)
1088 .collect();
1089 assert!(codes.contains(&"RM.POLYVAL.INVALID_ARGUMENT"));
1090 assert!(codes.contains(&"RM.POLYVAL.INVALID_INPUT"));
1091 assert!(codes.contains(&"RM.POLYVAL.INTERNAL"));
1092 }
1093
1094 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1095 #[test]
1096 fn polyval_scalar() {
1097 let coeffs = Tensor::new(vec![2.0, -3.0, 5.0], vec![1, 3]).unwrap();
1098 let value =
1099 polyval_builtin(Value::Tensor(coeffs), Value::Num(4.0), Vec::new()).expect("polyval");
1100 match value {
1101 Value::Num(n) => assert!((n - (2.0 * 16.0 - 12.0 + 5.0)).abs() < 1e-12),
1102 other => panic!("expected scalar, got {other:?}"),
1103 }
1104 }
1105
1106 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1107 #[test]
1108 fn polyval_matrix_input() {
1109 let coeffs = Tensor::new(vec![1.0, 0.0, -2.0, 1.0], vec![1, 4]).unwrap();
1110 let points = Tensor::new(vec![-2.0, -1.0, 0.0, 1.0, 2.0], vec![5, 1]).unwrap();
1111 let value = polyval_builtin(
1112 Value::Tensor(coeffs),
1113 Value::Tensor(points.clone()),
1114 Vec::new(),
1115 )
1116 .expect("polyval");
1117 match value {
1118 Value::Tensor(tensor) => {
1119 assert_eq!(tensor.shape, points.shape);
1120 let expected = vec![-3.0, 2.0, 1.0, 0.0, 5.0];
1121 assert_eq!(tensor.data, expected);
1122 }
1123 other => panic!("expected tensor output, got {other:?}"),
1124 }
1125 }
1126
1127 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1128 #[test]
1129 fn polyval_complex_inputs() {
1130 let coeffs =
1131 ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.0), (0.0, 4.0)], vec![1, 3]).unwrap();
1132 let points =
1133 ComplexTensor::new(vec![(-1.0, 1.0), (0.0, 0.0), (1.0, -2.0)], vec![1, 3]).unwrap();
1134 let value = polyval_builtin(
1135 Value::ComplexTensor(coeffs),
1136 Value::ComplexTensor(points.clone()),
1137 Vec::new(),
1138 )
1139 .expect("polyval");
1140 match value {
1141 Value::ComplexTensor(tensor) => {
1142 assert_eq!(tensor.shape, points.shape);
1143 assert_eq!(tensor.data.len(), 3);
1144 }
1145 other => panic!("expected complex tensor, got {other:?}"),
1146 }
1147 }
1148
1149 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1150 #[test]
1151 fn polyval_with_mu() {
1152 let coeffs = Tensor::new(vec![1.0, 0.0, 0.0], vec![1, 3]).unwrap();
1153 let points = Tensor::new(vec![0.0, 1.0, 2.0], vec![1, 3]).unwrap();
1154 let mu = Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
1155 let value = polyval_builtin(
1156 Value::Tensor(coeffs),
1157 Value::Tensor(points),
1158 vec![
1159 Value::Tensor(Tensor::new(vec![], vec![0, 0]).unwrap()),
1160 Value::Tensor(mu),
1161 ],
1162 )
1163 .expect("polyval");
1164 match value {
1165 Value::Tensor(tensor) => {
1166 assert_eq!(tensor.data, vec![0.25, 0.0, 0.25]);
1167 }
1168 other => panic!("expected tensor output, got {other:?}"),
1169 }
1170 }
1171
1172 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1173 #[test]
1174 fn polyval_delta_computation() {
1175 let coeffs = Tensor::new(vec![1.0, -3.0, 2.0], vec![1, 3]).unwrap();
1176 let points = Tensor::new(vec![0.0, 1.0, 2.0], vec![1, 3]).unwrap();
1177 let mut st = StructValue::new();
1178 let r = Tensor::new(
1179 vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
1180 vec![3, 3],
1181 )
1182 .unwrap();
1183 st.fields.insert("R".to_string(), Value::Tensor(r));
1184 st.fields.insert("df".to_string(), Value::Num(4.0));
1185 st.fields.insert("normr".to_string(), Value::Num(2.0));
1186 let stats = Value::Struct(st);
1187 let eval = futures::executor::block_on(evaluate(
1188 Value::Tensor(coeffs),
1189 Value::Tensor(points),
1190 &[stats],
1191 true,
1192 ))
1193 .expect("polyval");
1194 let (_, delta) = eval.into_pair().expect("delta available");
1195 match delta {
1196 Value::Tensor(tensor) => {
1197 assert_eq!(tensor.shape, vec![1, 3]);
1198 assert_eq!(tensor.data.len(), 3);
1199 }
1200 other => panic!("expected tensor delta, got {other:?}"),
1201 }
1202 }
1203
1204 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1205 #[test]
1206 fn polyval_delta_requires_stats() {
1207 let coeffs = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
1208 let points = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1209 let err = futures::executor::block_on(evaluate(
1210 Value::Tensor(coeffs),
1211 Value::Tensor(points),
1212 &[],
1213 true,
1214 ))
1215 .expect_err("expected error");
1216 assert_error_contains(err, "S input");
1217 }
1218
1219 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1220 #[test]
1221 fn polyval_invalid_mu_length_errors() {
1222 let coeffs = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
1223 let points = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
1224 let mu = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1225 let placeholder = Tensor::new(vec![], vec![0, 0]).unwrap();
1226 let err = polyval_builtin(
1227 Value::Tensor(coeffs),
1228 Value::Tensor(points),
1229 vec![Value::Tensor(placeholder), Value::Tensor(mu)],
1230 )
1231 .expect_err("expected mu length error");
1232 assert_error_contains(err, "mu must contain at least two elements");
1233 }
1234
1235 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1236 #[test]
1237 fn polyval_rejects_excess_optional_arguments() {
1238 let coeffs = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
1239 let points = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
1240 let err = polyval_builtin(
1241 Value::Tensor(coeffs),
1242 Value::Tensor(points),
1243 vec![Value::Num(1.0), Value::Num(2.0), Value::Num(3.0)],
1244 )
1245 .expect_err("expected too many arguments error");
1246 assert_eq!(err.identifier(), POLYVAL_ERROR_INVALID_ARGUMENT.identifier);
1247 assert_error_contains(err, "too many input arguments");
1248 }
1249
1250 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1251 #[test]
1252 fn polyval_complex_mu_rejected() {
1253 let coeffs = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
1254 let points = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
1255 let complex_mu =
1256 ComplexTensor::new(vec![(0.0, 0.0), (1.0, 0.5)], vec![1, 2]).expect("complex mu");
1257 let placeholder = Tensor::new(vec![], vec![0, 0]).unwrap();
1258 let err = polyval_builtin(
1259 Value::Tensor(coeffs),
1260 Value::Tensor(points),
1261 vec![Value::Tensor(placeholder), Value::ComplexTensor(complex_mu)],
1262 )
1263 .expect_err("expected complex mu error");
1264 assert_error_contains(err, "mu values must be real");
1265 }
1266
1267 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1268 #[test]
1269 fn polyval_invalid_stats_missing_r() {
1270 let coeffs = Tensor::new(vec![1.0, -3.0, 2.0], vec![1, 3]).unwrap();
1271 let points = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
1272 let mut st = StructValue::new();
1273 st.fields.insert("df".to_string(), Value::Num(1.0));
1274 st.fields.insert("normr".to_string(), Value::Num(1.0));
1275 let stats = Value::Struct(st);
1276 let err = polyval_builtin(Value::Tensor(coeffs), Value::Tensor(points), vec![stats])
1277 .expect_err("expected missing R error");
1278 assert_error_contains(err, "missing the field 'R'");
1279 }
1280
1281 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1282 #[test]
1283 fn polyval_gpu_roundtrip() {
1284 test_support::with_test_provider(|provider| {
1285 let coeffs = Tensor::new(vec![1.0, 0.0, 1.0], vec![1, 3]).unwrap();
1286 let points = Tensor::new(vec![-1.0, 0.0, 1.0], vec![3, 1]).unwrap();
1287 let coeff_handle = provider
1288 .upload(&HostTensorView {
1289 data: &coeffs.data,
1290 shape: &coeffs.shape,
1291 })
1292 .expect("upload coeff");
1293 let point_handle = provider
1294 .upload(&HostTensorView {
1295 data: &points.data,
1296 shape: &points.shape,
1297 })
1298 .expect("upload points");
1299 let value = polyval_builtin(
1300 Value::GpuTensor(coeff_handle),
1301 Value::GpuTensor(point_handle),
1302 Vec::new(),
1303 )
1304 .expect("polyval");
1305 match value {
1306 Value::GpuTensor(handle) => {
1307 let gathered = test_support::gather(Value::GpuTensor(handle)).expect("gather");
1308 assert_eq!(gathered.data, vec![2.0, 1.0, 2.0]);
1309 }
1310 other => panic!("expected gpu tensor, got {other:?}"),
1311 }
1312 });
1313 }
1314
1315 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1316 #[test]
1317 #[cfg(feature = "wgpu")]
1318 fn polyval_wgpu_matches_cpu_real_inputs() {
1319 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1320 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1321 );
1322 let coeffs = Tensor::new(vec![1.0, -3.0, 2.0], vec![1, 3]).unwrap();
1323 let points = Tensor::new(vec![-2.0, -1.0, 0.5, 2.5], vec![4, 1]).unwrap();
1324
1325 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1326 let coeff_handle = provider
1327 .upload(&HostTensorView {
1328 data: &coeffs.data,
1329 shape: &coeffs.shape,
1330 })
1331 .expect("upload coeffs");
1332 let point_handle = provider
1333 .upload(&HostTensorView {
1334 data: &points.data,
1335 shape: &points.shape,
1336 })
1337 .expect("upload points");
1338
1339 let gpu_value = polyval_builtin(
1340 Value::GpuTensor(coeff_handle.clone()),
1341 Value::GpuTensor(point_handle.clone()),
1342 Vec::new(),
1343 )
1344 .expect("polyval gpu");
1345
1346 let _ = provider.free(&coeff_handle);
1347 let _ = provider.free(&point_handle);
1348
1349 let gathered = test_support::gather(gpu_value).expect("gather");
1350
1351 let coeff_complex: Vec<Complex64> = coeffs
1352 .data
1353 .iter()
1354 .map(|&c| Complex64::new(c, 0.0))
1355 .collect();
1356 let point_complex: Vec<Complex64> = points
1357 .data
1358 .iter()
1359 .map(|&x| Complex64::new(x, 0.0))
1360 .collect();
1361 let expected_vals = evaluate_polynomial(&coeff_complex, &point_complex);
1362 let expected: Vec<f64> = expected_vals.iter().map(|c| c.re).collect();
1363
1364 assert_eq!(gathered.shape, vec![4, 1]);
1365 assert_eq!(gathered.data, expected);
1366 }
1367
1368 fn polyval_builtin(p: Value, x: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
1369 block_on(super::polyval_builtin(p, x, rest))
1370 }
1371}