1use ferrolearn_core::backend::Backend;
39use ferrolearn_core::backend_faer::NdarrayFaerBackend;
40use ferrolearn_core::error::FerroError;
41use ferrolearn_core::traits::{Fit, Predict, Transform};
42use ndarray::{Array1, Array2};
43use num_traits::Float;
44use std::any::TypeId;
45
46type SvdResult<F> = Result<(Array2<F>, Array1<F>, Array2<F>), FerroError>;
48
49fn centre_scale<F: Float + Send + Sync + 'static>(
58 x: &Array2<F>,
59 scale: bool,
60) -> (Array2<F>, Array1<F>, Option<Array1<F>>) {
61 let (n_samples, n_features) = x.dim();
62 let n_f = F::from(n_samples).unwrap();
63
64 let mean = Array1::from_shape_fn(n_features, |j| {
66 x.column(j).iter().copied().fold(F::zero(), |a, b| a + b) / n_f
67 });
68
69 let mut xc = x.to_owned();
71 for mut row in xc.rows_mut() {
72 for (v, &m) in row.iter_mut().zip(mean.iter()) {
73 *v = *v - m;
74 }
75 }
76
77 if scale {
78 let n_minus_1 = F::from(n_samples.saturating_sub(1).max(1)).unwrap();
79 let std_dev = Array1::from_shape_fn(n_features, |j| {
80 let var = xc
81 .column(j)
82 .iter()
83 .copied()
84 .fold(F::zero(), |a, b| a + b * b)
85 / n_minus_1;
86 let s = var.sqrt();
87 if s < F::epsilon() { F::one() } else { s }
88 });
89 for mut row in xc.rows_mut() {
90 for (v, &s) in row.iter_mut().zip(std_dev.iter()) {
91 *v = *v / s;
92 }
93 }
94 (xc, mean, Some(std_dev))
95 } else {
96 (xc, mean, None)
97 }
98}
99
100fn apply_centre_scale<F: Float + Send + Sync + 'static>(
102 x: &Array2<F>,
103 mean: &Array1<F>,
104 std_dev: &Option<Array1<F>>,
105 context: &str,
106) -> Result<Array2<F>, FerroError> {
107 if x.ncols() != mean.len() {
108 return Err(FerroError::ShapeMismatch {
109 expected: vec![x.nrows(), mean.len()],
110 actual: vec![x.nrows(), x.ncols()],
111 context: context.into(),
112 });
113 }
114 let mut xc = x.to_owned();
115 for mut row in xc.rows_mut() {
116 for (v, &m) in row.iter_mut().zip(mean.iter()) {
117 *v = *v - m;
118 }
119 }
120 if let Some(ref s) = *std_dev {
121 for mut row in xc.rows_mut() {
122 for (v, &sd) in row.iter_mut().zip(s.iter()) {
123 *v = *v / sd;
124 }
125 }
126 }
127 Ok(xc)
128}
129
130fn svd_dispatch<F: Float + Send + Sync + 'static>(a: &Array2<F>) -> SvdResult<F> {
144 if TypeId::of::<F>() == TypeId::of::<f64>() {
145 let a_f64: &Array2<f64> = unsafe { &*(std::ptr::from_ref(a).cast::<Array2<f64>>()) };
147 let (u, s, vt) = NdarrayFaerBackend::svd(a_f64)?;
148 let k = s.len();
150 let u_thin = u.slice(ndarray::s![.., ..k]).to_owned();
151 let vt_thin = vt.slice(ndarray::s![..k, ..]).to_owned();
152
153 let u_f: Array2<F> = unsafe { std::mem::transmute_copy::<Array2<f64>, Array2<F>>(&u_thin) };
155 let s_f: Array1<F> = unsafe { std::mem::transmute_copy::<Array1<f64>, Array1<F>>(&s) };
156 let vt_f: Array2<F> =
157 unsafe { std::mem::transmute_copy::<Array2<f64>, Array2<F>>(&vt_thin) };
158 std::mem::forget(u_thin);
159 std::mem::forget(s);
160 std::mem::forget(vt_thin);
161 Ok((u_f, s_f, vt_f))
162 } else if TypeId::of::<F>() == TypeId::of::<f32>() {
163 let (m, n) = a.dim();
165 let a_f64 =
166 Array2::<f64>::from_shape_fn((m, n), |(i, j)| a[[i, j]].to_f64().unwrap_or(0.0));
167 let (u64, s64, vt64) = NdarrayFaerBackend::svd(&a_f64)?;
168 let k = s64.len();
169 let u_thin = u64.slice(ndarray::s![.., ..k]).to_owned();
170 let vt_thin = vt64.slice(ndarray::s![..k, ..]).to_owned();
171
172 let u_f =
173 Array2::<F>::from_shape_fn(u_thin.dim(), |(i, j)| F::from(u_thin[[i, j]]).unwrap());
174 let s_f = Array1::<F>::from_shape_fn(s64.len(), |i| F::from(s64[i]).unwrap());
175 let vt_f =
176 Array2::<F>::from_shape_fn(vt_thin.dim(), |(i, j)| F::from(vt_thin[[i, j]]).unwrap());
177 Ok((u_f, s_f, vt_f))
178 } else {
179 svd_via_eigen(a)
181 }
182}
183
184fn svd_via_eigen<F: Float + Send + Sync + 'static>(a: &Array2<F>) -> SvdResult<F> {
186 let (m, n) = a.dim();
187 let k = m.min(n);
188
189 let ata = a.t().dot(a);
191
192 let max_iter = n * n * 100 + 1000;
194 let (eigenvalues, eigenvectors) = jacobi_eigen_symmetric(&ata, max_iter)?;
195
196 let mut indices: Vec<usize> = (0..n).collect();
198 indices.sort_by(|&i, &j| {
199 eigenvalues[j]
200 .partial_cmp(&eigenvalues[i])
201 .unwrap_or(std::cmp::Ordering::Equal)
202 });
203
204 let mut s = Array1::<F>::zeros(k);
206 let mut v = Array2::<F>::zeros((n, k));
207 for (col, &idx) in indices.iter().take(k).enumerate() {
208 let eval = eigenvalues[idx];
209 s[col] = if eval > F::zero() {
210 eval.sqrt()
211 } else {
212 F::zero()
213 };
214 for row in 0..n {
215 v[[row, col]] = eigenvectors[[row, idx]];
216 }
217 }
218
219 let av = a.dot(&v);
221 let mut u = Array2::<F>::zeros((m, k));
222 for col in 0..k {
223 if s[col] > F::epsilon() {
224 let inv_s = F::one() / s[col];
225 for row in 0..m {
226 u[[row, col]] = av[[row, col]] * inv_s;
227 }
228 }
229 }
230
231 let mut vt = Array2::<F>::zeros((k, n));
233 for i in 0..k {
234 for j in 0..n {
235 vt[[i, j]] = v[[j, i]];
236 }
237 }
238
239 Ok((u, s, vt))
240}
241
242fn jacobi_eigen_symmetric<F: Float + Send + Sync + 'static>(
244 a: &Array2<F>,
245 max_iter: usize,
246) -> Result<(Array1<F>, Array2<F>), FerroError> {
247 let n = a.nrows();
248 let mut mat = a.to_owned();
249 let mut v = Array2::<F>::zeros((n, n));
250 for i in 0..n {
251 v[[i, i]] = F::one();
252 }
253
254 let tol = F::from(1e-12).unwrap_or(F::epsilon());
255
256 for _iteration in 0..max_iter {
257 let mut max_off = F::zero();
258 let mut p = 0;
259 let mut q = 1;
260 for i in 0..n {
261 for j in (i + 1)..n {
262 let val = mat[[i, j]].abs();
263 if val > max_off {
264 max_off = val;
265 p = i;
266 q = j;
267 }
268 }
269 }
270
271 if max_off < tol {
272 let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
273 return Ok((eigenvalues, v));
274 }
275
276 let app = mat[[p, p]];
277 let aqq = mat[[q, q]];
278 let apq = mat[[p, q]];
279
280 let theta = if (app - aqq).abs() < tol {
281 F::from(std::f64::consts::FRAC_PI_4).unwrap_or(F::one())
282 } else {
283 let tau = (aqq - app) / (F::from(2.0).unwrap() * apq);
284 let t = if tau >= F::zero() {
285 F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
286 } else {
287 -F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
288 };
289 t.atan()
290 };
291
292 let c = theta.cos();
293 let s = theta.sin();
294
295 let mut new_mat = mat.clone();
296 for i in 0..n {
297 if i != p && i != q {
298 let mip = mat[[i, p]];
299 let miq = mat[[i, q]];
300 new_mat[[i, p]] = c * mip - s * miq;
301 new_mat[[p, i]] = new_mat[[i, p]];
302 new_mat[[i, q]] = s * mip + c * miq;
303 new_mat[[q, i]] = new_mat[[i, q]];
304 }
305 }
306
307 new_mat[[p, p]] = c * c * app - F::from(2.0).unwrap() * s * c * apq + s * s * aqq;
308 new_mat[[q, q]] = s * s * app + F::from(2.0).unwrap() * s * c * apq + c * c * aqq;
309 new_mat[[p, q]] = F::zero();
310 new_mat[[q, p]] = F::zero();
311
312 mat = new_mat;
313
314 for i in 0..n {
315 let vip = v[[i, p]];
316 let viq = v[[i, q]];
317 v[[i, p]] = c * vip - s * viq;
318 v[[i, q]] = s * vip + c * viq;
319 }
320 }
321
322 Err(FerroError::ConvergenceFailure {
323 iterations: max_iter,
324 message: "Jacobi eigendecomposition did not converge in cross_decomposition SVD fallback"
325 .into(),
326 })
327}
328
329fn norm<F: Float>(v: &Array1<F>) -> F {
335 v.iter().copied().fold(F::zero(), |a, b| a + b * b).sqrt()
336}
337
338fn dot<F: Float>(a: &Array1<F>, b: &Array1<F>) -> F {
340 a.iter()
341 .copied()
342 .zip(b.iter().copied())
343 .fold(F::zero(), |acc, (x, y)| acc + x * y)
344}
345
346fn invert_square<F: Float + Send + Sync + 'static>(a: &Array2<F>) -> Result<Array2<F>, FerroError> {
357 let n = a.nrows();
358 if n != a.ncols() {
359 return Err(FerroError::ShapeMismatch {
360 expected: vec![n, n],
361 actual: vec![a.nrows(), a.ncols()],
362 context: "invert_square: matrix must be square".into(),
363 });
364 }
365
366 let mut aug = Array2::<F>::zeros((n, 2 * n));
368 for i in 0..n {
369 for j in 0..n {
370 aug[[i, j]] = a[[i, j]];
371 }
372 aug[[i, n + i]] = F::one();
373 }
374
375 let max_abs = a.iter().copied().fold(F::zero(), |m, v| {
377 let abs = v.abs();
378 if abs > m { abs } else { m }
379 });
380 let regularise_tol =
381 max_abs * F::from(1e-12).unwrap_or(F::epsilon()) + F::from(1e-15).unwrap_or(F::epsilon());
382
383 for col in 0..n {
385 let mut max_val = aug[[col, col]].abs();
387 let mut max_row = col;
388 for row in (col + 1)..n {
389 let val = aug[[row, col]].abs();
390 if val > max_val {
391 max_val = val;
392 max_row = row;
393 }
394 }
395
396 if max_val < regularise_tol {
398 aug[[col, col]] = regularise_tol;
399 } else {
400 if max_row != col {
402 for j in 0..(2 * n) {
403 let tmp = aug[[col, j]];
404 aug[[col, j]] = aug[[max_row, j]];
405 aug[[max_row, j]] = tmp;
406 }
407 }
408 }
409
410 let pivot = aug[[col, col]];
412 for row in (col + 1)..n {
413 let factor = aug[[row, col]] / pivot;
414 for j in col..(2 * n) {
415 let above = aug[[col, j]];
416 aug[[row, j]] = aug[[row, j]] - factor * above;
417 }
418 }
419 }
420
421 for col in (0..n).rev() {
423 let pivot = aug[[col, col]];
424 for j in 0..(2 * n) {
425 aug[[col, j]] = aug[[col, j]] / pivot;
426 }
427 for row in 0..col {
428 let factor = aug[[row, col]];
429 for j in 0..(2 * n) {
430 let below = aug[[col, j]];
431 aug[[row, j]] = aug[[row, j]] - factor * below;
432 }
433 }
434 }
435
436 let mut inv = Array2::<F>::zeros((n, n));
438 for i in 0..n {
439 for j in 0..n {
440 inv[[i, j]] = aug[[i, n + j]];
441 }
442 }
443 Ok(inv)
444}
445
446#[derive(Debug, Clone)]
479pub struct PLSSVD<F> {
480 n_components: usize,
482 scale: bool,
484 _marker: std::marker::PhantomData<F>,
485}
486
487impl<F: Float + Send + Sync + 'static> PLSSVD<F> {
488 #[must_use]
490 pub fn new(n_components: usize) -> Self {
491 Self {
492 n_components,
493 scale: true,
494 _marker: std::marker::PhantomData,
495 }
496 }
497
498 #[must_use]
500 pub fn with_scale(mut self, scale: bool) -> Self {
501 self.scale = scale;
502 self
503 }
504
505 #[must_use]
507 pub fn n_components(&self) -> usize {
508 self.n_components
509 }
510}
511
512#[derive(Debug, Clone)]
517pub struct FittedPLSSVD<F> {
518 x_weights_: Array2<F>,
520 y_weights_: Array2<F>,
522 x_mean_: Array1<F>,
524 y_mean_: Array1<F>,
526 x_std_: Option<Array1<F>>,
528 y_std_: Option<Array1<F>>,
530}
531
532impl<F: Float + Send + Sync + 'static> FittedPLSSVD<F> {
533 #[must_use]
535 pub fn x_weights(&self) -> &Array2<F> {
536 &self.x_weights_
537 }
538
539 #[must_use]
541 pub fn y_weights(&self) -> &Array2<F> {
542 &self.y_weights_
543 }
544
545 #[must_use]
547 pub fn x_mean(&self) -> &Array1<F> {
548 &self.x_mean_
549 }
550
551 #[must_use]
553 pub fn y_mean(&self) -> &Array1<F> {
554 &self.y_mean_
555 }
556
557 pub fn transform_y(&self, y: &Array2<F>) -> Result<Array2<F>, FerroError> {
563 let yc = apply_centre_scale(y, &self.y_mean_, &self.y_std_, "FittedPLSSVD::transform_y")?;
564 Ok(yc.dot(&self.y_weights_))
565 }
566}
567
568impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array2<F>> for PLSSVD<F> {
569 type Fitted = FittedPLSSVD<F>;
570 type Error = FerroError;
571
572 fn fit(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedPLSSVD<F>, FerroError> {
581 let (n_samples_x, n_features_x) = x.dim();
582 let (n_samples_y, n_features_y) = y.dim();
583
584 if n_samples_x != n_samples_y {
585 return Err(FerroError::ShapeMismatch {
586 expected: vec![n_samples_x, n_features_y],
587 actual: vec![n_samples_y, n_features_y],
588 context: "PLSSVD::fit: X and Y must have the same number of rows".into(),
589 });
590 }
591
592 if self.n_components == 0 {
593 return Err(FerroError::InvalidParameter {
594 name: "n_components".into(),
595 reason: "must be at least 1".into(),
596 });
597 }
598
599 let max_components = n_features_x.min(n_features_y);
600 if self.n_components > max_components {
601 return Err(FerroError::InvalidParameter {
602 name: "n_components".into(),
603 reason: format!(
604 "n_components ({}) exceeds min(n_features_x, n_features_y) ({})",
605 self.n_components, max_components
606 ),
607 });
608 }
609
610 if n_samples_x < 2 {
611 return Err(FerroError::InsufficientSamples {
612 required: 2,
613 actual: n_samples_x,
614 context: "PLSSVD::fit requires at least 2 samples".into(),
615 });
616 }
617
618 let (xc, x_mean, x_std) = centre_scale(x, self.scale);
620 let (yc, y_mean, y_std) = centre_scale(y, self.scale);
621
622 let c = xc.t().dot(&yc);
624
625 let (u, _s, vt) = svd_dispatch(&c)?;
627
628 let nc = self.n_components;
630 let x_weights = u.slice(ndarray::s![.., ..nc]).to_owned();
631 let y_weights = vt.t().slice(ndarray::s![.., ..nc]).to_owned();
633
634 Ok(FittedPLSSVD {
635 x_weights_: x_weights,
636 y_weights_: y_weights,
637 x_mean_: x_mean,
638 y_mean_: y_mean,
639 x_std_: x_std,
640 y_std_: y_std,
641 })
642 }
643}
644
645impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedPLSSVD<F> {
646 type Output = Array2<F>;
647 type Error = FerroError;
648
649 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
655 let xc = apply_centre_scale(x, &self.x_mean_, &self.x_std_, "FittedPLSSVD::transform")?;
656 Ok(xc.dot(&self.x_weights_))
657 }
658}
659
660#[derive(Debug, Clone, Copy, PartialEq, Eq)]
666enum NipalsMode {
667 Regression,
669 Canonical,
671}
672
673#[derive(Debug, Clone, Copy, PartialEq, Eq)]
675enum ScoreNorm {
676 None,
678 UnitVariance,
680}
681
682#[derive(Debug, Clone)]
688struct NipalsResult<F> {
689 x_weights: Array2<F>,
691 x_loadings: Array2<F>,
693 x_scores: Array2<F>,
695 y_loadings: Array2<F>,
697 y_scores: Array2<F>,
699 n_iter: Vec<usize>,
701}
702
703fn nipals<F: Float + Send + Sync + 'static>(
705 x: &Array2<F>,
706 y: &Array2<F>,
707 n_components: usize,
708 max_iter: usize,
709 tol: F,
710 mode: NipalsMode,
711 score_norm: ScoreNorm,
712) -> Result<NipalsResult<F>, FerroError> {
713 let (n_samples, n_features_x) = x.dim();
714 let n_features_y = y.ncols();
715
716 let mut xk = x.to_owned();
717 let mut yk = y.to_owned();
718
719 let mut x_weights = Array2::<F>::zeros((n_features_x, n_components));
720 let mut x_loadings = Array2::<F>::zeros((n_features_x, n_components));
721 let mut x_scores = Array2::<F>::zeros((n_samples, n_components));
722 let mut y_loadings = Array2::<F>::zeros((n_features_y, n_components));
723 let mut y_scores = Array2::<F>::zeros((n_samples, n_components));
724 let mut n_iter_vec = Vec::with_capacity(n_components);
725
726 for k in 0..n_components {
727 let best_col = (0..n_features_y)
729 .max_by(|&a, &b| {
730 let var_a: F = yk
731 .column(a)
732 .iter()
733 .copied()
734 .fold(F::zero(), |s, v| s + v * v);
735 let var_b: F = yk
736 .column(b)
737 .iter()
738 .copied()
739 .fold(F::zero(), |s, v| s + v * v);
740 var_a
741 .partial_cmp(&var_b)
742 .unwrap_or(std::cmp::Ordering::Equal)
743 })
744 .unwrap_or(0);
745
746 let mut u = yk.column(best_col).to_owned();
747
748 let mut converged = false;
749 let mut iters = 0;
750
751 for iteration in 0..max_iter {
752 iters = iteration + 1;
753
754 let utu = dot(&u, &u);
756 let mut w = xk.t().dot(&u);
757 if utu > F::epsilon() {
758 w.mapv_inplace(|v| v / utu);
759 }
760 let w_norm = norm(&w);
762 if w_norm < F::epsilon() {
763 break;
765 }
766 w.mapv_inplace(|v| v / w_norm);
767
768 let t = xk.dot(&w);
770
771 let ttt = dot(&t, &t);
773 let mut q = yk.t().dot(&t);
774 if ttt > F::epsilon() {
775 q.mapv_inplace(|v| v / ttt);
776 }
777
778 if score_norm == ScoreNorm::UnitVariance {
780 let q_norm = norm(&q);
781 if q_norm > F::epsilon() {
782 q.mapv_inplace(|v| v / q_norm);
783 }
784 }
785
786 let qtq = dot(&q, &q);
788 let mut u_new = yk.dot(&q);
789 if qtq > F::epsilon() {
790 u_new.mapv_inplace(|v| v / qtq);
791 }
792
793 let diff_norm = {
798 let diff: Array1<F> = &u_new - &u;
799 norm(&diff)
800 };
801 let u_new_norm = norm(&u_new);
802
803 u = u_new;
804
805 if u_new_norm > F::epsilon() && diff_norm / u_new_norm < tol {
806 converged = true;
807 let utu2 = dot(&u, &u);
810 w = xk.t().dot(&u);
811 if utu2 > F::epsilon() {
812 w.mapv_inplace(|v| v / utu2);
813 }
814 let w_norm2 = norm(&w);
815 if w_norm2 > F::epsilon() {
816 w.mapv_inplace(|v| v / w_norm2);
817 }
818 let t_final = xk.dot(&w);
820 let ttt2 = dot(&t_final, &t_final);
821 q = yk.t().dot(&t_final);
822 if ttt2 > F::epsilon() {
823 q.mapv_inplace(|v| v / ttt2);
824 }
825 if score_norm == ScoreNorm::UnitVariance {
826 let q_norm2 = norm(&q);
827 if q_norm2 > F::epsilon() {
828 q.mapv_inplace(|v| v / q_norm2);
829 }
830 }
831 let qtq2 = dot(&q, &q);
832 u = yk.dot(&q);
833 if qtq2 > F::epsilon() {
834 u.mapv_inplace(|v| v / qtq2);
835 }
836 break;
837 }
838 }
839
840 let utu_final = dot(&u, &u);
842 let mut w_final = xk.t().dot(&u);
843 if utu_final > F::epsilon() {
844 w_final.mapv_inplace(|v| v / utu_final);
845 }
846 let w_norm_final = norm(&w_final);
847 if w_norm_final > F::epsilon() {
848 w_final.mapv_inplace(|v| v / w_norm_final);
849 }
850
851 let mut t_final = xk.dot(&w_final);
852 let ttt_final = dot(&t_final, &t_final);
853
854 let mut p = xk.t().dot(&t_final);
856 if ttt_final > F::epsilon() {
857 p.mapv_inplace(|v| v / ttt_final);
858 }
859
860 let mut q_final = yk.t().dot(&t_final);
862 if ttt_final > F::epsilon() {
863 q_final.mapv_inplace(|v| v / ttt_final);
864 }
865
866 if score_norm == ScoreNorm::UnitVariance {
867 let q_norm = norm(&q_final);
868 if q_norm > F::epsilon() {
869 q_final.mapv_inplace(|v| v / q_norm);
870 }
871 }
872
873 let qtq_final = dot(&q_final, &q_final);
874 let mut u_final = yk.dot(&q_final);
875 if qtq_final > F::epsilon() {
876 u_final.mapv_inplace(|v| v / qtq_final);
877 }
878
879 if score_norm == ScoreNorm::UnitVariance {
881 let t_std = {
882 let t_mean = t_final.iter().copied().fold(F::zero(), |a, b| a + b)
883 / F::from(n_samples).unwrap();
884 let var = t_final
885 .iter()
886 .copied()
887 .fold(F::zero(), |a, b| a + (b - t_mean) * (b - t_mean))
888 / F::from(n_samples.saturating_sub(1).max(1)).unwrap();
889 var.sqrt()
890 };
891 if t_std > F::epsilon() {
892 t_final.mapv_inplace(|v| v / t_std);
893 }
894
895 let u_std = {
896 let u_mean = u_final.iter().copied().fold(F::zero(), |a, b| a + b)
897 / F::from(n_samples).unwrap();
898 let var = u_final
899 .iter()
900 .copied()
901 .fold(F::zero(), |a, b| a + (b - u_mean) * (b - u_mean))
902 / F::from(n_samples.saturating_sub(1).max(1)).unwrap();
903 var.sqrt()
904 };
905 if u_std > F::epsilon() {
906 u_final.mapv_inplace(|v| v / u_std);
907 }
908 }
909
910 x_weights.column_mut(k).assign(&w_final);
912 x_loadings.column_mut(k).assign(&p);
913 x_scores.column_mut(k).assign(&t_final);
914 y_loadings.column_mut(k).assign(&q_final);
915 y_scores.column_mut(k).assign(&u_final);
916
917 for i in 0..n_samples {
919 let ti = t_final[i];
920 for j in 0..n_features_x {
921 xk[[i, j]] = xk[[i, j]] - ti * p[j];
922 }
923 }
924
925 match mode {
927 NipalsMode::Regression => {
928 for i in 0..n_samples {
930 let ti = t_final[i];
931 for j in 0..n_features_y {
932 yk[[i, j]] = yk[[i, j]] - ti * q_final[j];
933 }
934 }
935 }
936 NipalsMode::Canonical => {
937 let utu_c = dot(&u_final, &u_final);
939 let mut c = yk.t().dot(&u_final);
940 if utu_c > F::epsilon() {
941 c.mapv_inplace(|v| v / utu_c);
942 }
943 for i in 0..n_samples {
944 let ui = u_final[i];
945 for j in 0..n_features_y {
946 yk[[i, j]] = yk[[i, j]] - ui * c[j];
947 }
948 }
949 }
950 }
951
952 n_iter_vec.push(iters);
953
954 if !converged && n_features_y > 1 && iters == max_iter {
955 return Err(FerroError::ConvergenceFailure {
956 iterations: max_iter,
957 message: format!("NIPALS did not converge for component {k}"),
958 });
959 }
960 }
961
962 Ok(NipalsResult {
963 x_weights,
964 x_loadings,
965 x_scores,
966 y_loadings,
967 y_scores,
968 n_iter: n_iter_vec,
969 })
970}
971
972#[derive(Debug, Clone)]
1007pub struct PLSRegression<F> {
1008 n_components: usize,
1010 max_iter: usize,
1012 tol: F,
1014 scale: bool,
1016 _marker: std::marker::PhantomData<F>,
1017}
1018
1019impl<F: Float + Send + Sync + 'static> PLSRegression<F> {
1020 #[must_use]
1024 pub fn new(n_components: usize) -> Self {
1025 Self {
1026 n_components,
1027 max_iter: 500,
1028 tol: F::from(1e-6).unwrap_or(F::epsilon()),
1029 scale: true,
1030 _marker: std::marker::PhantomData,
1031 }
1032 }
1033
1034 #[must_use]
1036 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1037 self.max_iter = max_iter;
1038 self
1039 }
1040
1041 #[must_use]
1043 pub fn with_tol(mut self, tol: F) -> Self {
1044 self.tol = tol;
1045 self
1046 }
1047
1048 #[must_use]
1050 pub fn with_scale(mut self, scale: bool) -> Self {
1051 self.scale = scale;
1052 self
1053 }
1054
1055 #[must_use]
1057 pub fn n_components(&self) -> usize {
1058 self.n_components
1059 }
1060}
1061
1062#[derive(Debug, Clone)]
1068pub struct FittedPLSRegression<F> {
1069 x_weights_: Array2<F>,
1071 x_loadings_: Array2<F>,
1073 y_loadings_: Array2<F>,
1075 coefficients_: Array2<F>,
1078 x_scores_: Array2<F>,
1080 y_scores_: Array2<F>,
1082 n_iter_: Vec<usize>,
1084 x_mean_: Array1<F>,
1086 y_mean_: Array1<F>,
1088 x_std_: Option<Array1<F>>,
1090 y_std_: Option<Array1<F>>,
1092}
1093
1094impl<F: Float + Send + Sync + 'static> FittedPLSRegression<F> {
1095 #[must_use]
1097 pub fn x_weights(&self) -> &Array2<F> {
1098 &self.x_weights_
1099 }
1100
1101 #[must_use]
1103 pub fn x_loadings(&self) -> &Array2<F> {
1104 &self.x_loadings_
1105 }
1106
1107 #[must_use]
1109 pub fn y_loadings(&self) -> &Array2<F> {
1110 &self.y_loadings_
1111 }
1112
1113 #[must_use]
1117 pub fn coefficients(&self) -> &Array2<F> {
1118 &self.coefficients_
1119 }
1120
1121 #[must_use]
1123 pub fn x_scores(&self) -> &Array2<F> {
1124 &self.x_scores_
1125 }
1126
1127 #[must_use]
1129 pub fn y_scores(&self) -> &Array2<F> {
1130 &self.y_scores_
1131 }
1132
1133 #[must_use]
1135 pub fn n_iter(&self) -> &[usize] {
1136 &self.n_iter_
1137 }
1138}
1139
1140impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array2<F>> for PLSRegression<F> {
1141 type Fitted = FittedPLSRegression<F>;
1142 type Error = FerroError;
1143
1144 fn fit(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedPLSRegression<F>, FerroError> {
1153 let (n_samples_x, n_features_x) = x.dim();
1154 let (n_samples_y, n_features_y) = y.dim();
1155
1156 if n_samples_x != n_samples_y {
1157 return Err(FerroError::ShapeMismatch {
1158 expected: vec![n_samples_x, n_features_y],
1159 actual: vec![n_samples_y, n_features_y],
1160 context: "PLSRegression::fit: X and Y must have the same number of rows".into(),
1161 });
1162 }
1163
1164 if self.n_components == 0 {
1165 return Err(FerroError::InvalidParameter {
1166 name: "n_components".into(),
1167 reason: "must be at least 1".into(),
1168 });
1169 }
1170
1171 let max_components = n_features_x.min(n_features_y).min(n_samples_x);
1172 if self.n_components > max_components {
1173 return Err(FerroError::InvalidParameter {
1174 name: "n_components".into(),
1175 reason: format!(
1176 "n_components ({}) exceeds min(n_features_x, n_features_y, n_samples) ({})",
1177 self.n_components, max_components
1178 ),
1179 });
1180 }
1181
1182 if n_samples_x < 2 {
1183 return Err(FerroError::InsufficientSamples {
1184 required: 2,
1185 actual: n_samples_x,
1186 context: "PLSRegression::fit requires at least 2 samples".into(),
1187 });
1188 }
1189
1190 let (xc, x_mean, x_std) = centre_scale(x, self.scale);
1192 let (yc, y_mean, y_std) = centre_scale(y, self.scale);
1193
1194 let result = nipals(
1196 &xc,
1197 &yc,
1198 self.n_components,
1199 self.max_iter,
1200 self.tol,
1201 NipalsMode::Regression,
1202 ScoreNorm::None,
1203 )?;
1204
1205 let ptw = result.x_loadings.t().dot(&result.x_weights);
1207 let ptw_inv = invert_square(&ptw)?;
1208 let coefficients = result.x_weights.dot(&ptw_inv).dot(&result.y_loadings.t());
1209
1210 Ok(FittedPLSRegression {
1216 x_weights_: result.x_weights,
1217 x_loadings_: result.x_loadings,
1218 y_loadings_: result.y_loadings,
1219 coefficients_: coefficients,
1220 x_scores_: result.x_scores,
1221 y_scores_: result.y_scores,
1222 n_iter_: result.n_iter,
1223 x_mean_: x_mean,
1224 y_mean_: y_mean,
1225 x_std_: x_std,
1226 y_std_: y_std,
1227 })
1228 }
1229}
1230
1231impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedPLSRegression<F> {
1232 type Output = Array2<F>;
1233 type Error = FerroError;
1234
1235 fn predict(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1243 let xc = apply_centre_scale(
1244 x,
1245 &self.x_mean_,
1246 &self.x_std_,
1247 "FittedPLSRegression::predict",
1248 )?;
1249
1250 let mut y_pred = xc.dot(&self.coefficients_);
1251
1252 if let Some(ref ys) = self.y_std_ {
1254 for mut row in y_pred.rows_mut() {
1255 for (v, &s) in row.iter_mut().zip(ys.iter()) {
1256 *v = *v * s;
1257 }
1258 }
1259 }
1260
1261 for mut row in y_pred.rows_mut() {
1263 for (v, &m) in row.iter_mut().zip(self.y_mean_.iter()) {
1264 *v = *v + m;
1265 }
1266 }
1267
1268 Ok(y_pred)
1269 }
1270}
1271
1272impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedPLSRegression<F> {
1273 type Output = Array2<F>;
1274 type Error = FerroError;
1275
1276 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1284 let xc = apply_centre_scale(
1285 x,
1286 &self.x_mean_,
1287 &self.x_std_,
1288 "FittedPLSRegression::transform",
1289 )?;
1290
1291 let ptw = self.x_loadings_.t().dot(&self.x_weights_);
1295 let ptw_inv = invert_square(&ptw)?;
1296 let rotation = self.x_weights_.dot(&ptw_inv);
1297 Ok(xc.dot(&rotation))
1298 }
1299}
1300
1301#[derive(Debug, Clone)]
1334pub struct PLSCanonical<F> {
1335 n_components: usize,
1337 max_iter: usize,
1339 tol: F,
1341 scale: bool,
1343 _marker: std::marker::PhantomData<F>,
1344}
1345
1346impl<F: Float + Send + Sync + 'static> PLSCanonical<F> {
1347 #[must_use]
1351 pub fn new(n_components: usize) -> Self {
1352 Self {
1353 n_components,
1354 max_iter: 500,
1355 tol: F::from(1e-6).unwrap_or(F::epsilon()),
1356 scale: true,
1357 _marker: std::marker::PhantomData,
1358 }
1359 }
1360
1361 #[must_use]
1363 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1364 self.max_iter = max_iter;
1365 self
1366 }
1367
1368 #[must_use]
1370 pub fn with_tol(mut self, tol: F) -> Self {
1371 self.tol = tol;
1372 self
1373 }
1374
1375 #[must_use]
1377 pub fn with_scale(mut self, scale: bool) -> Self {
1378 self.scale = scale;
1379 self
1380 }
1381
1382 #[must_use]
1384 pub fn n_components(&self) -> usize {
1385 self.n_components
1386 }
1387}
1388
1389#[derive(Debug, Clone)]
1395pub struct FittedPLSCanonical<F> {
1396 x_weights_: Array2<F>,
1398 x_loadings_: Array2<F>,
1400 y_loadings_: Array2<F>,
1402 x_scores_: Array2<F>,
1404 y_scores_: Array2<F>,
1406 n_iter_: Vec<usize>,
1408 x_mean_: Array1<F>,
1410 y_mean_: Array1<F>,
1412 x_std_: Option<Array1<F>>,
1414 #[allow(dead_code)]
1416 y_std_: Option<Array1<F>>,
1417}
1418
1419impl<F: Float + Send + Sync + 'static> FittedPLSCanonical<F> {
1420 #[must_use]
1422 pub fn x_weights(&self) -> &Array2<F> {
1423 &self.x_weights_
1424 }
1425
1426 #[must_use]
1428 pub fn x_loadings(&self) -> &Array2<F> {
1429 &self.x_loadings_
1430 }
1431
1432 #[must_use]
1434 pub fn y_loadings(&self) -> &Array2<F> {
1435 &self.y_loadings_
1436 }
1437
1438 #[must_use]
1440 pub fn x_scores(&self) -> &Array2<F> {
1441 &self.x_scores_
1442 }
1443
1444 #[must_use]
1446 pub fn y_scores(&self) -> &Array2<F> {
1447 &self.y_scores_
1448 }
1449
1450 #[must_use]
1452 pub fn n_iter(&self) -> &[usize] {
1453 &self.n_iter_
1454 }
1455
1456 pub fn transform_y(&self, y: &Array2<F>) -> Result<Array2<F>, FerroError> {
1462 let yc = apply_centre_scale(
1463 y,
1464 &self.y_mean_,
1465 &self.y_std_,
1466 "FittedPLSCanonical::transform_y",
1467 )?;
1468 Ok(yc.dot(&self.y_loadings_))
1469 }
1470}
1471
1472impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array2<F>> for PLSCanonical<F> {
1473 type Fitted = FittedPLSCanonical<F>;
1474 type Error = FerroError;
1475
1476 fn fit(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedPLSCanonical<F>, FerroError> {
1485 let (n_samples_x, n_features_x) = x.dim();
1486 let (n_samples_y, n_features_y) = y.dim();
1487
1488 if n_samples_x != n_samples_y {
1489 return Err(FerroError::ShapeMismatch {
1490 expected: vec![n_samples_x, n_features_y],
1491 actual: vec![n_samples_y, n_features_y],
1492 context: "PLSCanonical::fit: X and Y must have the same number of rows".into(),
1493 });
1494 }
1495
1496 if self.n_components == 0 {
1497 return Err(FerroError::InvalidParameter {
1498 name: "n_components".into(),
1499 reason: "must be at least 1".into(),
1500 });
1501 }
1502
1503 let max_components = n_features_x.min(n_features_y).min(n_samples_x);
1504 if self.n_components > max_components {
1505 return Err(FerroError::InvalidParameter {
1506 name: "n_components".into(),
1507 reason: format!(
1508 "n_components ({}) exceeds min(n_features_x, n_features_y, n_samples) ({})",
1509 self.n_components, max_components
1510 ),
1511 });
1512 }
1513
1514 if n_samples_x < 2 {
1515 return Err(FerroError::InsufficientSamples {
1516 required: 2,
1517 actual: n_samples_x,
1518 context: "PLSCanonical::fit requires at least 2 samples".into(),
1519 });
1520 }
1521
1522 let (xc, x_mean, x_std) = centre_scale(x, self.scale);
1523 let (yc, y_mean, y_std) = centre_scale(y, self.scale);
1524
1525 let result = nipals(
1526 &xc,
1527 &yc,
1528 self.n_components,
1529 self.max_iter,
1530 self.tol,
1531 NipalsMode::Canonical,
1532 ScoreNorm::None,
1533 )?;
1534
1535 Ok(FittedPLSCanonical {
1536 x_weights_: result.x_weights,
1537 x_loadings_: result.x_loadings,
1538 y_loadings_: result.y_loadings,
1539 x_scores_: result.x_scores,
1540 y_scores_: result.y_scores,
1541 n_iter_: result.n_iter,
1542 x_mean_: x_mean,
1543 y_mean_: y_mean,
1544 x_std_: x_std,
1545 y_std_: y_std,
1546 })
1547 }
1548}
1549
1550impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedPLSCanonical<F> {
1551 type Output = Array2<F>;
1552 type Error = FerroError;
1553
1554 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1560 let xc = apply_centre_scale(
1561 x,
1562 &self.x_mean_,
1563 &self.x_std_,
1564 "FittedPLSCanonical::transform",
1565 )?;
1566
1567 let ptw = self.x_loadings_.t().dot(&self.x_weights_);
1568 let ptw_inv = invert_square(&ptw)?;
1569 let rotation = self.x_weights_.dot(&ptw_inv);
1570 Ok(xc.dot(&rotation))
1571 }
1572}
1573
1574#[derive(Debug, Clone)]
1604pub struct CCA<F> {
1605 n_components: usize,
1607 max_iter: usize,
1609 tol: F,
1611 scale: bool,
1613 _marker: std::marker::PhantomData<F>,
1614}
1615
1616impl<F: Float + Send + Sync + 'static> CCA<F> {
1617 #[must_use]
1621 pub fn new(n_components: usize) -> Self {
1622 Self {
1623 n_components,
1624 max_iter: 500,
1625 tol: F::from(1e-6).unwrap_or(F::epsilon()),
1626 scale: true,
1627 _marker: std::marker::PhantomData,
1628 }
1629 }
1630
1631 #[must_use]
1633 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1634 self.max_iter = max_iter;
1635 self
1636 }
1637
1638 #[must_use]
1640 pub fn with_tol(mut self, tol: F) -> Self {
1641 self.tol = tol;
1642 self
1643 }
1644
1645 #[must_use]
1647 pub fn with_scale(mut self, scale: bool) -> Self {
1648 self.scale = scale;
1649 self
1650 }
1651
1652 #[must_use]
1654 pub fn n_components(&self) -> usize {
1655 self.n_components
1656 }
1657}
1658
1659#[derive(Debug, Clone)]
1664pub struct FittedCCA<F> {
1665 x_weights_: Array2<F>,
1667 x_loadings_: Array2<F>,
1669 y_loadings_: Array2<F>,
1671 x_scores_: Array2<F>,
1673 y_scores_: Array2<F>,
1675 n_iter_: Vec<usize>,
1677 x_mean_: Array1<F>,
1679 y_mean_: Array1<F>,
1681 x_std_: Option<Array1<F>>,
1683 #[allow(dead_code)]
1685 y_std_: Option<Array1<F>>,
1686}
1687
1688impl<F: Float + Send + Sync + 'static> FittedCCA<F> {
1689 #[must_use]
1691 pub fn x_weights(&self) -> &Array2<F> {
1692 &self.x_weights_
1693 }
1694
1695 #[must_use]
1697 pub fn x_loadings(&self) -> &Array2<F> {
1698 &self.x_loadings_
1699 }
1700
1701 #[must_use]
1703 pub fn y_loadings(&self) -> &Array2<F> {
1704 &self.y_loadings_
1705 }
1706
1707 #[must_use]
1709 pub fn x_scores(&self) -> &Array2<F> {
1710 &self.x_scores_
1711 }
1712
1713 #[must_use]
1715 pub fn y_scores(&self) -> &Array2<F> {
1716 &self.y_scores_
1717 }
1718
1719 #[must_use]
1721 pub fn n_iter(&self) -> &[usize] {
1722 &self.n_iter_
1723 }
1724
1725 pub fn transform_y(&self, y: &Array2<F>) -> Result<Array2<F>, FerroError> {
1731 let yc = apply_centre_scale(y, &self.y_mean_, &self.y_std_, "FittedCCA::transform_y")?;
1732 Ok(yc.dot(&self.y_loadings_))
1733 }
1734}
1735
1736impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array2<F>> for CCA<F> {
1737 type Fitted = FittedCCA<F>;
1738 type Error = FerroError;
1739
1740 fn fit(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedCCA<F>, FerroError> {
1749 let (n_samples_x, n_features_x) = x.dim();
1750 let (n_samples_y, n_features_y) = y.dim();
1751
1752 if n_samples_x != n_samples_y {
1753 return Err(FerroError::ShapeMismatch {
1754 expected: vec![n_samples_x, n_features_y],
1755 actual: vec![n_samples_y, n_features_y],
1756 context: "CCA::fit: X and Y must have the same number of rows".into(),
1757 });
1758 }
1759
1760 if self.n_components == 0 {
1761 return Err(FerroError::InvalidParameter {
1762 name: "n_components".into(),
1763 reason: "must be at least 1".into(),
1764 });
1765 }
1766
1767 let max_components = n_features_x.min(n_features_y).min(n_samples_x);
1768 if self.n_components > max_components {
1769 return Err(FerroError::InvalidParameter {
1770 name: "n_components".into(),
1771 reason: format!(
1772 "n_components ({}) exceeds min(n_features_x, n_features_y, n_samples) ({})",
1773 self.n_components, max_components
1774 ),
1775 });
1776 }
1777
1778 if n_samples_x < 2 {
1779 return Err(FerroError::InsufficientSamples {
1780 required: 2,
1781 actual: n_samples_x,
1782 context: "CCA::fit requires at least 2 samples".into(),
1783 });
1784 }
1785
1786 let (xc, x_mean, x_std) = centre_scale(x, self.scale);
1787 let (yc, y_mean, y_std) = centre_scale(y, self.scale);
1788
1789 let result = nipals(
1790 &xc,
1791 &yc,
1792 self.n_components,
1793 self.max_iter,
1794 self.tol,
1795 NipalsMode::Canonical,
1796 ScoreNorm::UnitVariance,
1797 )?;
1798
1799 Ok(FittedCCA {
1800 x_weights_: result.x_weights,
1801 x_loadings_: result.x_loadings,
1802 y_loadings_: result.y_loadings,
1803 x_scores_: result.x_scores,
1804 y_scores_: result.y_scores,
1805 n_iter_: result.n_iter,
1806 x_mean_: x_mean,
1807 y_mean_: y_mean,
1808 x_std_: x_std,
1809 y_std_: y_std,
1810 })
1811 }
1812}
1813
1814impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedCCA<F> {
1815 type Output = Array2<F>;
1816 type Error = FerroError;
1817
1818 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1824 let xc = apply_centre_scale(x, &self.x_mean_, &self.x_std_, "FittedCCA::transform")?;
1825
1826 let ptw = self.x_loadings_.t().dot(&self.x_weights_);
1827 let ptw_inv = invert_square(&ptw)?;
1828 let rotation = self.x_weights_.dot(&ptw_inv);
1829 Ok(xc.dot(&rotation))
1830 }
1831}
1832
1833#[cfg(test)]
1838mod tests {
1839 use super::*;
1840 use approx::assert_abs_diff_eq;
1841 use ndarray::array;
1842
1843 #[test]
1848 fn test_plssvd_basic_fit_transform() {
1849 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
1850 let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
1851 let svd = PLSSVD::<f64>::new(1);
1852 let fitted = svd.fit(&x, &y).unwrap();
1853 let scores = fitted.transform(&x).unwrap();
1854 assert_eq!(scores.dim(), (5, 1));
1855 }
1856
1857 #[test]
1858 fn test_plssvd_two_components() {
1859 let x = array![
1860 [1.0, 2.0, 3.0],
1861 [4.0, 5.0, 6.0],
1862 [7.0, 8.0, 9.0],
1863 [10.0, 11.0, 12.0],
1864 ];
1865 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
1866 let svd = PLSSVD::<f64>::new(2);
1867 let fitted = svd.fit(&x, &y).unwrap();
1868 let scores = fitted.transform(&x).unwrap();
1869 assert_eq!(scores.dim(), (4, 2));
1870 }
1871
1872 #[test]
1873 fn test_plssvd_transform_y() {
1874 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1875 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
1876 let svd = PLSSVD::<f64>::new(1);
1877 let fitted = svd.fit(&x, &y).unwrap();
1878 let y_scores = fitted.transform_y(&y).unwrap();
1879 assert_eq!(y_scores.ncols(), 1);
1880 }
1881
1882 #[test]
1883 fn test_plssvd_no_scale() {
1884 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
1885 let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
1886 let svd = PLSSVD::<f64>::new(1).with_scale(false);
1887 let fitted = svd.fit(&x, &y).unwrap();
1888 let scores = fitted.transform(&x).unwrap();
1889 assert_eq!(scores.ncols(), 1);
1890 }
1891
1892 #[test]
1893 fn test_plssvd_x_weights_shape() {
1894 let x = array![
1895 [1.0, 2.0, 3.0],
1896 [4.0, 5.0, 6.0],
1897 [7.0, 8.0, 9.0],
1898 [10.0, 11.0, 12.0],
1899 ];
1900 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
1901 let svd = PLSSVD::<f64>::new(2);
1902 let fitted = svd.fit(&x, &y).unwrap();
1903 assert_eq!(fitted.x_weights().dim(), (3, 2));
1904 assert_eq!(fitted.y_weights().dim(), (2, 2));
1905 }
1906
1907 #[test]
1908 fn test_plssvd_invalid_zero_components() {
1909 let x = array![[1.0, 2.0], [3.0, 4.0]];
1910 let y = array![[1.0], [2.0]];
1911 let svd = PLSSVD::<f64>::new(0);
1912 assert!(svd.fit(&x, &y).is_err());
1913 }
1914
1915 #[test]
1916 fn test_plssvd_too_many_components() {
1917 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1918 let y = array![[1.0], [2.0], [3.0]];
1919 let svd = PLSSVD::<f64>::new(2);
1921 assert!(svd.fit(&x, &y).is_err());
1922 }
1923
1924 #[test]
1925 fn test_plssvd_row_mismatch() {
1926 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1927 let y = array![[1.0], [2.0]];
1928 let svd = PLSSVD::<f64>::new(1);
1929 assert!(svd.fit(&x, &y).is_err());
1930 }
1931
1932 #[test]
1933 fn test_plssvd_insufficient_samples() {
1934 let x = array![[1.0, 2.0]];
1935 let y = array![[1.0]];
1936 let svd = PLSSVD::<f64>::new(1);
1937 assert!(svd.fit(&x, &y).is_err());
1938 }
1939
1940 #[test]
1941 fn test_plssvd_transform_shape_mismatch() {
1942 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1943 let y = array![[1.0], [2.0], [3.0]];
1944 let svd = PLSSVD::<f64>::new(1);
1945 let fitted = svd.fit(&x, &y).unwrap();
1946 let x_bad = array![[1.0, 2.0, 3.0]];
1947 assert!(fitted.transform(&x_bad).is_err());
1948 }
1949
1950 #[test]
1951 fn test_plssvd_n_components_getter() {
1952 let svd = PLSSVD::<f64>::new(3);
1953 assert_eq!(svd.n_components(), 3);
1954 }
1955
1956 #[test]
1957 fn test_plssvd_f32() {
1958 let x: Array2<f32> = array![
1959 [1.0f32, 2.0],
1960 [3.0, 4.0],
1961 [5.0, 6.0],
1962 [7.0, 8.0],
1963 [9.0, 10.0],
1964 ];
1965 let y: Array2<f32> = array![[1.0f32], [2.0], [3.0], [4.0], [5.0]];
1966 let svd = PLSSVD::<f32>::new(1);
1967 let fitted = svd.fit(&x, &y).unwrap();
1968 let scores = fitted.transform(&x).unwrap();
1969 assert_eq!(scores.ncols(), 1);
1970 }
1971
1972 #[test]
1977 fn test_plsregression_basic_fit_predict() {
1978 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
1979 let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
1980 let pls = PLSRegression::<f64>::new(1);
1981 let fitted = pls.fit(&x, &y).unwrap();
1982 let y_pred = fitted.predict(&x).unwrap();
1983 assert_eq!(y_pred.dim(), (5, 1));
1984 }
1985
1986 #[test]
1987 fn test_plsregression_prediction_quality() {
1988 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
1990 let y = array![[3.0], [7.0], [11.0], [15.0], [19.0]];
1991 let pls = PLSRegression::<f64>::new(1);
1992 let fitted = pls.fit(&x, &y).unwrap();
1993 let y_pred = fitted.predict(&x).unwrap();
1994
1995 for (pred, actual) in y_pred.column(0).iter().zip(y.column(0).iter()) {
1998 assert_abs_diff_eq!(pred, actual, epsilon = 1e-6);
1999 }
2000 }
2001
2002 #[test]
2003 fn test_plsregression_multi_target() {
2004 let x = array![
2005 [1.0, 2.0, 3.0],
2006 [4.0, 5.0, 6.0],
2007 [7.0, 8.0, 9.0],
2008 [10.0, 11.0, 12.0],
2009 [13.0, 14.0, 15.0],
2010 ];
2011 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0], [5.0, 2.5],];
2012 let pls = PLSRegression::<f64>::new(2);
2013 let fitted = pls.fit(&x, &y).unwrap();
2014 let y_pred = fitted.predict(&x).unwrap();
2015 assert_eq!(y_pred.dim(), (5, 2));
2016 }
2017
2018 #[test]
2019 fn test_plsregression_transform() {
2020 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2021 let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
2022 let pls = PLSRegression::<f64>::new(1);
2023 let fitted = pls.fit(&x, &y).unwrap();
2024 let scores = fitted.transform(&x).unwrap();
2025 assert_eq!(scores.dim(), (5, 1));
2026 }
2027
2028 #[test]
2029 fn test_plsregression_coefficients_shape() {
2030 let x = array![
2031 [1.0, 2.0, 3.0],
2032 [4.0, 5.0, 6.0],
2033 [7.0, 8.0, 9.0],
2034 [10.0, 11.0, 12.0],
2035 ];
2036 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
2037 let pls = PLSRegression::<f64>::new(2);
2038 let fitted = pls.fit(&x, &y).unwrap();
2039 assert_eq!(fitted.coefficients().dim(), (3, 2));
2041 }
2042
2043 #[test]
2044 fn test_plsregression_no_scale() {
2045 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2046 let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
2047 let pls = PLSRegression::<f64>::new(1).with_scale(false);
2048 let fitted = pls.fit(&x, &y).unwrap();
2049 let y_pred = fitted.predict(&x).unwrap();
2050 assert_eq!(y_pred.dim(), (5, 1));
2051 }
2052
2053 #[test]
2054 fn test_plsregression_builder() {
2055 let pls = PLSRegression::<f64>::new(2)
2056 .with_max_iter(1000)
2057 .with_tol(1e-8)
2058 .with_scale(false);
2059 assert_eq!(pls.n_components(), 2);
2060 }
2061
2062 #[test]
2063 fn test_plsregression_invalid_zero_components() {
2064 let x = array![[1.0, 2.0], [3.0, 4.0]];
2065 let y = array![[1.0], [2.0]];
2066 let pls = PLSRegression::<f64>::new(0);
2067 assert!(pls.fit(&x, &y).is_err());
2068 }
2069
2070 #[test]
2071 fn test_plsregression_too_many_components() {
2072 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2073 let y = array![[1.0], [2.0], [3.0]];
2074 let pls = PLSRegression::<f64>::new(2);
2076 assert!(pls.fit(&x, &y).is_err());
2077 }
2078
2079 #[test]
2080 fn test_plsregression_row_mismatch() {
2081 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2082 let y = array![[1.0], [2.0]];
2083 let pls = PLSRegression::<f64>::new(1);
2084 assert!(pls.fit(&x, &y).is_err());
2085 }
2086
2087 #[test]
2088 fn test_plsregression_insufficient_samples() {
2089 let x = array![[1.0, 2.0]];
2090 let y = array![[1.0]];
2091 let pls = PLSRegression::<f64>::new(1);
2092 assert!(pls.fit(&x, &y).is_err());
2093 }
2094
2095 #[test]
2096 fn test_plsregression_predict_shape_mismatch() {
2097 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2098 let y = array![[1.0], [2.0], [3.0]];
2099 let pls = PLSRegression::<f64>::new(1);
2100 let fitted = pls.fit(&x, &y).unwrap();
2101 let x_bad = array![[1.0, 2.0, 3.0]];
2102 assert!(fitted.predict(&x_bad).is_err());
2103 }
2104
2105 #[test]
2106 fn test_plsregression_transform_shape_mismatch() {
2107 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2108 let y = array![[1.0], [2.0], [3.0]];
2109 let pls = PLSRegression::<f64>::new(1);
2110 let fitted = pls.fit(&x, &y).unwrap();
2111 let x_bad = array![[1.0, 2.0, 3.0]];
2112 assert!(fitted.transform(&x_bad).is_err());
2113 }
2114
2115 #[test]
2116 fn test_plsregression_x_scores_shape() {
2117 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
2118 let y = array![[1.0], [2.0], [3.0], [4.0]];
2119 let pls = PLSRegression::<f64>::new(1);
2120 let fitted = pls.fit(&x, &y).unwrap();
2121 assert_eq!(fitted.x_scores().dim(), (4, 1));
2122 assert_eq!(fitted.y_scores().dim(), (4, 1));
2123 assert_eq!(fitted.n_iter().len(), 1);
2124 }
2125
2126 #[test]
2127 fn test_plsregression_f32() {
2128 let x: Array2<f32> = array![
2129 [1.0f32, 2.0],
2130 [3.0, 4.0],
2131 [5.0, 6.0],
2132 [7.0, 8.0],
2133 [9.0, 10.0],
2134 ];
2135 let y: Array2<f32> = array![[1.0f32], [2.0], [3.0], [4.0], [5.0]];
2136 let pls = PLSRegression::<f32>::new(1);
2137 let fitted = pls.fit(&x, &y).unwrap();
2138 let y_pred = fitted.predict(&x).unwrap();
2139 assert_eq!(y_pred.ncols(), 1);
2140 }
2141
2142 #[test]
2147 fn test_plscanonical_basic_fit_transform() {
2148 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2149 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0], [5.0, 2.5],];
2150 let pls = PLSCanonical::<f64>::new(2);
2151 let fitted = pls.fit(&x, &y).unwrap();
2152 let scores = fitted.transform(&x).unwrap();
2153 assert_eq!(scores.dim(), (5, 2));
2154 }
2155
2156 #[test]
2157 fn test_plscanonical_single_component() {
2158 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2159 let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
2160 let pls = PLSCanonical::<f64>::new(1);
2161 let fitted = pls.fit(&x, &y).unwrap();
2162 let scores = fitted.transform(&x).unwrap();
2163 assert_eq!(scores.ncols(), 1);
2164 }
2165
2166 #[test]
2167 fn test_plscanonical_scores_shape() {
2168 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
2169 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2170 let pls = PLSCanonical::<f64>::new(2);
2171 let fitted = pls.fit(&x, &y).unwrap();
2172 assert_eq!(fitted.x_scores().dim(), (3, 2));
2173 assert_eq!(fitted.y_scores().dim(), (3, 2));
2174 assert_eq!(fitted.x_weights().dim(), (3, 2));
2175 assert_eq!(fitted.x_loadings().dim(), (3, 2));
2176 assert_eq!(fitted.y_loadings().dim(), (2, 2));
2177 }
2178
2179 #[test]
2180 fn test_plscanonical_transform_y() {
2181 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
2182 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
2183 let pls = PLSCanonical::<f64>::new(1);
2184 let fitted = pls.fit(&x, &y).unwrap();
2185 let y_scores = fitted.transform_y(&y).unwrap();
2186 assert_eq!(y_scores.ncols(), 1);
2187 }
2188
2189 #[test]
2190 fn test_plscanonical_builder() {
2191 let pls = PLSCanonical::<f64>::new(2)
2192 .with_max_iter(1000)
2193 .with_tol(1e-8)
2194 .with_scale(false);
2195 assert_eq!(pls.n_components(), 2);
2196 }
2197
2198 #[test]
2199 fn test_plscanonical_invalid_zero_components() {
2200 let x = array![[1.0, 2.0], [3.0, 4.0]];
2201 let y = array![[1.0, 0.5], [2.0, 1.0]];
2202 let pls = PLSCanonical::<f64>::new(0);
2203 assert!(pls.fit(&x, &y).is_err());
2204 }
2205
2206 #[test]
2207 fn test_plscanonical_too_many_components() {
2208 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2209 let y = array![[1.0], [2.0], [3.0]];
2210 let pls = PLSCanonical::<f64>::new(2);
2211 assert!(pls.fit(&x, &y).is_err());
2212 }
2213
2214 #[test]
2215 fn test_plscanonical_row_mismatch() {
2216 let x = array![[1.0, 2.0], [3.0, 4.0]];
2217 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2218 let pls = PLSCanonical::<f64>::new(1);
2219 assert!(pls.fit(&x, &y).is_err());
2220 }
2221
2222 #[test]
2223 fn test_plscanonical_transform_shape_mismatch() {
2224 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2225 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2226 let pls = PLSCanonical::<f64>::new(1);
2227 let fitted = pls.fit(&x, &y).unwrap();
2228 let x_bad = array![[1.0, 2.0, 3.0]];
2229 assert!(fitted.transform(&x_bad).is_err());
2230 }
2231
2232 #[test]
2233 fn test_plscanonical_f32() {
2234 let x: Array2<f32> = array![
2235 [1.0f32, 2.0],
2236 [3.0, 4.0],
2237 [5.0, 6.0],
2238 [7.0, 8.0],
2239 [9.0, 10.0],
2240 ];
2241 let y: Array2<f32> = array![
2242 [1.0f32, 0.5],
2243 [2.0, 1.0],
2244 [3.0, 1.5],
2245 [4.0, 2.0],
2246 [5.0, 2.5],
2247 ];
2248 let pls = PLSCanonical::<f32>::new(1);
2249 let fitted = pls.fit(&x, &y).unwrap();
2250 let scores = fitted.transform(&x).unwrap();
2251 assert_eq!(scores.ncols(), 1);
2252 }
2253
2254 #[test]
2259 fn test_cca_basic_fit_transform() {
2260 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2261 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0], [5.0, 2.5],];
2262 let cca = CCA::<f64>::new(2);
2263 let fitted = cca.fit(&x, &y).unwrap();
2264 let scores = fitted.transform(&x).unwrap();
2265 assert_eq!(scores.dim(), (5, 2));
2266 }
2267
2268 #[test]
2269 fn test_cca_single_component() {
2270 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2271 let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
2272 let cca = CCA::<f64>::new(1);
2273 let fitted = cca.fit(&x, &y).unwrap();
2274 let scores = fitted.transform(&x).unwrap();
2275 assert_eq!(scores.ncols(), 1);
2276 }
2277
2278 #[test]
2279 fn test_cca_scores_shape() {
2280 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
2281 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2282 let cca = CCA::<f64>::new(2);
2283 let fitted = cca.fit(&x, &y).unwrap();
2284 assert_eq!(fitted.x_scores().dim(), (3, 2));
2285 assert_eq!(fitted.y_scores().dim(), (3, 2));
2286 assert_eq!(fitted.x_weights().dim(), (3, 2));
2287 assert_eq!(fitted.x_loadings().dim(), (3, 2));
2288 assert_eq!(fitted.y_loadings().dim(), (2, 2));
2289 }
2290
2291 #[test]
2292 fn test_cca_transform_y() {
2293 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
2294 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
2295 let cca = CCA::<f64>::new(1);
2296 let fitted = cca.fit(&x, &y).unwrap();
2297 let y_scores = fitted.transform_y(&y).unwrap();
2298 assert_eq!(y_scores.ncols(), 1);
2299 }
2300
2301 #[test]
2302 fn test_cca_builder() {
2303 let cca = CCA::<f64>::new(2)
2304 .with_max_iter(1000)
2305 .with_tol(1e-8)
2306 .with_scale(false);
2307 assert_eq!(cca.n_components(), 2);
2308 }
2309
2310 #[test]
2311 fn test_cca_invalid_zero_components() {
2312 let x = array![[1.0, 2.0], [3.0, 4.0]];
2313 let y = array![[1.0, 0.5], [2.0, 1.0]];
2314 let cca = CCA::<f64>::new(0);
2315 assert!(cca.fit(&x, &y).is_err());
2316 }
2317
2318 #[test]
2319 fn test_cca_too_many_components() {
2320 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2321 let y = array![[1.0], [2.0], [3.0]];
2322 let cca = CCA::<f64>::new(2);
2323 assert!(cca.fit(&x, &y).is_err());
2324 }
2325
2326 #[test]
2327 fn test_cca_row_mismatch() {
2328 let x = array![[1.0, 2.0], [3.0, 4.0]];
2329 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2330 let cca = CCA::<f64>::new(1);
2331 assert!(cca.fit(&x, &y).is_err());
2332 }
2333
2334 #[test]
2335 fn test_cca_transform_shape_mismatch() {
2336 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2337 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2338 let cca = CCA::<f64>::new(1);
2339 let fitted = cca.fit(&x, &y).unwrap();
2340 let x_bad = array![[1.0, 2.0, 3.0]];
2341 assert!(fitted.transform(&x_bad).is_err());
2342 }
2343
2344 #[test]
2345 fn test_cca_f32() {
2346 let x: Array2<f32> = array![
2347 [1.0f32, 2.0],
2348 [3.0, 4.0],
2349 [5.0, 6.0],
2350 [7.0, 8.0],
2351 [9.0, 10.0],
2352 ];
2353 let y: Array2<f32> = array![
2354 [1.0f32, 0.5],
2355 [2.0, 1.0],
2356 [3.0, 1.5],
2357 [4.0, 2.0],
2358 [5.0, 2.5],
2359 ];
2360 let cca = CCA::<f32>::new(1);
2361 let fitted = cca.fit(&x, &y).unwrap();
2362 let scores = fitted.transform(&x).unwrap();
2363 assert_eq!(scores.ncols(), 1);
2364 }
2365
2366 #[test]
2371 fn test_pls_regression_and_canonical_give_different_scores() {
2372 let x = array![
2373 [1.0, 2.0, 0.5],
2374 [3.0, 1.0, 2.5],
2375 [5.0, 6.0, 1.0],
2376 [7.0, 3.0, 4.5],
2377 [9.0, 10.0, 2.0],
2378 ];
2379 let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0], [5.0, 2.5],];
2380
2381 let pls_reg = PLSRegression::<f64>::new(2);
2382 let fitted_reg = pls_reg.fit(&x, &y).unwrap();
2383 let scores_reg = fitted_reg.transform(&x).unwrap();
2384
2385 let pls_can = PLSCanonical::<f64>::new(2);
2386 let fitted_can = pls_can.fit(&x, &y).unwrap();
2387 let scores_can = fitted_can.transform(&x).unwrap();
2388
2389 let diff: f64 = scores_reg
2391 .iter()
2392 .zip(scores_can.iter())
2393 .map(|(a, b)| (a - b).abs())
2394 .sum();
2395 assert_eq!(scores_reg.dim(), scores_can.dim());
2398 assert!(diff.is_finite());
2400 }
2401
2402 #[test]
2403 fn test_centre_scale_helper() {
2404 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2405 let (xc, mean, std_dev) = centre_scale(&x, true);
2406 assert_abs_diff_eq!(mean[0], 3.0, epsilon = 1e-10);
2407 assert_abs_diff_eq!(mean[1], 4.0, epsilon = 1e-10);
2408 assert!(std_dev.is_some());
2409
2410 let col_mean_0: f64 = xc.column(0).iter().sum::<f64>() / 3.0;
2412 assert_abs_diff_eq!(col_mean_0, 0.0, epsilon = 1e-10);
2413 }
2414
2415 #[test]
2416 fn test_centre_scale_no_scale() {
2417 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2418 let (_xc, _mean, std_dev) = centre_scale(&x, false);
2419 assert!(std_dev.is_none());
2420 }
2421
2422 #[test]
2423 fn test_invert_square_identity() {
2424 let eye = Array2::<f64>::from_shape_fn((3, 3), |(i, j)| if i == j { 1.0 } else { 0.0 });
2425 let inv = invert_square(&eye).unwrap();
2426 for i in 0..3 {
2427 for j in 0..3 {
2428 let expected = if i == j { 1.0 } else { 0.0 };
2429 assert_abs_diff_eq!(inv[[i, j]], expected, epsilon = 1e-10);
2430 }
2431 }
2432 }
2433
2434 #[test]
2435 fn test_invert_square_2x2() {
2436 let a = array![[4.0, 7.0], [2.0, 6.0]];
2437 let inv = invert_square(&a).unwrap();
2438 let prod = a.dot(&inv);
2440 for i in 0..2 {
2441 for j in 0..2 {
2442 let expected = if i == j { 1.0 } else { 0.0 };
2443 assert_abs_diff_eq!(prod[[i, j]], expected, epsilon = 1e-10);
2444 }
2445 }
2446 }
2447}