1use std::sync::atomic::{AtomicUsize, Ordering};
12use std::sync::Arc;
13use std::time::Instant;
14
15use numra_core::Scalar;
16use numra_ode::{DoPri5, OdeProblem, Solver, SolverOptions};
17use numra_optim::OptimProblem;
18
19use crate::error::OcpError;
20
21type ModelFn<S> = dyn Fn(S, &[S], &mut [S], &[S]) + Send + Sync;
23
24#[derive(Clone, Debug, Default)]
30pub enum OdeSolverChoice {
31 #[default]
33 DoPri5,
34}
35
36#[derive(Clone, Debug)]
38pub struct ParamEstResult<S: Scalar> {
39 pub params: Vec<S>,
41 pub residual_norm: S,
43 pub iterations: usize,
45 pub converged: bool,
47 pub message: String,
49 pub predicted: Vec<S>,
51 pub n_integrations: usize,
53 pub wall_time_secs: f64,
55}
56
57pub struct ParamEstProblem<S: Scalar> {
63 n_params: usize,
64 n_states: usize,
65 model: Option<Box<ModelFn<S>>>,
66 y0: Option<Vec<S>>,
67 params0: Option<Vec<S>>,
68 param_bounds: Vec<Option<(S, S)>>,
69 t_data: Vec<S>,
70 y_data: Vec<S>,
71 observed_indices: Option<Vec<usize>>,
72 solver: OdeSolverChoice,
73 ode_rtol: S,
74 ode_atol: S,
75 max_iter: usize,
76}
77
78impl<S: Scalar> ParamEstProblem<S> {
79 pub fn new(n_params: usize, n_states: usize) -> Self {
84 Self {
85 n_params,
86 n_states,
87 model: None,
88 y0: None,
89 params0: None,
90 param_bounds: vec![None; n_params],
91 t_data: Vec::new(),
92 y_data: Vec::new(),
93 observed_indices: None,
94 solver: OdeSolverChoice::default(),
95 ode_rtol: S::from_f64(1e-8),
96 ode_atol: S::from_f64(1e-10),
97 max_iter: 100,
98 }
99 }
100
101 pub fn model<F>(mut self, f: F) -> Self
103 where
104 F: Fn(S, &[S], &mut [S], &[S]) + Send + Sync + 'static,
105 {
106 self.model = Some(Box::new(f));
107 self
108 }
109
110 pub fn initial_state(mut self, y0: Vec<S>) -> Self {
112 self.y0 = Some(y0);
113 self
114 }
115
116 pub fn params(mut self, p0: Vec<S>) -> Self {
118 self.params0 = Some(p0);
119 self
120 }
121
122 pub fn param_bounds(mut self, i: usize, bounds: (S, S)) -> Self {
124 self.param_bounds[i] = Some(bounds);
125 self
126 }
127
128 pub fn all_param_bounds(mut self, bounds: Vec<Option<(S, S)>>) -> Self {
130 self.param_bounds = bounds;
131 self
132 }
133
134 pub fn data(mut self, t_data: Vec<S>, y_data: Vec<S>) -> Self {
139 self.t_data = t_data;
140 self.y_data = y_data;
141 self
142 }
143
144 pub fn observed(mut self, indices: Vec<usize>) -> Self {
148 self.observed_indices = Some(indices);
149 self
150 }
151
152 pub fn ode_solver(mut self, choice: OdeSolverChoice) -> Self {
154 self.solver = choice;
155 self
156 }
157
158 pub fn ode_tolerances(mut self, rtol: S, atol: S) -> Self {
160 self.ode_rtol = rtol;
161 self.ode_atol = atol;
162 self
163 }
164
165 pub fn max_iter(mut self, n: usize) -> Self {
167 self.max_iter = n;
168 self
169 }
170
171 pub fn solve(self) -> Result<ParamEstResult<S>, OcpError>
177 where
178 S: faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
179 {
180 let start = Instant::now();
181
182 let model = self.model.ok_or(OcpError::NoModel)?;
184 let y0 = self.y0.ok_or(OcpError::NoInitialState)?;
185 let params0 = self
186 .params0
187 .ok_or(OcpError::Other("no initial parameter guess".to_string()))?;
188 if self.t_data.is_empty() || self.y_data.is_empty() {
189 return Err(OcpError::NoData);
190 }
191 if y0.len() != self.n_states {
192 return Err(OcpError::DimensionMismatch(format!(
193 "y0 length {} != n_states {}",
194 y0.len(),
195 self.n_states
196 )));
197 }
198 if params0.len() != self.n_params {
199 return Err(OcpError::DimensionMismatch(format!(
200 "params0 length {} != n_params {}",
201 params0.len(),
202 self.n_params
203 )));
204 }
205
206 let obs_idx: Vec<usize> = self
207 .observed_indices
208 .unwrap_or_else(|| (0..self.n_states).collect());
209 let n_observed = obs_idx.len();
210 let n_data = self.t_data.len();
211 let n_residuals = n_data * n_observed;
212
213 if self.y_data.len() != n_residuals {
214 return Err(OcpError::DimensionMismatch(format!(
215 "y_data length {} != n_data({}) * n_observed({})",
216 self.y_data.len(),
217 n_data,
218 n_observed,
219 )));
220 }
221
222 let model = Arc::new(model);
224 let y0 = Arc::new(y0);
225 let t_data = Arc::new(self.t_data);
226 let y_data = Arc::new(self.y_data);
227 let obs_idx = Arc::new(obs_idx);
228 let n_states = self.n_states;
229 let ode_rtol = self.ode_rtol;
230 let ode_atol = self.ode_atol;
231 let counter = Arc::new(AtomicUsize::new(0));
232 let has_bounds = self.param_bounds.iter().any(|b| b.is_some());
233
234 let optim_result = if has_bounds {
236 let m = Arc::clone(&model);
238 let y0c = Arc::clone(&y0);
239 let td = Arc::clone(&t_data);
240 let yd = Arc::clone(&y_data);
241 let oi = Arc::clone(&obs_idx);
242 let ctr = Arc::clone(&counter);
243
244 let mut prob = OptimProblem::new(self.n_params)
245 .x0(¶ms0)
246 .objective(move |p: &[S]| {
247 let pred = integrate_at_params(&m, &y0c, &td, p, n_states, ode_rtol, ode_atol);
248 ctr.fetch_add(1, Ordering::Relaxed);
249 let mut sos = S::ZERO;
250 for i in 0..td.len() {
251 for (j, &idx) in oi.iter().enumerate() {
252 let r = pred[i * n_states + idx] - yd[i * oi.len() + j];
253 sos += r * r;
254 }
255 }
256 sos
257 })
258 .max_iter(self.max_iter);
259
260 for (i, b) in self.param_bounds.iter().enumerate() {
261 if let Some(&(lo, hi)) = b.as_ref() {
262 prob = prob.bounds(i, (lo, hi));
263 }
264 }
265 prob.solve().map_err(OcpError::OptimFailed)?
266 } else {
267 let m = Arc::clone(&model);
269 let y0c = Arc::clone(&y0);
270 let td = Arc::clone(&t_data);
271 let yd = Arc::clone(&y_data);
272 let oi = Arc::clone(&obs_idx);
273 let ctr = Arc::clone(&counter);
274
275 OptimProblem::new(self.n_params)
276 .x0(¶ms0)
277 .least_squares(n_residuals, move |p: &[S], r: &mut [S]| {
278 let pred = integrate_at_params(&m, &y0c, &td, p, n_states, ode_rtol, ode_atol);
279 ctr.fetch_add(1, Ordering::Relaxed);
280 for i in 0..td.len() {
281 for (j, &idx) in oi.iter().enumerate() {
282 r[i * oi.len() + j] = pred[i * n_states + idx] - yd[i * oi.len() + j];
283 }
284 }
285 })
286 .max_iter(self.max_iter)
287 .solve()
288 .map_err(OcpError::OptimFailed)?
289 };
290
291 let optimal_params = &optim_result.x;
293 let pred_full = integrate_at_params(
294 &model,
295 &y0,
296 &t_data,
297 optimal_params,
298 n_states,
299 ode_rtol,
300 ode_atol,
301 );
302 counter.fetch_add(1, Ordering::Relaxed);
303
304 let mut predicted = Vec::with_capacity(n_residuals);
306 for i in 0..n_data {
307 for &idx in obs_idx.iter() {
308 predicted.push(pred_full[i * n_states + idx]);
309 }
310 }
311
312 let mut rnorm2 = S::ZERO;
314 for k in 0..n_residuals {
315 let r = predicted[k] - y_data[k];
316 rnorm2 += r * r;
317 }
318 let residual_norm = rnorm2.sqrt();
319
320 Ok(ParamEstResult {
321 params: optimal_params.clone(),
322 residual_norm,
323 iterations: optim_result.iterations,
324 converged: optim_result.converged,
325 message: optim_result.message.clone(),
326 predicted,
327 n_integrations: counter.load(Ordering::Relaxed),
328 wall_time_secs: start.elapsed().as_secs_f64(),
329 })
330 }
331}
332
333fn integrate_at_params<S: Scalar>(
346 model: &Arc<Box<ModelFn<S>>>,
347 y0: &Arc<Vec<S>>,
348 t_data: &Arc<Vec<S>>,
349 params: &[S],
350 n_states: usize,
351 rtol: S,
352 atol: S,
353) -> Vec<S> {
354 let n_data = t_data.len();
355 let total = n_data * n_states;
356
357 let options = SolverOptions::default().rtol(rtol).atol(atol);
358
359 let mut out = Vec::with_capacity(total);
361
362 let mut y_cur = y0.as_ref().clone();
364
365 out.extend_from_slice(&y_cur);
367
368 let big = S::from_f64(1e10);
369 let tiny = S::from_f64(1e-15);
370
371 for i in 0..(n_data - 1) {
373 let t_start = t_data[i];
374 let t_end = t_data[i + 1];
375
376 if (t_end - t_start).abs() < tiny {
378 out.extend_from_slice(&y_cur);
379 continue;
380 }
381
382 let p = params.to_vec();
383 let model_ref = Arc::clone(model);
384 let rhs = move |t: S, y: &[S], dydt: &mut [S]| {
385 model_ref(t, y, dydt, &p);
386 };
387
388 let problem = OdeProblem::new(rhs, t_start, t_end, y_cur.clone());
389
390 match DoPri5::solve(&problem, t_start, t_end, &y_cur, &options) {
391 Ok(result) if result.success => {
392 if let Some(y_final) = result.y_final() {
394 y_cur = y_final.to_vec();
395 out.extend_from_slice(&y_cur);
396 } else {
397 return vec![big; total];
398 }
399 }
400 _ => return vec![big; total],
401 }
402 }
403
404 out
405}
406
407#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
417 fn test_exponential_decay() {
418 let k_true = 0.5;
419 let y0_val = 1.0;
420 let t_data: Vec<f64> = (0..=10).map(|i| i as f64 * 0.5).collect();
421 let y_data: Vec<f64> = t_data
422 .iter()
423 .map(|&t| y0_val * (-k_true * t).exp())
424 .collect();
425
426 let result = ParamEstProblem::new(1, 1)
427 .model(|_t: f64, y, dydt, p| {
428 dydt[0] = -p[0] * y[0];
429 })
430 .initial_state(vec![y0_val])
431 .params(vec![1.0])
432 .data(t_data, y_data)
433 .solve()
434 .expect("parameter estimation failed");
435
436 assert!(
437 result.converged,
438 "optimizer did not converge: {}",
439 result.message
440 );
441 let k_est = result.params[0];
442 assert!(
443 (k_est - k_true).abs() < 0.01,
444 "k_est = {k_est}, expected ~{k_true}"
445 );
446 assert!(
447 result.residual_norm < 1e-4,
448 "residual_norm = {}",
449 result.residual_norm
450 );
451 assert!(result.n_integrations > 0);
452 }
453
454 #[test]
457 fn test_two_param_model() {
458 let a_true = 1.0;
459 let b_true = 2.0;
460 let y0_val = 1.0;
461
462 let t_data: Vec<f64> = (0..=20).map(|i| i as f64 * 0.25).collect();
463 let y_data: Vec<f64> = t_data
464 .iter()
465 .map(|&t| b_true / a_true + (y0_val - b_true / a_true) * (-a_true * t).exp())
466 .collect();
467
468 let result = ParamEstProblem::new(2, 1)
469 .model(|_t: f64, y, dydt, p| {
470 dydt[0] = -p[0] * y[0] + p[1];
471 })
472 .initial_state(vec![y0_val])
473 .params(vec![0.5, 1.0])
474 .data(t_data, y_data)
475 .solve()
476 .expect("parameter estimation failed");
477
478 assert!(
479 result.converged,
480 "optimizer did not converge: {}",
481 result.message
482 );
483 assert!(
484 (result.params[0] - a_true).abs() < 0.1,
485 "a_est = {}, expected ~{a_true}",
486 result.params[0]
487 );
488 assert!(
489 (result.params[1] - b_true).abs() < 0.1,
490 "b_est = {}, expected ~{b_true}",
491 result.params[1]
492 );
493 }
494
495 #[test]
497 fn test_param_est_with_bounds() {
498 let k_true = 0.5;
499 let y0_val = 1.0;
500 let t_data: Vec<f64> = (0..=10).map(|i| i as f64 * 0.5).collect();
501 let y_data: Vec<f64> = t_data
502 .iter()
503 .map(|&t| y0_val * (-k_true * t).exp())
504 .collect();
505
506 let result = ParamEstProblem::new(1, 1)
507 .model(|_t: f64, y, dydt, p| {
508 dydt[0] = -p[0] * y[0];
509 })
510 .initial_state(vec![y0_val])
511 .params(vec![3.0])
512 .param_bounds(0, (0.01, 5.0))
513 .data(t_data, y_data)
514 .solve()
515 .expect("parameter estimation failed");
516
517 assert!(
518 result.converged,
519 "optimizer did not converge: {}",
520 result.message
521 );
522 let k_est = result.params[0];
523 assert!(
524 (k_est - k_true).abs() < 0.05,
525 "k_est = {k_est}, expected ~{k_true}"
526 );
527 assert!(
528 (0.01..=5.0).contains(&k_est),
529 "k_est out of bounds: {k_est}"
530 );
531 }
532
533 #[test]
539 fn test_partial_observation() {
540 let a_true = 0.5;
541 let b_true = 1.0;
542 let x0 = 1.0;
543 let y0_val = 0.0;
544
545 let t_data: Vec<f64> = (0..=20).map(|i| i as f64 * 0.5).collect();
547
548 let opts = numra_ode::SolverOptions::default().rtol(1e-12).atol(1e-14);
550
551 let mut y_data = Vec::new();
553 let mut y_cur = vec![x0, y0_val];
554 y_data.push(y_cur[0]); for i in 0..(t_data.len() - 1) {
556 let t_s = t_data[i];
557 let t_e = t_data[i + 1];
558 let prob = numra_ode::OdeProblem::new(
559 move |_t: f64, y: &[f64], dydt: &mut [f64]| {
560 dydt[0] = -a_true * y[0] + y[1];
561 dydt[1] = y[0] - b_true * y[1];
562 },
563 t_s,
564 t_e,
565 y_cur.clone(),
566 );
567 let res = numra_ode::DoPri5::solve(&prob, t_s, t_e, &y_cur, &opts).unwrap();
568 y_cur = res.y_final().unwrap().to_vec();
569 y_data.push(y_cur[0]); }
571
572 let result = ParamEstProblem::new(2, 2)
573 .model(|_t: f64, y, dydt, p| {
574 dydt[0] = -p[0] * y[0] + y[1];
575 dydt[1] = y[0] - p[1] * y[1];
576 })
577 .initial_state(vec![x0, y0_val])
578 .params(vec![0.8, 1.5]) .observed(vec![0]) .data(t_data, y_data)
581 .max_iter(200)
582 .solve()
583 .expect("parameter estimation failed");
584
585 assert!(
586 result.converged,
587 "optimizer did not converge: {}",
588 result.message
589 );
590 assert!(
591 (result.params[0] - a_true).abs() < 0.2,
592 "a_est = {}, expected ~{a_true}",
593 result.params[0]
594 );
595 assert!(
596 (result.params[1] - b_true).abs() < 0.2,
597 "b_est = {}, expected ~{b_true}",
598 result.params[1]
599 );
600 }
601}