1use std::sync::Arc;
16use std::time::Instant;
17
18use numra_core::Scalar;
19use numra_ode::{DoPri5, OdeProblem, Solver, SolverOptions};
20use numra_optim::OptimProblem;
21
22use crate::error::OcpError;
23
24type DynamicsFn<S> = dyn Fn(S, &[S], &mut [S], &[S]) + Send + Sync;
30
31type TerminalCostFn<S> = dyn Fn(&[S]) -> S + Send + Sync;
33
34type RunningCostFn<S> = dyn Fn(S, &[S], &[S]) -> S + Send + Sync;
36
37type TerminalConstraintFn<S> = dyn Fn(&[S]) -> Vec<S> + Send + Sync;
39
40#[derive(Clone, Debug)]
42pub struct ShootingResult<S: Scalar> {
43 pub controls: Vec<S>,
45 pub final_state: Vec<S>,
47 pub objective: S,
49 pub converged: bool,
51 pub message: String,
53 pub iterations: usize,
55 pub wall_time_secs: f64,
57 pub t_trajectory: Vec<S>,
59 pub y_trajectory: Vec<S>,
61 pub n_states: usize,
63}
64
65pub struct ShootingProblem<S: Scalar> {
71 n_states: usize,
72 n_controls: usize,
73 dynamics: Option<Box<DynamicsFn<S>>>,
74 y0: Option<Vec<S>>,
75 t0: S,
76 tf: S,
77 n_segments: usize,
78 control_bounds: Vec<Option<(S, S)>>,
79 terminal_cost: Option<Box<TerminalCostFn<S>>>,
80 running_cost: Option<Box<RunningCostFn<S>>>,
81 terminal_constraints: Option<Box<TerminalConstraintFn<S>>>,
82 ode_rtol: S,
83 ode_atol: S,
84 max_iter: usize,
85}
86
87impl<S: Scalar> ShootingProblem<S> {
88 pub fn new(n_states: usize, n_controls: usize) -> Self {
93 Self {
94 n_states,
95 n_controls,
96 dynamics: None,
97 y0: None,
98 t0: S::ZERO,
99 tf: S::ONE,
100 n_segments: 10,
101 control_bounds: vec![None; n_controls],
102 terminal_cost: None,
103 running_cost: None,
104 terminal_constraints: None,
105 ode_rtol: S::from_f64(1e-8),
106 ode_atol: S::from_f64(1e-10),
107 max_iter: 200,
108 }
109 }
110
111 pub fn dynamics<F>(mut self, f: F) -> Self
113 where
114 F: Fn(S, &[S], &mut [S], &[S]) + Send + Sync + 'static,
115 {
116 self.dynamics = Some(Box::new(f));
117 self
118 }
119
120 pub fn initial_state(mut self, y0: Vec<S>) -> Self {
122 self.y0 = Some(y0);
123 self
124 }
125
126 pub fn time_span(mut self, t0: S, tf: S) -> Self {
128 self.t0 = t0;
129 self.tf = tf;
130 self
131 }
132
133 pub fn n_segments(mut self, n: usize) -> Self {
135 self.n_segments = n;
136 self
137 }
138
139 pub fn control_bounds(mut self, bounds: Vec<Option<(S, S)>>) -> Self {
141 self.control_bounds = bounds;
142 self
143 }
144
145 pub fn terminal_cost<F>(mut self, f: F) -> Self
147 where
148 F: Fn(&[S]) -> S + Send + Sync + 'static,
149 {
150 self.terminal_cost = Some(Box::new(f));
151 self
152 }
153
154 pub fn running_cost<F>(mut self, f: F) -> Self
156 where
157 F: Fn(S, &[S], &[S]) -> S + Send + Sync + 'static,
158 {
159 self.running_cost = Some(Box::new(f));
160 self
161 }
162
163 pub fn terminal_constraint<F>(mut self, f: F) -> Self
165 where
166 F: Fn(&[S]) -> Vec<S> + Send + Sync + 'static,
167 {
168 self.terminal_constraints = Some(Box::new(f));
169 self
170 }
171
172 pub fn ode_tolerances(mut self, rtol: S, atol: S) -> Self {
174 self.ode_rtol = rtol;
175 self.ode_atol = atol;
176 self
177 }
178
179 pub fn max_iter(mut self, n: usize) -> Self {
181 self.max_iter = n;
182 self
183 }
184
185 pub fn solve(self) -> Result<ShootingResult<S>, OcpError>
191 where
192 S: faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
193 {
194 let start = Instant::now();
195
196 let dynamics = self.dynamics.ok_or(OcpError::NoDynamics)?;
198 let y0 = self.y0.ok_or(OcpError::NoInitialState)?;
199
200 if y0.len() != self.n_states {
201 return Err(OcpError::DimensionMismatch(format!(
202 "y0 length {} != n_states {}",
203 y0.len(),
204 self.n_states,
205 )));
206 }
207
208 if self.terminal_cost.is_none() && self.running_cost.is_none() {
209 return Err(OcpError::Other(
210 "at least one of terminal_cost or running_cost must be set".into(),
211 ));
212 }
213
214 let n_states = self.n_states;
215 let n_controls = self.n_controls;
216 let n_segments = self.n_segments;
217 let n_decision = n_controls * n_segments;
218 let t0 = self.t0;
219 let tf = self.tf;
220 let dt = (tf - t0) / S::from_usize(n_segments);
221 let ode_rtol = self.ode_rtol;
222 let ode_atol = self.ode_atol;
223
224 let dynamics = Arc::new(dynamics);
226 let y0 = Arc::new(y0);
227 let terminal_cost: Option<Arc<Box<TerminalCostFn<S>>>> = self.terminal_cost.map(Arc::new);
228 let running_cost: Option<Arc<Box<RunningCostFn<S>>>> = self.running_cost.map(Arc::new);
229
230 let params = SimParams {
231 n_states,
232 n_controls,
233 n_segments,
234 t0,
235 dt,
236 ode_rtol,
237 ode_atol,
238 };
239
240 let dyn_obj = Arc::clone(&dynamics);
242 let y0_obj = Arc::clone(&y0);
243 let tc_obj = terminal_cost.clone();
244 let rc_obj = running_cost.clone();
245 let p_obj = params;
246
247 let big = S::from_f64(1e20);
248 let objective_fn = move |u: &[S]| -> S {
249 let rc_ref = rc_obj.as_ref().map(|b| &***b as &RunningCostFn<S>);
250 let tc_ref = tc_obj.as_ref().map(|b| &***b as &TerminalCostFn<S>);
251 match simulate(&dyn_obj, &y0_obj, u, &p_obj, rc_ref, tc_ref) {
252 Ok((_traj_t, _traj_y, cost)) => cost,
253 Err(_) => big,
254 }
255 };
256
257 let u0 = vec![S::ZERO; n_decision];
259 let mut prob = OptimProblem::new(n_decision)
260 .x0(&u0)
261 .objective(objective_fn)
262 .max_iter(self.max_iter);
263
264 for seg in 0..n_segments {
266 for ctrl in 0..n_controls {
267 if let Some(&Some((lo, hi))) = self.control_bounds.get(ctrl) {
268 prob = prob.bounds(seg * n_controls + ctrl, (lo, hi));
269 }
270 }
271 }
272
273 if let Some(tc_fn) = self.terminal_constraints {
275 let tc_fn = Arc::new(tc_fn);
276
277 let dummy = vec![S::ZERO; n_states];
279 let n_constraints = tc_fn(&dummy).len();
280
281 let big_c = S::from_f64(1e20);
282 for ci in 0..n_constraints {
283 let dyn_c = Arc::clone(&dynamics);
284 let y0_c = Arc::clone(&y0);
285 let tc_c = Arc::clone(&tc_fn);
286 let p_c = params;
287
288 prob = prob.constraint_eq(move |u: &[S]| -> S {
289 match simulate_final_state(&dyn_c, &y0_c, u, &p_c) {
290 Ok(y_final) => tc_c(&y_final)[ci],
291 Err(_) => big_c,
292 }
293 });
294 }
295 }
296
297 let optim_result = prob.solve().map_err(OcpError::OptimFailed)?;
299
300 let optimal_u = &optim_result.x;
302 let rc_final = running_cost.as_ref().map(|b| &***b as &RunningCostFn<S>);
303 let tc_final = terminal_cost.as_ref().map(|b| &***b as &TerminalCostFn<S>);
304 let (traj_t, traj_y, obj) =
305 simulate(&dynamics, &y0, optimal_u, ¶ms, rc_final, tc_final)
306 .map_err(OcpError::IntegrationFailed)?;
307
308 let final_state = if traj_t.is_empty() {
309 y0.as_ref().clone()
310 } else {
311 let last_idx = traj_t.len() - 1;
312 traj_y[last_idx * n_states..(last_idx + 1) * n_states].to_vec()
313 };
314
315 Ok(ShootingResult {
316 controls: optimal_u.clone(),
317 final_state,
318 objective: obj,
319 converged: optim_result.converged,
320 message: optim_result.message.clone(),
321 iterations: optim_result.iterations,
322 wall_time_secs: start.elapsed().as_secs_f64(),
323 t_trajectory: traj_t,
324 y_trajectory: traj_y,
325 n_states,
326 })
327 }
328}
329
330#[derive(Clone, Copy)]
336struct SimParams<S: Scalar> {
337 n_states: usize,
338 n_controls: usize,
339 n_segments: usize,
340 t0: S,
341 dt: S,
342 ode_rtol: S,
343 ode_atol: S,
344}
345
346fn simulate<S: Scalar>(
354 dynamics: &Arc<Box<DynamicsFn<S>>>,
355 y0: &Arc<Vec<S>>,
356 u: &[S],
357 p: &SimParams<S>,
358 running_cost: Option<&RunningCostFn<S>>,
359 terminal_cost: Option<&TerminalCostFn<S>>,
360) -> Result<(Vec<S>, Vec<S>, S), String> {
361 let options = SolverOptions::default().rtol(p.ode_rtol).atol(p.ode_atol);
362
363 let mut traj_t: Vec<S> = Vec::new();
364 let mut traj_y: Vec<S> = Vec::new();
365 let mut y_cur = y0.as_ref().clone();
366 let mut total_cost = S::ZERO;
367
368 for seg in 0..p.n_segments {
369 let t_start = p.t0 + S::from_usize(seg) * p.dt;
370 let t_end = p.t0 + S::from_usize(seg + 1) * p.dt;
371 let u_seg: Vec<S> = u[seg * p.n_controls..(seg + 1) * p.n_controls].to_vec();
372
373 let dyn_ref = Arc::clone(dynamics);
375 let u_seg_clone = u_seg.clone();
376 let rhs = move |t: S, y: &[S], dydt: &mut [S]| {
377 dyn_ref(t, y, dydt, &u_seg_clone);
378 };
379
380 let problem = OdeProblem::new(rhs, t_start, t_end, y_cur.clone());
381 let result = DoPri5::solve(&problem, t_start, t_end, &y_cur, &options)
382 .map_err(|e| format!("segment {seg}: {e}"))?;
383
384 if !result.success {
385 return Err(format!("segment {seg}: {}", result.message));
386 }
387
388 if let Some(rc) = running_cost {
390 let n_pts = result.t.len();
391 for k in 0..n_pts.saturating_sub(1) {
392 let tk = result.t[k];
393 let tk1 = result.t[k + 1];
394 let yk = &result.y[k * p.n_states..(k + 1) * p.n_states];
395 let yk1 = &result.y[(k + 1) * p.n_states..(k + 2) * p.n_states];
396 let lk = rc(tk, yk, &u_seg);
397 let lk1 = rc(tk1, yk1, &u_seg);
398 total_cost += S::HALF * (tk1 - tk) * (lk + lk1);
399 }
400 }
401
402 let skip = if seg == 0 { 0 } else { 1 };
405 for k in skip..result.t.len() {
406 traj_t.push(result.t[k]);
407 traj_y.extend_from_slice(&result.y[k * p.n_states..(k + 1) * p.n_states]);
408 }
409
410 y_cur = result
412 .y_final()
413 .ok_or_else(|| format!("segment {seg}: empty result"))?;
414 }
415
416 if let Some(tc) = terminal_cost {
418 total_cost += tc(&y_cur);
419 }
420
421 Ok((traj_t, traj_y, total_cost))
422}
423
424fn simulate_final_state<S: Scalar>(
426 dynamics: &Arc<Box<DynamicsFn<S>>>,
427 y0: &Arc<Vec<S>>,
428 u: &[S],
429 p: &SimParams<S>,
430) -> Result<Vec<S>, String> {
431 let options = SolverOptions::default().rtol(p.ode_rtol).atol(p.ode_atol);
432 let mut y_cur = y0.as_ref().clone();
433
434 for seg in 0..p.n_segments {
435 let t_start = p.t0 + S::from_usize(seg) * p.dt;
436 let t_end = p.t0 + S::from_usize(seg + 1) * p.dt;
437 let u_seg: Vec<S> = u[seg * p.n_controls..(seg + 1) * p.n_controls].to_vec();
438
439 let dyn_ref = Arc::clone(dynamics);
440 let rhs = move |t: S, y: &[S], dydt: &mut [S]| {
441 dyn_ref(t, y, dydt, &u_seg);
442 };
443
444 let problem = OdeProblem::new(rhs, t_start, t_end, y_cur.clone());
445 let result = DoPri5::solve(&problem, t_start, t_end, &y_cur, &options)
446 .map_err(|e| format!("segment {seg}: {e}"))?;
447
448 if !result.success {
449 return Err(format!("segment {seg}: {}", result.message));
450 }
451
452 y_cur = result
453 .y_final()
454 .ok_or_else(|| format!("segment {seg}: empty result"))?;
455 }
456
457 Ok(y_cur)
458}
459
460#[cfg(test)]
465mod tests {
466 use super::*;
467
468 #[test]
473 fn test_double_integrator_terminal_cost() {
474 let result = ShootingProblem::new(2, 1)
475 .dynamics(|_t, y, dydt, u| {
476 dydt[0] = y[1]; dydt[1] = u[0]; })
479 .initial_state(vec![0.0, 0.0])
480 .time_span(0.0, 2.0)
481 .n_segments(10)
482 .terminal_cost(|y| 100.0 * ((y[0] - 1.0).powi(2) + y[1].powi(2)))
483 .running_cost(|_t, _y, u| 0.01 * u[0].powi(2))
484 .max_iter(200)
485 .solve()
486 .expect("shooting solve failed");
487
488 let x_final = result.final_state[0];
489 assert!(
490 (x_final - 1.0).abs() < 0.3,
491 "x(T) = {x_final}, expected within 0.3 of 1.0"
492 );
493 }
494
495 #[test]
500 fn test_minimum_energy_control() {
501 let result = ShootingProblem::new(1, 1)
502 .dynamics(|_t, _y, dydt, u| {
503 dydt[0] = u[0];
504 })
505 .initial_state(vec![0.0])
506 .time_span(0.0, 1.0)
507 .n_segments(10)
508 .terminal_cost(|y| 1000.0 * (y[0] - 1.0).powi(2))
509 .running_cost(|_t, _y, u| u[0].powi(2))
510 .max_iter(200)
511 .solve()
512 .expect("shooting solve failed");
513
514 let x_final = result.final_state[0];
515 assert!(
516 (x_final - 1.0).abs() < 0.3,
517 "x(T) = {x_final}, expected within 0.3 of 1.0"
518 );
519 }
520
521 #[test]
525 fn test_pure_terminal_cost() {
526 let result = ShootingProblem::new(1, 1)
527 .dynamics(|_t, _y, dydt, u| {
528 dydt[0] = u[0];
529 })
530 .initial_state(vec![0.0])
531 .time_span(0.0, 1.0)
532 .n_segments(5)
533 .terminal_cost(|y| (y[0] - 3.0).powi(2))
534 .max_iter(200)
535 .solve()
536 .expect("shooting solve failed");
537
538 let x_final = result.final_state[0];
539 assert!(
540 (x_final - 3.0).abs() < 0.5,
541 "x(T) = {x_final}, expected within 0.5 of 3.0"
542 );
543 }
544
545 #[test]
548 fn test_trajectory_output() {
549 let result = ShootingProblem::new(1, 1)
550 .dynamics(|_t, _y, dydt, u| {
551 dydt[0] = u[0];
552 })
553 .initial_state(vec![0.0])
554 .time_span(0.0, 1.0)
555 .n_segments(5)
556 .terminal_cost(|y| y[0].powi(2))
557 .max_iter(50)
558 .solve()
559 .expect("shooting solve failed");
560
561 assert!(
563 !result.t_trajectory.is_empty(),
564 "t_trajectory should be non-empty"
565 );
566 assert!(
567 !result.y_trajectory.is_empty(),
568 "y_trajectory should be non-empty"
569 );
570
571 assert!(
573 (result.t_trajectory[0] - 0.0).abs() < 1e-12,
574 "first time should be t0=0.0, got {}",
575 result.t_trajectory[0],
576 );
577
578 let t_last = *result.t_trajectory.last().unwrap();
580 assert!(
581 (t_last - 1.0).abs() < 1e-6,
582 "last time should be ~tf=1.0, got {t_last}"
583 );
584
585 assert!(
587 !result.y_trajectory.is_empty(),
588 "y_trajectory should have entries"
589 );
590 assert_eq!(
591 result.y_trajectory.len(),
592 result.t_trajectory.len() * result.n_states,
593 "y_trajectory length mismatch"
594 );
595
596 assert_eq!(result.n_states, 1);
598 }
599}