Skip to main content

numra_ocp/
param_est.rs

1//! Parameter estimation for ODE models.
2//!
3//! Given an ODE model `dy/dt = f(t, y; p)` and observed data `(t_i, y_i)`,
4//! find the parameters `p` that minimize the residual between predicted and
5//! observed states.
6//!
7//! Author: Moussa Leblouba
8//! Date: 9 February 2026
9//! Modified: 2 May 2026
10
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::sync::Arc;
13use std::time::Instant;
14
15use numra_core::Scalar;
16use numra_ode::{DoPri5, OdeProblem, Solver, SolverOptions};
17use numra_optim::OptimProblem;
18
19use crate::error::OcpError;
20
21/// ODE model closure: `(t, y, dydt, params)`.
22type ModelFn<S> = dyn Fn(S, &[S], &mut [S], &[S]) + Send + Sync;
23
24// ---------------------------------------------------------------------------
25// Types
26// ---------------------------------------------------------------------------
27
28/// Which ODE solver to use for forward integrations.
29#[derive(Clone, Debug, Default)]
30pub enum OdeSolverChoice {
31    /// Dormand-Prince 5(4) explicit method (non-stiff).
32    #[default]
33    DoPri5,
34}
35
36/// Result of a parameter estimation run.
37#[derive(Clone, Debug)]
38pub struct ParamEstResult<S: Scalar> {
39    /// Estimated parameters.
40    pub params: Vec<S>,
41    /// Final residual norm (L2).
42    pub residual_norm: S,
43    /// Optimizer iterations.
44    pub iterations: usize,
45    /// Whether the optimizer converged.
46    pub converged: bool,
47    /// Human-readable status message.
48    pub message: String,
49    /// Predicted observations at the data times (flat row-major).
50    pub predicted: Vec<S>,
51    /// Total number of ODE integrations performed.
52    pub n_integrations: usize,
53    /// Wall-clock time in seconds.
54    pub wall_time_secs: f64,
55}
56
57// ---------------------------------------------------------------------------
58// Builder
59// ---------------------------------------------------------------------------
60
61/// Builder for ODE parameter estimation problems.
62pub struct ParamEstProblem<S: Scalar> {
63    n_params: usize,
64    n_states: usize,
65    model: Option<Box<ModelFn<S>>>,
66    y0: Option<Vec<S>>,
67    params0: Option<Vec<S>>,
68    param_bounds: Vec<Option<(S, S)>>,
69    t_data: Vec<S>,
70    y_data: Vec<S>,
71    observed_indices: Option<Vec<usize>>,
72    solver: OdeSolverChoice,
73    ode_rtol: S,
74    ode_atol: S,
75    max_iter: usize,
76}
77
78impl<S: Scalar> ParamEstProblem<S> {
79    /// Create a new parameter estimation problem.
80    ///
81    /// - `n_params`: number of parameters to estimate.
82    /// - `n_states`: dimension of the ODE state vector.
83    pub fn new(n_params: usize, n_states: usize) -> Self {
84        Self {
85            n_params,
86            n_states,
87            model: None,
88            y0: None,
89            params0: None,
90            param_bounds: vec![None; n_params],
91            t_data: Vec::new(),
92            y_data: Vec::new(),
93            observed_indices: None,
94            solver: OdeSolverChoice::default(),
95            ode_rtol: S::from_f64(1e-8),
96            ode_atol: S::from_f64(1e-10),
97            max_iter: 100,
98        }
99    }
100
101    /// Set the ODE right-hand side: `f(t, y, dydt, params)`.
102    pub fn model<F>(mut self, f: F) -> Self
103    where
104        F: Fn(S, &[S], &mut [S], &[S]) + Send + Sync + 'static,
105    {
106        self.model = Some(Box::new(f));
107        self
108    }
109
110    /// Set the initial state `y(t0)`.
111    pub fn initial_state(mut self, y0: Vec<S>) -> Self {
112        self.y0 = Some(y0);
113        self
114    }
115
116    /// Set the initial parameter guess.
117    pub fn params(mut self, p0: Vec<S>) -> Self {
118        self.params0 = Some(p0);
119        self
120    }
121
122    /// Set bounds for parameter `i`.
123    pub fn param_bounds(mut self, i: usize, bounds: (S, S)) -> Self {
124        self.param_bounds[i] = Some(bounds);
125        self
126    }
127
128    /// Set bounds for all parameters at once.
129    pub fn all_param_bounds(mut self, bounds: Vec<Option<(S, S)>>) -> Self {
130        self.param_bounds = bounds;
131        self
132    }
133
134    /// Set measurement data.
135    ///
136    /// `y_data` is flat row-major: `y_data[i * n_observed + j]` for time
137    /// index `i` and observed state `j`.
138    pub fn data(mut self, t_data: Vec<S>, y_data: Vec<S>) -> Self {
139        self.t_data = t_data;
140        self.y_data = y_data;
141        self
142    }
143
144    /// Specify which state indices are observed.
145    ///
146    /// If not called, all states are assumed observed.
147    pub fn observed(mut self, indices: Vec<usize>) -> Self {
148        self.observed_indices = Some(indices);
149        self
150    }
151
152    /// Choose the ODE solver for forward integrations.
153    pub fn ode_solver(mut self, choice: OdeSolverChoice) -> Self {
154        self.solver = choice;
155        self
156    }
157
158    /// Set ODE solver tolerances.
159    pub fn ode_tolerances(mut self, rtol: S, atol: S) -> Self {
160        self.ode_rtol = rtol;
161        self.ode_atol = atol;
162        self
163    }
164
165    /// Set maximum optimizer iterations.
166    pub fn max_iter(mut self, n: usize) -> Self {
167        self.max_iter = n;
168        self
169    }
170
171    // -----------------------------------------------------------------------
172    // Solve
173    // -----------------------------------------------------------------------
174
175    /// Run parameter estimation.
176    pub fn solve(self) -> Result<ParamEstResult<S>, OcpError>
177    where
178        S: faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
179    {
180        let start = Instant::now();
181
182        // -- Validate -------------------------------------------------------
183        let model = self.model.ok_or(OcpError::NoModel)?;
184        let y0 = self.y0.ok_or(OcpError::NoInitialState)?;
185        let params0 = self
186            .params0
187            .ok_or(OcpError::Other("no initial parameter guess".to_string()))?;
188        if self.t_data.is_empty() || self.y_data.is_empty() {
189            return Err(OcpError::NoData);
190        }
191        if y0.len() != self.n_states {
192            return Err(OcpError::DimensionMismatch(format!(
193                "y0 length {} != n_states {}",
194                y0.len(),
195                self.n_states
196            )));
197        }
198        if params0.len() != self.n_params {
199            return Err(OcpError::DimensionMismatch(format!(
200                "params0 length {} != n_params {}",
201                params0.len(),
202                self.n_params
203            )));
204        }
205
206        let obs_idx: Vec<usize> = self
207            .observed_indices
208            .unwrap_or_else(|| (0..self.n_states).collect());
209        let n_observed = obs_idx.len();
210        let n_data = self.t_data.len();
211        let n_residuals = n_data * n_observed;
212
213        if self.y_data.len() != n_residuals {
214            return Err(OcpError::DimensionMismatch(format!(
215                "y_data length {} != n_data({}) * n_observed({})",
216                self.y_data.len(),
217                n_data,
218                n_observed,
219            )));
220        }
221
222        // -- Shared state ---------------------------------------------------
223        let model = Arc::new(model);
224        let y0 = Arc::new(y0);
225        let t_data = Arc::new(self.t_data);
226        let y_data = Arc::new(self.y_data);
227        let obs_idx = Arc::new(obs_idx);
228        let n_states = self.n_states;
229        let ode_rtol = self.ode_rtol;
230        let ode_atol = self.ode_atol;
231        let counter = Arc::new(AtomicUsize::new(0));
232        let has_bounds = self.param_bounds.iter().any(|b| b.is_some());
233
234        // -- Optimize -------------------------------------------------------
235        let optim_result = if has_bounds {
236            // L-BFGS-B with scalar sum-of-squares objective.
237            let m = Arc::clone(&model);
238            let y0c = Arc::clone(&y0);
239            let td = Arc::clone(&t_data);
240            let yd = Arc::clone(&y_data);
241            let oi = Arc::clone(&obs_idx);
242            let ctr = Arc::clone(&counter);
243
244            let mut prob = OptimProblem::new(self.n_params)
245                .x0(&params0)
246                .objective(move |p: &[S]| {
247                    let pred = integrate_at_params(&m, &y0c, &td, p, n_states, ode_rtol, ode_atol);
248                    ctr.fetch_add(1, Ordering::Relaxed);
249                    let mut sos = S::ZERO;
250                    for i in 0..td.len() {
251                        for (j, &idx) in oi.iter().enumerate() {
252                            let r = pred[i * n_states + idx] - yd[i * oi.len() + j];
253                            sos += r * r;
254                        }
255                    }
256                    sos
257                })
258                .max_iter(self.max_iter);
259
260            for (i, b) in self.param_bounds.iter().enumerate() {
261                if let Some(&(lo, hi)) = b.as_ref() {
262                    prob = prob.bounds(i, (lo, hi));
263                }
264            }
265            prob.solve().map_err(OcpError::OptimFailed)?
266        } else {
267            // LM (least squares).
268            let m = Arc::clone(&model);
269            let y0c = Arc::clone(&y0);
270            let td = Arc::clone(&t_data);
271            let yd = Arc::clone(&y_data);
272            let oi = Arc::clone(&obs_idx);
273            let ctr = Arc::clone(&counter);
274
275            OptimProblem::new(self.n_params)
276                .x0(&params0)
277                .least_squares(n_residuals, move |p: &[S], r: &mut [S]| {
278                    let pred = integrate_at_params(&m, &y0c, &td, p, n_states, ode_rtol, ode_atol);
279                    ctr.fetch_add(1, Ordering::Relaxed);
280                    for i in 0..td.len() {
281                        for (j, &idx) in oi.iter().enumerate() {
282                            r[i * oi.len() + j] = pred[i * n_states + idx] - yd[i * oi.len() + j];
283                        }
284                    }
285                })
286                .max_iter(self.max_iter)
287                .solve()
288                .map_err(OcpError::OptimFailed)?
289        };
290
291        // -- Final integration at optimal params ----------------------------
292        let optimal_params = &optim_result.x;
293        let pred_full = integrate_at_params(
294            &model,
295            &y0,
296            &t_data,
297            optimal_params,
298            n_states,
299            ode_rtol,
300            ode_atol,
301        );
302        counter.fetch_add(1, Ordering::Relaxed);
303
304        // Extract predicted observations (only observed indices).
305        let mut predicted = Vec::with_capacity(n_residuals);
306        for i in 0..n_data {
307            for &idx in obs_idx.iter() {
308                predicted.push(pred_full[i * n_states + idx]);
309            }
310        }
311
312        // Residual norm.
313        let mut rnorm2 = S::ZERO;
314        for k in 0..n_residuals {
315            let r = predicted[k] - y_data[k];
316            rnorm2 += r * r;
317        }
318        let residual_norm = rnorm2.sqrt();
319
320        Ok(ParamEstResult {
321            params: optimal_params.clone(),
322            residual_norm,
323            iterations: optim_result.iterations,
324            converged: optim_result.converged,
325            message: optim_result.message.clone(),
326            predicted,
327            n_integrations: counter.load(Ordering::Relaxed),
328            wall_time_secs: start.elapsed().as_secs_f64(),
329        })
330    }
331}
332
333// ---------------------------------------------------------------------------
334// Helper: forward integration and flat extraction
335// ---------------------------------------------------------------------------
336
337/// Integrate the ODE at the given parameters and return the full state
338/// at each data time as a flat vector of length `n_data * n_states`.
339///
340/// We integrate segment-by-segment between successive data times so
341/// the solver lands exactly on each measurement time.
342///
343/// If the ODE integration fails, returns a vector filled with `1e10` to
344/// steer the optimizer away from this region.
345fn integrate_at_params<S: Scalar>(
346    model: &Arc<Box<ModelFn<S>>>,
347    y0: &Arc<Vec<S>>,
348    t_data: &Arc<Vec<S>>,
349    params: &[S],
350    n_states: usize,
351    rtol: S,
352    atol: S,
353) -> Vec<S> {
354    let n_data = t_data.len();
355    let total = n_data * n_states;
356
357    let options = SolverOptions::default().rtol(rtol).atol(atol);
358
359    // Store the state at each data time.
360    let mut out = Vec::with_capacity(total);
361
362    // Current state, starting from y0.
363    let mut y_cur = y0.as_ref().clone();
364
365    // First data point: the initial state itself.
366    out.extend_from_slice(&y_cur);
367
368    let big = S::from_f64(1e10);
369    let tiny = S::from_f64(1e-15);
370
371    // Integrate from t_data[i] to t_data[i+1] for each segment.
372    for i in 0..(n_data - 1) {
373        let t_start = t_data[i];
374        let t_end = t_data[i + 1];
375
376        // Skip zero-length segments.
377        if (t_end - t_start).abs() < tiny {
378            out.extend_from_slice(&y_cur);
379            continue;
380        }
381
382        let p = params.to_vec();
383        let model_ref = Arc::clone(model);
384        let rhs = move |t: S, y: &[S], dydt: &mut [S]| {
385            model_ref(t, y, dydt, &p);
386        };
387
388        let problem = OdeProblem::new(rhs, t_start, t_end, y_cur.clone());
389
390        match DoPri5::solve(&problem, t_start, t_end, &y_cur, &options) {
391            Ok(result) if result.success => {
392                // The last time point should be t_end; grab the final state.
393                if let Some(y_final) = result.y_final() {
394                    y_cur = y_final.to_vec();
395                    out.extend_from_slice(&y_cur);
396                } else {
397                    return vec![big; total];
398                }
399            }
400            _ => return vec![big; total],
401        }
402    }
403
404    out
405}
406
407// ---------------------------------------------------------------------------
408// Tests
409// ---------------------------------------------------------------------------
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    /// Exponential decay: dy/dt = -k*y, true k=0.5.
416    #[test]
417    fn test_exponential_decay() {
418        let k_true = 0.5;
419        let y0_val = 1.0;
420        let t_data: Vec<f64> = (0..=10).map(|i| i as f64 * 0.5).collect();
421        let y_data: Vec<f64> = t_data
422            .iter()
423            .map(|&t| y0_val * (-k_true * t).exp())
424            .collect();
425
426        let result = ParamEstProblem::new(1, 1)
427            .model(|_t: f64, y, dydt, p| {
428                dydt[0] = -p[0] * y[0];
429            })
430            .initial_state(vec![y0_val])
431            .params(vec![1.0])
432            .data(t_data, y_data)
433            .solve()
434            .expect("parameter estimation failed");
435
436        assert!(
437            result.converged,
438            "optimizer did not converge: {}",
439            result.message
440        );
441        let k_est = result.params[0];
442        assert!(
443            (k_est - k_true).abs() < 0.01,
444            "k_est = {k_est}, expected ~{k_true}"
445        );
446        assert!(
447            result.residual_norm < 1e-4,
448            "residual_norm = {}",
449            result.residual_norm
450        );
451        assert!(result.n_integrations > 0);
452    }
453
454    /// Two-parameter model: dy/dt = -a*y + b, true a=1, b=2.
455    /// Analytical: y(t) = b/a + (y0 - b/a)*exp(-a*t) = 2 + (y0-2)*exp(-t).
456    #[test]
457    fn test_two_param_model() {
458        let a_true = 1.0;
459        let b_true = 2.0;
460        let y0_val = 1.0;
461
462        let t_data: Vec<f64> = (0..=20).map(|i| i as f64 * 0.25).collect();
463        let y_data: Vec<f64> = t_data
464            .iter()
465            .map(|&t| b_true / a_true + (y0_val - b_true / a_true) * (-a_true * t).exp())
466            .collect();
467
468        let result = ParamEstProblem::new(2, 1)
469            .model(|_t: f64, y, dydt, p| {
470                dydt[0] = -p[0] * y[0] + p[1];
471            })
472            .initial_state(vec![y0_val])
473            .params(vec![0.5, 1.0])
474            .data(t_data, y_data)
475            .solve()
476            .expect("parameter estimation failed");
477
478        assert!(
479            result.converged,
480            "optimizer did not converge: {}",
481            result.message
482        );
483        assert!(
484            (result.params[0] - a_true).abs() < 0.1,
485            "a_est = {}, expected ~{a_true}",
486            result.params[0]
487        );
488        assert!(
489            (result.params[1] - b_true).abs() < 0.1,
490            "b_est = {}, expected ~{b_true}",
491            result.params[1]
492        );
493    }
494
495    /// Exponential decay with bounds: k in [0.01, 5.0].
496    #[test]
497    fn test_param_est_with_bounds() {
498        let k_true = 0.5;
499        let y0_val = 1.0;
500        let t_data: Vec<f64> = (0..=10).map(|i| i as f64 * 0.5).collect();
501        let y_data: Vec<f64> = t_data
502            .iter()
503            .map(|&t| y0_val * (-k_true * t).exp())
504            .collect();
505
506        let result = ParamEstProblem::new(1, 1)
507            .model(|_t: f64, y, dydt, p| {
508                dydt[0] = -p[0] * y[0];
509            })
510            .initial_state(vec![y0_val])
511            .params(vec![3.0])
512            .param_bounds(0, (0.01, 5.0))
513            .data(t_data, y_data)
514            .solve()
515            .expect("parameter estimation failed");
516
517        assert!(
518            result.converged,
519            "optimizer did not converge: {}",
520            result.message
521        );
522        let k_est = result.params[0];
523        assert!(
524            (k_est - k_true).abs() < 0.05,
525            "k_est = {k_est}, expected ~{k_true}"
526        );
527        assert!(
528            (0.01..=5.0).contains(&k_est),
529            "k_est out of bounds: {k_est}"
530        );
531    }
532
533    /// Partial observation: 2-state coupled system, observe only state 0.
534    ///   dx/dt = -a*x + y
535    ///   dy/dt = x - b*y
536    /// True: a=0.5, b=1.0, x0=1, y0=0.
537    /// Both parameters influence x(t) through coupling.
538    #[test]
539    fn test_partial_observation() {
540        let a_true = 0.5;
541        let b_true = 1.0;
542        let x0 = 1.0;
543        let y0_val = 0.0;
544
545        // Generate "exact" data by integrating the ODE with the true params.
546        let t_data: Vec<f64> = (0..=20).map(|i| i as f64 * 0.5).collect();
547
548        // Integrate with true parameters to get reference data.
549        let opts = numra_ode::SolverOptions::default().rtol(1e-12).atol(1e-14);
550
551        // Integrate segment by segment to get exact values at data times.
552        let mut y_data = Vec::new();
553        let mut y_cur = vec![x0, y0_val];
554        y_data.push(y_cur[0]); // state 0 at t=0
555        for i in 0..(t_data.len() - 1) {
556            let t_s = t_data[i];
557            let t_e = t_data[i + 1];
558            let prob = numra_ode::OdeProblem::new(
559                move |_t: f64, y: &[f64], dydt: &mut [f64]| {
560                    dydt[0] = -a_true * y[0] + y[1];
561                    dydt[1] = y[0] - b_true * y[1];
562                },
563                t_s,
564                t_e,
565                y_cur.clone(),
566            );
567            let res = numra_ode::DoPri5::solve(&prob, t_s, t_e, &y_cur, &opts).unwrap();
568            y_cur = res.y_final().unwrap().to_vec();
569            y_data.push(y_cur[0]); // only state 0
570        }
571
572        let result = ParamEstProblem::new(2, 2)
573            .model(|_t: f64, y, dydt, p| {
574                dydt[0] = -p[0] * y[0] + y[1];
575                dydt[1] = y[0] - p[1] * y[1];
576            })
577            .initial_state(vec![x0, y0_val])
578            .params(vec![0.8, 1.5]) // initial guess
579            .observed(vec![0]) // only observe x
580            .data(t_data, y_data)
581            .max_iter(200)
582            .solve()
583            .expect("parameter estimation failed");
584
585        assert!(
586            result.converged,
587            "optimizer did not converge: {}",
588            result.message
589        );
590        assert!(
591            (result.params[0] - a_true).abs() < 0.2,
592            "a_est = {}, expected ~{a_true}",
593            result.params[0]
594        );
595        assert!(
596            (result.params[1] - b_true).abs() < 0.2,
597            "b_est = {}, expected ~{b_true}",
598            result.params[1]
599        );
600    }
601}