runmat_runtime/builtins/math/interpolation/
ppval.rs1use runmat_builtins::{ResolveContext, Type, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8 ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
9};
10
11use super::pp::{evaluate_pp, pp_from_value, query_points, Extrapolation};
12
13const NAME: &str = "ppval";
14
15#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::interpolation::ppval")]
16pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
17 name: NAME,
18 op_kind: GpuOpKind::Custom("piecewise-polynomial-eval"),
19 supported_precisions: &[ScalarType::F32, ScalarType::F64],
20 broadcast: BroadcastSemantics::Matlab,
21 provider_hooks: &[],
22 constant_strategy: ConstantStrategy::UniformBuffer,
23 residency: ResidencyPolicy::GatherImmediately,
24 nan_mode: ReductionNaN::Include,
25 two_pass_threshold: None,
26 workgroup_size: None,
27 accepts_nan_mode: false,
28 notes:
29 "Initial implementation evaluates pp structs on the CPU after gathering GPU query points.",
30};
31
32#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::interpolation::ppval")]
33pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
34 name: NAME,
35 shape: ShapeRequirements::Any,
36 constant_strategy: ConstantStrategy::UniformBuffer,
37 elementwise: None,
38 reduction: None,
39 emits_nan: true,
40 notes: "ppval is currently a runtime sink.",
41};
42
43fn ppval_type(args: &[Type], _ctx: &ResolveContext) -> Type {
44 match args.get(1) {
45 Some(Type::Num | Type::Int | Type::Bool) => Type::Num,
46 Some(Type::Tensor { shape }) | Some(Type::Logical { shape }) => Type::Tensor {
47 shape: shape.clone(),
48 },
49 _ => Type::tensor(),
50 }
51}
52
53#[runtime_builtin(
54 name = "ppval",
55 category = "math/interpolation",
56 summary = "Evaluate a piecewise-polynomial structure at query points.",
57 keywords = "ppval,piecewise polynomial,spline,pchip",
58 accel = "sink",
59 sink = true,
60 type_resolver(ppval_type),
61 builtin_path = "crate::builtins::math::interpolation::ppval"
62)]
63async fn ppval_builtin(pp: Value, xq: Value) -> crate::BuiltinResult<Value> {
64 let parsed = pp_from_value(pp, NAME).await?;
65 let query = query_points(xq, NAME).await?;
66 evaluate_pp(&parsed, &query, &Extrapolation::Extrapolate, NAME)
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use crate::builtins::math::interpolation::pp::{build_spline_pp, pp_to_value, NumericSeries};
73 use futures::executor::block_on;
74 use runmat_builtins::Tensor;
75
76 #[test]
77 fn ppval_evaluates_spline_struct() {
78 let series = NumericSeries {
79 x: vec![1.0, 2.0, 3.0],
80 y: vec![1.0, 4.0, 9.0],
81 series: 1,
82 trailing_shape: Vec::new(),
83 };
84 let pp = pp_to_value(
85 build_spline_pp(&series, "spline").expect("spline"),
86 "spline",
87 )
88 .expect("pp");
89 let query = Value::Tensor(Tensor::new(vec![1.5, 2.5], vec![1, 2]).expect("tensor"));
90 let value = block_on(ppval_builtin(pp, query)).expect("ppval");
91 let Value::Tensor(tensor) = value else {
92 panic!("expected tensor");
93 };
94 assert!((tensor.data[0] - 2.25).abs() < 1e-10);
95 assert!((tensor.data[1] - 6.25).abs() < 1e-10);
96 }
97}