Skip to main content

aeon_tk/element/
helpers.rs

1use faer::diag::generic::Diag;
2use faer::dyn_stack::{MemBuffer, MemStack, StackReq};
3use faer::linalg::matmul::matmul;
4use faer::linalg::svd::{
5    ComputeSvdVectors, SvdError, pseudoinverse_from_svd, pseudoinverse_from_svd_scratch, svd,
6    svd_scratch,
7};
8use faer::{Accum, Mat, MatMut, MatRef, Par};
9use reborrow::{Reborrow, ReborrowMut};
10
11/// A faer workspace
12pub struct Workspace {
13    req: StackReq,
14    buffer: MemBuffer,
15}
16
17impl Workspace {
18    pub fn empty() -> Self {
19        Self {
20            req: StackReq::empty(),
21            buffer: MemBuffer::new(StackReq::empty()),
22        }
23    }
24
25    pub fn stack(&mut self, req: StackReq) -> &mut MemStack {
26        if self.req.or(req) != self.req {
27            self.req = req;
28            self.buffer = MemBuffer::new(req);
29        }
30
31        MemStack::new(&mut self.buffer)
32    }
33}
34
35impl Clone for Workspace {
36    fn clone(&self) -> Self {
37        Self {
38            req: self.req,
39            buffer: MemBuffer::new(self.req),
40        }
41    }
42}
43
44#[derive(Clone)]
45pub struct LeastSquares {
46    /// Cache for workspace memory used when computing svd.
47    workspace: Workspace,
48    s: Vec<f64>,
49    u: Mat<f64>,
50    v: Mat<f64>,
51    pinv: Mat<f64>,
52}
53
54impl Default for LeastSquares {
55    fn default() -> Self {
56        Self {
57            workspace: Workspace::empty(),
58            s: Vec::default(),
59            u: Mat::zeros(0, 0),
60            v: Mat::zeros(0, 0),
61            pinv: Mat::zeros(0, 0),
62        }
63    }
64}
65
66impl LeastSquares {
67    fn compute_psuedo_inverse(&mut self, m: MatRef<f64>) -> Result<(), SvdError> {
68        // We can only compute the psuedo inverse for overdetermined systems.
69        assert!(m.nrows() >= m.ncols());
70
71        let compute = ComputeSvdVectors::Full;
72        let par = Par::Seq;
73
74        let nrows = m.nrows();
75        let ncols = m.ncols();
76
77        // Compute memory requirements
78        let svd_reqs = svd_scratch::<f64>(
79            nrows,
80            ncols,
81            compute,
82            compute,
83            par,
84            faer::prelude::default(),
85        );
86        let pinv_regs = pseudoinverse_from_svd_scratch::<f64>(nrows, ncols, par);
87        let stack = self
88            .workspace
89            .stack(StackReq::any_of(&[svd_reqs, pinv_regs]));
90
91        let size = nrows.min(ncols);
92        self.s.resize(size, 0.0);
93        self.u.resize_with(nrows, nrows, |_, _| 0.0);
94        self.v.resize_with(ncols, ncols, |_, _| 0.0);
95
96        svd(
97            m,
98            Diag::from_slice_mut(&mut self.s),
99            Some(self.u.rb_mut()),
100            Some(self.v.rb_mut()),
101            par,
102            stack,
103            faer::prelude::default(),
104        )?;
105
106        self.pinv.resize_with(ncols, nrows, |_, _| 0.0);
107        pseudoinverse_from_svd(
108            self.pinv.rb_mut(),
109            Diag::from_slice(&mut self.s),
110            self.u.rb(),
111            self.v.rb(),
112            par,
113            stack,
114        );
115
116        Ok(())
117    }
118
119    pub fn overdetermined(
120        &mut self,
121        m: MatRef<f64>,
122        a: MatMut<f64>,
123        b: MatRef<f64>,
124    ) -> Result<(), SvdError> {
125        assert!(a.nrows() == m.ncols() && a.ncols() == b.ncols() && b.nrows() == m.nrows());
126        assert!(m.nrows() >= m.ncols());
127
128        self.compute_psuedo_inverse(m)?;
129        matmul(a, Accum::Replace, self.pinv.rb(), b, 1.0, Par::Seq);
130
131        Ok(())
132    }
133
134    pub fn underdetermined(
135        &mut self,
136        m: MatRef<f64>,
137        a: MatMut<f64>,
138        b: MatRef<f64>,
139    ) -> Result<(), SvdError> {
140        assert!(a.nrows() == m.ncols() && a.ncols() == b.ncols() && b.nrows() == m.nrows());
141        assert!(m.nrows() <= m.ncols());
142
143        self.compute_psuedo_inverse(m.transpose())?;
144        matmul(a, Accum::Replace, self.pinv.transpose(), b, 1.0, Par::Seq);
145
146        Ok(())
147    }
148
149    pub fn least_squares(
150        &mut self,
151        m: MatRef<f64>,
152        a: MatMut<f64>,
153        b: MatRef<f64>,
154    ) -> Result<(), SvdError> {
155        if m.nrows() >= m.ncols() {
156            self.overdetermined(m, a, b)
157        } else {
158            self.underdetermined(m, a, b)
159        }
160    }
161}