tinympc_rs/
cache.rs

1use nalgebra::{RealField, SMatrix, Scalar, convert};
2
3/// Errors that can occur during cache setup
4#[derive(Debug, PartialEq, Clone, Copy)]
5pub enum Error {
6    /// The value of rho must be strictly positive `(rho > 0)`
7    RhoNotPositive,
8    /// The matrix `R_aug + B^T * P * B` is not invertible
9    RpBPBNotInvertible,
10    /// The resulting matrices contained non-finite elements (Inf or NaN)
11    NonFiniteValues,
12}
13
14pub trait Cache<T, const NX: usize, const NU: usize> {
15    /// Updates which cache is active by evaluating the primal and dual residuals.
16    ///
17    /// Returns: A scalar (old_rho/new_rho) to be applied to constraint duals in case the cache changed
18    fn update_active(&mut self, prim_residual: T, dual_residual: T) -> Option<T>;
19
20    /// Get a reference to the currently active cache.
21    fn get_active(&self) -> &SingleCache<T, NX, NU>;
22}
23
24/// Contains all pre-computed values for a given problem and value of rho
25#[derive(Debug)]
26pub struct SingleCache<T, const NX: usize, const NU: usize> {
27    /// Penalty-parameter for this cache
28    pub(crate) rho: T,
29
30    /// (Negated) Infinite-time horizon LQR gain
31    pub(crate) nKlqr: SMatrix<T, NU, NX>,
32
33    /// Infinite-time horizon LQR cost-to-go
34    pub(crate) Plqr: SMatrix<T, NX, NX>,
35
36    /// Precomputed `inv(R_aug + B^T * Plqr * B)`
37    pub(crate) RpBPBi: SMatrix<T, NU, NU>,
38
39    /// Precomputed `(A - B * Klqr)^T`
40    pub(crate) AmBKt: SMatrix<T, NX, NX>,
41}
42
43impl<T, const NX: usize, const NU: usize> SingleCache<T, NX, NU>
44where
45    T: Scalar + RealField + Copy,
46{
47    pub fn new(
48        rho: T,
49        iters: usize,
50        A: &SMatrix<T, NX, NX>,
51        B: &SMatrix<T, NX, NU>,
52        Q: &SMatrix<T, NX, NX>,
53        R: &SMatrix<T, NU, NU>,
54        S: &SMatrix<T, NX, NU>,
55    ) -> Result<Self, Error> {
56        if !rho.is_positive() {
57            return Err(Error::RhoNotPositive);
58        }
59
60        let Q = Q.symmetric_part();
61        let R = R.symmetric_part();
62
63        // ADMM-augmented cost matrices for LQR problem
64        let Q_aug = Q + SMatrix::from_diagonal_element(rho);
65        let R_aug = R + SMatrix::from_diagonal_element(rho);
66
67        let mut Klqr = SMatrix::zeros();
68        let mut Plqr = Q_aug.clone_owned();
69
70        for _ in 0..iters {
71            Klqr = (R_aug + B.transpose() * Plqr * B)
72                .try_inverse()
73                .ok_or(Error::RpBPBNotInvertible)?
74                * (S.transpose() + B.transpose() * Plqr * A);
75            Plqr = A.transpose() * Plqr * A - A.transpose() * Plqr * B * Klqr + Q_aug;
76        }
77
78        let RpBPBi = (R_aug + B.transpose() * Plqr * B)
79            .try_inverse()
80            .ok_or(Error::RpBPBNotInvertible)?;
81        let AmBKt = (A - B * Klqr).transpose();
82        let nKlqr = -Klqr;
83
84        // If RpBPBi and AmBKt are finite, so are all the other values
85        ([].iter())
86            .chain(RpBPBi.iter())
87            .chain(AmBKt.iter())
88            .all(|x| x.is_finite())
89            .then_some(SingleCache {
90                rho,
91                nKlqr,
92                Plqr,
93                RpBPBi,
94                AmBKt,
95            })
96            .ok_or(Error::NonFiniteValues)
97    }
98}
99
100impl<T, const NX: usize, const NU: usize> Cache<T, NX, NU> for SingleCache<T, NX, NU>
101where
102    T: Scalar + RealField + Copy,
103{
104    fn update_active(&mut self, _prim_residual: T, _dual_residual: T) -> Option<T> {
105        None
106    }
107
108    fn get_active(&self) -> &SingleCache<T, NX, NU> {
109        self
110    }
111}
112
113/// Contains an array of pre-computed values for a given problem and value of rho
114#[derive(Debug)]
115pub struct ArrayCache<T, const NX: usize, const NU: usize, const NUM: usize> {
116    threshold: T,
117    active_index: usize,
118    caches: [SingleCache<T, NX, NU>; NUM],
119}
120
121impl<T, const NX: usize, const NU: usize, const NUM: usize> ArrayCache<T, NX, NU, NUM>
122where
123    T: Scalar + RealField + Copy,
124{
125    pub fn new(
126        central_rho: T,
127        threshold: T,
128        factor: T,
129        iters: usize,
130        A: &SMatrix<T, NX, NX>,
131        B: &SMatrix<T, NX, NU>,
132        Q: &SMatrix<T, NX, NX>,
133        R: &SMatrix<T, NU, NU>,
134        S: &SMatrix<T, NX, NU>,
135    ) -> Result<Self, Error> {
136        let active_index = NUM / 2;
137        let caches = crate::util::try_array_from_fn(|index| {
138            let diff = index as i32 - active_index as i32;
139            let mult = factor.powf(convert(diff as f64));
140            let rho = central_rho * mult;
141            SingleCache::new(rho, iters, A, B, Q, R, S)
142        })?;
143
144        Ok(Self {
145            threshold,
146            active_index,
147            caches,
148        })
149    }
150}
151
152impl<T, const NX: usize, const NU: usize, const NUM: usize> Cache<T, NX, NU>
153    for ArrayCache<T, NX, NU, NUM>
154where
155    T: Scalar + RealField + Copy,
156{
157    #[inline(always)]
158    fn update_active(&mut self, prim_residual: T, dual_residual: T) -> Option<T> {
159        let mut cache = &self.caches[self.active_index];
160        let prev_rho = cache.rho;
161
162        // TODO: It seems to work better without this?
163        // Since we are using a scaled dual formulation
164        let dual_residual = dual_residual * prev_rho;
165
166        // For much larger primal residuals, increase rho
167        if prim_residual > dual_residual * self.threshold {
168            if self.active_index < NUM - 1 {
169                self.active_index += 1;
170                cache = &self.caches[self.active_index];
171            }
172        }
173        // For much larger dual residuals, decrease rho
174        else if dual_residual > prim_residual * self.threshold {
175            if self.active_index > 0 {
176                self.active_index -= 1;
177                cache = &self.caches[self.active_index];
178            }
179        }
180
181        // If the value of rho changed we must also rescale all duals
182        (prev_rho != cache.rho).then(|| prev_rho / cache.rho)
183    }
184
185    #[inline(always)]
186    fn get_active(&self) -> &SingleCache<T, NX, NU> {
187        &self.caches[self.active_index]
188    }
189}