1use log::{trace, warn};
4use num_complex::Complex64;
5use runmat_accelerate_api::HostTensorView;
6use runmat_builtins::{
7 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
8 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
9 ComplexTensor, Tensor, Value,
10};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::spec::{
14 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
16};
17use crate::builtins::common::tensor;
18use crate::builtins::math::poly::type_resolvers::polyint_type;
19use crate::dispatcher;
20use crate::{build_runtime_error, BuiltinResult, RuntimeError};
21
22const EPS: f64 = 1.0e-12;
23const BUILTIN_NAME: &str = "polyint";
24
25const POLYINT_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
26 name: "q",
27 ty: BuiltinParamType::Any,
28 arity: BuiltinParamArity::Required,
29 default: None,
30 description: "Integrated polynomial coefficient vector.",
31}];
32
33const POLYINT_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
34 name: "p",
35 ty: BuiltinParamType::Any,
36 arity: BuiltinParamArity::Required,
37 default: None,
38 description: "Polynomial coefficient vector.",
39}];
40
41const POLYINT_INPUTS_WITH_K: [BuiltinParamDescriptor; 2] = [
42 BuiltinParamDescriptor {
43 name: "p",
44 ty: BuiltinParamType::Any,
45 arity: BuiltinParamArity::Required,
46 default: None,
47 description: "Polynomial coefficient vector.",
48 },
49 BuiltinParamDescriptor {
50 name: "k",
51 ty: BuiltinParamType::Any,
52 arity: BuiltinParamArity::Optional,
53 default: None,
54 description: "Constant of integration.",
55 },
56];
57
58const POLYINT_SIGNATURES: [BuiltinSignatureDescriptor; 2] = [
59 BuiltinSignatureDescriptor {
60 label: "q = polyint(p)",
61 inputs: &POLYINT_INPUTS,
62 outputs: &POLYINT_OUTPUT,
63 },
64 BuiltinSignatureDescriptor {
65 label: "q = polyint(p, k)",
66 inputs: &POLYINT_INPUTS_WITH_K,
67 outputs: &POLYINT_OUTPUT,
68 },
69];
70
71const POLYINT_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
72 code: "RM.POLYINT.INVALID_ARGUMENT",
73 identifier: Some("RunMat:polyint:InvalidArgument"),
74 when: "Input arity or integration-constant argument is malformed.",
75 message: "polyint: invalid argument",
76};
77
78const POLYINT_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
79 code: "RM.POLYINT.INVALID_INPUT",
80 identifier: Some("RunMat:polyint:InvalidInput"),
81 when: "Input polynomial cannot be interpreted as a numeric coefficient vector.",
82 message: "polyint: invalid input",
83};
84
85const POLYINT_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
86 code: "RM.POLYINT.INTERNAL",
87 identifier: Some("RunMat:polyint:Internal"),
88 when: "Runtime fails while building output tensors or provider fallback paths.",
89 message: "polyint: internal runtime failure",
90};
91
92const POLYINT_ERRORS: [BuiltinErrorDescriptor; 3] = [
93 POLYINT_ERROR_INVALID_ARGUMENT,
94 POLYINT_ERROR_INVALID_INPUT,
95 POLYINT_ERROR_INTERNAL,
96];
97
98pub const POLYINT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
99 signatures: &POLYINT_SIGNATURES,
100 output_mode: BuiltinOutputMode::Fixed,
101 completion_policy: BuiltinCompletionPolicy::Public,
102 errors: &POLYINT_ERRORS,
103};
104
105#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::poly::polyint")]
106pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
107 name: "polyint",
108 op_kind: GpuOpKind::Custom("polynomial-integral"),
109 supported_precisions: &[ScalarType::F32, ScalarType::F64],
110 broadcast: BroadcastSemantics::None,
111 provider_hooks: &[ProviderHook::Custom("polyint")],
112 constant_strategy: ConstantStrategy::InlineLiteral,
113 residency: ResidencyPolicy::NewHandle,
114 nan_mode: ReductionNaN::Include,
115 two_pass_threshold: None,
116 workgroup_size: None,
117 accepts_nan_mode: false,
118 notes: "Providers implement the polyint hook for real coefficient vectors; complex inputs fall back to the host.",
119};
120
121fn polyint_error(message: impl Into<String>) -> RuntimeError {
122 polyint_error_with(message, &POLYINT_ERROR_INVALID_INPUT)
123}
124
125fn polyint_argument_error(message: impl Into<String>) -> RuntimeError {
126 polyint_error_with(message, &POLYINT_ERROR_INVALID_ARGUMENT)
127}
128
129fn polyint_error_with(
130 message: impl Into<String>,
131 error: &'static BuiltinErrorDescriptor,
132) -> RuntimeError {
133 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
134 if let Some(identifier) = error.identifier {
135 builder = builder.with_identifier(identifier);
136 }
137 builder.build()
138}
139
140#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::poly::polyint")]
141pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
142 name: "polyint",
143 shape: ShapeRequirements::Any,
144 constant_strategy: ConstantStrategy::InlineLiteral,
145 elementwise: None,
146 reduction: None,
147 emits_nan: false,
148 notes: "Symbolic operation on coefficient vectors; fusion does not apply.",
149};
150
151#[runtime_builtin(
152 name = "polyint",
153 category = "math/poly",
154 summary = "Integrate polynomial coefficient vectors and append a constant of integration.",
155 keywords = "polyint,polynomial,integral,antiderivative",
156 type_resolver(polyint_type),
157 descriptor(crate::builtins::math::poly::polyint::POLYINT_DESCRIPTOR),
158 builtin_path = "crate::builtins::math::poly::polyint"
159)]
160async fn polyint_builtin(coeffs: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
161 if rest.len() > 1 {
162 return Err(polyint_argument_error("polyint: too many input arguments"));
163 }
164
165 let constant = match rest.into_iter().next() {
166 Some(value) => parse_constant(value).await?,
167 None => Complex64::new(0.0, 0.0),
168 };
169
170 if let Value::GpuTensor(handle) = &coeffs {
171 if let Some(device_result) = try_polyint_gpu(handle, constant)? {
172 return Ok(Value::GpuTensor(device_result));
173 }
174 }
175
176 let was_gpu = matches!(coeffs, Value::GpuTensor(_));
177 polyint_host_value(coeffs, constant, was_gpu).await
178}
179
180async fn polyint_host_value(
181 coeffs: Value,
182 constant: Complex64,
183 was_gpu: bool,
184) -> BuiltinResult<Value> {
185 let polynomial = parse_polynomial(coeffs).await?;
186 let mut integrated = integrate_coeffs(&polynomial.coeffs);
187 if integrated.is_empty() {
188 integrated.push(constant);
189 } else if let Some(last) = integrated.last_mut() {
190 *last += constant;
191 }
192 let value = coeffs_to_value(&integrated, polynomial.orientation)?;
193 maybe_return_gpu(value, was_gpu)
194}
195
196fn try_polyint_gpu(
197 handle: &runmat_accelerate_api::GpuTensorHandle,
198 constant: Complex64,
199) -> BuiltinResult<Option<runmat_accelerate_api::GpuTensorHandle>> {
200 if constant.im.abs() > EPS {
201 return Ok(None);
202 }
203 ensure_vector_shape(&handle.shape)?;
204 let Some(provider) = runmat_accelerate_api::provider() else {
205 return Ok(None);
206 };
207 match provider.polyint(handle, constant.re) {
208 Ok(result) => Ok(Some(result)),
209 Err(err) => {
210 trace!("polyint: provider hook unavailable, falling back to host: {err}");
211 Ok(None)
212 }
213 }
214}
215
216fn integrate_coeffs(coeffs: &[Complex64]) -> Vec<Complex64> {
217 if coeffs.is_empty() {
218 return Vec::new();
219 }
220 let mut result = Vec::with_capacity(coeffs.len() + 1);
221 for (idx, coeff) in coeffs.iter().enumerate() {
222 let power = (coeffs.len() - idx) as f64;
223 if power <= 0.0 {
224 result.push(Complex64::new(0.0, 0.0));
225 } else {
226 result.push(*coeff / Complex64::new(power, 0.0));
227 }
228 }
229 result.push(Complex64::new(0.0, 0.0));
230 result
231}
232
233fn maybe_return_gpu(value: Value, was_gpu: bool) -> BuiltinResult<Value> {
234 if !was_gpu {
235 return Ok(value);
236 }
237 match value {
238 Value::Tensor(tensor) => {
239 if let Some(provider) = runmat_accelerate_api::provider() {
240 let view = HostTensorView {
241 data: &tensor.data,
242 shape: &tensor.shape,
243 };
244 match provider.upload(&view) {
245 Ok(handle) => return Ok(Value::GpuTensor(handle)),
246 Err(err) => {
247 warn!("polyint: provider upload failed, keeping result on host: {err}");
248 }
249 }
250 } else {
251 trace!("polyint: no provider available to re-upload result");
252 }
253 Ok(Value::Tensor(tensor))
254 }
255 other => Ok(other),
256 }
257}
258
259fn coeffs_to_value(coeffs: &[Complex64], orientation: Orientation) -> BuiltinResult<Value> {
260 if coeffs.iter().all(|c| c.im.abs() <= EPS) {
261 let data: Vec<f64> = coeffs.iter().map(|c| c.re).collect();
262 let shape = orientation.shape_for_len(data.len());
263 let tensor =
264 Tensor::new(data, shape).map_err(|e| polyint_error(format!("polyint: {e}")))?;
265 Ok(tensor::tensor_into_value(tensor))
266 } else {
267 let data: Vec<(f64, f64)> = coeffs.iter().map(|c| (c.re, c.im)).collect();
268 let shape = orientation.shape_for_len(data.len());
269 let tensor =
270 ComplexTensor::new(data, shape).map_err(|e| polyint_error(format!("polyint: {e}")))?;
271 Ok(Value::ComplexTensor(tensor))
272 }
273}
274
275async fn parse_polynomial(value: Value) -> BuiltinResult<Polynomial> {
276 let gathered = dispatcher::gather_if_needed_async(&value).await?;
277 match gathered {
278 Value::Tensor(tensor) => parse_tensor_coeffs(&tensor),
279 Value::ComplexTensor(tensor) => parse_complex_tensor_coeffs(&tensor),
280 Value::LogicalArray(logical) => {
281 let tensor = tensor::logical_to_tensor(&logical).map_err(polyint_error)?;
282 parse_tensor_coeffs(&tensor)
283 }
284 Value::Num(n) => Ok(Polynomial {
285 coeffs: vec![Complex64::new(n, 0.0)],
286 orientation: Orientation::Scalar,
287 }),
288 Value::Int(i) => Ok(Polynomial {
289 coeffs: vec![Complex64::new(i.to_f64(), 0.0)],
290 orientation: Orientation::Scalar,
291 }),
292 Value::Bool(b) => Ok(Polynomial {
293 coeffs: vec![Complex64::new(if b { 1.0 } else { 0.0 }, 0.0)],
294 orientation: Orientation::Scalar,
295 }),
296 Value::Complex(re, im) => Ok(Polynomial {
297 coeffs: vec![Complex64::new(re, im)],
298 orientation: Orientation::Scalar,
299 }),
300 other => Err(polyint_error(format!(
301 "polyint: expected a numeric coefficient vector, got {:?}",
302 other
303 ))),
304 }
305}
306
307fn parse_tensor_coeffs(tensor: &Tensor) -> BuiltinResult<Polynomial> {
308 ensure_vector_shape(&tensor.shape)?;
309 let orientation = orientation_from_shape(&tensor.shape);
310 Ok(Polynomial {
311 coeffs: tensor
312 .data
313 .iter()
314 .map(|&v| Complex64::new(v, 0.0))
315 .collect(),
316 orientation,
317 })
318}
319
320fn parse_complex_tensor_coeffs(tensor: &ComplexTensor) -> BuiltinResult<Polynomial> {
321 ensure_vector_shape(&tensor.shape)?;
322 let orientation = orientation_from_shape(&tensor.shape);
323 Ok(Polynomial {
324 coeffs: tensor
325 .data
326 .iter()
327 .map(|&(re, im)| Complex64::new(re, im))
328 .collect(),
329 orientation,
330 })
331}
332
333async fn parse_constant(value: Value) -> BuiltinResult<Complex64> {
334 let gathered = dispatcher::gather_if_needed_async(&value).await?;
335 match gathered {
336 Value::Tensor(tensor) => {
337 if tensor.data.len() != 1 {
338 return Err(polyint_error(
339 "polyint: constant of integration must be a scalar",
340 ));
341 }
342 Ok(Complex64::new(tensor.data[0], 0.0))
343 }
344 Value::ComplexTensor(tensor) => {
345 if tensor.data.len() != 1 {
346 return Err(polyint_error(
347 "polyint: constant of integration must be a scalar",
348 ));
349 }
350 let (re, im) = tensor.data[0];
351 Ok(Complex64::new(re, im))
352 }
353 Value::Num(n) => Ok(Complex64::new(n, 0.0)),
354 Value::Int(i) => Ok(Complex64::new(i.to_f64(), 0.0)),
355 Value::Bool(b) => Ok(Complex64::new(if b { 1.0 } else { 0.0 }, 0.0)),
356 Value::Complex(re, im) => Ok(Complex64::new(re, im)),
357 Value::LogicalArray(logical) => {
358 let tensor = tensor::logical_to_tensor(&logical).map_err(polyint_error)?;
359 if tensor.data.len() != 1 {
360 return Err(polyint_error(
361 "polyint: constant of integration must be a scalar",
362 ));
363 }
364 Ok(Complex64::new(tensor.data[0], 0.0))
365 }
366 other => Err(polyint_error(format!(
367 "polyint: constant of integration must be numeric, got {:?}",
368 other
369 ))),
370 }
371}
372
373fn ensure_vector_shape(shape: &[usize]) -> BuiltinResult<()> {
374 let non_unit = shape.iter().filter(|&&dim| dim > 1).count();
375 if non_unit <= 1 {
376 Ok(())
377 } else {
378 Err(polyint_error("polyint: coefficients must form a vector"))
379 }
380}
381
382fn orientation_from_shape(shape: &[usize]) -> Orientation {
383 for (idx, &dim) in shape.iter().enumerate() {
384 if dim != 1 {
385 return match idx {
386 0 => Orientation::Column,
387 1 => Orientation::Row,
388 _ => Orientation::Column,
389 };
390 }
391 }
392 Orientation::Scalar
393}
394
395#[derive(Clone)]
396struct Polynomial {
397 coeffs: Vec<Complex64>,
398 orientation: Orientation,
399}
400
401#[derive(Clone, Copy)]
402enum Orientation {
403 Scalar,
404 Row,
405 Column,
406}
407
408impl Orientation {
409 fn shape_for_len(self, len: usize) -> Vec<usize> {
410 if len <= 1 {
411 return vec![1, 1];
412 }
413 match self {
414 Orientation::Scalar | Orientation::Row => vec![1, len],
415 Orientation::Column => vec![len, 1],
416 }
417 }
418}
419
420#[cfg(test)]
421pub(crate) mod tests {
422 use super::*;
423 use crate::builtins::common::test_support;
424 use futures::executor::block_on;
425 use runmat_builtins::LogicalArray;
426
427 fn assert_error_contains(err: crate::RuntimeError, needle: &str) {
428 assert!(
429 err.message().contains(needle),
430 "expected error containing '{needle}', got '{}'",
431 err.message()
432 );
433 }
434
435 #[test]
436 fn polyint_descriptor_signatures_cover_core_forms() {
437 let labels: Vec<&str> = POLYINT_DESCRIPTOR
438 .signatures
439 .iter()
440 .map(|signature| signature.label)
441 .collect();
442 assert!(labels.contains(&"q = polyint(p)"));
443 assert!(labels.contains(&"q = polyint(p, k)"));
444 }
445
446 #[test]
447 fn polyint_descriptor_errors_have_stable_codes() {
448 let codes: Vec<&str> = POLYINT_DESCRIPTOR
449 .errors
450 .iter()
451 .map(|error| error.code)
452 .collect();
453 assert!(codes.contains(&"RM.POLYINT.INVALID_ARGUMENT"));
454 assert!(codes.contains(&"RM.POLYINT.INVALID_INPUT"));
455 assert!(codes.contains(&"RM.POLYINT.INTERNAL"));
456 }
457
458 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
459 #[test]
460 fn integrates_polynomial_without_constant() {
461 let tensor = Tensor::new(vec![3.0, -2.0, 5.0, 7.0], vec![1, 4]).unwrap();
462 let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
463 match result {
464 Value::Tensor(t) => {
465 assert_eq!(t.shape, vec![1, 5]);
466 let expected = [0.75, -2.0 / 3.0, 2.5, 7.0, 0.0];
467 assert!(t
468 .data
469 .iter()
470 .zip(expected.iter())
471 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
472 }
473 other => panic!("expected tensor result, got {other:?}"),
474 }
475 }
476
477 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
478 #[test]
479 fn integrates_with_constant() {
480 let tensor = Tensor::new(vec![4.0, 0.0, -8.0], vec![1, 3]).unwrap();
481 let args = vec![Value::Num(3.0)];
482 let result = polyint_builtin(Value::Tensor(tensor), args).expect("polyint");
483 match result {
484 Value::Tensor(t) => {
485 assert_eq!(t.shape, vec![1, 4]);
486 let expected = [4.0 / 3.0, 0.0, -8.0, 3.0];
487 assert!(t
488 .data
489 .iter()
490 .zip(expected.iter())
491 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
492 }
493 other => panic!("expected tensor result, got {other:?}"),
494 }
495 }
496
497 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
498 #[test]
499 fn integrates_scalar_value() {
500 let result = polyint_builtin(Value::Num(5.0), Vec::new()).expect("polyint");
501 match result {
502 Value::Tensor(t) => {
503 assert_eq!(t.shape, vec![1, 2]);
504 assert!((t.data[0] - 5.0).abs() < 1e-12);
505 assert!(t.data[1].abs() < 1e-12);
506 }
507 other => panic!("expected tensor result, got {other:?}"),
508 }
509 }
510
511 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
512 #[test]
513 fn integrates_logical_coefficients() {
514 let logical = LogicalArray::new(vec![1, 0, 1], vec![1, 3]).unwrap();
515 let result =
516 polyint_builtin(Value::LogicalArray(logical), Vec::new()).expect("polyint logical");
517 match result {
518 Value::Tensor(t) => {
519 assert_eq!(t.shape, vec![1, 4]);
520 let expected = [1.0 / 3.0, 0.0, 1.0, 0.0];
521 assert!(t
522 .data
523 .iter()
524 .zip(expected.iter())
525 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
526 }
527 other => panic!("expected tensor result, got {other:?}"),
528 }
529 }
530
531 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
532 #[test]
533 fn preserves_column_vector_orientation() {
534 let tensor = Tensor::new(vec![2.0, 0.0, -6.0], vec![3, 1]).unwrap();
535 let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
536 match result {
537 Value::Tensor(t) => {
538 assert_eq!(t.shape, vec![4, 1]);
539 let expected = [2.0 / 3.0, 0.0, -6.0, 0.0];
540 assert!(t
541 .data
542 .iter()
543 .zip(expected.iter())
544 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
545 }
546 other => panic!("expected column tensor, got {other:?}"),
547 }
548 }
549
550 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
551 #[test]
552 fn integrates_complex_coefficients() {
553 let tensor =
554 ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.0), (0.0, 4.0)], vec![1, 3]).unwrap();
555 let args = vec![Value::Complex(0.0, -1.0)];
556 let result = polyint_builtin(Value::ComplexTensor(tensor), args).expect("polyint");
557 match result {
558 Value::ComplexTensor(t) => {
559 assert_eq!(t.shape, vec![1, 4]);
560 let expected = [(1.0 / 3.0, 2.0 / 3.0), (-1.5, 0.0), (0.0, 4.0), (0.0, -1.0)];
561 assert!(t
562 .data
563 .iter()
564 .zip(expected.iter())
565 .all(|((lre, lim), (rre, rim))| {
566 (lre - rre).abs() < 1e-12 && (lim - rim).abs() < 1e-12
567 }));
568 }
569 other => panic!("expected complex tensor, got {other:?}"),
570 }
571 }
572
573 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
574 #[test]
575 fn rejects_matrix_coefficients() {
576 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
577 let err = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect_err("expected error");
578 assert_error_contains(err, "vector");
579 }
580
581 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
582 #[test]
583 fn rejects_non_scalar_constant() {
584 let coeffs = Tensor::new(vec![1.0, -4.0, 6.0], vec![1, 3]).unwrap();
585 let constant = Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
586 let err = polyint_builtin(Value::Tensor(coeffs), vec![Value::Tensor(constant)])
587 .expect_err("expected error");
588 assert_error_contains(err, "constant of integration must be a scalar");
589 }
590
591 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
592 #[test]
593 fn rejects_excess_arguments() {
594 let tensor = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
595 let err = polyint_builtin(
596 Value::Tensor(tensor),
597 vec![Value::Num(1.0), Value::Num(2.0)],
598 )
599 .expect_err("expected error");
600 assert_eq!(err.identifier(), POLYINT_ERROR_INVALID_ARGUMENT.identifier);
601 assert_error_contains(err, "too many input arguments");
602 }
603
604 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
605 #[test]
606 fn handles_empty_input_as_zero_polynomial() {
607 let tensor = Tensor::new(vec![], vec![1, 0]).unwrap();
608 let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
609 match result {
610 Value::Num(v) => assert!(v.abs() < 1e-12),
611 Value::Tensor(t) => {
612 assert_eq!(t.data.len(), 1);
614 assert!(t.data[0].abs() < 1e-12);
615 }
616 other => panic!("expected numeric result, got {other:?}"),
617 }
618 }
619
620 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
621 #[test]
622 fn empty_input_with_constant() {
623 let tensor = Tensor::new(vec![], vec![1, 0]).unwrap();
624 let result = polyint_builtin(Value::Tensor(tensor), vec![Value::Complex(1.5, -2.0)])
625 .expect("polyint");
626 match result {
627 Value::ComplexTensor(t) => {
628 assert_eq!(t.shape, vec![1, 1]);
629 assert_eq!(t.data.len(), 1);
630 let (re, im) = t.data[0];
631 assert!((re - 1.5).abs() < 1e-12);
632 assert!((im + 2.0).abs() < 1e-12);
633 }
634 other => panic!("expected complex tensor result, got {other:?}"),
635 }
636 }
637
638 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
639 #[test]
640 fn polyint_gpu_roundtrip() {
641 test_support::with_test_provider(|provider| {
642 let tensor = Tensor::new(vec![1.0, -4.0, 6.0], vec![1, 3]).unwrap();
643 let view = HostTensorView {
644 data: &tensor.data,
645 shape: &tensor.shape,
646 };
647 let handle = provider.upload(&view).expect("upload");
648 let result = polyint_builtin(Value::GpuTensor(handle), Vec::new()).expect("polyint");
649 match result {
650 Value::GpuTensor(handle) => {
651 let gathered = test_support::gather(Value::GpuTensor(handle)).expect("gather");
652 assert_eq!(gathered.shape, vec![1, 4]);
653 let expected = [1.0 / 3.0, -2.0, 6.0, 0.0];
654 assert!(gathered
655 .data
656 .iter()
657 .zip(expected.iter())
658 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
659 }
660 other => panic!("expected GPU tensor result, got {other:?}"),
661 }
662 });
663 }
664
665 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
666 #[test]
667 fn polyint_gpu_complex_constant_falls_back_to_host() {
668 test_support::with_test_provider(|provider| {
669 let tensor = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
670 let view = HostTensorView {
671 data: &tensor.data,
672 shape: &tensor.shape,
673 };
674 let handle = provider.upload(&view).expect("upload");
675 let result = polyint_builtin(Value::GpuTensor(handle), vec![Value::Complex(0.0, 2.0)])
676 .expect("polyint");
677 match result {
678 Value::ComplexTensor(ct) => {
679 assert_eq!(ct.shape, vec![1, 3]);
680 let expected = [(0.5, 0.0), (0.0, 0.0), (0.0, 2.0)];
681 assert!(ct
682 .data
683 .iter()
684 .zip(expected.iter())
685 .all(|((lre, lim), (rre, rim))| {
686 (lre - rre).abs() < 1e-12 && (lim - rim).abs() < 1e-12
687 }));
688 }
689 other => panic!("expected complex tensor fall-back, got {other:?}"),
690 }
691 });
692 }
693
694 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
695 #[test]
696 fn polyint_gpu_with_gpu_constant() {
697 test_support::with_test_provider(|provider| {
698 let coeffs = Tensor::new(vec![2.0, 0.0], vec![1, 2]).unwrap();
699 let coeff_view = HostTensorView {
700 data: &coeffs.data,
701 shape: &coeffs.shape,
702 };
703 let coeff_handle = provider.upload(&coeff_view).expect("upload coeffs");
704 let constant = Tensor::new(vec![3.0], vec![1, 1]).unwrap();
705 let constant_view = HostTensorView {
706 data: &constant.data,
707 shape: &constant.shape,
708 };
709 let constant_handle = provider.upload(&constant_view).expect("upload constant");
710 let result = polyint_builtin(
711 Value::GpuTensor(coeff_handle),
712 vec![Value::GpuTensor(constant_handle)],
713 )
714 .expect("polyint");
715 match result {
716 Value::GpuTensor(handle) => {
717 let gathered =
718 test_support::gather(Value::GpuTensor(handle)).expect("gather result");
719 assert_eq!(gathered.shape, vec![1, 3]);
720 let expected = [1.0, 0.0, 3.0];
721 assert!(gathered
722 .data
723 .iter()
724 .zip(expected.iter())
725 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
726 }
727 other => panic!("expected gpu tensor result, got {other:?}"),
728 }
729 });
730 }
731
732 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
733 #[test]
734 #[cfg(feature = "wgpu")]
735 fn polyint_wgpu_matches_cpu() {
736 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
737 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
738 );
739 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
740 let tensor = Tensor::new(vec![3.0, -2.0, 5.0, 7.0], vec![1, 4]).unwrap();
741 let view = HostTensorView {
742 data: &tensor.data,
743 shape: &tensor.shape,
744 };
745 let handle = provider.upload(&view).expect("upload");
746 let gpu_value = polyint_builtin(Value::GpuTensor(handle), Vec::new()).expect("polyint gpu");
747 let gathered = test_support::gather(gpu_value).expect("gather");
748 let cpu_value =
749 polyint_builtin(Value::Tensor(tensor.clone()), Vec::new()).expect("polyint cpu");
750 let expected = match cpu_value {
751 Value::Tensor(t) => t,
752 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
753 other => panic!("unexpected cpu result {other:?}"),
754 };
755 assert_eq!(gathered.shape, expected.shape);
756 let tol = match provider.precision() {
757 runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
758 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
759 };
760 gathered
761 .data
762 .iter()
763 .zip(expected.data.iter())
764 .for_each(|(lhs, rhs)| assert!((lhs - rhs).abs() < tol));
765 }
766
767 fn polyint_builtin(coeffs: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
768 block_on(super::polyint_builtin(coeffs, rest))
769 }
770}