1use crate::system::{IdeOptions, IdeResult, IdeSolver, IdeStats, IdeSystem};
16use numra_core::Scalar;
17
18pub struct VolterraSolver;
24
25impl<S: Scalar> IdeSolver<S> for VolterraSolver {
26 fn solve<Sys: IdeSystem<S>>(
27 system: &Sys,
28 t0: S,
29 tf: S,
30 y0: &[S],
31 options: &IdeOptions<S>,
32 ) -> Result<IdeResult<S>, String> {
33 let dim = system.dim();
34
35 if y0.len() != dim {
36 return Err(format!(
37 "Initial state dimension {} doesn't match system dimension {}",
38 y0.len(),
39 dim
40 ));
41 }
42
43 let dt = options.dt;
44 let n_steps = ((tf - t0) / dt).to_f64().ceil() as usize;
45
46 if n_steps > options.max_steps {
47 return Err(format!(
48 "Required steps {} exceeds maximum {}",
49 n_steps, options.max_steps
50 ));
51 }
52
53 let mut t_history: Vec<S> = Vec::with_capacity(n_steps + 1);
55 let mut y_history: Vec<Vec<S>> = Vec::with_capacity(n_steps + 1);
56
57 t_history.push(t0);
58 y_history.push(y0.to_vec());
59
60 let mut t_out = vec![t0];
61 let mut y_out = y0.to_vec();
62 let mut stats = IdeStats::default();
63
64 let mut f_buf = vec![S::ZERO; dim];
65 let mut k_buf = vec![S::ZERO; dim];
66
67 let mut t = t0;
68 let mut y = y0.to_vec();
69
70 for n in 1..=n_steps {
71 let t_new = t0 + S::from_usize(n) * dt;
72
73 let mut integral = vec![S::ZERO; dim];
75 compute_integral(
76 system,
77 t_new,
78 &t_history,
79 &y_history,
80 &mut integral,
81 &mut k_buf,
82 &mut stats,
83 );
84
85 system.rhs(t, &y, &mut f_buf);
87 stats.n_rhs += 1;
88
89 let mut y_new = vec![S::ZERO; dim];
91 for i in 0..dim {
92 y_new[i] = y[i] + dt * (f_buf[i] + integral[i]);
93 }
94
95 t_history.push(t_new);
97 y_history.push(y_new.clone());
98
99 t_out.push(t_new);
101 y_out.extend_from_slice(&y_new);
102 stats.n_steps += 1;
103
104 t = t_new;
105 y = y_new;
106 }
107
108 Ok(IdeResult::new(t_out, y_out, dim, stats))
109 }
110}
111
112fn compute_integral<S: Scalar, Sys: IdeSystem<S>>(
114 system: &Sys,
115 t: S,
116 t_history: &[S],
117 y_history: &[Vec<S>],
118 integral: &mut [S],
119 k_buf: &mut [S],
120 stats: &mut IdeStats,
121) {
122 let dim = integral.len();
123 let n = t_history.len();
124
125 if n < 2 {
126 return;
128 }
129
130 for item in integral.iter_mut().take(dim) {
132 *item = S::ZERO;
133 }
134
135 for j in 0..n {
136 let s = t_history[j];
137 let y_s = &y_history[j];
138
139 system.kernel(t, s, y_s, k_buf);
140 stats.n_kernel += 1;
141
142 let weight = if j == 0 || j == n - 1 {
144 S::from_f64(0.5)
145 } else {
146 S::ONE
147 };
148
149 let dt = if j < n - 1 {
151 t_history[j + 1] - s
152 } else if j > 0 {
153 s - t_history[j - 1]
154 } else {
155 S::ZERO
156 };
157
158 for i in 0..dim {
159 integral[i] += weight * dt * k_buf[i];
160 }
161 }
162}
163
164pub struct VolterraRK4Solver;
166
167impl<S: Scalar> IdeSolver<S> for VolterraRK4Solver {
168 fn solve<Sys: IdeSystem<S>>(
169 system: &Sys,
170 t0: S,
171 tf: S,
172 y0: &[S],
173 options: &IdeOptions<S>,
174 ) -> Result<IdeResult<S>, String> {
175 let dim = system.dim();
176
177 if y0.len() != dim {
178 return Err(format!(
179 "Initial state dimension {} doesn't match system dimension {}",
180 y0.len(),
181 dim
182 ));
183 }
184
185 let dt = options.dt;
186 let n_steps = ((tf - t0) / dt).to_f64().ceil() as usize;
187
188 if n_steps > options.max_steps {
189 return Err(format!(
190 "Required steps {} exceeds maximum {}",
191 n_steps, options.max_steps
192 ));
193 }
194
195 let mut t_history: Vec<S> = Vec::with_capacity(n_steps + 1);
197 let mut y_history: Vec<Vec<S>> = Vec::with_capacity(n_steps + 1);
198
199 t_history.push(t0);
200 y_history.push(y0.to_vec());
201
202 let mut t_out = vec![t0];
203 let mut y_out = y0.to_vec();
204 let mut stats = IdeStats::default();
205
206 let mut t = t0;
207 let mut y = y0.to_vec();
208
209 let half = S::from_f64(0.5);
210 let sixth = S::ONE / S::from_f64(6.0);
211 let two = S::from_f64(2.0);
212
213 for n in 1..=n_steps {
214 let t_new = t0 + S::from_usize(n) * dt;
215
216 let k1 = compute_derivative(system, t, &y, &t_history, &y_history, &mut stats);
218
219 let y_mid1: Vec<S> = y
221 .iter()
222 .zip(k1.iter())
223 .map(|(&yi, &ki)| yi + half * dt * ki)
224 .collect();
225 let k2 = compute_derivative(
226 system,
227 t + half * dt,
228 &y_mid1,
229 &t_history,
230 &y_history,
231 &mut stats,
232 );
233
234 let y_mid2: Vec<S> = y
236 .iter()
237 .zip(k2.iter())
238 .map(|(&yi, &ki)| yi + half * dt * ki)
239 .collect();
240 let k3 = compute_derivative(
241 system,
242 t + half * dt,
243 &y_mid2,
244 &t_history,
245 &y_history,
246 &mut stats,
247 );
248
249 let y_end: Vec<S> = y
251 .iter()
252 .zip(k3.iter())
253 .map(|(&yi, &ki)| yi + dt * ki)
254 .collect();
255 let k4 = compute_derivative(system, t + dt, &y_end, &t_history, &y_history, &mut stats);
256
257 let mut y_new = vec![S::ZERO; dim];
259 for i in 0..dim {
260 y_new[i] = y[i] + sixth * dt * (k1[i] + two * k2[i] + two * k3[i] + k4[i]);
261 }
262
263 t_history.push(t_new);
265 y_history.push(y_new.clone());
266
267 t_out.push(t_new);
269 y_out.extend_from_slice(&y_new);
270 stats.n_steps += 1;
271
272 t = t_new;
273 y = y_new;
274 }
275
276 Ok(IdeResult::new(t_out, y_out, dim, stats))
277 }
278}
279
280fn compute_derivative<S: Scalar, Sys: IdeSystem<S>>(
282 system: &Sys,
283 t: S,
284 y: &[S],
285 t_history: &[S],
286 y_history: &[Vec<S>],
287 stats: &mut IdeStats,
288) -> Vec<S> {
289 let dim = y.len();
290 let mut f_buf = vec![S::ZERO; dim];
291 let mut k_buf = vec![S::ZERO; dim];
292 let mut integral = vec![S::ZERO; dim];
293
294 system.rhs(t, y, &mut f_buf);
296 stats.n_rhs += 1;
297
298 compute_integral(
300 system,
301 t,
302 t_history,
303 y_history,
304 &mut integral,
305 &mut k_buf,
306 stats,
307 );
308
309 for i in 0..dim {
311 f_buf[i] += integral[i];
312 }
313
314 f_buf
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 struct SimpleIde;
324
325 impl IdeSystem<f64> for SimpleIde {
326 fn dim(&self) -> usize {
327 1
328 }
329
330 fn rhs(&self, _t: f64, y: &[f64], f: &mut [f64]) {
331 f[0] = -y[0];
332 }
333
334 fn kernel(&self, t: f64, s: f64, y_s: &[f64], k: &mut [f64]) {
335 k[0] = (-(t - s)).exp() * y_s[0];
336 }
337
338 fn is_convolution_kernel(&self) -> bool {
339 true
340 }
341 }
342
343 #[test]
344 fn test_volterra_simple() {
345 let options = IdeOptions::default().dt(0.01);
346 let result =
347 VolterraSolver::solve(&SimpleIde, 0.0, 1.0, &[1.0], &options).expect("Solve failed");
348
349 assert!(result.success);
350 assert!(result.t.len() > 1);
351
352 let y_final = result.y_final().unwrap()[0];
354 assert!(y_final > 0.0 && y_final < 2.0);
355 }
356
357 #[test]
358 fn test_volterra_rk4_more_accurate() {
359 let options = IdeOptions::default().dt(0.05);
360
361 let euler_result = VolterraSolver::solve(&SimpleIde, 0.0, 1.0, &[1.0], &options)
362 .expect("Euler solve failed");
363 let rk4_result = VolterraRK4Solver::solve(&SimpleIde, 0.0, 1.0, &[1.0], &options)
364 .expect("RK4 solve failed");
365
366 let y_euler = euler_result.y_final().unwrap()[0];
368 let y_rk4 = rk4_result.y_final().unwrap()[0];
369
370 assert!((y_euler - y_rk4).abs() < 0.1);
372 assert!((y_euler - y_rk4).abs() > 1e-6);
373 }
374
375 struct TwoDIde;
377
378 impl IdeSystem<f64> for TwoDIde {
379 fn dim(&self) -> usize {
380 2
381 }
382
383 fn rhs(&self, _t: f64, y: &[f64], f: &mut [f64]) {
384 f[0] = -y[0] + 0.1 * y[1];
385 f[1] = -y[1];
386 }
387
388 fn kernel(&self, t: f64, s: f64, y_s: &[f64], k: &mut [f64]) {
389 let decay = (-(t - s)).exp();
390 k[0] = 0.5 * decay * y_s[0];
391 k[1] = 0.2 * decay * y_s[1];
392 }
393 }
394
395 #[test]
396 fn test_volterra_2d() {
397 let options = IdeOptions::default().dt(0.01);
398 let result =
399 VolterraSolver::solve(&TwoDIde, 0.0, 1.0, &[1.0, 1.0], &options).expect("Solve failed");
400
401 assert!(result.success);
402 let y_final = result.y_final().unwrap();
403 assert_eq!(y_final.len(), 2);
404
405 assert!(y_final[0] < 1.0);
407 assert!(y_final[1] < 1.0);
408 }
409
410 #[test]
411 fn test_dimension_mismatch() {
412 let options = IdeOptions::default();
413 let result = VolterraSolver::solve(&SimpleIde, 0.0, 1.0, &[1.0, 2.0], &options);
414 assert!(result.is_err());
415 }
416
417 #[test]
418 fn test_volterra_rk4_2d() {
419 let options = IdeOptions::default().dt(0.01);
420 let result = VolterraRK4Solver::solve(&TwoDIde, 0.0, 1.0, &[1.0, 1.0], &options)
421 .expect("RK4 2D solve failed");
422
423 assert!(result.success);
424 let y_final = result.y_final().unwrap();
425 assert_eq!(y_final.len(), 2);
426
427 assert!(y_final[0] < 1.0, "y[0] should decay: {}", y_final[0]);
429 assert!(y_final[1] < 1.0, "y[1] should decay: {}", y_final[1]);
430 assert!(y_final[0] > 0.0, "y[0] should remain positive");
431 assert!(y_final[1] > 0.0, "y[1] should remain positive");
432 }
433
434 #[test]
435 fn test_volterra_max_steps_exceeded() {
436 let options = IdeOptions::default().dt(0.001).max_steps(5);
438 let result = VolterraSolver::solve(&SimpleIde, 0.0, 1.0, &[1.0], &options);
439 assert!(result.is_err());
440 let msg = result.unwrap_err();
441 assert!(msg.contains("exceeds maximum"), "Error message: {}", msg);
442 }
443
444 #[test]
445 fn test_volterra_zero_kernel() {
446 struct ZeroKernelIde;
448
449 impl IdeSystem<f64> for ZeroKernelIde {
450 fn dim(&self) -> usize {
451 1
452 }
453
454 fn rhs(&self, _t: f64, y: &[f64], f: &mut [f64]) {
455 f[0] = -y[0];
456 }
457
458 fn kernel(&self, _t: f64, _s: f64, _y_s: &[f64], k: &mut [f64]) {
459 k[0] = 0.0;
460 }
461 }
462
463 let options = IdeOptions::default().dt(0.001);
464 let result = VolterraRK4Solver::solve(&ZeroKernelIde, 0.0, 1.0, &[1.0], &options)
465 .expect("Solve failed");
466
467 let y_final = result.y_final().unwrap()[0];
468 let expected = (-1.0_f64).exp(); assert!(
470 (y_final - expected).abs() < 1e-4,
471 "Zero kernel should match pure ODE: got {}, expected {}",
472 y_final,
473 expected
474 );
475 }
476
477 #[test]
478 fn test_volterra_rk4_convergence() {
479 let options_coarse = IdeOptions::default().dt(0.02);
483 let options_fine = IdeOptions::default().dt(0.01);
484 let options_ref = IdeOptions::default().dt(0.0005);
486
487 let result_coarse = VolterraRK4Solver::solve(&SimpleIde, 0.0, 1.0, &[1.0], &options_coarse)
488 .expect("Coarse solve failed");
489 let result_fine = VolterraRK4Solver::solve(&SimpleIde, 0.0, 1.0, &[1.0], &options_fine)
490 .expect("Fine solve failed");
491 let result_ref = VolterraRK4Solver::solve(&SimpleIde, 0.0, 1.0, &[1.0], &options_ref)
492 .expect("Reference solve failed");
493
494 let y_ref = result_ref.y_final().unwrap()[0];
495 let err_coarse = (result_coarse.y_final().unwrap()[0] - y_ref).abs();
496 let err_fine = (result_fine.y_final().unwrap()[0] - y_ref).abs();
497
498 let ratio = err_coarse / err_fine;
500 assert!(
501 ratio > 1.5,
502 "RK4 solver should converge: ratio={:.2} (err_coarse={:.2e}, err_fine={:.2e})",
503 ratio,
504 err_coarse,
505 err_fine
506 );
507 assert!(
509 err_fine < err_coarse,
510 "Finer dt should give smaller error: fine={:.2e}, coarse={:.2e}",
511 err_fine,
512 err_coarse
513 );
514 }
515}