1use super::srsf::srsf_single;
9use super::{dp_edge_weight, dp_lambda_penalty, dp_path_to_gamma};
10use crate::error::FdarError;
11use crate::helpers::{l2_distance, simpsons_weights};
12
13#[derive(Debug, Clone, PartialEq)]
17pub struct PartialMatchConfig {
18 pub lambda: f64,
20 pub min_span: f64,
23}
24
25impl Default for PartialMatchConfig {
26 fn default() -> Self {
27 Self {
28 lambda: 0.0,
29 min_span: 0.5,
30 }
31 }
32}
33
34#[derive(Debug, Clone, PartialEq)]
36#[non_exhaustive]
37pub struct PartialMatchResult {
38 pub start_index: usize,
40 pub end_index: usize,
42 pub gamma: Vec<f64>,
44 pub distance: f64,
46 pub domain_fraction: f64,
48}
49
50#[must_use = "expensive computation whose result should not be discarded"]
69pub fn elastic_partial_match(
70 template: &[f64],
71 target: &[f64],
72 argvals_template: &[f64],
73 argvals_target: &[f64],
74 config: &PartialMatchConfig,
75) -> Result<PartialMatchResult, FdarError> {
76 let m_t = template.len();
77 let m_f = target.len();
78
79 if m_t != argvals_template.len() {
80 return Err(FdarError::InvalidDimension {
81 parameter: "argvals_template",
82 expected: format!("{m_t}"),
83 actual: format!("{}", argvals_template.len()),
84 });
85 }
86 if m_f != argvals_target.len() {
87 return Err(FdarError::InvalidDimension {
88 parameter: "argvals_target",
89 expected: format!("{m_f}"),
90 actual: format!("{}", argvals_target.len()),
91 });
92 }
93 if m_t < 2 || m_f < 2 {
94 return Err(FdarError::InvalidDimension {
95 parameter: "template/target",
96 expected: "length >= 2".to_string(),
97 actual: format!("template={m_t}, target={m_f}"),
98 });
99 }
100 if config.min_span <= 0.0 || config.min_span > 1.0 {
101 return Err(FdarError::InvalidParameter {
102 parameter: "min_span",
103 message: format!("must be in (0, 1], got {}", config.min_span),
104 });
105 }
106
107 let q_template = srsf_single(template, argvals_template);
108
109 let min_window = ((m_f as f64 * config.min_span).ceil() as usize).max(2);
111
112 let mut best_start = 0;
113 let mut best_end = m_f - 1;
114 let mut best_dist = f64::INFINITY;
115 let mut best_gamma = argvals_template.to_vec();
116
117 let n_sizes = 5.min(m_f - min_window + 1);
120 let sizes: Vec<usize> = if n_sizes <= 1 {
121 vec![m_f]
122 } else {
123 (0..n_sizes)
124 .map(|i| min_window + i * (m_f - min_window) / (n_sizes - 1))
125 .collect()
126 };
127
128 for &win_size in &sizes {
129 let step = (win_size / 10).max(1);
130 let mut start = 0;
131 while start + win_size <= m_f {
132 let end = start + win_size - 1;
133
134 let sub_argvals: Vec<f64> = (0..m_t)
136 .map(|i| {
137 argvals_target[start]
138 + (argvals_target[end] - argvals_target[start]) * i as f64
139 / (m_t - 1) as f64
140 })
141 .collect();
142
143 let sub_target: Vec<f64> = sub_argvals
145 .iter()
146 .map(|&t| interp_target(target, argvals_target, t))
147 .collect();
148
149 let q_sub = srsf_single(&sub_target, argvals_template);
150
151 let gamma = dp_align_partial(&q_template, &q_sub, argvals_template, config.lambda);
153
154 let sub_aligned: Vec<f64> = argvals_template
156 .iter()
157 .map(|&t| {
158 interp_target(
159 &sub_target,
160 argvals_template,
161 interp_target(&gamma, argvals_template, t),
162 )
163 })
164 .collect();
165 let q_aligned = srsf_single(&sub_aligned, argvals_template);
166 let weights = simpsons_weights(argvals_template);
167 let dist = l2_distance(&q_template, &q_aligned, &weights);
168
169 if dist < best_dist {
170 best_dist = dist;
171 best_start = start;
172 best_end = end;
173 best_gamma = gamma;
174 }
175
176 start += step;
177 }
178 }
179
180 let total_domain = argvals_target[m_f - 1] - argvals_target[0];
181 let match_domain = argvals_target[best_end] - argvals_target[best_start];
182 let domain_fraction = if total_domain > 0.0 {
183 match_domain / total_domain
184 } else {
185 1.0
186 };
187
188 Ok(PartialMatchResult {
189 start_index: best_start,
190 end_index: best_end,
191 gamma: best_gamma,
192 distance: best_dist,
193 domain_fraction,
194 })
195}
196
197fn interp_target(values: &[f64], grid: &[f64], t: f64) -> f64 {
201 let n = grid.len();
202 if n == 0 {
203 return 0.0;
204 }
205 if t <= grid[0] {
206 return values[0];
207 }
208 if t >= grid[n - 1] {
209 return values[n - 1];
210 }
211 let mut lo = 0;
213 let mut hi = n - 1;
214 while hi - lo > 1 {
215 let mid = (lo + hi) / 2;
216 if grid[mid] <= t {
217 lo = mid;
218 } else {
219 hi = mid;
220 }
221 }
222 let frac = (t - grid[lo]) / (grid[hi] - grid[lo]);
223 values[lo] * (1.0 - frac) + values[hi] * frac
224}
225
226fn dp_align_partial(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> Vec<f64> {
228 let m = argvals.len();
229 if m < 2 {
230 return argvals.to_vec();
231 }
232
233 let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
234 let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
235 let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
236 let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
237
238 let path = super::dp_grid_solve(m, m, |sr, sc, tr, tc| {
239 dp_edge_weight(&q1n, &q2n, argvals, sc, tc, sr, tr)
240 + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
241 });
242
243 dp_path_to_gamma(&path, argvals)
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use crate::test_helpers::uniform_grid;
250
251 #[test]
252 fn partial_match_identity() {
253 let m = 30;
254 let t = uniform_grid(m);
255 let f: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
256
257 let config = PartialMatchConfig {
258 min_span: 0.5,
259 ..Default::default()
260 };
261 let result = elastic_partial_match(&f, &f, &t, &t, &config).unwrap();
262
263 assert!(
264 result.distance < 0.5,
265 "matching a curve to itself should give small distance, got {}",
266 result.distance
267 );
268 }
269
270 #[test]
271 fn partial_match_subcurve() {
272 let m = 40;
273 let t = uniform_grid(m);
274 let target: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
275
276 let m_t = 20;
278 let t_template = uniform_grid(m_t);
279 let template: Vec<f64> = t_template
280 .iter()
281 .map(|&x| ((x * 0.5 + 0.25) * 6.0).sin())
282 .collect();
283
284 let config = PartialMatchConfig {
285 min_span: 0.3,
286 ..Default::default()
287 };
288 let result = elastic_partial_match(&template, &target, &t_template, &t, &config).unwrap();
289
290 assert!(result.start_index < result.end_index);
291 assert!(result.domain_fraction >= 0.3);
292 assert!(result.gamma.len() == m_t);
293 }
294
295 #[test]
296 fn partial_match_rejects_bad_min_span() {
297 let t = uniform_grid(10);
298 let f: Vec<f64> = t.iter().map(|&x| x * x).collect();
299 let config = PartialMatchConfig {
300 min_span: 0.0,
301 ..Default::default()
302 };
303 assert!(elastic_partial_match(&f, &f, &t, &t, &config).is_err());
304 }
305
306 #[test]
307 fn partial_match_config_default() {
308 let config = PartialMatchConfig::default();
309 assert!((config.lambda - 0.0).abs() < f64::EPSILON);
310 assert!((config.min_span - 0.5).abs() < f64::EPSILON);
311 }
312}