1use crate::core::generator::{Generator, Generator1D, Generator2D, Generator3D, Generator4D};
2
3#[derive(Debug)]
5pub enum SplineError {
6 NotEnoughKnots(String),
7}
8
9trait SplineCoefficients {
10 fn evaluate(&self, point: f64, knot_vector: &[f64], interval_idx: usize) -> f64;
11}
12
13#[derive(Debug, Clone, Copy)]
14struct CubicSplineCoefficients {
15 a: f64,
16 b: f64,
17 c: f64,
18 d: f64,
19}
20
21impl SplineCoefficients for CubicSplineCoefficients {
22 fn evaluate(&self, point: f64, knot_vector: &[f64], interval_idx: usize) -> f64 {
23 let t = point - knot_vector[interval_idx];
24 self.a + self.b * t + self.c * t.powi(2) + self.d * t.powi(3)
25 }
26}
27
28pub trait SplineImpl {
30 fn new(knot_vector: &[f64], knots: &[f64]) -> Result<Self, SplineError>
31 where
32 Self: Sized,
33 {
34 let mut spline = Self::init(knot_vector, knots);
35 spline.validate()?;
36 spline.precompute_coefficients();
37 Ok(spline)
38 }
39
40 fn init(knot_vector: &[f64], knots: &[f64]) -> Self;
41
42 fn validate(&self) -> Result<(), SplineError>;
43
44 fn precompute_coefficients(&mut self);
45
46 fn evaluate(&self, point: f64) -> f64;
47}
48
49#[derive(Clone, Debug)]
51pub struct NaturalCubicSpline {
52 knot_vector: Vec<f64>,
53 knots: Vec<f64>,
54 coefficients: Vec<CubicSplineCoefficients>,
55}
56
57impl SplineImpl for NaturalCubicSpline {
58 fn init(knot_vector: &[f64], knots: &[f64]) -> Self {
59 Self {
60 knot_vector: knot_vector.into(),
61 knots: knots.into(),
62 coefficients: Vec::new(),
63 }
64 }
65
66 fn validate(&self) -> Result<(), SplineError> {
67 if self.knots.len() < 4 {
68 return Err(SplineError::NotEnoughKnots(format!(
69 "Cubic spline expected at least 4 knots, but got {}.",
70 self.knots.len()
71 )));
72 }
73 if self.knots.len() != self.knot_vector.len() {
74 return Err(SplineError::NotEnoughKnots(
75 "Knot vector and knots must be the same length, but they were not.".to_owned(),
76 ));
77 }
78 if !self.knot_vector.is_sorted() {
79 return Err(SplineError::NotEnoughKnots(
80 "Knot vector must be sorted, but it was not.".to_owned(),
81 ));
82 }
83 if self.knot_vector.iter().any(|x| !x.is_finite()) {
84 return Err(SplineError::NotEnoughKnots(
85 "Knot vector must contain finite values, but encountered either NaN, Inf or -Inf."
86 .to_owned(),
87 ));
88 }
89 if self.knots.iter().any(|x| !x.is_finite()) {
90 return Err(SplineError::NotEnoughKnots(
91 "Knots must contain finite values, but encountered either NaN, Inf or -Inf."
92 .to_owned(),
93 ));
94 }
95 Ok(())
96 }
97
98 fn precompute_coefficients(&mut self) {
99 let n = self.knots.len();
100 let mut h = vec![0.0; n - 1];
101 let mut alpha = vec![0.0; n - 1];
102 for (i, hi) in h.iter_mut().enumerate().take(n - 1) {
104 *hi = self.knot_vector[i + 1] - self.knot_vector[i];
105 }
106 for i in 1..n - 1 {
107 alpha[i] = (3.0 / h[i]) * (self.knots[i + 1] - self.knots[i])
109 - (3.0 / h[i - 1]) * (self.knots[i] - self.knots[i - 1]);
110 }
111 let mut l = vec![0.0; n];
113 let mut mu = vec![0.0; n];
114 let mut z = vec![0.0; n];
115 let mut c = vec![0.0; n];
116 l[0] = 1.0;
117 mu[0] = 0.0;
118 z[0] = 0.0;
119 for i in 1..n - 1 {
120 l[i] = 2.0 * (self.knot_vector[i + 1] - self.knot_vector[i - 1]) - h[i - 1] * mu[i - 1];
121 mu[i] = h[i] / l[i];
122 z[i] = (alpha[i] - h[i - 1] * z[i - 1]) / l[i];
123 }
124 l[n - 1] = 1.0;
125 z[n - 1] = 0.0;
126 c[n - 1] = 0.0;
127 let mut b = vec![0.0; n - 1];
129 let mut d = vec![0.0; n - 1];
130 let a = self.knots[..n - 1].to_vec();
131 for j in (0..n - 1).rev() {
132 c[j] = z[j] - mu[j] * c[j + 1];
133 b[j] =
134 (self.knots[j + 1] - self.knots[j]) / h[j] - h[j] * (c[j + 1] + 2.0 * c[j]) / 3.0;
135 d[j] = (c[j + 1] - c[j]) / (3.0 * h[j]);
136 }
137 for i in 0..n - 1 {
139 self.coefficients.push(CubicSplineCoefficients {
140 a: a[i],
141 b: b[i],
142 c: c[i],
143 d: d[i],
144 });
145 }
146 }
147
148 fn evaluate(&self, point: f64) -> f64 {
149 if point < *self.knot_vector.first().unwrap() || point > *self.knot_vector.last().unwrap() {
151 return f64::NAN;
152 }
153 let idx = self
155 .knot_vector
156 .binary_search_by(|x| x.partial_cmp(&point).unwrap())
157 .unwrap_or_else(|idx| idx - 1);
158 self.coefficients[idx].evaluate(point, &self.knot_vector, idx)
160 }
161}
162
163#[derive(Clone, Debug)]
170pub struct Spline<const D: usize, G, S: SplineImpl> {
171 generator: G,
172 spline: S,
173}
174
175impl<G: Generator<1>, S: SplineImpl> Generator1D for Spline<1, G, S> {}
176impl<G: Generator<2>, S: SplineImpl> Generator2D for Spline<2, G, S> {}
177impl<G: Generator<3>, S: SplineImpl> Generator3D for Spline<3, G, S> {}
178impl<G: Generator<4>, S: SplineImpl> Generator4D for Spline<4, G, S> {}
179
180impl<const D: usize, G, S: SplineImpl> Spline<D, G, S>
181where
182 G: Generator<D>,
183{
184 #[inline]
185 pub fn new(generator: G, knot_vector: &[f64], knots: &[f64]) -> Self {
186 let spline = SplineImpl::new(knot_vector, knots).unwrap();
187 Self { generator, spline }
188 }
189}
190
191impl<const D: usize, G, S> Generator<D> for Spline<D, G, S>
192where
193 G: Generator<D>,
194 S: SplineImpl,
195{
196 #[inline]
197 fn sample(&self, point: [f64; D]) -> f64 {
198 self.spline.evaluate(self.generator.sample(point))
199 }
200}