Skip to main content

fdars_core/irreg_fdata/
kernels.rs

1//! Kernel functions and mean estimation for irregular functional data.
2
3use crate::slice_maybe_parallel;
4#[cfg(feature = "parallel")]
5use rayon::iter::ParallelIterator;
6
7use super::IrregFdata;
8
9/// Kernel function type for smoothing operations.
10#[derive(Clone, Copy, Debug, PartialEq, Eq)]
11#[non_exhaustive]
12pub enum KernelType {
13    /// Epanechnikov kernel: K(u) = 0.75(1 - u^2) for |u| <= 1
14    Epanechnikov,
15    /// Gaussian kernel: K(u) = exp(-u^2/2) / sqrt(2*pi)
16    Gaussian,
17}
18
19/// Epanechnikov kernel function.
20#[inline]
21pub(crate) fn kernel_epanechnikov(u: f64) -> f64 {
22    if u.abs() <= 1.0 {
23        0.75 * (1.0 - u * u)
24    } else {
25        0.0
26    }
27}
28
29/// Gaussian kernel function.
30#[inline]
31pub(crate) fn kernel_gaussian(u: f64) -> f64 {
32    (-0.5 * u * u).exp() / (2.0 * std::f64::consts::PI).sqrt()
33}
34
35impl KernelType {
36    #[inline]
37    pub(crate) fn as_fn(self) -> fn(f64) -> f64 {
38        match self {
39            KernelType::Epanechnikov => kernel_epanechnikov,
40            KernelType::Gaussian => kernel_gaussian,
41        }
42    }
43}
44
45/// Estimate mean function at specified target points using kernel smoothing.
46///
47/// Uses local weighted averaging (Nadaraya-Watson estimator) at each target point:
48/// mu_hat(t) = sum_{i,j} K_h(t - t_{ij}) x_{ij} / sum_{i,j} K_h(t - t_{ij})
49///
50/// # Arguments
51/// * `ifd` - Irregular functional data
52/// * `target_argvals` - Points at which to estimate the mean
53/// * `bandwidth` - Kernel bandwidth
54/// * `kernel_type` - Kernel function to use
55///
56/// # Returns
57/// Estimated mean function values at target points
58pub fn mean_irreg(
59    ifd: &IrregFdata,
60    target_argvals: &[f64],
61    bandwidth: f64,
62    kernel_type: KernelType,
63) -> Vec<f64> {
64    let n = ifd.n_obs();
65    let kernel = kernel_type.as_fn();
66
67    slice_maybe_parallel!(target_argvals)
68        .map(|&t| {
69            let mut sum_weights = 0.0;
70            let mut sum_values = 0.0;
71
72            for i in 0..n {
73                let (obs_t, obs_x) = ifd.get_obs(i);
74
75                for (&ti, &xi) in obs_t.iter().zip(obs_x.iter()) {
76                    let u = (ti - t) / bandwidth;
77                    let w = kernel(u);
78                    sum_weights += w;
79                    sum_values += w * xi;
80                }
81            }
82
83            if sum_weights > 0.0 {
84                sum_values / sum_weights
85            } else {
86                f64::NAN
87            }
88        })
89        .collect()
90}