1use log::{trace, warn};
4use num_complex::Complex64;
5use runmat_accelerate_api::HostTensorView;
6use runmat_builtins::{
7 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
8 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
9 ComplexTensor, Tensor, Value,
10};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::gpu_helpers;
14use crate::builtins::common::spec::{
15 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
16 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
17};
18use crate::builtins::common::tensor;
19use crate::builtins::math::poly::type_resolvers::polyint_type;
20use crate::dispatcher;
21use crate::{build_runtime_error, BuiltinResult, RuntimeError};
22
23const EPS: f64 = 1.0e-12;
24const BUILTIN_NAME: &str = "polyint";
25
26const POLYINT_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
27 name: "q",
28 ty: BuiltinParamType::Any,
29 arity: BuiltinParamArity::Required,
30 default: None,
31 description: "Integrated polynomial coefficient vector.",
32}];
33
34const POLYINT_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
35 name: "p",
36 ty: BuiltinParamType::Any,
37 arity: BuiltinParamArity::Required,
38 default: None,
39 description: "Polynomial coefficient vector.",
40}];
41
42const POLYINT_INPUTS_WITH_K: [BuiltinParamDescriptor; 2] = [
43 BuiltinParamDescriptor {
44 name: "p",
45 ty: BuiltinParamType::Any,
46 arity: BuiltinParamArity::Required,
47 default: None,
48 description: "Polynomial coefficient vector.",
49 },
50 BuiltinParamDescriptor {
51 name: "k",
52 ty: BuiltinParamType::Any,
53 arity: BuiltinParamArity::Optional,
54 default: None,
55 description: "Constant of integration.",
56 },
57];
58
59const POLYINT_SIGNATURES: [BuiltinSignatureDescriptor; 2] = [
60 BuiltinSignatureDescriptor {
61 label: "q = polyint(p)",
62 inputs: &POLYINT_INPUTS,
63 outputs: &POLYINT_OUTPUT,
64 },
65 BuiltinSignatureDescriptor {
66 label: "q = polyint(p, k)",
67 inputs: &POLYINT_INPUTS_WITH_K,
68 outputs: &POLYINT_OUTPUT,
69 },
70];
71
72const POLYINT_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
73 code: "RM.POLYINT.INVALID_ARGUMENT",
74 identifier: Some("RunMat:polyint:InvalidArgument"),
75 when: "Input arity or integration-constant argument is malformed.",
76 message: "polyint: invalid argument",
77};
78
79const POLYINT_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
80 code: "RM.POLYINT.INVALID_INPUT",
81 identifier: Some("RunMat:polyint:InvalidInput"),
82 when: "Input polynomial cannot be interpreted as a numeric coefficient vector.",
83 message: "polyint: invalid input",
84};
85
86const POLYINT_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
87 code: "RM.POLYINT.INTERNAL",
88 identifier: Some("RunMat:polyint:Internal"),
89 when: "Runtime fails while building output tensors or provider fallback paths.",
90 message: "polyint: internal runtime failure",
91};
92
93const POLYINT_ERRORS: [BuiltinErrorDescriptor; 3] = [
94 POLYINT_ERROR_INVALID_ARGUMENT,
95 POLYINT_ERROR_INVALID_INPUT,
96 POLYINT_ERROR_INTERNAL,
97];
98
99pub const POLYINT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
100 signatures: &POLYINT_SIGNATURES,
101 output_mode: BuiltinOutputMode::Fixed,
102 completion_policy: BuiltinCompletionPolicy::Public,
103 errors: &POLYINT_ERRORS,
104};
105
106#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::poly::polyint")]
107pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
108 name: "polyint",
109 op_kind: GpuOpKind::Custom("polynomial-integral"),
110 supported_precisions: &[ScalarType::F32, ScalarType::F64],
111 broadcast: BroadcastSemantics::None,
112 provider_hooks: &[ProviderHook::Custom("polyint")],
113 constant_strategy: ConstantStrategy::InlineLiteral,
114 residency: ResidencyPolicy::NewHandle,
115 nan_mode: ReductionNaN::Include,
116 two_pass_threshold: None,
117 workgroup_size: None,
118 accepts_nan_mode: false,
119 notes: "Providers implement the polyint hook for real and complex-interleaved coefficient vectors; complex integration constants fall back to host integration and re-upload.",
120};
121
122fn polyint_error(message: impl Into<String>) -> RuntimeError {
123 polyint_error_with(message, &POLYINT_ERROR_INVALID_INPUT)
124}
125
126fn polyint_argument_error(message: impl Into<String>) -> RuntimeError {
127 polyint_error_with(message, &POLYINT_ERROR_INVALID_ARGUMENT)
128}
129
130fn polyint_error_with(
131 message: impl Into<String>,
132 error: &'static BuiltinErrorDescriptor,
133) -> RuntimeError {
134 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
135 if let Some(identifier) = error.identifier {
136 builder = builder.with_identifier(identifier);
137 }
138 builder.build()
139}
140
141#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::poly::polyint")]
142pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
143 name: "polyint",
144 shape: ShapeRequirements::Any,
145 constant_strategy: ConstantStrategy::InlineLiteral,
146 elementwise: None,
147 reduction: None,
148 emits_nan: false,
149 notes: "Symbolic operation on coefficient vectors; fusion does not apply.",
150};
151
152#[runtime_builtin(
153 name = "polyint",
154 category = "math/poly",
155 summary = "Integrate polynomial coefficient vectors and append a constant of integration.",
156 keywords = "polyint,polynomial,integral,antiderivative",
157 type_resolver(polyint_type),
158 descriptor(crate::builtins::math::poly::polyint::POLYINT_DESCRIPTOR),
159 builtin_path = "crate::builtins::math::poly::polyint"
160)]
161async fn polyint_builtin(coeffs: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
162 if rest.len() > 1 {
163 return Err(polyint_argument_error("polyint: too many input arguments"));
164 }
165
166 let constant = match rest.into_iter().next() {
167 Some(value) => parse_constant(value).await?,
168 None => Complex64::new(0.0, 0.0),
169 };
170
171 if let Value::GpuTensor(handle) = &coeffs {
172 if let Some(device_result) = try_polyint_gpu(handle, constant)? {
173 return Ok(Value::GpuTensor(device_result));
174 }
175 }
176
177 let source_gpu = match &coeffs {
178 Value::GpuTensor(handle) => Some(handle.clone()),
179 _ => None,
180 };
181 polyint_host_value(coeffs, constant, source_gpu).await
182}
183
184async fn polyint_host_value(
185 coeffs: Value,
186 constant: Complex64,
187 source_gpu: Option<runmat_accelerate_api::GpuTensorHandle>,
188) -> BuiltinResult<Value> {
189 let polynomial = parse_polynomial(coeffs).await?;
190 let mut integrated = integrate_coeffs(&polynomial.coeffs);
191 if integrated.is_empty() {
192 integrated.push(constant);
193 } else if let Some(last) = integrated.last_mut() {
194 *last += constant;
195 }
196 let value = coeffs_to_value(&integrated, polynomial.orientation)?;
197 maybe_return_gpu(value, source_gpu.as_ref())
198}
199
200fn try_polyint_gpu(
201 handle: &runmat_accelerate_api::GpuTensorHandle,
202 constant: Complex64,
203) -> BuiltinResult<Option<runmat_accelerate_api::GpuTensorHandle>> {
204 if constant.im.abs() > EPS {
205 return Ok(None);
206 }
207 ensure_vector_shape(&handle.shape)?;
208 let Some(provider) =
209 runmat_accelerate_api::provider_for_handle(handle).or_else(runmat_accelerate_api::provider)
210 else {
211 return Ok(None);
212 };
213 match provider.polyint(handle, constant.re) {
214 Ok(result) => Ok(Some(result)),
215 Err(err) => {
216 trace!("polyint: provider hook unavailable, falling back to host: {err}");
217 Ok(None)
218 }
219 }
220}
221
222fn integrate_coeffs(coeffs: &[Complex64]) -> Vec<Complex64> {
223 if coeffs.is_empty() {
224 return Vec::new();
225 }
226 let mut result = Vec::with_capacity(coeffs.len() + 1);
227 for (idx, coeff) in coeffs.iter().enumerate() {
228 let power = (coeffs.len() - idx) as f64;
229 if power <= 0.0 {
230 result.push(Complex64::new(0.0, 0.0));
231 } else {
232 result.push(*coeff / Complex64::new(power, 0.0));
233 }
234 }
235 result.push(Complex64::new(0.0, 0.0));
236 result
237}
238
239fn maybe_return_gpu(
240 value: Value,
241 source_gpu: Option<&runmat_accelerate_api::GpuTensorHandle>,
242) -> BuiltinResult<Value> {
243 let Some(source_gpu) = source_gpu else {
244 return Ok(value);
245 };
246 let provider = runmat_accelerate_api::provider_for_handle(source_gpu);
247 match value {
248 Value::Tensor(tensor) => {
249 if let Some(provider) = provider {
250 let view = HostTensorView {
251 data: &tensor.data,
252 shape: &tensor.shape,
253 };
254 match provider.upload(&view) {
255 Ok(handle) => return Ok(Value::GpuTensor(handle)),
256 Err(err) => {
257 warn!("polyint: provider upload failed, keeping result on host: {err}");
258 }
259 }
260 } else {
261 trace!("polyint: no provider available to re-upload result");
262 }
263 Ok(Value::Tensor(tensor))
264 }
265 Value::ComplexTensor(tensor) => {
266 if let Some(provider) = provider {
267 match gpu_helpers::upload_complex_tensor(provider, &tensor) {
268 Ok(handle) => return Ok(gpu_helpers::complex_gpu_value(handle)),
269 Err(err) => {
270 warn!(
271 "polyint: provider complex upload failed, keeping result on host: {err}"
272 );
273 }
274 }
275 } else {
276 trace!("polyint: no provider available to re-upload complex result");
277 }
278 Ok(Value::ComplexTensor(tensor))
279 }
280 other => Ok(other),
281 }
282}
283
284fn coeffs_to_value(coeffs: &[Complex64], orientation: Orientation) -> BuiltinResult<Value> {
285 if coeffs.iter().all(|c| c.im.abs() <= EPS) {
286 let data: Vec<f64> = coeffs.iter().map(|c| c.re).collect();
287 let shape = orientation.shape_for_len(data.len());
288 let tensor =
289 Tensor::new(data, shape).map_err(|e| polyint_error(format!("polyint: {e}")))?;
290 Ok(tensor::tensor_into_value(tensor))
291 } else {
292 let data: Vec<(f64, f64)> = coeffs.iter().map(|c| (c.re, c.im)).collect();
293 let shape = orientation.shape_for_len(data.len());
294 let tensor =
295 ComplexTensor::new(data, shape).map_err(|e| polyint_error(format!("polyint: {e}")))?;
296 Ok(Value::ComplexTensor(tensor))
297 }
298}
299
300async fn parse_polynomial(value: Value) -> BuiltinResult<Polynomial> {
301 let gathered = dispatcher::gather_if_needed_async(&value).await?;
302 match gathered {
303 Value::Tensor(tensor) => parse_tensor_coeffs(&tensor),
304 Value::ComplexTensor(tensor) => parse_complex_tensor_coeffs(&tensor),
305 Value::LogicalArray(logical) => {
306 let tensor = tensor::logical_to_tensor(&logical).map_err(polyint_error)?;
307 parse_tensor_coeffs(&tensor)
308 }
309 Value::Num(n) => Ok(Polynomial {
310 coeffs: vec![Complex64::new(n, 0.0)],
311 orientation: Orientation::Scalar,
312 }),
313 Value::Int(i) => Ok(Polynomial {
314 coeffs: vec![Complex64::new(i.to_f64(), 0.0)],
315 orientation: Orientation::Scalar,
316 }),
317 Value::Bool(b) => Ok(Polynomial {
318 coeffs: vec![Complex64::new(if b { 1.0 } else { 0.0 }, 0.0)],
319 orientation: Orientation::Scalar,
320 }),
321 Value::Complex(re, im) => Ok(Polynomial {
322 coeffs: vec![Complex64::new(re, im)],
323 orientation: Orientation::Scalar,
324 }),
325 other => Err(polyint_error(format!(
326 "polyint: expected a numeric coefficient vector, got {:?}",
327 other
328 ))),
329 }
330}
331
332fn parse_tensor_coeffs(tensor: &Tensor) -> BuiltinResult<Polynomial> {
333 ensure_vector_shape(&tensor.shape)?;
334 let orientation = orientation_from_shape(&tensor.shape);
335 Ok(Polynomial {
336 coeffs: tensor
337 .data
338 .iter()
339 .map(|&v| Complex64::new(v, 0.0))
340 .collect(),
341 orientation,
342 })
343}
344
345fn parse_complex_tensor_coeffs(tensor: &ComplexTensor) -> BuiltinResult<Polynomial> {
346 ensure_vector_shape(&tensor.shape)?;
347 let orientation = orientation_from_shape(&tensor.shape);
348 Ok(Polynomial {
349 coeffs: tensor
350 .data
351 .iter()
352 .map(|&(re, im)| Complex64::new(re, im))
353 .collect(),
354 orientation,
355 })
356}
357
358async fn parse_constant(value: Value) -> BuiltinResult<Complex64> {
359 let gathered = dispatcher::gather_if_needed_async(&value).await?;
360 match gathered {
361 Value::Tensor(tensor) => {
362 if tensor.data.len() != 1 {
363 return Err(polyint_error(
364 "polyint: constant of integration must be a scalar",
365 ));
366 }
367 Ok(Complex64::new(tensor.data[0], 0.0))
368 }
369 Value::ComplexTensor(tensor) => {
370 if tensor.data.len() != 1 {
371 return Err(polyint_error(
372 "polyint: constant of integration must be a scalar",
373 ));
374 }
375 let (re, im) = tensor.data[0];
376 Ok(Complex64::new(re, im))
377 }
378 Value::Num(n) => Ok(Complex64::new(n, 0.0)),
379 Value::Int(i) => Ok(Complex64::new(i.to_f64(), 0.0)),
380 Value::Bool(b) => Ok(Complex64::new(if b { 1.0 } else { 0.0 }, 0.0)),
381 Value::Complex(re, im) => Ok(Complex64::new(re, im)),
382 Value::LogicalArray(logical) => {
383 let tensor = tensor::logical_to_tensor(&logical).map_err(polyint_error)?;
384 if tensor.data.len() != 1 {
385 return Err(polyint_error(
386 "polyint: constant of integration must be a scalar",
387 ));
388 }
389 Ok(Complex64::new(tensor.data[0], 0.0))
390 }
391 other => Err(polyint_error(format!(
392 "polyint: constant of integration must be numeric, got {:?}",
393 other
394 ))),
395 }
396}
397
398fn ensure_vector_shape(shape: &[usize]) -> BuiltinResult<()> {
399 let non_unit = shape.iter().filter(|&&dim| dim > 1).count();
400 if non_unit <= 1 {
401 Ok(())
402 } else {
403 Err(polyint_error("polyint: coefficients must form a vector"))
404 }
405}
406
407fn orientation_from_shape(shape: &[usize]) -> Orientation {
408 for (idx, &dim) in shape.iter().enumerate() {
409 if dim != 1 {
410 return match idx {
411 0 => Orientation::Column,
412 1 => Orientation::Row,
413 _ => Orientation::Column,
414 };
415 }
416 }
417 Orientation::Scalar
418}
419
420#[derive(Clone)]
421struct Polynomial {
422 coeffs: Vec<Complex64>,
423 orientation: Orientation,
424}
425
426#[derive(Clone, Copy)]
427enum Orientation {
428 Scalar,
429 Row,
430 Column,
431}
432
433impl Orientation {
434 fn shape_for_len(self, len: usize) -> Vec<usize> {
435 if len <= 1 {
436 return vec![1, 1];
437 }
438 match self {
439 Orientation::Scalar | Orientation::Row => vec![1, len],
440 Orientation::Column => vec![len, 1],
441 }
442 }
443}
444
445#[cfg(test)]
446pub(crate) mod tests {
447 use super::*;
448 use crate::builtins::common::gpu_helpers;
449 use crate::builtins::common::test_support;
450 use futures::executor::block_on;
451 #[cfg(feature = "wgpu")]
452 use runmat_accelerate_api::AccelProvider;
453 use runmat_builtins::LogicalArray;
454
455 fn assert_error_contains(err: crate::RuntimeError, needle: &str) {
456 assert!(
457 err.message().contains(needle),
458 "expected error containing '{needle}', got '{}'",
459 err.message()
460 );
461 }
462
463 #[test]
464 fn polyint_descriptor_signatures_cover_core_forms() {
465 let labels: Vec<&str> = POLYINT_DESCRIPTOR
466 .signatures
467 .iter()
468 .map(|signature| signature.label)
469 .collect();
470 assert!(labels.contains(&"q = polyint(p)"));
471 assert!(labels.contains(&"q = polyint(p, k)"));
472 }
473
474 #[test]
475 fn polyint_descriptor_errors_have_stable_codes() {
476 let codes: Vec<&str> = POLYINT_DESCRIPTOR
477 .errors
478 .iter()
479 .map(|error| error.code)
480 .collect();
481 assert!(codes.contains(&"RM.POLYINT.INVALID_ARGUMENT"));
482 assert!(codes.contains(&"RM.POLYINT.INVALID_INPUT"));
483 assert!(codes.contains(&"RM.POLYINT.INTERNAL"));
484 }
485
486 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
487 #[test]
488 fn integrates_polynomial_without_constant() {
489 let tensor = Tensor::new(vec![3.0, -2.0, 5.0, 7.0], vec![1, 4]).unwrap();
490 let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
491 match result {
492 Value::Tensor(t) => {
493 assert_eq!(t.shape, vec![1, 5]);
494 let expected = [0.75, -2.0 / 3.0, 2.5, 7.0, 0.0];
495 assert!(t
496 .data
497 .iter()
498 .zip(expected.iter())
499 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
500 }
501 other => panic!("expected tensor result, got {other:?}"),
502 }
503 }
504
505 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
506 #[test]
507 fn integrates_with_constant() {
508 let tensor = Tensor::new(vec![4.0, 0.0, -8.0], vec![1, 3]).unwrap();
509 let args = vec![Value::Num(3.0)];
510 let result = polyint_builtin(Value::Tensor(tensor), args).expect("polyint");
511 match result {
512 Value::Tensor(t) => {
513 assert_eq!(t.shape, vec![1, 4]);
514 let expected = [4.0 / 3.0, 0.0, -8.0, 3.0];
515 assert!(t
516 .data
517 .iter()
518 .zip(expected.iter())
519 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
520 }
521 other => panic!("expected tensor result, got {other:?}"),
522 }
523 }
524
525 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
526 #[test]
527 fn integrates_scalar_value() {
528 let result = polyint_builtin(Value::Num(5.0), Vec::new()).expect("polyint");
529 match result {
530 Value::Tensor(t) => {
531 assert_eq!(t.shape, vec![1, 2]);
532 assert!((t.data[0] - 5.0).abs() < 1e-12);
533 assert!(t.data[1].abs() < 1e-12);
534 }
535 other => panic!("expected tensor result, got {other:?}"),
536 }
537 }
538
539 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
540 #[test]
541 fn integrates_logical_coefficients() {
542 let logical = LogicalArray::new(vec![1, 0, 1], vec![1, 3]).unwrap();
543 let result =
544 polyint_builtin(Value::LogicalArray(logical), Vec::new()).expect("polyint logical");
545 match result {
546 Value::Tensor(t) => {
547 assert_eq!(t.shape, vec![1, 4]);
548 let expected = [1.0 / 3.0, 0.0, 1.0, 0.0];
549 assert!(t
550 .data
551 .iter()
552 .zip(expected.iter())
553 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
554 }
555 other => panic!("expected tensor result, got {other:?}"),
556 }
557 }
558
559 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
560 #[test]
561 fn preserves_column_vector_orientation() {
562 let tensor = Tensor::new(vec![2.0, 0.0, -6.0], vec![3, 1]).unwrap();
563 let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
564 match result {
565 Value::Tensor(t) => {
566 assert_eq!(t.shape, vec![4, 1]);
567 let expected = [2.0 / 3.0, 0.0, -6.0, 0.0];
568 assert!(t
569 .data
570 .iter()
571 .zip(expected.iter())
572 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
573 }
574 other => panic!("expected column tensor, got {other:?}"),
575 }
576 }
577
578 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
579 #[test]
580 fn integrates_complex_coefficients() {
581 let tensor =
582 ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.0), (0.0, 4.0)], vec![1, 3]).unwrap();
583 let args = vec![Value::Complex(0.0, -1.0)];
584 let result = polyint_builtin(Value::ComplexTensor(tensor), args).expect("polyint");
585 match result {
586 Value::ComplexTensor(t) => {
587 assert_eq!(t.shape, vec![1, 4]);
588 let expected = [(1.0 / 3.0, 2.0 / 3.0), (-1.5, 0.0), (0.0, 4.0), (0.0, -1.0)];
589 assert!(t
590 .data
591 .iter()
592 .zip(expected.iter())
593 .all(|((lre, lim), (rre, rim))| {
594 (lre - rre).abs() < 1e-12 && (lim - rim).abs() < 1e-12
595 }));
596 }
597 other => panic!("expected complex tensor, got {other:?}"),
598 }
599 }
600
601 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
602 #[test]
603 fn rejects_matrix_coefficients() {
604 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
605 let err = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect_err("expected error");
606 assert_error_contains(err, "vector");
607 }
608
609 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
610 #[test]
611 fn rejects_non_scalar_constant() {
612 let coeffs = Tensor::new(vec![1.0, -4.0, 6.0], vec![1, 3]).unwrap();
613 let constant = Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
614 let err = polyint_builtin(Value::Tensor(coeffs), vec![Value::Tensor(constant)])
615 .expect_err("expected error");
616 assert_error_contains(err, "constant of integration must be a scalar");
617 }
618
619 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
620 #[test]
621 fn rejects_excess_arguments() {
622 let tensor = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
623 let err = polyint_builtin(
624 Value::Tensor(tensor),
625 vec![Value::Num(1.0), Value::Num(2.0)],
626 )
627 .expect_err("expected error");
628 assert_eq!(err.identifier(), POLYINT_ERROR_INVALID_ARGUMENT.identifier);
629 assert_error_contains(err, "too many input arguments");
630 }
631
632 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
633 #[test]
634 fn handles_empty_input_as_zero_polynomial() {
635 let tensor = Tensor::new(vec![], vec![1, 0]).unwrap();
636 let result = polyint_builtin(Value::Tensor(tensor), Vec::new()).expect("polyint");
637 match result {
638 Value::Num(v) => assert!(v.abs() < 1e-12),
639 Value::Tensor(t) => {
640 assert_eq!(t.data.len(), 1);
642 assert!(t.data[0].abs() < 1e-12);
643 }
644 other => panic!("expected numeric result, got {other:?}"),
645 }
646 }
647
648 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
649 #[test]
650 fn empty_input_with_constant() {
651 let tensor = Tensor::new(vec![], vec![1, 0]).unwrap();
652 let result = polyint_builtin(Value::Tensor(tensor), vec![Value::Complex(1.5, -2.0)])
653 .expect("polyint");
654 match result {
655 Value::ComplexTensor(t) => {
656 assert_eq!(t.shape, vec![1, 1]);
657 assert_eq!(t.data.len(), 1);
658 let (re, im) = t.data[0];
659 assert!((re - 1.5).abs() < 1e-12);
660 assert!((im + 2.0).abs() < 1e-12);
661 }
662 other => panic!("expected complex tensor result, got {other:?}"),
663 }
664 }
665
666 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
667 #[test]
668 fn polyint_gpu_roundtrip() {
669 test_support::with_test_provider(|provider| {
670 let tensor = Tensor::new(vec![1.0, -4.0, 6.0], vec![1, 3]).unwrap();
671 let view = HostTensorView {
672 data: &tensor.data,
673 shape: &tensor.shape,
674 };
675 let handle = provider.upload(&view).expect("upload");
676 let result = polyint_builtin(Value::GpuTensor(handle), Vec::new()).expect("polyint");
677 match result {
678 Value::GpuTensor(handle) => {
679 let gathered = test_support::gather(Value::GpuTensor(handle)).expect("gather");
680 assert_eq!(gathered.shape, vec![1, 4]);
681 let expected = [1.0 / 3.0, -2.0, 6.0, 0.0];
682 assert!(gathered
683 .data
684 .iter()
685 .zip(expected.iter())
686 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
687 }
688 other => panic!("expected GPU tensor result, got {other:?}"),
689 }
690 });
691 }
692
693 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
694 #[test]
695 fn polyint_gpu_complex_constant_reuploads_complex_result() {
696 test_support::with_test_provider(|provider| {
697 let tensor = Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap();
698 let view = HostTensorView {
699 data: &tensor.data,
700 shape: &tensor.shape,
701 };
702 let handle = provider.upload(&view).expect("upload");
703 let result = polyint_builtin(Value::GpuTensor(handle), vec![Value::Complex(0.0, 2.0)])
704 .expect("polyint");
705 match result {
706 Value::GpuTensor(handle) => {
707 assert_eq!(
708 runmat_accelerate_api::handle_storage(&handle),
709 runmat_accelerate_api::GpuTensorStorage::ComplexInterleaved
710 );
711 let gathered =
712 block_on(gpu_helpers::gather_value_async(&Value::GpuTensor(handle)))
713 .expect("gather");
714 let Value::ComplexTensor(ct) = gathered else {
715 panic!("expected complex tensor");
716 };
717 assert_eq!(ct.shape, vec![1, 3]);
718 let expected = [(0.5, 0.0), (0.0, 0.0), (0.0, 2.0)];
719 assert!(ct
720 .data
721 .iter()
722 .zip(expected.iter())
723 .all(|((lre, lim), (rre, rim))| {
724 (lre - rre).abs() < 1e-12 && (lim - rim).abs() < 1e-12
725 }));
726 }
727 other => panic!("expected complex gpu tensor, got {other:?}"),
728 }
729 });
730 }
731
732 #[test]
733 fn polyint_complex_gpu_coefficients_stay_resident() {
734 test_support::with_test_provider(|provider| {
735 let coeffs = ComplexTensor::new(vec![(1.0, 1.0), (2.0, -1.0)], vec![1, 2]).unwrap();
736 let handle = gpu_helpers::upload_complex_tensor(provider, &coeffs).expect("upload");
737 let result =
738 polyint_builtin(Value::GpuTensor(handle), vec![Value::Num(2.0)]).expect("polyint");
739 let Value::GpuTensor(handle) = result else {
740 panic!("expected complex gpu tensor");
741 };
742 assert_eq!(
743 runmat_accelerate_api::handle_storage(&handle),
744 runmat_accelerate_api::GpuTensorStorage::ComplexInterleaved
745 );
746 let gathered = block_on(gpu_helpers::gather_value_async(&Value::GpuTensor(handle)))
747 .expect("gather");
748 let Value::ComplexTensor(ct) = gathered else {
749 panic!("expected complex tensor");
750 };
751 assert_eq!(ct.shape, vec![1, 3]);
752 let expected = [(0.5, 0.5), (2.0, -1.0), (2.0, 0.0)];
753 assert!(ct
754 .data
755 .iter()
756 .zip(expected.iter())
757 .all(|((lre, lim), (rre, rim))| {
758 (lre - rre).abs() < 1e-12 && (lim - rim).abs() < 1e-12
759 }));
760 });
761 }
762
763 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
764 #[test]
765 fn polyint_gpu_with_gpu_constant() {
766 test_support::with_test_provider(|provider| {
767 let coeffs = Tensor::new(vec![2.0, 0.0], vec![1, 2]).unwrap();
768 let coeff_view = HostTensorView {
769 data: &coeffs.data,
770 shape: &coeffs.shape,
771 };
772 let coeff_handle = provider.upload(&coeff_view).expect("upload coeffs");
773 let constant = Tensor::new(vec![3.0], vec![1, 1]).unwrap();
774 let constant_view = HostTensorView {
775 data: &constant.data,
776 shape: &constant.shape,
777 };
778 let constant_handle = provider.upload(&constant_view).expect("upload constant");
779 let result = polyint_builtin(
780 Value::GpuTensor(coeff_handle),
781 vec![Value::GpuTensor(constant_handle)],
782 )
783 .expect("polyint");
784 match result {
785 Value::GpuTensor(handle) => {
786 let gathered =
787 test_support::gather(Value::GpuTensor(handle)).expect("gather result");
788 assert_eq!(gathered.shape, vec![1, 3]);
789 let expected = [1.0, 0.0, 3.0];
790 assert!(gathered
791 .data
792 .iter()
793 .zip(expected.iter())
794 .all(|(lhs, rhs)| (lhs - rhs).abs() < 1e-12));
795 }
796 other => panic!("expected gpu tensor result, got {other:?}"),
797 }
798 });
799 }
800
801 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
802 #[test]
803 #[cfg(feature = "wgpu")]
804 fn polyint_wgpu_matches_cpu() {
805 let _guard = test_support::accel_test_lock();
806 let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
807 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
808 ) else {
809 return;
810 };
811 let tensor = Tensor::new(vec![3.0, -2.0, 5.0, 7.0], vec![1, 4]).unwrap();
812 let view = HostTensorView {
813 data: &tensor.data,
814 shape: &tensor.shape,
815 };
816 let handle = provider.upload(&view).expect("upload");
817 let gpu_value = polyint_builtin(Value::GpuTensor(handle), Vec::new()).expect("polyint gpu");
818 let gathered = test_support::gather(gpu_value).expect("gather");
819 let cpu_value =
820 polyint_builtin(Value::Tensor(tensor.clone()), Vec::new()).expect("polyint cpu");
821 let expected = match cpu_value {
822 Value::Tensor(t) => t,
823 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
824 other => panic!("unexpected cpu result {other:?}"),
825 };
826 assert_eq!(gathered.shape, expected.shape);
827 let tol = match provider.precision() {
828 runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
829 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
830 };
831 gathered
832 .data
833 .iter()
834 .zip(expected.data.iter())
835 .for_each(|(lhs, rhs)| assert!((lhs - rhs).abs() < tol));
836 }
837
838 #[test]
839 #[cfg(feature = "wgpu")]
840 fn polyint_wgpu_complex_coefficients_match_cpu() {
841 let _guard = test_support::accel_test_lock();
842 let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
843 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
844 ) else {
845 return;
846 };
847 let coeffs =
848 ComplexTensor::new(vec![(3.0, 1.5), (-2.0, 0.5), (5.0, -1.0)], vec![1, 3]).unwrap();
849 let cpu_value =
850 polyint_builtin(Value::ComplexTensor(coeffs.clone()), vec![Value::Num(2.0)])
851 .expect("polyint cpu");
852 let cpu = match cpu_value {
853 Value::ComplexTensor(t) => t,
854 other => panic!("unexpected cpu result {other:?}"),
855 };
856
857 let handle = gpu_helpers::upload_complex_tensor(provider, &coeffs).expect("upload");
858 let gpu_value =
859 polyint_builtin(Value::GpuTensor(handle), vec![Value::Num(2.0)]).expect("polyint gpu");
860 let Value::GpuTensor(handle) = gpu_value else {
861 panic!("expected gpu tensor");
862 };
863 assert_eq!(
864 runmat_accelerate_api::handle_storage(&handle),
865 runmat_accelerate_api::GpuTensorStorage::ComplexInterleaved
866 );
867 let gathered =
868 block_on(gpu_helpers::gather_value_async(&Value::GpuTensor(handle))).expect("gather");
869 let Value::ComplexTensor(gpu) = gathered else {
870 panic!("expected complex tensor");
871 };
872 assert_eq!(gpu.shape, cpu.shape);
873 let tol = match provider.precision() {
874 runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
875 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
876 };
877 gpu.data
878 .iter()
879 .zip(cpu.data.iter())
880 .for_each(|((lre, lim), (rre, rim))| {
881 assert!((lre - rre).abs() < tol);
882 assert!((lim - rim).abs() < tol);
883 });
884 }
885
886 fn polyint_builtin(coeffs: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
887 block_on(super::polyint_builtin(coeffs, rest))
888 }
889}