rbf_interpolation/
rbf.rs

1use std::iter::{Product, Sum};
2
3use na::{Const, DimName, Matrix, RealField, Scalar, Storage, U1, Vector};
4
5use crate::{builder::RBFInterpolatorBuilder, powers::monomial_exponents};
6
7pub struct RBFInterpolator<
8    T,
9    const DEGREE: usize,
10    const MONOMIALS: usize,
11    const POINTS: usize,
12    const DIM: usize,
13    SP,
14    SW,
15> where
16    T: Scalar,
17    SP: Storage<T, Const<DIM>, Const<POINTS>>,
18    SW: Storage<T, Const<{ POINTS + MONOMIALS }>, U1>,
19    Const<{ POINTS + MONOMIALS }>: DimName,
20{
21    pub(crate) kernel: RBFInterpolatorBuilder<T, DEGREE, MONOMIALS, POINTS, DIM>,
22    pub(crate) points: Matrix<T, Const<DIM>, Const<POINTS>, SP>,
23    pub(crate) weights: Vector<T, Const<{ POINTS + MONOMIALS }>, SW>,
24}
25
26impl<T, const DEGREE: usize, const MONOMIALS: usize, const POINTS: usize, const DIM: usize, SP, SW>
27    RBFInterpolator<T, DEGREE, MONOMIALS, POINTS, DIM, SP, SW>
28where
29    T: Scalar + RealField + Copy + Sum + Copy + Product,
30    SP: Storage<T, Const<DIM>, Const<POINTS>>,
31    SW: Storage<T, Const<{ POINTS + MONOMIALS }>, U1>,
32    Const<{ POINTS + MONOMIALS }>: DimName,
33{
34    pub fn interpolate<S1>(&self, point_b: &Vector<T, Const<DIM>, S1>) -> T
35    where
36        S1: Storage<T, Const<DIM>, U1>,
37    {
38        let mut weights = self.weights.iter();
39
40        // Add the phi terms
41        let phi: T = self
42            .points
43            .column_iter()
44            .map(|point_a| {
45                let &weight = weights.next().unwrap();
46                let phi = self
47                    .kernel
48                    .kernel((point_a - point_b).map(|e| e.powi(2)).sum().sqrt());
49
50                weight * phi
51            })
52            .sum();
53
54        // Add the polynomial terms
55        let exponents = monomial_exponents::<DIM, DEGREE>();
56        debug_assert_eq!(exponents.len(), MONOMIALS);
57
58        let polynomial: T = exponents
59            .iter()
60            .map(|exponent| {
61                debug_assert_eq!(exponent.len(), point_b.len());
62                let weight = weights.next().unwrap();
63                let value: T = exponent
64                    .iter()
65                    .zip(point_b.row_iter())
66                    .map(|(&exponent, ordinate)| ordinate[(0, 0)].powi(exponent))
67                    .product();
68
69                *weight * value
70            })
71            .sum();
72
73        debug_assert!(weights.next().is_none());
74
75        phi + polynomial
76    }
77}