Skip to main content

numra_ide/
prony.rs

1//! Prony series (Sum of Exponentials) solver for IDEs.
2//!
3//! For kernels of the form K(τ) = Σᵢ aᵢ exp(-bᵢ τ), the memory integral
4//! can be computed recursively without storing full history:
5//!
6//! ```text
7//! Iᵢ(t) = ∫₀ᵗ aᵢ exp(-bᵢ(t-s)) y(s) ds
8//! ```
9//!
10//! satisfies the ODE:
11//! ```text
12//! I'ᵢ = aᵢ y - bᵢ Iᵢ,  Iᵢ(0) = 0
13//! ```
14//!
15//! This reduces memory from O(N) to O(M) where M is the number of exponential terms.
16//!
17//! Author: Moussa Leblouba
18//! Date: 5 March 2026
19//! Modified: 2 May 2026
20
21use crate::kernels::PronyKernel;
22use crate::system::{IdeOptions, IdeResult, IdeStats};
23use numra_core::Scalar;
24
25/// System for Prony solver: y' = f(t, y) + Σᵢ Iᵢ.
26///
27/// The kernel is implicitly defined by the Prony series.
28pub trait PronySystem<S: Scalar> {
29    /// Dimension of the state space.
30    fn dim(&self) -> usize;
31
32    /// Evaluate the local right-hand side f(t, y).
33    fn rhs(&self, t: S, y: &[S], f: &mut [S]);
34
35    /// Get the Prony kernel for computing memory effects.
36    fn kernel(&self) -> &PronyKernel<S>;
37
38    /// Optional coupling matrix: how each kernel term affects the state.
39    ///
40    /// By default, assumes kernel affects all components equally: I(t) * y(t).
41    /// For more complex coupling, override this method.
42    ///
43    /// Returns: `coupling[i][k]` = coefficient for how I_k affects dy_i/dt
44    fn coupling(&self) -> Option<Vec<Vec<S>>> {
45        None
46    }
47}
48
49/// Efficient solver for IDEs with sum-of-exponentials (Prony) kernels.
50///
51/// Uses the recursive property of exponential integrals to avoid
52/// storing full solution history. Memory complexity is O(dim × n_terms)
53/// instead of O(dim × n_steps).
54pub struct PronySolver;
55
56impl PronySolver {
57    /// Solve an IDE with Prony kernel.
58    pub fn solve<S: Scalar, Sys: PronySystem<S>>(
59        system: &Sys,
60        t0: S,
61        tf: S,
62        y0: &[S],
63        options: &IdeOptions<S>,
64    ) -> Result<IdeResult<S>, String> {
65        let dim = system.dim();
66        let kernel = system.kernel();
67        let n_terms = kernel.num_terms();
68
69        if y0.len() != dim {
70            return Err(format!(
71                "Initial state dimension {} doesn't match system dimension {}",
72                y0.len(),
73                dim
74            ));
75        }
76
77        let dt = options.dt;
78        let n_steps = ((tf - t0) / dt).to_f64().ceil() as usize;
79
80        if n_steps > options.max_steps {
81            return Err(format!(
82                "Required steps {} exceeds maximum {}",
83                n_steps, options.max_steps
84            ));
85        }
86
87        // State: y and auxiliary integrals I[k][i] for each term k and dimension i
88        let mut y = y0.to_vec();
89        let mut integrals: Vec<Vec<S>> = vec![vec![S::ZERO; dim]; n_terms];
90
91        let mut t_out = vec![t0];
92        let mut y_out = y0.to_vec();
93        let mut stats = IdeStats::default();
94
95        let mut t = t0;
96        let mut f_buf = vec![S::ZERO; dim];
97
98        // Get coupling matrix or use default
99        let coupling = system.coupling();
100
101        let half = S::from_f64(0.5);
102        let sixth = S::ONE / S::from_f64(6.0);
103        let two = S::from_f64(2.0);
104
105        for _n in 1..=n_steps {
106            let t_new = t + dt;
107
108            // RK4 for the extended system [y, I_1, I_2, ...]
109
110            // Stage 1: k1
111            let (k1_y, k1_i) =
112                compute_derivatives(system, t, &y, &integrals, &coupling, &mut f_buf, &mut stats);
113
114            // Stage 2: k2 at t + dt/2
115            let y_mid1: Vec<S> = y
116                .iter()
117                .zip(k1_y.iter())
118                .map(|(&yi, &ki)| yi + half * dt * ki)
119                .collect();
120            let i_mid1: Vec<Vec<S>> = integrals
121                .iter()
122                .zip(k1_i.iter())
123                .map(|(ii, ki)| {
124                    ii.iter()
125                        .zip(ki.iter())
126                        .map(|(&ij, &kij)| ij + half * dt * kij)
127                        .collect()
128                })
129                .collect();
130            let (k2_y, k2_i) = compute_derivatives(
131                system,
132                t + half * dt,
133                &y_mid1,
134                &i_mid1,
135                &coupling,
136                &mut f_buf,
137                &mut stats,
138            );
139
140            // Stage 3: k3 at t + dt/2
141            let y_mid2: Vec<S> = y
142                .iter()
143                .zip(k2_y.iter())
144                .map(|(&yi, &ki)| yi + half * dt * ki)
145                .collect();
146            let i_mid2: Vec<Vec<S>> = integrals
147                .iter()
148                .zip(k2_i.iter())
149                .map(|(ii, ki)| {
150                    ii.iter()
151                        .zip(ki.iter())
152                        .map(|(&ij, &kij)| ij + half * dt * kij)
153                        .collect()
154                })
155                .collect();
156            let (k3_y, k3_i) = compute_derivatives(
157                system,
158                t + half * dt,
159                &y_mid2,
160                &i_mid2,
161                &coupling,
162                &mut f_buf,
163                &mut stats,
164            );
165
166            // Stage 4: k4 at t + dt
167            let y_end: Vec<S> = y
168                .iter()
169                .zip(k3_y.iter())
170                .map(|(&yi, &ki)| yi + dt * ki)
171                .collect();
172            let i_end: Vec<Vec<S>> = integrals
173                .iter()
174                .zip(k3_i.iter())
175                .map(|(ii, ki)| {
176                    ii.iter()
177                        .zip(ki.iter())
178                        .map(|(&ij, &kij)| ij + dt * kij)
179                        .collect()
180                })
181                .collect();
182            let (k4_y, k4_i) = compute_derivatives(
183                system,
184                t + dt,
185                &y_end,
186                &i_end,
187                &coupling,
188                &mut f_buf,
189                &mut stats,
190            );
191
192            // RK4 combination for y
193            for i in 0..dim {
194                y[i] += sixth * dt * (k1_y[i] + two * k2_y[i] + two * k3_y[i] + k4_y[i]);
195            }
196
197            // RK4 combination for integrals
198            for k in 0..n_terms {
199                for i in 0..dim {
200                    integrals[k][i] += sixth
201                        * dt
202                        * (k1_i[k][i] + two * k2_i[k][i] + two * k3_i[k][i] + k4_i[k][i]);
203                }
204            }
205
206            // Store output
207            t_out.push(t_new);
208            y_out.extend_from_slice(&y);
209            stats.n_steps += 1;
210
211            t = t_new;
212        }
213
214        Ok(IdeResult::new(t_out, y_out, dim, stats))
215    }
216}
217
218/// Compute derivatives for y and all integral terms.
219fn compute_derivatives<S: Scalar, Sys: PronySystem<S>>(
220    system: &Sys,
221    t: S,
222    y: &[S],
223    integrals: &[Vec<S>],
224    coupling: &Option<Vec<Vec<S>>>,
225    f_buf: &mut [S],
226    stats: &mut IdeStats,
227) -> (Vec<S>, Vec<Vec<S>>) {
228    let dim = y.len();
229    let kernel = system.kernel();
230    let n_terms = kernel.num_terms();
231
232    // Compute local RHS
233    system.rhs(t, y, f_buf);
234    stats.n_rhs += 1;
235
236    // Derivative of y: f(t,y) + sum of integrals
237    let mut dy = f_buf.to_vec();
238
239    if let Some(c) = coupling {
240        // Custom coupling: dy[i] += sum_k c[i][k] * I_k[i]
241        for i in 0..dim {
242            for k in 0..n_terms {
243                dy[i] += c[i][k] * integrals[k][i];
244            }
245        }
246    } else {
247        // Default: dy[i] += sum_k I_k[i]
248        for i in 0..dim {
249            for integral in integrals.iter().take(n_terms) {
250                dy[i] += integral[i];
251            }
252        }
253    }
254
255    // Derivative of integrals: dI_k[i]/dt = a_k * y[i] - b_k * I_k[i]
256    let mut di: Vec<Vec<S>> = Vec::with_capacity(n_terms);
257    for (k, integral) in integrals.iter().enumerate().take(n_terms) {
258        let a_k = kernel.amplitudes[k];
259        let b_k = kernel.rates[k];
260        let mut di_k = vec![S::ZERO; dim];
261        for i in 0..dim {
262            di_k[i] = a_k * y[i] - b_k * integral[i];
263        }
264        di.push(di_k);
265        stats.n_kernel += dim;
266    }
267
268    (dy, di)
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    /// Viscoelastic material: y' = -k*y + ∫₀ᵗ a*exp(-b*(t-s)) * y(s) ds
276    struct Viscoelastic {
277        k: f64,
278        kernel: PronyKernel<f64>,
279    }
280
281    impl Viscoelastic {
282        fn new(k: f64, a: f64, b: f64) -> Self {
283            Self {
284                k,
285                kernel: PronyKernel::single(a, b),
286            }
287        }
288    }
289
290    impl PronySystem<f64> for Viscoelastic {
291        fn dim(&self) -> usize {
292            1
293        }
294
295        fn rhs(&self, _t: f64, y: &[f64], f: &mut [f64]) {
296            f[0] = -self.k * y[0];
297        }
298
299        fn kernel(&self) -> &PronyKernel<f64> {
300            &self.kernel
301        }
302    }
303
304    #[test]
305    fn test_prony_viscoelastic() {
306        let system = Viscoelastic::new(1.0, 0.5, 0.3);
307        let options = IdeOptions::default().dt(0.01);
308
309        let result = PronySolver::solve(&system, 0.0, 2.0, &[1.0], &options).expect("Solve failed");
310
311        assert!(result.success);
312
313        // Solution should be smooth and bounded
314        let y_final = result.y_final().unwrap()[0];
315        assert!(y_final > 0.0, "Solution should remain positive");
316        assert!(y_final < 1.0, "Solution should decay");
317
318        // Memory effect should slow down decay compared to pure exponential
319        // Pure y' = -y gives y(2) = exp(-2) ≈ 0.135
320        // With memory integral adding back, should be higher
321        assert!(
322            y_final > 0.135,
323            "Memory should slow decay: y_final = {}",
324            y_final
325        );
326    }
327
328    #[test]
329    fn test_prony_two_term() {
330        // Two-term Prony series (generalized Maxwell model)
331        struct TwoTermMaxwell {
332            kernel: PronyKernel<f64>,
333        }
334
335        impl PronySystem<f64> for TwoTermMaxwell {
336            fn dim(&self) -> usize {
337                1
338            }
339
340            fn rhs(&self, _t: f64, y: &[f64], f: &mut [f64]) {
341                f[0] = -2.0 * y[0]; // Strong local damping
342            }
343
344            fn kernel(&self) -> &PronyKernel<f64> {
345                &self.kernel
346            }
347        }
348
349        let system = TwoTermMaxwell {
350            kernel: PronyKernel::two_term(0.8, 0.5, 0.4, 2.0),
351        };
352        let options = IdeOptions::default().dt(0.01);
353
354        let result = PronySolver::solve(&system, 0.0, 3.0, &[1.0], &options).expect("Solve failed");
355
356        assert!(result.success);
357
358        // Check solution at various points
359        for (i, &t) in result.t.iter().enumerate() {
360            let y = result.y_at(i)[0];
361            assert!(y.is_finite(), "Solution should be finite at t={}", t);
362            assert!(
363                y >= 0.0 || y.abs() < 0.1,
364                "Solution should be non-negative or small negative at t={}",
365                t
366            );
367        }
368    }
369
370    #[test]
371    fn test_prony_2d_system() {
372        struct TwoDProny {
373            kernel: PronyKernel<f64>,
374        }
375
376        impl PronySystem<f64> for TwoDProny {
377            fn dim(&self) -> usize {
378                2
379            }
380
381            fn rhs(&self, _t: f64, y: &[f64], f: &mut [f64]) {
382                f[0] = -y[0] + 0.1 * y[1];
383                f[1] = -0.5 * y[1];
384            }
385
386            fn kernel(&self) -> &PronyKernel<f64> {
387                &self.kernel
388            }
389        }
390
391        let system = TwoDProny {
392            kernel: PronyKernel::single(0.3, 0.5),
393        };
394        let options = IdeOptions::default().dt(0.01);
395
396        let result =
397            PronySolver::solve(&system, 0.0, 2.0, &[1.0, 1.0], &options).expect("Solve failed");
398
399        assert!(result.success);
400        let y_final = result.y_final().unwrap();
401        assert_eq!(y_final.len(), 2);
402    }
403
404    #[test]
405    fn test_prony_efficiency() {
406        // The Prony solver should have constant memory per step
407        let system = Viscoelastic::new(1.0, 0.5, 0.3);
408
409        let options_short = IdeOptions::default().dt(0.01);
410        let result_short = PronySolver::solve(&system, 0.0, 1.0, &[1.0], &options_short)
411            .expect("Short solve failed");
412
413        let options_long = IdeOptions::default().dt(0.01);
414        let result_long = PronySolver::solve(&system, 0.0, 10.0, &[1.0], &options_long)
415            .expect("Long solve failed");
416
417        // Both should succeed - Prony doesn't accumulate memory errors
418        assert!(result_short.success);
419        assert!(result_long.success);
420
421        // Kernel evaluations scale linearly with steps, not quadratically
422        // (Unlike full quadrature which is O(n²))
423        let ratio = result_long.stats.n_kernel as f64 / result_short.stats.n_kernel as f64;
424        // Ratio should be close to 10 (time ratio), not 100 (quadratic)
425        assert!(
426            ratio < 15.0,
427            "Kernel evals should scale linearly: ratio = {}",
428            ratio
429        );
430    }
431
432    #[test]
433    fn test_prony_dimension_mismatch() {
434        let system = Viscoelastic::new(1.0, 0.5, 0.3);
435        let options = IdeOptions::default().dt(0.01);
436
437        // y0 has 2 elements but system dim is 1
438        let result = PronySolver::solve(&system, 0.0, 1.0, &[1.0, 2.0], &options);
439        assert!(result.is_err());
440        let msg = result.unwrap_err();
441        assert!(msg.contains("dimension"), "Error message: {}", msg);
442    }
443
444    #[test]
445    fn test_prony_max_steps_exceeded() {
446        let system = Viscoelastic::new(1.0, 0.5, 0.3);
447        // Very small dt with tiny max_steps to trigger error
448        let options = IdeOptions::default().dt(0.001).max_steps(5);
449
450        let result = PronySolver::solve(&system, 0.0, 1.0, &[1.0], &options);
451        assert!(result.is_err());
452        let msg = result.unwrap_err();
453        assert!(msg.contains("exceeds maximum"), "Error message: {}", msg);
454    }
455
456    #[test]
457    fn test_prony_zero_kernel() {
458        // Kernel with amplitude 0 => no memory => pure ODE y' = -y
459        struct PureDecay {
460            kernel: PronyKernel<f64>,
461        }
462
463        impl PronySystem<f64> for PureDecay {
464            fn dim(&self) -> usize {
465                1
466            }
467
468            fn rhs(&self, _t: f64, y: &[f64], f: &mut [f64]) {
469                f[0] = -y[0];
470            }
471
472            fn kernel(&self) -> &PronyKernel<f64> {
473                &self.kernel
474            }
475        }
476
477        let system = PureDecay {
478            kernel: PronyKernel::single(0.0, 1.0),
479        };
480        let options = IdeOptions::default().dt(0.001);
481
482        let result = PronySolver::solve(&system, 0.0, 1.0, &[1.0], &options).expect("Solve failed");
483
484        let y_final = result.y_final().unwrap()[0];
485        let expected = (-1.0_f64).exp(); // exp(-1) ≈ 0.3679
486        assert!(
487            (y_final - expected).abs() < 1e-4,
488            "Zero kernel Prony should match pure ODE: got {}, expected {}",
489            y_final,
490            expected
491        );
492    }
493}