Skip to main content

runmat_runtime/builtins/math/poly/
polyint.rs

1//! MATLAB-compatible `polyint` builtin with GPU-aware semantics for RunMat.
2
3use log::{trace, warn};
4use num_complex::Complex64;
5use runmat_accelerate_api::HostTensorView;
6use runmat_builtins::{
7    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
8    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
9    ComplexTensor, Tensor, Value,
10};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::gpu_helpers;
14use crate::builtins::common::spec::{
15    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
16    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
17};
18use crate::builtins::common::tensor;
19use crate::builtins::math::poly::type_resolvers::polyint_type;
20use crate::dispatcher;
21use crate::{build_runtime_error, BuiltinResult, RuntimeError};
22
23const EPS: f64 = 1.0e-12;
24const BUILTIN_NAME: &str = "polyint";
25
26const POLYINT_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
27    name: "q",
28    ty: BuiltinParamType::Any,
29    arity: BuiltinParamArity::Required,
30    default: None,
31    description: "Integrated polynomial coefficient vector.",
32}];
33
34const POLYINT_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
35    name: "p",
36    ty: BuiltinParamType::Any,
37    arity: BuiltinParamArity::Required,
38    default: None,
39    description: "Polynomial coefficient vector.",
40}];
41
42const POLYINT_INPUTS_WITH_K: [BuiltinParamDescriptor; 2] = [
43    BuiltinParamDescriptor {
44        name: "p",
45        ty: BuiltinParamType::Any,
46        arity: BuiltinParamArity::Required,
47        default: None,
48        description: "Polynomial coefficient vector.",
49    },
50    BuiltinParamDescriptor {
51        name: "k",
52        ty: BuiltinParamType::Any,
53        arity: BuiltinParamArity::Optional,
54        default: None,
55        description: "Constant of integration.",
56    },
57];
58
59const POLYINT_SIGNATURES: [BuiltinSignatureDescriptor; 2] = [
60    BuiltinSignatureDescriptor {
61        label: "q = polyint(p)",
62        inputs: &POLYINT_INPUTS,
63        outputs: &POLYINT_OUTPUT,
64    },
65    BuiltinSignatureDescriptor {
66        label: "q = polyint(p, k)",
67        inputs: &POLYINT_INPUTS_WITH_K,
68        outputs: &POLYINT_OUTPUT,
69    },
70];
71
72const POLYINT_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
73    code: "RM.POLYINT.INVALID_ARGUMENT",
74    identifier: Some("RunMat:polyint:InvalidArgument"),
75    when: "Input arity or integration-constant argument is malformed.",
76    message: "polyint: invalid argument",
77};
78
79const POLYINT_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
80    code: "RM.POLYINT.INVALID_INPUT",
81    identifier: Some("RunMat:polyint:InvalidInput"),
82    when: "Input polynomial cannot be interpreted as a numeric coefficient vector.",
83    message: "polyint: invalid input",
84};
85
86const POLYINT_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
87    code: "RM.POLYINT.INTERNAL",
88    identifier: Some("RunMat:polyint:Internal"),
89    when: "Runtime fails while building output tensors or provider fallback paths.",
90    message: "polyint: internal runtime failure",
91};
92
93const POLYINT_ERRORS: [BuiltinErrorDescriptor; 3] = [
94    POLYINT_ERROR_INVALID_ARGUMENT,
95    POLYINT_ERROR_INVALID_INPUT,
96    POLYINT_ERROR_INTERNAL,
97];
98
99pub const POLYINT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
100    signatures: &POLYINT_SIGNATURES,
101    output_mode: BuiltinOutputMode::Fixed,
102    completion_policy: BuiltinCompletionPolicy::Public,
103    errors: &POLYINT_ERRORS,
104};
105
106#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::poly::polyint")]
107pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
108    name: "polyint",
109    op_kind: GpuOpKind::Custom("polynomial-integral"),
110    supported_precisions: &[ScalarType::F32, ScalarType::F64],
111    broadcast: BroadcastSemantics::None,
112    provider_hooks: &[ProviderHook::Custom("polyint")],
113    constant_strategy: ConstantStrategy::InlineLiteral,
114    residency: ResidencyPolicy::NewHandle,
115    nan_mode: ReductionNaN::Include,
116    two_pass_threshold: None,
117    workgroup_size: None,
118    accepts_nan_mode: false,
119    notes: "Providers implement the polyint hook for real and complex-interleaved coefficient vectors; complex integration constants fall back to host integration and re-upload.",
120};
121
122fn polyint_error(message: impl Into<String>) -> RuntimeError {
123    polyint_error_with(message, &POLYINT_ERROR_INVALID_INPUT)
124}
125
126fn polyint_argument_error(message: impl Into<String>) -> RuntimeError {
127    polyint_error_with(message, &POLYINT_ERROR_INVALID_ARGUMENT)
128}
129
130fn polyint_error_with(
131    message: impl Into<String>,
132    error: &'static BuiltinErrorDescriptor,
133) -> RuntimeError {
134    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
135    if let Some(identifier) = error.identifier {
136        builder = builder.with_identifier(identifier);
137    }
138    builder.build()
139}
140
141#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::poly::polyint")]
142pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
143    name: "polyint",
144    shape: ShapeRequirements::Any,
145    constant_strategy: ConstantStrategy::InlineLiteral,
146    elementwise: None,
147    reduction: None,
148    emits_nan: false,
149    notes: "Symbolic operation on coefficient vectors; fusion does not apply.",
150};
151
152#[runtime_builtin(
153    name = "polyint",
154    category = "math/poly",
155    summary = "Integrate polynomial coefficient vectors and append a constant of integration.",
156    keywords = "polyint,polynomial,integral,antiderivative",
157    type_resolver(polyint_type),
158    descriptor(crate::builtins::math::poly::polyint::POLYINT_DESCRIPTOR),
159    builtin_path = "crate::builtins::math::poly::polyint"
160)]
161async fn polyint_builtin(coeffs: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
162    if rest.len() > 1 {
163        return Err(polyint_argument_error("polyint: too many input arguments"));
164    }
165
166    let constant = match rest.into_iter().next() {
167        Some(value) => parse_constant(value).await?,
168        None => Complex64::new(0.0, 0.0),
169    };
170
171    if let Value::GpuTensor(handle) = &coeffs {
172        if let Some(device_result) = try_polyint_gpu(handle, constant)? {
173            return Ok(Value::GpuTensor(device_result));
174        }
175    }
176
177    let source_gpu = match &coeffs {
178        Value::GpuTensor(handle) => Some(handle.clone()),
179        _ => None,
180    };
181    polyint_host_value(coeffs, constant, source_gpu).await
182}
183
184async fn polyint_host_value(
185    coeffs: Value,
186    constant: Complex64,
187    source_gpu: Option<runmat_accelerate_api::GpuTensorHandle>,
188) -> BuiltinResult<Value> {
189    let polynomial = parse_polynomial(coeffs).await?;
190    let mut integrated = integrate_coeffs(&polynomial.coeffs);
191    if integrated.is_empty() {
192        integrated.push(constant);
193    } else if let Some(last) = integrated.last_mut() {
194        *last += constant;
195    }
196    let value = coeffs_to_value(&integrated, polynomial.orientation)?;
197    maybe_return_gpu(value, source_gpu.as_ref())
198}
199
200fn try_polyint_gpu(
201    handle: &runmat_accelerate_api::GpuTensorHandle,
202    constant: Complex64,
203) -> BuiltinResult<Option<runmat_accelerate_api::GpuTensorHandle>> {
204    if constant.im.abs() > EPS {
205        return Ok(None);
206    }
207    ensure_vector_shape(&handle.shape)?;
208    let Some(provider) =
209        runmat_accelerate_api::provider_for_handle(handle).or_else(runmat_accelerate_api::provider)
210    else {
211        return Ok(None);
212    };
213    match provider.polyint(handle, constant.re) {
214        Ok(result) => Ok(Some(result)),
215        Err(err) => {
216            trace!("polyint: provider hook unavailable, falling back to host: {err}");
217            Ok(None)
218        }
219    }
220}
221
222fn integrate_coeffs(coeffs: &[Complex64]) -> Vec<Complex64> {
223    if coeffs.is_empty() {
224        return Vec::new();
225    }
226    let mut result = Vec::with_capacity(coeffs.len() + 1);
227    for (idx, coeff) in coeffs.iter().enumerate() {
228        let power = (coeffs.len() - idx) as f64;
229        if power <= 0.0 {
230            result.push(Complex64::new(0.0, 0.0));
231        } else {
232            result.push(*coeff / Complex64::new(power, 0.0));
233        }
234    }
235    result.push(Complex64::new(0.0, 0.0));
236    result
237}
238
239fn maybe_return_gpu(
240    value: Value,
241    source_gpu: Option<&runmat_accelerate_api::GpuTensorHandle>,
242) -> BuiltinResult<Value> {
243    let Some(source_gpu) = source_gpu else {
244        return Ok(value);
245    };
246    let provider = runmat_accelerate_api::provider_for_handle(source_gpu);
247    match value {
248        Value::Tensor(tensor) => {
249            if let Some(provider) = provider {
250                let view = HostTensorView {
251                    data: &tensor.data,
252                    shape: &tensor.shape,
253                };
254                match provider.upload(&view) {
255                    Ok(handle) => return Ok(Value::GpuTensor(handle)),
256                    Err(err) => {
257                        warn!("polyint: provider upload failed, keeping result on host: {err}");
258                    }
259                }
260            } else {
261                trace!("polyint: no provider available to re-upload result");
262            }
263            Ok(Value::Tensor(tensor))
264        }
265        Value::ComplexTensor(tensor) => {
266            if let Some(provider) = provider {
267                match gpu_helpers::upload_complex_tensor(provider, &tensor) {
268                    Ok(handle) => return Ok(gpu_helpers::complex_gpu_value(handle)),
269                    Err(err) => {
270                        warn!(
271                            "polyint: provider complex upload failed, keeping result on host: {err}"
272                        );
273                    }
274                }
275            } else {
276                trace!("polyint: no provider available to re-upload complex result");
277            }
278            Ok(Value::ComplexTensor(tensor))
279        }
280        other => Ok(other),
281    }
282}
283
284fn coeffs_to_value(coeffs: &[Complex64], orientation: Orientation) -> BuiltinResult<Value> {
285    if coeffs.iter().all(|c| c.im.abs() <= EPS) {
286        let data: Vec<f64> = coeffs.iter().map(|c| c.re).collect();
287        let shape = orientation.shape_for_len(data.len());
288        let tensor =
289            Tensor::new(data, shape).map_err(|e| polyint_error(format!("polyint: {e}")))?;
290        Ok(tensor::tensor_into_value(tensor))
291    } else {
292        let data: Vec<(f64, f64)> = coeffs.iter().map(|c| (c.re, c.im)).collect();
293        let shape = orientation.shape_for_len(data.len());
294        let tensor =
295            ComplexTensor::new(data, shape).map_err(|e| polyint_error(format!("polyint: {e}")))?;
296        Ok(Value::ComplexTensor(tensor))
297    }
298}
299
300async fn parse_polynomial(value: Value) -> BuiltinResult<Polynomial> {
301    let gathered = dispatcher::gather_if_needed_async(&value).await?;
302    match gathered {
303        Value::Tensor(tensor) => parse_tensor_coeffs(&tensor),
304        Value::ComplexTensor(tensor) => parse_complex_tensor_coeffs(&tensor),
305        Value::LogicalArray(logical) => {
306            let tensor = tensor::logical_to_tensor(&logical).map_err(polyint_error)?;
307            parse_tensor_coeffs(&tensor)
308        }
309        Value::Num(n) => Ok(Polynomial {
310            coeffs: vec![Complex64::new(n, 0.0)],
311            orientation: Orientation::Scalar,
312        }),
313        Value::Int(i) => Ok(Polynomial {
314            coeffs: vec![Complex64::new(i.to_f64(), 0.0)],
315            orientation: Orientation::Scalar,
316        }),
317        Value::Bool(b) => Ok(Polynomial {
318            coeffs: vec![Complex64::new(if b { 1.0 } else { 0.0 }, 0.0)],
319            orientation: Orientation::Scalar,
320        }),
321        Value::Complex(re, im) => Ok(Polynomial {
322            coeffs: vec![Complex64::new(re, im)],
323            orientation: Orientation::Scalar,
324        }),
325        other => Err(polyint_error(format!(
326            "polyint: expected a numeric coefficient vector, got {:?}",
327            other
328        ))),
329    }
330}
331
332fn parse_tensor_coeffs(tensor: &Tensor) -> BuiltinResult<Polynomial> {
333    ensure_vector_shape(&tensor.shape)?;
334    let orientation = orientation_from_shape(&tensor.shape);
335    Ok(Polynomial {
336        coeffs: tensor
337            .data
338            .iter()
339            .map(|&v| Complex64::new(v, 0.0))
340            .collect(),
341        orientation,
342    })
343}
344
345fn parse_complex_tensor_coeffs(tensor: &ComplexTensor) -> BuiltinResult<Polynomial> {
346    ensure_vector_shape(&tensor.shape)?;
347    let orientation = orientation_from_shape(&tensor.shape);
348    Ok(Polynomial {
349        coeffs: tensor
350            .data
351            .iter()
352            .map(|&(re, im)| Complex64::new(re, im))
353            .collect(),
354        orientation,
355    })
356}
357
358async fn parse_constant(value: Value) -> BuiltinResult<Complex64> {
359    let gathered = dispatcher::gather_if_needed_async(&value).await?;
360    match gathered {
361        Value::Tensor(tensor) => {
362            if tensor.data.len() != 1 {
363                return Err(polyint_error(
364                    "polyint: constant of integration must be a scalar",
365                ));
366            }
367            Ok(Complex64::new(tensor.data[0], 0.0))
368        }
369        Value::ComplexTensor(tensor) => {
370            if tensor.data.len() != 1 {
371                return Err(polyint_error(
372                    "polyint: constant of integration must be a scalar",
373                ));
374            }
375            let (re, im) = tensor.data[0];
376            Ok(Complex64::new(re, im))
377        }
378        Value::Num(n) => Ok(Complex64::new(n, 0.0)),
379        Value::Int(i) => Ok(Complex64::new(i.to_f64(), 0.0)),
380        Value::Bool(b) => Ok(Complex64::new(if b { 1.0 } else { 0.0 }, 0.0)),
381        Value::Complex(re, im) => Ok(Complex64::new(re, im)),
382        Value::LogicalArray(logical) => {
383            let tensor = tensor::logical_to_tensor(&logical).map_err(polyint_error)?;
384            if tensor.data.len() != 1 {
385                return Err(polyint_error(
386                    "polyint: constant of integration must be a scalar",
387                ));
388            }
389            Ok(Complex64::new(tensor.data[0], 0.0))
390        }
391        other => Err(polyint_error(format!(
392            "polyint: constant of integration must be numeric, got {:?}",
393            other
394        ))),
395    }
396}
397
398fn ensure_vector_shape(shape: &[usize]) -> BuiltinResult<()> {
399    let non_unit = shape.iter().filter(|&&dim| dim > 1).count();
400    if non_unit <= 1 {
401        Ok(())
402    } else {
403        Err(polyint_error("polyint: coefficients must form a vector"))
404    }
405}
406
407fn orientation_from_shape(shape: &[usize]) -> Orientation {
408    for (idx, &dim) in shape.iter().enumerate() {
409        if dim != 1 {
410            return match idx {
411                0 => Orientation::Column,
412                1 => Orientation::Row,
413                _ => Orientation::Column,
414            };
415        }
416    }
417    Orientation::Scalar
418}
419
420#[derive(Clone)]
421struct Polynomial {
422    coeffs: Vec<Complex64>,
423    orientation: Orientation,
424}
425
426#[derive(Clone, Copy)]
427enum Orientation {
428    Scalar,
429    Row,
430    Column,
431}
432
433impl Orientation {
434    fn shape_for_len(self, len: usize) -> Vec<usize> {
435        if len <= 1 {
436            return vec![1, 1];
437        }
438        match self {
439            Orientation::Scalar | Orientation::Row => vec![1, len],
440            Orientation::Column => vec![len, 1],
441        }
442    }
443}
444
445#[cfg(test)]
446pub(crate) mod tests {
447    use super::*;
448    use crate::builtins::common::gpu_helpers;
449    use crate::builtins::common::test_support;
450    use futures::executor::block_on;
451    #[cfg(feature = "wgpu")]
452    use runmat_accelerate_api::AccelProvider;
453    use runmat_builtins::LogicalArray;
454
455    fn assert_error_contains(err: crate::RuntimeError, needle: &str) {
456        assert!(
457            err.message().contains(needle),
458            "expected error containing '{needle}', got '{}'",
459            err.message()
460        );
461    }
462
463    #[test]
464    fn polyint_descriptor_signatures_cover_core_forms() {
465        let labels: Vec<&str> = POLYINT_DESCRIPTOR
466            .signatures
467            .iter()
468            .map(|signature| signature.label)
469            .collect();
470        assert!(labels.contains(&"q = polyint(p)"));
471        assert!(labels.contains(&"q = polyint(p, k)"));
472    }
473
474    #[test]
475    fn polyint_descriptor_errors_have_stable_codes() {
476        let codes: Vec<&str> = POLYINT_DESCRIPTOR
477            .errors
478            .iter()
479            .map(|error| error.code)
480            .collect();
481        assert!(codes.contains(&"RM.POLYINT.INVALID_ARGUMENT"));
482        assert!(codes.contains(&"RM.POLYINT.INVALID_INPUT"));
483        assert!(codes.contains(&"RM.POLYINT.INTERNAL"));
484    }
485
486    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
487    #[test]
488    fn integrates_polynomial_without_constant() {
489        let tensor = Tensor::new(vec![3.0, -2.0, 5.0, 7.0], vec![1, 4]).unwrap();
490        let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
491        match result {
492            Value::Tensor(t) => {
493                assert_eq!(t.shape, vec![1, 5]);
494                let expected = [0.75, -2.0 / 3.0, 2.5, 7.0, 0.0];
495                assert!(t
496                    .data
497                    .iter()
498                    .zip(expected.iter())
499                    .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
500            }
501            other => panic!("expected tensor result, got {other:?}"),
502        }
503    }
504
505    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
506    #[test]
507    fn integrates_with_constant() {
508        let tensor = Tensor::new(vec![4.0, 0.0, -8.0], vec![1, 3]).unwrap();
509        let args = vec![Value::Num(3.0)];
510        let result = polyint_builtin(Value::Tensor(tensor), args).expect("polyint");
511        match result {
512            Value::Tensor(t) => {
513                assert_eq!(t.shape, vec![1, 4]);
514                let expected = [4.0 / 3.0, 0.0, -8.0, 3.0];
515                assert!(t
516                    .data
517                    .iter()
518                    .zip(expected.iter())
519                    .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
520            }
521            other => panic!("expected tensor result, got {other:?}"),
522        }
523    }
524
525    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
526    #[test]
527    fn integrates_scalar_value() {
528        let result = polyint_builtin(Value::Num(5.0), Vec::new()).expect("polyint");
529        match result {
530            Value::Tensor(t) => {
531                assert_eq!(t.shape, vec![1, 2]);
532                assert!((t.data[0] - 5.0).abs() < 1e-12);
533                assert!(t.data[1].abs() < 1e-12);
534            }
535            other => panic!("expected tensor result, got {other:?}"),
536        }
537    }
538
539    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
540    #[test]
541    fn integrates_logical_coefficients() {
542        let logical = LogicalArray::new(vec![1, 0, 1], vec![1, 3]).unwrap();
543        let result =
544            polyint_builtin(Value::LogicalArray(logical), Vec::new()).expect("polyint logical");
545        match result {
546            Value::Tensor(t) => {
547                assert_eq!(t.shape, vec![1, 4]);
548                let expected = [1.0 / 3.0, 0.0, 1.0, 0.0];
549                assert!(t
550                    .data
551                    .iter()
552                    .zip(expected.iter())
553                    .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
554            }
555            other => panic!("expected tensor result, got {other:?}"),
556        }
557    }
558
559    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
560    #[test]
561    fn preserves_column_vector_orientation() {
562        let tensor = Tensor::new(vec![2.0, 0.0, -6.0], vec![3, 1]).unwrap();
563        let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
564        match result {
565            Value::Tensor(t) => {
566                assert_eq!(t.shape, vec![4, 1]);
567                let expected = [2.0 / 3.0, 0.0, -6.0, 0.0];
568                assert!(t
569                    .data
570                    .iter()
571                    .zip(expected.iter())
572                    .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
573            }
574            other => panic!("expected column tensor, got {other:?}"),
575        }
576    }
577
578    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
579    #[test]
580    fn integrates_complex_coefficients() {
581        let tensor =
582            ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.0), (0.0, 4.0)], vec![1, 3]).unwrap();
583        let args = vec![Value::Complex(0.0, -1.0)];
584        let result = polyint_builtin(Value::ComplexTensor(tensor), args).expect("polyint");
585        match result {
586            Value::ComplexTensor(t) => {
587                assert_eq!(t.shape, vec![1, 4]);
588                let expected = [(1.0 / 3.0, 2.0 / 3.0), (-1.5, 0.0), (0.0, 4.0), (0.0, -1.0)];
589                assert!(t
590                    .data
591                    .iter()
592                    .zip(expected.iter())
593                    .all(|((lre, lim), (rre, rim))| {
594                        (lre - rre).abs() < 1e-12 && (lim - rim).abs() < 1e-12
595                    }));
596            }
597            other => panic!("expected complex tensor, got {other:?}"),
598        }
599    }
600
601    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
602    #[test]
603    fn rejects_matrix_coefficients() {
604        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
605        let err = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect_err("expected error");
606        assert_error_contains(err, "vector");
607    }
608
609    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
610    #[test]
611    fn rejects_non_scalar_constant() {
612        let coeffs = Tensor::new(vec![1.0, -4.0, 6.0], vec![1, 3]).unwrap();
613        let constant = Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
614        let err = polyint_builtin(Value::Tensor(coeffs), vec![Value::Tensor(constant)])
615            .expect_err("expected error");
616        assert_error_contains(err, "constant of integration must be a scalar");
617    }
618
619    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
620    #[test]
621    fn rejects_excess_arguments() {
622        let tensor = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
623        let err = polyint_builtin(
624            Value::Tensor(tensor),
625            vec![Value::Num(1.0), Value::Num(2.0)],
626        )
627        .expect_err("expected error");
628        assert_eq!(err.identifier(), POLYINT_ERROR_INVALID_ARGUMENT.identifier);
629        assert_error_contains(err, "too many input arguments");
630    }
631
632    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
633    #[test]
634    fn handles_empty_input_as_zero_polynomial() {
635        let tensor = Tensor::new(vec![], vec![1, 0]).unwrap();
636        let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
637        match result {
638            Value::Num(v) => assert!(v.abs() < 1e-12),
639            Value::Tensor(t) => {
640                // Allow tensor fallback if scalar auto-boxing changes in future
641                assert_eq!(t.data.len(), 1);
642                assert!(t.data[0].abs() < 1e-12);
643            }
644            other => panic!("expected numeric result, got {other:?}"),
645        }
646    }
647
648    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
649    #[test]
650    fn empty_input_with_constant() {
651        let tensor = Tensor::new(vec![], vec![1, 0]).unwrap();
652        let result = polyint_builtin(Value::Tensor(tensor), vec![Value::Complex(1.5, -2.0)])
653            .expect("polyint");
654        match result {
655            Value::ComplexTensor(t) => {
656                assert_eq!(t.shape, vec![1, 1]);
657                assert_eq!(t.data.len(), 1);
658                let (re, im) = t.data[0];
659                assert!((re - 1.5).abs() < 1e-12);
660                assert!((im + 2.0).abs() < 1e-12);
661            }
662            other => panic!("expected complex tensor result, got {other:?}"),
663        }
664    }
665
666    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
667    #[test]
668    fn polyint_gpu_roundtrip() {
669        test_support::with_test_provider(|provider| {
670            let tensor = Tensor::new(vec![1.0, -4.0, 6.0], vec![1, 3]).unwrap();
671            let view = HostTensorView {
672                data: &tensor.data,
673                shape: &tensor.shape,
674            };
675            let handle = provider.upload(&view).expect("upload");
676            let result = polyint_builtin(Value::GpuTensor(handle), Vec::new()).expect("polyint");
677            match result {
678                Value::GpuTensor(handle) => {
679                    let gathered = test_support::gather(Value::GpuTensor(handle)).expect("gather");
680                    assert_eq!(gathered.shape, vec![1, 4]);
681                    let expected = [1.0 / 3.0, -2.0, 6.0, 0.0];
682                    assert!(gathered
683                        .data
684                        .iter()
685                        .zip(expected.iter())
686                        .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
687                }
688                other => panic!("expected GPU tensor result, got {other:?}"),
689            }
690        });
691    }
692
693    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
694    #[test]
695    fn polyint_gpu_complex_constant_reuploads_complex_result() {
696        test_support::with_test_provider(|provider| {
697            let tensor = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
698            let view = HostTensorView {
699                data: &tensor.data,
700                shape: &tensor.shape,
701            };
702            let handle = provider.upload(&view).expect("upload");
703            let result = polyint_builtin(Value::GpuTensor(handle), vec![Value::Complex(0.0, 2.0)])
704                .expect("polyint");
705            match result {
706                Value::GpuTensor(handle) => {
707                    assert_eq!(
708                        runmat_accelerate_api::handle_storage(&handle),
709                        runmat_accelerate_api::GpuTensorStorage::ComplexInterleaved
710                    );
711                    let gathered =
712                        block_on(gpu_helpers::gather_value_async(&Value::GpuTensor(handle)))
713                            .expect("gather");
714                    let Value::ComplexTensor(ct) = gathered else {
715                        panic!("expected complex tensor");
716                    };
717                    assert_eq!(ct.shape, vec![1, 3]);
718                    let expected = [(0.5, 0.0), (0.0, 0.0), (0.0, 2.0)];
719                    assert!(ct
720                        .data
721                        .iter()
722                        .zip(expected.iter())
723                        .all(|((lre, lim), (rre, rim))| {
724                            (lre - rre).abs() < 1e-12 && (lim - rim).abs() < 1e-12
725                        }));
726                }
727                other => panic!("expected complex gpu tensor, got {other:?}"),
728            }
729        });
730    }
731
732    #[test]
733    fn polyint_complex_gpu_coefficients_stay_resident() {
734        test_support::with_test_provider(|provider| {
735            let coeffs = ComplexTensor::new(vec![(1.0, 1.0), (2.0, -1.0)], vec![1, 2]).unwrap();
736            let handle = gpu_helpers::upload_complex_tensor(provider, &coeffs).expect("upload");
737            let result =
738                polyint_builtin(Value::GpuTensor(handle), vec![Value::Num(2.0)]).expect("polyint");
739            let Value::GpuTensor(handle) = result else {
740                panic!("expected complex gpu tensor");
741            };
742            assert_eq!(
743                runmat_accelerate_api::handle_storage(&handle),
744                runmat_accelerate_api::GpuTensorStorage::ComplexInterleaved
745            );
746            let gathered = block_on(gpu_helpers::gather_value_async(&Value::GpuTensor(handle)))
747                .expect("gather");
748            let Value::ComplexTensor(ct) = gathered else {
749                panic!("expected complex tensor");
750            };
751            assert_eq!(ct.shape, vec![1, 3]);
752            let expected = [(0.5, 0.5), (2.0, -1.0), (2.0, 0.0)];
753            assert!(ct
754                .data
755                .iter()
756                .zip(expected.iter())
757                .all(|((lre, lim), (rre, rim))| {
758                    (lre - rre).abs() < 1e-12 && (lim - rim).abs() < 1e-12
759                }));
760        });
761    }
762
763    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
764    #[test]
765    fn polyint_gpu_with_gpu_constant() {
766        test_support::with_test_provider(|provider| {
767            let coeffs = Tensor::new(vec![2.0, 0.0], vec![1, 2]).unwrap();
768            let coeff_view = HostTensorView {
769                data: &coeffs.data,
770                shape: &coeffs.shape,
771            };
772            let coeff_handle = provider.upload(&coeff_view).expect("upload coeffs");
773            let constant = Tensor::new(vec![3.0], vec![1, 1]).unwrap();
774            let constant_view = HostTensorView {
775                data: &constant.data,
776                shape: &constant.shape,
777            };
778            let constant_handle = provider.upload(&constant_view).expect("upload constant");
779            let result = polyint_builtin(
780                Value::GpuTensor(coeff_handle),
781                vec![Value::GpuTensor(constant_handle)],
782            )
783            .expect("polyint");
784            match result {
785                Value::GpuTensor(handle) => {
786                    let gathered =
787                        test_support::gather(Value::GpuTensor(handle)).expect("gather result");
788                    assert_eq!(gathered.shape, vec![1, 3]);
789                    let expected = [1.0, 0.0, 3.0];
790                    assert!(gathered
791                        .data
792                        .iter()
793                        .zip(expected.iter())
794                        .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
795                }
796                other => panic!("expected gpu tensor result, got {other:?}"),
797            }
798        });
799    }
800
801    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
802    #[test]
803    #[cfg(feature = "wgpu")]
804    fn polyint_wgpu_matches_cpu() {
805        let _guard = test_support::accel_test_lock();
806        let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
807            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
808        ) else {
809            return;
810        };
811        let tensor = Tensor::new(vec![3.0, -2.0, 5.0, 7.0], vec![1, 4]).unwrap();
812        let view = HostTensorView {
813            data: &tensor.data,
814            shape: &tensor.shape,
815        };
816        let handle = provider.upload(&view).expect("upload");
817        let gpu_value = polyint_builtin(Value::GpuTensor(handle), Vec::new()).expect("polyint gpu");
818        let gathered = test_support::gather(gpu_value).expect("gather");
819        let cpu_value =
820            polyint_builtin(Value::Tensor(tensor.clone()), Vec::new()).expect("polyint cpu");
821        let expected = match cpu_value {
822            Value::Tensor(t) => t,
823            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
824            other => panic!("unexpected cpu result {other:?}"),
825        };
826        assert_eq!(gathered.shape, expected.shape);
827        let tol = match provider.precision() {
828            runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
829            runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
830        };
831        gathered
832            .data
833            .iter()
834            .zip(expected.data.iter())
835            .for_each(|(lhs, rhs)| assert!((lhs - rhs).abs() < tol));
836    }
837
838    #[test]
839    #[cfg(feature = "wgpu")]
840    fn polyint_wgpu_complex_coefficients_match_cpu() {
841        let _guard = test_support::accel_test_lock();
842        let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
843            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
844        ) else {
845            return;
846        };
847        let coeffs =
848            ComplexTensor::new(vec![(3.0, 1.5), (-2.0, 0.5), (5.0, -1.0)], vec![1, 3]).unwrap();
849        let cpu_value =
850            polyint_builtin(Value::ComplexTensor(coeffs.clone()), vec![Value::Num(2.0)])
851                .expect("polyint cpu");
852        let cpu = match cpu_value {
853            Value::ComplexTensor(t) => t,
854            other => panic!("unexpected cpu result {other:?}"),
855        };
856
857        let handle = gpu_helpers::upload_complex_tensor(provider, &coeffs).expect("upload");
858        let gpu_value =
859            polyint_builtin(Value::GpuTensor(handle), vec![Value::Num(2.0)]).expect("polyint gpu");
860        let Value::GpuTensor(handle) = gpu_value else {
861            panic!("expected gpu tensor");
862        };
863        assert_eq!(
864            runmat_accelerate_api::handle_storage(&handle),
865            runmat_accelerate_api::GpuTensorStorage::ComplexInterleaved
866        );
867        let gathered =
868            block_on(gpu_helpers::gather_value_async(&Value::GpuTensor(handle))).expect("gather");
869        let Value::ComplexTensor(gpu) = gathered else {
870            panic!("expected complex tensor");
871        };
872        assert_eq!(gpu.shape, cpu.shape);
873        let tol = match provider.precision() {
874            runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
875            runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
876        };
877        gpu.data
878            .iter()
879            .zip(cpu.data.iter())
880            .for_each(|((lre, lim), (rre, rim))| {
881                assert!((lre - rre).abs() < tol);
882                assert!((lim - rim).abs() < tol);
883            });
884    }
885
886    fn polyint_builtin(coeffs: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
887        block_on(super::polyint_builtin(coeffs, rest))
888    }
889}