tinympc_rs/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![allow(non_snake_case)]
3
4use nalgebra::{RealField, SMatrix, SVector, SVectorView, SVectorViewMut, Scalar, convert};
5
6pub mod cache;
7pub mod constraint;
8pub mod project;
9
10pub use cache::{Cache, Error};
11pub use constraint::Constraint;
12pub use project::*;
13
14mod util;
15
16pub type LtiFn<T, const NX: usize, const NU: usize> =
17    fn(SVectorViewMut<T, NX>, SVectorView<T, NX>, SVectorView<T, NU>);
18
19#[derive(Debug, PartialEq, Clone, Copy)]
20pub enum TerminationReason {
21    /// The solver converged to within the defined tolerances
22    Converged,
23    /// The solver reached the maximum number of iterations allowed
24    MaxIters,
25}
26
27#[derive(Debug)]
28pub struct TinyMpc<
29    T,
30    CACHE: Cache<T, NX, NU>,
31    const NX: usize,
32    const NU: usize,
33    const HX: usize,
34    const HU: usize,
35> {
36    cache: CACHE,
37    state: State<T, NX, NU, HX, HU>,
38    pub config: Config<T>,
39}
40
41#[derive(Debug)]
42pub struct Config<T> {
43    /// The convergence tolerance for the primal residual (default 0.001)
44    pub prim_tol: T,
45
46    /// The convergence tolerance for the dual residual (default 0.001)
47    pub dual_tol: T,
48
49    /// Maximum iterations without converging before terminating (default 50)
50    pub max_iter: usize,
51
52    /// Number of iterations between evaluating convergence (default 5)
53    pub do_check: usize,
54
55    /// Relaxation, values `1.5-1.8` may improve convergence (default 1.0)
56    pub relaxation: T,
57}
58
59#[derive(Debug)]
60pub struct State<T, const NX: usize, const NU: usize, const HX: usize, const HU: usize> {
61    // Linear state space model
62    A: SMatrix<T, NX, NX>,
63    B: SMatrix<T, NX, NU>,
64
65    // For sparse system dynamics
66    sys: Option<LtiFn<T, NX, NU>>,
67
68    // State and input tracking error predictions
69    ex: SMatrix<T, NX, HX>,
70    eu: SMatrix<T, NU, HU>,
71
72    // State tracking dynamics mismatch
73    cx: SMatrix<T, NX, HX>,
74    cp: SMatrix<T, NX, HX>,
75
76    // Linear cost matrices
77    q: SMatrix<T, NX, HX>,
78    r: SMatrix<T, NU, HU>,
79
80    // Riccati backward pass terms
81    p: SMatrix<T, NX, HX>,
82    d: SMatrix<T, NU, HU>,
83
84    // Number of iterations for latest solve
85    iter: usize,
86}
87
88pub struct Problem<
89    'a,
90    T,
91    C,
92    const NX: usize,
93    const NU: usize,
94    const HX: usize,
95    const HU: usize,
96    XProj = (),
97    UProj = (),
98> where
99    T: Scalar + RealField + Copy,
100    C: Cache<T, NX, NU>,
101    XProj: Project<T, NX, HX>,
102    UProj: Project<T, NU, HU>,
103{
104    mpc: &'a mut TinyMpc<T, C, NX, NU, HX, HU>,
105    x_now: SVector<T, NX>,
106    x_ref: Option<&'a SMatrix<T, NX, HX>>,
107    u_ref: Option<&'a SMatrix<T, NU, HU>>,
108    x_con: Option<&'a mut [Constraint<T, XProj, NX, HX>]>,
109    u_con: Option<&'a mut [Constraint<T, UProj, NU, HU>]>,
110}
111
112impl<'a, T, C, XProj, UProj, const NX: usize, const NU: usize, const HX: usize, const HU: usize>
113    Problem<'a, T, C, NX, NU, HX, HU, XProj, UProj>
114where
115    T: Scalar + RealField + Copy,
116    C: Cache<T, NX, NU>,
117    XProj: Project<T, NX, HX>,
118    UProj: Project<T, NU, HU>,
119{
120    /// Set the reference for state variables
121    pub fn x_reference(mut self, x_ref: &'a SMatrix<T, NX, HX>) -> Self {
122        self.x_ref = Some(x_ref);
123        self
124    }
125
126    /// Set the reference for input variables
127    pub fn u_reference(mut self, u_ref: &'a SMatrix<T, NU, HU>) -> Self {
128        self.u_ref = Some(u_ref);
129        self
130    }
131
132    /// Set constraints on the state variables
133    pub fn x_constraints<Proj: Project<T, NX, HX>>(
134        self,
135        x_con: &'a mut [Constraint<T, Proj, NX, HX>],
136    ) -> Problem<'a, T, C, NX, NU, HX, HU, Proj, UProj> {
137        Problem {
138            mpc: self.mpc,
139            x_now: self.x_now,
140            x_ref: self.x_ref,
141            u_ref: self.u_ref,
142            x_con: Some(x_con),
143            u_con: self.u_con,
144        }
145    }
146
147    /// Set constraints on the input variables
148    pub fn u_constraints<Proj: Project<T, NU, HU>>(
149        self,
150        u_con: &'a mut [Constraint<T, Proj, NU, HU>],
151    ) -> Problem<'a, T, C, NX, NU, HX, HU, XProj, Proj> {
152        Problem {
153            mpc: self.mpc,
154            x_now: self.x_now,
155            x_ref: self.x_ref,
156            u_ref: self.u_ref,
157            x_con: self.x_con,
158            u_con: Some(u_con),
159        }
160    }
161
162    /// Run the solver
163    pub fn solve(self) -> Solution<'a, T, NX, NU, HX, HU> {
164        self.mpc
165            .solve(self.x_now, self.x_ref, self.u_ref, self.x_con, self.u_con)
166    }
167}
168
169impl<T, C: Cache<T, NX, NU>, const NX: usize, const NU: usize, const HX: usize, const HU: usize>
170    TinyMpc<T, C, NX, NU, HX, HU>
171where
172    T: Scalar + RealField + Copy,
173{
174    #[must_use]
175    #[inline(always)]
176    pub fn new(A: SMatrix<T, NX, NX>, B: SMatrix<T, NX, NU>, cache: C) -> Self {
177        // Compile-time guard against invalid horizon lengths
178        const {
179            assert!(HX > HU, "`HX` must be larger than `HU`");
180            assert!(HU > 0, "`HU` must be non-zero");
181        }
182
183        Self {
184            config: Config {
185                prim_tol: convert(1e-2),
186                dual_tol: convert(1e-2),
187                max_iter: 50,
188                do_check: 5,
189                relaxation: T::one(),
190            },
191            cache,
192            state: State {
193                A,
194                B,
195                sys: None,
196                cx: SMatrix::zeros(),
197                cp: SMatrix::zeros(),
198                q: SMatrix::zeros(),
199                r: SMatrix::zeros(),
200                p: SMatrix::zeros(),
201                d: SMatrix::zeros(),
202                ex: SMatrix::zeros(),
203                eu: SMatrix::zeros(),
204                iter: 0,
205            },
206        }
207    }
208
209    pub fn with_sys(mut self, sys: LtiFn<T, NX, NU>) -> Self {
210        self.state.sys = Some(sys);
211        self
212    }
213
214    #[inline(always)]
215    pub fn initial_condition(
216        &mut self,
217        x_now: SVector<T, NX>,
218    ) -> Problem<'_, T, C, NX, NU, HX, HU> {
219        Problem {
220            mpc: self,
221            x_now,
222            x_ref: None,
223            u_ref: None,
224            x_con: None,
225            u_con: None,
226        }
227    }
228
229    #[inline(always)]
230    pub fn solve<'a>(
231        &'a mut self,
232        x_now: SVector<T, NX>,
233        x_ref: Option<&'a SMatrix<T, NX, HX>>,
234        u_ref: Option<&'a SMatrix<T, NU, HU>>,
235        x_con: Option<&mut [Constraint<T, impl Project<T, NX, HX>, NX, HX>]>,
236        u_con: Option<&mut [Constraint<T, impl Project<T, NU, HU>, NU, HU>]>,
237    ) -> Solution<'a, T, NX, NU, HX, HU> {
238        let mut reason = TerminationReason::MaxIters;
239
240        // We flatten the None variant into an empty slice
241        let x_con = x_con.unwrap_or(&mut [][..]);
242        let u_con = u_con.unwrap_or(&mut [][..]);
243
244        // Set initial error state and warm-start constraints
245        self.set_initial_conditions(x_now, x_ref, u_ref);
246        self.warm_start_constraints(x_con, u_con);
247
248        let mut prim_residual = T::zero();
249        let mut dual_residual = T::zero();
250
251        self.state.iter = 0;
252        while self.state.iter < self.config.max_iter {
253            profiling::scope!("solve loop", format!("iter: {}", self.state.iter));
254
255            self.update_cost(x_con, u_con);
256
257            self.backward_pass();
258
259            self.forward_pass();
260
261            self.update_constraints(x_ref, u_ref, x_con, u_con);
262
263            if self.check_termination(&mut prim_residual, &mut dual_residual, x_con, u_con) {
264                reason = TerminationReason::Converged;
265                self.state.iter += 1;
266                break;
267            }
268
269            self.state.iter += 1;
270        }
271
272        Solution {
273            x_ref,
274            u_ref,
275            x: &self.state.ex,
276            u: &self.state.eu,
277            reason,
278            iterations: self.state.iter,
279            prim_residual,
280            dual_residual: dual_residual * self.cache.get_active().rho,
281        }
282    }
283
284    #[inline(always)]
285    fn should_compute_residuals(&self) -> bool {
286        self.state.iter % self.config.do_check == 0
287    }
288
289    #[inline(always)]
290    #[profiling::function]
291    fn set_initial_conditions(
292        &mut self,
293        x_now: SVector<T, NX>,
294        x_ref: Option<&SMatrix<T, NX, HX>>,
295        u_ref: Option<&SMatrix<T, NU, HU>>,
296    ) {
297        if let Some(x_ref) = x_ref {
298            profiling::scope!("affine state reference term");
299            x_now.sub_to(&x_ref.column(0), &mut self.state.ex.column_mut(0));
300            self.state.A.mul_to(&x_ref, &mut self.state.cx);
301            for i in 0..HX - 1 {
302                let mut cx_col = self.state.cx.column_mut(i);
303                cx_col.axpy(-T::one(), &x_ref.column(i + 1), T::one());
304            }
305        } else {
306            self.state.ex.set_column(0, &x_now);
307        }
308
309        if let Some(u_ref) = u_ref {
310            profiling::scope!("affine input reference term");
311            for i in 0..HX - 1 {
312                let mut cx_col = self.state.cx.column_mut(i);
313                let u_ref_col = u_ref.column(i.min(HU - 1));
314                cx_col.gemv(-T::one(), &self.state.B, &u_ref_col, T::one());
315            }
316        }
317
318        self.update_tracking_mismatch_plqr();
319    }
320
321    #[inline(always)]
322    fn update_tracking_mismatch_plqr(&mut self) {
323        // Note: using `sygemv` to exploit the symmetry of Plqr is actually
324        // slower than just doing a regular matrix-vector multiplication,
325        // since sygemv adds additional indexing overhead.
326        let cache = self.cache.get_active();
327        cache.Plqr.mul_to(&self.state.cx, &mut self.state.cp);
328    }
329
330    /// Shift the dual variables by one time step for more accurate hot starting
331    #[inline(always)]
332    #[profiling::function]
333    fn warm_start_constraints(
334        &mut self,
335        x_con: &mut [Constraint<T, impl Project<T, NX, HX>, NX, HX>],
336        u_con: &mut [Constraint<T, impl Project<T, NU, HU>, NU, HU>],
337    ) {
338        for con in x_con {
339            util::shift_columns_left(&mut con.dual);
340            util::shift_columns_left(&mut con.slac);
341        }
342
343        for con in u_con {
344            util::shift_columns_left(&mut con.dual);
345            util::shift_columns_left(&mut con.slac);
346        }
347    }
348
349    /// Update linear control cost terms based on constraint violations
350    #[inline(always)]
351    #[profiling::function]
352    fn update_cost(
353        &mut self,
354        x_con: &mut [Constraint<T, impl Project<T, NX, HX>, NX, HX>],
355        u_con: &mut [Constraint<T, impl Project<T, NU, HU>, NU, HU>],
356    ) {
357        let s = &mut self.state;
358        let c = self.cache.get_active();
359
360        // Add cost contribution for state constraint violations
361        let mut x_con_iter = x_con.iter_mut();
362        if let Some(x_con_first) = x_con_iter.next() {
363            profiling::scope!("update state cost");
364            x_con_first.set_cost(&mut s.q);
365            for x_con_next in x_con_iter {
366                x_con_next.add_cost(&mut s.q);
367            }
368            s.q.scale_mut(c.rho);
369        } else {
370            s.q = SMatrix::<T, NX, HX>::zeros()
371        }
372
373        // Add cost contribution for input constraint violations
374        let mut u_con_iter = u_con.iter_mut();
375        if let Some(u_con_first) = u_con_iter.next() {
376            profiling::scope!("update input cost");
377            u_con_first.set_cost(&mut s.r);
378            for u_con_next in u_con_iter {
379                u_con_next.add_cost(&mut s.r);
380            }
381            s.r.scale_mut(c.rho);
382        } else {
383            s.r = SMatrix::<T, NU, HU>::zeros()
384        }
385
386        // Extract ADMM cost term for Riccati terminal condition
387        s.p.set_column(HX - 1, &(s.q.column(HX - 1)));
388    }
389
390    /// Backward pass to update Ricatti variables
391    #[inline(always)]
392    #[profiling::function]
393    fn backward_pass(&mut self) {
394        let s = &mut self.state;
395        let c = self.cache.get_active();
396
397        for i in (0..HX - 1).rev() {
398            let (mut p_now, mut p_fut) = util::column_pair_mut(&mut s.p, i, i + 1);
399            let mut r_col = s.r.column_mut(i.min(HU - 1));
400
401            // Reused calculation: [[[i+1]]] <- (p[i+1] + Plqr * w[i])
402            p_fut.axpy(T::one(), &s.cp.column(i), T::one());
403
404            // Calc: p[i] = AmBKt * [[[i+1]]] - Klqr' * r[:,u_index] + q[i]
405            p_now.gemv(T::one(), &c.AmBKt, &p_fut, T::zero());
406            p_now.gemv_tr(T::one(), &c.nKlqr, &r_col, T::one());
407            p_now.axpy(T::one(), &s.q.column(i), T::one());
408
409            if i < HU {
410                let mut d_col = s.d.column_mut(i);
411
412                // Calc: d[i] = RpBPBi * (B' * [[[i+1]]] + r[i])
413                r_col.gemv_tr(T::one(), &s.B, &p_fut, T::one());
414                d_col.gemv(T::one(), &c.RpBPBi, &r_col, T::zero());
415            }
416        }
417    }
418
419    /// Use LQR feedback policy to roll out trajectory
420    #[inline(always)]
421    #[profiling::function]
422    fn forward_pass(&mut self) {
423        let s = &mut self.state;
424        let c = self.cache.get_active();
425
426        if let Some(system) = s.sys {
427            // Roll out trajectory up to the control horizon (HU)
428            for i in 0..HU {
429                let (ex_now, mut ex_fut) = util::column_pair_mut(&mut s.ex, i, i + 1);
430                let mut u_col = s.eu.column_mut(i);
431
432                u_col.gemv(T::one(), &c.nKlqr, &ex_now, T::zero());
433                u_col.axpy(-T::one(), &s.d.column(i), T::one());
434
435                system(ex_fut.as_view_mut(), ex_now.as_view(), u_col.as_view());
436                ex_fut.axpy(T::one(), &s.cx.column(i), T::one());
437            }
438
439            // Roll out rest of trajectory keeping u constant
440            for i in HU..HX - 1 {
441                let (ex_now, mut ex_fut) = util::column_pair_mut(&mut s.ex, i, i + 1);
442                let u_col = s.eu.column(HU - 1);
443
444                system(ex_fut.as_view_mut(), ex_now.as_view(), u_col.as_view());
445                ex_fut.axpy(T::one(), &s.cx.column(i), T::one());
446            }
447        } else {
448            // Roll out trajectory up to the control horizon (HU)
449            for i in 0..HU {
450                let (ex_now, mut ex_fut) = util::column_pair_mut(&mut s.ex, i, i + 1);
451                let mut u_col = s.eu.column_mut(i);
452
453                // Calc: u[i] = -Klqr * ex[i] + d[i]
454                u_col.gemv(T::one(), &c.nKlqr, &ex_now, T::zero());
455                u_col.axpy(-T::one(), &s.d.column(i), T::one());
456
457                // Calc x[i+1] = A * x[i] + B * u[i] + w[i]
458                ex_fut.gemv(T::one(), &s.A, &ex_now, T::zero());
459                ex_fut.gemv(T::one(), &s.B, &u_col, T::one());
460                ex_fut.axpy(T::one(), &s.cx.column(i), T::one());
461            }
462
463            // Roll out rest of trajectory keeping u constant
464            for i in HU..HX - 1 {
465                let (ex_now, mut ex_fut) = util::column_pair_mut(&mut s.ex, i, i + 1);
466                let u_col = s.eu.column(HU - 1);
467
468                // Calc x[i+1] = A * x[i] + B * u[i] + w[i]
469                ex_fut.gemv(T::one(), &s.A, &ex_now, T::zero());
470                ex_fut.gemv(T::one(), &s.B, &u_col, T::one());
471                ex_fut.axpy(T::one(), &s.cx.column(i), T::one());
472            }
473        }
474    }
475
476    /// Project slack variables into their feasible domain and update dual variables
477    #[inline(always)]
478    #[profiling::function]
479    fn update_constraints(
480        &mut self,
481        x_ref: Option<&SMatrix<T, NX, HX>>,
482        u_ref: Option<&SMatrix<T, NU, HU>>,
483        x_con: &mut [Constraint<T, impl Project<T, NX, HX>, NX, HX>],
484        u_con: &mut [Constraint<T, impl Project<T, NU, HU>, NU, HU>],
485    ) {
486        let compute_residuals = self.should_compute_residuals();
487        let s = &mut self.state;
488
489        let (x_points, u_points) = if self.config.relaxation != T::one() {
490            profiling::scope!("apply relaxation to state and input");
491
492            // Use Riccati matrices to store relaxed x and u matrices.
493            s.q.copy_from(&s.ex);
494            s.r.copy_from(&s.eu);
495
496            let alpha = self.config.relaxation;
497
498            s.q.scale_mut(alpha);
499            s.r.scale_mut(alpha);
500
501            for con in x_con.as_mut() {
502                for (mut prim, slac) in s.q.column_iter_mut().zip(con.slac.column_iter()) {
503                    prim.axpy(T::one() - alpha, &slac, T::one());
504                }
505            }
506
507            for con in u_con.as_mut() {
508                for (mut prim, slac) in s.r.column_iter_mut().zip(con.slac.column_iter()) {
509                    prim.axpy(T::one() - alpha, &slac, T::one());
510                }
511            }
512
513            // Buffers now contain: x' = alpha * x + (1 - alpha) * z
514            (&s.q, &s.r)
515        } else {
516            // Just use original predictions
517            (&s.ex, &s.eu)
518        };
519
520        // Use cost matrices as scratch buffers
521        let u_scratch = &mut s.d;
522        let x_scratch = &mut s.p;
523
524        for con in x_con {
525            con.constrain(compute_residuals, x_points, x_ref, x_scratch);
526        }
527
528        for con in u_con {
529            con.constrain(compute_residuals, u_points, u_ref, u_scratch);
530        }
531    }
532
533    /// Check for termination condition by evaluating residuals
534    #[inline(always)]
535    #[profiling::function]
536    fn check_termination(
537        &mut self,
538        max_prim_residual: &mut T,
539        max_dual_residual: &mut T,
540        x_con: &mut [Constraint<T, impl Project<T, NX, HX>, NX, HX>],
541        u_con: &mut [Constraint<T, impl Project<T, NU, HU>, NU, HU>],
542    ) -> bool {
543        let c = self.cache.get_active();
544        let cfg = &self.config;
545
546        if !self.should_compute_residuals() {
547            return false;
548        }
549
550        *max_prim_residual = T::zero();
551        *max_dual_residual = T::zero();
552
553        for con in x_con.iter() {
554            *max_prim_residual = (*max_prim_residual).max(con.max_prim_residual);
555            *max_dual_residual = (*max_dual_residual).max(con.max_dual_residual);
556        }
557
558        for con in u_con.iter() {
559            *max_prim_residual = (*max_prim_residual).max(con.max_prim_residual);
560            *max_dual_residual = (*max_dual_residual).max(con.max_dual_residual);
561        }
562
563        let terminate =
564            *max_prim_residual < cfg.prim_tol && *max_dual_residual * c.rho < cfg.dual_tol;
565
566        // Try to adapt rho
567        if !terminate
568            && let Some(scalar) = self
569                .cache
570                .update_active(*max_prim_residual, *max_dual_residual)
571        {
572            profiling::scope!("cache updated, rescale all dual variables");
573
574            self.update_tracking_mismatch_plqr();
575
576            for con in x_con.iter_mut() {
577                con.rescale_dual(scalar)
578            }
579
580            for con in u_con.iter_mut() {
581                con.rescale_dual(scalar)
582            }
583        }
584
585        terminate
586    }
587
588    /// Get the number of iterations of the previous solve
589    pub fn get_num_iters(&self) -> usize {
590        self.state.iter
591    }
592
593    /// Get the system input `u` for the time `i`
594    pub fn get_u_at(&self, i: usize) -> SVector<T, NU> {
595        self.state.eu.column(i).into()
596    }
597
598    /// Get the system input `u` for the current time
599    pub fn get_u(&self) -> SVector<T, NU> {
600        self.state.eu.column(0).into()
601    }
602
603    /// Get reference to matrix containing state predictions
604    pub fn get_x_matrix(&self) -> &SMatrix<T, NX, HX> {
605        &self.state.ex
606    }
607
608    /// Get reference to matrix containing input predictions
609    pub fn get_u_matrix(&self) -> &SMatrix<T, NU, HU> {
610        &self.state.eu
611    }
612}
613
614pub struct Solution<'a, T, const NX: usize, const NU: usize, const HX: usize, const HU: usize> {
615    x_ref: Option<&'a SMatrix<T, NX, HX>>,
616    u_ref: Option<&'a SMatrix<T, NU, HU>>,
617    x: &'a SMatrix<T, NX, HX>,
618    u: &'a SMatrix<T, NU, HU>,
619    pub reason: TerminationReason,
620    pub iterations: usize,
621    pub prim_residual: T,
622    pub dual_residual: T,
623}
624
625impl<T: RealField + Copy, const NX: usize, const NU: usize, const HX: usize, const HU: usize>
626    Solution<'_, T, NX, NU, HX, HU>
627{
628    /// Get the predictiction of states for this solution
629    pub fn x_prediction(&self) -> SMatrix<T, NX, HX> {
630        if let Some(x_ref) = self.x_ref.as_ref() {
631            self.x + *x_ref
632        } else {
633            self.x.clone_owned()
634        }
635    }
636
637    /// Get the predictiction of inputs for this solution
638    pub fn u_prediction(&self) -> SMatrix<T, NU, HU> {
639        if let Some(u_ref) = self.u_ref.as_ref() {
640            self.u + *u_ref
641        } else {
642            self.u.clone_owned()
643        }
644    }
645
646    /// Get the current contron input to be applied
647    pub fn u_now(&self) -> SVector<T, NU> {
648        if let Some(u_ref) = self.u_ref.as_ref() {
649            self.u.column(0) + u_ref.column(0)
650        } else {
651            self.u.column(0).into()
652        }
653    }
654}