Skip to main content

anofox_ml_preprocessing/
pls.rs

1//! Partial Least Squares Regression (PLS1).
2//!
3//! Mirrors `sklearn.cross_decomposition.PLSRegression`. Implements the
4//! NIPALS algorithm for `n_components` latent variables on 1-D `y` (PLS1).
5//! 2-D `y` (PLS2) is not currently supported.
6
7use anofox_ml_core::{Fit, Predict, Result, RustMlError};
8use ndarray::{Array1, Array2};
9
10#[derive(Debug, Clone)]
11pub struct PlsRegression {
12    pub n_components: usize,
13    pub max_iter: usize,
14    pub tol: f64,
15}
16
17impl PlsRegression {
18    pub fn new(n_components: usize) -> Self {
19        Self {
20            n_components,
21            max_iter: 500,
22            tol: 1e-6,
23        }
24    }
25}
26
27#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
28pub struct FittedPlsRegression {
29    pub x_mean: Array1<f64>,
30    pub y_mean: f64,
31    pub x_std: Array1<f64>,
32    pub y_std: f64,
33    /// Regression coefficients in centred+scaled space.
34    pub coef: Array1<f64>,
35    n_features: usize,
36}
37
38impl Fit<f64> for PlsRegression {
39    type Fitted = FittedPlsRegression;
40
41    fn fit(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
42        if x.nrows() != y.len() {
43            return Err(RustMlError::ShapeMismatch(format!(
44                "X has {} rows but y has {}",
45                x.nrows(),
46                y.len()
47            )));
48        }
49        let n = x.nrows();
50        let d = x.ncols();
51        if self.n_components == 0 || self.n_components > d.min(n) {
52            return Err(RustMlError::InvalidParameter(format!(
53                "n_components must be in 1..={}",
54                d.min(n)
55            )));
56        }
57
58        // Standardize columns of X and y to unit variance (sklearn default).
59        let n_f = n as f64;
60        let mut x_mean = Array1::<f64>::zeros(d);
61        for j in 0..d {
62            x_mean[j] = x.column(j).sum() / n_f;
63        }
64        let y_mean = y.sum() / n_f;
65        let mut x_std = Array1::<f64>::ones(d);
66        for j in 0..d {
67            let mut v = 0.0;
68            for i in 0..n {
69                let dv = x[[i, j]] - x_mean[j];
70                v += dv * dv;
71            }
72            x_std[j] = (v / n_f).sqrt().max(1e-12);
73        }
74        let mut yv = 0.0;
75        for i in 0..n {
76            let dv = y[i] - y_mean;
77            yv += dv * dv;
78        }
79        let y_std = (yv / n_f).sqrt().max(1e-12);
80
81        let mut xs = Array2::<f64>::zeros((n, d));
82        let mut ys = Array1::<f64>::zeros(n);
83        for i in 0..n {
84            for j in 0..d {
85                xs[[i, j]] = (x[[i, j]] - x_mean[j]) / x_std[j];
86            }
87            ys[i] = (y[i] - y_mean) / y_std;
88        }
89
90        // NIPALS for PLS1 — Y is a single column.
91        let k = self.n_components;
92        let mut p_mat = Array2::<f64>::zeros((d, k));
93        let mut w_mat = Array2::<f64>::zeros((d, k));
94        let mut q_vec = Array1::<f64>::zeros(k);
95        let mut x_def = xs.clone();
96        let mut y_def = ys.clone();
97
98        for comp in 0..k {
99            // Weights w = X'y / ||X'y||
100            let mut w = Array1::<f64>::zeros(d);
101            for j in 0..d {
102                let mut s = 0.0;
103                for i in 0..n {
104                    s += x_def[[i, j]] * y_def[i];
105                }
106                w[j] = s;
107            }
108            let nw = w.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-12);
109            for j in 0..d {
110                w[j] /= nw;
111            }
112            // Scores t = X w
113            let mut t = Array1::<f64>::zeros(n);
114            for i in 0..n {
115                let mut s = 0.0;
116                for j in 0..d {
117                    s += x_def[[i, j]] * w[j];
118                }
119                t[i] = s;
120            }
121            // Loadings p = X' t / (t' t)
122            let tt: f64 = t.iter().map(|v| v * v).sum::<f64>().max(1e-12);
123            let mut p = Array1::<f64>::zeros(d);
124            for j in 0..d {
125                let mut s = 0.0;
126                for i in 0..n {
127                    s += x_def[[i, j]] * t[i];
128                }
129                p[j] = s / tt;
130            }
131            // Regression coef on Y: q = y' t / (t' t)
132            let mut q = 0.0;
133            for i in 0..n {
134                q += y_def[i] * t[i];
135            }
136            q /= tt;
137
138            // Deflate.
139            for i in 0..n {
140                for j in 0..d {
141                    x_def[[i, j]] -= t[i] * p[j];
142                }
143                y_def[i] -= t[i] * q;
144            }
145
146            for j in 0..d {
147                p_mat[[j, comp]] = p[j];
148                w_mat[[j, comp]] = w[j];
149            }
150            q_vec[comp] = q;
151        }
152
153        // Final regression coefficients in centred+scaled space:
154        // beta = W (P' W)^{-1} q
155        // For PLS1 this simplifies; we compute beta numerically.
156        // PtW is k×k. Solve PtW * z = q, then beta = W z.
157        let mut pt_w = Array2::<f64>::zeros((k, k));
158        for a in 0..k {
159            for b in 0..k {
160                let mut s = 0.0;
161                for j in 0..d {
162                    s += p_mat[[j, a]] * w_mat[[j, b]];
163                }
164                pt_w[[a, b]] = s;
165            }
166        }
167        // Forward solve PtW z = q (PtW is upper triangular for PLS1).
168        let mut z = Array1::<f64>::zeros(k);
169        // General solve via Gauss elimination (k is small).
170        let mut m = pt_w.clone();
171        let mut rhs = q_vec.clone();
172        for col in 0..k {
173            // Find pivot.
174            let mut piv = col;
175            for r in (col + 1)..k {
176                if m[[r, col]].abs() > m[[piv, col]].abs() {
177                    piv = r;
178                }
179            }
180            if piv != col {
181                for c in 0..k {
182                    let tmp = m[[col, c]];
183                    m[[col, c]] = m[[piv, c]];
184                    m[[piv, c]] = tmp;
185                }
186                let tmp = rhs[col];
187                rhs[col] = rhs[piv];
188                rhs[piv] = tmp;
189            }
190            let pv = m[[col, col]];
191            if pv.abs() < 1e-14 {
192                continue;
193            }
194            for r in (col + 1)..k {
195                let f = m[[r, col]] / pv;
196                for c in col..k {
197                    m[[r, c]] -= f * m[[col, c]];
198                }
199                rhs[r] -= f * rhs[col];
200            }
201        }
202        // Back-substitution.
203        for r in (0..k).rev() {
204            let mut s = rhs[r];
205            for c in (r + 1)..k {
206                s -= m[[r, c]] * z[c];
207            }
208            let pv = m[[r, r]];
209            if pv.abs() > 1e-14 {
210                z[r] = s / pv;
211            }
212        }
213
214        let mut coef = Array1::<f64>::zeros(d);
215        for j in 0..d {
216            let mut s = 0.0;
217            for c in 0..k {
218                s += w_mat[[j, c]] * z[c];
219            }
220            coef[j] = s;
221        }
222
223        Ok(FittedPlsRegression {
224            x_mean,
225            y_mean,
226            x_std,
227            y_std,
228            coef,
229            n_features: d,
230        })
231    }
232}
233
234impl Predict<f64> for FittedPlsRegression {
235    fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
236        if x.ncols() != self.n_features {
237            return Err(RustMlError::ShapeMismatch(format!(
238                "expected {} features, got {}",
239                self.n_features,
240                x.ncols()
241            )));
242        }
243        let n = x.nrows();
244        let mut out = Array1::<f64>::zeros(n);
245        for i in 0..n {
246            let mut s = 0.0;
247            for j in 0..self.n_features {
248                let xs = (x[[i, j]] - self.x_mean[j]) / self.x_std[j];
249                s += self.coef[j] * xs;
250            }
251            out[i] = s * self.y_std + self.y_mean;
252        }
253        Ok(out)
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use ndarray::array;
261
262    #[test]
263    fn test_pls1_recovers_linear() {
264        let rng_x: Vec<f64> = (0..40)
265            .flat_map(|i| {
266                let i = i as f64;
267                vec![i, 0.5 * i, -0.3 * i + 1.0]
268            })
269            .collect();
270        let x = Array2::from_shape_vec((40, 3), rng_x).unwrap();
271        let y: Array1<f64> = x.column(0).mapv(|v| 2.0 * v) + x.column(1).mapv(|v| 1.5 * v);
272        let fitted = PlsRegression::new(2).fit(&x, &y).unwrap();
273        let preds = fitted.predict(&x).unwrap();
274        let rss: f64 = preds
275            .iter()
276            .zip(y.iter())
277            .map(|(p, t)| (t - p).powi(2))
278            .sum();
279        let mean = y.iter().sum::<f64>() / y.len() as f64;
280        let tss: f64 = y.iter().map(|t| (t - mean).powi(2)).sum();
281        let r2 = 1.0 - rss / tss;
282        assert!(r2 > 0.99, "R² too low: {r2}");
283        let _ = array![1.0_f64];
284    }
285}