Skip to main content

numra_ocp/
shooting.rs

1//! Single-shooting optimal control.
2//!
3//! Given a controlled ODE `dy/dt = f(t, y, u)`, an initial state `y(t0)`,
4//! and a cost functional (terminal and/or running cost), find the piecewise-
5//! constant control sequence `u_0, u_1, ..., u_{N-1}` that minimizes the
6//! total cost subject to optional terminal equality constraints.
7//!
8//! The control vector is parameterised as `N` segments of `n_controls` each,
9//! yielding `n_decision = n_controls * n_segments` decision variables.
10//!
11//! Author: Moussa Leblouba
12//! Date: 9 February 2026
13//! Modified: 2 May 2026
14
15use std::sync::Arc;
16use std::time::Instant;
17
18use numra_core::Scalar;
19use numra_ode::{DoPri5, OdeProblem, Solver, SolverOptions};
20use numra_optim::OptimProblem;
21
22use crate::error::OcpError;
23
24// ---------------------------------------------------------------------------
25// Types
26// ---------------------------------------------------------------------------
27
28/// Dynamics closure: `(t, y, dydt, u)`.
29type DynamicsFn<S> = dyn Fn(S, &[S], &mut [S], &[S]) + Send + Sync;
30
31/// Terminal cost closure: `phi(y(T)) -> S`.
32type TerminalCostFn<S> = dyn Fn(&[S]) -> S + Send + Sync;
33
34/// Running cost closure: `L(t, y, u) -> S`.
35type RunningCostFn<S> = dyn Fn(S, &[S], &[S]) -> S + Send + Sync;
36
37/// Terminal constraint closure: `h(y(T)) -> Vec<S>`, each component = 0.
38type TerminalConstraintFn<S> = dyn Fn(&[S]) -> Vec<S> + Send + Sync;
39
40/// Result of a single-shooting optimal control solve.
41#[derive(Clone, Debug)]
42pub struct ShootingResult<S: Scalar> {
43    /// Optimal control vector (flat: `n_controls * n_segments`).
44    pub controls: Vec<S>,
45    /// Final state `y(T)` at the optimum.
46    pub final_state: Vec<S>,
47    /// Optimal objective value.
48    pub objective: S,
49    /// Whether the optimizer converged.
50    pub converged: bool,
51    /// Human-readable status message.
52    pub message: String,
53    /// Number of optimizer iterations.
54    pub iterations: usize,
55    /// Wall-clock time in seconds.
56    pub wall_time_secs: f64,
57    /// Time grid of the reconstructed trajectory.
58    pub t_trajectory: Vec<S>,
59    /// State trajectory (flat row-major: `y[i * n_states + j]`).
60    pub y_trajectory: Vec<S>,
61    /// Number of states (useful for interpreting `y_trajectory`).
62    pub n_states: usize,
63}
64
65// ---------------------------------------------------------------------------
66// Builder
67// ---------------------------------------------------------------------------
68
69/// Builder for single-shooting optimal control problems.
70pub struct ShootingProblem<S: Scalar> {
71    n_states: usize,
72    n_controls: usize,
73    dynamics: Option<Box<DynamicsFn<S>>>,
74    y0: Option<Vec<S>>,
75    t0: S,
76    tf: S,
77    n_segments: usize,
78    control_bounds: Vec<Option<(S, S)>>,
79    terminal_cost: Option<Box<TerminalCostFn<S>>>,
80    running_cost: Option<Box<RunningCostFn<S>>>,
81    terminal_constraints: Option<Box<TerminalConstraintFn<S>>>,
82    ode_rtol: S,
83    ode_atol: S,
84    max_iter: usize,
85}
86
87impl<S: Scalar> ShootingProblem<S> {
88    /// Create a new shooting problem.
89    ///
90    /// - `n_states`: dimension of the ODE state vector.
91    /// - `n_controls`: dimension of the control vector per segment.
92    pub fn new(n_states: usize, n_controls: usize) -> Self {
93        Self {
94            n_states,
95            n_controls,
96            dynamics: None,
97            y0: None,
98            t0: S::ZERO,
99            tf: S::ONE,
100            n_segments: 10,
101            control_bounds: vec![None; n_controls],
102            terminal_cost: None,
103            running_cost: None,
104            terminal_constraints: None,
105            ode_rtol: S::from_f64(1e-8),
106            ode_atol: S::from_f64(1e-10),
107            max_iter: 200,
108        }
109    }
110
111    /// Set the controlled ODE right-hand side: `f(t, y, dydt, u)`.
112    pub fn dynamics<F>(mut self, f: F) -> Self
113    where
114        F: Fn(S, &[S], &mut [S], &[S]) + Send + Sync + 'static,
115    {
116        self.dynamics = Some(Box::new(f));
117        self
118    }
119
120    /// Set the initial state `y(t0)`.
121    pub fn initial_state(mut self, y0: Vec<S>) -> Self {
122        self.y0 = Some(y0);
123        self
124    }
125
126    /// Set the time interval `[t0, tf]`.
127    pub fn time_span(mut self, t0: S, tf: S) -> Self {
128        self.t0 = t0;
129        self.tf = tf;
130        self
131    }
132
133    /// Set the number of control intervals.
134    pub fn n_segments(mut self, n: usize) -> Self {
135        self.n_segments = n;
136        self
137    }
138
139    /// Set bounds for each control variable (applied to every segment).
140    pub fn control_bounds(mut self, bounds: Vec<Option<(S, S)>>) -> Self {
141        self.control_bounds = bounds;
142        self
143    }
144
145    /// Set the terminal cost `phi(y(T))`.
146    pub fn terminal_cost<F>(mut self, f: F) -> Self
147    where
148        F: Fn(&[S]) -> S + Send + Sync + 'static,
149    {
150        self.terminal_cost = Some(Box::new(f));
151        self
152    }
153
154    /// Set the running cost `L(t, y, u)`.
155    pub fn running_cost<F>(mut self, f: F) -> Self
156    where
157        F: Fn(S, &[S], &[S]) -> S + Send + Sync + 'static,
158    {
159        self.running_cost = Some(Box::new(f));
160        self
161    }
162
163    /// Set terminal equality constraints `h(y(T)) = 0`.
164    pub fn terminal_constraint<F>(mut self, f: F) -> Self
165    where
166        F: Fn(&[S]) -> Vec<S> + Send + Sync + 'static,
167    {
168        self.terminal_constraints = Some(Box::new(f));
169        self
170    }
171
172    /// Set ODE solver tolerances.
173    pub fn ode_tolerances(mut self, rtol: S, atol: S) -> Self {
174        self.ode_rtol = rtol;
175        self.ode_atol = atol;
176        self
177    }
178
179    /// Set maximum optimizer iterations.
180    pub fn max_iter(mut self, n: usize) -> Self {
181        self.max_iter = n;
182        self
183    }
184
185    // -----------------------------------------------------------------------
186    // Solve
187    // -----------------------------------------------------------------------
188
189    /// Execute the single-shooting optimal control solve.
190    pub fn solve(self) -> Result<ShootingResult<S>, OcpError>
191    where
192        S: faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
193    {
194        let start = Instant::now();
195
196        // -- Validate inputs ------------------------------------------------
197        let dynamics = self.dynamics.ok_or(OcpError::NoDynamics)?;
198        let y0 = self.y0.ok_or(OcpError::NoInitialState)?;
199
200        if y0.len() != self.n_states {
201            return Err(OcpError::DimensionMismatch(format!(
202                "y0 length {} != n_states {}",
203                y0.len(),
204                self.n_states,
205            )));
206        }
207
208        if self.terminal_cost.is_none() && self.running_cost.is_none() {
209            return Err(OcpError::Other(
210                "at least one of terminal_cost or running_cost must be set".into(),
211            ));
212        }
213
214        let n_states = self.n_states;
215        let n_controls = self.n_controls;
216        let n_segments = self.n_segments;
217        let n_decision = n_controls * n_segments;
218        let t0 = self.t0;
219        let tf = self.tf;
220        let dt = (tf - t0) / S::from_usize(n_segments);
221        let ode_rtol = self.ode_rtol;
222        let ode_atol = self.ode_atol;
223
224        // -- Shared state ---------------------------------------------------
225        let dynamics = Arc::new(dynamics);
226        let y0 = Arc::new(y0);
227        let terminal_cost: Option<Arc<Box<TerminalCostFn<S>>>> = self.terminal_cost.map(Arc::new);
228        let running_cost: Option<Arc<Box<RunningCostFn<S>>>> = self.running_cost.map(Arc::new);
229
230        let params = SimParams {
231            n_states,
232            n_controls,
233            n_segments,
234            t0,
235            dt,
236            ode_rtol,
237            ode_atol,
238        };
239
240        // -- Build objective ------------------------------------------------
241        let dyn_obj = Arc::clone(&dynamics);
242        let y0_obj = Arc::clone(&y0);
243        let tc_obj = terminal_cost.clone();
244        let rc_obj = running_cost.clone();
245        let p_obj = params;
246
247        let big = S::from_f64(1e20);
248        let objective_fn = move |u: &[S]| -> S {
249            let rc_ref = rc_obj.as_ref().map(|b| &***b as &RunningCostFn<S>);
250            let tc_ref = tc_obj.as_ref().map(|b| &***b as &TerminalCostFn<S>);
251            match simulate(&dyn_obj, &y0_obj, u, &p_obj, rc_ref, tc_ref) {
252                Ok((_traj_t, _traj_y, cost)) => cost,
253                Err(_) => big,
254            }
255        };
256
257        // -- Build OptimProblem ---------------------------------------------
258        let u0 = vec![S::ZERO; n_decision];
259        let mut prob = OptimProblem::new(n_decision)
260            .x0(&u0)
261            .objective(objective_fn)
262            .max_iter(self.max_iter);
263
264        // Apply control bounds to every segment.
265        for seg in 0..n_segments {
266            for ctrl in 0..n_controls {
267                if let Some(&Some((lo, hi))) = self.control_bounds.get(ctrl) {
268                    prob = prob.bounds(seg * n_controls + ctrl, (lo, hi));
269                }
270            }
271        }
272
273        // -- Terminal constraints -------------------------------------------
274        if let Some(tc_fn) = self.terminal_constraints {
275            let tc_fn = Arc::new(tc_fn);
276
277            // Probe to determine number of constraints.
278            let dummy = vec![S::ZERO; n_states];
279            let n_constraints = tc_fn(&dummy).len();
280
281            let big_c = S::from_f64(1e20);
282            for ci in 0..n_constraints {
283                let dyn_c = Arc::clone(&dynamics);
284                let y0_c = Arc::clone(&y0);
285                let tc_c = Arc::clone(&tc_fn);
286                let p_c = params;
287
288                prob = prob.constraint_eq(move |u: &[S]| -> S {
289                    match simulate_final_state(&dyn_c, &y0_c, u, &p_c) {
290                        Ok(y_final) => tc_c(&y_final)[ci],
291                        Err(_) => big_c,
292                    }
293                });
294            }
295        }
296
297        // -- Solve ----------------------------------------------------------
298        let optim_result = prob.solve().map_err(OcpError::OptimFailed)?;
299
300        // -- Reconstruct trajectory at optimal controls ---------------------
301        let optimal_u = &optim_result.x;
302        let rc_final = running_cost.as_ref().map(|b| &***b as &RunningCostFn<S>);
303        let tc_final = terminal_cost.as_ref().map(|b| &***b as &TerminalCostFn<S>);
304        let (traj_t, traj_y, obj) =
305            simulate(&dynamics, &y0, optimal_u, &params, rc_final, tc_final)
306                .map_err(OcpError::IntegrationFailed)?;
307
308        let final_state = if traj_t.is_empty() {
309            y0.as_ref().clone()
310        } else {
311            let last_idx = traj_t.len() - 1;
312            traj_y[last_idx * n_states..(last_idx + 1) * n_states].to_vec()
313        };
314
315        Ok(ShootingResult {
316            controls: optimal_u.clone(),
317            final_state,
318            objective: obj,
319            converged: optim_result.converged,
320            message: optim_result.message.clone(),
321            iterations: optim_result.iterations,
322            wall_time_secs: start.elapsed().as_secs_f64(),
323            t_trajectory: traj_t,
324            y_trajectory: traj_y,
325            n_states,
326        })
327    }
328}
329
330// ---------------------------------------------------------------------------
331// Internal helpers
332// ---------------------------------------------------------------------------
333
334/// Simulation parameters shared across closures.
335#[derive(Clone, Copy)]
336struct SimParams<S: Scalar> {
337    n_states: usize,
338    n_controls: usize,
339    n_segments: usize,
340    t0: S,
341    dt: S,
342    ode_rtol: S,
343    ode_atol: S,
344}
345
346/// Simulate the full trajectory under piecewise-constant controls and
347/// return `(t_grid, y_flat, total_cost)`.
348///
349/// `t_grid` and `y_flat` concatenate the ODE output of all segments.
350/// Duplicate boundary points between consecutive segments are removed
351/// (only the first segment keeps its initial point; subsequent segments
352/// skip the duplicated initial state).
353fn simulate<S: Scalar>(
354    dynamics: &Arc<Box<DynamicsFn<S>>>,
355    y0: &Arc<Vec<S>>,
356    u: &[S],
357    p: &SimParams<S>,
358    running_cost: Option<&RunningCostFn<S>>,
359    terminal_cost: Option<&TerminalCostFn<S>>,
360) -> Result<(Vec<S>, Vec<S>, S), String> {
361    let options = SolverOptions::default().rtol(p.ode_rtol).atol(p.ode_atol);
362
363    let mut traj_t: Vec<S> = Vec::new();
364    let mut traj_y: Vec<S> = Vec::new();
365    let mut y_cur = y0.as_ref().clone();
366    let mut total_cost = S::ZERO;
367
368    for seg in 0..p.n_segments {
369        let t_start = p.t0 + S::from_usize(seg) * p.dt;
370        let t_end = p.t0 + S::from_usize(seg + 1) * p.dt;
371        let u_seg: Vec<S> = u[seg * p.n_controls..(seg + 1) * p.n_controls].to_vec();
372
373        // Build ODE RHS with this segment's control baked in.
374        let dyn_ref = Arc::clone(dynamics);
375        let u_seg_clone = u_seg.clone();
376        let rhs = move |t: S, y: &[S], dydt: &mut [S]| {
377            dyn_ref(t, y, dydt, &u_seg_clone);
378        };
379
380        let problem = OdeProblem::new(rhs, t_start, t_end, y_cur.clone());
381        let result = DoPri5::solve(&problem, t_start, t_end, &y_cur, &options)
382            .map_err(|e| format!("segment {seg}: {e}"))?;
383
384        if !result.success {
385            return Err(format!("segment {seg}: {}", result.message));
386        }
387
388        // Accumulate running cost via trapezoidal rule.
389        if let Some(rc) = running_cost {
390            let n_pts = result.t.len();
391            for k in 0..n_pts.saturating_sub(1) {
392                let tk = result.t[k];
393                let tk1 = result.t[k + 1];
394                let yk = &result.y[k * p.n_states..(k + 1) * p.n_states];
395                let yk1 = &result.y[(k + 1) * p.n_states..(k + 2) * p.n_states];
396                let lk = rc(tk, yk, &u_seg);
397                let lk1 = rc(tk1, yk1, &u_seg);
398                total_cost += S::HALF * (tk1 - tk) * (lk + lk1);
399            }
400        }
401
402        // Append trajectory, skipping the first point for segments > 0
403        // to avoid duplicating the boundary.
404        let skip = if seg == 0 { 0 } else { 1 };
405        for k in skip..result.t.len() {
406            traj_t.push(result.t[k]);
407            traj_y.extend_from_slice(&result.y[k * p.n_states..(k + 1) * p.n_states]);
408        }
409
410        // Chain: final state of this segment is initial state of the next.
411        y_cur = result
412            .y_final()
413            .ok_or_else(|| format!("segment {seg}: empty result"))?;
414    }
415
416    // Add terminal cost.
417    if let Some(tc) = terminal_cost {
418        total_cost += tc(&y_cur);
419    }
420
421    Ok((traj_t, traj_y, total_cost))
422}
423
424/// Simulate only to obtain the final state `y(T)`.
425fn simulate_final_state<S: Scalar>(
426    dynamics: &Arc<Box<DynamicsFn<S>>>,
427    y0: &Arc<Vec<S>>,
428    u: &[S],
429    p: &SimParams<S>,
430) -> Result<Vec<S>, String> {
431    let options = SolverOptions::default().rtol(p.ode_rtol).atol(p.ode_atol);
432    let mut y_cur = y0.as_ref().clone();
433
434    for seg in 0..p.n_segments {
435        let t_start = p.t0 + S::from_usize(seg) * p.dt;
436        let t_end = p.t0 + S::from_usize(seg + 1) * p.dt;
437        let u_seg: Vec<S> = u[seg * p.n_controls..(seg + 1) * p.n_controls].to_vec();
438
439        let dyn_ref = Arc::clone(dynamics);
440        let rhs = move |t: S, y: &[S], dydt: &mut [S]| {
441            dyn_ref(t, y, dydt, &u_seg);
442        };
443
444        let problem = OdeProblem::new(rhs, t_start, t_end, y_cur.clone());
445        let result = DoPri5::solve(&problem, t_start, t_end, &y_cur, &options)
446            .map_err(|e| format!("segment {seg}: {e}"))?;
447
448        if !result.success {
449            return Err(format!("segment {seg}: {}", result.message));
450        }
451
452        y_cur = result
453            .y_final()
454            .ok_or_else(|| format!("segment {seg}: empty result"))?;
455    }
456
457    Ok(y_cur)
458}
459
460// ---------------------------------------------------------------------------
461// Tests
462// ---------------------------------------------------------------------------
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467
468    /// Double integrator: dx/dt = v, dv/dt = u.
469    /// Terminal cost: 100*((x-1)^2 + v^2).
470    /// Running cost: 0.01*u^2.
471    /// T = 2.0, 10 segments.
472    #[test]
473    fn test_double_integrator_terminal_cost() {
474        let result = ShootingProblem::new(2, 1)
475            .dynamics(|_t, y, dydt, u| {
476                dydt[0] = y[1]; // dx/dt = v
477                dydt[1] = u[0]; // dv/dt = u
478            })
479            .initial_state(vec![0.0, 0.0])
480            .time_span(0.0, 2.0)
481            .n_segments(10)
482            .terminal_cost(|y| 100.0 * ((y[0] - 1.0).powi(2) + y[1].powi(2)))
483            .running_cost(|_t, _y, u| 0.01 * u[0].powi(2))
484            .max_iter(200)
485            .solve()
486            .expect("shooting solve failed");
487
488        let x_final = result.final_state[0];
489        assert!(
490            (x_final - 1.0).abs() < 0.3,
491            "x(T) = {x_final}, expected within 0.3 of 1.0"
492        );
493    }
494
495    /// Minimum-energy control: 1-state dx/dt = u.
496    /// Terminal cost: 1000*(x-1)^2.
497    /// Running cost: u^2.
498    /// T = 1.0, 10 segments.
499    #[test]
500    fn test_minimum_energy_control() {
501        let result = ShootingProblem::new(1, 1)
502            .dynamics(|_t, _y, dydt, u| {
503                dydt[0] = u[0];
504            })
505            .initial_state(vec![0.0])
506            .time_span(0.0, 1.0)
507            .n_segments(10)
508            .terminal_cost(|y| 1000.0 * (y[0] - 1.0).powi(2))
509            .running_cost(|_t, _y, u| u[0].powi(2))
510            .max_iter(200)
511            .solve()
512            .expect("shooting solve failed");
513
514        let x_final = result.final_state[0];
515        assert!(
516            (x_final - 1.0).abs() < 0.3,
517            "x(T) = {x_final}, expected within 0.3 of 1.0"
518        );
519    }
520
521    /// Pure terminal cost (no running cost).
522    /// 1-state dx/dt = u, terminal cost: (x-3)^2.
523    /// T = 1.0, 5 segments.
524    #[test]
525    fn test_pure_terminal_cost() {
526        let result = ShootingProblem::new(1, 1)
527            .dynamics(|_t, _y, dydt, u| {
528                dydt[0] = u[0];
529            })
530            .initial_state(vec![0.0])
531            .time_span(0.0, 1.0)
532            .n_segments(5)
533            .terminal_cost(|y| (y[0] - 3.0).powi(2))
534            .max_iter(200)
535            .solve()
536            .expect("shooting solve failed");
537
538        let x_final = result.final_state[0];
539        assert!(
540            (x_final - 3.0).abs() < 0.5,
541            "x(T) = {x_final}, expected within 0.5 of 3.0"
542        );
543    }
544
545    /// Trajectory output structure check.
546    /// 1-state dx/dt = u, terminal cost: x^2.
547    #[test]
548    fn test_trajectory_output() {
549        let result = ShootingProblem::new(1, 1)
550            .dynamics(|_t, _y, dydt, u| {
551                dydt[0] = u[0];
552            })
553            .initial_state(vec![0.0])
554            .time_span(0.0, 1.0)
555            .n_segments(5)
556            .terminal_cost(|y| y[0].powi(2))
557            .max_iter(50)
558            .solve()
559            .expect("shooting solve failed");
560
561        // Trajectory arrays are populated.
562        assert!(
563            !result.t_trajectory.is_empty(),
564            "t_trajectory should be non-empty"
565        );
566        assert!(
567            !result.y_trajectory.is_empty(),
568            "y_trajectory should be non-empty"
569        );
570
571        // First time is t0.
572        assert!(
573            (result.t_trajectory[0] - 0.0).abs() < 1e-12,
574            "first time should be t0=0.0, got {}",
575            result.t_trajectory[0],
576        );
577
578        // Last time is approximately tf.
579        let t_last = *result.t_trajectory.last().unwrap();
580        assert!(
581            (t_last - 1.0).abs() < 1e-6,
582            "last time should be ~tf=1.0, got {t_last}"
583        );
584
585        // y_trajectory length is consistent.
586        assert!(
587            !result.y_trajectory.is_empty(),
588            "y_trajectory should have entries"
589        );
590        assert_eq!(
591            result.y_trajectory.len(),
592            result.t_trajectory.len() * result.n_states,
593            "y_trajectory length mismatch"
594        );
595
596        // n_states matches.
597        assert_eq!(result.n_states, 1);
598    }
599}