1#![cfg_attr(not(feature = "std"), no_std)]
2#![allow(non_snake_case)]
3
4use nalgebra::{RealField, SMatrix, SVector, SVectorView, SVectorViewMut, Scalar, convert};
5
6pub mod cache;
7pub mod constraint;
8pub mod project;
9
10pub use cache::{Cache, Error};
11pub use constraint::Constraint;
12pub use project::*;
13
14mod util;
15
16pub type LtiFn<T, const NX: usize, const NU: usize> =
17 fn(SVectorViewMut<T, NX>, SVectorView<T, NX>, SVectorView<T, NU>);
18
19#[derive(Debug, PartialEq, Clone, Copy)]
20pub enum TerminationReason {
21 Converged,
23 MaxIters,
25}
26
27#[derive(Debug)]
28pub struct TinyMpc<
29 T,
30 CACHE: Cache<T, NX, NU>,
31 const NX: usize,
32 const NU: usize,
33 const HX: usize,
34 const HU: usize,
35> {
36 cache: CACHE,
37 state: State<T, NX, NU, HX, HU>,
38 pub config: Config<T>,
39}
40
41#[derive(Debug)]
42pub struct Config<T> {
43 pub prim_tol: T,
45
46 pub dual_tol: T,
48
49 pub max_iter: usize,
51
52 pub do_check: usize,
54
55 pub relaxation: T,
57}
58
59#[derive(Debug)]
60pub struct State<T, const NX: usize, const NU: usize, const HX: usize, const HU: usize> {
61 A: SMatrix<T, NX, NX>,
63 B: SMatrix<T, NX, NU>,
64
65 sys: Option<LtiFn<T, NX, NU>>,
67
68 ex: SMatrix<T, NX, HX>,
70 eu: SMatrix<T, NU, HU>,
71
72 cx: SMatrix<T, NX, HX>,
74 cp: SMatrix<T, NX, HX>,
75
76 q: SMatrix<T, NX, HX>,
78 r: SMatrix<T, NU, HU>,
79
80 p: SMatrix<T, NX, HX>,
82 d: SMatrix<T, NU, HU>,
83
84 iter: usize,
86}
87
88pub struct Problem<
89 'a,
90 T,
91 C,
92 const NX: usize,
93 const NU: usize,
94 const HX: usize,
95 const HU: usize,
96 XProj = (),
97 UProj = (),
98> where
99 T: Scalar + RealField + Copy,
100 C: Cache<T, NX, NU>,
101 XProj: Project<T, NX, HX>,
102 UProj: Project<T, NU, HU>,
103{
104 mpc: &'a mut TinyMpc<T, C, NX, NU, HX, HU>,
105 x_now: SVector<T, NX>,
106 x_ref: Option<&'a SMatrix<T, NX, HX>>,
107 u_ref: Option<&'a SMatrix<T, NU, HU>>,
108 x_con: Option<&'a mut [Constraint<T, XProj, NX, HX>]>,
109 u_con: Option<&'a mut [Constraint<T, UProj, NU, HU>]>,
110}
111
112impl<'a, T, C, XProj, UProj, const NX: usize, const NU: usize, const HX: usize, const HU: usize>
113 Problem<'a, T, C, NX, NU, HX, HU, XProj, UProj>
114where
115 T: Scalar + RealField + Copy,
116 C: Cache<T, NX, NU>,
117 XProj: Project<T, NX, HX>,
118 UProj: Project<T, NU, HU>,
119{
120 pub fn x_reference(mut self, x_ref: &'a SMatrix<T, NX, HX>) -> Self {
122 self.x_ref = Some(x_ref);
123 self
124 }
125
126 pub fn u_reference(mut self, u_ref: &'a SMatrix<T, NU, HU>) -> Self {
128 self.u_ref = Some(u_ref);
129 self
130 }
131
132 pub fn x_constraints<Proj: Project<T, NX, HX>>(
134 self,
135 x_con: &'a mut [Constraint<T, Proj, NX, HX>],
136 ) -> Problem<'a, T, C, NX, NU, HX, HU, Proj, UProj> {
137 Problem {
138 mpc: self.mpc,
139 x_now: self.x_now,
140 x_ref: self.x_ref,
141 u_ref: self.u_ref,
142 x_con: Some(x_con),
143 u_con: self.u_con,
144 }
145 }
146
147 pub fn u_constraints<Proj: Project<T, NU, HU>>(
149 self,
150 u_con: &'a mut [Constraint<T, Proj, NU, HU>],
151 ) -> Problem<'a, T, C, NX, NU, HX, HU, XProj, Proj> {
152 Problem {
153 mpc: self.mpc,
154 x_now: self.x_now,
155 x_ref: self.x_ref,
156 u_ref: self.u_ref,
157 x_con: self.x_con,
158 u_con: Some(u_con),
159 }
160 }
161
162 pub fn solve(self) -> Solution<'a, T, NX, NU, HX, HU> {
164 self.mpc
165 .solve(self.x_now, self.x_ref, self.u_ref, self.x_con, self.u_con)
166 }
167}
168
169impl<T, C: Cache<T, NX, NU>, const NX: usize, const NU: usize, const HX: usize, const HU: usize>
170 TinyMpc<T, C, NX, NU, HX, HU>
171where
172 T: Scalar + RealField + Copy,
173{
174 #[must_use]
175 #[inline(always)]
176 pub fn new(A: SMatrix<T, NX, NX>, B: SMatrix<T, NX, NU>, cache: C) -> Self {
177 const {
179 assert!(HX > HU, "`HX` must be larger than `HU`");
180 assert!(HU > 0, "`HU` must be non-zero");
181 }
182
183 Self {
184 config: Config {
185 prim_tol: convert(1e-2),
186 dual_tol: convert(1e-2),
187 max_iter: 50,
188 do_check: 5,
189 relaxation: T::one(),
190 },
191 cache,
192 state: State {
193 A,
194 B,
195 sys: None,
196 cx: SMatrix::zeros(),
197 cp: SMatrix::zeros(),
198 q: SMatrix::zeros(),
199 r: SMatrix::zeros(),
200 p: SMatrix::zeros(),
201 d: SMatrix::zeros(),
202 ex: SMatrix::zeros(),
203 eu: SMatrix::zeros(),
204 iter: 0,
205 },
206 }
207 }
208
209 pub fn with_sys(mut self, sys: LtiFn<T, NX, NU>) -> Self {
210 self.state.sys = Some(sys);
211 self
212 }
213
214 #[inline(always)]
215 pub fn initial_condition(
216 &mut self,
217 x_now: SVector<T, NX>,
218 ) -> Problem<'_, T, C, NX, NU, HX, HU> {
219 Problem {
220 mpc: self,
221 x_now,
222 x_ref: None,
223 u_ref: None,
224 x_con: None,
225 u_con: None,
226 }
227 }
228
229 #[inline(always)]
230 pub fn solve<'a>(
231 &'a mut self,
232 x_now: SVector<T, NX>,
233 x_ref: Option<&'a SMatrix<T, NX, HX>>,
234 u_ref: Option<&'a SMatrix<T, NU, HU>>,
235 x_con: Option<&mut [Constraint<T, impl Project<T, NX, HX>, NX, HX>]>,
236 u_con: Option<&mut [Constraint<T, impl Project<T, NU, HU>, NU, HU>]>,
237 ) -> Solution<'a, T, NX, NU, HX, HU> {
238 let mut reason = TerminationReason::MaxIters;
239
240 let x_con = x_con.unwrap_or(&mut [][..]);
242 let u_con = u_con.unwrap_or(&mut [][..]);
243
244 self.set_initial_conditions(x_now, x_ref, u_ref);
246 self.warm_start_constraints(x_con, u_con);
247
248 let mut prim_residual = T::zero();
249 let mut dual_residual = T::zero();
250
251 self.state.iter = 0;
252 while self.state.iter < self.config.max_iter {
253 profiling::scope!("solve loop", format!("iter: {}", self.state.iter));
254
255 self.update_cost(x_con, u_con);
256
257 self.backward_pass();
258
259 self.forward_pass();
260
261 self.update_constraints(x_ref, u_ref, x_con, u_con);
262
263 if self.check_termination(&mut prim_residual, &mut dual_residual, x_con, u_con) {
264 reason = TerminationReason::Converged;
265 self.state.iter += 1;
266 break;
267 }
268
269 self.state.iter += 1;
270 }
271
272 Solution {
273 x_ref,
274 u_ref,
275 x: &self.state.ex,
276 u: &self.state.eu,
277 reason,
278 iterations: self.state.iter,
279 prim_residual,
280 dual_residual: dual_residual * self.cache.get_active().rho,
281 }
282 }
283
284 #[inline(always)]
285 fn should_compute_residuals(&self) -> bool {
286 self.state.iter % self.config.do_check == 0
287 }
288
289 #[inline(always)]
290 #[profiling::function]
291 fn set_initial_conditions(
292 &mut self,
293 x_now: SVector<T, NX>,
294 x_ref: Option<&SMatrix<T, NX, HX>>,
295 u_ref: Option<&SMatrix<T, NU, HU>>,
296 ) {
297 if let Some(x_ref) = x_ref {
298 profiling::scope!("affine state reference term");
299 x_now.sub_to(&x_ref.column(0), &mut self.state.ex.column_mut(0));
300 self.state.A.mul_to(&x_ref, &mut self.state.cx);
301 for i in 0..HX - 1 {
302 let mut cx_col = self.state.cx.column_mut(i);
303 cx_col.axpy(-T::one(), &x_ref.column(i + 1), T::one());
304 }
305 } else {
306 self.state.ex.set_column(0, &x_now);
307 }
308
309 if let Some(u_ref) = u_ref {
310 profiling::scope!("affine input reference term");
311 for i in 0..HX - 1 {
312 let mut cx_col = self.state.cx.column_mut(i);
313 let u_ref_col = u_ref.column(i.min(HU - 1));
314 cx_col.gemv(-T::one(), &self.state.B, &u_ref_col, T::one());
315 }
316 }
317
318 self.update_tracking_mismatch_plqr();
319 }
320
321 #[inline(always)]
322 fn update_tracking_mismatch_plqr(&mut self) {
323 let cache = self.cache.get_active();
327 cache.Plqr.mul_to(&self.state.cx, &mut self.state.cp);
328 }
329
330 #[inline(always)]
332 #[profiling::function]
333 fn warm_start_constraints(
334 &mut self,
335 x_con: &mut [Constraint<T, impl Project<T, NX, HX>, NX, HX>],
336 u_con: &mut [Constraint<T, impl Project<T, NU, HU>, NU, HU>],
337 ) {
338 for con in x_con {
339 util::shift_columns_left(&mut con.dual);
340 util::shift_columns_left(&mut con.slac);
341 }
342
343 for con in u_con {
344 util::shift_columns_left(&mut con.dual);
345 util::shift_columns_left(&mut con.slac);
346 }
347 }
348
349 #[inline(always)]
351 #[profiling::function]
352 fn update_cost(
353 &mut self,
354 x_con: &mut [Constraint<T, impl Project<T, NX, HX>, NX, HX>],
355 u_con: &mut [Constraint<T, impl Project<T, NU, HU>, NU, HU>],
356 ) {
357 let s = &mut self.state;
358 let c = self.cache.get_active();
359
360 let mut x_con_iter = x_con.iter_mut();
362 if let Some(x_con_first) = x_con_iter.next() {
363 profiling::scope!("update state cost");
364 x_con_first.set_cost(&mut s.q);
365 for x_con_next in x_con_iter {
366 x_con_next.add_cost(&mut s.q);
367 }
368 s.q.scale_mut(c.rho);
369 } else {
370 s.q = SMatrix::<T, NX, HX>::zeros()
371 }
372
373 let mut u_con_iter = u_con.iter_mut();
375 if let Some(u_con_first) = u_con_iter.next() {
376 profiling::scope!("update input cost");
377 u_con_first.set_cost(&mut s.r);
378 for u_con_next in u_con_iter {
379 u_con_next.add_cost(&mut s.r);
380 }
381 s.r.scale_mut(c.rho);
382 } else {
383 s.r = SMatrix::<T, NU, HU>::zeros()
384 }
385
386 s.p.set_column(HX - 1, &(s.q.column(HX - 1)));
388 }
389
390 #[inline(always)]
392 #[profiling::function]
393 fn backward_pass(&mut self) {
394 let s = &mut self.state;
395 let c = self.cache.get_active();
396
397 for i in (0..HX - 1).rev() {
398 let (mut p_now, mut p_fut) = util::column_pair_mut(&mut s.p, i, i + 1);
399 let mut r_col = s.r.column_mut(i.min(HU - 1));
400
401 p_fut.axpy(T::one(), &s.cp.column(i), T::one());
403
404 p_now.gemv(T::one(), &c.AmBKt, &p_fut, T::zero());
406 p_now.gemv_tr(T::one(), &c.nKlqr, &r_col, T::one());
407 p_now.axpy(T::one(), &s.q.column(i), T::one());
408
409 if i < HU {
410 let mut d_col = s.d.column_mut(i);
411
412 r_col.gemv_tr(T::one(), &s.B, &p_fut, T::one());
414 d_col.gemv(T::one(), &c.RpBPBi, &r_col, T::zero());
415 }
416 }
417 }
418
419 #[inline(always)]
421 #[profiling::function]
422 fn forward_pass(&mut self) {
423 let s = &mut self.state;
424 let c = self.cache.get_active();
425
426 if let Some(system) = s.sys {
427 for i in 0..HU {
429 let (ex_now, mut ex_fut) = util::column_pair_mut(&mut s.ex, i, i + 1);
430 let mut u_col = s.eu.column_mut(i);
431
432 u_col.gemv(T::one(), &c.nKlqr, &ex_now, T::zero());
433 u_col.axpy(-T::one(), &s.d.column(i), T::one());
434
435 system(ex_fut.as_view_mut(), ex_now.as_view(), u_col.as_view());
436 ex_fut.axpy(T::one(), &s.cx.column(i), T::one());
437 }
438
439 for i in HU..HX - 1 {
441 let (ex_now, mut ex_fut) = util::column_pair_mut(&mut s.ex, i, i + 1);
442 let u_col = s.eu.column(HU - 1);
443
444 system(ex_fut.as_view_mut(), ex_now.as_view(), u_col.as_view());
445 ex_fut.axpy(T::one(), &s.cx.column(i), T::one());
446 }
447 } else {
448 for i in 0..HU {
450 let (ex_now, mut ex_fut) = util::column_pair_mut(&mut s.ex, i, i + 1);
451 let mut u_col = s.eu.column_mut(i);
452
453 u_col.gemv(T::one(), &c.nKlqr, &ex_now, T::zero());
455 u_col.axpy(-T::one(), &s.d.column(i), T::one());
456
457 ex_fut.gemv(T::one(), &s.A, &ex_now, T::zero());
459 ex_fut.gemv(T::one(), &s.B, &u_col, T::one());
460 ex_fut.axpy(T::one(), &s.cx.column(i), T::one());
461 }
462
463 for i in HU..HX - 1 {
465 let (ex_now, mut ex_fut) = util::column_pair_mut(&mut s.ex, i, i + 1);
466 let u_col = s.eu.column(HU - 1);
467
468 ex_fut.gemv(T::one(), &s.A, &ex_now, T::zero());
470 ex_fut.gemv(T::one(), &s.B, &u_col, T::one());
471 ex_fut.axpy(T::one(), &s.cx.column(i), T::one());
472 }
473 }
474 }
475
476 #[inline(always)]
478 #[profiling::function]
479 fn update_constraints(
480 &mut self,
481 x_ref: Option<&SMatrix<T, NX, HX>>,
482 u_ref: Option<&SMatrix<T, NU, HU>>,
483 x_con: &mut [Constraint<T, impl Project<T, NX, HX>, NX, HX>],
484 u_con: &mut [Constraint<T, impl Project<T, NU, HU>, NU, HU>],
485 ) {
486 let compute_residuals = self.should_compute_residuals();
487 let s = &mut self.state;
488
489 let (x_points, u_points) = if self.config.relaxation != T::one() {
490 profiling::scope!("apply relaxation to state and input");
491
492 s.q.copy_from(&s.ex);
494 s.r.copy_from(&s.eu);
495
496 let alpha = self.config.relaxation;
497
498 s.q.scale_mut(alpha);
499 s.r.scale_mut(alpha);
500
501 for con in x_con.as_mut() {
502 for (mut prim, slac) in s.q.column_iter_mut().zip(con.slac.column_iter()) {
503 prim.axpy(T::one() - alpha, &slac, T::one());
504 }
505 }
506
507 for con in u_con.as_mut() {
508 for (mut prim, slac) in s.r.column_iter_mut().zip(con.slac.column_iter()) {
509 prim.axpy(T::one() - alpha, &slac, T::one());
510 }
511 }
512
513 (&s.q, &s.r)
515 } else {
516 (&s.ex, &s.eu)
518 };
519
520 let u_scratch = &mut s.d;
522 let x_scratch = &mut s.p;
523
524 for con in x_con {
525 con.constrain(compute_residuals, x_points, x_ref, x_scratch);
526 }
527
528 for con in u_con {
529 con.constrain(compute_residuals, u_points, u_ref, u_scratch);
530 }
531 }
532
533 #[inline(always)]
535 #[profiling::function]
536 fn check_termination(
537 &mut self,
538 max_prim_residual: &mut T,
539 max_dual_residual: &mut T,
540 x_con: &mut [Constraint<T, impl Project<T, NX, HX>, NX, HX>],
541 u_con: &mut [Constraint<T, impl Project<T, NU, HU>, NU, HU>],
542 ) -> bool {
543 let c = self.cache.get_active();
544 let cfg = &self.config;
545
546 if !self.should_compute_residuals() {
547 return false;
548 }
549
550 *max_prim_residual = T::zero();
551 *max_dual_residual = T::zero();
552
553 for con in x_con.iter() {
554 *max_prim_residual = (*max_prim_residual).max(con.max_prim_residual);
555 *max_dual_residual = (*max_dual_residual).max(con.max_dual_residual);
556 }
557
558 for con in u_con.iter() {
559 *max_prim_residual = (*max_prim_residual).max(con.max_prim_residual);
560 *max_dual_residual = (*max_dual_residual).max(con.max_dual_residual);
561 }
562
563 let terminate =
564 *max_prim_residual < cfg.prim_tol && *max_dual_residual * c.rho < cfg.dual_tol;
565
566 if !terminate
568 && let Some(scalar) = self
569 .cache
570 .update_active(*max_prim_residual, *max_dual_residual)
571 {
572 profiling::scope!("cache updated, rescale all dual variables");
573
574 self.update_tracking_mismatch_plqr();
575
576 for con in x_con.iter_mut() {
577 con.rescale_dual(scalar)
578 }
579
580 for con in u_con.iter_mut() {
581 con.rescale_dual(scalar)
582 }
583 }
584
585 terminate
586 }
587
588 pub fn get_num_iters(&self) -> usize {
590 self.state.iter
591 }
592
593 pub fn get_u_at(&self, i: usize) -> SVector<T, NU> {
595 self.state.eu.column(i).into()
596 }
597
598 pub fn get_u(&self) -> SVector<T, NU> {
600 self.state.eu.column(0).into()
601 }
602
603 pub fn get_x_matrix(&self) -> &SMatrix<T, NX, HX> {
605 &self.state.ex
606 }
607
608 pub fn get_u_matrix(&self) -> &SMatrix<T, NU, HU> {
610 &self.state.eu
611 }
612}
613
614pub struct Solution<'a, T, const NX: usize, const NU: usize, const HX: usize, const HU: usize> {
615 x_ref: Option<&'a SMatrix<T, NX, HX>>,
616 u_ref: Option<&'a SMatrix<T, NU, HU>>,
617 x: &'a SMatrix<T, NX, HX>,
618 u: &'a SMatrix<T, NU, HU>,
619 pub reason: TerminationReason,
620 pub iterations: usize,
621 pub prim_residual: T,
622 pub dual_residual: T,
623}
624
625impl<T: RealField + Copy, const NX: usize, const NU: usize, const HX: usize, const HU: usize>
626 Solution<'_, T, NX, NU, HX, HU>
627{
628 pub fn x_prediction(&self) -> SMatrix<T, NX, HX> {
630 if let Some(x_ref) = self.x_ref.as_ref() {
631 self.x + *x_ref
632 } else {
633 self.x.clone_owned()
634 }
635 }
636
637 pub fn u_prediction(&self) -> SMatrix<T, NU, HU> {
639 if let Some(u_ref) = self.u_ref.as_ref() {
640 self.u + *u_ref
641 } else {
642 self.u.clone_owned()
643 }
644 }
645
646 pub fn u_now(&self) -> SVector<T, NU> {
648 if let Some(u_ref) = self.u_ref.as_ref() {
649 self.u.column(0) + u_ref.column(0)
650 } else {
651 self.u.column(0).into()
652 }
653 }
654}