1use crate::kernels::PronyKernel;
22use crate::system::{IdeOptions, IdeResult, IdeStats};
23use numra_core::Scalar;
24
25pub trait PronySystem<S: Scalar> {
29 fn dim(&self) -> usize;
31
32 fn rhs(&self, t: S, y: &[S], f: &mut [S]);
34
35 fn kernel(&self) -> &PronyKernel<S>;
37
38 fn coupling(&self) -> Option<Vec<Vec<S>>> {
45 None
46 }
47}
48
49pub struct PronySolver;
55
56impl PronySolver {
57 pub fn solve<S: Scalar, Sys: PronySystem<S>>(
59 system: &Sys,
60 t0: S,
61 tf: S,
62 y0: &[S],
63 options: &IdeOptions<S>,
64 ) -> Result<IdeResult<S>, String> {
65 let dim = system.dim();
66 let kernel = system.kernel();
67 let n_terms = kernel.num_terms();
68
69 if y0.len() != dim {
70 return Err(format!(
71 "Initial state dimension {} doesn't match system dimension {}",
72 y0.len(),
73 dim
74 ));
75 }
76
77 let dt = options.dt;
78 let n_steps = ((tf - t0) / dt).to_f64().ceil() as usize;
79
80 if n_steps > options.max_steps {
81 return Err(format!(
82 "Required steps {} exceeds maximum {}",
83 n_steps, options.max_steps
84 ));
85 }
86
87 let mut y = y0.to_vec();
89 let mut integrals: Vec<Vec<S>> = vec![vec![S::ZERO; dim]; n_terms];
90
91 let mut t_out = vec![t0];
92 let mut y_out = y0.to_vec();
93 let mut stats = IdeStats::default();
94
95 let mut t = t0;
96 let mut f_buf = vec![S::ZERO; dim];
97
98 let coupling = system.coupling();
100
101 let half = S::from_f64(0.5);
102 let sixth = S::ONE / S::from_f64(6.0);
103 let two = S::from_f64(2.0);
104
105 for _n in 1..=n_steps {
106 let t_new = t + dt;
107
108 let (k1_y, k1_i) =
112 compute_derivatives(system, t, &y, &integrals, &coupling, &mut f_buf, &mut stats);
113
114 let y_mid1: Vec<S> = y
116 .iter()
117 .zip(k1_y.iter())
118 .map(|(&yi, &ki)| yi + half * dt * ki)
119 .collect();
120 let i_mid1: Vec<Vec<S>> = integrals
121 .iter()
122 .zip(k1_i.iter())
123 .map(|(ii, ki)| {
124 ii.iter()
125 .zip(ki.iter())
126 .map(|(&ij, &kij)| ij + half * dt * kij)
127 .collect()
128 })
129 .collect();
130 let (k2_y, k2_i) = compute_derivatives(
131 system,
132 t + half * dt,
133 &y_mid1,
134 &i_mid1,
135 &coupling,
136 &mut f_buf,
137 &mut stats,
138 );
139
140 let y_mid2: Vec<S> = y
142 .iter()
143 .zip(k2_y.iter())
144 .map(|(&yi, &ki)| yi + half * dt * ki)
145 .collect();
146 let i_mid2: Vec<Vec<S>> = integrals
147 .iter()
148 .zip(k2_i.iter())
149 .map(|(ii, ki)| {
150 ii.iter()
151 .zip(ki.iter())
152 .map(|(&ij, &kij)| ij + half * dt * kij)
153 .collect()
154 })
155 .collect();
156 let (k3_y, k3_i) = compute_derivatives(
157 system,
158 t + half * dt,
159 &y_mid2,
160 &i_mid2,
161 &coupling,
162 &mut f_buf,
163 &mut stats,
164 );
165
166 let y_end: Vec<S> = y
168 .iter()
169 .zip(k3_y.iter())
170 .map(|(&yi, &ki)| yi + dt * ki)
171 .collect();
172 let i_end: Vec<Vec<S>> = integrals
173 .iter()
174 .zip(k3_i.iter())
175 .map(|(ii, ki)| {
176 ii.iter()
177 .zip(ki.iter())
178 .map(|(&ij, &kij)| ij + dt * kij)
179 .collect()
180 })
181 .collect();
182 let (k4_y, k4_i) = compute_derivatives(
183 system,
184 t + dt,
185 &y_end,
186 &i_end,
187 &coupling,
188 &mut f_buf,
189 &mut stats,
190 );
191
192 for i in 0..dim {
194 y[i] += sixth * dt * (k1_y[i] + two * k2_y[i] + two * k3_y[i] + k4_y[i]);
195 }
196
197 for k in 0..n_terms {
199 for i in 0..dim {
200 integrals[k][i] += sixth
201 * dt
202 * (k1_i[k][i] + two * k2_i[k][i] + two * k3_i[k][i] + k4_i[k][i]);
203 }
204 }
205
206 t_out.push(t_new);
208 y_out.extend_from_slice(&y);
209 stats.n_steps += 1;
210
211 t = t_new;
212 }
213
214 Ok(IdeResult::new(t_out, y_out, dim, stats))
215 }
216}
217
218fn compute_derivatives<S: Scalar, Sys: PronySystem<S>>(
220 system: &Sys,
221 t: S,
222 y: &[S],
223 integrals: &[Vec<S>],
224 coupling: &Option<Vec<Vec<S>>>,
225 f_buf: &mut [S],
226 stats: &mut IdeStats,
227) -> (Vec<S>, Vec<Vec<S>>) {
228 let dim = y.len();
229 let kernel = system.kernel();
230 let n_terms = kernel.num_terms();
231
232 system.rhs(t, y, f_buf);
234 stats.n_rhs += 1;
235
236 let mut dy = f_buf.to_vec();
238
239 if let Some(c) = coupling {
240 for i in 0..dim {
242 for k in 0..n_terms {
243 dy[i] += c[i][k] * integrals[k][i];
244 }
245 }
246 } else {
247 for i in 0..dim {
249 for integral in integrals.iter().take(n_terms) {
250 dy[i] += integral[i];
251 }
252 }
253 }
254
255 let mut di: Vec<Vec<S>> = Vec::with_capacity(n_terms);
257 for (k, integral) in integrals.iter().enumerate().take(n_terms) {
258 let a_k = kernel.amplitudes[k];
259 let b_k = kernel.rates[k];
260 let mut di_k = vec![S::ZERO; dim];
261 for i in 0..dim {
262 di_k[i] = a_k * y[i] - b_k * integral[i];
263 }
264 di.push(di_k);
265 stats.n_kernel += dim;
266 }
267
268 (dy, di)
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 struct Viscoelastic {
277 k: f64,
278 kernel: PronyKernel<f64>,
279 }
280
281 impl Viscoelastic {
282 fn new(k: f64, a: f64, b: f64) -> Self {
283 Self {
284 k,
285 kernel: PronyKernel::single(a, b),
286 }
287 }
288 }
289
290 impl PronySystem<f64> for Viscoelastic {
291 fn dim(&self) -> usize {
292 1
293 }
294
295 fn rhs(&self, _t: f64, y: &[f64], f: &mut [f64]) {
296 f[0] = -self.k * y[0];
297 }
298
299 fn kernel(&self) -> &PronyKernel<f64> {
300 &self.kernel
301 }
302 }
303
304 #[test]
305 fn test_prony_viscoelastic() {
306 let system = Viscoelastic::new(1.0, 0.5, 0.3);
307 let options = IdeOptions::default().dt(0.01);
308
309 let result = PronySolver::solve(&system, 0.0, 2.0, &[1.0], &options).expect("Solve failed");
310
311 assert!(result.success);
312
313 let y_final = result.y_final().unwrap()[0];
315 assert!(y_final > 0.0, "Solution should remain positive");
316 assert!(y_final < 1.0, "Solution should decay");
317
318 assert!(
322 y_final > 0.135,
323 "Memory should slow decay: y_final = {}",
324 y_final
325 );
326 }
327
328 #[test]
329 fn test_prony_two_term() {
330 struct TwoTermMaxwell {
332 kernel: PronyKernel<f64>,
333 }
334
335 impl PronySystem<f64> for TwoTermMaxwell {
336 fn dim(&self) -> usize {
337 1
338 }
339
340 fn rhs(&self, _t: f64, y: &[f64], f: &mut [f64]) {
341 f[0] = -2.0 * y[0]; }
343
344 fn kernel(&self) -> &PronyKernel<f64> {
345 &self.kernel
346 }
347 }
348
349 let system = TwoTermMaxwell {
350 kernel: PronyKernel::two_term(0.8, 0.5, 0.4, 2.0),
351 };
352 let options = IdeOptions::default().dt(0.01);
353
354 let result = PronySolver::solve(&system, 0.0, 3.0, &[1.0], &options).expect("Solve failed");
355
356 assert!(result.success);
357
358 for (i, &t) in result.t.iter().enumerate() {
360 let y = result.y_at(i)[0];
361 assert!(y.is_finite(), "Solution should be finite at t={}", t);
362 assert!(
363 y >= 0.0 || y.abs() < 0.1,
364 "Solution should be non-negative or small negative at t={}",
365 t
366 );
367 }
368 }
369
370 #[test]
371 fn test_prony_2d_system() {
372 struct TwoDProny {
373 kernel: PronyKernel<f64>,
374 }
375
376 impl PronySystem<f64> for TwoDProny {
377 fn dim(&self) -> usize {
378 2
379 }
380
381 fn rhs(&self, _t: f64, y: &[f64], f: &mut [f64]) {
382 f[0] = -y[0] + 0.1 * y[1];
383 f[1] = -0.5 * y[1];
384 }
385
386 fn kernel(&self) -> &PronyKernel<f64> {
387 &self.kernel
388 }
389 }
390
391 let system = TwoDProny {
392 kernel: PronyKernel::single(0.3, 0.5),
393 };
394 let options = IdeOptions::default().dt(0.01);
395
396 let result =
397 PronySolver::solve(&system, 0.0, 2.0, &[1.0, 1.0], &options).expect("Solve failed");
398
399 assert!(result.success);
400 let y_final = result.y_final().unwrap();
401 assert_eq!(y_final.len(), 2);
402 }
403
404 #[test]
405 fn test_prony_efficiency() {
406 let system = Viscoelastic::new(1.0, 0.5, 0.3);
408
409 let options_short = IdeOptions::default().dt(0.01);
410 let result_short = PronySolver::solve(&system, 0.0, 1.0, &[1.0], &options_short)
411 .expect("Short solve failed");
412
413 let options_long = IdeOptions::default().dt(0.01);
414 let result_long = PronySolver::solve(&system, 0.0, 10.0, &[1.0], &options_long)
415 .expect("Long solve failed");
416
417 assert!(result_short.success);
419 assert!(result_long.success);
420
421 let ratio = result_long.stats.n_kernel as f64 / result_short.stats.n_kernel as f64;
424 assert!(
426 ratio < 15.0,
427 "Kernel evals should scale linearly: ratio = {}",
428 ratio
429 );
430 }
431
432 #[test]
433 fn test_prony_dimension_mismatch() {
434 let system = Viscoelastic::new(1.0, 0.5, 0.3);
435 let options = IdeOptions::default().dt(0.01);
436
437 let result = PronySolver::solve(&system, 0.0, 1.0, &[1.0, 2.0], &options);
439 assert!(result.is_err());
440 let msg = result.unwrap_err();
441 assert!(msg.contains("dimension"), "Error message: {}", msg);
442 }
443
444 #[test]
445 fn test_prony_max_steps_exceeded() {
446 let system = Viscoelastic::new(1.0, 0.5, 0.3);
447 let options = IdeOptions::default().dt(0.001).max_steps(5);
449
450 let result = PronySolver::solve(&system, 0.0, 1.0, &[1.0], &options);
451 assert!(result.is_err());
452 let msg = result.unwrap_err();
453 assert!(msg.contains("exceeds maximum"), "Error message: {}", msg);
454 }
455
456 #[test]
457 fn test_prony_zero_kernel() {
458 struct PureDecay {
460 kernel: PronyKernel<f64>,
461 }
462
463 impl PronySystem<f64> for PureDecay {
464 fn dim(&self) -> usize {
465 1
466 }
467
468 fn rhs(&self, _t: f64, y: &[f64], f: &mut [f64]) {
469 f[0] = -y[0];
470 }
471
472 fn kernel(&self) -> &PronyKernel<f64> {
473 &self.kernel
474 }
475 }
476
477 let system = PureDecay {
478 kernel: PronyKernel::single(0.0, 1.0),
479 };
480 let options = IdeOptions::default().dt(0.001);
481
482 let result = PronySolver::solve(&system, 0.0, 1.0, &[1.0], &options).expect("Solve failed");
483
484 let y_final = result.y_final().unwrap()[0];
485 let expected = (-1.0_f64).exp(); assert!(
487 (y_final - expected).abs() < 1e-4,
488 "Zero kernel Prony should match pure ODE: got {}, expected {}",
489 y_final,
490 expected
491 );
492 }
493}