Skip to main content

runmat_runtime/builtins/math/poly/
polyint.rs

1//! MATLAB-compatible `polyint` builtin with GPU-aware semantics for RunMat.
2
3use log::{trace, warn};
4use num_complex::Complex64;
5use runmat_accelerate_api::HostTensorView;
6use runmat_builtins::{ComplexTensor, Tensor, Value};
7use runmat_macros::runtime_builtin;
8
9use crate::builtins::common::spec::{
10    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
12};
13use crate::builtins::common::tensor;
14use crate::builtins::math::poly::type_resolvers::polyint_type;
15use crate::dispatcher;
16use crate::{build_runtime_error, BuiltinResult, RuntimeError};
17
18const EPS: f64 = 1.0e-12;
19const BUILTIN_NAME: &str = "polyint";
20
21#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::poly::polyint")]
22pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
23    name: "polyint",
24    op_kind: GpuOpKind::Custom("polynomial-integral"),
25    supported_precisions: &[ScalarType::F32, ScalarType::F64],
26    broadcast: BroadcastSemantics::None,
27    provider_hooks: &[ProviderHook::Custom("polyint")],
28    constant_strategy: ConstantStrategy::InlineLiteral,
29    residency: ResidencyPolicy::NewHandle,
30    nan_mode: ReductionNaN::Include,
31    two_pass_threshold: None,
32    workgroup_size: None,
33    accepts_nan_mode: false,
34    notes: "Providers implement the polyint hook for real coefficient vectors; complex inputs fall back to the host.",
35};
36
37fn polyint_error(message: impl Into<String>) -> RuntimeError {
38    build_runtime_error(message)
39        .with_builtin(BUILTIN_NAME)
40        .build()
41}
42
43#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::poly::polyint")]
44pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
45    name: "polyint",
46    shape: ShapeRequirements::Any,
47    constant_strategy: ConstantStrategy::InlineLiteral,
48    elementwise: None,
49    reduction: None,
50    emits_nan: false,
51    notes: "Symbolic operation on coefficient vectors; fusion does not apply.",
52};
53
54#[runtime_builtin(
55    name = "polyint",
56    category = "math/poly",
57    summary = "Integrate polynomial coefficient vectors and append a constant of integration.",
58    keywords = "polyint,polynomial,integral,antiderivative",
59    type_resolver(polyint_type),
60    builtin_path = "crate::builtins::math::poly::polyint"
61)]
62async fn polyint_builtin(coeffs: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
63    if rest.len() > 1 {
64        return Err(polyint_error("polyint: too many input arguments"));
65    }
66
67    let constant = match rest.into_iter().next() {
68        Some(value) => parse_constant(value).await?,
69        None => Complex64::new(0.0, 0.0),
70    };
71
72    if let Value::GpuTensor(handle) = &coeffs {
73        if let Some(device_result) = try_polyint_gpu(handle, constant)? {
74            return Ok(Value::GpuTensor(device_result));
75        }
76    }
77
78    let was_gpu = matches!(coeffs, Value::GpuTensor(_));
79    polyint_host_value(coeffs, constant, was_gpu).await
80}
81
82async fn polyint_host_value(
83    coeffs: Value,
84    constant: Complex64,
85    was_gpu: bool,
86) -> BuiltinResult<Value> {
87    let polynomial = parse_polynomial(coeffs).await?;
88    let mut integrated = integrate_coeffs(&polynomial.coeffs);
89    if integrated.is_empty() {
90        integrated.push(constant);
91    } else if let Some(last) = integrated.last_mut() {
92        *last += constant;
93    }
94    let value = coeffs_to_value(&integrated, polynomial.orientation)?;
95    maybe_return_gpu(value, was_gpu)
96}
97
98fn try_polyint_gpu(
99    handle: &runmat_accelerate_api::GpuTensorHandle,
100    constant: Complex64,
101) -> BuiltinResult<Option<runmat_accelerate_api::GpuTensorHandle>> {
102    if constant.im.abs() > EPS {
103        return Ok(None);
104    }
105    ensure_vector_shape(&handle.shape)?;
106    let Some(provider) = runmat_accelerate_api::provider() else {
107        return Ok(None);
108    };
109    match provider.polyint(handle, constant.re) {
110        Ok(result) => Ok(Some(result)),
111        Err(err) => {
112            trace!("polyint: provider hook unavailable, falling back to host: {err}");
113            Ok(None)
114        }
115    }
116}
117
118fn integrate_coeffs(coeffs: &[Complex64]) -> Vec<Complex64> {
119    if coeffs.is_empty() {
120        return Vec::new();
121    }
122    let mut result = Vec::with_capacity(coeffs.len() + 1);
123    for (idx, coeff) in coeffs.iter().enumerate() {
124        let power = (coeffs.len() - idx) as f64;
125        if power <= 0.0 {
126            result.push(Complex64::new(0.0, 0.0));
127        } else {
128            result.push(*coeff / Complex64::new(power, 0.0));
129        }
130    }
131    result.push(Complex64::new(0.0, 0.0));
132    result
133}
134
135fn maybe_return_gpu(value: Value, was_gpu: bool) -> BuiltinResult<Value> {
136    if !was_gpu {
137        return Ok(value);
138    }
139    match value {
140        Value::Tensor(tensor) => {
141            if let Some(provider) = runmat_accelerate_api::provider() {
142                let view = HostTensorView {
143                    data: &tensor.data,
144                    shape: &tensor.shape,
145                };
146                match provider.upload(&view) {
147                    Ok(handle) => return Ok(Value::GpuTensor(handle)),
148                    Err(err) => {
149                        warn!("polyint: provider upload failed, keeping result on host: {err}");
150                    }
151                }
152            } else {
153                trace!("polyint: no provider available to re-upload result");
154            }
155            Ok(Value::Tensor(tensor))
156        }
157        other => Ok(other),
158    }
159}
160
161fn coeffs_to_value(coeffs: &[Complex64], orientation: Orientation) -> BuiltinResult<Value> {
162    if coeffs.iter().all(|c| c.im.abs() <= EPS) {
163        let data: Vec<f64> = coeffs.iter().map(|c| c.re).collect();
164        let shape = orientation.shape_for_len(data.len());
165        let tensor =
166            Tensor::new(data, shape).map_err(|e| polyint_error(format!("polyint: {e}")))?;
167        Ok(tensor::tensor_into_value(tensor))
168    } else {
169        let data: Vec<(f64, f64)> = coeffs.iter().map(|c| (c.re, c.im)).collect();
170        let shape = orientation.shape_for_len(data.len());
171        let tensor =
172            ComplexTensor::new(data, shape).map_err(|e| polyint_error(format!("polyint: {e}")))?;
173        Ok(Value::ComplexTensor(tensor))
174    }
175}
176
177async fn parse_polynomial(value: Value) -> BuiltinResult<Polynomial> {
178    let gathered = dispatcher::gather_if_needed_async(&value).await?;
179    match gathered {
180        Value::Tensor(tensor) => parse_tensor_coeffs(&tensor),
181        Value::ComplexTensor(tensor) => parse_complex_tensor_coeffs(&tensor),
182        Value::LogicalArray(logical) => {
183            let tensor = tensor::logical_to_tensor(&logical).map_err(polyint_error)?;
184            parse_tensor_coeffs(&tensor)
185        }
186        Value::Num(n) => Ok(Polynomial {
187            coeffs: vec![Complex64::new(n, 0.0)],
188            orientation: Orientation::Scalar,
189        }),
190        Value::Int(i) => Ok(Polynomial {
191            coeffs: vec![Complex64::new(i.to_f64(), 0.0)],
192            orientation: Orientation::Scalar,
193        }),
194        Value::Bool(b) => Ok(Polynomial {
195            coeffs: vec![Complex64::new(if b { 1.0 } else { 0.0 }, 0.0)],
196            orientation: Orientation::Scalar,
197        }),
198        Value::Complex(re, im) => Ok(Polynomial {
199            coeffs: vec![Complex64::new(re, im)],
200            orientation: Orientation::Scalar,
201        }),
202        other => Err(polyint_error(format!(
203            "polyint: expected a numeric coefficient vector, got {:?}",
204            other
205        ))),
206    }
207}
208
209fn parse_tensor_coeffs(tensor: &Tensor) -> BuiltinResult<Polynomial> {
210    ensure_vector_shape(&tensor.shape)?;
211    let orientation = orientation_from_shape(&tensor.shape);
212    Ok(Polynomial {
213        coeffs: tensor
214            .data
215            .iter()
216            .map(|&v| Complex64::new(v, 0.0))
217            .collect(),
218        orientation,
219    })
220}
221
222fn parse_complex_tensor_coeffs(tensor: &ComplexTensor) -> BuiltinResult<Polynomial> {
223    ensure_vector_shape(&tensor.shape)?;
224    let orientation = orientation_from_shape(&tensor.shape);
225    Ok(Polynomial {
226        coeffs: tensor
227            .data
228            .iter()
229            .map(|&(re, im)| Complex64::new(re, im))
230            .collect(),
231        orientation,
232    })
233}
234
235async fn parse_constant(value: Value) -> BuiltinResult<Complex64> {
236    let gathered = dispatcher::gather_if_needed_async(&value).await?;
237    match gathered {
238        Value::Tensor(tensor) => {
239            if tensor.data.len() != 1 {
240                return Err(polyint_error(
241                    "polyint: constant of integration must be a scalar",
242                ));
243            }
244            Ok(Complex64::new(tensor.data[0], 0.0))
245        }
246        Value::ComplexTensor(tensor) => {
247            if tensor.data.len() != 1 {
248                return Err(polyint_error(
249                    "polyint: constant of integration must be a scalar",
250                ));
251            }
252            let (re, im) = tensor.data[0];
253            Ok(Complex64::new(re, im))
254        }
255        Value::Num(n) => Ok(Complex64::new(n, 0.0)),
256        Value::Int(i) => Ok(Complex64::new(i.to_f64(), 0.0)),
257        Value::Bool(b) => Ok(Complex64::new(if b { 1.0 } else { 0.0 }, 0.0)),
258        Value::Complex(re, im) => Ok(Complex64::new(re, im)),
259        Value::LogicalArray(logical) => {
260            let tensor = tensor::logical_to_tensor(&logical).map_err(polyint_error)?;
261            if tensor.data.len() != 1 {
262                return Err(polyint_error(
263                    "polyint: constant of integration must be a scalar",
264                ));
265            }
266            Ok(Complex64::new(tensor.data[0], 0.0))
267        }
268        other => Err(polyint_error(format!(
269            "polyint: constant of integration must be numeric, got {:?}",
270            other
271        ))),
272    }
273}
274
275fn ensure_vector_shape(shape: &[usize]) -> BuiltinResult<()> {
276    let non_unit = shape.iter().filter(|&&dim| dim > 1).count();
277    if non_unit <= 1 {
278        Ok(())
279    } else {
280        Err(polyint_error("polyint: coefficients must form a vector"))
281    }
282}
283
284fn orientation_from_shape(shape: &[usize]) -> Orientation {
285    for (idx, &dim) in shape.iter().enumerate() {
286        if dim != 1 {
287            return match idx {
288                0 => Orientation::Column,
289                1 => Orientation::Row,
290                _ => Orientation::Column,
291            };
292        }
293    }
294    Orientation::Scalar
295}
296
297#[derive(Clone)]
298struct Polynomial {
299    coeffs: Vec<Complex64>,
300    orientation: Orientation,
301}
302
303#[derive(Clone, Copy)]
304enum Orientation {
305    Scalar,
306    Row,
307    Column,
308}
309
310impl Orientation {
311    fn shape_for_len(self, len: usize) -> Vec<usize> {
312        if len <= 1 {
313            return vec![1, 1];
314        }
315        match self {
316            Orientation::Scalar | Orientation::Row => vec![1, len],
317            Orientation::Column => vec![len, 1],
318        }
319    }
320}
321
322#[cfg(test)]
323pub(crate) mod tests {
324    use super::*;
325    use crate::builtins::common::test_support;
326    use futures::executor::block_on;
327    use runmat_builtins::LogicalArray;
328
329    fn assert_error_contains(err: crate::RuntimeError, needle: &str) {
330        assert!(
331            err.message().contains(needle),
332            "expected error containing '{needle}', got '{}'",
333            err.message()
334        );
335    }
336
337    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
338    #[test]
339    fn integrates_polynomial_without_constant() {
340        let tensor = Tensor::new(vec![3.0, -2.0, 5.0, 7.0], vec![1, 4]).unwrap();
341        let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
342        match result {
343            Value::Tensor(t) => {
344                assert_eq!(t.shape, vec![1, 5]);
345                let expected = [0.75, -2.0 / 3.0, 2.5, 7.0, 0.0];
346                assert!(t
347                    .data
348                    .iter()
349                    .zip(expected.iter())
350                    .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
351            }
352            other => panic!("expected tensor result, got {other:?}"),
353        }
354    }
355
356    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
357    #[test]
358    fn integrates_with_constant() {
359        let tensor = Tensor::new(vec![4.0, 0.0, -8.0], vec![1, 3]).unwrap();
360        let args = vec![Value::Num(3.0)];
361        let result = polyint_builtin(Value::Tensor(tensor), args).expect("polyint");
362        match result {
363            Value::Tensor(t) => {
364                assert_eq!(t.shape, vec![1, 4]);
365                let expected = [4.0 / 3.0, 0.0, -8.0, 3.0];
366                assert!(t
367                    .data
368                    .iter()
369                    .zip(expected.iter())
370                    .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
371            }
372            other => panic!("expected tensor result, got {other:?}"),
373        }
374    }
375
376    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
377    #[test]
378    fn integrates_scalar_value() {
379        let result = polyint_builtin(Value::Num(5.0), Vec::new()).expect("polyint");
380        match result {
381            Value::Tensor(t) => {
382                assert_eq!(t.shape, vec![1, 2]);
383                assert!((t.data[0] - 5.0).abs() < 1e-12);
384                assert!(t.data[1].abs() < 1e-12);
385            }
386            other => panic!("expected tensor result, got {other:?}"),
387        }
388    }
389
390    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
391    #[test]
392    fn integrates_logical_coefficients() {
393        let logical = LogicalArray::new(vec![1, 0, 1], vec![1, 3]).unwrap();
394        let result =
395            polyint_builtin(Value::LogicalArray(logical), Vec::new()).expect("polyint logical");
396        match result {
397            Value::Tensor(t) => {
398                assert_eq!(t.shape, vec![1, 4]);
399                let expected = [1.0 / 3.0, 0.0, 1.0, 0.0];
400                assert!(t
401                    .data
402                    .iter()
403                    .zip(expected.iter())
404                    .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
405            }
406            other => panic!("expected tensor result, got {other:?}"),
407        }
408    }
409
410    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
411    #[test]
412    fn preserves_column_vector_orientation() {
413        let tensor = Tensor::new(vec![2.0, 0.0, -6.0], vec![3, 1]).unwrap();
414        let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
415        match result {
416            Value::Tensor(t) => {
417                assert_eq!(t.shape, vec![4, 1]);
418                let expected = [2.0 / 3.0, 0.0, -6.0, 0.0];
419                assert!(t
420                    .data
421                    .iter()
422                    .zip(expected.iter())
423                    .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
424            }
425            other => panic!("expected column tensor, got {other:?}"),
426        }
427    }
428
429    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
430    #[test]
431    fn integrates_complex_coefficients() {
432        let tensor =
433            ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.0), (0.0, 4.0)], vec![1, 3]).unwrap();
434        let args = vec![Value::Complex(0.0, -1.0)];
435        let result = polyint_builtin(Value::ComplexTensor(tensor), args).expect("polyint");
436        match result {
437            Value::ComplexTensor(t) => {
438                assert_eq!(t.shape, vec![1, 4]);
439                let expected = [(1.0 / 3.0, 2.0 / 3.0), (-1.5, 0.0), (0.0, 4.0), (0.0, -1.0)];
440                assert!(t
441                    .data
442                    .iter()
443                    .zip(expected.iter())
444                    .all(|((lre, lim), (rre, rim))| {
445                        (lre - rre).abs() < 1e-12 && (lim - rim).abs() < 1e-12
446                    }));
447            }
448            other => panic!("expected complex tensor, got {other:?}"),
449        }
450    }
451
452    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
453    #[test]
454    fn rejects_matrix_coefficients() {
455        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
456        let err = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect_err("expected error");
457        assert_error_contains(err, "vector");
458    }
459
460    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
461    #[test]
462    fn rejects_non_scalar_constant() {
463        let coeffs = Tensor::new(vec![1.0, -4.0, 6.0], vec![1, 3]).unwrap();
464        let constant = Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
465        let err = polyint_builtin(Value::Tensor(coeffs), vec![Value::Tensor(constant)])
466            .expect_err("expected error");
467        assert_error_contains(err, "constant of integration must be a scalar");
468    }
469
470    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
471    #[test]
472    fn rejects_excess_arguments() {
473        let tensor = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
474        let err = polyint_builtin(
475            Value::Tensor(tensor),
476            vec![Value::Num(1.0), Value::Num(2.0)],
477        )
478        .expect_err("expected error");
479        assert_error_contains(err, "too many input arguments");
480    }
481
482    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
483    #[test]
484    fn handles_empty_input_as_zero_polynomial() {
485        let tensor = Tensor::new(vec![], vec![1, 0]).unwrap();
486        let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
487        match result {
488            Value::Num(v) => assert!(v.abs() < 1e-12),
489            Value::Tensor(t) => {
490                // Allow tensor fallback if scalar auto-boxing changes in future
491                assert_eq!(t.data.len(), 1);
492                assert!(t.data[0].abs() < 1e-12);
493            }
494            other => panic!("expected numeric result, got {other:?}"),
495        }
496    }
497
498    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
499    #[test]
500    fn empty_input_with_constant() {
501        let tensor = Tensor::new(vec![], vec![1, 0]).unwrap();
502        let result = polyint_builtin(Value::Tensor(tensor), vec![Value::Complex(1.5, -2.0)])
503            .expect("polyint");
504        match result {
505            Value::ComplexTensor(t) => {
506                assert_eq!(t.shape, vec![1, 1]);
507                assert_eq!(t.data.len(), 1);
508                let (re, im) = t.data[0];
509                assert!((re - 1.5).abs() < 1e-12);
510                assert!((im + 2.0).abs() < 1e-12);
511            }
512            other => panic!("expected complex tensor result, got {other:?}"),
513        }
514    }
515
516    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
517    #[test]
518    fn polyint_gpu_roundtrip() {
519        test_support::with_test_provider(|provider| {
520            let tensor = Tensor::new(vec![1.0, -4.0, 6.0], vec![1, 3]).unwrap();
521            let view = HostTensorView {
522                data: &tensor.data,
523                shape: &tensor.shape,
524            };
525            let handle = provider.upload(&view).expect("upload");
526            let result = polyint_builtin(Value::GpuTensor(handle), Vec::new()).expect("polyint");
527            match result {
528                Value::GpuTensor(handle) => {
529                    let gathered = test_support::gather(Value::GpuTensor(handle)).expect("gather");
530                    assert_eq!(gathered.shape, vec![1, 4]);
531                    let expected = [1.0 / 3.0, -2.0, 6.0, 0.0];
532                    assert!(gathered
533                        .data
534                        .iter()
535                        .zip(expected.iter())
536                        .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
537                }
538                other => panic!("expected GPU tensor result, got {other:?}"),
539            }
540        });
541    }
542
543    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
544    #[test]
545    fn polyint_gpu_complex_constant_falls_back_to_host() {
546        test_support::with_test_provider(|provider| {
547            let tensor = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
548            let view = HostTensorView {
549                data: &tensor.data,
550                shape: &tensor.shape,
551            };
552            let handle = provider.upload(&view).expect("upload");
553            let result = polyint_builtin(Value::GpuTensor(handle), vec![Value::Complex(0.0, 2.0)])
554                .expect("polyint");
555            match result {
556                Value::ComplexTensor(ct) => {
557                    assert_eq!(ct.shape, vec![1, 3]);
558                    let expected = [(0.5, 0.0), (0.0, 0.0), (0.0, 2.0)];
559                    assert!(ct
560                        .data
561                        .iter()
562                        .zip(expected.iter())
563                        .all(|((lre, lim), (rre, rim))| {
564                            (lre - rre).abs() < 1e-12 && (lim - rim).abs() < 1e-12
565                        }));
566                }
567                other => panic!("expected complex tensor fall-back, got {other:?}"),
568            }
569        });
570    }
571
572    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
573    #[test]
574    fn polyint_gpu_with_gpu_constant() {
575        test_support::with_test_provider(|provider| {
576            let coeffs = Tensor::new(vec![2.0, 0.0], vec![1, 2]).unwrap();
577            let coeff_view = HostTensorView {
578                data: &coeffs.data,
579                shape: &coeffs.shape,
580            };
581            let coeff_handle = provider.upload(&coeff_view).expect("upload coeffs");
582            let constant = Tensor::new(vec![3.0], vec![1, 1]).unwrap();
583            let constant_view = HostTensorView {
584                data: &constant.data,
585                shape: &constant.shape,
586            };
587            let constant_handle = provider.upload(&constant_view).expect("upload constant");
588            let result = polyint_builtin(
589                Value::GpuTensor(coeff_handle),
590                vec![Value::GpuTensor(constant_handle)],
591            )
592            .expect("polyint");
593            match result {
594                Value::GpuTensor(handle) => {
595                    let gathered =
596                        test_support::gather(Value::GpuTensor(handle)).expect("gather result");
597                    assert_eq!(gathered.shape, vec![1, 3]);
598                    let expected = [1.0, 0.0, 3.0];
599                    assert!(gathered
600                        .data
601                        .iter()
602                        .zip(expected.iter())
603                        .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
604                }
605                other => panic!("expected gpu tensor result, got {other:?}"),
606            }
607        });
608    }
609
610    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
611    #[test]
612    #[cfg(feature = "wgpu")]
613    fn polyint_wgpu_matches_cpu() {
614        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
615            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
616        );
617        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
618        let tensor = Tensor::new(vec![3.0, -2.0, 5.0, 7.0], vec![1, 4]).unwrap();
619        let view = HostTensorView {
620            data: &tensor.data,
621            shape: &tensor.shape,
622        };
623        let handle = provider.upload(&view).expect("upload");
624        let gpu_value = polyint_builtin(Value::GpuTensor(handle), Vec::new()).expect("polyint gpu");
625        let gathered = test_support::gather(gpu_value).expect("gather");
626        let cpu_value =
627            polyint_builtin(Value::Tensor(tensor.clone()), Vec::new()).expect("polyint cpu");
628        let expected = match cpu_value {
629            Value::Tensor(t) => t,
630            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
631            other => panic!("unexpected cpu result {other:?}"),
632        };
633        assert_eq!(gathered.shape, expected.shape);
634        let tol = match provider.precision() {
635            runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
636            runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
637        };
638        gathered
639            .data
640            .iter()
641            .zip(expected.data.iter())
642            .for_each(|(lhs, rhs)| assert!((lhs - rhs).abs() < tol));
643    }
644
645    fn polyint_builtin(coeffs: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
646        block_on(super::polyint_builtin(coeffs, rest))
647    }
648}