mod errors;
mod hyperparams;
mod pls_generic;
mod pls_svd;
mod utils;
use crate::pls_generic::*;
use linfa::{traits::Fit, traits::PredictInplace, traits::Transformer, DatasetBase, Float};
use ndarray::{Array2, ArrayBase, Data, Ix2};
pub use errors::*;
pub use hyperparams::*;
pub use pls_svd::*;
macro_rules! pls_algo { ($name:ident) => {
paste::item! {
pub struct [<Pls $name>]<F: Float>(Pls<F>);
impl<F: Float> [<Pls $name>]<F> {
pub fn params(n_components: usize) -> [<Pls $name Params>]<F> {
[<Pls $name Params>]([<Pls $name ValidParams>](Pls::[<$name:lower>](n_components).0))
}
pub fn weights(&self) -> (&Array2<F>, &Array2<F>) {
self.0.weights()
}
pub fn loadings(&self) -> (&Array2<F>, &Array2<F>) {
self.0.loadings()
}
pub fn rotations(&self) -> (&Array2<F>, &Array2<F>) {
self.0.rotations()
}
pub fn coefficients(&self) -> &Array2<F> {
self.0.coefficients()
}
pub fn inverse_transform(
&self,
dataset: DatasetBase<
ArrayBase<impl Data<Elem = F>, Ix2>,
ArrayBase<impl Data<Elem = F>, Ix2>,
>,
) -> DatasetBase<Array2<F>, Array2<F>> {
self.0.inverse_transform(dataset)
}
}
impl<F: Float, D: Data<Elem = F>> Fit<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>, PlsError>
for [<Pls $name ValidParams>]<F>
{
type Object = [<Pls $name>]<F>;
fn fit(
&self,
dataset: &DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
) -> Result<Self::Object> {
let pls = self.0.fit(dataset)?;
Ok([<Pls $name>](pls))
}
}
impl<F: Float, D: Data<Elem = F>> Transformer<
DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
DatasetBase<Array2<F>, Array2<F>>,
> for [<Pls $name>]<F>
{
fn transform(
&self,
dataset: DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
) -> DatasetBase<Array2<F>, Array2<F>> {
self.0.transform(dataset)
}
}
impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array2<F>> for [<Pls $name>]<F> {
fn predict_inplace<'a>(&'a self, x: &ArrayBase<D, Ix2>, y: &mut Array2<F>) {
self.0.predict_inplace(x, y);
}
fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array2<F> {
self.0.default_target(x)
}
}
}
}}
pls_algo!(Regression);
pls_algo!(Canonical);
pls_algo!(Cca);
#[cfg(test)]
mod test {
use super::*;
use approx::assert_abs_diff_eq;
use linfa::{traits::Fit, traits::Predict, traits::Transformer};
use linfa_datasets::linnerud;
use ndarray::array;
macro_rules! test_pls_algo {
(Svd) => {
paste::item! {
#[test]
fn [<test_pls_svd>]() -> Result<()> {
let ds = linnerud();
let pls = PlsSvd::<f64>::params(3).fit(&ds)?;
let _ds1 = pls.transform(ds);
Ok(())
}
}
};
($name:ident, $expected:expr) => {
paste::item! {
#[test]
fn [<test_pls_$name:lower>]() -> Result<()> {
let ds = linnerud();
let pls = [<Pls $name>]::<f64>::params(2).fit(&ds)?;
let _ds1 = pls.transform(ds);
let exercices = array![[14., 146., 61.], [6., 80., 60.]];
let physios = pls.predict(exercices);
assert_abs_diff_eq!($expected, physios.targets(), epsilon=1e-2);
Ok(())
}
}
};
}
test_pls_algo!(
Canonical,
array![
[180.56979423, 33.29543984, 56.90850758],
[190.854022, 38.91963398, 53.26914489]
]
);
test_pls_algo!(
Regression,
array![
[172.39580643, 34.11919145, 57.15430526],
[192.11167813, 38.05058858, 53.99844922]
]
);
test_pls_algo!(
Cca,
array![
[181.56238421, 34.42502589, 57.31447865],
[205.11767414, 40.23445194, 52.26494323]
]
);
test_pls_algo!(Svd);
#[test]
fn test_one_component_equivalence() -> Result<()> {
let ds = linnerud();
let regression = PlsRegression::params(1).fit(&ds)?.transform(linnerud());
let canonical = PlsCanonical::params(1).fit(&ds)?.transform(linnerud());
let svd = PlsSvd::<f64>::params(1).fit(&ds)?.transform(linnerud());
assert_abs_diff_eq!(regression.records(), canonical.records(), epsilon = 1e-5);
assert_abs_diff_eq!(svd.records(), canonical.records(), epsilon = 1e-5);
Ok(())
}
}