Skip to main content

runmat_runtime/builtins/math/poly/
polyval.rs

1//! MATLAB-compatible `polyval` builtin with GPU-aware semantics for RunMat.
2
3use log::debug;
4use num_complex::Complex64;
5use runmat_accelerate_api::{HostTensorView, ProviderPolyvalMu, ProviderPolyvalOptions};
6use runmat_builtins::{
7    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
8    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
9    ComplexTensor, LogicalArray, Tensor, Value,
10};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::spec::{
14    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
16};
17use crate::builtins::common::{gpu_helpers, tensor};
18use crate::builtins::math::poly::type_resolvers::polyval_type;
19use crate::{build_runtime_error, dispatcher::download_handle_async, BuiltinResult, RuntimeError};
20
21const EPS: f64 = 1.0e-12;
22const BUILTIN_NAME: &str = "polyval";
23
24const POLYVAL_OUTPUT_Y: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
25    name: "y",
26    ty: BuiltinParamType::Any,
27    arity: BuiltinParamArity::Required,
28    default: None,
29    description: "Evaluated polynomial values at x.",
30}];
31
32const POLYVAL_OUTPUT_Y_DELTA: [BuiltinParamDescriptor; 2] = [
33    BuiltinParamDescriptor {
34        name: "y",
35        ty: BuiltinParamType::Any,
36        arity: BuiltinParamArity::Required,
37        default: None,
38        description: "Evaluated polynomial values at x.",
39    },
40    BuiltinParamDescriptor {
41        name: "delta",
42        ty: BuiltinParamType::Any,
43        arity: BuiltinParamArity::Required,
44        default: None,
45        description: "Prediction interval values when S is supplied.",
46    },
47];
48
49const POLYVAL_INPUTS: [BuiltinParamDescriptor; 2] = [
50    BuiltinParamDescriptor {
51        name: "p",
52        ty: BuiltinParamType::Any,
53        arity: BuiltinParamArity::Required,
54        default: None,
55        description: "Polynomial coefficient vector.",
56    },
57    BuiltinParamDescriptor {
58        name: "x",
59        ty: BuiltinParamType::Any,
60        arity: BuiltinParamArity::Required,
61        default: None,
62        description: "Evaluation points.",
63    },
64];
65
66const POLYVAL_INPUTS_WITH_S: [BuiltinParamDescriptor; 3] = [
67    BuiltinParamDescriptor {
68        name: "p",
69        ty: BuiltinParamType::Any,
70        arity: BuiltinParamArity::Required,
71        default: None,
72        description: "Polynomial coefficient vector.",
73    },
74    BuiltinParamDescriptor {
75        name: "x",
76        ty: BuiltinParamType::Any,
77        arity: BuiltinParamArity::Required,
78        default: None,
79        description: "Evaluation points.",
80    },
81    BuiltinParamDescriptor {
82        name: "S",
83        ty: BuiltinParamType::Any,
84        arity: BuiltinParamArity::Optional,
85        default: None,
86        description: "Optional polyfit statistics structure.",
87    },
88];
89
90const POLYVAL_INPUTS_WITH_S_MU: [BuiltinParamDescriptor; 4] = [
91    BuiltinParamDescriptor {
92        name: "p",
93        ty: BuiltinParamType::Any,
94        arity: BuiltinParamArity::Required,
95        default: None,
96        description: "Polynomial coefficient vector.",
97    },
98    BuiltinParamDescriptor {
99        name: "x",
100        ty: BuiltinParamType::Any,
101        arity: BuiltinParamArity::Required,
102        default: None,
103        description: "Evaluation points.",
104    },
105    BuiltinParamDescriptor {
106        name: "S",
107        ty: BuiltinParamType::Any,
108        arity: BuiltinParamArity::Optional,
109        default: None,
110        description: "Optional polyfit statistics structure (or []).",
111    },
112    BuiltinParamDescriptor {
113        name: "mu",
114        ty: BuiltinParamType::Any,
115        arity: BuiltinParamArity::Optional,
116        default: None,
117        description: "Optional centering/scaling vector [mean, std].",
118    },
119];
120
121const POLYVAL_SIGNATURES: [BuiltinSignatureDescriptor; 6] = [
122    BuiltinSignatureDescriptor {
123        label: "y = polyval(p, x)",
124        inputs: &POLYVAL_INPUTS,
125        outputs: &POLYVAL_OUTPUT_Y,
126    },
127    BuiltinSignatureDescriptor {
128        label: "y = polyval(p, x, S)",
129        inputs: &POLYVAL_INPUTS_WITH_S,
130        outputs: &POLYVAL_OUTPUT_Y,
131    },
132    BuiltinSignatureDescriptor {
133        label: "y = polyval(p, x, S, mu)",
134        inputs: &POLYVAL_INPUTS_WITH_S_MU,
135        outputs: &POLYVAL_OUTPUT_Y,
136    },
137    BuiltinSignatureDescriptor {
138        label: "[y, delta] = polyval(p, x)",
139        inputs: &POLYVAL_INPUTS,
140        outputs: &POLYVAL_OUTPUT_Y_DELTA,
141    },
142    BuiltinSignatureDescriptor {
143        label: "[y, delta] = polyval(p, x, S)",
144        inputs: &POLYVAL_INPUTS_WITH_S,
145        outputs: &POLYVAL_OUTPUT_Y_DELTA,
146    },
147    BuiltinSignatureDescriptor {
148        label: "[y, delta] = polyval(p, x, S, mu)",
149        inputs: &POLYVAL_INPUTS_WITH_S_MU,
150        outputs: &POLYVAL_OUTPUT_Y_DELTA,
151    },
152];
153
154const POLYVAL_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
155    code: "RM.POLYVAL.INVALID_ARGUMENT",
156    identifier: Some("RunMat:polyval:InvalidArgument"),
157    when: "Option arguments (S/mu/output arity) are malformed or unsupported.",
158    message: "polyval: invalid argument",
159};
160
161const POLYVAL_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
162    code: "RM.POLYVAL.INVALID_INPUT",
163    identifier: Some("RunMat:polyval:InvalidInput"),
164    when: "Polynomial coefficients or evaluation points cannot be interpreted as numeric inputs.",
165    message: "polyval: invalid input",
166};
167
168const POLYVAL_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
169    code: "RM.POLYVAL.INTERNAL",
170    identifier: Some("RunMat:polyval:Internal"),
171    when: "Runtime fails while building output tensors, deltas, or provider fallbacks.",
172    message: "polyval: internal runtime failure",
173};
174
175const POLYVAL_ERRORS: [BuiltinErrorDescriptor; 3] = [
176    POLYVAL_ERROR_INVALID_ARGUMENT,
177    POLYVAL_ERROR_INVALID_INPUT,
178    POLYVAL_ERROR_INTERNAL,
179];
180
181pub const POLYVAL_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
182    signatures: &POLYVAL_SIGNATURES,
183    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
184    completion_policy: BuiltinCompletionPolicy::Public,
185    errors: &POLYVAL_ERRORS,
186};
187
188#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::poly::polyval")]
189pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
190    name: "polyval",
191    op_kind: GpuOpKind::Custom("polyval"),
192    supported_precisions: &[ScalarType::F32, ScalarType::F64],
193    broadcast: BroadcastSemantics::Matlab,
194    provider_hooks: &[ProviderHook::Custom("polyval")],
195    constant_strategy: ConstantStrategy::UniformBuffer,
196    residency: ResidencyPolicy::NewHandle,
197    nan_mode: ReductionNaN::Include,
198    two_pass_threshold: None,
199    workgroup_size: None,
200    accepts_nan_mode: false,
201    notes:
202        "Uses provider-level Horner kernels for real coefficients/inputs; falls back to host evaluation (with upload) for complex or prediction-interval paths.",
203};
204
205fn polyval_error(message: impl Into<String>) -> RuntimeError {
206    polyval_error_with(message, &POLYVAL_ERROR_INVALID_INPUT)
207}
208
209fn polyval_argument_error(message: impl Into<String>) -> RuntimeError {
210    polyval_error_with(message, &POLYVAL_ERROR_INVALID_ARGUMENT)
211}
212
213fn polyval_error_with(
214    message: impl Into<String>,
215    error: &'static BuiltinErrorDescriptor,
216) -> RuntimeError {
217    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
218    if let Some(identifier) = error.identifier {
219        builder = builder.with_identifier(identifier);
220    }
221    builder.build()
222}
223
224#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::poly::polyval")]
225pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
226    name: "polyval",
227    shape: ShapeRequirements::Any,
228    constant_strategy: ConstantStrategy::UniformBuffer,
229    elementwise: None,
230    reduction: None,
231    emits_nan: true,
232    notes: "Acts as a fusion sink; real-valued workloads stay on device, while complex/delta paths gather to the host.",
233};
234
235#[runtime_builtin(
236    name = "polyval",
237    category = "math/poly",
238    summary = "Evaluate polynomials at specified points.",
239    keywords = "polyval,polynomial,polyfit,delta,gpu",
240    accel = "sink",
241    sink = true,
242    type_resolver(polyval_type),
243    descriptor(crate::builtins::math::poly::polyval::POLYVAL_DESCRIPTOR),
244    builtin_path = "crate::builtins::math::poly::polyval"
245)]
246async fn polyval_builtin(p: Value, x: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
247    if let Some(out_count) = crate::output_count::current_output_count() {
248        let eval = evaluate(p, x, &rest, out_count >= 2).await?;
249        if out_count == 0 {
250            return Ok(Value::OutputList(Vec::new()));
251        }
252        let mut outputs = vec![eval.value()];
253        if out_count >= 2 {
254            outputs.push(eval.delta()?);
255        }
256        return Ok(crate::output_count::output_list_with_padding(
257            out_count, outputs,
258        ));
259    }
260    let eval = evaluate(p, x, &rest, false).await?;
261    Ok(eval.value())
262}
263
264/// Evaluate `polyval`, optionally computing the prediction interval.
265pub async fn evaluate(
266    coefficients: Value,
267    points: Value,
268    rest: &[Value],
269    want_delta: bool,
270) -> BuiltinResult<PolyvalEval> {
271    let options = parse_option_values(rest).await?;
272
273    let coeff_clone = coefficients.clone();
274    let points_clone = points.clone();
275
276    let coeff_was_gpu = matches!(coefficients, Value::GpuTensor(_));
277    let (coeffs, coeff_real) = convert_coefficients(coeff_clone).await?;
278
279    let (mut inputs, prefer_gpu_points) = convert_points(points_clone).await?;
280    let prefer_gpu_output = prefer_gpu_points || coeff_was_gpu;
281
282    let mu = match options.mu.clone() {
283        Some(mu_value) => Some(parse_mu(mu_value).await?),
284        None => None,
285    };
286
287    if prefer_gpu_output && !want_delta && options.s.is_none() {
288        if let Some(value) =
289            try_gpu_polyval(&coeffs, coeff_real, &inputs, mu, prefer_gpu_output).await?
290        {
291            return Ok(PolyvalEval::new(value, None));
292        }
293    }
294
295    if let Some(mu_val) = mu {
296        apply_mu(&mut inputs.data, mu_val)?;
297    }
298
299    let stats = if let Some(s_value) = options.s {
300        parse_stats(s_value, coeffs.len()).await?
301    } else {
302        None
303    };
304
305    if want_delta && stats.is_none() {
306        return Err(polyval_argument_error(
307            "polyval: S input (structure returned by polyfit) is required for delta output",
308        ));
309    }
310
311    if inputs.data.is_empty() {
312        let y = zeros_like(&inputs.shape, prefer_gpu_output)?;
313        let delta = if want_delta {
314            Some(zeros_like(&inputs.shape, prefer_gpu_output)?)
315        } else {
316            None
317        };
318        return Ok(PolyvalEval::new(y, delta));
319    }
320
321    if coeffs.is_empty() {
322        let zeros = zeros_like(&inputs.shape, prefer_gpu_output)?;
323        let delta = if want_delta {
324            Some(zeros_like(&inputs.shape, prefer_gpu_output)?)
325        } else {
326            None
327        };
328        return Ok(PolyvalEval::new(zeros, delta));
329    }
330
331    let output_real = coeff_real && inputs.all_real;
332    let values = evaluate_polynomial(&coeffs, &inputs.data);
333    let result_value = finalize_values(
334        &values,
335        &inputs.shape,
336        prefer_gpu_output,
337        output_real && values_are_real(&values),
338    )?;
339
340    let delta_value = if want_delta {
341        let stats = stats.expect("delta requires stats");
342        let delta = compute_prediction_interval(&coeffs, &inputs.data, &stats)?;
343        let prefer = prefer_gpu_output && stats.is_real;
344        Some(finalize_delta(delta, &inputs.shape, prefer)?)
345    } else {
346        None
347    };
348
349    Ok(PolyvalEval::new(result_value, delta_value))
350}
351
352async fn try_gpu_polyval(
353    coeffs: &[Complex64],
354    coeff_real: bool,
355    inputs: &NumericArray,
356    mu: Option<Mu>,
357    prefer_gpu_output: bool,
358) -> BuiltinResult<Option<Value>> {
359    if !coeff_real || !inputs.all_real {
360        return Ok(None);
361    }
362    if coeffs.is_empty() || inputs.data.is_empty() {
363        return Ok(None);
364    }
365    let Some(provider) = runmat_accelerate_api::provider() else {
366        return Ok(None);
367    };
368
369    let coeff_data: Vec<f64> = coeffs.iter().map(|c| c.re).collect();
370    let coeff_shape = vec![1usize, coeffs.len()];
371    let coeff_view = HostTensorView {
372        data: &coeff_data,
373        shape: &coeff_shape,
374    };
375    let coeff_handle = match provider.upload(&coeff_view) {
376        Ok(handle) => handle,
377        Err(err) => {
378            debug!("polyval: GPU upload of coefficients failed, falling back: {err}");
379            return Ok(None);
380        }
381    };
382
383    let input_data: Vec<f64> = inputs.data.iter().map(|c| c.re).collect();
384    let input_shape = inputs.shape.clone();
385    let input_view = HostTensorView {
386        data: &input_data,
387        shape: &input_shape,
388    };
389    let input_handle = match provider.upload(&input_view) {
390        Ok(handle) => handle,
391        Err(err) => {
392            debug!("polyval: GPU upload of evaluation points failed, falling back: {err}");
393            let _ = provider.free(&coeff_handle);
394            return Ok(None);
395        }
396    };
397
398    let options = ProviderPolyvalOptions {
399        mu: mu.map(|m| ProviderPolyvalMu {
400            mean: m.mean,
401            scale: m.scale,
402        }),
403    };
404
405    let result_handle = match provider.polyval(&coeff_handle, &input_handle, &options) {
406        Ok(handle) => handle,
407        Err(err) => {
408            debug!("polyval: GPU kernel execution failed, falling back: {err}");
409            let _ = provider.free(&coeff_handle);
410            let _ = provider.free(&input_handle);
411            return Ok(None);
412        }
413    };
414
415    let _ = provider.free(&coeff_handle);
416    let _ = provider.free(&input_handle);
417
418    if prefer_gpu_output {
419        return Ok(Some(Value::GpuTensor(result_handle)));
420    }
421
422    let host = match download_handle_async(provider, &result_handle).await {
423        Ok(host) => host,
424        Err(err) => {
425            debug!("polyval: GPU download failed, falling back: {err}");
426            let _ = provider.free(&result_handle);
427            return Ok(None);
428        }
429    };
430    let _ = provider.free(&result_handle);
431
432    let tensor =
433        Tensor::new(host.data, host.shape).map_err(|e| polyval_error(format!("polyval: {e}")))?;
434    Ok(Some(tensor::tensor_into_value(tensor)))
435}
436
437/// Result object for polyval evaluation.
438#[derive(Debug)]
439pub struct PolyvalEval {
440    value: Value,
441    delta: Option<Value>,
442}
443
444impl PolyvalEval {
445    fn new(value: Value, delta: Option<Value>) -> Self {
446        Self { value, delta }
447    }
448
449    /// Primary output (`y`).
450    pub fn value(&self) -> Value {
451        self.value.clone()
452    }
453
454    /// Optional prediction interval (`delta`).
455    pub fn delta(&self) -> BuiltinResult<Value> {
456        self.delta
457            .clone()
458            .ok_or_else(|| polyval_argument_error("polyval: delta output not computed"))
459    }
460
461    /// Consume into the main value.
462    pub fn into_value(self) -> Value {
463        self.value
464    }
465
466    /// Consume into `(value, delta)` pair.
467    pub fn into_pair(self) -> BuiltinResult<(Value, Value)> {
468        match self.delta {
469            Some(delta) => Ok((self.value, delta)),
470            None => Err(polyval_argument_error("polyval: delta output not computed")),
471        }
472    }
473}
474
475#[derive(Clone, Copy)]
476struct Mu {
477    mean: f64,
478    scale: f64,
479}
480
481impl Mu {
482    fn new(mean: f64, scale: f64) -> BuiltinResult<Self> {
483        if !mean.is_finite() || !scale.is_finite() {
484            return Err(polyval_error("polyval: mu values must be finite"));
485        }
486        if scale.abs() <= EPS {
487            return Err(polyval_error("polyval: mu(2) must be non-zero"));
488        }
489        Ok(Self { mean, scale })
490    }
491}
492
493#[derive(Clone)]
494struct NumericArray {
495    data: Vec<Complex64>,
496    shape: Vec<usize>,
497    all_real: bool,
498}
499
500#[derive(Clone)]
501struct PolyfitStats {
502    r: Matrix,
503    df: f64,
504    normr: f64,
505    is_real: bool,
506}
507
508impl PolyfitStats {
509    fn is_effective(&self) -> bool {
510        self.r.len() > 0 && self.df > 0.0 && self.normr.is_finite()
511    }
512}
513
514#[derive(Clone)]
515struct Matrix {
516    rows: usize,
517    cols: usize,
518    data: Vec<Complex64>,
519}
520
521impl Matrix {
522    fn get(&self, row: usize, col: usize) -> Complex64 {
523        self.data[row + col * self.rows]
524    }
525
526    fn len(&self) -> usize {
527        self.rows * self.cols
528    }
529}
530
531struct ParsedOptions {
532    s: Option<Value>,
533    mu: Option<Value>,
534}
535
536async fn parse_option_values(rest: &[Value]) -> BuiltinResult<ParsedOptions> {
537    match rest.len() {
538        0 => Ok(ParsedOptions { s: None, mu: None }),
539        1 => Ok(ParsedOptions {
540            s: if is_empty_value(&rest[0]).await? {
541                None
542            } else {
543                Some(rest[0].clone())
544            },
545            mu: None,
546        }),
547        2 => Ok(ParsedOptions {
548            s: if is_empty_value(&rest[0]).await? {
549                None
550            } else {
551                Some(rest[0].clone())
552            },
553            mu: Some(rest[1].clone()),
554        }),
555        _ => Err(polyval_argument_error("polyval: too many input arguments")),
556    }
557}
558
559#[async_recursion::async_recursion(?Send)]
560async fn convert_coefficients(value: Value) -> BuiltinResult<(Vec<Complex64>, bool)> {
561    match value {
562        Value::GpuTensor(handle) => {
563            let gathered =
564                gpu_helpers::gather_value_async(&Value::GpuTensor(handle.clone())).await?;
565            convert_coefficients(gathered).await
566        }
567        Value::Tensor(mut tensor) => {
568            ensure_vector_shape("polyval", &tensor.shape)?;
569            let data = tensor
570                .data
571                .drain(..)
572                .map(|re| Complex64::new(re, 0.0))
573                .collect();
574            Ok((data, true))
575        }
576        Value::ComplexTensor(mut tensor) => {
577            ensure_vector_shape("polyval", &tensor.shape)?;
578            let all_real = tensor.data.iter().all(|&(_, im)| im.abs() <= EPS);
579            let data = tensor
580                .data
581                .drain(..)
582                .map(|(re, im)| Complex64::new(re, im))
583                .collect();
584            Ok((data, all_real))
585        }
586        Value::LogicalArray(mut array) => {
587            ensure_vector_data_shape("polyval", &array.shape)?;
588            let data = array
589                .data
590                .drain(..)
591                .map(|bit| Complex64::new(if bit != 0 { 1.0 } else { 0.0 }, 0.0))
592                .collect();
593            Ok((data, true))
594        }
595        Value::Num(n) => Ok((vec![Complex64::new(n, 0.0)], true)),
596        Value::Int(i) => Ok((vec![Complex64::new(i.to_f64(), 0.0)], true)),
597        Value::Bool(flag) => Ok((
598            vec![Complex64::new(if flag { 1.0 } else { 0.0 }, 0.0)],
599            true,
600        )),
601        Value::Complex(re, im) => Ok((vec![Complex64::new(re, im)], im.abs() <= EPS)),
602        other => Err(polyval_error(format!(
603            "polyval: coefficients must be numeric, got {other:?}"
604        ))),
605    }
606}
607
608async fn convert_points(value: Value) -> BuiltinResult<(NumericArray, bool)> {
609    match value {
610        Value::GpuTensor(handle) => {
611            let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
612            let array = NumericArray {
613                data: tensor
614                    .data
615                    .iter()
616                    .map(|&re| Complex64::new(re, 0.0))
617                    .collect(),
618                shape: tensor.shape.clone(),
619                all_real: true,
620            };
621            Ok((array, true))
622        }
623        Value::Tensor(tensor) => Ok((
624            NumericArray {
625                data: tensor
626                    .data
627                    .iter()
628                    .map(|&re| Complex64::new(re, 0.0))
629                    .collect(),
630                shape: tensor.shape.clone(),
631                all_real: true,
632            },
633            false,
634        )),
635        Value::ComplexTensor(tensor) => Ok((
636            NumericArray {
637                data: tensor
638                    .data
639                    .iter()
640                    .map(|&(re, im)| Complex64::new(re, im))
641                    .collect(),
642                shape: tensor.shape.clone(),
643                all_real: tensor.data.iter().all(|&(_, im)| im.abs() <= EPS),
644            },
645            false,
646        )),
647        Value::LogicalArray(array) => Ok((
648            NumericArray {
649                data: array
650                    .data
651                    .iter()
652                    .map(|&bit| Complex64::new(if bit != 0 { 1.0 } else { 0.0 }, 0.0))
653                    .collect(),
654                shape: array.shape.clone(),
655                all_real: true,
656            },
657            false,
658        )),
659        Value::Num(n) => Ok((
660            NumericArray {
661                data: vec![Complex64::new(n, 0.0)],
662                shape: vec![1, 1],
663                all_real: true,
664            },
665            false,
666        )),
667        Value::Int(i) => Ok((
668            NumericArray {
669                data: vec![Complex64::new(i.to_f64(), 0.0)],
670                shape: vec![1, 1],
671                all_real: true,
672            },
673            false,
674        )),
675        Value::Bool(flag) => Ok((
676            NumericArray {
677                data: vec![Complex64::new(if flag { 1.0 } else { 0.0 }, 0.0)],
678                shape: vec![1, 1],
679                all_real: true,
680            },
681            false,
682        )),
683        Value::Complex(re, im) => Ok((
684            NumericArray {
685                data: vec![Complex64::new(re, im)],
686                shape: vec![1, 1],
687                all_real: im.abs() <= EPS,
688            },
689            false,
690        )),
691        other => Err(polyval_error(format!(
692            "polyval: X must be numeric, got {other:?}"
693        ))),
694    }
695}
696
697#[async_recursion::async_recursion(?Send)]
698async fn parse_mu(value: Value) -> BuiltinResult<Mu> {
699    match value {
700        Value::GpuTensor(handle) => {
701            let gathered = gpu_helpers::gather_tensor_async(&handle).await?;
702            parse_mu(Value::Tensor(gathered)).await
703        }
704        Value::Tensor(tensor) => {
705            if tensor.data.len() < 2 {
706                return Err(polyval_error(
707                    "polyval: mu must contain at least two elements",
708                ));
709            }
710            Mu::new(tensor.data[0], tensor.data[1])
711        }
712        Value::LogicalArray(array) => {
713            if array.data.len() < 2 {
714                return Err(polyval_error(
715                    "polyval: mu must contain at least two elements",
716                ));
717            }
718            let mean = if array.data[0] != 0 { 1.0 } else { 0.0 };
719            let scale = if array.data[1] != 0 { 1.0 } else { 0.0 };
720            Mu::new(mean, scale)
721        }
722        Value::Num(_) | Value::Int(_) | Value::Bool(_) | Value::Complex(_, _) => Err(
723            polyval_error("polyval: mu must be a numeric vector with at least two values"),
724        ),
725        Value::ComplexTensor(tensor) => {
726            if tensor.data.len() < 2 {
727                return Err(polyval_error(
728                    "polyval: mu must contain at least two elements",
729                ));
730            }
731            let (mean_re, mean_im) = tensor.data[0];
732            let (scale_re, scale_im) = tensor.data[1];
733            if mean_im.abs() > EPS || scale_im.abs() > EPS {
734                return Err(polyval_error("polyval: mu values must be real"));
735            }
736            Mu::new(mean_re, scale_re)
737        }
738        _ => Err(polyval_error(
739            "polyval: mu must be a numeric vector with at least two values",
740        )),
741    }
742}
743
744#[async_recursion::async_recursion(?Send)]
745async fn parse_stats(value: Value, coeff_len: usize) -> BuiltinResult<Option<PolyfitStats>> {
746    if is_empty_value(&value).await? {
747        return Ok(None);
748    }
749    let struct_value = match value {
750        Value::Struct(s) => s,
751        Value::GpuTensor(handle) => {
752            let gathered = gpu_helpers::gather_value_async(&Value::GpuTensor(handle)).await?;
753            return parse_stats(gathered, coeff_len).await;
754        }
755        other => {
756            return Err(polyval_error(format!(
757                "polyval: S input must be the structure returned by polyfit, got {other:?}"
758            )))
759        }
760    };
761    let r_value = struct_value
762        .fields
763        .get("R")
764        .cloned()
765        .ok_or_else(|| polyval_error("polyval: S input is missing the field 'R'"))?;
766    let df_value = struct_value
767        .fields
768        .get("df")
769        .cloned()
770        .ok_or_else(|| polyval_error("polyval: S input is missing the field 'df'"))?;
771    let normr_value = struct_value
772        .fields
773        .get("normr")
774        .cloned()
775        .ok_or_else(|| polyval_error("polyval: S input is missing the field 'normr'"))?;
776
777    let (matrix, is_real) = convert_matrix(r_value, coeff_len).await?;
778    let df = scalar_to_f64(df_value, "polyval: S.df").await?;
779    let normr = scalar_to_f64(normr_value, "polyval: S.normr").await?;
780
781    Ok(Some(PolyfitStats {
782        r: matrix,
783        df,
784        normr,
785        is_real,
786    }))
787}
788
789#[async_recursion::async_recursion(?Send)]
790async fn convert_matrix(value: Value, coeff_len: usize) -> BuiltinResult<(Matrix, bool)> {
791    match value {
792        Value::GpuTensor(handle) => {
793            let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
794            convert_matrix(Value::Tensor(tensor), coeff_len).await
795        }
796        Value::Tensor(tensor) => {
797            let Tensor {
798                data, rows, cols, ..
799            } = tensor;
800            if rows != coeff_len || cols != coeff_len {
801                return Err(polyval_error("polyval: size of S.R must match the coefficient vector"));
802            }
803            let data = data.into_iter().map(|re| Complex64::new(re, 0.0)).collect();
804            Ok((Matrix { rows, cols, data }, true))
805        }
806        Value::ComplexTensor(tensor) => {
807            let ComplexTensor {
808                data, rows, cols, ..
809            } = tensor;
810            if rows != coeff_len || cols != coeff_len {
811                return Err(polyval_error("polyval: size of S.R must match the coefficient vector"));
812            }
813            let imag_small = data.iter().all(|&(_, im)| im.abs() <= EPS);
814            let data = data
815                .into_iter()
816                .map(|(re, im)| Complex64::new(re, im))
817                .collect();
818            Ok((Matrix { rows, cols, data }, imag_small))
819        }
820        Value::LogicalArray(array) => {
821            let LogicalArray { data, shape } = array;
822            let rows = shape.first().copied().unwrap_or(0);
823            let cols = shape.get(1).copied().unwrap_or(0);
824            if rows != coeff_len || cols != coeff_len {
825                return Err(polyval_error("polyval: size of S.R must match the coefficient vector"));
826            }
827            let data = data
828                .into_iter()
829                .map(|bit| Complex64::new(if bit != 0 { 1.0 } else { 0.0 }, 0.0))
830                .collect();
831            Ok((Matrix { rows, cols, data }, true))
832        }
833        Value::Num(_) | Value::Int(_) | Value::Bool(_) | Value::Complex(_, _) => Err(
834            polyval_error(
835                "polyval: S.R must be a square numeric matrix matching the coefficient vector length",
836            ),
837        ),
838        Value::Struct(_)
839        | Value::Cell(_)
840        | Value::String(_)
841        | Value::StringArray(_)
842        | Value::CharArray(_) => Err(
843            polyval_error(
844                "polyval: S.R must be a square numeric matrix matching the coefficient vector length",
845            ),
846        ),
847        _ => Err(
848            polyval_error(
849                "polyval: S.R must be a square numeric matrix matching the coefficient vector length",
850            ),
851        ),
852    }
853}
854
855#[async_recursion::async_recursion(?Send)]
856async fn scalar_to_f64(value: Value, context: &str) -> BuiltinResult<f64> {
857    match value {
858        Value::Num(n) => Ok(n),
859        Value::Int(i) => Ok(i.to_f64()),
860        Value::Bool(flag) => Ok(if flag { 1.0 } else { 0.0 }),
861        Value::Tensor(tensor) => {
862            if tensor.data.len() != 1 {
863                return Err(polyval_error(format!("{context} must be a scalar")));
864            }
865            Ok(tensor.data[0])
866        }
867        Value::LogicalArray(array) => {
868            if array.data.len() != 1 {
869                return Err(polyval_error(format!("{context} must be a scalar")));
870            }
871            Ok(if array.data[0] != 0 { 1.0 } else { 0.0 })
872        }
873        Value::GpuTensor(handle) => {
874            let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
875            scalar_to_f64(Value::Tensor(tensor), context).await
876        }
877        Value::Complex(_, _) | Value::ComplexTensor(_) => {
878            Err(polyval_error(format!("{context} must be real-valued")))
879        }
880        other => Err(polyval_error(format!(
881            "{context} must be a scalar, got {other:?}"
882        ))),
883    }
884}
885
886fn apply_mu(values: &mut [Complex64], mu: Mu) -> BuiltinResult<()> {
887    let mean = Complex64::new(mu.mean, 0.0);
888    let scale = Complex64::new(mu.scale, 0.0);
889    for v in values.iter_mut() {
890        *v = (*v - mean) / scale;
891    }
892    Ok(())
893}
894
895fn evaluate_polynomial(coeffs: &[Complex64], inputs: &[Complex64]) -> Vec<Complex64> {
896    let mut outputs = Vec::with_capacity(inputs.len());
897    for &x in inputs {
898        let mut acc = Complex64::new(0.0, 0.0);
899        for &c in coeffs {
900            acc = acc * x + c;
901        }
902        outputs.push(acc);
903    }
904    outputs
905}
906
907fn compute_prediction_interval(
908    coeffs: &[Complex64],
909    inputs: &[Complex64],
910    stats: &PolyfitStats,
911) -> BuiltinResult<Vec<f64>> {
912    if !stats.is_effective() {
913        return Ok(vec![0.0; inputs.len()]);
914    }
915    let n = coeffs.len();
916    let mut delta = Vec::with_capacity(inputs.len());
917    for &x in inputs {
918        let row = vandermonde_row(x, n);
919        let solved = solve_row_against_upper(&row, &stats.r)?;
920        let sum_sq: f64 = solved.iter().map(|c| c.norm_sqr()).sum();
921        let interval = (1.0 + sum_sq).sqrt() * (stats.normr / stats.df.sqrt());
922        delta.push(interval);
923    }
924    Ok(delta)
925}
926
927fn vandermonde_row(x: Complex64, len: usize) -> Vec<Complex64> {
928    if len == 0 {
929        return vec![Complex64::new(1.0, 0.0)];
930    }
931    let degree = len - 1;
932    let mut powers = vec![Complex64::new(1.0, 0.0); degree + 1];
933    for idx in 1..=degree {
934        powers[idx] = powers[idx - 1] * x;
935    }
936    let mut row = vec![Complex64::new(0.0, 0.0); degree + 1];
937    for (i, value) in powers.into_iter().enumerate() {
938        row[degree - i] = value;
939    }
940    row
941}
942
943fn solve_row_against_upper(row: &[Complex64], matrix: &Matrix) -> BuiltinResult<Vec<Complex64>> {
944    let n = row.len();
945    if matrix.rows != n || matrix.cols != n {
946        return Err(polyval_error(
947            "polyval: size of S.R must match the coefficient vector",
948        ));
949    }
950    let mut result = vec![Complex64::new(0.0, 0.0); n];
951    for j in (0..n).rev() {
952        let mut acc = row[j];
953        for (k, value) in result.iter().enumerate().skip(j + 1) {
954            acc -= *value * matrix.get(k, j);
955        }
956        let diag = matrix.get(j, j);
957        if diag.norm() <= EPS {
958            return Err(polyval_error("polyval: S.R is singular"));
959        }
960        result[j] = acc / diag;
961    }
962    Ok(result)
963}
964
965fn finalize_values(
966    data: &[Complex64],
967    shape: &[usize],
968    prefer_gpu: bool,
969    real_only: bool,
970) -> BuiltinResult<Value> {
971    if real_only {
972        let real_data: Vec<f64> = data.iter().map(|c| c.re).collect();
973        finalize_real(real_data, shape, prefer_gpu)
974    } else if data.len() == 1 {
975        let value = data[0];
976        Ok(Value::Complex(value.re, value.im))
977    } else {
978        let complex_data: Vec<(f64, f64)> = data.iter().map(|c| (c.re, c.im)).collect();
979        let tensor = ComplexTensor::new(complex_data, shape.to_vec())
980            .map_err(|e| polyval_error(format!("polyval: failed to build complex tensor: {e}")))?;
981        Ok(Value::ComplexTensor(tensor))
982    }
983}
984
985fn finalize_delta(data: Vec<f64>, shape: &[usize], prefer_gpu: bool) -> BuiltinResult<Value> {
986    finalize_real(data, shape, prefer_gpu)
987}
988
989fn finalize_real(data: Vec<f64>, shape: &[usize], prefer_gpu: bool) -> BuiltinResult<Value> {
990    let tensor = Tensor::new(data, shape.to_vec())
991        .map_err(|e| polyval_error(format!("polyval: failed to build tensor: {e}")))?;
992    if prefer_gpu {
993        if let Some(provider) = runmat_accelerate_api::provider() {
994            let view = HostTensorView {
995                data: &tensor.data,
996                shape: &tensor.shape,
997            };
998            if let Ok(handle) = provider.upload(&view) {
999                return Ok(Value::GpuTensor(handle));
1000            }
1001        }
1002    }
1003    Ok(tensor::tensor_into_value(tensor))
1004}
1005
1006fn zeros_like(shape: &[usize], prefer_gpu: bool) -> BuiltinResult<Value> {
1007    let len = shape.iter().product();
1008    finalize_real(vec![0.0; len], shape, prefer_gpu)
1009}
1010
1011fn ensure_vector_shape(name: &str, shape: &[usize]) -> BuiltinResult<()> {
1012    if !is_vector_shape(shape) {
1013        Err(polyval_error(format!(
1014            "{name}: coefficients must be a scalar, row vector, or column vector"
1015        )))
1016    } else {
1017        Ok(())
1018    }
1019}
1020
1021fn ensure_vector_data_shape(name: &str, shape: &[usize]) -> BuiltinResult<()> {
1022    if !is_vector_shape(shape) {
1023        Err(polyval_error(format!(
1024            "{name}: inputs must be vectors or scalars"
1025        )))
1026    } else {
1027        Ok(())
1028    }
1029}
1030
1031fn is_vector_shape(shape: &[usize]) -> bool {
1032    shape.iter().filter(|&&dim| dim > 1).count() <= 1
1033}
1034
1035#[async_recursion::async_recursion(?Send)]
1036async fn is_empty_value(value: &Value) -> BuiltinResult<bool> {
1037    match value {
1038        Value::Tensor(t) => Ok(t.data.is_empty()),
1039        Value::LogicalArray(l) => Ok(l.data.is_empty()),
1040        Value::Cell(ca) => Ok(ca.data.is_empty()),
1041        Value::GpuTensor(handle) => {
1042            let gathered =
1043                gpu_helpers::gather_value_async(&Value::GpuTensor(handle.clone())).await?;
1044            is_empty_value(&gathered).await
1045        }
1046        _ => Ok(false),
1047    }
1048}
1049
1050fn values_are_real(values: &[Complex64]) -> bool {
1051    values.iter().all(|c| c.im.abs() <= EPS)
1052}
1053
1054#[cfg(test)]
1055pub(crate) mod tests {
1056    use super::*;
1057    use crate::builtins::common::test_support;
1058    use futures::executor::block_on;
1059    use runmat_builtins::StructValue;
1060
1061    fn assert_error_contains(err: crate::RuntimeError, needle: &str) {
1062        assert!(
1063            err.message().contains(needle),
1064            "expected error containing '{needle}', got '{}'",
1065            err.message()
1066        );
1067    }
1068
1069    #[test]
1070    fn polyval_descriptor_signatures_cover_core_forms() {
1071        let labels: Vec<&str> = POLYVAL_DESCRIPTOR
1072            .signatures
1073            .iter()
1074            .map(|signature| signature.label)
1075            .collect();
1076        assert!(labels.contains(&"y = polyval(p, x)"));
1077        assert!(labels.contains(&"y = polyval(p, x, S)"));
1078        assert!(labels.contains(&"y = polyval(p, x, S, mu)"));
1079        assert!(labels.contains(&"[y, delta] = polyval(p, x, S)"));
1080    }
1081
1082    #[test]
1083    fn polyval_descriptor_errors_have_stable_codes() {
1084        let codes: Vec<&str> = POLYVAL_DESCRIPTOR
1085            .errors
1086            .iter()
1087            .map(|error| error.code)
1088            .collect();
1089        assert!(codes.contains(&"RM.POLYVAL.INVALID_ARGUMENT"));
1090        assert!(codes.contains(&"RM.POLYVAL.INVALID_INPUT"));
1091        assert!(codes.contains(&"RM.POLYVAL.INTERNAL"));
1092    }
1093
1094    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1095    #[test]
1096    fn polyval_scalar() {
1097        let coeffs = Tensor::new(vec![2.0, -3.0, 5.0], vec![1, 3]).unwrap();
1098        let value =
1099            polyval_builtin(Value::Tensor(coeffs), Value::Num(4.0), Vec::new()).expect("polyval");
1100        match value {
1101            Value::Num(n) => assert!((n - (2.0 * 16.0 - 12.0 + 5.0)).abs() < 1e-12),
1102            other => panic!("expected scalar, got {other:?}"),
1103        }
1104    }
1105
1106    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1107    #[test]
1108    fn polyval_matrix_input() {
1109        let coeffs = Tensor::new(vec![1.0, 0.0, -2.0, 1.0], vec![1, 4]).unwrap();
1110        let points = Tensor::new(vec![-2.0, -1.0, 0.0, 1.0, 2.0], vec![5, 1]).unwrap();
1111        let value = polyval_builtin(
1112            Value::Tensor(coeffs),
1113            Value::Tensor(points.clone()),
1114            Vec::new(),
1115        )
1116        .expect("polyval");
1117        match value {
1118            Value::Tensor(tensor) => {
1119                assert_eq!(tensor.shape, points.shape);
1120                let expected = vec![-3.0, 2.0, 1.0, 0.0, 5.0];
1121                assert_eq!(tensor.data, expected);
1122            }
1123            other => panic!("expected tensor output, got {other:?}"),
1124        }
1125    }
1126
1127    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1128    #[test]
1129    fn polyval_complex_inputs() {
1130        let coeffs =
1131            ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.0), (0.0, 4.0)], vec![1, 3]).unwrap();
1132        let points =
1133            ComplexTensor::new(vec![(-1.0, 1.0), (0.0, 0.0), (1.0, -2.0)], vec![1, 3]).unwrap();
1134        let value = polyval_builtin(
1135            Value::ComplexTensor(coeffs),
1136            Value::ComplexTensor(points.clone()),
1137            Vec::new(),
1138        )
1139        .expect("polyval");
1140        match value {
1141            Value::ComplexTensor(tensor) => {
1142                assert_eq!(tensor.shape, points.shape);
1143                assert_eq!(tensor.data.len(), 3);
1144            }
1145            other => panic!("expected complex tensor, got {other:?}"),
1146        }
1147    }
1148
1149    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1150    #[test]
1151    fn polyval_with_mu() {
1152        let coeffs = Tensor::new(vec![1.0, 0.0, 0.0], vec![1, 3]).unwrap();
1153        let points = Tensor::new(vec![0.0, 1.0, 2.0], vec![1, 3]).unwrap();
1154        let mu = Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
1155        let value = polyval_builtin(
1156            Value::Tensor(coeffs),
1157            Value::Tensor(points),
1158            vec![
1159                Value::Tensor(Tensor::new(vec![], vec![0, 0]).unwrap()),
1160                Value::Tensor(mu),
1161            ],
1162        )
1163        .expect("polyval");
1164        match value {
1165            Value::Tensor(tensor) => {
1166                assert_eq!(tensor.data, vec![0.25, 0.0, 0.25]);
1167            }
1168            other => panic!("expected tensor output, got {other:?}"),
1169        }
1170    }
1171
1172    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1173    #[test]
1174    fn polyval_delta_computation() {
1175        let coeffs = Tensor::new(vec![1.0, -3.0, 2.0], vec![1, 3]).unwrap();
1176        let points = Tensor::new(vec![0.0, 1.0, 2.0], vec![1, 3]).unwrap();
1177        let mut st = StructValue::new();
1178        let r = Tensor::new(
1179            vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
1180            vec![3, 3],
1181        )
1182        .unwrap();
1183        st.fields.insert("R".to_string(), Value::Tensor(r));
1184        st.fields.insert("df".to_string(), Value::Num(4.0));
1185        st.fields.insert("normr".to_string(), Value::Num(2.0));
1186        let stats = Value::Struct(st);
1187        let eval = futures::executor::block_on(evaluate(
1188            Value::Tensor(coeffs),
1189            Value::Tensor(points),
1190            &[stats],
1191            true,
1192        ))
1193        .expect("polyval");
1194        let (_, delta) = eval.into_pair().expect("delta available");
1195        match delta {
1196            Value::Tensor(tensor) => {
1197                assert_eq!(tensor.shape, vec![1, 3]);
1198                assert_eq!(tensor.data.len(), 3);
1199            }
1200            other => panic!("expected tensor delta, got {other:?}"),
1201        }
1202    }
1203
1204    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1205    #[test]
1206    fn polyval_delta_requires_stats() {
1207        let coeffs = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
1208        let points = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1209        let err = futures::executor::block_on(evaluate(
1210            Value::Tensor(coeffs),
1211            Value::Tensor(points),
1212            &[],
1213            true,
1214        ))
1215        .expect_err("expected error");
1216        assert_error_contains(err, "S input");
1217    }
1218
1219    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1220    #[test]
1221    fn polyval_invalid_mu_length_errors() {
1222        let coeffs = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
1223        let points = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
1224        let mu = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1225        let placeholder = Tensor::new(vec![], vec![0, 0]).unwrap();
1226        let err = polyval_builtin(
1227            Value::Tensor(coeffs),
1228            Value::Tensor(points),
1229            vec![Value::Tensor(placeholder), Value::Tensor(mu)],
1230        )
1231        .expect_err("expected mu length error");
1232        assert_error_contains(err, "mu must contain at least two elements");
1233    }
1234
1235    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1236    #[test]
1237    fn polyval_rejects_excess_optional_arguments() {
1238        let coeffs = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
1239        let points = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
1240        let err = polyval_builtin(
1241            Value::Tensor(coeffs),
1242            Value::Tensor(points),
1243            vec![Value::Num(1.0), Value::Num(2.0), Value::Num(3.0)],
1244        )
1245        .expect_err("expected too many arguments error");
1246        assert_eq!(err.identifier(), POLYVAL_ERROR_INVALID_ARGUMENT.identifier);
1247        assert_error_contains(err, "too many input arguments");
1248    }
1249
1250    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1251    #[test]
1252    fn polyval_complex_mu_rejected() {
1253        let coeffs = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
1254        let points = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
1255        let complex_mu =
1256            ComplexTensor::new(vec![(0.0, 0.0), (1.0, 0.5)], vec![1, 2]).expect("complex mu");
1257        let placeholder = Tensor::new(vec![], vec![0, 0]).unwrap();
1258        let err = polyval_builtin(
1259            Value::Tensor(coeffs),
1260            Value::Tensor(points),
1261            vec![Value::Tensor(placeholder), Value::ComplexTensor(complex_mu)],
1262        )
1263        .expect_err("expected complex mu error");
1264        assert_error_contains(err, "mu values must be real");
1265    }
1266
1267    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1268    #[test]
1269    fn polyval_invalid_stats_missing_r() {
1270        let coeffs = Tensor::new(vec![1.0, -3.0, 2.0], vec![1, 3]).unwrap();
1271        let points = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
1272        let mut st = StructValue::new();
1273        st.fields.insert("df".to_string(), Value::Num(1.0));
1274        st.fields.insert("normr".to_string(), Value::Num(1.0));
1275        let stats = Value::Struct(st);
1276        let err = polyval_builtin(Value::Tensor(coeffs), Value::Tensor(points), vec![stats])
1277            .expect_err("expected missing R error");
1278        assert_error_contains(err, "missing the field 'R'");
1279    }
1280
1281    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1282    #[test]
1283    fn polyval_gpu_roundtrip() {
1284        test_support::with_test_provider(|provider| {
1285            let coeffs = Tensor::new(vec![1.0, 0.0, 1.0], vec![1, 3]).unwrap();
1286            let points = Tensor::new(vec![-1.0, 0.0, 1.0], vec![3, 1]).unwrap();
1287            let coeff_handle = provider
1288                .upload(&HostTensorView {
1289                    data: &coeffs.data,
1290                    shape: &coeffs.shape,
1291                })
1292                .expect("upload coeff");
1293            let point_handle = provider
1294                .upload(&HostTensorView {
1295                    data: &points.data,
1296                    shape: &points.shape,
1297                })
1298                .expect("upload points");
1299            let value = polyval_builtin(
1300                Value::GpuTensor(coeff_handle),
1301                Value::GpuTensor(point_handle),
1302                Vec::new(),
1303            )
1304            .expect("polyval");
1305            match value {
1306                Value::GpuTensor(handle) => {
1307                    let gathered = test_support::gather(Value::GpuTensor(handle)).expect("gather");
1308                    assert_eq!(gathered.data, vec![2.0, 1.0, 2.0]);
1309                }
1310                other => panic!("expected gpu tensor, got {other:?}"),
1311            }
1312        });
1313    }
1314
1315    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1316    #[test]
1317    #[cfg(feature = "wgpu")]
1318    fn polyval_wgpu_matches_cpu_real_inputs() {
1319        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1320            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1321        );
1322        let coeffs = Tensor::new(vec![1.0, -3.0, 2.0], vec![1, 3]).unwrap();
1323        let points = Tensor::new(vec![-2.0, -1.0, 0.5, 2.5], vec![4, 1]).unwrap();
1324
1325        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1326        let coeff_handle = provider
1327            .upload(&HostTensorView {
1328                data: &coeffs.data,
1329                shape: &coeffs.shape,
1330            })
1331            .expect("upload coeffs");
1332        let point_handle = provider
1333            .upload(&HostTensorView {
1334                data: &points.data,
1335                shape: &points.shape,
1336            })
1337            .expect("upload points");
1338
1339        let gpu_value = polyval_builtin(
1340            Value::GpuTensor(coeff_handle.clone()),
1341            Value::GpuTensor(point_handle.clone()),
1342            Vec::new(),
1343        )
1344        .expect("polyval gpu");
1345
1346        let _ = provider.free(&coeff_handle);
1347        let _ = provider.free(&point_handle);
1348
1349        let gathered = test_support::gather(gpu_value).expect("gather");
1350
1351        let coeff_complex: Vec<Complex64> = coeffs
1352            .data
1353            .iter()
1354            .map(|&c| Complex64::new(c, 0.0))
1355            .collect();
1356        let point_complex: Vec<Complex64> = points
1357            .data
1358            .iter()
1359            .map(|&x| Complex64::new(x, 0.0))
1360            .collect();
1361        let expected_vals = evaluate_polynomial(&coeff_complex, &point_complex);
1362        let expected: Vec<f64> = expected_vals.iter().map(|c| c.re).collect();
1363
1364        assert_eq!(gathered.shape, vec![4, 1]);
1365        assert_eq!(gathered.data, expected);
1366    }
1367
1368    fn polyval_builtin(p: Value, x: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
1369        block_on(super::polyval_builtin(p, x, rest))
1370    }
1371}