1use std::sync::Arc;
31use std::time::Instant;
32
33use numra_core::Scalar;
34use numra_ode::{DoPri5, OdeProblem, Solver, SolverOptions};
35use numra_optim::OptimProblem;
36
37use crate::error::OcpError;
38
39type DynamicsFn<S> = dyn Fn(S, &[S], &mut [S], &[S]) + Send + Sync;
45
46type TerminalCostFn<S> = dyn Fn(&[S]) -> S + Send + Sync;
48
49type RunningCostFn<S> = dyn Fn(S, &[S], &[S]) -> S + Send + Sync;
51
52type TerminalConstraintFn<S> = dyn Fn(&[S]) -> Vec<S> + Send + Sync;
54
55#[derive(Clone, Debug)]
57pub struct MultipleShootingResult<S: Scalar> {
58 pub controls: Vec<S>,
60 pub final_state: Vec<S>,
62 pub objective: S,
64 pub converged: bool,
66 pub message: String,
68 pub iterations: usize,
70 pub wall_time_secs: f64,
72 pub t_trajectory: Vec<S>,
74 pub y_trajectory: Vec<S>,
76 pub n_states: usize,
78}
79
80#[derive(Clone, Copy)]
86struct MsLayout<S: Scalar> {
87 nx: usize,
88 nu: usize,
89 n_seg: usize,
90 x_offset: usize, u_offset: usize, t0: S,
93 dt: S,
94 ode_rtol: S,
95 ode_atol: S,
96}
97
98impl<S: Scalar> MsLayout<S> {
99 fn n_decision(&self) -> usize {
100 self.u_offset + self.n_seg * self.nu
101 }
102
103 fn x_start(&self, k: usize) -> usize {
104 self.x_offset + k * self.nx
105 }
106
107 fn u_start(&self, k: usize) -> usize {
108 self.u_offset + k * self.nu
109 }
110
111 fn t_start(&self, k: usize) -> S {
112 self.t0 + S::from_usize(k) * self.dt
113 }
114
115 fn t_end(&self, k: usize) -> S {
116 self.t0 + S::from_usize(k + 1) * self.dt
117 }
118}
119
120fn integrate_segment<S: Scalar>(
122 dynamics: &Arc<Box<DynamicsFn<S>>>,
123 x_k: &[S],
124 u_k: &[S],
125 t_start: S,
126 t_end: S,
127 rtol: S,
128 atol: S,
129) -> Result<Vec<S>, String> {
130 let dyn_ref = Arc::clone(dynamics);
131 let u_seg = u_k.to_vec();
132 let rhs = move |t: S, y: &[S], dydt: &mut [S]| {
133 dyn_ref(t, y, dydt, &u_seg);
134 };
135 let opts = SolverOptions::default().rtol(rtol).atol(atol);
136 let problem = OdeProblem::new(rhs, t_start, t_end, x_k.to_vec());
137 let result = DoPri5::solve(&problem, t_start, t_end, x_k, &opts).map_err(|e| e.to_string())?;
138 if !result.success {
139 return Err(result.message.clone());
140 }
141 result.y_final().ok_or_else(|| "empty result".into())
142}
143
144#[allow(clippy::too_many_arguments)]
146fn integrate_segment_full<S: Scalar>(
147 dynamics: &Arc<Box<DynamicsFn<S>>>,
148 x_k: &[S],
149 u_k: &[S],
150 t_start: S,
151 t_end: S,
152 rtol: S,
153 atol: S,
154 running_cost: Option<&RunningCostFn<S>>,
155 nx: usize,
156) -> Result<(Vec<S>, Vec<S>, S), String> {
157 let dyn_ref = Arc::clone(dynamics);
158 let u_seg = u_k.to_vec();
159 let u_for_cost = u_k.to_vec();
160 let rhs = move |t: S, y: &[S], dydt: &mut [S]| {
161 dyn_ref(t, y, dydt, &u_seg);
162 };
163 let opts = SolverOptions::default().rtol(rtol).atol(atol);
164 let problem = OdeProblem::new(rhs, t_start, t_end, x_k.to_vec());
165 let result = DoPri5::solve(&problem, t_start, t_end, x_k, &opts).map_err(|e| e.to_string())?;
166 if !result.success {
167 return Err(result.message);
168 }
169
170 let mut cost = S::ZERO;
171 if let Some(rc) = running_cost {
172 let n_pts = result.t.len();
173 for i in 0..n_pts.saturating_sub(1) {
174 let ti = result.t[i];
175 let ti1 = result.t[i + 1];
176 let yi = &result.y[i * nx..(i + 1) * nx];
177 let yi1 = &result.y[(i + 1) * nx..(i + 2) * nx];
178 let li = rc(ti, yi, &u_for_cost);
179 let li1 = rc(ti1, yi1, &u_for_cost);
180 cost += S::HALF * (ti1 - ti) * (li + li1);
181 }
182 }
183
184 Ok((result.t, result.y, cost))
185}
186
187pub struct MultipleShootingProblem<S: Scalar> {
193 n_states: usize,
194 n_controls: usize,
195 dynamics: Option<Box<DynamicsFn<S>>>,
196 y0: Option<Vec<S>>,
197 t0: S,
198 tf: S,
199 n_segments: usize,
200 control_bounds: Vec<Option<(S, S)>>,
201 terminal_cost: Option<Box<TerminalCostFn<S>>>,
202 running_cost: Option<Box<RunningCostFn<S>>>,
203 terminal_constraints: Option<Box<TerminalConstraintFn<S>>>,
204 ode_rtol: S,
205 ode_atol: S,
206 max_iter: usize,
207}
208
209impl<S: Scalar> MultipleShootingProblem<S> {
210 pub fn new(n_states: usize, n_controls: usize) -> Self {
212 Self {
213 n_states,
214 n_controls,
215 dynamics: None,
216 y0: None,
217 t0: S::ZERO,
218 tf: S::ONE,
219 n_segments: 10,
220 control_bounds: vec![None; n_controls],
221 terminal_cost: None,
222 running_cost: None,
223 terminal_constraints: None,
224 ode_rtol: S::from_f64(1e-8),
225 ode_atol: S::from_f64(1e-10),
226 max_iter: 200,
227 }
228 }
229
230 pub fn dynamics<F>(mut self, f: F) -> Self
232 where
233 F: Fn(S, &[S], &mut [S], &[S]) + Send + Sync + 'static,
234 {
235 self.dynamics = Some(Box::new(f));
236 self
237 }
238
239 pub fn initial_state(mut self, y0: Vec<S>) -> Self {
241 self.y0 = Some(y0);
242 self
243 }
244
245 pub fn time_span(mut self, t0: S, tf: S) -> Self {
247 self.t0 = t0;
248 self.tf = tf;
249 self
250 }
251
252 pub fn n_segments(mut self, n: usize) -> Self {
254 self.n_segments = n;
255 self
256 }
257
258 pub fn control_bounds(mut self, bounds: Vec<Option<(S, S)>>) -> Self {
260 self.control_bounds = bounds;
261 self
262 }
263
264 pub fn terminal_cost<F>(mut self, f: F) -> Self
266 where
267 F: Fn(&[S]) -> S + Send + Sync + 'static,
268 {
269 self.terminal_cost = Some(Box::new(f));
270 self
271 }
272
273 pub fn running_cost<F>(mut self, f: F) -> Self
275 where
276 F: Fn(S, &[S], &[S]) -> S + Send + Sync + 'static,
277 {
278 self.running_cost = Some(Box::new(f));
279 self
280 }
281
282 pub fn terminal_constraint<F>(mut self, f: F) -> Self
284 where
285 F: Fn(&[S]) -> Vec<S> + Send + Sync + 'static,
286 {
287 self.terminal_constraints = Some(Box::new(f));
288 self
289 }
290
291 pub fn ode_tolerances(mut self, rtol: S, atol: S) -> Self {
293 self.ode_rtol = rtol;
294 self.ode_atol = atol;
295 self
296 }
297
298 pub fn max_iter(mut self, n: usize) -> Self {
300 self.max_iter = n;
301 self
302 }
303
304 pub fn solve(self) -> Result<MultipleShootingResult<S>, OcpError>
310 where
311 S: faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
312 {
313 let start = Instant::now();
314
315 let dynamics = self.dynamics.ok_or(OcpError::NoDynamics)?;
317 let y0 = self.y0.ok_or(OcpError::NoInitialState)?;
318
319 if y0.len() != self.n_states {
320 return Err(OcpError::DimensionMismatch(format!(
321 "y0 length {} != n_states {}",
322 y0.len(),
323 self.n_states,
324 )));
325 }
326
327 if self.terminal_cost.is_none() && self.running_cost.is_none() {
328 return Err(OcpError::Other(
329 "at least one of terminal_cost or running_cost must be set".into(),
330 ));
331 }
332
333 let nx = self.n_states;
334 let nu = self.n_controls;
335 let n_seg = self.n_segments;
336 let dt = (self.tf - self.t0) / S::from_usize(n_seg);
337
338 let lay = MsLayout {
339 nx,
340 nu,
341 n_seg,
342 x_offset: 0,
343 u_offset: n_seg * nx,
344 t0: self.t0,
345 dt,
346 ode_rtol: self.ode_rtol,
347 ode_atol: self.ode_atol,
348 };
349
350 let n_decision = lay.n_decision();
351
352 let dynamics = Arc::new(dynamics);
355 let mut z0 = vec![S::ZERO; n_decision];
356
357 z0[..nx].copy_from_slice(&y0);
359
360 let mut x_cur = y0.clone();
362 for k in 1..n_seg {
363 match integrate_segment(
364 &dynamics,
365 &x_cur,
366 &vec![S::ZERO; nu],
367 lay.t_start(k - 1),
368 lay.t_end(k - 1),
369 lay.ode_rtol,
370 lay.ode_atol,
371 ) {
372 Ok(x_next) => {
373 z0[lay.x_start(k)..lay.x_start(k) + nx].copy_from_slice(&x_next);
374 x_cur = x_next;
375 }
376 Err(_) => {
377 z0[lay.x_start(k)..lay.x_start(k) + nx].copy_from_slice(&y0);
379 }
380 }
381 }
382
383 let terminal_cost: Option<Arc<Box<TerminalCostFn<S>>>> = self.terminal_cost.map(Arc::new);
387 let running_cost: Option<Arc<Box<RunningCostFn<S>>>> = self.running_cost.map(Arc::new);
388
389 let dyn_obj = Arc::clone(&dynamics);
391 let tc_obj = terminal_cost.clone();
392 let rc_obj = running_cost.clone();
393
394 let big = S::from_f64(1e20);
395 let objective_fn = move |z: &[S]| -> S {
396 let mut total_cost = S::ZERO;
397 let mut x_final = vec![S::ZERO; nx];
398
399 for k in 0..n_seg {
400 let x_k = &z[lay.x_start(k)..lay.x_start(k) + nx];
401 let u_k = &z[lay.u_start(k)..lay.u_start(k) + nu];
402 let rc_ref = rc_obj.as_ref().map(|b| &***b as &RunningCostFn<S>);
403
404 match integrate_segment_full(
405 &dyn_obj,
406 x_k,
407 u_k,
408 lay.t_start(k),
409 lay.t_end(k),
410 lay.ode_rtol,
411 lay.ode_atol,
412 rc_ref,
413 nx,
414 ) {
415 Ok((_t, y, seg_cost)) => {
416 total_cost += seg_cost;
417 let n_pts = _t.len();
418 if n_pts > 0 {
419 x_final.copy_from_slice(&y[(n_pts - 1) * nx..n_pts * nx]);
420 }
421 }
422 Err(_) => return big,
423 }
424 }
425
426 if let Some(ref tc) = tc_obj {
427 total_cost += tc(&x_final);
428 }
429
430 total_cost
431 };
432
433 let mut prob = OptimProblem::new(n_decision)
435 .x0(&z0)
436 .objective(objective_fn)
437 .max_iter(self.max_iter);
438
439 for j in 0..nx {
441 let y0_j = y0[j];
442 prob = prob.constraint_eq(move |z: &[S]| -> S { z[j] - y0_j });
443 }
444
445 let big_c = S::from_f64(1e10);
448 for k in 0..(n_seg - 1) {
449 for j in 0..nx {
450 let dyn_c = Arc::clone(&dynamics);
451 prob = prob.constraint_eq(move |z: &[S]| -> S {
452 let x_k = &z[lay.x_start(k)..lay.x_start(k) + nx];
453 let u_k = &z[lay.u_start(k)..lay.u_start(k) + nu];
454
455 match integrate_segment(
456 &dyn_c,
457 x_k,
458 u_k,
459 lay.t_start(k),
460 lay.t_end(k),
461 lay.ode_rtol,
462 lay.ode_atol,
463 ) {
464 Ok(x_end) => x_end[j] - z[lay.x_start(k + 1) + j],
465 Err(_) => big_c,
466 }
467 });
468 }
469 }
470
471 if let Some(tc_fn) = self.terminal_constraints {
473 let tc_fn = Arc::new(tc_fn);
474 let dummy = vec![S::ZERO; nx];
475 let n_tc = tc_fn(&dummy).len();
476
477 for ci in 0..n_tc {
478 let tc_c = Arc::clone(&tc_fn);
479 let dyn_c = Arc::clone(&dynamics);
480
481 prob = prob.constraint_eq(move |z: &[S]| -> S {
482 let k = n_seg - 1;
484 let x_k = &z[lay.x_start(k)..lay.x_start(k) + nx];
485 let u_k = &z[lay.u_start(k)..lay.u_start(k) + nu];
486
487 match integrate_segment(
488 &dyn_c,
489 x_k,
490 u_k,
491 lay.t_start(k),
492 lay.t_end(k),
493 lay.ode_rtol,
494 lay.ode_atol,
495 ) {
496 Ok(x_final) => tc_c(&x_final)[ci],
497 Err(_) => big_c,
498 }
499 });
500 }
501 }
502
503 for seg in 0..n_seg {
505 for ctrl in 0..nu {
506 if let Some(&Some((lo, hi))) = self.control_bounds.get(ctrl) {
507 prob = prob.bounds(lay.u_start(seg) + ctrl, (lo, hi));
508 }
509 }
510 }
511
512 let optim_result = prob.solve().map_err(OcpError::OptimFailed)?;
514 let z_opt = &optim_result.x;
515
516 let mut traj_t: Vec<S> = Vec::new();
518 let mut traj_y: Vec<S> = Vec::new();
519 let mut x_final = y0.clone();
520 let mut total_obj = S::ZERO;
521
522 let rc_ref = running_cost.as_ref().map(|b| &***b as &RunningCostFn<S>);
523
524 for k in 0..n_seg {
525 let x_k = &z_opt[lay.x_start(k)..lay.x_start(k) + nx];
526 let u_k = &z_opt[lay.u_start(k)..lay.u_start(k) + nu];
527
528 let (seg_t, seg_y, seg_cost) = integrate_segment_full(
529 &dynamics,
530 x_k,
531 u_k,
532 lay.t_start(k),
533 lay.t_end(k),
534 lay.ode_rtol,
535 lay.ode_atol,
536 rc_ref,
537 nx,
538 )
539 .map_err(OcpError::IntegrationFailed)?;
540
541 total_obj += seg_cost;
542
543 let skip = if k == 0 { 0 } else { 1 };
545 for i in skip..seg_t.len() {
546 traj_t.push(seg_t[i]);
547 traj_y.extend_from_slice(&seg_y[i * nx..(i + 1) * nx]);
548 }
549
550 let n_pts = seg_t.len();
551 if n_pts > 0 {
552 x_final = seg_y[(n_pts - 1) * nx..n_pts * nx].to_vec();
553 }
554 }
555
556 if let Some(ref tc) = terminal_cost {
557 total_obj += tc(&x_final);
558 }
559
560 let controls = z_opt[lay.u_offset..].to_vec();
562
563 Ok(MultipleShootingResult {
564 controls,
565 final_state: x_final,
566 objective: total_obj,
567 converged: optim_result.converged,
568 message: optim_result.message.clone(),
569 iterations: optim_result.iterations,
570 wall_time_secs: start.elapsed().as_secs_f64(),
571 t_trajectory: traj_t,
572 y_trajectory: traj_y,
573 n_states: nx,
574 })
575 }
576}
577
578#[cfg(test)]
583mod tests {
584 use super::*;
585
586 #[test]
589 fn test_double_integrator() {
590 let result = MultipleShootingProblem::new(2, 1)
591 .dynamics(|_t, y, dydt, u| {
592 dydt[0] = y[1];
593 dydt[1] = u[0];
594 })
595 .initial_state(vec![0.0, 0.0])
596 .time_span(0.0, 2.0)
597 .n_segments(10)
598 .terminal_cost(|y| 100.0 * ((y[0] - 1.0).powi(2) + y[1].powi(2)))
599 .running_cost(|_t, _y, u| 0.01 * u[0].powi(2))
600 .max_iter(200)
601 .solve()
602 .expect("multiple shooting solve failed");
603
604 let x_final = result.final_state[0];
605 assert!(
606 (x_final - 1.0).abs() < 0.3,
607 "x(T) = {x_final}, expected within 0.3 of 1.0"
608 );
609 }
610
611 #[test]
613 fn test_minimum_energy() {
614 let result = MultipleShootingProblem::new(1, 1)
615 .dynamics(|_t, _y, dydt, u| {
616 dydt[0] = u[0];
617 })
618 .initial_state(vec![0.0])
619 .time_span(0.0, 1.0)
620 .n_segments(10)
621 .terminal_cost(|y| 1000.0 * (y[0] - 1.0).powi(2))
622 .running_cost(|_t, _y, u| u[0].powi(2))
623 .max_iter(200)
624 .solve()
625 .expect("multiple shooting solve failed");
626
627 let x_final = result.final_state[0];
628 assert!(
629 (x_final - 1.0).abs() < 0.3,
630 "x(T) = {x_final}, expected within 0.3 of 1.0"
631 );
632 }
633
634 #[test]
637 fn test_vs_single_shooting() {
638 let ms_result = MultipleShootingProblem::new(1, 1)
639 .dynamics(|_t, _y, dydt, u| {
640 dydt[0] = u[0];
641 })
642 .initial_state(vec![0.0])
643 .time_span(0.0, 1.0)
644 .n_segments(5)
645 .terminal_cost(|y| (y[0] - 2.0).powi(2))
646 .max_iter(100)
647 .solve()
648 .expect("multiple shooting solve failed");
649
650 let ss_result = crate::ShootingProblem::new(1, 1)
651 .dynamics(|_t, _y, dydt, u| {
652 dydt[0] = u[0];
653 })
654 .initial_state(vec![0.0])
655 .time_span(0.0, 1.0)
656 .n_segments(5)
657 .terminal_cost(|y| (y[0] - 2.0).powi(2))
658 .max_iter(100)
659 .solve()
660 .expect("single shooting solve failed");
661
662 assert!(
664 (ms_result.final_state[0] - ss_result.final_state[0]).abs() < 0.5,
665 "MS x(T) = {}, SS x(T) = {}",
666 ms_result.final_state[0],
667 ss_result.final_state[0],
668 );
669 }
670
671 #[test]
673 fn test_trajectory_structure() {
674 let result = MultipleShootingProblem::new(1, 1)
675 .dynamics(|_t, _y, dydt, u| {
676 dydt[0] = u[0];
677 })
678 .initial_state(vec![0.0])
679 .time_span(0.0, 1.0)
680 .n_segments(5)
681 .terminal_cost(|y| y[0].powi(2))
682 .max_iter(50)
683 .solve()
684 .expect("multiple shooting solve failed");
685
686 assert!(
687 !result.t_trajectory.is_empty(),
688 "t_trajectory should be non-empty"
689 );
690 assert!(
691 !result.y_trajectory.is_empty(),
692 "y_trajectory should be non-empty"
693 );
694
695 assert!(
697 (result.t_trajectory[0] - 0.0).abs() < 1e-12,
698 "first time should be 0.0"
699 );
700
701 let t_last = *result.t_trajectory.last().unwrap();
703 assert!(
704 (t_last - 1.0).abs() < 1e-6,
705 "last time should be ~1.0, got {t_last}"
706 );
707
708 assert_eq!(
710 result.y_trajectory.len(),
711 result.t_trajectory.len() * result.n_states,
712 );
713 }
714}