1use numra_core::Scalar;
21
22#[derive(Clone, Debug)]
24pub struct DenseSegment<S: Scalar> {
25 pub t_start: S,
27 pub t_end: S,
29 pub coeffs: Vec<S>,
32 pub dim: usize,
34}
35
36impl<S: Scalar> DenseSegment<S> {
37 pub fn new(t_start: S, t_end: S, coeffs: Vec<S>, dim: usize) -> Self {
39 Self {
40 t_start,
41 t_end,
42 coeffs,
43 dim,
44 }
45 }
46
47 #[inline]
49 pub fn contains(&self, t: S) -> bool {
50 t >= self.t_start && t <= self.t_end
51 }
52
53 #[inline]
55 pub fn h(&self) -> S {
56 self.t_end - self.t_start
57 }
58
59 #[inline]
61 pub fn theta(&self, t: S) -> S {
62 (t - self.t_start) / self.h()
63 }
64}
65
66#[derive(Clone, Debug)]
70pub struct DenseOutput<S: Scalar> {
71 segments: Vec<DenseSegment<S>>,
73 #[allow(dead_code)]
75 dim: usize,
76 direction: S,
78}
79
80impl<S: Scalar> Default for DenseOutput<S> {
81 fn default() -> Self {
82 Self::new(0, S::ONE)
83 }
84}
85
86impl<S: Scalar> DenseOutput<S> {
87 pub fn new(dim: usize, direction: S) -> Self {
89 Self {
90 segments: Vec::new(),
91 dim,
92 direction,
93 }
94 }
95
96 pub fn add_segment(&mut self, segment: DenseSegment<S>) {
98 self.segments.push(segment);
99 }
100
101 pub fn len(&self) -> usize {
103 self.segments.len()
104 }
105
106 pub fn is_empty(&self) -> bool {
108 self.segments.is_empty()
109 }
110
111 pub fn tspan(&self) -> Option<(S, S)> {
113 if self.segments.is_empty() {
114 None
115 } else {
116 let t0 = self.segments.first().unwrap().t_start;
117 let tf = self.segments.last().unwrap().t_end;
118 Some((t0, tf))
119 }
120 }
121
122 pub fn find_segment(&self, t: S) -> Option<&DenseSegment<S>> {
124 if self.segments.is_empty() {
126 return None;
127 }
128
129 let first = &self.segments[0];
131 let last = &self.segments[self.segments.len() - 1];
132
133 if self.direction > S::ZERO {
134 if t < first.t_start || t > last.t_end {
135 return None;
136 }
137 } else {
138 if t > first.t_start || t < last.t_end {
140 return None;
141 }
142 }
143
144 let mut lo = 0;
146 let mut hi = self.segments.len();
147
148 while lo < hi {
149 let mid = (lo + hi) / 2;
150 let seg = &self.segments[mid];
151
152 if seg.contains(t) {
153 return Some(seg);
154 }
155
156 if self.direction > S::ZERO {
157 if t < seg.t_start {
158 hi = mid;
159 } else {
160 lo = mid + 1;
161 }
162 } else {
163 if t > seg.t_start {
164 hi = mid;
165 } else {
166 lo = mid + 1;
167 }
168 }
169 }
170
171 None
172 }
173
174 pub fn clear(&mut self) {
176 self.segments.clear();
177 }
178
179 pub fn segments(&self) -> &[DenseSegment<S>] {
181 &self.segments
182 }
183}
184
185pub trait DenseInterpolant<S: Scalar> {
187 fn interpolate(&self, segment: &DenseSegment<S>, t: S, y_out: &mut [S]);
189
190 fn interpolate_derivative(&self, segment: &DenseSegment<S>, t: S, dydt_out: &mut [S]);
192}
193
194#[derive(Clone, Debug, Default)]
204pub struct DoPri5Interpolant;
205
206impl<S: Scalar> DenseInterpolant<S> for DoPri5Interpolant {
207 fn interpolate(&self, segment: &DenseSegment<S>, t: S, y_out: &mut [S]) {
208 let theta = segment.theta(t);
209 let h = segment.h();
210 let dim = segment.dim;
211
212 for i in 0..dim {
215 let y0 = segment.coeffs[i];
216 let d0 = segment.coeffs[dim + i];
217 let d1 = segment.coeffs[2 * dim + i];
218 let d2 = segment.coeffs[3 * dim + i];
219 let d3 = segment.coeffs[4 * dim + i];
220 let d4 = segment.coeffs[5 * dim + i];
221
222 let poly = d0 + theta * (d1 + theta * (d2 + theta * (d3 + theta * d4)));
224 y_out[i] = y0 + h * theta * poly;
225 }
226 }
227
228 fn interpolate_derivative(&self, segment: &DenseSegment<S>, t: S, dydt_out: &mut [S]) {
229 let theta = segment.theta(t);
230 let dim = segment.dim;
231
232 for i in 0..dim {
236 let d0 = segment.coeffs[dim + i];
237 let d1 = segment.coeffs[2 * dim + i];
238 let d2 = segment.coeffs[3 * dim + i];
239 let d3 = segment.coeffs[4 * dim + i];
240 let d4 = segment.coeffs[5 * dim + i];
241
242 let two = S::from_f64(2.0);
243 let three = S::from_f64(3.0);
244 let four = S::from_f64(4.0);
245 let five = S::from_f64(5.0);
246
247 let theta2 = theta * theta;
248 let theta3 = theta2 * theta;
249 let theta4 = theta3 * theta;
250
251 dydt_out[i] = d0
252 + two * theta * d1
253 + three * theta2 * d2
254 + four * theta3 * d3
255 + five * theta4 * d4;
256 }
257 }
258}
259
260impl DoPri5Interpolant {
261 pub fn build_coefficients<S: Scalar>(y0: &[S], y1: &[S], k: &[S], h: S, dim: usize) -> Vec<S> {
272 let mut coeffs = vec![S::ZERO; 6 * dim];
276
277 for i in 0..dim {
279 coeffs[i] = y0[i];
280 }
281
282 let k1 = &k[0..dim];
284 let _k2 = &k[dim..2 * dim];
285 let k3 = &k[2 * dim..3 * dim];
286 let k4 = &k[3 * dim..4 * dim];
287 let k5 = &k[4 * dim..5 * dim];
288 let _k6 = &k[5 * dim..6 * dim];
289 let k7 = &k[6 * dim..7 * dim];
290
291 for i in 0..dim {
294 let d0 = k1[i];
296
297 let ydiff = y1[i] - y0[i];
308 let bspl = h * k1[i] - ydiff;
309
310 let d1 = ydiff - h * k1[i];
312
313 let d2 = S::from_f64(2.0) * bspl - h * (k7[i] - k1[i]);
316
317 let d3 = -S::from_f64(2.0) * bspl
319 + h * (k7[i] - k1[i])
320 + h * (S::from_f64(-5.0 / 3.0) * k1[i]
321 + S::from_f64(1.0 / 3.0) * k3[i]
322 + S::from_f64(1.0 / 3.0) * k4[i]
323 + S::from_f64(-1.0 / 3.0) * k5[i]
324 + S::from_f64(5.0 / 3.0) * k7[i]);
325
326 let d4 = S::ZERO;
329
330 coeffs[dim + i] = d0;
331 coeffs[2 * dim + i] = d1;
332 coeffs[3 * dim + i] = d2;
333 coeffs[4 * dim + i] = d3;
334 coeffs[5 * dim + i] = d4;
335 }
336
337 coeffs
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
346 fn test_dense_segment_contains() {
347 let seg = DenseSegment::<f64>::new(0.0, 1.0, vec![], 1);
348 assert!(seg.contains(0.0));
349 assert!(seg.contains(0.5));
350 assert!(seg.contains(1.0));
351 assert!(!seg.contains(-0.1));
352 assert!(!seg.contains(1.1));
353 }
354
355 #[test]
356 fn test_dense_segment_theta() {
357 let seg = DenseSegment::<f64>::new(1.0, 2.0, vec![], 1);
358 assert!((seg.theta(1.0) - 0.0).abs() < 1e-10);
359 assert!((seg.theta(1.5) - 0.5).abs() < 1e-10);
360 assert!((seg.theta(2.0) - 1.0).abs() < 1e-10);
361 }
362
363 #[test]
364 fn test_dense_output_find_segment() {
365 let mut dense = DenseOutput::<f64>::new(1, 1.0);
366
367 dense.add_segment(DenseSegment::new(0.0, 1.0, vec![], 1));
368 dense.add_segment(DenseSegment::new(1.0, 2.0, vec![], 1));
369 dense.add_segment(DenseSegment::new(2.0, 3.0, vec![], 1));
370
371 assert!(dense.find_segment(0.5).is_some());
372 assert!(dense.find_segment(1.5).is_some());
373 assert!(dense.find_segment(2.5).is_some());
374 assert!(dense.find_segment(-0.5).is_none());
375 assert!(dense.find_segment(3.5).is_none());
376 }
377
378 #[test]
379 fn test_dopri5_interpolant_endpoints() {
380 let y0 = vec![1.0];
382 let y1 = vec![2.0]; let h = 1.0;
386 let k = vec![1.0; 7]; let coeffs = DoPri5Interpolant::build_coefficients(&y0, &y1, &k, h, 1);
389 let seg = DenseSegment::new(0.0, 1.0, coeffs, 1);
390 let interp = DoPri5Interpolant;
391
392 let mut y_at_0 = vec![0.0];
393 let mut y_at_1 = vec![0.0];
394
395 interp.interpolate(&seg, 0.0, &mut y_at_0);
396 interp.interpolate(&seg, 1.0, &mut y_at_1);
397
398 assert!((y_at_0[0] - y0[0]).abs() < 1e-10);
400 }
401
402 #[test]
403 fn test_dense_output_tspan() {
404 let mut dense = DenseOutput::<f64>::new(1, 1.0);
405 assert!(dense.tspan().is_none());
406
407 dense.add_segment(DenseSegment::new(0.0, 1.0, vec![], 1));
408 dense.add_segment(DenseSegment::new(1.0, 2.0, vec![], 1));
409
410 let (t0, tf) = dense.tspan().unwrap();
411 assert!((t0 - 0.0).abs() < 1e-10);
412 assert!((tf - 2.0).abs() < 1e-10);
413 }
414}