1use numra_core::Scalar;
29
30use crate::error::SolverError;
31
32pub struct TEvalEmitter<'a, S: Scalar> {
37 points: &'a [S],
38 idx: usize,
39 direction: S,
40}
41
42impl<'a, S: Scalar> TEvalEmitter<'a, S> {
43 pub fn new(points: &'a [S], direction: S) -> Self {
44 Self {
45 points,
46 idx: 0,
47 direction,
48 }
49 }
50
51 pub fn is_done(&self) -> bool {
53 self.idx >= self.points.len()
54 }
55
56 pub fn emit_step(
60 &mut self,
61 t_old: S,
62 y_old: &[S],
63 dy_old: &[S],
64 t_new: S,
65 y_new: &[S],
66 dy_new: &[S],
67 t_out: &mut Vec<S>,
68 y_out: &mut Vec<S>,
69 ) {
70 let dim = y_old.len();
71 let h = t_new - t_old;
72 if h == S::ZERO {
75 while self.idx < self.points.len() && self.points[self.idx] == t_old {
76 t_out.push(t_old);
77 y_out.extend_from_slice(y_old);
78 self.idx += 1;
79 }
80 return;
81 }
82
83 while self.idx < self.points.len() {
84 let t_q = self.points[self.idx];
85 let in_step = if self.direction > S::ZERO {
86 t_q >= t_old && t_q <= t_new
87 } else {
88 t_q <= t_old && t_q >= t_new
89 };
90 if !in_step {
91 break;
92 }
93
94 let theta = (t_q - t_old) / h;
95 let theta2 = theta * theta;
96 let theta3 = theta2 * theta;
97 let three = S::from_f64(3.0);
98 let h00 = S::TWO * theta3 - three * theta2 + S::ONE;
99 let h10 = theta3 - S::TWO * theta2 + theta;
100 let h01 = -S::TWO * theta3 + three * theta2;
101 let h11 = theta3 - theta2;
102
103 t_out.push(t_q);
104 for i in 0..dim {
105 y_out.push(
106 h00 * y_old[i] + h10 * h * dy_old[i] + h01 * y_new[i] + h11 * h * dy_new[i],
107 );
108 }
109 self.idx += 1;
110 }
111 }
112}
113
114pub fn validate_grid<S: Scalar>(grid: &[S], t0: S, tf: S) -> Result<(), SolverError> {
119 if grid.is_empty() {
120 return Ok(());
121 }
122 let direction = if tf >= t0 { S::ONE } else { -S::ONE };
123 let (lo, hi) = if direction > S::ZERO {
124 (t0, tf)
125 } else {
126 (tf, t0)
127 };
128 let span = (tf - t0).abs();
131 let tol = S::EPSILON * S::from_f64(16.0) * (span + S::ONE);
132 for window in grid.windows(2) {
133 let d = window[1] - window[0];
134 if d * direction < -tol {
135 return Err(SolverError::Other(
136 "t_eval must be sorted in the direction of integration".into(),
137 ));
138 }
139 }
140 for &t in grid {
141 if t < lo - tol || t > hi + tol {
142 return Err(SolverError::Other(format!(
143 "t_eval contains {} which lies outside [t0, tf] = [{}, {}]",
144 t.to_f64(),
145 t0.to_f64(),
146 tf.to_f64()
147 )));
148 }
149 }
150 Ok(())
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn validate_rejects_out_of_range() {
159 assert!(validate_grid::<f64>(&[0.0, 1.0, 5.0], 0.0, 4.0).is_err());
160 assert!(validate_grid::<f64>(&[-1.0, 0.0, 1.0], 0.0, 4.0).is_err());
161 }
162
163 #[test]
164 fn validate_rejects_unsorted() {
165 assert!(validate_grid::<f64>(&[0.0, 2.0, 1.0], 0.0, 4.0).is_err());
166 assert!(validate_grid::<f64>(&[4.0, 1.0, 2.0], 4.0, 0.0).is_err());
168 }
169
170 #[test]
171 fn validate_accepts_descending_for_backward() {
172 assert!(validate_grid::<f64>(&[4.0, 3.0, 2.0, 1.0, 0.0], 4.0, 0.0).is_ok());
173 }
174
175 #[test]
176 fn emit_reproduces_endpoints_exactly() {
177 let grid = vec![0.0, 0.5, 1.0];
180 let mut emitter = TEvalEmitter::new(&grid, 1.0_f64);
181
182 let mut t_out = Vec::new();
183 let mut y_out = Vec::new();
184
185 emitter.emit_step(
186 0.0,
187 &[0.0],
188 &[1.0],
189 1.0,
190 &[1.0],
191 &[1.0],
192 &mut t_out,
193 &mut y_out,
194 );
195
196 assert_eq!(t_out, vec![0.0, 0.5, 1.0]);
197 assert!((y_out[0] - 0.0).abs() < 1e-15);
198 assert!((y_out[1] - 0.5).abs() < 1e-15);
199 assert!((y_out[2] - 1.0).abs() < 1e-15);
200 assert!(emitter.is_done());
201 }
202
203 #[test]
204 fn emit_advances_across_step_boundary_without_double_count() {
205 let grid = vec![1.0, 2.0];
208 let mut emitter = TEvalEmitter::new(&grid, 1.0_f64);
209 let mut t_out = Vec::new();
210 let mut y_out = Vec::new();
211
212 emitter.emit_step(
214 0.0,
215 &[0.0],
216 &[1.0],
217 1.0,
218 &[1.0],
219 &[1.0],
220 &mut t_out,
221 &mut y_out,
222 );
223 assert_eq!(t_out.len(), 1);
224 emitter.emit_step(
226 1.0,
227 &[1.0],
228 &[1.0],
229 2.0,
230 &[2.0],
231 &[1.0],
232 &mut t_out,
233 &mut y_out,
234 );
235 assert_eq!(t_out, vec![1.0, 2.0]);
236 }
237
238 #[test]
239 fn emit_handles_backward_direction() {
240 let grid = vec![1.0, 0.5, 0.0];
241 let mut emitter = TEvalEmitter::new(&grid, -1.0_f64);
242 let mut t_out = Vec::new();
243 let mut y_out = Vec::new();
244
245 emitter.emit_step(
246 1.0,
247 &[1.0],
248 &[1.0],
249 0.0,
250 &[0.0],
251 &[1.0],
252 &mut t_out,
253 &mut y_out,
254 );
255 assert_eq!(t_out, vec![1.0, 0.5, 0.0]);
256 assert!((y_out[1] - 0.5).abs() < 1e-15);
257 }
258}