1use log::{trace, warn};
4use num_complex::Complex64;
5use runmat_accelerate_api::HostTensorView;
6use runmat_builtins::{ComplexTensor, Tensor, Value};
7use runmat_macros::runtime_builtin;
8
9use crate::builtins::common::spec::{
10 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
12};
13use crate::builtins::common::tensor;
14use crate::builtins::math::poly::type_resolvers::polyint_type;
15use crate::dispatcher;
16use crate::{build_runtime_error, BuiltinResult, RuntimeError};
17
18const EPS: f64 = 1.0e-12;
19const BUILTIN_NAME: &str = "polyint";
20
21#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::poly::polyint")]
22pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
23 name: "polyint",
24 op_kind: GpuOpKind::Custom("polynomial-integral"),
25 supported_precisions: &[ScalarType::F32, ScalarType::F64],
26 broadcast: BroadcastSemantics::None,
27 provider_hooks: &[ProviderHook::Custom("polyint")],
28 constant_strategy: ConstantStrategy::InlineLiteral,
29 residency: ResidencyPolicy::NewHandle,
30 nan_mode: ReductionNaN::Include,
31 two_pass_threshold: None,
32 workgroup_size: None,
33 accepts_nan_mode: false,
34 notes: "Providers implement the polyint hook for real coefficient vectors; complex inputs fall back to the host.",
35};
36
37fn polyint_error(message: impl Into<String>) -> RuntimeError {
38 build_runtime_error(message)
39 .with_builtin(BUILTIN_NAME)
40 .build()
41}
42
43#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::poly::polyint")]
44pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
45 name: "polyint",
46 shape: ShapeRequirements::Any,
47 constant_strategy: ConstantStrategy::InlineLiteral,
48 elementwise: None,
49 reduction: None,
50 emits_nan: false,
51 notes: "Symbolic operation on coefficient vectors; fusion does not apply.",
52};
53
54#[runtime_builtin(
55 name = "polyint",
56 category = "math/poly",
57 summary = "Integrate polynomial coefficient vectors and append a constant of integration.",
58 keywords = "polyint,polynomial,integral,antiderivative",
59 type_resolver(polyint_type),
60 builtin_path = "crate::builtins::math::poly::polyint"
61)]
62async fn polyint_builtin(coeffs: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
63 if rest.len() > 1 {
64 return Err(polyint_error("polyint: too many input arguments"));
65 }
66
67 let constant = match rest.into_iter().next() {
68 Some(value) => parse_constant(value).await?,
69 None => Complex64::new(0.0, 0.0),
70 };
71
72 if let Value::GpuTensor(handle) = &coeffs {
73 if let Some(device_result) = try_polyint_gpu(handle, constant)? {
74 return Ok(Value::GpuTensor(device_result));
75 }
76 }
77
78 let was_gpu = matches!(coeffs, Value::GpuTensor(_));
79 polyint_host_value(coeffs, constant, was_gpu).await
80}
81
82async fn polyint_host_value(
83 coeffs: Value,
84 constant: Complex64,
85 was_gpu: bool,
86) -> BuiltinResult<Value> {
87 let polynomial = parse_polynomial(coeffs).await?;
88 let mut integrated = integrate_coeffs(&polynomial.coeffs);
89 if integrated.is_empty() {
90 integrated.push(constant);
91 } else if let Some(last) = integrated.last_mut() {
92 *last += constant;
93 }
94 let value = coeffs_to_value(&integrated, polynomial.orientation)?;
95 maybe_return_gpu(value, was_gpu)
96}
97
98fn try_polyint_gpu(
99 handle: &runmat_accelerate_api::GpuTensorHandle,
100 constant: Complex64,
101) -> BuiltinResult<Option<runmat_accelerate_api::GpuTensorHandle>> {
102 if constant.im.abs() > EPS {
103 return Ok(None);
104 }
105 ensure_vector_shape(&handle.shape)?;
106 let Some(provider) = runmat_accelerate_api::provider() else {
107 return Ok(None);
108 };
109 match provider.polyint(handle, constant.re) {
110 Ok(result) => Ok(Some(result)),
111 Err(err) => {
112 trace!("polyint: provider hook unavailable, falling back to host: {err}");
113 Ok(None)
114 }
115 }
116}
117
118fn integrate_coeffs(coeffs: &[Complex64]) -> Vec<Complex64> {
119 if coeffs.is_empty() {
120 return Vec::new();
121 }
122 let mut result = Vec::with_capacity(coeffs.len() + 1);
123 for (idx, coeff) in coeffs.iter().enumerate() {
124 let power = (coeffs.len() - idx) as f64;
125 if power <= 0.0 {
126 result.push(Complex64::new(0.0, 0.0));
127 } else {
128 result.push(*coeff / Complex64::new(power, 0.0));
129 }
130 }
131 result.push(Complex64::new(0.0, 0.0));
132 result
133}
134
135fn maybe_return_gpu(value: Value, was_gpu: bool) -> BuiltinResult<Value> {
136 if !was_gpu {
137 return Ok(value);
138 }
139 match value {
140 Value::Tensor(tensor) => {
141 if let Some(provider) = runmat_accelerate_api::provider() {
142 let view = HostTensorView {
143 data: &tensor.data,
144 shape: &tensor.shape,
145 };
146 match provider.upload(&view) {
147 Ok(handle) => return Ok(Value::GpuTensor(handle)),
148 Err(err) => {
149 warn!("polyint: provider upload failed, keeping result on host: {err}");
150 }
151 }
152 } else {
153 trace!("polyint: no provider available to re-upload result");
154 }
155 Ok(Value::Tensor(tensor))
156 }
157 other => Ok(other),
158 }
159}
160
161fn coeffs_to_value(coeffs: &[Complex64], orientation: Orientation) -> BuiltinResult<Value> {
162 if coeffs.iter().all(|c| c.im.abs() <= EPS) {
163 let data: Vec<f64> = coeffs.iter().map(|c| c.re).collect();
164 let shape = orientation.shape_for_len(data.len());
165 let tensor =
166 Tensor::new(data, shape).map_err(|e| polyint_error(format!("polyint: {e}")))?;
167 Ok(tensor::tensor_into_value(tensor))
168 } else {
169 let data: Vec<(f64, f64)> = coeffs.iter().map(|c| (c.re, c.im)).collect();
170 let shape = orientation.shape_for_len(data.len());
171 let tensor =
172 ComplexTensor::new(data, shape).map_err(|e| polyint_error(format!("polyint: {e}")))?;
173 Ok(Value::ComplexTensor(tensor))
174 }
175}
176
177async fn parse_polynomial(value: Value) -> BuiltinResult<Polynomial> {
178 let gathered = dispatcher::gather_if_needed_async(&value).await?;
179 match gathered {
180 Value::Tensor(tensor) => parse_tensor_coeffs(&tensor),
181 Value::ComplexTensor(tensor) => parse_complex_tensor_coeffs(&tensor),
182 Value::LogicalArray(logical) => {
183 let tensor = tensor::logical_to_tensor(&logical).map_err(polyint_error)?;
184 parse_tensor_coeffs(&tensor)
185 }
186 Value::Num(n) => Ok(Polynomial {
187 coeffs: vec![Complex64::new(n, 0.0)],
188 orientation: Orientation::Scalar,
189 }),
190 Value::Int(i) => Ok(Polynomial {
191 coeffs: vec![Complex64::new(i.to_f64(), 0.0)],
192 orientation: Orientation::Scalar,
193 }),
194 Value::Bool(b) => Ok(Polynomial {
195 coeffs: vec![Complex64::new(if b { 1.0 } else { 0.0 }, 0.0)],
196 orientation: Orientation::Scalar,
197 }),
198 Value::Complex(re, im) => Ok(Polynomial {
199 coeffs: vec![Complex64::new(re, im)],
200 orientation: Orientation::Scalar,
201 }),
202 other => Err(polyint_error(format!(
203 "polyint: expected a numeric coefficient vector, got {:?}",
204 other
205 ))),
206 }
207}
208
209fn parse_tensor_coeffs(tensor: &Tensor) -> BuiltinResult<Polynomial> {
210 ensure_vector_shape(&tensor.shape)?;
211 let orientation = orientation_from_shape(&tensor.shape);
212 Ok(Polynomial {
213 coeffs: tensor
214 .data
215 .iter()
216 .map(|&v| Complex64::new(v, 0.0))
217 .collect(),
218 orientation,
219 })
220}
221
222fn parse_complex_tensor_coeffs(tensor: &ComplexTensor) -> BuiltinResult<Polynomial> {
223 ensure_vector_shape(&tensor.shape)?;
224 let orientation = orientation_from_shape(&tensor.shape);
225 Ok(Polynomial {
226 coeffs: tensor
227 .data
228 .iter()
229 .map(|&(re, im)| Complex64::new(re, im))
230 .collect(),
231 orientation,
232 })
233}
234
235async fn parse_constant(value: Value) -> BuiltinResult<Complex64> {
236 let gathered = dispatcher::gather_if_needed_async(&value).await?;
237 match gathered {
238 Value::Tensor(tensor) => {
239 if tensor.data.len() != 1 {
240 return Err(polyint_error(
241 "polyint: constant of integration must be a scalar",
242 ));
243 }
244 Ok(Complex64::new(tensor.data[0], 0.0))
245 }
246 Value::ComplexTensor(tensor) => {
247 if tensor.data.len() != 1 {
248 return Err(polyint_error(
249 "polyint: constant of integration must be a scalar",
250 ));
251 }
252 let (re, im) = tensor.data[0];
253 Ok(Complex64::new(re, im))
254 }
255 Value::Num(n) => Ok(Complex64::new(n, 0.0)),
256 Value::Int(i) => Ok(Complex64::new(i.to_f64(), 0.0)),
257 Value::Bool(b) => Ok(Complex64::new(if b { 1.0 } else { 0.0 }, 0.0)),
258 Value::Complex(re, im) => Ok(Complex64::new(re, im)),
259 Value::LogicalArray(logical) => {
260 let tensor = tensor::logical_to_tensor(&logical).map_err(polyint_error)?;
261 if tensor.data.len() != 1 {
262 return Err(polyint_error(
263 "polyint: constant of integration must be a scalar",
264 ));
265 }
266 Ok(Complex64::new(tensor.data[0], 0.0))
267 }
268 other => Err(polyint_error(format!(
269 "polyint: constant of integration must be numeric, got {:?}",
270 other
271 ))),
272 }
273}
274
275fn ensure_vector_shape(shape: &[usize]) -> BuiltinResult<()> {
276 let non_unit = shape.iter().filter(|&&dim| dim > 1).count();
277 if non_unit <= 1 {
278 Ok(())
279 } else {
280 Err(polyint_error("polyint: coefficients must form a vector"))
281 }
282}
283
284fn orientation_from_shape(shape: &[usize]) -> Orientation {
285 for (idx, &dim) in shape.iter().enumerate() {
286 if dim != 1 {
287 return match idx {
288 0 => Orientation::Column,
289 1 => Orientation::Row,
290 _ => Orientation::Column,
291 };
292 }
293 }
294 Orientation::Scalar
295}
296
297#[derive(Clone)]
298struct Polynomial {
299 coeffs: Vec<Complex64>,
300 orientation: Orientation,
301}
302
303#[derive(Clone, Copy)]
304enum Orientation {
305 Scalar,
306 Row,
307 Column,
308}
309
310impl Orientation {
311 fn shape_for_len(self, len: usize) -> Vec<usize> {
312 if len <= 1 {
313 return vec![1, 1];
314 }
315 match self {
316 Orientation::Scalar | Orientation::Row => vec![1, len],
317 Orientation::Column => vec![len, 1],
318 }
319 }
320}
321
322#[cfg(test)]
323pub(crate) mod tests {
324 use super::*;
325 use crate::builtins::common::test_support;
326 use futures::executor::block_on;
327 use runmat_builtins::LogicalArray;
328
329 fn assert_error_contains(err: crate::RuntimeError, needle: &str) {
330 assert!(
331 err.message().contains(needle),
332 "expected error containing '{needle}', got '{}'",
333 err.message()
334 );
335 }
336
337 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
338 #[test]
339 fn integrates_polynomial_without_constant() {
340 let tensor = Tensor::new(vec![3.0, -2.0, 5.0, 7.0], vec![1, 4]).unwrap();
341 let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
342 match result {
343 Value::Tensor(t) => {
344 assert_eq!(t.shape, vec![1, 5]);
345 let expected = [0.75, -2.0 / 3.0, 2.5, 7.0, 0.0];
346 assert!(t
347 .data
348 .iter()
349 .zip(expected.iter())
350 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
351 }
352 other => panic!("expected tensor result, got {other:?}"),
353 }
354 }
355
356 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
357 #[test]
358 fn integrates_with_constant() {
359 let tensor = Tensor::new(vec![4.0, 0.0, -8.0], vec![1, 3]).unwrap();
360 let args = vec![Value::Num(3.0)];
361 let result = polyint_builtin(Value::Tensor(tensor), args).expect("polyint");
362 match result {
363 Value::Tensor(t) => {
364 assert_eq!(t.shape, vec![1, 4]);
365 let expected = [4.0 / 3.0, 0.0, -8.0, 3.0];
366 assert!(t
367 .data
368 .iter()
369 .zip(expected.iter())
370 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
371 }
372 other => panic!("expected tensor result, got {other:?}"),
373 }
374 }
375
376 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
377 #[test]
378 fn integrates_scalar_value() {
379 let result = polyint_builtin(Value::Num(5.0), Vec::new()).expect("polyint");
380 match result {
381 Value::Tensor(t) => {
382 assert_eq!(t.shape, vec![1, 2]);
383 assert!((t.data[0] - 5.0).abs() < 1e-12);
384 assert!(t.data[1].abs() < 1e-12);
385 }
386 other => panic!("expected tensor result, got {other:?}"),
387 }
388 }
389
390 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
391 #[test]
392 fn integrates_logical_coefficients() {
393 let logical = LogicalArray::new(vec![1, 0, 1], vec![1, 3]).unwrap();
394 let result =
395 polyint_builtin(Value::LogicalArray(logical), Vec::new()).expect("polyint logical");
396 match result {
397 Value::Tensor(t) => {
398 assert_eq!(t.shape, vec![1, 4]);
399 let expected = [1.0 / 3.0, 0.0, 1.0, 0.0];
400 assert!(t
401 .data
402 .iter()
403 .zip(expected.iter())
404 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
405 }
406 other => panic!("expected tensor result, got {other:?}"),
407 }
408 }
409
410 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
411 #[test]
412 fn preserves_column_vector_orientation() {
413 let tensor = Tensor::new(vec![2.0, 0.0, -6.0], vec![3, 1]).unwrap();
414 let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
415 match result {
416 Value::Tensor(t) => {
417 assert_eq!(t.shape, vec![4, 1]);
418 let expected = [2.0 / 3.0, 0.0, -6.0, 0.0];
419 assert!(t
420 .data
421 .iter()
422 .zip(expected.iter())
423 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
424 }
425 other => panic!("expected column tensor, got {other:?}"),
426 }
427 }
428
429 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
430 #[test]
431 fn integrates_complex_coefficients() {
432 let tensor =
433 ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.0), (0.0, 4.0)], vec![1, 3]).unwrap();
434 let args = vec![Value::Complex(0.0, -1.0)];
435 let result = polyint_builtin(Value::ComplexTensor(tensor), args).expect("polyint");
436 match result {
437 Value::ComplexTensor(t) => {
438 assert_eq!(t.shape, vec![1, 4]);
439 let expected = [(1.0 / 3.0, 2.0 / 3.0), (-1.5, 0.0), (0.0, 4.0), (0.0, -1.0)];
440 assert!(t
441 .data
442 .iter()
443 .zip(expected.iter())
444 .all(|((lre, lim), (rre, rim))| {
445 (lre - rre).abs() < 1e-12 && (lim - rim).abs() < 1e-12
446 }));
447 }
448 other => panic!("expected complex tensor, got {other:?}"),
449 }
450 }
451
452 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
453 #[test]
454 fn rejects_matrix_coefficients() {
455 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
456 let err = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect_err("expected error");
457 assert_error_contains(err, "vector");
458 }
459
460 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
461 #[test]
462 fn rejects_non_scalar_constant() {
463 let coeffs = Tensor::new(vec![1.0, -4.0, 6.0], vec![1, 3]).unwrap();
464 let constant = Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
465 let err = polyint_builtin(Value::Tensor(coeffs), vec![Value::Tensor(constant)])
466 .expect_err("expected error");
467 assert_error_contains(err, "constant of integration must be a scalar");
468 }
469
470 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
471 #[test]
472 fn rejects_excess_arguments() {
473 let tensor = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
474 let err = polyint_builtin(
475 Value::Tensor(tensor),
476 vec![Value::Num(1.0), Value::Num(2.0)],
477 )
478 .expect_err("expected error");
479 assert_error_contains(err, "too many input arguments");
480 }
481
482 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
483 #[test]
484 fn handles_empty_input_as_zero_polynomial() {
485 let tensor = Tensor::new(vec![], vec![1, 0]).unwrap();
486 let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
487 match result {
488 Value::Num(v) => assert!(v.abs() < 1e-12),
489 Value::Tensor(t) => {
490 assert_eq!(t.data.len(), 1);
492 assert!(t.data[0].abs() < 1e-12);
493 }
494 other => panic!("expected numeric result, got {other:?}"),
495 }
496 }
497
498 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
499 #[test]
500 fn empty_input_with_constant() {
501 let tensor = Tensor::new(vec![], vec![1, 0]).unwrap();
502 let result = polyint_builtin(Value::Tensor(tensor), vec![Value::Complex(1.5, -2.0)])
503 .expect("polyint");
504 match result {
505 Value::ComplexTensor(t) => {
506 assert_eq!(t.shape, vec![1, 1]);
507 assert_eq!(t.data.len(), 1);
508 let (re, im) = t.data[0];
509 assert!((re - 1.5).abs() < 1e-12);
510 assert!((im + 2.0).abs() < 1e-12);
511 }
512 other => panic!("expected complex tensor result, got {other:?}"),
513 }
514 }
515
516 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
517 #[test]
518 fn polyint_gpu_roundtrip() {
519 test_support::with_test_provider(|provider| {
520 let tensor = Tensor::new(vec![1.0, -4.0, 6.0], vec![1, 3]).unwrap();
521 let view = HostTensorView {
522 data: &tensor.data,
523 shape: &tensor.shape,
524 };
525 let handle = provider.upload(&view).expect("upload");
526 let result = polyint_builtin(Value::GpuTensor(handle), Vec::new()).expect("polyint");
527 match result {
528 Value::GpuTensor(handle) => {
529 let gathered = test_support::gather(Value::GpuTensor(handle)).expect("gather");
530 assert_eq!(gathered.shape, vec![1, 4]);
531 let expected = [1.0 / 3.0, -2.0, 6.0, 0.0];
532 assert!(gathered
533 .data
534 .iter()
535 .zip(expected.iter())
536 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
537 }
538 other => panic!("expected GPU tensor result, got {other:?}"),
539 }
540 });
541 }
542
543 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
544 #[test]
545 fn polyint_gpu_complex_constant_falls_back_to_host() {
546 test_support::with_test_provider(|provider| {
547 let tensor = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
548 let view = HostTensorView {
549 data: &tensor.data,
550 shape: &tensor.shape,
551 };
552 let handle = provider.upload(&view).expect("upload");
553 let result = polyint_builtin(Value::GpuTensor(handle), vec![Value::Complex(0.0, 2.0)])
554 .expect("polyint");
555 match result {
556 Value::ComplexTensor(ct) => {
557 assert_eq!(ct.shape, vec![1, 3]);
558 let expected = [(0.5, 0.0), (0.0, 0.0), (0.0, 2.0)];
559 assert!(ct
560 .data
561 .iter()
562 .zip(expected.iter())
563 .all(|((lre, lim), (rre, rim))| {
564 (lre - rre).abs() < 1e-12 && (lim - rim).abs() < 1e-12
565 }));
566 }
567 other => panic!("expected complex tensor fall-back, got {other:?}"),
568 }
569 });
570 }
571
572 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
573 #[test]
574 fn polyint_gpu_with_gpu_constant() {
575 test_support::with_test_provider(|provider| {
576 let coeffs = Tensor::new(vec![2.0, 0.0], vec![1, 2]).unwrap();
577 let coeff_view = HostTensorView {
578 data: &coeffs.data,
579 shape: &coeffs.shape,
580 };
581 let coeff_handle = provider.upload(&coeff_view).expect("upload coeffs");
582 let constant = Tensor::new(vec![3.0], vec![1, 1]).unwrap();
583 let constant_view = HostTensorView {
584 data: &constant.data,
585 shape: &constant.shape,
586 };
587 let constant_handle = provider.upload(&constant_view).expect("upload constant");
588 let result = polyint_builtin(
589 Value::GpuTensor(coeff_handle),
590 vec![Value::GpuTensor(constant_handle)],
591 )
592 .expect("polyint");
593 match result {
594 Value::GpuTensor(handle) => {
595 let gathered =
596 test_support::gather(Value::GpuTensor(handle)).expect("gather result");
597 assert_eq!(gathered.shape, vec![1, 3]);
598 let expected = [1.0, 0.0, 3.0];
599 assert!(gathered
600 .data
601 .iter()
602 .zip(expected.iter())
603 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
604 }
605 other => panic!("expected gpu tensor result, got {other:?}"),
606 }
607 });
608 }
609
610 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
611 #[test]
612 #[cfg(feature = "wgpu")]
613 fn polyint_wgpu_matches_cpu() {
614 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
615 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
616 );
617 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
618 let tensor = Tensor::new(vec![3.0, -2.0, 5.0, 7.0], vec![1, 4]).unwrap();
619 let view = HostTensorView {
620 data: &tensor.data,
621 shape: &tensor.shape,
622 };
623 let handle = provider.upload(&view).expect("upload");
624 let gpu_value = polyint_builtin(Value::GpuTensor(handle), Vec::new()).expect("polyint gpu");
625 let gathered = test_support::gather(gpu_value).expect("gather");
626 let cpu_value =
627 polyint_builtin(Value::Tensor(tensor.clone()), Vec::new()).expect("polyint cpu");
628 let expected = match cpu_value {
629 Value::Tensor(t) => t,
630 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
631 other => panic!("unexpected cpu result {other:?}"),
632 };
633 assert_eq!(gathered.shape, expected.shape);
634 let tol = match provider.precision() {
635 runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
636 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
637 };
638 gathered
639 .data
640 .iter()
641 .zip(expected.data.iter())
642 .for_each(|(lhs, rhs)| assert!((lhs - rhs).abs() < tol));
643 }
644
645 fn polyint_builtin(coeffs: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
646 block_on(super::polyint_builtin(coeffs, rest))
647 }
648}