1use faer::diag::generic::Diag;
2use faer::dyn_stack::{MemBuffer, MemStack, StackReq};
3use faer::linalg::matmul::matmul;
4use faer::linalg::svd::{
5 ComputeSvdVectors, SvdError, pseudoinverse_from_svd, pseudoinverse_from_svd_scratch, svd,
6 svd_scratch,
7};
8use faer::{Accum, Mat, MatMut, MatRef, Par};
9use reborrow::{Reborrow, ReborrowMut};
10
11pub struct Workspace {
13 req: StackReq,
14 buffer: MemBuffer,
15}
16
17impl Workspace {
18 pub fn empty() -> Self {
19 Self {
20 req: StackReq::empty(),
21 buffer: MemBuffer::new(StackReq::empty()),
22 }
23 }
24
25 pub fn stack(&mut self, req: StackReq) -> &mut MemStack {
26 if self.req.or(req) != self.req {
27 self.req = req;
28 self.buffer = MemBuffer::new(req);
29 }
30
31 MemStack::new(&mut self.buffer)
32 }
33}
34
35impl Clone for Workspace {
36 fn clone(&self) -> Self {
37 Self {
38 req: self.req,
39 buffer: MemBuffer::new(self.req),
40 }
41 }
42}
43
44#[derive(Clone)]
45pub struct LeastSquares {
46 workspace: Workspace,
48 s: Vec<f64>,
49 u: Mat<f64>,
50 v: Mat<f64>,
51 pinv: Mat<f64>,
52}
53
54impl Default for LeastSquares {
55 fn default() -> Self {
56 Self {
57 workspace: Workspace::empty(),
58 s: Vec::default(),
59 u: Mat::zeros(0, 0),
60 v: Mat::zeros(0, 0),
61 pinv: Mat::zeros(0, 0),
62 }
63 }
64}
65
66impl LeastSquares {
67 fn compute_psuedo_inverse(&mut self, m: MatRef<f64>) -> Result<(), SvdError> {
68 assert!(m.nrows() >= m.ncols());
70
71 let compute = ComputeSvdVectors::Full;
72 let par = Par::Seq;
73
74 let nrows = m.nrows();
75 let ncols = m.ncols();
76
77 let svd_reqs = svd_scratch::<f64>(
79 nrows,
80 ncols,
81 compute,
82 compute,
83 par,
84 faer::prelude::default(),
85 );
86 let pinv_regs = pseudoinverse_from_svd_scratch::<f64>(nrows, ncols, par);
87 let stack = self
88 .workspace
89 .stack(StackReq::any_of(&[svd_reqs, pinv_regs]));
90
91 let size = nrows.min(ncols);
92 self.s.resize(size, 0.0);
93 self.u.resize_with(nrows, nrows, |_, _| 0.0);
94 self.v.resize_with(ncols, ncols, |_, _| 0.0);
95
96 svd(
97 m,
98 Diag::from_slice_mut(&mut self.s),
99 Some(self.u.rb_mut()),
100 Some(self.v.rb_mut()),
101 par,
102 stack,
103 faer::prelude::default(),
104 )?;
105
106 self.pinv.resize_with(ncols, nrows, |_, _| 0.0);
107 pseudoinverse_from_svd(
108 self.pinv.rb_mut(),
109 Diag::from_slice(&mut self.s),
110 self.u.rb(),
111 self.v.rb(),
112 par,
113 stack,
114 );
115
116 Ok(())
117 }
118
119 pub fn overdetermined(
120 &mut self,
121 m: MatRef<f64>,
122 a: MatMut<f64>,
123 b: MatRef<f64>,
124 ) -> Result<(), SvdError> {
125 assert!(a.nrows() == m.ncols() && a.ncols() == b.ncols() && b.nrows() == m.nrows());
126 assert!(m.nrows() >= m.ncols());
127
128 self.compute_psuedo_inverse(m)?;
129 matmul(a, Accum::Replace, self.pinv.rb(), b, 1.0, Par::Seq);
130
131 Ok(())
132 }
133
134 pub fn underdetermined(
135 &mut self,
136 m: MatRef<f64>,
137 a: MatMut<f64>,
138 b: MatRef<f64>,
139 ) -> Result<(), SvdError> {
140 assert!(a.nrows() == m.ncols() && a.ncols() == b.ncols() && b.nrows() == m.nrows());
141 assert!(m.nrows() <= m.ncols());
142
143 self.compute_psuedo_inverse(m.transpose())?;
144 matmul(a, Accum::Replace, self.pinv.transpose(), b, 1.0, Par::Seq);
145
146 Ok(())
147 }
148
149 pub fn least_squares(
150 &mut self,
151 m: MatRef<f64>,
152 a: MatMut<f64>,
153 b: MatRef<f64>,
154 ) -> Result<(), SvdError> {
155 if m.nrows() >= m.ncols() {
156 self.overdetermined(m, a, b)
157 } else {
158 self.underdetermined(m, a, b)
159 }
160 }
161}