1use numra_core::Scalar;
11
12pub trait HistoryFunction<S: Scalar>: Fn(S) -> Vec<S> {}
14impl<S: Scalar, F: Fn(S) -> Vec<S>> HistoryFunction<S> for F {}
15
16#[derive(Clone, Debug)]
18pub struct HistoryStep<S: Scalar> {
19 pub t: S,
21 pub y: Vec<S>,
23 pub f: Vec<S>,
25 pub t_next: S,
27 pub y_next: Vec<S>,
29 pub f_next: Vec<S>,
31}
32
33impl<S: Scalar> HistoryStep<S> {
34 pub fn new(t: S, y: Vec<S>, f: Vec<S>, t_next: S, y_next: Vec<S>, f_next: Vec<S>) -> Self {
35 Self {
36 t,
37 y,
38 f,
39 t_next,
40 y_next,
41 f_next,
42 }
43 }
44
45 pub fn h(&self) -> S {
47 self.t_next - self.t
48 }
49
50 pub fn contains(&self, t: S) -> bool {
52 t >= self.t && t <= self.t_next
53 }
54}
55
56pub struct HermiteInterpolator;
61
62impl HermiteInterpolator {
63 pub fn interpolate<S: Scalar>(step: &HistoryStep<S>, t: S) -> Vec<S> {
70 let h = step.h();
71 if h.abs() < S::from_f64(1e-15) {
72 return step.y.clone();
73 }
74
75 let theta = (t - step.t) / h;
76 let theta2 = theta * theta;
77 let theta3 = theta2 * theta;
78
79 let h00 = S::ONE - S::from_f64(3.0) * theta2 + S::from_f64(2.0) * theta3; let h10 = theta - S::from_f64(2.0) * theta2 + theta3; let h01 = S::from_f64(3.0) * theta2 - S::from_f64(2.0) * theta3; let h11 = -theta2 + theta3; let dim = step.y.len();
86 let mut result = vec![S::ZERO; dim];
87
88 for i in 0..dim {
89 result[i] = h00 * step.y[i]
90 + h10 * h * step.f[i]
91 + h01 * step.y_next[i]
92 + h11 * h * step.f_next[i];
93 }
94
95 result
96 }
97
98 pub fn interpolate_derivative<S: Scalar>(step: &HistoryStep<S>, t: S) -> Vec<S> {
100 let h = step.h();
101 if h.abs() < S::from_f64(1e-15) {
102 return step.f.clone();
103 }
104
105 let theta = (t - step.t) / h;
106 let theta2 = theta * theta;
107
108 let dh00 = -S::from_f64(6.0) * theta + S::from_f64(6.0) * theta2;
110 let dh10 = S::ONE - S::from_f64(4.0) * theta + S::from_f64(3.0) * theta2;
111 let dh01 = S::from_f64(6.0) * theta - S::from_f64(6.0) * theta2;
112 let dh11 = -S::from_f64(2.0) * theta + S::from_f64(3.0) * theta2;
113
114 let dim = step.y.len();
115 let mut result = vec![S::ZERO; dim];
116
117 for i in 0..dim {
118 result[i] = (dh00 * step.y[i]
119 + dh10 * h * step.f[i]
120 + dh01 * step.y_next[i]
121 + dh11 * h * step.f_next[i])
122 / h;
123 }
124
125 result
126 }
127}
128
129pub struct History<S: Scalar, H: Fn(S) -> Vec<S>> {
134 initial_history: H,
136 t0: S,
138 steps: Vec<HistoryStep<S>>,
140 dim: usize,
142}
143
144impl<S: Scalar, H: Fn(S) -> Vec<S>> History<S, H> {
145 pub fn new(initial_history: H, t0: S, dim: usize) -> Self {
147 Self {
148 initial_history,
149 t0,
150 steps: Vec::new(),
151 dim,
152 }
153 }
154
155 pub fn add_step(&mut self, step: HistoryStep<S>) {
157 self.steps.push(step);
158 }
159
160 pub fn evaluate(&self, t: S) -> Vec<S> {
162 if t <= self.t0 {
164 return (self.initial_history)(t);
165 }
166
167 match self.find_step(t) {
169 Some(idx) => HermiteInterpolator::interpolate(&self.steps[idx], t),
170 None => {
171 if let Some(last) = self.steps.last() {
173 if t <= last.t_next {
174 HermiteInterpolator::interpolate(last, t)
175 } else {
176 last.y_next.clone()
178 }
179 } else {
180 (self.initial_history)(self.t0)
182 }
183 }
184 }
185 }
186
187 pub fn evaluate_derivative(&self, t: S) -> Vec<S> {
189 if t <= self.t0 {
190 let eps = S::from_f64(1e-8);
193 let y1 = (self.initial_history)(t - eps);
194 let y2 = (self.initial_history)(t + eps);
195 let mut deriv = vec![S::ZERO; self.dim];
196 for i in 0..self.dim {
197 deriv[i] = (y2[i] - y1[i]) / (S::from_f64(2.0) * eps);
198 }
199 return deriv;
200 }
201
202 match self.find_step(t) {
203 Some(idx) => HermiteInterpolator::interpolate_derivative(&self.steps[idx], t),
204 None => {
205 if let Some(last) = self.steps.last() {
206 last.f_next.clone()
207 } else {
208 vec![S::ZERO; self.dim]
209 }
210 }
211 }
212 }
213
214 fn find_step(&self, t: S) -> Option<usize> {
216 if self.steps.is_empty() {
217 return None;
218 }
219
220 let mut lo = 0;
222 let mut hi = self.steps.len();
223
224 while lo < hi {
225 let mid = (lo + hi) / 2;
226 if t < self.steps[mid].t {
227 hi = mid;
228 } else if t > self.steps[mid].t_next {
229 lo = mid + 1;
230 } else {
231 return Some(mid);
232 }
233 }
234
235 if lo < self.steps.len() && self.steps[lo].contains(t) {
237 Some(lo)
238 } else if lo > 0 && self.steps[lo - 1].contains(t) {
239 Some(lo - 1)
240 } else {
241 None
242 }
243 }
244
245 pub fn current_time(&self) -> S {
247 self.steps.last().map(|s| s.t_next).unwrap_or(self.t0)
248 }
249
250 pub fn n_steps(&self) -> usize {
252 self.steps.len()
253 }
254
255 pub fn clear(&mut self) {
257 self.steps.clear();
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn test_hermite_interpolation() {
267 let step = HistoryStep {
270 t: 0.0,
271 y: vec![0.0],
272 f: vec![0.0],
273 t_next: 1.0,
274 y_next: vec![1.0],
275 f_next: vec![2.0],
276 };
277
278 let y_mid = HermiteInterpolator::interpolate(&step, 0.5);
280 assert!((y_mid[0] - 0.25).abs() < 1e-10);
281
282 let y_quarter = HermiteInterpolator::interpolate(&step, 0.25);
284 assert!((y_quarter[0] - 0.0625).abs() < 1e-10);
285 }
286
287 #[test]
288 fn test_hermite_derivative() {
289 let step = HistoryStep {
291 t: 0.0,
292 y: vec![0.0],
293 f: vec![0.0],
294 t_next: 1.0,
295 y_next: vec![1.0],
296 f_next: vec![2.0],
297 };
298
299 let f_mid = HermiteInterpolator::interpolate_derivative(&step, 0.5);
301 assert!((f_mid[0] - 1.0).abs() < 1e-10);
302 }
303
304 #[test]
305 fn test_history_initial() {
306 let history_fn = |t: f64| vec![t.sin()];
307 let history = History::new(history_fn, 0.0, 1);
308
309 let y = history.evaluate(-1.0);
311 assert!((y[0] - (-1.0_f64).sin()).abs() < 1e-10);
312 }
313
314 #[test]
315 fn test_history_with_steps() {
316 let history_fn = |_t: f64| vec![1.0];
317 let mut history = History::new(history_fn, 0.0, 1);
318
319 history.add_step(HistoryStep {
322 t: 0.0,
323 y: vec![1.0],
324 f: vec![1.0],
325 t_next: 1.0,
326 y_next: vec![2.0],
327 f_next: vec![1.0],
328 });
329
330 let y = history.evaluate(0.5);
332 assert!((y[0] - 1.5).abs() < 0.1); let y_pre = history.evaluate(-0.5);
336 assert!((y_pre[0] - 1.0).abs() < 1e-10);
337
338 let y_end = history.evaluate(1.0);
340 assert!((y_end[0] - 2.0).abs() < 1e-10);
341 }
342
343 #[test]
344 fn test_history_multiple_steps() {
345 let history_fn = |_t: f64| vec![0.0];
346 let mut history = History::new(history_fn, 0.0, 1);
347
348 for i in 0..5 {
350 let t = i as f64;
351 let t_next = (i + 1) as f64;
352 history.add_step(HistoryStep {
353 t,
354 y: vec![t],
355 f: vec![1.0],
356 t_next,
357 y_next: vec![t_next],
358 f_next: vec![1.0],
359 });
360 }
361
362 assert!((history.evaluate(2.5)[0] - 2.5).abs() < 0.1);
364 assert!((history.evaluate(4.0)[0] - 4.0).abs() < 1e-10);
365 }
366}