Skip to main content

numra_ocp/
collocation.rs

1//! Direct collocation for optimal control.
2//!
3//! Transcribes an optimal control problem into a nonlinear programme (NLP)
4//! by parameterising states and controls at mesh nodes and enforcing
5//! dynamics via collocation defect constraints.
6//!
7//! Two collocation schemes are supported:
8//!
9//! - **Trapezoidal**: 2nd-order, simplest. Defect constraint:
10//!   `x_{k+1} - x_k - (h/2)*(f_k + f_{k+1}) = 0`
11//!
12//! - **Hermite-Simpson**: 3rd-order, the industry standard for direct
13//!   collocation. Uses a midpoint and cubic Hermite interpolation.
14//!
15//! Author: Moussa Leblouba
16//! Date: 9 February 2026
17//! Modified: 2 May 2026
18
19use std::sync::Arc;
20use std::time::Instant;
21
22use numra_core::Scalar;
23use numra_optim::OptimProblem;
24
25use crate::error::OcpError;
26
27// ---------------------------------------------------------------------------
28// Types
29// ---------------------------------------------------------------------------
30
31/// Dynamics closure: `f(t, x, u, p, dxdt)`.
32type DynamicsFn<S> = dyn Fn(S, &[S], &[S], &[S], &mut [S]) + Send + Sync;
33
34/// Terminal cost closure: `phi(x(tf), tf) -> S`.
35type TerminalCostFn<S> = dyn Fn(&[S], S) -> S + Send + Sync;
36
37/// Running cost closure: `L(t, x, u) -> S`.
38type RunningCostFn<S> = dyn Fn(S, &[S], &[S]) -> S + Send + Sync;
39
40/// Terminal constraint closure: `psi(x(tf)) -> Vec<S>`, each component = 0.
41type TerminalConstraintFn<S> = dyn Fn(&[S]) -> Vec<S> + Send + Sync;
42
43/// Path constraint closure: `c(t, x, u) -> Vec<S>`.
44type PathConstraintFn<S> = dyn Fn(S, &[S], &[S]) -> Vec<S> + Send + Sync;
45
46/// Path constraints with bounds.
47type PathConstraints<S> = (Box<PathConstraintFn<S>>, Vec<(S, S)>);
48
49/// Collocation scheme.
50#[derive(Clone, Copy, Debug, PartialEq)]
51pub enum CollocationScheme {
52    /// Trapezoidal rule (2nd order).
53    Trapezoidal,
54    /// Hermite-Simpson (3rd order) -- the recommended default.
55    HermiteSimpson,
56}
57
58/// Result of a direct collocation solve.
59#[derive(Clone, Debug)]
60pub struct CollocationResult<S: Scalar> {
61    /// Time grid at mesh nodes.
62    pub time: Vec<S>,
63    /// States at each mesh node (flat row-major: `states[i * nx + j]`).
64    pub states: Vec<S>,
65    /// Controls at each mesh node (flat row-major: `controls[i * nu + j]`).
66    pub controls: Vec<S>,
67    /// Optimal parameters (empty if no free parameters).
68    pub parameters: Vec<S>,
69    /// Optimal objective value.
70    pub objective: S,
71    /// Final time (may differ from initial if free).
72    pub final_time: S,
73    /// Whether the optimizer converged.
74    pub converged: bool,
75    /// Human-readable status message.
76    pub message: String,
77    /// Number of optimizer iterations.
78    pub iterations: usize,
79    /// Wall-clock time in seconds.
80    pub wall_time_secs: f64,
81    /// Number of states (for interpreting flat arrays).
82    pub n_states: usize,
83    /// Number of controls (for interpreting flat arrays).
84    pub n_controls: usize,
85}
86
87// ---------------------------------------------------------------------------
88// NLP layout helpers
89// ---------------------------------------------------------------------------
90
91/// NLP variable layout information. All offsets and dimensions are computed
92/// once and shared (via `Copy`) across all closures.
93///
94/// This struct is non-generic -- it only stores usize offsets and a few
95/// f64 values for the fixed time parameters.
96#[derive(Clone, Copy)]
97struct NlpLayout {
98    nx: usize,
99    nu: usize,
100    np: usize,
101    n_int: usize,
102    x_offset: usize,
103    u_offset: usize,
104    p_offset: usize,
105    tf_offset: usize,
106    has_free_tf: bool,
107    tf_fixed: f64,
108    t0: f64,
109}
110
111impl NlpLayout {
112    fn n_decision(&self) -> usize {
113        self.tf_offset + if self.has_free_tf { 1 } else { 0 }
114    }
115
116    fn tf<S: Scalar>(&self, z: &[S]) -> S {
117        if self.has_free_tf {
118            z[self.tf_offset]
119        } else {
120            S::from_f64(self.tf_fixed)
121        }
122    }
123
124    fn h<S: Scalar>(&self, z: &[S]) -> S {
125        (self.tf(z) - S::from_f64(self.t0)) / S::from_usize(self.n_int)
126    }
127
128    fn time_at<S: Scalar>(&self, z: &[S], k: usize) -> S {
129        S::from_f64(self.t0) + S::from_usize(k) * self.h(z)
130    }
131
132    fn x_range(&self, k: usize) -> (usize, usize) {
133        let start = self.x_offset + k * self.nx;
134        (start, start + self.nx)
135    }
136
137    fn u_range(&self, k: usize) -> (usize, usize) {
138        let start = self.u_offset + k * self.nu;
139        (start, start + self.nu)
140    }
141
142    fn p_range(&self) -> (usize, usize) {
143        (self.p_offset, self.p_offset + self.np)
144    }
145}
146
147// ---------------------------------------------------------------------------
148// Builder
149// ---------------------------------------------------------------------------
150
151/// Builder for direct collocation optimal control problems.
152pub struct CollocationProblem<S: Scalar> {
153    n_states: usize,
154    n_controls: usize,
155    n_params: usize,
156    dynamics: Option<Box<DynamicsFn<S>>>,
157    x0: Option<Vec<S>>,
158    t0: S,
159    tf: S,
160    free_final_time: Option<(S, S, S)>, // (tf_init, tf_lo, tf_hi)
161    n_intervals: usize,
162    scheme: CollocationScheme,
163    terminal_cost: Option<Box<TerminalCostFn<S>>>,
164    running_cost: Option<Box<RunningCostFn<S>>>,
165    terminal_constraints: Option<Box<TerminalConstraintFn<S>>>,
166    path_constraints: Option<PathConstraints<S>>,
167    control_bounds: Vec<Option<(S, S)>>,
168    state_bounds: Vec<Option<(S, S)>>,
169    params0: Option<Vec<S>>,
170    param_bounds: Vec<Option<(S, S)>>,
171    max_iter: usize,
172}
173
174impl<S: Scalar> CollocationProblem<S> {
175    /// Create a new collocation problem.
176    pub fn new(n_states: usize, n_controls: usize) -> Self {
177        Self {
178            n_states,
179            n_controls,
180            n_params: 0,
181            dynamics: None,
182            x0: None,
183            t0: S::ZERO,
184            tf: S::ONE,
185            free_final_time: None,
186            n_intervals: 20,
187            scheme: CollocationScheme::HermiteSimpson,
188            terminal_cost: None,
189            running_cost: None,
190            terminal_constraints: None,
191            path_constraints: None,
192            control_bounds: vec![None; n_controls],
193            state_bounds: vec![None; n_states],
194            params0: None,
195            param_bounds: Vec::new(),
196            max_iter: 500,
197        }
198    }
199
200    /// Set the dynamics: `f(t, x, u, p, dxdt)`.
201    pub fn dynamics<F>(mut self, f: F) -> Self
202    where
203        F: Fn(S, &[S], &[S], &[S], &mut [S]) + Send + Sync + 'static,
204    {
205        self.dynamics = Some(Box::new(f));
206        self
207    }
208
209    /// Set the number of free parameters.
210    pub fn n_params(mut self, n: usize) -> Self {
211        self.n_params = n;
212        self.param_bounds = vec![None; n];
213        self
214    }
215
216    /// Set the initial state.
217    pub fn x0(mut self, x0: &[S]) -> Self {
218        self.x0 = Some(x0.to_vec());
219        self
220    }
221
222    /// Set the time span `[t0, tf]`.
223    pub fn time_span(mut self, t0: S, tf: S) -> Self {
224        self.t0 = t0;
225        self.tf = tf;
226        self
227    }
228
229    /// Enable free final time.
230    pub fn free_final_time(mut self, tf_init: S, tf_bounds: (S, S)) -> Self {
231        self.free_final_time = Some((tf_init, tf_bounds.0, tf_bounds.1));
232        self
233    }
234
235    /// Set the number of collocation intervals.
236    pub fn n_intervals(mut self, n: usize) -> Self {
237        self.n_intervals = n;
238        self
239    }
240
241    /// Set the collocation scheme.
242    pub fn scheme(mut self, scheme: CollocationScheme) -> Self {
243        self.scheme = scheme;
244        self
245    }
246
247    /// Set the terminal cost `phi(x(tf), tf)`.
248    pub fn terminal_cost<F>(mut self, f: F) -> Self
249    where
250        F: Fn(&[S], S) -> S + Send + Sync + 'static,
251    {
252        self.terminal_cost = Some(Box::new(f));
253        self
254    }
255
256    /// Set the running cost `L(t, x, u)`.
257    pub fn running_cost<F>(mut self, f: F) -> Self
258    where
259        F: Fn(S, &[S], &[S]) -> S + Send + Sync + 'static,
260    {
261        self.running_cost = Some(Box::new(f));
262        self
263    }
264
265    /// Set terminal equality constraints `psi(x(tf)) = 0`.
266    pub fn terminal_constraint<F>(mut self, f: F) -> Self
267    where
268        F: Fn(&[S]) -> Vec<S> + Send + Sync + 'static,
269    {
270        self.terminal_constraints = Some(Box::new(f));
271        self
272    }
273
274    /// Set path constraints with bounds: `lo <= c(t, x, u) <= hi`.
275    pub fn path_constraint<F>(mut self, f: F, bounds: &[(S, S)]) -> Self
276    where
277        F: Fn(S, &[S], &[S]) -> Vec<S> + Send + Sync + 'static,
278    {
279        self.path_constraints = Some((Box::new(f), bounds.to_vec()));
280        self
281    }
282
283    /// Set bounds for each control variable (applied at every node).
284    pub fn control_bounds(mut self, bounds: &[(S, S)]) -> Self {
285        self.control_bounds = bounds.iter().map(|&b| Some(b)).collect();
286        self
287    }
288
289    /// Set bounds for each state variable (applied at every node).
290    pub fn state_bounds(mut self, bounds: &[(S, S)]) -> Self {
291        self.state_bounds = bounds.iter().map(|&b| Some(b)).collect();
292        self
293    }
294
295    /// Set initial guess for free parameters.
296    pub fn params0(mut self, p0: &[S]) -> Self {
297        self.params0 = Some(p0.to_vec());
298        self
299    }
300
301    /// Set bounds for a specific parameter.
302    pub fn param_bound(mut self, idx: usize, lo: S, hi: S) -> Self {
303        if idx < self.param_bounds.len() {
304            self.param_bounds[idx] = Some((lo, hi));
305        }
306        self
307    }
308
309    /// Set maximum optimizer iterations.
310    pub fn max_iter(mut self, n: usize) -> Self {
311        self.max_iter = n;
312        self
313    }
314
315    // -----------------------------------------------------------------------
316    // Solve
317    // -----------------------------------------------------------------------
318
319    /// Solve the collocation problem.
320    pub fn solve(self) -> Result<CollocationResult<S>, OcpError>
321    where
322        S: faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
323    {
324        let start = Instant::now();
325
326        // -- Validate -------------------------------------------------------
327        let dynamics = self.dynamics.ok_or(OcpError::NoDynamics)?;
328        let x0_val = self.x0.ok_or(OcpError::NoInitialState)?;
329
330        if x0_val.len() != self.n_states {
331            return Err(OcpError::DimensionMismatch(format!(
332                "x0 length {} != n_states {}",
333                x0_val.len(),
334                self.n_states,
335            )));
336        }
337
338        if self.terminal_cost.is_none() && self.running_cost.is_none() {
339            return Err(OcpError::Other(
340                "at least one of terminal_cost or running_cost must be set".into(),
341            ));
342        }
343
344        let nx = self.n_states;
345        let nu = self.n_controls;
346        let np = self.n_params;
347        let n_int = self.n_intervals;
348        let n_nodes = n_int + 1;
349        let scheme = self.scheme;
350
351        let has_free_tf = self.free_final_time.is_some();
352        let tf_fixed = if let Some((tfi, _, _)) = self.free_final_time {
353            tfi.to_f64()
354        } else {
355            self.tf.to_f64()
356        };
357
358        let lay = NlpLayout {
359            nx,
360            nu,
361            np,
362            n_int,
363            x_offset: 0,
364            u_offset: n_nodes * nx,
365            p_offset: n_nodes * nx + n_nodes * nu,
366            tf_offset: n_nodes * nx + n_nodes * nu + np,
367            has_free_tf,
368            tf_fixed,
369            t0: self.t0.to_f64(),
370        };
371
372        let n_decision = lay.n_decision();
373
374        // -- Build initial guess -------------------------------------------
375        let mut z0 = vec![S::ZERO; n_decision];
376
377        // Constant state initial guess (x0 everywhere).
378        for k in 0..n_nodes {
379            let (s, e) = lay.x_range(k);
380            z0[s..e].copy_from_slice(&x0_val);
381        }
382
383        // Parameters initial guess.
384        if let Some(ref p0) = self.params0 {
385            let (s, e) = lay.p_range();
386            z0[s..e].copy_from_slice(p0);
387        }
388
389        // Free final time initial.
390        if has_free_tf {
391            z0[lay.tf_offset] = S::from_f64(tf_fixed);
392        }
393
394        // -- Shared state for closures -------------------------------------
395        let dynamics = Arc::new(dynamics);
396        let terminal_cost: Option<Arc<Box<TerminalCostFn<S>>>> = self.terminal_cost.map(Arc::new);
397        let running_cost: Option<Arc<Box<RunningCostFn<S>>>> = self.running_cost.map(Arc::new);
398
399        // -- Objective function --------------------------------------------
400        let dyn_obj = Arc::clone(&dynamics);
401        let tc_obj = terminal_cost.clone();
402        let rc_obj = running_cost.clone();
403
404        let objective_fn = move |z: &[S]| -> S {
405            let tf_val = lay.tf(z);
406            let mut cost = S::ZERO;
407
408            if let Some(ref rc) = rc_obj {
409                for k in 0..n_int {
410                    let tk = lay.time_at(z, k);
411                    let tk1 = lay.time_at(z, k + 1);
412                    let h = tk1 - tk;
413                    let (xs, xe) = lay.x_range(k);
414                    let (us, ue) = lay.u_range(k);
415                    let (xs1, xe1) = lay.x_range(k + 1);
416                    let (us1, ue1) = lay.u_range(k + 1);
417                    let lk = rc(tk, &z[xs..xe], &z[us..ue]);
418                    let lk1 = rc(tk1, &z[xs1..xe1], &z[us1..ue1]);
419
420                    if scheme == CollocationScheme::HermiteSimpson {
421                        let t_mid = S::HALF * (tk + tk1);
422                        let mut x_mid = vec![S::ZERO; nx];
423                        let mut u_mid = vec![S::ZERO; nu];
424                        for j in 0..nx {
425                            x_mid[j] = S::HALF * (z[xs + j] + z[xs1 + j]);
426                        }
427                        for j in 0..nu {
428                            u_mid[j] = S::HALF * (z[us + j] + z[us1 + j]);
429                        }
430                        let (ps, pe) = lay.p_range();
431                        let mut fk = vec![S::ZERO; nx];
432                        let mut fk1 = vec![S::ZERO; nx];
433                        dyn_obj(tk, &z[xs..xe], &z[us..ue], &z[ps..pe], &mut fk);
434                        dyn_obj(tk1, &z[xs1..xe1], &z[us1..ue1], &z[ps..pe], &mut fk1);
435                        let eighth_h = h / S::from_f64(8.0);
436                        for j in 0..nx {
437                            x_mid[j] += eighth_h * (fk[j] - fk1[j]);
438                        }
439                        let l_mid = rc(t_mid, &x_mid, &u_mid);
440                        let sixth_h = h / S::from_f64(6.0);
441                        cost += sixth_h * (lk + S::from_f64(4.0) * l_mid + lk1);
442                    } else {
443                        cost += S::HALF * h * (lk + lk1);
444                    }
445                }
446            }
447
448            if let Some(ref tc) = tc_obj {
449                let (xs, xe) = lay.x_range(n_int);
450                cost += tc(&z[xs..xe], tf_val);
451            }
452
453            cost
454        };
455
456        // -- Build OptimProblem --------------------------------------------
457        let mut prob = OptimProblem::new(n_decision)
458            .x0(&z0)
459            .objective(objective_fn)
460            .max_iter(self.max_iter);
461
462        // Initial state equality constraints: x_0 = x0_val.
463        for j in 0..nx {
464            let x0_j = x0_val[j];
465            let x_off = lay.x_offset;
466            prob = prob.constraint_eq(move |z: &[S]| -> S { z[x_off + j] - x0_j });
467        }
468
469        // -- Defect constraints --------------------------------------------
470        match scheme {
471            CollocationScheme::Trapezoidal => {
472                for k in 0..n_int {
473                    for j in 0..nx {
474                        let dyn_c = Arc::clone(&dynamics);
475                        prob = prob.constraint_eq(move |z: &[S]| -> S {
476                            let h = lay.h(z);
477                            let tk = lay.time_at(z, k);
478                            let tk1 = lay.time_at(z, k + 1);
479                            let (xs, xe) = lay.x_range(k);
480                            let (us, ue) = lay.u_range(k);
481                            let (xs1, xe1) = lay.x_range(k + 1);
482                            let (us1, ue1) = lay.u_range(k + 1);
483                            let (ps, pe) = lay.p_range();
484
485                            let mut fk = vec![S::ZERO; nx];
486                            let mut fk1 = vec![S::ZERO; nx];
487                            dyn_c(tk, &z[xs..xe], &z[us..ue], &z[ps..pe], &mut fk);
488                            dyn_c(tk1, &z[xs1..xe1], &z[us1..ue1], &z[ps..pe], &mut fk1);
489
490                            z[xs1 + j] - z[xs + j] - S::HALF * h * (fk[j] + fk1[j])
491                        });
492                    }
493                }
494            }
495            CollocationScheme::HermiteSimpson => {
496                for k in 0..n_int {
497                    for j in 0..nx {
498                        let dyn_c = Arc::clone(&dynamics);
499                        prob = prob.constraint_eq(move |z: &[S]| -> S {
500                            let h = lay.h(z);
501                            let tk = lay.time_at(z, k);
502                            let tk1 = lay.time_at(z, k + 1);
503                            let t_mid = S::HALF * (tk + tk1);
504                            let (xs, xe) = lay.x_range(k);
505                            let (us, ue) = lay.u_range(k);
506                            let (xs1, xe1) = lay.x_range(k + 1);
507                            let (us1, ue1) = lay.u_range(k + 1);
508                            let (ps, pe) = lay.p_range();
509
510                            let mut fk = vec![S::ZERO; nx];
511                            let mut fk1 = vec![S::ZERO; nx];
512                            dyn_c(tk, &z[xs..xe], &z[us..ue], &z[ps..pe], &mut fk);
513                            dyn_c(tk1, &z[xs1..xe1], &z[us1..ue1], &z[ps..pe], &mut fk1);
514
515                            // Hermite midpoint state.
516                            let mut x_mid = vec![S::ZERO; nx];
517                            let mut u_mid = vec![S::ZERO; nu];
518                            let eighth_h = h / S::from_f64(8.0);
519                            for i in 0..nx {
520                                x_mid[i] = S::HALF * (z[xs + i] + z[xs1 + i])
521                                    + eighth_h * (fk[i] - fk1[i]);
522                            }
523                            for i in 0..nu {
524                                u_mid[i] = S::HALF * (z[us + i] + z[us1 + i]);
525                            }
526
527                            let mut f_mid = vec![S::ZERO; nx];
528                            dyn_c(t_mid, &x_mid, &u_mid, &z[ps..pe], &mut f_mid);
529
530                            // Simpson defect.
531                            let sixth_h = h / S::from_f64(6.0);
532                            z[xs1 + j]
533                                - z[xs + j]
534                                - sixth_h * (fk[j] + S::from_f64(4.0) * f_mid[j] + fk1[j])
535                        });
536                    }
537                }
538            }
539        }
540
541        // -- Terminal constraints ------------------------------------------
542        if let Some(tc_fn) = self.terminal_constraints {
543            let tc_fn = Arc::new(tc_fn);
544            let dummy = vec![S::ZERO; nx];
545            let n_tc = tc_fn(&dummy).len();
546
547            for ci in 0..n_tc {
548                let tc_c = Arc::clone(&tc_fn);
549                prob = prob.constraint_eq(move |z: &[S]| -> S {
550                    let (xs, xe) = lay.x_range(n_int);
551                    tc_c(&z[xs..xe])[ci]
552                });
553            }
554        }
555
556        // -- Path constraints (inequality, enforced at each node) ----------
557        if let Some((pc_fn, pc_bounds)) = self.path_constraints {
558            let pc_fn = Arc::new(pc_fn);
559            let n_pc = pc_bounds.len();
560
561            for k in 0..n_nodes {
562                #[allow(clippy::needless_range_loop)]
563                for ci in 0..n_pc {
564                    let (lo, hi) = pc_bounds[ci];
565                    let pc_lo = Arc::clone(&pc_fn);
566                    let pc_hi = Arc::clone(&pc_fn);
567
568                    // c(t, x, u) - lo >= 0
569                    prob = prob.constraint_ineq(move |z: &[S]| -> S {
570                        let tk = lay.time_at(z, k);
571                        let (xs, xe) = lay.x_range(k);
572                        let (us, ue) = lay.u_range(k);
573                        pc_lo(tk, &z[xs..xe], &z[us..ue])[ci] - lo
574                    });
575
576                    // hi - c(t, x, u) >= 0
577                    prob = prob.constraint_ineq(move |z: &[S]| -> S {
578                        let tk = lay.time_at(z, k);
579                        let (xs, xe) = lay.x_range(k);
580                        let (us, ue) = lay.u_range(k);
581                        hi - pc_hi(tk, &z[xs..xe], &z[us..ue])[ci]
582                    });
583                }
584            }
585        }
586
587        // -- Variable bounds -----------------------------------------------
588        for k in 0..n_nodes {
589            for j in 0..nx {
590                if let Some(&Some((lo, hi))) = self.state_bounds.get(j) {
591                    let (s, _) = lay.x_range(k);
592                    prob = prob.bounds(s + j, (lo, hi));
593                }
594            }
595            for j in 0..nu {
596                if let Some(&Some((lo, hi))) = self.control_bounds.get(j) {
597                    let (s, _) = lay.u_range(k);
598                    prob = prob.bounds(s + j, (lo, hi));
599                }
600            }
601        }
602
603        for j in 0..np {
604            if let Some(&Some((lo, hi))) = self.param_bounds.get(j) {
605                prob = prob.bounds(lay.p_offset + j, (lo, hi));
606            }
607        }
608
609        if let Some((_, tf_lo, tf_hi)) = self.free_final_time {
610            prob = prob.bounds(lay.tf_offset, (tf_lo, tf_hi));
611        }
612
613        // -- Solve ----------------------------------------------------------
614        let optim_result = prob.solve().map_err(OcpError::OptimFailed)?;
615        let z_opt = &optim_result.x;
616
617        // -- Extract results ------------------------------------------------
618        let tf_opt = lay.tf(z_opt);
619        let mut time = Vec::with_capacity(n_nodes);
620        let mut states = Vec::with_capacity(n_nodes * nx);
621        let mut controls = Vec::with_capacity(n_nodes * nu);
622
623        for k in 0..n_nodes {
624            time.push(lay.time_at(z_opt, k));
625            let (xs, xe) = lay.x_range(k);
626            states.extend_from_slice(&z_opt[xs..xe]);
627            let (us, ue) = lay.u_range(k);
628            controls.extend_from_slice(&z_opt[us..ue]);
629        }
630
631        let parameters = if np > 0 {
632            let (ps, pe) = lay.p_range();
633            z_opt[ps..pe].to_vec()
634        } else {
635            Vec::new()
636        };
637
638        Ok(CollocationResult {
639            time,
640            states,
641            controls,
642            parameters,
643            objective: optim_result.f,
644            final_time: tf_opt,
645            converged: optim_result.converged,
646            message: optim_result.message.clone(),
647            iterations: optim_result.iterations,
648            wall_time_secs: start.elapsed().as_secs_f64(),
649            n_states: nx,
650            n_controls: nu,
651        })
652    }
653}
654
655// ---------------------------------------------------------------------------
656// Tests
657// ---------------------------------------------------------------------------
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662
663    /// Double integrator: dx/dt = v, dv/dt = u.
664    /// Minimize integral u^2 dt, x(0)=0, v(0)=0, x(1)=1, v(1)=0.
665    /// Analytical cost = 12.
666    #[test]
667    fn test_double_integrator_hermite_simpson() {
668        let result = CollocationProblem::new(2, 1)
669            .dynamics(|_t, x, u, _p, dxdt| {
670                dxdt[0] = x[1];
671                dxdt[1] = u[0];
672            })
673            .x0(&[0.0, 0.0])
674            .time_span(0.0, 1.0)
675            .n_intervals(20)
676            .scheme(CollocationScheme::HermiteSimpson)
677            .running_cost(|_t, _x, u| u[0] * u[0])
678            .terminal_constraint(|x| vec![x[0] - 1.0, x[1]])
679            .max_iter(500)
680            .solve()
681            .expect("collocation solve failed");
682
683        assert!(result.converged, "should converge, got: {}", result.message);
684
685        let n = result.time.len();
686        let xf = result.states[(n - 1) * 2];
687        let vf = result.states[(n - 1) * 2 + 1];
688        assert!((xf - 1.0).abs() < 0.1, "x(T) = {xf}, expected ~1.0");
689        assert!(vf.abs() < 0.1, "v(T) = {vf}, expected ~0.0");
690        assert!(
691            (result.objective - 12.0).abs() < 2.0,
692            "cost = {}, expected ~12.0",
693            result.objective,
694        );
695    }
696
697    /// Same problem with trapezoidal scheme.
698    #[test]
699    fn test_double_integrator_trapezoidal() {
700        let result = CollocationProblem::new(2, 1)
701            .dynamics(|_t, x, u, _p, dxdt| {
702                dxdt[0] = x[1];
703                dxdt[1] = u[0];
704            })
705            .x0(&[0.0, 0.0])
706            .time_span(0.0, 1.0)
707            .n_intervals(20)
708            .scheme(CollocationScheme::Trapezoidal)
709            .running_cost(|_t, _x, u| u[0] * u[0])
710            .terminal_constraint(|x| vec![x[0] - 1.0, x[1]])
711            .max_iter(500)
712            .solve()
713            .expect("collocation solve failed");
714
715        assert!(result.converged, "should converge, got: {}", result.message);
716
717        let n = result.time.len();
718        let xf = result.states[(n - 1) * 2];
719        assert!((xf - 1.0).abs() < 0.1, "x(T) = {xf}, expected ~1.0");
720    }
721
722    /// Minimum energy: dx/dt = u, x(0) = 0, terminal cost: 100*(x(1)-1)^2.
723    #[test]
724    fn test_minimum_energy_collocation() {
725        let result = CollocationProblem::new(1, 1)
726            .dynamics(|_t, _x, u, _p, dxdt| {
727                dxdt[0] = u[0];
728            })
729            .x0(&[0.0])
730            .time_span(0.0, 1.0)
731            .n_intervals(10)
732            .scheme(CollocationScheme::HermiteSimpson)
733            .terminal_cost(|x, _tf| 100.0 * (x[0] - 1.0).powi(2))
734            .running_cost(|_t, _x, u| u[0] * u[0])
735            .max_iter(300)
736            .solve()
737            .expect("collocation solve failed");
738
739        let n = result.time.len();
740        let xf = result.states[n - 1];
741        assert!((xf - 1.0).abs() < 0.3, "x(T) = {xf}, expected ~1.0");
742    }
743
744    /// Control-bounded: dx/dt = u, -1 <= u <= 1, target x(2) = 3.
745    /// With T=2 and |u|<=1, best is u=1 => x(2)=2.
746    #[test]
747    fn test_control_bounded() {
748        let result = CollocationProblem::new(1, 1)
749            .dynamics(|_t, _x, u, _p, dxdt| {
750                dxdt[0] = u[0];
751            })
752            .x0(&[0.0])
753            .time_span(0.0, 2.0)
754            .n_intervals(10)
755            .scheme(CollocationScheme::HermiteSimpson)
756            .terminal_cost(|x, _tf| (x[0] - 3.0).powi(2))
757            .control_bounds(&[(-1.0, 1.0)])
758            .max_iter(300)
759            .solve()
760            .expect("collocation solve failed");
761
762        let n = result.time.len();
763        let xf = result.states[n - 1];
764        assert!((xf - 2.0).abs() < 0.3, "x(T) = {xf}, expected ~2.0");
765    }
766
767    /// Result structure: time grid, states, controls dimensions.
768    #[test]
769    fn test_result_structure() {
770        let result = CollocationProblem::new(2, 1)
771            .dynamics(|_t, x, u, _p, dxdt| {
772                dxdt[0] = x[1];
773                dxdt[1] = u[0];
774            })
775            .x0(&[0.0, 0.0])
776            .time_span(0.0, 1.0)
777            .n_intervals(5)
778            .scheme(CollocationScheme::HermiteSimpson)
779            .terminal_cost(|x, _tf| x[0].powi(2))
780            .max_iter(100)
781            .solve()
782            .expect("collocation solve failed");
783
784        assert_eq!(result.time.len(), 6);
785        assert_eq!(result.states.len(), 6 * 2);
786        assert_eq!(result.controls.len(), 6 * 1);
787        assert_eq!(result.n_states, 2);
788        assert_eq!(result.n_controls, 1);
789
790        for k in 0..5 {
791            assert!(result.time[k + 1] > result.time[k]);
792        }
793
794        assert!((result.time[0] - 0.0).abs() < 1e-12);
795        assert!((result.time[5] - 1.0).abs() < 1e-12);
796    }
797}