runmat_runtime/builtins/math/interpolation/
interp1.rs1use runmat_builtins::{ResolveContext, Type, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8 ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
9};
10
11use super::pp::{
12 build_pchip_pp, build_spline_pp, evaluate_linear_or_nearest, evaluate_pp,
13 implicit_series_from_values, interp_error, parse_extrapolation, parse_method, query_points,
14 series_from_values, Extrapolation, InterpMethod,
15};
16
17const NAME: &str = "interp1";
18
19#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::interpolation::interp1")]
20pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
21 name: NAME,
22 op_kind: GpuOpKind::Custom("interpolation-1d"),
23 supported_precisions: &[ScalarType::F32, ScalarType::F64],
24 broadcast: BroadcastSemantics::Matlab,
25 provider_hooks: &[],
26 constant_strategy: ConstantStrategy::InlineLiteral,
27 residency: ResidencyPolicy::GatherImmediately,
28 nan_mode: ReductionNaN::Include,
29 two_pass_threshold: None,
30 workgroup_size: None,
31 accepts_nan_mode: false,
32 notes: "Initial implementation gathers GPU inputs to the CPU reference path. Provider kernels can later accelerate linear and nearest evaluation.",
33};
34
35#[runmat_macros::register_fusion_spec(
36 builtin_path = "crate::builtins::math::interpolation::interp1"
37)]
38pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
39 name: NAME,
40 shape: ShapeRequirements::Any,
41 constant_strategy: ConstantStrategy::InlineLiteral,
42 elementwise: None,
43 reduction: None,
44 emits_nan: true,
45 notes: "Interpolation is currently a runtime sink.",
46};
47
48fn interp1_type(args: &[Type], _ctx: &ResolveContext) -> Type {
49 let query = match args.len() {
50 0 | 1 => return Type::tensor(),
51 2 => args.get(1),
52 _ => args.get(2),
53 };
54 match query {
55 Some(Type::Num | Type::Int | Type::Bool) => Type::Num,
56 Some(Type::Tensor { shape }) | Some(Type::Logical { shape }) => Type::Tensor {
57 shape: shape.clone(),
58 },
59 _ => Type::tensor(),
60 }
61}
62
63#[runtime_builtin(
64 name = "interp1",
65 category = "math/interpolation",
66 summary = "One-dimensional interpolation for sampled data.",
67 keywords = "interp1,interpolation,linear,nearest,spline,pchip",
68 accel = "sink",
69 sink = true,
70 type_resolver(interp1_type),
71 builtin_path = "crate::builtins::math::interpolation::interp1"
72)]
73async fn interp1_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
74 let parsed = ParsedInterp1::parse(args).await?;
75 match parsed.method {
76 InterpMethod::Linear | InterpMethod::Nearest => evaluate_linear_or_nearest(
77 &parsed.series,
78 &parsed.query,
79 parsed.method,
80 &parsed.extrap,
81 NAME,
82 ),
83 InterpMethod::Spline => {
84 let pp = build_spline_pp(&parsed.series, NAME)?;
85 evaluate_pp(&pp, &parsed.query, &parsed.extrap_for_cubic(), NAME)
86 }
87 InterpMethod::Pchip => {
88 let pp = build_pchip_pp(&parsed.series, NAME)?;
89 evaluate_pp(&pp, &parsed.query, &parsed.extrap_for_cubic(), NAME)
90 }
91 }
92}
93
94struct ParsedInterp1 {
95 series: super::pp::NumericSeries,
96 query: super::pp::QueryPoints,
97 method: InterpMethod,
98 extrap: Extrapolation,
99}
100
101impl ParsedInterp1 {
102 async fn parse(args: Vec<Value>) -> crate::BuiltinResult<Self> {
103 if args.len() < 2 {
104 return Err(interp_error(
105 NAME,
106 "interp1: expected at least Y and Xq arguments",
107 ));
108 }
109
110 let mut method = InterpMethod::Linear;
111 let mut extrap = Extrapolation::Nan;
112 let (series, query, options) = if args.len() == 2 || third_arg_is_option(&args) {
113 let mut iter = args.into_iter();
114 let y = iter.next().expect("Y argument");
115 let xq = iter.next().expect("Xq argument");
116 let series = implicit_series_from_values(y, NAME).await?;
117 let query = query_points(xq, NAME).await?;
118 (series, query, iter.collect::<Vec<_>>())
119 } else {
120 let mut iter = args.into_iter();
121 let x = iter.next().expect("X argument");
122 let y = iter.next().expect("Y argument");
123 let xq = iter.next().expect("Xq argument");
124 let series = series_from_values(x, y, NAME).await?;
125 let query = query_points(xq, NAME).await?;
126 (series, query, iter.collect::<Vec<_>>())
127 };
128
129 for option in &options {
130 if let Some(parsed) = parse_extrapolation(option, NAME).await? {
131 extrap = parsed;
132 continue;
133 }
134 if let Some(parsed) = parse_method(option, NAME)? {
135 method = parsed;
136 continue;
137 }
138 return Err(interp_error(
139 NAME,
140 "interp1: unsupported interpolation option",
141 ));
142 }
143
144 Ok(Self {
145 series,
146 query,
147 method,
148 extrap,
149 })
150 }
151
152 fn extrap_for_cubic(&self) -> Extrapolation {
153 self.extrap.clone()
154 }
155}
156
157fn third_arg_is_option(args: &[Value]) -> bool {
158 args.get(2)
159 .and_then(|value| crate::builtins::common::random_args::keyword_of(value))
160 .is_some()
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use futures::executor::block_on;
167 use runmat_builtins::Tensor;
168
169 fn row(values: &[f64]) -> Value {
170 Value::Tensor(Tensor::new(values.to_vec(), vec![1, values.len()]).expect("tensor"))
171 }
172
173 fn run(args: Vec<Value>) -> crate::BuiltinResult<Value> {
174 block_on(interp1_builtin(args))
175 }
176
177 #[test]
178 fn interp1_linear_midpoints() {
179 let result = run(vec![
180 row(&[1.0, 2.0, 3.0]),
181 row(&[10.0, 20.0, 40.0]),
182 row(&[1.5, 2.5]),
183 ])
184 .expect("interp1");
185 let Value::Tensor(tensor) = result else {
186 panic!("expected tensor");
187 };
188 assert_eq!(tensor.data, vec![15.0, 30.0]);
189 }
190
191 #[test]
192 fn interp1_nearest() {
193 let result = run(vec![
194 row(&[1.0, 2.0, 3.0]),
195 row(&[10.0, 20.0, 40.0]),
196 row(&[1.2, 2.8]),
197 Value::String("nearest".to_string()),
198 ])
199 .expect("interp1");
200 let Value::Tensor(tensor) = result else {
201 panic!("expected tensor");
202 };
203 assert_eq!(tensor.data, vec![10.0, 40.0]);
204 }
205
206 #[test]
207 fn interp1_default_out_of_range_is_nan() {
208 let result =
209 run(vec![row(&[1.0, 2.0]), row(&[10.0, 20.0]), Value::Num(0.0)]).expect("interp1");
210 let Value::Num(value) = result else {
211 panic!("expected scalar");
212 };
213 assert!(value.is_nan());
214 }
215
216 #[test]
217 fn interp1_extrapolates_when_requested() {
218 let result = run(vec![
219 row(&[1.0, 2.0]),
220 row(&[10.0, 20.0]),
221 Value::Num(0.0),
222 Value::String("extrap".to_string()),
223 ])
224 .expect("interp1");
225 assert_eq!(result, Value::Num(0.0));
226 }
227}