use super::srsf::srsf_single;
use super::{dp_edge_weight, dp_lambda_penalty, dp_path_to_gamma};
use crate::error::FdarError;
use crate::helpers::{l2_distance, simpsons_weights};
#[derive(Debug, Clone, PartialEq)]
pub struct PartialMatchConfig {
pub lambda: f64,
pub min_span: f64,
}
impl Default for PartialMatchConfig {
fn default() -> Self {
Self {
lambda: 0.0,
min_span: 0.5,
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct PartialMatchResult {
pub start_index: usize,
pub end_index: usize,
pub gamma: Vec<f64>,
pub distance: f64,
pub domain_fraction: f64,
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn elastic_partial_match(
template: &[f64],
target: &[f64],
argvals_template: &[f64],
argvals_target: &[f64],
config: &PartialMatchConfig,
) -> Result<PartialMatchResult, FdarError> {
let m_t = template.len();
let m_f = target.len();
if m_t != argvals_template.len() {
return Err(FdarError::InvalidDimension {
parameter: "argvals_template",
expected: format!("{m_t}"),
actual: format!("{}", argvals_template.len()),
});
}
if m_f != argvals_target.len() {
return Err(FdarError::InvalidDimension {
parameter: "argvals_target",
expected: format!("{m_f}"),
actual: format!("{}", argvals_target.len()),
});
}
if m_t < 2 || m_f < 2 {
return Err(FdarError::InvalidDimension {
parameter: "template/target",
expected: "length >= 2".to_string(),
actual: format!("template={m_t}, target={m_f}"),
});
}
if config.min_span <= 0.0 || config.min_span > 1.0 {
return Err(FdarError::InvalidParameter {
parameter: "min_span",
message: format!("must be in (0, 1], got {}", config.min_span),
});
}
let q_template = srsf_single(template, argvals_template);
let min_window = ((m_f as f64 * config.min_span).ceil() as usize).max(2);
let mut best_start = 0;
let mut best_end = m_f - 1;
let mut best_dist = f64::INFINITY;
let mut best_gamma = argvals_template.to_vec();
let n_sizes = 5.min(m_f - min_window + 1);
let sizes: Vec<usize> = if n_sizes <= 1 {
vec![m_f]
} else {
(0..n_sizes)
.map(|i| min_window + i * (m_f - min_window) / (n_sizes - 1))
.collect()
};
for &win_size in &sizes {
let step = (win_size / 10).max(1);
let mut start = 0;
while start + win_size <= m_f {
let end = start + win_size - 1;
let sub_argvals: Vec<f64> = (0..m_t)
.map(|i| {
argvals_target[start]
+ (argvals_target[end] - argvals_target[start]) * i as f64
/ (m_t - 1) as f64
})
.collect();
let sub_target: Vec<f64> = sub_argvals
.iter()
.map(|&t| interp_target(target, argvals_target, t))
.collect();
let q_sub = srsf_single(&sub_target, argvals_template);
let gamma = dp_align_partial(&q_template, &q_sub, argvals_template, config.lambda);
let sub_aligned: Vec<f64> = argvals_template
.iter()
.map(|&t| {
interp_target(
&sub_target,
argvals_template,
interp_target(&gamma, argvals_template, t),
)
})
.collect();
let q_aligned = srsf_single(&sub_aligned, argvals_template);
let weights = simpsons_weights(argvals_template);
let dist = l2_distance(&q_template, &q_aligned, &weights);
if dist < best_dist {
best_dist = dist;
best_start = start;
best_end = end;
best_gamma = gamma;
}
start += step;
}
}
let total_domain = argvals_target[m_f - 1] - argvals_target[0];
let match_domain = argvals_target[best_end] - argvals_target[best_start];
let domain_fraction = if total_domain > 0.0 {
match_domain / total_domain
} else {
1.0
};
Ok(PartialMatchResult {
start_index: best_start,
end_index: best_end,
gamma: best_gamma,
distance: best_dist,
domain_fraction,
})
}
fn interp_target(values: &[f64], grid: &[f64], t: f64) -> f64 {
let n = grid.len();
if n == 0 {
return 0.0;
}
if t <= grid[0] {
return values[0];
}
if t >= grid[n - 1] {
return values[n - 1];
}
let mut lo = 0;
let mut hi = n - 1;
while hi - lo > 1 {
let mid = (lo + hi) / 2;
if grid[mid] <= t {
lo = mid;
} else {
hi = mid;
}
}
let frac = (t - grid[lo]) / (grid[hi] - grid[lo]);
values[lo] * (1.0 - frac) + values[hi] * frac
}
fn dp_align_partial(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> Vec<f64> {
let m = argvals.len();
if m < 2 {
return argvals.to_vec();
}
let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
let path = super::dp_grid_solve(m, m, |sr, sc, tr, tc| {
dp_edge_weight(&q1n, &q2n, argvals, sc, tc, sr, tr)
+ dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
});
dp_path_to_gamma(&path, argvals)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::uniform_grid;
#[test]
fn partial_match_identity() {
let m = 30;
let t = uniform_grid(m);
let f: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
let config = PartialMatchConfig {
min_span: 0.5,
..Default::default()
};
let result = elastic_partial_match(&f, &f, &t, &t, &config).unwrap();
assert!(
result.distance < 0.5,
"matching a curve to itself should give small distance, got {}",
result.distance
);
}
#[test]
fn partial_match_subcurve() {
let m = 40;
let t = uniform_grid(m);
let target: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
let m_t = 20;
let t_template = uniform_grid(m_t);
let template: Vec<f64> = t_template
.iter()
.map(|&x| ((x * 0.5 + 0.25) * 6.0).sin())
.collect();
let config = PartialMatchConfig {
min_span: 0.3,
..Default::default()
};
let result = elastic_partial_match(&template, &target, &t_template, &t, &config).unwrap();
assert!(result.start_index < result.end_index);
assert!(result.domain_fraction >= 0.3);
assert!(result.gamma.len() == m_t);
}
#[test]
fn partial_match_rejects_bad_min_span() {
let t = uniform_grid(10);
let f: Vec<f64> = t.iter().map(|&x| x * x).collect();
let config = PartialMatchConfig {
min_span: 0.0,
..Default::default()
};
assert!(elastic_partial_match(&f, &f, &t, &t, &config).is_err());
}
#[test]
fn partial_match_config_default() {
let config = PartialMatchConfig::default();
assert!((config.lambda - 0.0).abs() < f64::EPSILON);
assert!((config.min_span - 0.5).abs() < f64::EPSILON);
}
}