Skip to main content

runmat_runtime/builtins/math/poly/
polyint.rs

1//! MATLAB-compatible `polyint` builtin with GPU-aware semantics for RunMat.
2
3use log::{trace, warn};
4use num_complex::Complex64;
5use runmat_accelerate_api::HostTensorView;
6use runmat_builtins::{
7    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
8    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
9    ComplexTensor, Tensor, Value,
10};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::spec::{
14    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
16};
17use crate::builtins::common::tensor;
18use crate::builtins::math::poly::type_resolvers::polyint_type;
19use crate::dispatcher;
20use crate::{build_runtime_error, BuiltinResult, RuntimeError};
21
22const EPS: f64 = 1.0e-12;
23const BUILTIN_NAME: &str = "polyint";
24
25const POLYINT_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
26    name: "q",
27    ty: BuiltinParamType::Any,
28    arity: BuiltinParamArity::Required,
29    default: None,
30    description: "Integrated polynomial coefficient vector.",
31}];
32
33const POLYINT_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
34    name: "p",
35    ty: BuiltinParamType::Any,
36    arity: BuiltinParamArity::Required,
37    default: None,
38    description: "Polynomial coefficient vector.",
39}];
40
41const POLYINT_INPUTS_WITH_K: [BuiltinParamDescriptor; 2] = [
42    BuiltinParamDescriptor {
43        name: "p",
44        ty: BuiltinParamType::Any,
45        arity: BuiltinParamArity::Required,
46        default: None,
47        description: "Polynomial coefficient vector.",
48    },
49    BuiltinParamDescriptor {
50        name: "k",
51        ty: BuiltinParamType::Any,
52        arity: BuiltinParamArity::Optional,
53        default: None,
54        description: "Constant of integration.",
55    },
56];
57
58const POLYINT_SIGNATURES: [BuiltinSignatureDescriptor; 2] = [
59    BuiltinSignatureDescriptor {
60        label: "q = polyint(p)",
61        inputs: &POLYINT_INPUTS,
62        outputs: &POLYINT_OUTPUT,
63    },
64    BuiltinSignatureDescriptor {
65        label: "q = polyint(p, k)",
66        inputs: &POLYINT_INPUTS_WITH_K,
67        outputs: &POLYINT_OUTPUT,
68    },
69];
70
71const POLYINT_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
72    code: "RM.POLYINT.INVALID_ARGUMENT",
73    identifier: Some("RunMat:polyint:InvalidArgument"),
74    when: "Input arity or integration-constant argument is malformed.",
75    message: "polyint: invalid argument",
76};
77
78const POLYINT_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
79    code: "RM.POLYINT.INVALID_INPUT",
80    identifier: Some("RunMat:polyint:InvalidInput"),
81    when: "Input polynomial cannot be interpreted as a numeric coefficient vector.",
82    message: "polyint: invalid input",
83};
84
85const POLYINT_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
86    code: "RM.POLYINT.INTERNAL",
87    identifier: Some("RunMat:polyint:Internal"),
88    when: "Runtime fails while building output tensors or provider fallback paths.",
89    message: "polyint: internal runtime failure",
90};
91
92const POLYINT_ERRORS: [BuiltinErrorDescriptor; 3] = [
93    POLYINT_ERROR_INVALID_ARGUMENT,
94    POLYINT_ERROR_INVALID_INPUT,
95    POLYINT_ERROR_INTERNAL,
96];
97
98pub const POLYINT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
99    signatures: &POLYINT_SIGNATURES,
100    output_mode: BuiltinOutputMode::Fixed,
101    completion_policy: BuiltinCompletionPolicy::Public,
102    errors: &POLYINT_ERRORS,
103};
104
105#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::poly::polyint")]
106pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
107    name: "polyint",
108    op_kind: GpuOpKind::Custom("polynomial-integral"),
109    supported_precisions: &[ScalarType::F32, ScalarType::F64],
110    broadcast: BroadcastSemantics::None,
111    provider_hooks: &[ProviderHook::Custom("polyint")],
112    constant_strategy: ConstantStrategy::InlineLiteral,
113    residency: ResidencyPolicy::NewHandle,
114    nan_mode: ReductionNaN::Include,
115    two_pass_threshold: None,
116    workgroup_size: None,
117    accepts_nan_mode: false,
118    notes: "Providers implement the polyint hook for real coefficient vectors; complex inputs fall back to the host.",
119};
120
121fn polyint_error(message: impl Into<String>) -> RuntimeError {
122    polyint_error_with(message, &POLYINT_ERROR_INVALID_INPUT)
123}
124
125fn polyint_argument_error(message: impl Into<String>) -> RuntimeError {
126    polyint_error_with(message, &POLYINT_ERROR_INVALID_ARGUMENT)
127}
128
129fn polyint_error_with(
130    message: impl Into<String>,
131    error: &'static BuiltinErrorDescriptor,
132) -> RuntimeError {
133    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
134    if let Some(identifier) = error.identifier {
135        builder = builder.with_identifier(identifier);
136    }
137    builder.build()
138}
139
140#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::poly::polyint")]
141pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
142    name: "polyint",
143    shape: ShapeRequirements::Any,
144    constant_strategy: ConstantStrategy::InlineLiteral,
145    elementwise: None,
146    reduction: None,
147    emits_nan: false,
148    notes: "Symbolic operation on coefficient vectors; fusion does not apply.",
149};
150
151#[runtime_builtin(
152    name = "polyint",
153    category = "math/poly",
154    summary = "Integrate polynomial coefficient vectors and append a constant of integration.",
155    keywords = "polyint,polynomial,integral,antiderivative",
156    type_resolver(polyint_type),
157    descriptor(crate::builtins::math::poly::polyint::POLYINT_DESCRIPTOR),
158    builtin_path = "crate::builtins::math::poly::polyint"
159)]
160async fn polyint_builtin(coeffs: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
161    if rest.len() > 1 {
162        return Err(polyint_argument_error("polyint: too many input arguments"));
163    }
164
165    let constant = match rest.into_iter().next() {
166        Some(value) => parse_constant(value).await?,
167        None => Complex64::new(0.0, 0.0),
168    };
169
170    if let Value::GpuTensor(handle) = &coeffs {
171        if let Some(device_result) = try_polyint_gpu(handle, constant)? {
172            return Ok(Value::GpuTensor(device_result));
173        }
174    }
175
176    let was_gpu = matches!(coeffs, Value::GpuTensor(_));
177    polyint_host_value(coeffs, constant, was_gpu).await
178}
179
180async fn polyint_host_value(
181    coeffs: Value,
182    constant: Complex64,
183    was_gpu: bool,
184) -> BuiltinResult<Value> {
185    let polynomial = parse_polynomial(coeffs).await?;
186    let mut integrated = integrate_coeffs(&polynomial.coeffs);
187    if integrated.is_empty() {
188        integrated.push(constant);
189    } else if let Some(last) = integrated.last_mut() {
190        *last += constant;
191    }
192    let value = coeffs_to_value(&integrated, polynomial.orientation)?;
193    maybe_return_gpu(value, was_gpu)
194}
195
196fn try_polyint_gpu(
197    handle: &runmat_accelerate_api::GpuTensorHandle,
198    constant: Complex64,
199) -> BuiltinResult<Option<runmat_accelerate_api::GpuTensorHandle>> {
200    if constant.im.abs() > EPS {
201        return Ok(None);
202    }
203    ensure_vector_shape(&handle.shape)?;
204    let Some(provider) = runmat_accelerate_api::provider() else {
205        return Ok(None);
206    };
207    match provider.polyint(handle, constant.re) {
208        Ok(result) => Ok(Some(result)),
209        Err(err) => {
210            trace!("polyint: provider hook unavailable, falling back to host: {err}");
211            Ok(None)
212        }
213    }
214}
215
216fn integrate_coeffs(coeffs: &[Complex64]) -> Vec<Complex64> {
217    if coeffs.is_empty() {
218        return Vec::new();
219    }
220    let mut result = Vec::with_capacity(coeffs.len() + 1);
221    for (idx, coeff) in coeffs.iter().enumerate() {
222        let power = (coeffs.len() - idx) as f64;
223        if power <= 0.0 {
224            result.push(Complex64::new(0.0, 0.0));
225        } else {
226            result.push(*coeff / Complex64::new(power, 0.0));
227        }
228    }
229    result.push(Complex64::new(0.0, 0.0));
230    result
231}
232
233fn maybe_return_gpu(value: Value, was_gpu: bool) -> BuiltinResult<Value> {
234    if !was_gpu {
235        return Ok(value);
236    }
237    match value {
238        Value::Tensor(tensor) => {
239            if let Some(provider) = runmat_accelerate_api::provider() {
240                let view = HostTensorView {
241                    data: &tensor.data,
242                    shape: &tensor.shape,
243                };
244                match provider.upload(&view) {
245                    Ok(handle) => return Ok(Value::GpuTensor(handle)),
246                    Err(err) => {
247                        warn!("polyint: provider upload failed, keeping result on host: {err}");
248                    }
249                }
250            } else {
251                trace!("polyint: no provider available to re-upload result");
252            }
253            Ok(Value::Tensor(tensor))
254        }
255        other => Ok(other),
256    }
257}
258
259fn coeffs_to_value(coeffs: &[Complex64], orientation: Orientation) -> BuiltinResult<Value> {
260    if coeffs.iter().all(|c| c.im.abs() <= EPS) {
261        let data: Vec<f64> = coeffs.iter().map(|c| c.re).collect();
262        let shape = orientation.shape_for_len(data.len());
263        let tensor =
264            Tensor::new(data, shape).map_err(|e| polyint_error(format!("polyint: {e}")))?;
265        Ok(tensor::tensor_into_value(tensor))
266    } else {
267        let data: Vec<(f64, f64)> = coeffs.iter().map(|c| (c.re, c.im)).collect();
268        let shape = orientation.shape_for_len(data.len());
269        let tensor =
270            ComplexTensor::new(data, shape).map_err(|e| polyint_error(format!("polyint: {e}")))?;
271        Ok(Value::ComplexTensor(tensor))
272    }
273}
274
275async fn parse_polynomial(value: Value) -> BuiltinResult<Polynomial> {
276    let gathered = dispatcher::gather_if_needed_async(&value).await?;
277    match gathered {
278        Value::Tensor(tensor) => parse_tensor_coeffs(&tensor),
279        Value::ComplexTensor(tensor) => parse_complex_tensor_coeffs(&tensor),
280        Value::LogicalArray(logical) => {
281            let tensor = tensor::logical_to_tensor(&logical).map_err(polyint_error)?;
282            parse_tensor_coeffs(&tensor)
283        }
284        Value::Num(n) => Ok(Polynomial {
285            coeffs: vec![Complex64::new(n, 0.0)],
286            orientation: Orientation::Scalar,
287        }),
288        Value::Int(i) => Ok(Polynomial {
289            coeffs: vec![Complex64::new(i.to_f64(), 0.0)],
290            orientation: Orientation::Scalar,
291        }),
292        Value::Bool(b) => Ok(Polynomial {
293            coeffs: vec![Complex64::new(if b { 1.0 } else { 0.0 }, 0.0)],
294            orientation: Orientation::Scalar,
295        }),
296        Value::Complex(re, im) => Ok(Polynomial {
297            coeffs: vec![Complex64::new(re, im)],
298            orientation: Orientation::Scalar,
299        }),
300        other => Err(polyint_error(format!(
301            "polyint: expected a numeric coefficient vector, got {:?}",
302            other
303        ))),
304    }
305}
306
307fn parse_tensor_coeffs(tensor: &Tensor) -> BuiltinResult<Polynomial> {
308    ensure_vector_shape(&tensor.shape)?;
309    let orientation = orientation_from_shape(&tensor.shape);
310    Ok(Polynomial {
311        coeffs: tensor
312            .data
313            .iter()
314            .map(|&v| Complex64::new(v, 0.0))
315            .collect(),
316        orientation,
317    })
318}
319
320fn parse_complex_tensor_coeffs(tensor: &ComplexTensor) -> BuiltinResult<Polynomial> {
321    ensure_vector_shape(&tensor.shape)?;
322    let orientation = orientation_from_shape(&tensor.shape);
323    Ok(Polynomial {
324        coeffs: tensor
325            .data
326            .iter()
327            .map(|&(re, im)| Complex64::new(re, im))
328            .collect(),
329        orientation,
330    })
331}
332
333async fn parse_constant(value: Value) -> BuiltinResult<Complex64> {
334    let gathered = dispatcher::gather_if_needed_async(&value).await?;
335    match gathered {
336        Value::Tensor(tensor) => {
337            if tensor.data.len() != 1 {
338                return Err(polyint_error(
339                    "polyint: constant of integration must be a scalar",
340                ));
341            }
342            Ok(Complex64::new(tensor.data[0], 0.0))
343        }
344        Value::ComplexTensor(tensor) => {
345            if tensor.data.len() != 1 {
346                return Err(polyint_error(
347                    "polyint: constant of integration must be a scalar",
348                ));
349            }
350            let (re, im) = tensor.data[0];
351            Ok(Complex64::new(re, im))
352        }
353        Value::Num(n) => Ok(Complex64::new(n, 0.0)),
354        Value::Int(i) => Ok(Complex64::new(i.to_f64(), 0.0)),
355        Value::Bool(b) => Ok(Complex64::new(if b { 1.0 } else { 0.0 }, 0.0)),
356        Value::Complex(re, im) => Ok(Complex64::new(re, im)),
357        Value::LogicalArray(logical) => {
358            let tensor = tensor::logical_to_tensor(&logical).map_err(polyint_error)?;
359            if tensor.data.len() != 1 {
360                return Err(polyint_error(
361                    "polyint: constant of integration must be a scalar",
362                ));
363            }
364            Ok(Complex64::new(tensor.data[0], 0.0))
365        }
366        other => Err(polyint_error(format!(
367            "polyint: constant of integration must be numeric, got {:?}",
368            other
369        ))),
370    }
371}
372
373fn ensure_vector_shape(shape: &[usize]) -> BuiltinResult<()> {
374    let non_unit = shape.iter().filter(|&&dim| dim > 1).count();
375    if non_unit <= 1 {
376        Ok(())
377    } else {
378        Err(polyint_error("polyint: coefficients must form a vector"))
379    }
380}
381
382fn orientation_from_shape(shape: &[usize]) -> Orientation {
383    for (idx, &dim) in shape.iter().enumerate() {
384        if dim != 1 {
385            return match idx {
386                0 => Orientation::Column,
387                1 => Orientation::Row,
388                _ => Orientation::Column,
389            };
390        }
391    }
392    Orientation::Scalar
393}
394
395#[derive(Clone)]
396struct Polynomial {
397    coeffs: Vec<Complex64>,
398    orientation: Orientation,
399}
400
401#[derive(Clone, Copy)]
402enum Orientation {
403    Scalar,
404    Row,
405    Column,
406}
407
408impl Orientation {
409    fn shape_for_len(self, len: usize) -> Vec<usize> {
410        if len <= 1 {
411            return vec![1, 1];
412        }
413        match self {
414            Orientation::Scalar | Orientation::Row => vec![1, len],
415            Orientation::Column => vec![len, 1],
416        }
417    }
418}
419
420#[cfg(test)]
421pub(crate) mod tests {
422    use super::*;
423    use crate::builtins::common::test_support;
424    use futures::executor::block_on;
425    use runmat_builtins::LogicalArray;
426
427    fn assert_error_contains(err: crate::RuntimeError, needle: &str) {
428        assert!(
429            err.message().contains(needle),
430            "expected error containing '{needle}', got '{}'",
431            err.message()
432        );
433    }
434
435    #[test]
436    fn polyint_descriptor_signatures_cover_core_forms() {
437        let labels: Vec<&str> = POLYINT_DESCRIPTOR
438            .signatures
439            .iter()
440            .map(|signature| signature.label)
441            .collect();
442        assert!(labels.contains(&"q = polyint(p)"));
443        assert!(labels.contains(&"q = polyint(p, k)"));
444    }
445
446    #[test]
447    fn polyint_descriptor_errors_have_stable_codes() {
448        let codes: Vec<&str> = POLYINT_DESCRIPTOR
449            .errors
450            .iter()
451            .map(|error| error.code)
452            .collect();
453        assert!(codes.contains(&"RM.POLYINT.INVALID_ARGUMENT"));
454        assert!(codes.contains(&"RM.POLYINT.INVALID_INPUT"));
455        assert!(codes.contains(&"RM.POLYINT.INTERNAL"));
456    }
457
458    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
459    #[test]
460    fn integrates_polynomial_without_constant() {
461        let tensor = Tensor::new(vec![3.0, -2.0, 5.0, 7.0], vec![1, 4]).unwrap();
462        let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
463        match result {
464            Value::Tensor(t) => {
465                assert_eq!(t.shape, vec![1, 5]);
466                let expected = [0.75, -2.0 / 3.0, 2.5, 7.0, 0.0];
467                assert!(t
468                    .data
469                    .iter()
470                    .zip(expected.iter())
471                    .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
472            }
473            other => panic!("expected tensor result, got {other:?}"),
474        }
475    }
476
477    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
478    #[test]
479    fn integrates_with_constant() {
480        let tensor = Tensor::new(vec![4.0, 0.0, -8.0], vec![1, 3]).unwrap();
481        let args = vec![Value::Num(3.0)];
482        let result = polyint_builtin(Value::Tensor(tensor), args).expect("polyint");
483        match result {
484            Value::Tensor(t) => {
485                assert_eq!(t.shape, vec![1, 4]);
486                let expected = [4.0 / 3.0, 0.0, -8.0, 3.0];
487                assert!(t
488                    .data
489                    .iter()
490                    .zip(expected.iter())
491                    .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
492            }
493            other => panic!("expected tensor result, got {other:?}"),
494        }
495    }
496
497    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
498    #[test]
499    fn integrates_scalar_value() {
500        let result = polyint_builtin(Value::Num(5.0), Vec::new()).expect("polyint");
501        match result {
502            Value::Tensor(t) => {
503                assert_eq!(t.shape, vec![1, 2]);
504                assert!((t.data[0] - 5.0).abs() < 1e-12);
505                assert!(t.data[1].abs() < 1e-12);
506            }
507            other => panic!("expected tensor result, got {other:?}"),
508        }
509    }
510
511    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
512    #[test]
513    fn integrates_logical_coefficients() {
514        let logical = LogicalArray::new(vec![1, 0, 1], vec![1, 3]).unwrap();
515        let result =
516            polyint_builtin(Value::LogicalArray(logical), Vec::new()).expect("polyint logical");
517        match result {
518            Value::Tensor(t) => {
519                assert_eq!(t.shape, vec![1, 4]);
520                let expected = [1.0 / 3.0, 0.0, 1.0, 0.0];
521                assert!(t
522                    .data
523                    .iter()
524                    .zip(expected.iter())
525                    .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
526            }
527            other => panic!("expected tensor result, got {other:?}"),
528        }
529    }
530
531    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
532    #[test]
533    fn preserves_column_vector_orientation() {
534        let tensor = Tensor::new(vec![2.0, 0.0, -6.0], vec![3, 1]).unwrap();
535        let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
536        match result {
537            Value::Tensor(t) => {
538                assert_eq!(t.shape, vec![4, 1]);
539                let expected = [2.0 / 3.0, 0.0, -6.0, 0.0];
540                assert!(t
541                    .data
542                    .iter()
543                    .zip(expected.iter())
544                    .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
545            }
546            other => panic!("expected column tensor, got {other:?}"),
547        }
548    }
549
550    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
551    #[test]
552    fn integrates_complex_coefficients() {
553        let tensor =
554            ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.0), (0.0, 4.0)], vec![1, 3]).unwrap();
555        let args = vec![Value::Complex(0.0, -1.0)];
556        let result = polyint_builtin(Value::ComplexTensor(tensor), args).expect("polyint");
557        match result {
558            Value::ComplexTensor(t) => {
559                assert_eq!(t.shape, vec![1, 4]);
560                let expected = [(1.0 / 3.0, 2.0 / 3.0), (-1.5, 0.0), (0.0, 4.0), (0.0, -1.0)];
561                assert!(t
562                    .data
563                    .iter()
564                    .zip(expected.iter())
565                    .all(|((lre, lim), (rre, rim))| {
566                        (lre - rre).abs() < 1e-12 && (lim - rim).abs() < 1e-12
567                    }));
568            }
569            other => panic!("expected complex tensor, got {other:?}"),
570        }
571    }
572
573    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
574    #[test]
575    fn rejects_matrix_coefficients() {
576        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
577        let err = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect_err("expected error");
578        assert_error_contains(err, "vector");
579    }
580
581    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
582    #[test]
583    fn rejects_non_scalar_constant() {
584        let coeffs = Tensor::new(vec![1.0, -4.0, 6.0], vec![1, 3]).unwrap();
585        let constant = Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
586        let err = polyint_builtin(Value::Tensor(coeffs), vec![Value::Tensor(constant)])
587            .expect_err("expected error");
588        assert_error_contains(err, "constant of integration must be a scalar");
589    }
590
591    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
592    #[test]
593    fn rejects_excess_arguments() {
594        let tensor = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
595        let err = polyint_builtin(
596            Value::Tensor(tensor),
597            vec![Value::Num(1.0), Value::Num(2.0)],
598        )
599        .expect_err("expected error");
600        assert_eq!(err.identifier(), POLYINT_ERROR_INVALID_ARGUMENT.identifier);
601        assert_error_contains(err, "too many input arguments");
602    }
603
604    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
605    #[test]
606    fn handles_empty_input_as_zero_polynomial() {
607        let tensor = Tensor::new(vec![], vec![1, 0]).unwrap();
608        let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
609        match result {
610            Value::Num(v) => assert!(v.abs() < 1e-12),
611            Value::Tensor(t) => {
612                // Allow tensor fallback if scalar auto-boxing changes in future
613                assert_eq!(t.data.len(), 1);
614                assert!(t.data[0].abs() < 1e-12);
615            }
616            other => panic!("expected numeric result, got {other:?}"),
617        }
618    }
619
620    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
621    #[test]
622    fn empty_input_with_constant() {
623        let tensor = Tensor::new(vec![], vec![1, 0]).unwrap();
624        let result = polyint_builtin(Value::Tensor(tensor), vec![Value::Complex(1.5, -2.0)])
625            .expect("polyint");
626        match result {
627            Value::ComplexTensor(t) => {
628                assert_eq!(t.shape, vec![1, 1]);
629                assert_eq!(t.data.len(), 1);
630                let (re, im) = t.data[0];
631                assert!((re - 1.5).abs() < 1e-12);
632                assert!((im + 2.0).abs() < 1e-12);
633            }
634            other => panic!("expected complex tensor result, got {other:?}"),
635        }
636    }
637
638    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
639    #[test]
640    fn polyint_gpu_roundtrip() {
641        test_support::with_test_provider(|provider| {
642            let tensor = Tensor::new(vec![1.0, -4.0, 6.0], vec![1, 3]).unwrap();
643            let view = HostTensorView {
644                data: &tensor.data,
645                shape: &tensor.shape,
646            };
647            let handle = provider.upload(&view).expect("upload");
648            let result = polyint_builtin(Value::GpuTensor(handle), Vec::new()).expect("polyint");
649            match result {
650                Value::GpuTensor(handle) => {
651                    let gathered = test_support::gather(Value::GpuTensor(handle)).expect("gather");
652                    assert_eq!(gathered.shape, vec![1, 4]);
653                    let expected = [1.0 / 3.0, -2.0, 6.0, 0.0];
654                    assert!(gathered
655                        .data
656                        .iter()
657                        .zip(expected.iter())
658                        .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
659                }
660                other => panic!("expected GPU tensor result, got {other:?}"),
661            }
662        });
663    }
664
665    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
666    #[test]
667    fn polyint_gpu_complex_constant_falls_back_to_host() {
668        test_support::with_test_provider(|provider| {
669            let tensor = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
670            let view = HostTensorView {
671                data: &tensor.data,
672                shape: &tensor.shape,
673            };
674            let handle = provider.upload(&view).expect("upload");
675            let result = polyint_builtin(Value::GpuTensor(handle), vec![Value::Complex(0.0, 2.0)])
676                .expect("polyint");
677            match result {
678                Value::ComplexTensor(ct) => {
679                    assert_eq!(ct.shape, vec![1, 3]);
680                    let expected = [(0.5, 0.0), (0.0, 0.0), (0.0, 2.0)];
681                    assert!(ct
682                        .data
683                        .iter()
684                        .zip(expected.iter())
685                        .all(|((lre, lim), (rre, rim))| {
686                            (lre - rre).abs() < 1e-12 && (lim - rim).abs() < 1e-12
687                        }));
688                }
689                other => panic!("expected complex tensor fall-back, got {other:?}"),
690            }
691        });
692    }
693
694    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
695    #[test]
696    fn polyint_gpu_with_gpu_constant() {
697        test_support::with_test_provider(|provider| {
698            let coeffs = Tensor::new(vec![2.0, 0.0], vec![1, 2]).unwrap();
699            let coeff_view = HostTensorView {
700                data: &coeffs.data,
701                shape: &coeffs.shape,
702            };
703            let coeff_handle = provider.upload(&coeff_view).expect("upload coeffs");
704            let constant = Tensor::new(vec![3.0], vec![1, 1]).unwrap();
705            let constant_view = HostTensorView {
706                data: &constant.data,
707                shape: &constant.shape,
708            };
709            let constant_handle = provider.upload(&constant_view).expect("upload constant");
710            let result = polyint_builtin(
711                Value::GpuTensor(coeff_handle),
712                vec![Value::GpuTensor(constant_handle)],
713            )
714            .expect("polyint");
715            match result {
716                Value::GpuTensor(handle) => {
717                    let gathered =
718                        test_support::gather(Value::GpuTensor(handle)).expect("gather result");
719                    assert_eq!(gathered.shape, vec![1, 3]);
720                    let expected = [1.0, 0.0, 3.0];
721                    assert!(gathered
722                        .data
723                        .iter()
724                        .zip(expected.iter())
725                        .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
726                }
727                other => panic!("expected gpu tensor result, got {other:?}"),
728            }
729        });
730    }
731
732    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
733    #[test]
734    #[cfg(feature = "wgpu")]
735    fn polyint_wgpu_matches_cpu() {
736        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
737            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
738        );
739        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
740        let tensor = Tensor::new(vec![3.0, -2.0, 5.0, 7.0], vec![1, 4]).unwrap();
741        let view = HostTensorView {
742            data: &tensor.data,
743            shape: &tensor.shape,
744        };
745        let handle = provider.upload(&view).expect("upload");
746        let gpu_value = polyint_builtin(Value::GpuTensor(handle), Vec::new()).expect("polyint gpu");
747        let gathered = test_support::gather(gpu_value).expect("gather");
748        let cpu_value =
749            polyint_builtin(Value::Tensor(tensor.clone()), Vec::new()).expect("polyint cpu");
750        let expected = match cpu_value {
751            Value::Tensor(t) => t,
752            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
753            other => panic!("unexpected cpu result {other:?}"),
754        };
755        assert_eq!(gathered.shape, expected.shape);
756        let tol = match provider.precision() {
757            runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
758            runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
759        };
760        gathered
761            .data
762            .iter()
763            .zip(expected.data.iter())
764            .for_each(|(lhs, rhs)| assert!((lhs - rhs).abs() < tol));
765    }
766
767    fn polyint_builtin(coeffs: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
768        block_on(super::polyint_builtin(coeffs, rest))
769    }
770}