1use ferrolearn_core::backend::Backend;
39use ferrolearn_core::backend_faer::NdarrayFaerBackend;
40use ferrolearn_core::error::FerroError;
41use ferrolearn_core::traits::{Fit, Predict, Transform};
42use ndarray::{Array1, Array2};
43use num_traits::Float;
44use std::any::TypeId;
45
46type SvdResult<F> = Result<(Array2<F>, Array1<F>, Array2<F>), FerroError>;
48
49fn centre_scale<F: Float + Send + Sync + 'static>(
58 x: &Array2<F>,
59 scale: bool,
60) -> (Array2<F>, Array1<F>, Option<Array1<F>>) {
61 let (n_samples, n_features) = x.dim();
62 let n_f = F::from(n_samples).unwrap();
63
64 let mean = Array1::from_shape_fn(n_features, |j| {
66 x.column(j).iter().copied().fold(F::zero(), |a, b| a + b) / n_f
67 });
68
69 let mut xc = x.to_owned();
71 for mut row in xc.rows_mut() {
72 for (v, &m) in row.iter_mut().zip(mean.iter()) {
73 *v = *v - m;
74 }
75 }
76
77 if scale {
78 let n_minus_1 = F::from(n_samples.saturating_sub(1).max(1)).unwrap();
79 let std_dev = Array1::from_shape_fn(n_features, |j| {
80 let var = xc
81 .column(j)
82 .iter()
83 .copied()
84 .fold(F::zero(), |a, b| a + b * b)
85 / n_minus_1;
86 let s = var.sqrt();
87 if s < F::epsilon() { F::one() } else { s }
88 });
89 for mut row in xc.rows_mut() {
90 for (v, &s) in row.iter_mut().zip(std_dev.iter()) {
91 *v = *v / s;
92 }
93 }
94 (xc, mean, Some(std_dev))
95 } else {
96 (xc, mean, None)
97 }
98}
99
100fn apply_centre_scale<F: Float + Send + Sync + 'static>(
102 x: &Array2<F>,
103 mean: &Array1<F>,
104 std_dev: &Option<Array1<F>>,
105 context: &str,
106) -> Result<Array2<F>, FerroError> {
107 if x.ncols() != mean.len() {
108 return Err(FerroError::ShapeMismatch {
109 expected: vec![x.nrows(), mean.len()],
110 actual: vec![x.nrows(), x.ncols()],
111 context: context.into(),
112 });
113 }
114 let mut xc = x.to_owned();
115 for mut row in xc.rows_mut() {
116 for (v, &m) in row.iter_mut().zip(mean.iter()) {
117 *v = *v - m;
118 }
119 }
120 if let Some(ref s) = *std_dev {
121 for mut row in xc.rows_mut() {
122 for (v, &sd) in row.iter_mut().zip(s.iter()) {
123 *v = *v / sd;
124 }
125 }
126 }
127 Ok(xc)
128}
129
130fn svd_dispatch<F: Float + Send + Sync + 'static>(a: &Array2<F>) -> SvdResult<F> {
144 if TypeId::of::<F>() == TypeId::of::<f64>() {
145 let a_f64: &Array2<f64> = unsafe { &*(std::ptr::from_ref(a).cast::<Array2<f64>>()) };
147 let (u, s, vt) = NdarrayFaerBackend::svd(a_f64)?;
148 let k = s.len();
150 let u_thin = u.slice(ndarray::s![.., ..k]).to_owned();
151 let vt_thin = vt.slice(ndarray::s![..k, ..]).to_owned();
152
153 let u_f: Array2<F> = unsafe { std::mem::transmute_copy::<Array2<f64>, Array2<F>>(&u_thin) };
155 let s_f: Array1<F> = unsafe { std::mem::transmute_copy::<Array1<f64>, Array1<F>>(&s) };
156 let vt_f: Array2<F> =
157 unsafe { std::mem::transmute_copy::<Array2<f64>, Array2<F>>(&vt_thin) };
158 std::mem::forget(u_thin);
159 std::mem::forget(s);
160 std::mem::forget(vt_thin);
161 Ok((u_f, s_f, vt_f))
162 } else if TypeId::of::<F>() == TypeId::of::<f32>() {
163 let (m, n) = a.dim();
165 let a_f64 =
166 Array2::<f64>::from_shape_fn((m, n), |(i, j)| a[[i, j]].to_f64().unwrap_or(0.0));
167 let (u64, s64, vt64) = NdarrayFaerBackend::svd(&a_f64)?;
168 let k = s64.len();
169 let u_thin = u64.slice(ndarray::s![.., ..k]).to_owned();
170 let vt_thin = vt64.slice(ndarray::s![..k, ..]).to_owned();
171
172 let u_f =
173 Array2::<F>::from_shape_fn(u_thin.dim(), |(i, j)| F::from(u_thin[[i, j]]).unwrap());
174 let s_f = Array1::<F>::from_shape_fn(s64.len(), |i| F::from(s64[i]).unwrap());
175 let vt_f =
176 Array2::<F>::from_shape_fn(vt_thin.dim(), |(i, j)| F::from(vt_thin[[i, j]]).unwrap());
177 Ok((u_f, s_f, vt_f))
178 } else {
179 svd_via_eigen(a)
181 }
182}
183
184fn svd_via_eigen<F: Float + Send + Sync + 'static>(a: &Array2<F>) -> SvdResult<F> {
186 let (m, n) = a.dim();
187 let k = m.min(n);
188
189 let ata = a.t().dot(a);
191
192 let max_iter = n * n * 100 + 1000;
194 let (eigenvalues, eigenvectors) = jacobi_eigen_symmetric(&ata, max_iter)?;
195
196 let mut indices: Vec<usize> = (0..n).collect();
198 indices.sort_by(|&i, &j| {
199 eigenvalues[j]
200 .partial_cmp(&eigenvalues[i])
201 .unwrap_or(std::cmp::Ordering::Equal)
202 });
203
204 let mut s = Array1::<F>::zeros(k);
206 let mut v = Array2::<F>::zeros((n, k));
207 for (col, &idx) in indices.iter().take(k).enumerate() {
208 let eval = eigenvalues[idx];
209 s[col] = if eval > F::zero() {
210 eval.sqrt()
211 } else {
212 F::zero()
213 };
214 for row in 0..n {
215 v[[row, col]] = eigenvectors[[row, idx]];
216 }
217 }
218
219 let av = a.dot(&v);
221 let mut u = Array2::<F>::zeros((m, k));
222 for col in 0..k {
223 if s[col] > F::epsilon() {
224 let inv_s = F::one() / s[col];
225 for row in 0..m {
226 u[[row, col]] = av[[row, col]] * inv_s;
227 }
228 }
229 }
230
231 let mut vt = Array2::<F>::zeros((k, n));
233 for i in 0..k {
234 for j in 0..n {
235 vt[[i, j]] = v[[j, i]];
236 }
237 }
238
239 Ok((u, s, vt))
240}
241
242fn jacobi_eigen_symmetric<F: Float + Send + Sync + 'static>(
244 a: &Array2<F>,
245 max_iter: usize,
246) -> Result<(Array1<F>, Array2<F>), FerroError> {
247 let n = a.nrows();
248 let mut mat = a.to_owned();
249 let mut v = Array2::<F>::zeros((n, n));
250 for i in 0..n {
251 v[[i, i]] = F::one();
252 }
253
254 let tol = F::from(1e-12).unwrap_or_else(F::epsilon);
255
256 for _iteration in 0..max_iter {
257 let mut max_off = F::zero();
258 let mut p = 0;
259 let mut q = 1;
260 for i in 0..n {
261 for j in (i + 1)..n {
262 let val = mat[[i, j]].abs();
263 if val > max_off {
264 max_off = val;
265 p = i;
266 q = j;
267 }
268 }
269 }
270
271 if max_off < tol {
272 let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
273 return Ok((eigenvalues, v));
274 }
275
276 let app = mat[[p, p]];
277 let aqq = mat[[q, q]];
278 let apq = mat[[p, q]];
279
280 let theta = if (app - aqq).abs() < tol {
281 F::from(std::f64::consts::FRAC_PI_4).unwrap_or_else(F::one)
282 } else {
283 let tau = (aqq - app) / (F::from(2.0).unwrap() * apq);
284 let t = if tau >= F::zero() {
285 F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
286 } else {
287 -F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
288 };
289 t.atan()
290 };
291
292 let c = theta.cos();
293 let s = theta.sin();
294
295 let mut new_mat = mat.clone();
296 for i in 0..n {
297 if i != p && i != q {
298 let mip = mat[[i, p]];
299 let miq = mat[[i, q]];
300 new_mat[[i, p]] = c * mip - s * miq;
301 new_mat[[p, i]] = new_mat[[i, p]];
302 new_mat[[i, q]] = s * mip + c * miq;
303 new_mat[[q, i]] = new_mat[[i, q]];
304 }
305 }
306
307 new_mat[[p, p]] = c * c * app - F::from(2.0).unwrap() * s * c * apq + s * s * aqq;
308 new_mat[[q, q]] = s * s * app + F::from(2.0).unwrap() * s * c * apq + c * c * aqq;
309 new_mat[[p, q]] = F::zero();
310 new_mat[[q, p]] = F::zero();
311
312 mat = new_mat;
313
314 for i in 0..n {
315 let vip = v[[i, p]];
316 let viq = v[[i, q]];
317 v[[i, p]] = c * vip - s * viq;
318 v[[i, q]] = s * vip + c * viq;
319 }
320 }
321
322 Err(FerroError::ConvergenceFailure {
323 iterations: max_iter,
324 message: "Jacobi eigendecomposition did not converge in cross_decomposition SVD fallback"
325 .into(),
326 })
327}
328
329fn norm<F: Float>(v: &Array1<F>) -> F {
335 v.iter().copied().fold(F::zero(), |a, b| a + b * b).sqrt()
336}
337
338fn dot<F: Float>(a: &Array1<F>, b: &Array1<F>) -> F {
340 a.iter()
341 .copied()
342 .zip(b.iter().copied())
343 .fold(F::zero(), |acc, (x, y)| acc + x * y)
344}
345
346fn invert_square<F: Float + Send + Sync + 'static>(a: &Array2<F>) -> Result<Array2<F>, FerroError> {
357 let n = a.nrows();
358 if n != a.ncols() {
359 return Err(FerroError::ShapeMismatch {
360 expected: vec![n, n],
361 actual: vec![a.nrows(), a.ncols()],
362 context: "invert_square: matrix must be square".into(),
363 });
364 }
365
366 let mut aug = Array2::<F>::zeros((n, 2 * n));
368 for i in 0..n {
369 for j in 0..n {
370 aug[[i, j]] = a[[i, j]];
371 }
372 aug[[i, n + i]] = F::one();
373 }
374
375 let max_abs = a.iter().copied().fold(F::zero(), |m, v| {
377 let abs = v.abs();
378 if abs > m { abs } else { m }
379 });
380 let regularise_tol = max_abs * F::from(1e-12).unwrap_or_else(F::epsilon)
381 + F::from(1e-15).unwrap_or_else(F::epsilon);
382
383 for col in 0..n {
385 let mut max_val = aug[[col, col]].abs();
387 let mut max_row = col;
388 for row in (col + 1)..n {
389 let val = aug[[row, col]].abs();
390 if val > max_val {
391 max_val = val;
392 max_row = row;
393 }
394 }
395
396 if max_val < regularise_tol {
398 aug[[col, col]] = regularise_tol;
399 } else {
400 if max_row != col {
402 for j in 0..(2 * n) {
403 let tmp = aug[[col, j]];
404 aug[[col, j]] = aug[[max_row, j]];
405 aug[[max_row, j]] = tmp;
406 }
407 }
408 }
409
410 let pivot = aug[[col, col]];
412 for row in (col + 1)..n {
413 let factor = aug[[row, col]] / pivot;
414 for j in col..(2 * n) {
415 let above = aug[[col, j]];
416 aug[[row, j]] = aug[[row, j]] - factor * above;
417 }
418 }
419 }
420
421 for col in (0..n).rev() {
423 let pivot = aug[[col, col]];
424 for j in 0..(2 * n) {
425 aug[[col, j]] = aug[[col, j]] / pivot;
426 }
427 for row in 0..col {
428 let factor = aug[[row, col]];
429 for j in 0..(2 * n) {
430 let below = aug[[col, j]];
431 aug[[row, j]] = aug[[row, j]] - factor * below;
432 }
433 }
434 }
435
436 let mut inv = Array2::<F>::zeros((n, n));
438 for i in 0..n {
439 for j in 0..n {
440 inv[[i, j]] = aug[[i, n + j]];
441 }
442 }
443 Ok(inv)
444}
445
446#[derive(Debug, Clone)]
479pub struct PLSSVD<F> {
480 n_components: usize,
482 scale: bool,
484 _marker: std::marker::PhantomData<F>,
485}
486
487impl<F: Float + Send + Sync + 'static> PLSSVD<F> {
488 #[must_use]
490 pub fn new(n_components: usize) -> Self {
491 Self {
492 n_components,
493 scale: true,
494 _marker: std::marker::PhantomData,
495 }
496 }
497
498 #[must_use]
500 pub fn with_scale(mut self, scale: bool) -> Self {
501 self.scale = scale;
502 self
503 }
504
505 #[must_use]
507 pub fn n_components(&self) -> usize {
508 self.n_components
509 }
510}
511
512#[derive(Debug, Clone)]
517pub struct FittedPLSSVD<F> {
518 x_weights_: Array2<F>,
520 y_weights_: Array2<F>,
522 x_mean_: Array1<F>,
524 y_mean_: Array1<F>,
526 x_std_: Option<Array1<F>>,
528 y_std_: Option<Array1<F>>,
530}
531
532impl<F: Float + Send + Sync + 'static> FittedPLSSVD<F> {
533 #[must_use]
535 pub fn x_weights(&self) -> &Array2<F> {
536 &self.x_weights_
537 }
538
539 #[must_use]
541 pub fn y_weights(&self) -> &Array2<F> {
542 &self.y_weights_
543 }
544
545 #[must_use]
547 pub fn x_mean(&self) -> &Array1<F> {
548 &self.x_mean_
549 }
550
551 #[must_use]
553 pub fn y_mean(&self) -> &Array1<F> {
554 &self.y_mean_
555 }
556
557 pub fn transform_y(&self, y: &Array2<F>) -> Result<Array2<F>, FerroError> {
563 let yc = apply_centre_scale(y, &self.y_mean_, &self.y_std_, "FittedPLSSVD::transform_y")?;
564 Ok(yc.dot(&self.y_weights_))
565 }
566}
567
568impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array2<F>> for PLSSVD<F> {
569 type Fitted = FittedPLSSVD<F>;
570 type Error = FerroError;
571
572 fn fit(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedPLSSVD<F>, FerroError> {
581 let (n_samples_x, n_features_x) = x.dim();
582 let (n_samples_y, n_features_y) = y.dim();
583
584 if n_samples_x != n_samples_y {
585 return Err(FerroError::ShapeMismatch {
586 expected: vec![n_samples_x, n_features_y],
587 actual: vec![n_samples_y, n_features_y],
588 context: "PLSSVD::fit: X and Y must have the same number of rows".into(),
589 });
590 }
591
592 if self.n_components == 0 {
593 return Err(FerroError::InvalidParameter {
594 name: "n_components".into(),
595 reason: "must be at least 1".into(),
596 });
597 }
598
599 let max_components = n_features_x.min(n_features_y);
600 if self.n_components > max_components {
601 return Err(FerroError::InvalidParameter {
602 name: "n_components".into(),
603 reason: format!(
604 "n_components ({}) exceeds min(n_features_x, n_features_y) ({})",
605 self.n_components, max_components
606 ),
607 });
608 }
609
610 if n_samples_x < 2 {
611 return Err(FerroError::InsufficientSamples {
612 required: 2,
613 actual: n_samples_x,
614 context: "PLSSVD::fit requires at least 2 samples".into(),
615 });
616 }
617
618 let (xc, x_mean, x_std) = centre_scale(x, self.scale);
620 let (yc, y_mean, y_std) = centre_scale(y, self.scale);
621
622 let c = xc.t().dot(&yc);
624
625 let (u, _s, vt) = svd_dispatch(&c)?;
627
628 let nc = self.n_components;
630 let x_weights = u.slice(ndarray::s![.., ..nc]).to_owned();
631 let y_weights = vt.t().slice(ndarray::s![.., ..nc]).to_owned();
633
634 Ok(FittedPLSSVD {
635 x_weights_: x_weights,
636 y_weights_: y_weights,
637 x_mean_: x_mean,
638 y_mean_: y_mean,
639 x_std_: x_std,
640 y_std_: y_std,
641 })
642 }
643}
644
645impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedPLSSVD<F> {
646 type Output = Array2<F>;
647 type Error = FerroError;
648
649 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
655 let xc = apply_centre_scale(x, &self.x_mean_, &self.x_std_, "FittedPLSSVD::transform")?;
656 Ok(xc.dot(&self.x_weights_))
657 }
658}
659
660#[derive(Debug, Clone, Copy, PartialEq, Eq)]
666enum NipalsMode {
667 Regression,
669 Canonical,
671}
672
673#[derive(Debug, Clone, Copy, PartialEq, Eq)]
675enum ScoreNorm {
676 None,
678 UnitVariance,
680}
681
682#[derive(Debug, Clone)]
688struct NipalsResult<F> {
689 x_weights: Array2<F>,
691 x_loadings: Array2<F>,
693 x_scores: Array2<F>,
695 y_loadings: Array2<F>,
697 y_scores: Array2<F>,
699 n_iter: Vec<usize>,
701}
702
703fn nipals<F: Float + Send + Sync + 'static>(
705 x: &Array2<F>,
706 y: &Array2<F>,
707 n_components: usize,
708 max_iter: usize,
709 tol: F,
710 mode: NipalsMode,
711 score_norm: ScoreNorm,
712) -> Result<NipalsResult<F>, FerroError> {
713 let (n_samples, n_features_x) = x.dim();
714 let n_features_y = y.ncols();
715
716 let mut xk = x.to_owned();
717 let mut yk = y.to_owned();
718
719 let mut x_weights = Array2::<F>::zeros((n_features_x, n_components));
720 let mut x_loadings = Array2::<F>::zeros((n_features_x, n_components));
721 let mut x_scores = Array2::<F>::zeros((n_samples, n_components));
722 let mut y_loadings = Array2::<F>::zeros((n_features_y, n_components));
723 let mut y_scores = Array2::<F>::zeros((n_samples, n_components));
724 let mut n_iter_vec = Vec::with_capacity(n_components);
725
726 for k in 0..n_components {
727 let best_col = (0..n_features_y)
729 .max_by(|&a, &b| {
730 let var_a: F = yk
731 .column(a)
732 .iter()
733 .copied()
734 .fold(F::zero(), |s, v| s + v * v);
735 let var_b: F = yk
736 .column(b)
737 .iter()
738 .copied()
739 .fold(F::zero(), |s, v| s + v * v);
740 var_a
741 .partial_cmp(&var_b)
742 .unwrap_or(std::cmp::Ordering::Equal)
743 })
744 .unwrap_or(0);
745
746 let mut u = yk.column(best_col).to_owned();
747
748 let mut converged = false;
749 let mut iters = 0;
750
751 for iteration in 0..max_iter {
752 iters = iteration + 1;
753
754 let utu = dot(&u, &u);
756 let mut w = xk.t().dot(&u);
757 if utu > F::epsilon() {
758 w.mapv_inplace(|v| v / utu);
759 }
760 let w_norm = norm(&w);
762 if w_norm < F::epsilon() {
763 break;
765 }
766 w.mapv_inplace(|v| v / w_norm);
767
768 let t = xk.dot(&w);
770
771 let ttt = dot(&t, &t);
773 let mut q = yk.t().dot(&t);
774 if ttt > F::epsilon() {
775 q.mapv_inplace(|v| v / ttt);
776 }
777
778 if score_norm == ScoreNorm::UnitVariance {
780 let q_norm = norm(&q);
781 if q_norm > F::epsilon() {
782 q.mapv_inplace(|v| v / q_norm);
783 }
784 }
785
786 let qtq = dot(&q, &q);
788 let mut u_new = yk.dot(&q);
789 if qtq > F::epsilon() {
790 u_new.mapv_inplace(|v| v / qtq);
791 }
792
793 let diff_norm = {
798 let diff: Array1<F> = &u_new - &u;
799 norm(&diff)
800 };
801 let u_new_norm = norm(&u_new);
802
803 u = u_new;
804
805 if u_new_norm > F::epsilon() && diff_norm / u_new_norm < tol {
806 converged = true;
807 let utu2 = dot(&u, &u);
810 w = xk.t().dot(&u);
811 if utu2 > F::epsilon() {
812 w.mapv_inplace(|v| v / utu2);
813 }
814 let w_norm2 = norm(&w);
815 if w_norm2 > F::epsilon() {
816 w.mapv_inplace(|v| v / w_norm2);
817 }
818 let t_final = xk.dot(&w);
820 let ttt2 = dot(&t_final, &t_final);
821 q = yk.t().dot(&t_final);
822 if ttt2 > F::epsilon() {
823 q.mapv_inplace(|v| v / ttt2);
824 }
825 if score_norm == ScoreNorm::UnitVariance {
826 let q_norm2 = norm(&q);
827 if q_norm2 > F::epsilon() {
828 q.mapv_inplace(|v| v / q_norm2);
829 }
830 }
831 let qtq2 = dot(&q, &q);
832 u = yk.dot(&q);
833 if qtq2 > F::epsilon() {
834 u.mapv_inplace(|v| v / qtq2);
835 }
836 break;
837 }
838 }
839
840 let utu_final = dot(&u, &u);
842 let mut w_final = xk.t().dot(&u);
843 if utu_final > F::epsilon() {
844 w_final.mapv_inplace(|v| v / utu_final);
845 }
846 let w_norm_final = norm(&w_final);
847 if w_norm_final > F::epsilon() {
848 w_final.mapv_inplace(|v| v / w_norm_final);
849 }
850
851 let mut t_final = xk.dot(&w_final);
852 let ttt_final = dot(&t_final, &t_final);
853
854 let mut p = xk.t().dot(&t_final);
856 if ttt_final > F::epsilon() {
857 p.mapv_inplace(|v| v / ttt_final);
858 }
859
860 let mut q_final = yk.t().dot(&t_final);
862 if ttt_final > F::epsilon() {
863 q_final.mapv_inplace(|v| v / ttt_final);
864 }
865
866 if score_norm == ScoreNorm::UnitVariance {
867 let q_norm = norm(&q_final);
868 if q_norm > F::epsilon() {
869 q_final.mapv_inplace(|v| v / q_norm);
870 }
871 }
872
873 let qtq_final = dot(&q_final, &q_final);
874 let mut u_final = yk.dot(&q_final);
875 if qtq_final > F::epsilon() {
876 u_final.mapv_inplace(|v| v / qtq_final);
877 }
878
879 if score_norm == ScoreNorm::UnitVariance {
881 let t_std = {
882 let t_mean = t_final.iter().copied().fold(F::zero(), |a, b| a + b)
883 / F::from(n_samples).unwrap();
884 let var = t_final
885 .iter()
886 .copied()
887 .fold(F::zero(), |a, b| a + (b - t_mean) * (b - t_mean))
888 / F::from(n_samples.saturating_sub(1).max(1)).unwrap();
889 var.sqrt()
890 };
891 if t_std > F::epsilon() {
892 t_final.mapv_inplace(|v| v / t_std);
893 }
894
895 let u_std = {
896 let u_mean = u_final.iter().copied().fold(F::zero(), |a, b| a + b)
897 / F::from(n_samples).unwrap();
898 let var = u_final
899 .iter()
900 .copied()
901 .fold(F::zero(), |a, b| a + (b - u_mean) * (b - u_mean))
902 / F::from(n_samples.saturating_sub(1).max(1)).unwrap();
903 var.sqrt()
904 };
905 if u_std > F::epsilon() {
906 u_final.mapv_inplace(|v| v / u_std);
907 }
908 }
909
910 x_weights.column_mut(k).assign(&w_final);
912 x_loadings.column_mut(k).assign(&p);
913 x_scores.column_mut(k).assign(&t_final);
914 y_loadings.column_mut(k).assign(&q_final);
915 y_scores.column_mut(k).assign(&u_final);
916
917 for i in 0..n_samples {
919 let ti = t_final[i];
920 for j in 0..n_features_x {
921 xk[[i, j]] = xk[[i, j]] - ti * p[j];
922 }
923 }
924
925 match mode {
927 NipalsMode::Regression => {
928 for i in 0..n_samples {
930 let ti = t_final[i];
931 for j in 0..n_features_y {
932 yk[[i, j]] = yk[[i, j]] - ti * q_final[j];
933 }
934 }
935 }
936 NipalsMode::Canonical => {
937 let utu_c = dot(&u_final, &u_final);
939 let mut c = yk.t().dot(&u_final);
940 if utu_c > F::epsilon() {
941 c.mapv_inplace(|v| v / utu_c);
942 }
943 for i in 0..n_samples {
944 let ui = u_final[i];
945 for j in 0..n_features_y {
946 yk[[i, j]] = yk[[i, j]] - ui * c[j];
947 }
948 }
949 }
950 }
951
952 n_iter_vec.push(iters);
953
954 if !converged && n_features_y > 1 && iters == max_iter {
955 return Err(FerroError::ConvergenceFailure {
956 iterations: max_iter,
957 message: format!("NIPALS did not converge for component {k}"),
958 });
959 }
960 }
961
962 Ok(NipalsResult {
963 x_weights,
964 x_loadings,
965 x_scores,
966 y_loadings,
967 y_scores,
968 n_iter: n_iter_vec,
969 })
970}
971
972#[derive(Debug, Clone)]
1007pub struct PLSRegression<F> {
1008 n_components: usize,
1010 max_iter: usize,
1012 tol: F,
1014 scale: bool,
1016 _marker: std::marker::PhantomData<F>,
1017}
1018
1019impl<F: Float + Send + Sync + 'static> PLSRegression<F> {
1020 #[must_use]
1024 pub fn new(n_components: usize) -> Self {
1025 Self {
1026 n_components,
1027 max_iter: 500,
1028 tol: F::from(1e-6).unwrap_or_else(F::epsilon),
1029 scale: true,
1030 _marker: std::marker::PhantomData,
1031 }
1032 }
1033
1034 #[must_use]
1036 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1037 self.max_iter = max_iter;
1038 self
1039 }
1040
1041 #[must_use]
1043 pub fn with_tol(mut self, tol: F) -> Self {
1044 self.tol = tol;
1045 self
1046 }
1047
1048 #[must_use]
1050 pub fn with_scale(mut self, scale: bool) -> Self {
1051 self.scale = scale;
1052 self
1053 }
1054
1055 #[must_use]
1057 pub fn n_components(&self) -> usize {
1058 self.n_components
1059 }
1060}
1061
1062#[derive(Debug, Clone)]
1068pub struct FittedPLSRegression<F> {
1069 x_weights_: Array2<F>,
1071 x_loadings_: Array2<F>,
1073 y_loadings_: Array2<F>,
1075 coefficients_: Array2<F>,
1078 x_scores_: Array2<F>,
1080 y_scores_: Array2<F>,
1082 n_iter_: Vec<usize>,
1084 x_mean_: Array1<F>,
1086 y_mean_: Array1<F>,
1088 x_std_: Option<Array1<F>>,
1090 y_std_: Option<Array1<F>>,
1092}
1093
1094impl<F: Float + Send + Sync + 'static> FittedPLSRegression<F> {
1095 #[must_use]
1097 pub fn x_weights(&self) -> &Array2<F> {
1098 &self.x_weights_
1099 }
1100
1101 #[must_use]
1103 pub fn x_loadings(&self) -> &Array2<F> {
1104 &self.x_loadings_
1105 }
1106
1107 #[must_use]
1109 pub fn y_loadings(&self) -> &Array2<F> {
1110 &self.y_loadings_
1111 }
1112
1113 #[must_use]
1117 pub fn coefficients(&self) -> &Array2<F> {
1118 &self.coefficients_
1119 }
1120
1121 #[must_use]
1123 pub fn x_scores(&self) -> &Array2<F> {
1124 &self.x_scores_
1125 }
1126
1127 #[must_use]
1129 pub fn y_scores(&self) -> &Array2<F> {
1130 &self.y_scores_
1131 }
1132
1133 #[must_use]
1135 pub fn n_iter(&self) -> &[usize] {
1136 &self.n_iter_
1137 }
1138}
1139
1140impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array2<F>> for PLSRegression<F> {
1141 type Fitted = FittedPLSRegression<F>;
1142 type Error = FerroError;
1143
1144 fn fit(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedPLSRegression<F>, FerroError> {
1153 let (n_samples_x, n_features_x) = x.dim();
1154 let (n_samples_y, n_features_y) = y.dim();
1155
1156 if n_samples_x != n_samples_y {
1157 return Err(FerroError::ShapeMismatch {
1158 expected: vec![n_samples_x, n_features_y],
1159 actual: vec![n_samples_y, n_features_y],
1160 context: "PLSRegression::fit: X and Y must have the same number of rows".into(),
1161 });
1162 }
1163
1164 if self.n_components == 0 {
1165 return Err(FerroError::InvalidParameter {
1166 name: "n_components".into(),
1167 reason: "must be at least 1".into(),
1168 });
1169 }
1170
1171 let max_components = n_features_x.min(n_features_y).min(n_samples_x);
1172 if self.n_components > max_components {
1173 return Err(FerroError::InvalidParameter {
1174 name: "n_components".into(),
1175 reason: format!(
1176 "n_components ({}) exceeds min(n_features_x, n_features_y, n_samples) ({})",
1177 self.n_components, max_components
1178 ),
1179 });
1180 }
1181
1182 if n_samples_x < 2 {
1183 return Err(FerroError::InsufficientSamples {
1184 required: 2,
1185 actual: n_samples_x,
1186 context: "PLSRegression::fit requires at least 2 samples".into(),
1187 });
1188 }
1189
1190 let (xc, x_mean, x_std) = centre_scale(x, self.scale);
1192 let (yc, y_mean, y_std) = centre_scale(y, self.scale);
1193
1194 let result = nipals(
1196 &xc,
1197 &yc,
1198 self.n_components,
1199 self.max_iter,
1200 self.tol,
1201 NipalsMode::Regression,
1202 ScoreNorm::None,
1203 )?;
1204
1205 let ptw = result.x_loadings.t().dot(&result.x_weights);
1207 let ptw_inv = invert_square(&ptw)?;
1208 let coefficients = result.x_weights.dot(&ptw_inv).dot(&result.y_loadings.t());
1209
1210 Ok(FittedPLSRegression {
1216 x_weights_: result.x_weights,
1217 x_loadings_: result.x_loadings,
1218 y_loadings_: result.y_loadings,
1219 coefficients_: coefficients,
1220 x_scores_: result.x_scores,
1221 y_scores_: result.y_scores,
1222 n_iter_: result.n_iter,
1223 x_mean_: x_mean,
1224 y_mean_: y_mean,
1225 x_std_: x_std,
1226 y_std_: y_std,
1227 })
1228 }
1229}
1230
1231impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedPLSRegression<F> {
1232 type Output = Array2<F>;
1233 type Error = FerroError;
1234
1235 fn predict(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1243 let xc = apply_centre_scale(
1244 x,
1245 &self.x_mean_,
1246 &self.x_std_,
1247 "FittedPLSRegression::predict",
1248 )?;
1249
1250 let mut y_pred = xc.dot(&self.coefficients_);
1251
1252 if let Some(ref ys) = self.y_std_ {
1254 for mut row in y_pred.rows_mut() {
1255 for (v, &s) in row.iter_mut().zip(ys.iter()) {
1256 *v = *v * s;
1257 }
1258 }
1259 }
1260
1261 for mut row in y_pred.rows_mut() {
1263 for (v, &m) in row.iter_mut().zip(self.y_mean_.iter()) {
1264 *v = *v + m;
1265 }
1266 }
1267
1268 Ok(y_pred)
1269 }
1270}
1271
1272impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedPLSRegression<F> {
1273 type Output = Array2<F>;
1274 type Error = FerroError;
1275
1276 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1284 let xc = apply_centre_scale(
1285 x,
1286 &self.x_mean_,
1287 &self.x_std_,
1288 "FittedPLSRegression::transform",
1289 )?;
1290
1291 let ptw = self.x_loadings_.t().dot(&self.x_weights_);
1295 let ptw_inv = invert_square(&ptw)?;
1296 let rotation = self.x_weights_.dot(&ptw_inv);
1297 Ok(xc.dot(&rotation))
1298 }
1299}
1300
1301#[derive(Debug, Clone)]
1334pub struct PLSCanonical<F> {
1335 n_components: usize,
1337 max_iter: usize,
1339 tol: F,
1341 scale: bool,
1343 _marker: std::marker::PhantomData<F>,
1344}
1345
1346impl<F: Float + Send + Sync + 'static> PLSCanonical<F> {
1347 #[must_use]
1351 pub fn new(n_components: usize) -> Self {
1352 Self {
1353 n_components,
1354 max_iter: 500,
1355 tol: F::from(1e-6).unwrap_or_else(F::epsilon),
1356 scale: true,
1357 _marker: std::marker::PhantomData,
1358 }
1359 }
1360
1361 #[must_use]
1363 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1364 self.max_iter = max_iter;
1365 self
1366 }
1367
1368 #[must_use]
1370 pub fn with_tol(mut self, tol: F) -> Self {
1371 self.tol = tol;
1372 self
1373 }
1374
1375 #[must_use]
1377 pub fn with_scale(mut self, scale: bool) -> Self {
1378 self.scale = scale;
1379 self
1380 }
1381
1382 #[must_use]
1384 pub fn n_components(&self) -> usize {
1385 self.n_components
1386 }
1387}
1388
1389#[derive(Debug, Clone)]
1395pub struct FittedPLSCanonical<F> {
1396 x_weights_: Array2<F>,
1398 x_loadings_: Array2<F>,
1400 y_loadings_: Array2<F>,
1402 x_scores_: Array2<F>,
1404 y_scores_: Array2<F>,
1406 n_iter_: Vec<usize>,
1408 x_mean_: Array1<F>,
1410 y_mean_: Array1<F>,
1412 x_std_: Option<Array1<F>>,
1414 y_std_: Option<Array1<F>>,
1416}
1417
1418impl<F: Float + Send + Sync + 'static> FittedPLSCanonical<F> {
1419 #[must_use]
1421 pub fn x_weights(&self) -> &Array2<F> {
1422 &self.x_weights_
1423 }
1424
1425 #[must_use]
1427 pub fn x_loadings(&self) -> &Array2<F> {
1428 &self.x_loadings_
1429 }
1430
1431 #[must_use]
1433 pub fn y_loadings(&self) -> &Array2<F> {
1434 &self.y_loadings_
1435 }
1436
1437 #[must_use]
1439 pub fn x_scores(&self) -> &Array2<F> {
1440 &self.x_scores_
1441 }
1442
1443 #[must_use]
1445 pub fn y_scores(&self) -> &Array2<F> {
1446 &self.y_scores_
1447 }
1448
1449 #[must_use]
1451 pub fn n_iter(&self) -> &[usize] {
1452 &self.n_iter_
1453 }
1454
1455 pub fn transform_y(&self, y: &Array2<F>) -> Result<Array2<F>, FerroError> {
1461 let yc = apply_centre_scale(
1462 y,
1463 &self.y_mean_,
1464 &self.y_std_,
1465 "FittedPLSCanonical::transform_y",
1466 )?;
1467 Ok(yc.dot(&self.y_loadings_))
1468 }
1469}
1470
1471impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array2<F>> for PLSCanonical<F> {
1472 type Fitted = FittedPLSCanonical<F>;
1473 type Error = FerroError;
1474
1475 fn fit(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedPLSCanonical<F>, FerroError> {
1484 let (n_samples_x, n_features_x) = x.dim();
1485 let (n_samples_y, n_features_y) = y.dim();
1486
1487 if n_samples_x != n_samples_y {
1488 return Err(FerroError::ShapeMismatch {
1489 expected: vec![n_samples_x, n_features_y],
1490 actual: vec![n_samples_y, n_features_y],
1491 context: "PLSCanonical::fit: X and Y must have the same number of rows".into(),
1492 });
1493 }
1494
1495 if self.n_components == 0 {
1496 return Err(FerroError::InvalidParameter {
1497 name: "n_components".into(),
1498 reason: "must be at least 1".into(),
1499 });
1500 }
1501
1502 let max_components = n_features_x.min(n_features_y).min(n_samples_x);
1503 if self.n_components > max_components {
1504 return Err(FerroError::InvalidParameter {
1505 name: "n_components".into(),
1506 reason: format!(
1507 "n_components ({}) exceeds min(n_features_x, n_features_y, n_samples) ({})",
1508 self.n_components, max_components
1509 ),
1510 });
1511 }
1512
1513 if n_samples_x < 2 {
1514 return Err(FerroError::InsufficientSamples {
1515 required: 2,
1516 actual: n_samples_x,
1517 context: "PLSCanonical::fit requires at least 2 samples".into(),
1518 });
1519 }
1520
1521 let (xc, x_mean, x_std) = centre_scale(x, self.scale);
1522 let (yc, y_mean, y_std) = centre_scale(y, self.scale);
1523
1524 let result = nipals(
1525 &xc,
1526 &yc,
1527 self.n_components,
1528 self.max_iter,
1529 self.tol,
1530 NipalsMode::Canonical,
1531 ScoreNorm::None,
1532 )?;
1533
1534 Ok(FittedPLSCanonical {
1535 x_weights_: result.x_weights,
1536 x_loadings_: result.x_loadings,
1537 y_loadings_: result.y_loadings,
1538 x_scores_: result.x_scores,
1539 y_scores_: result.y_scores,
1540 n_iter_: result.n_iter,
1541 x_mean_: x_mean,
1542 y_mean_: y_mean,
1543 x_std_: x_std,
1544 y_std_: y_std,
1545 })
1546 }
1547}
1548
1549impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedPLSCanonical<F> {
1550 type Output = Array2<F>;
1551 type Error = FerroError;
1552
1553 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1559 let xc = apply_centre_scale(
1560 x,
1561 &self.x_mean_,
1562 &self.x_std_,
1563 "FittedPLSCanonical::transform",
1564 )?;
1565
1566 let ptw = self.x_loadings_.t().dot(&self.x_weights_);
1567 let ptw_inv = invert_square(&ptw)?;
1568 let rotation = self.x_weights_.dot(&ptw_inv);
1569 Ok(xc.dot(&rotation))
1570 }
1571}
1572
1573#[derive(Debug, Clone)]
1603pub struct CCA<F> {
1604 n_components: usize,
1606 max_iter: usize,
1608 tol: F,
1610 scale: bool,
1612 _marker: std::marker::PhantomData<F>,
1613}
1614
1615impl<F: Float + Send + Sync + 'static> CCA<F> {
1616 #[must_use]
1620 pub fn new(n_components: usize) -> Self {
1621 Self {
1622 n_components,
1623 max_iter: 500,
1624 tol: F::from(1e-6).unwrap_or_else(F::epsilon),
1625 scale: true,
1626 _marker: std::marker::PhantomData,
1627 }
1628 }
1629
1630 #[must_use]
1632 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1633 self.max_iter = max_iter;
1634 self
1635 }
1636
1637 #[must_use]
1639 pub fn with_tol(mut self, tol: F) -> Self {
1640 self.tol = tol;
1641 self
1642 }
1643
1644 #[must_use]
1646 pub fn with_scale(mut self, scale: bool) -> Self {
1647 self.scale = scale;
1648 self
1649 }
1650
1651 #[must_use]
1653 pub fn n_components(&self) -> usize {
1654 self.n_components
1655 }
1656}
1657
1658#[derive(Debug, Clone)]
1663pub struct FittedCCA<F> {
1664 x_weights_: Array2<F>,
1666 x_loadings_: Array2<F>,
1668 y_loadings_: Array2<F>,
1670 x_scores_: Array2<F>,
1672 y_scores_: Array2<F>,
1674 n_iter_: Vec<usize>,
1676 x_mean_: Array1<F>,
1678 y_mean_: Array1<F>,
1680 x_std_: Option<Array1<F>>,
1682 y_std_: Option<Array1<F>>,
1684}
1685
1686impl<F: Float + Send + Sync + 'static> FittedCCA<F> {
1687 #[must_use]
1689 pub fn x_weights(&self) -> &Array2<F> {
1690 &self.x_weights_
1691 }
1692
1693 #[must_use]
1695 pub fn x_loadings(&self) -> &Array2<F> {
1696 &self.x_loadings_
1697 }
1698
1699 #[must_use]
1701 pub fn y_loadings(&self) -> &Array2<F> {
1702 &self.y_loadings_
1703 }
1704
1705 #[must_use]
1707 pub fn x_scores(&self) -> &Array2<F> {
1708 &self.x_scores_
1709 }
1710
1711 #[must_use]
1713 pub fn y_scores(&self) -> &Array2<F> {
1714 &self.y_scores_
1715 }
1716
1717 #[must_use]
1719 pub fn n_iter(&self) -> &[usize] {
1720 &self.n_iter_
1721 }
1722
1723 pub fn transform_y(&self, y: &Array2<F>) -> Result<Array2<F>, FerroError> {
1729 let yc = apply_centre_scale(y, &self.y_mean_, &self.y_std_, "FittedCCA::transform_y")?;
1730 Ok(yc.dot(&self.y_loadings_))
1731 }
1732}
1733
1734impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array2<F>> for CCA<F> {
1735 type Fitted = FittedCCA<F>;
1736 type Error = FerroError;
1737
1738 fn fit(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedCCA<F>, FerroError> {
1747 let (n_samples_x, n_features_x) = x.dim();
1748 let (n_samples_y, n_features_y) = y.dim();
1749
1750 if n_samples_x != n_samples_y {
1751 return Err(FerroError::ShapeMismatch {
1752 expected: vec![n_samples_x, n_features_y],
1753 actual: vec![n_samples_y, n_features_y],
1754 context: "CCA::fit: X and Y must have the same number of rows".into(),
1755 });
1756 }
1757
1758 if self.n_components == 0 {
1759 return Err(FerroError::InvalidParameter {
1760 name: "n_components".into(),
1761 reason: "must be at least 1".into(),
1762 });
1763 }
1764
1765 let max_components = n_features_x.min(n_features_y).min(n_samples_x);
1766 if self.n_components > max_components {
1767 return Err(FerroError::InvalidParameter {
1768 name: "n_components".into(),
1769 reason: format!(
1770 "n_components ({}) exceeds min(n_features_x, n_features_y, n_samples) ({})",
1771 self.n_components, max_components
1772 ),
1773 });
1774 }
1775
1776 if n_samples_x < 2 {
1777 return Err(FerroError::InsufficientSamples {
1778 required: 2,
1779 actual: n_samples_x,
1780 context: "CCA::fit requires at least 2 samples".into(),
1781 });
1782 }
1783
1784 let (xc, x_mean, x_std) = centre_scale(x, self.scale);
1785 let (yc, y_mean, y_std) = centre_scale(y, self.scale);
1786
1787 let result = nipals(
1788 &xc,
1789 &yc,
1790 self.n_components,
1791 self.max_iter,
1792 self.tol,
1793 NipalsMode::Canonical,
1794 ScoreNorm::UnitVariance,
1795 )?;
1796
1797 Ok(FittedCCA {
1798 x_weights_: result.x_weights,
1799 x_loadings_: result.x_loadings,
1800 y_loadings_: result.y_loadings,
1801 x_scores_: result.x_scores,
1802 y_scores_: result.y_scores,
1803 n_iter_: result.n_iter,
1804 x_mean_: x_mean,
1805 y_mean_: y_mean,
1806 x_std_: x_std,
1807 y_std_: y_std,
1808 })
1809 }
1810}
1811
1812impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedCCA<F> {
1813 type Output = Array2<F>;
1814 type Error = FerroError;
1815
1816 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1822 let xc = apply_centre_scale(x, &self.x_mean_, &self.x_std_, "FittedCCA::transform")?;
1823
1824 let ptw = self.x_loadings_.t().dot(&self.x_weights_);
1825 let ptw_inv = invert_square(&ptw)?;
1826 let rotation = self.x_weights_.dot(&ptw_inv);
1827 Ok(xc.dot(&rotation))
1828 }
1829}
1830
1831#[cfg(test)]
1836mod tests {
1837 use super::*;
1838 use approx::assert_abs_diff_eq;
1839 use ndarray::array;
1840
1841 #[test]
1846 fn test_plssvd_basic_fit_transform() {
1847 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
1848 let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
1849 let svd = PLSSVD::<f64>::new(1);
1850 let fitted = svd.fit(&x, &y).unwrap();
1851 let scores = fitted.transform(&x).unwrap();
1852 assert_eq!(scores.dim(), (5, 1));
1853 }
1854
1855 #[test]
1856 fn test_plssvd_two_components() {
1857 let x = array![
1858 [1.0, 2.0, 3.0],
1859 [4.0, 5.0, 6.0],
1860 [7.0, 8.0, 9.0],
1861 [10.0, 11.0, 12.0],
1862 ];
1863 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
1864 let svd = PLSSVD::<f64>::new(2);
1865 let fitted = svd.fit(&x, &y).unwrap();
1866 let scores = fitted.transform(&x).unwrap();
1867 assert_eq!(scores.dim(), (4, 2));
1868 }
1869
1870 #[test]
1871 fn test_plssvd_transform_y() {
1872 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1873 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
1874 let svd = PLSSVD::<f64>::new(1);
1875 let fitted = svd.fit(&x, &y).unwrap();
1876 let y_scores = fitted.transform_y(&y).unwrap();
1877 assert_eq!(y_scores.ncols(), 1);
1878 }
1879
1880 #[test]
1881 fn test_plssvd_no_scale() {
1882 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
1883 let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
1884 let svd = PLSSVD::<f64>::new(1).with_scale(false);
1885 let fitted = svd.fit(&x, &y).unwrap();
1886 let scores = fitted.transform(&x).unwrap();
1887 assert_eq!(scores.ncols(), 1);
1888 }
1889
1890 #[test]
1891 fn test_plssvd_x_weights_shape() {
1892 let x = array![
1893 [1.0, 2.0, 3.0],
1894 [4.0, 5.0, 6.0],
1895 [7.0, 8.0, 9.0],
1896 [10.0, 11.0, 12.0],
1897 ];
1898 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
1899 let svd = PLSSVD::<f64>::new(2);
1900 let fitted = svd.fit(&x, &y).unwrap();
1901 assert_eq!(fitted.x_weights().dim(), (3, 2));
1902 assert_eq!(fitted.y_weights().dim(), (2, 2));
1903 }
1904
1905 #[test]
1906 fn test_plssvd_invalid_zero_components() {
1907 let x = array![[1.0, 2.0], [3.0, 4.0]];
1908 let y = array![[1.0], [2.0]];
1909 let svd = PLSSVD::<f64>::new(0);
1910 assert!(svd.fit(&x, &y).is_err());
1911 }
1912
1913 #[test]
1914 fn test_plssvd_too_many_components() {
1915 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1916 let y = array![[1.0], [2.0], [3.0]];
1917 let svd = PLSSVD::<f64>::new(2);
1919 assert!(svd.fit(&x, &y).is_err());
1920 }
1921
1922 #[test]
1923 fn test_plssvd_row_mismatch() {
1924 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1925 let y = array![[1.0], [2.0]];
1926 let svd = PLSSVD::<f64>::new(1);
1927 assert!(svd.fit(&x, &y).is_err());
1928 }
1929
1930 #[test]
1931 fn test_plssvd_insufficient_samples() {
1932 let x = array![[1.0, 2.0]];
1933 let y = array![[1.0]];
1934 let svd = PLSSVD::<f64>::new(1);
1935 assert!(svd.fit(&x, &y).is_err());
1936 }
1937
1938 #[test]
1939 fn test_plssvd_transform_shape_mismatch() {
1940 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1941 let y = array![[1.0], [2.0], [3.0]];
1942 let svd = PLSSVD::<f64>::new(1);
1943 let fitted = svd.fit(&x, &y).unwrap();
1944 let x_bad = array![[1.0, 2.0, 3.0]];
1945 assert!(fitted.transform(&x_bad).is_err());
1946 }
1947
1948 #[test]
1949 fn test_plssvd_n_components_getter() {
1950 let svd = PLSSVD::<f64>::new(3);
1951 assert_eq!(svd.n_components(), 3);
1952 }
1953
1954 #[test]
1955 fn test_plssvd_f32() {
1956 let x: Array2<f32> = array![
1957 [1.0f32, 2.0],
1958 [3.0, 4.0],
1959 [5.0, 6.0],
1960 [7.0, 8.0],
1961 [9.0, 10.0],
1962 ];
1963 let y: Array2<f32> = array![[1.0f32], [2.0], [3.0], [4.0], [5.0]];
1964 let svd = PLSSVD::<f32>::new(1);
1965 let fitted = svd.fit(&x, &y).unwrap();
1966 let scores = fitted.transform(&x).unwrap();
1967 assert_eq!(scores.ncols(), 1);
1968 }
1969
1970 #[test]
1975 fn test_plsregression_basic_fit_predict() {
1976 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
1977 let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
1978 let pls = PLSRegression::<f64>::new(1);
1979 let fitted = pls.fit(&x, &y).unwrap();
1980 let y_pred = fitted.predict(&x).unwrap();
1981 assert_eq!(y_pred.dim(), (5, 1));
1982 }
1983
1984 #[test]
1985 fn test_plsregression_prediction_quality() {
1986 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
1988 let y = array![[3.0], [7.0], [11.0], [15.0], [19.0]];
1989 let pls = PLSRegression::<f64>::new(1);
1990 let fitted = pls.fit(&x, &y).unwrap();
1991 let y_pred = fitted.predict(&x).unwrap();
1992
1993 for (pred, actual) in y_pred.column(0).iter().zip(y.column(0).iter()) {
1996 assert_abs_diff_eq!(pred, actual, epsilon = 1e-6);
1997 }
1998 }
1999
2000 #[test]
2001 fn test_plsregression_multi_target() {
2002 let x = array![
2003 [1.0, 2.0, 3.0],
2004 [4.0, 5.0, 6.0],
2005 [7.0, 8.0, 9.0],
2006 [10.0, 11.0, 12.0],
2007 [13.0, 14.0, 15.0],
2008 ];
2009 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0], [5.0, 2.5],];
2010 let pls = PLSRegression::<f64>::new(2);
2011 let fitted = pls.fit(&x, &y).unwrap();
2012 let y_pred = fitted.predict(&x).unwrap();
2013 assert_eq!(y_pred.dim(), (5, 2));
2014 }
2015
2016 #[test]
2017 fn test_plsregression_transform() {
2018 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2019 let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
2020 let pls = PLSRegression::<f64>::new(1);
2021 let fitted = pls.fit(&x, &y).unwrap();
2022 let scores = fitted.transform(&x).unwrap();
2023 assert_eq!(scores.dim(), (5, 1));
2024 }
2025
2026 #[test]
2027 fn test_plsregression_coefficients_shape() {
2028 let x = array![
2029 [1.0, 2.0, 3.0],
2030 [4.0, 5.0, 6.0],
2031 [7.0, 8.0, 9.0],
2032 [10.0, 11.0, 12.0],
2033 ];
2034 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
2035 let pls = PLSRegression::<f64>::new(2);
2036 let fitted = pls.fit(&x, &y).unwrap();
2037 assert_eq!(fitted.coefficients().dim(), (3, 2));
2039 }
2040
2041 #[test]
2042 fn test_plsregression_no_scale() {
2043 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2044 let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
2045 let pls = PLSRegression::<f64>::new(1).with_scale(false);
2046 let fitted = pls.fit(&x, &y).unwrap();
2047 let y_pred = fitted.predict(&x).unwrap();
2048 assert_eq!(y_pred.dim(), (5, 1));
2049 }
2050
2051 #[test]
2052 fn test_plsregression_builder() {
2053 let pls = PLSRegression::<f64>::new(2)
2054 .with_max_iter(1000)
2055 .with_tol(1e-8)
2056 .with_scale(false);
2057 assert_eq!(pls.n_components(), 2);
2058 }
2059
2060 #[test]
2061 fn test_plsregression_invalid_zero_components() {
2062 let x = array![[1.0, 2.0], [3.0, 4.0]];
2063 let y = array![[1.0], [2.0]];
2064 let pls = PLSRegression::<f64>::new(0);
2065 assert!(pls.fit(&x, &y).is_err());
2066 }
2067
2068 #[test]
2069 fn test_plsregression_too_many_components() {
2070 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2071 let y = array![[1.0], [2.0], [3.0]];
2072 let pls = PLSRegression::<f64>::new(2);
2074 assert!(pls.fit(&x, &y).is_err());
2075 }
2076
2077 #[test]
2078 fn test_plsregression_row_mismatch() {
2079 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2080 let y = array![[1.0], [2.0]];
2081 let pls = PLSRegression::<f64>::new(1);
2082 assert!(pls.fit(&x, &y).is_err());
2083 }
2084
2085 #[test]
2086 fn test_plsregression_insufficient_samples() {
2087 let x = array![[1.0, 2.0]];
2088 let y = array![[1.0]];
2089 let pls = PLSRegression::<f64>::new(1);
2090 assert!(pls.fit(&x, &y).is_err());
2091 }
2092
2093 #[test]
2094 fn test_plsregression_predict_shape_mismatch() {
2095 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2096 let y = array![[1.0], [2.0], [3.0]];
2097 let pls = PLSRegression::<f64>::new(1);
2098 let fitted = pls.fit(&x, &y).unwrap();
2099 let x_bad = array![[1.0, 2.0, 3.0]];
2100 assert!(fitted.predict(&x_bad).is_err());
2101 }
2102
2103 #[test]
2104 fn test_plsregression_transform_shape_mismatch() {
2105 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2106 let y = array![[1.0], [2.0], [3.0]];
2107 let pls = PLSRegression::<f64>::new(1);
2108 let fitted = pls.fit(&x, &y).unwrap();
2109 let x_bad = array![[1.0, 2.0, 3.0]];
2110 assert!(fitted.transform(&x_bad).is_err());
2111 }
2112
2113 #[test]
2114 fn test_plsregression_x_scores_shape() {
2115 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
2116 let y = array![[1.0], [2.0], [3.0], [4.0]];
2117 let pls = PLSRegression::<f64>::new(1);
2118 let fitted = pls.fit(&x, &y).unwrap();
2119 assert_eq!(fitted.x_scores().dim(), (4, 1));
2120 assert_eq!(fitted.y_scores().dim(), (4, 1));
2121 assert_eq!(fitted.n_iter().len(), 1);
2122 }
2123
2124 #[test]
2125 fn test_plsregression_f32() {
2126 let x: Array2<f32> = array![
2127 [1.0f32, 2.0],
2128 [3.0, 4.0],
2129 [5.0, 6.0],
2130 [7.0, 8.0],
2131 [9.0, 10.0],
2132 ];
2133 let y: Array2<f32> = array![[1.0f32], [2.0], [3.0], [4.0], [5.0]];
2134 let pls = PLSRegression::<f32>::new(1);
2135 let fitted = pls.fit(&x, &y).unwrap();
2136 let y_pred = fitted.predict(&x).unwrap();
2137 assert_eq!(y_pred.ncols(), 1);
2138 }
2139
2140 #[test]
2145 fn test_plscanonical_basic_fit_transform() {
2146 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2147 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0], [5.0, 2.5],];
2148 let pls = PLSCanonical::<f64>::new(2);
2149 let fitted = pls.fit(&x, &y).unwrap();
2150 let scores = fitted.transform(&x).unwrap();
2151 assert_eq!(scores.dim(), (5, 2));
2152 }
2153
2154 #[test]
2155 fn test_plscanonical_single_component() {
2156 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2157 let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
2158 let pls = PLSCanonical::<f64>::new(1);
2159 let fitted = pls.fit(&x, &y).unwrap();
2160 let scores = fitted.transform(&x).unwrap();
2161 assert_eq!(scores.ncols(), 1);
2162 }
2163
2164 #[test]
2165 fn test_plscanonical_scores_shape() {
2166 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
2167 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2168 let pls = PLSCanonical::<f64>::new(2);
2169 let fitted = pls.fit(&x, &y).unwrap();
2170 assert_eq!(fitted.x_scores().dim(), (3, 2));
2171 assert_eq!(fitted.y_scores().dim(), (3, 2));
2172 assert_eq!(fitted.x_weights().dim(), (3, 2));
2173 assert_eq!(fitted.x_loadings().dim(), (3, 2));
2174 assert_eq!(fitted.y_loadings().dim(), (2, 2));
2175 }
2176
2177 #[test]
2178 fn test_plscanonical_transform_y() {
2179 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
2180 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
2181 let pls = PLSCanonical::<f64>::new(1);
2182 let fitted = pls.fit(&x, &y).unwrap();
2183 let y_scores = fitted.transform_y(&y).unwrap();
2184 assert_eq!(y_scores.ncols(), 1);
2185 }
2186
2187 #[test]
2188 fn test_plscanonical_builder() {
2189 let pls = PLSCanonical::<f64>::new(2)
2190 .with_max_iter(1000)
2191 .with_tol(1e-8)
2192 .with_scale(false);
2193 assert_eq!(pls.n_components(), 2);
2194 }
2195
2196 #[test]
2197 fn test_plscanonical_invalid_zero_components() {
2198 let x = array![[1.0, 2.0], [3.0, 4.0]];
2199 let y = array![[1.0, 0.5], [2.0, 1.0]];
2200 let pls = PLSCanonical::<f64>::new(0);
2201 assert!(pls.fit(&x, &y).is_err());
2202 }
2203
2204 #[test]
2205 fn test_plscanonical_too_many_components() {
2206 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2207 let y = array![[1.0], [2.0], [3.0]];
2208 let pls = PLSCanonical::<f64>::new(2);
2209 assert!(pls.fit(&x, &y).is_err());
2210 }
2211
2212 #[test]
2213 fn test_plscanonical_row_mismatch() {
2214 let x = array![[1.0, 2.0], [3.0, 4.0]];
2215 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2216 let pls = PLSCanonical::<f64>::new(1);
2217 assert!(pls.fit(&x, &y).is_err());
2218 }
2219
2220 #[test]
2221 fn test_plscanonical_transform_shape_mismatch() {
2222 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2223 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2224 let pls = PLSCanonical::<f64>::new(1);
2225 let fitted = pls.fit(&x, &y).unwrap();
2226 let x_bad = array![[1.0, 2.0, 3.0]];
2227 assert!(fitted.transform(&x_bad).is_err());
2228 }
2229
2230 #[test]
2231 fn test_plscanonical_f32() {
2232 let x: Array2<f32> = array![
2233 [1.0f32, 2.0],
2234 [3.0, 4.0],
2235 [5.0, 6.0],
2236 [7.0, 8.0],
2237 [9.0, 10.0],
2238 ];
2239 let y: Array2<f32> = array![
2240 [1.0f32, 0.5],
2241 [2.0, 1.0],
2242 [3.0, 1.5],
2243 [4.0, 2.0],
2244 [5.0, 2.5],
2245 ];
2246 let pls = PLSCanonical::<f32>::new(1);
2247 let fitted = pls.fit(&x, &y).unwrap();
2248 let scores = fitted.transform(&x).unwrap();
2249 assert_eq!(scores.ncols(), 1);
2250 }
2251
2252 #[test]
2257 fn test_cca_basic_fit_transform() {
2258 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2259 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0], [5.0, 2.5],];
2260 let cca = CCA::<f64>::new(2);
2261 let fitted = cca.fit(&x, &y).unwrap();
2262 let scores = fitted.transform(&x).unwrap();
2263 assert_eq!(scores.dim(), (5, 2));
2264 }
2265
2266 #[test]
2267 fn test_cca_single_component() {
2268 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2269 let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
2270 let cca = CCA::<f64>::new(1);
2271 let fitted = cca.fit(&x, &y).unwrap();
2272 let scores = fitted.transform(&x).unwrap();
2273 assert_eq!(scores.ncols(), 1);
2274 }
2275
2276 #[test]
2277 fn test_cca_scores_shape() {
2278 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
2279 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2280 let cca = CCA::<f64>::new(2);
2281 let fitted = cca.fit(&x, &y).unwrap();
2282 assert_eq!(fitted.x_scores().dim(), (3, 2));
2283 assert_eq!(fitted.y_scores().dim(), (3, 2));
2284 assert_eq!(fitted.x_weights().dim(), (3, 2));
2285 assert_eq!(fitted.x_loadings().dim(), (3, 2));
2286 assert_eq!(fitted.y_loadings().dim(), (2, 2));
2287 }
2288
2289 #[test]
2290 fn test_cca_transform_y() {
2291 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
2292 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
2293 let cca = CCA::<f64>::new(1);
2294 let fitted = cca.fit(&x, &y).unwrap();
2295 let y_scores = fitted.transform_y(&y).unwrap();
2296 assert_eq!(y_scores.ncols(), 1);
2297 }
2298
2299 #[test]
2300 fn test_cca_builder() {
2301 let cca = CCA::<f64>::new(2)
2302 .with_max_iter(1000)
2303 .with_tol(1e-8)
2304 .with_scale(false);
2305 assert_eq!(cca.n_components(), 2);
2306 }
2307
2308 #[test]
2309 fn test_cca_invalid_zero_components() {
2310 let x = array![[1.0, 2.0], [3.0, 4.0]];
2311 let y = array![[1.0, 0.5], [2.0, 1.0]];
2312 let cca = CCA::<f64>::new(0);
2313 assert!(cca.fit(&x, &y).is_err());
2314 }
2315
2316 #[test]
2317 fn test_cca_too_many_components() {
2318 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2319 let y = array![[1.0], [2.0], [3.0]];
2320 let cca = CCA::<f64>::new(2);
2321 assert!(cca.fit(&x, &y).is_err());
2322 }
2323
2324 #[test]
2325 fn test_cca_row_mismatch() {
2326 let x = array![[1.0, 2.0], [3.0, 4.0]];
2327 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2328 let cca = CCA::<f64>::new(1);
2329 assert!(cca.fit(&x, &y).is_err());
2330 }
2331
2332 #[test]
2333 fn test_cca_transform_shape_mismatch() {
2334 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2335 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2336 let cca = CCA::<f64>::new(1);
2337 let fitted = cca.fit(&x, &y).unwrap();
2338 let x_bad = array![[1.0, 2.0, 3.0]];
2339 assert!(fitted.transform(&x_bad).is_err());
2340 }
2341
2342 #[test]
2343 fn test_cca_f32() {
2344 let x: Array2<f32> = array![
2345 [1.0f32, 2.0],
2346 [3.0, 4.0],
2347 [5.0, 6.0],
2348 [7.0, 8.0],
2349 [9.0, 10.0],
2350 ];
2351 let y: Array2<f32> = array![
2352 [1.0f32, 0.5],
2353 [2.0, 1.0],
2354 [3.0, 1.5],
2355 [4.0, 2.0],
2356 [5.0, 2.5],
2357 ];
2358 let cca = CCA::<f32>::new(1);
2359 let fitted = cca.fit(&x, &y).unwrap();
2360 let scores = fitted.transform(&x).unwrap();
2361 assert_eq!(scores.ncols(), 1);
2362 }
2363
2364 #[test]
2369 fn test_pls_regression_and_canonical_give_different_scores() {
2370 let x = array![
2371 [1.0, 2.0, 0.5],
2372 [3.0, 1.0, 2.5],
2373 [5.0, 6.0, 1.0],
2374 [7.0, 3.0, 4.5],
2375 [9.0, 10.0, 2.0],
2376 ];
2377 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0], [5.0, 2.5],];
2378
2379 let pls_reg = PLSRegression::<f64>::new(2);
2380 let fitted_reg = pls_reg.fit(&x, &y).unwrap();
2381 let scores_reg = fitted_reg.transform(&x).unwrap();
2382
2383 let pls_can = PLSCanonical::<f64>::new(2);
2384 let fitted_can = pls_can.fit(&x, &y).unwrap();
2385 let scores_can = fitted_can.transform(&x).unwrap();
2386
2387 let diff: f64 = scores_reg
2389 .iter()
2390 .zip(scores_can.iter())
2391 .map(|(a, b)| (a - b).abs())
2392 .sum();
2393 assert_eq!(scores_reg.dim(), scores_can.dim());
2396 assert!(diff.is_finite());
2398 }
2399
2400 #[test]
2401 fn test_centre_scale_helper() {
2402 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2403 let (xc, mean, std_dev) = centre_scale(&x, true);
2404 assert_abs_diff_eq!(mean[0], 3.0, epsilon = 1e-10);
2405 assert_abs_diff_eq!(mean[1], 4.0, epsilon = 1e-10);
2406 assert!(std_dev.is_some());
2407
2408 let col_mean_0: f64 = xc.column(0).iter().sum::<f64>() / 3.0;
2410 assert_abs_diff_eq!(col_mean_0, 0.0, epsilon = 1e-10);
2411 }
2412
2413 #[test]
2414 fn test_centre_scale_no_scale() {
2415 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2416 let (_xc, _mean, std_dev) = centre_scale(&x, false);
2417 assert!(std_dev.is_none());
2418 }
2419
2420 #[test]
2421 fn test_invert_square_identity() {
2422 let eye = Array2::<f64>::from_shape_fn((3, 3), |(i, j)| if i == j { 1.0 } else { 0.0 });
2423 let inv = invert_square(&eye).unwrap();
2424 for i in 0..3 {
2425 for j in 0..3 {
2426 let expected = if i == j { 1.0 } else { 0.0 };
2427 assert_abs_diff_eq!(inv[[i, j]], expected, epsilon = 1e-10);
2428 }
2429 }
2430 }
2431
2432 #[test]
2433 fn test_invert_square_2x2() {
2434 let a = array![[4.0, 7.0], [2.0, 6.0]];
2435 let inv = invert_square(&a).unwrap();
2436 let prod = a.dot(&inv);
2438 for i in 0..2 {
2439 for j in 0..2 {
2440 let expected = if i == j { 1.0 } else { 0.0 };
2441 assert_abs_diff_eq!(prod[[i, j]], expected, epsilon = 1e-10);
2442 }
2443 }
2444 }
2445}