rstsr-sci-traits 0.7.8

An n-Dimension Rust Tensor Toolkit
Documentation
use crate::distance::metric::{MetricDistAPI, MetricDistWeightedAPI, MetricEuclidean};
use crate::distance::native_impl::{cdist_serial, cdist_weighted_serial};
use crate::distance::traits::CDistAPI;
use num::Float;
use rstsr_core::prelude_dev::*;

impl<T, D, M, TW, DW> CDistAPI<DeviceCpuSerial>
    for (
        TensorView<'_, T, DeviceCpuSerial, D>,
        TensorView<'_, T, DeviceCpuSerial, D>,
        M,
        TensorView<'_, TW, DeviceCpuSerial, DW>,
    )
where
    M: MetricDistWeightedAPI<Vec<T>, Weight = Vec<TW>, Out = TW>,
    TW: Float + Clone,
    DeviceCpuSerial: DeviceAPI<T, Raw = Vec<T>>
        + DeviceAPI<TW, Raw = M::Weight>
        + DeviceAPI<M::Out, Raw = Vec<M::Out>>
        + DeviceCreationAnyAPI<M::Out>
        + DeviceCreationAnyAPI<TW>
        + OpAssignArbitaryAPI<TW, DW, DW>
        + OpAssignAPI<TW, DW>,
    D: DimAPI + DimIntoAPI<Ix2>,
    DW: DimAPI + DimIntoAPI<Ix1>,
{
    type Out = Tensor<M::Out, DeviceCpuSerial, D>;

    fn cdist_f(self) -> Result<Self::Out> {
        let (xa, xb, kernel, weight) = self;
        rstsr_assert_eq!(xa.ndim(), 2, InvalidLayout, "xa must be a 2D tensor")?;
        rstsr_assert_eq!(xb.ndim(), 2, InvalidLayout, "xb must be a 2D tensor")?;
        rstsr_assert_eq!(weight.ndim(), 1, InvalidLayout, "weight must be a 1D tensor")?;
        rstsr_assert!(xa.device().same_device(xb.device()), DeviceMismatch)?;
        rstsr_assert!(xa.device().same_device(weight.device()), DeviceMismatch)?;
        let la = xa.layout().to_dim::<Ix2>()?;
        let lb = xb.layout().to_dim::<Ix2>()?;
        let device = xa.device().clone();
        let order = device.default_order();
        let weight = weight.into_contig_f(RowMajor)?;
        let dist = cdist_weighted_serial(xa.raw(), xb.raw(), &la, &lb, weight.raw(), kernel, order)?;

        let m = la.shape()[0];
        let n = lb.shape()[0];
        asarray_f((dist, [m, n], &device))?.into_dim_f::<D>()
    }
}

impl<T, D, M> CDistAPI<DeviceCpuSerial>
    for (TensorView<'_, T, DeviceCpuSerial, D>, TensorView<'_, T, DeviceCpuSerial, D>, M)
where
    DeviceCpuSerial: DeviceAPI<T, Raw = Vec<T>> + DeviceAPI<M::Out, Raw = Vec<M::Out>> + DeviceCreationAnyAPI<M::Out>,
    M: MetricDistAPI<Vec<T>>,
    D: DimAPI + DimIntoAPI<Ix2>,
{
    type Out = Tensor<M::Out, DeviceCpuSerial, D>;

    fn cdist_f(self) -> Result<Self::Out> {
        let (xa, xb, kernel) = self;
        rstsr_assert_eq!(xa.ndim(), 2, InvalidLayout, "xa must be a 2D tensor")?;
        rstsr_assert_eq!(xb.ndim(), 2, InvalidLayout, "xb must be a 2D tensor")?;
        rstsr_assert!(xa.device().same_device(xb.device()), DeviceMismatch)?;
        let la = xa.layout().to_dim::<Ix2>()?;
        let lb = xb.layout().to_dim::<Ix2>()?;
        let device = xa.device().clone();
        let order = device.default_order();
        let dist = cdist_serial(xa.raw(), xb.raw(), &la, &lb, kernel, order)?;

        let m = la.shape()[0];
        let n = lb.shape()[0];
        asarray_f((dist, [m, n], &device))?.into_dim_f::<D>()
    }
}

impl<T, D> CDistAPI<DeviceCpuSerial> for (TensorView<'_, T, DeviceCpuSerial, D>, TensorView<'_, T, DeviceCpuSerial, D>)
where
    T: Float,
    DeviceCpuSerial: DeviceAPI<T, Raw = Vec<T>> + DeviceCreationAnyAPI<T>,
    D: DimAPI + DimIntoAPI<Ix2>,
{
    type Out = Tensor<T, DeviceCpuSerial, D>;

    fn cdist_f(self) -> Result<Self::Out> {
        let (xa, xb) = self;
        CDistAPI::<DeviceCpuSerial>::cdist_f((xa, xb, MetricEuclidean))
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use crate::distance::metric::*;
    use crate::distance::traits::cdist;

    #[test]
    fn playground() {
        let device = DeviceCpuSerial::default();
        let a = linspace((0., 1., 64, &device)).into_shape((16, 4));
        let b = linspace((0., 1., 80, &device)).into_shape((20, 4)).into_flip(-1);

        let d = cdist((a.view(), b.view(), MetricEuclidean));
        println!("{d:16.8?}");

        let d = cdist((a.view(), b.view()));
        println!("{d:16.8?}");

        let w = asarray((vec![1.5, 1.2, 0.7, 1.3], &device));
        let d_w = cdist((a.view(), b.view(), MetricEuclidean, w.view()));
        println!("{d_w:16.8?}");

        let d = cdist((a.view(), b.view(), MetricMinkowski::new(3.0)));
        println!("{d:16.8?}");

        let d = cdist((a.view(), b.view(), MetricCityBlock));
        println!("{d:16.8?}");

        let d = cdist((a.view(), b.view(), MetricSqEuclidean));
        println!("{d:16.8?}");

        let a = linspace((0., 1., 64, &device)).into_shape((16, 4));
        let b = linspace((0., 1., 80, &device)).into_shape((20, 4)).into_flip(-1);
        let d = cdist((a.view(), b.view(), MetricHamming));
        println!("{d:16.8?}");

        let d = cdist((a.view(), b.view(), MetricHamming, w.view()));
        println!("{d:16.8?}");

        let d = cdist((a.view(), b.view(), MetricBrayCurtis, w.view()));
        println!("{d:16.8?}");

        let d = cdist((a.view(), b.view(), MetricCosine::new(), w.view()));
        println!("{d:16.8?}");

        let d = cdist((a.view(), b.view(), MetricJensenShannon::new()));
        println!("{d:16.8?}");

        let vec_a =
            vec![true, false, false, false, false, false, true, true, true, true, true, false, true, true, true, true];
        let vec_b = vec![
            false, false, true, false, false, true, true, false, false, true, true, false, false, false, false, false,
        ];
        let a = asarray((vec_a, &device)).into_shape((4, 4));
        let b = asarray((vec_b, &device)).into_shape((4, 4));

        let d = cdist((a.view(), b.view(), MetricSokalSneath, w.view()));
        println!("{d:16.8?}");

        let d = cdist((a.view(), b.view(), MetricYule, w.view()));
        println!("{d:16.8?}");

        let a = asarray((vec![-3., 6., 3., 6., -3., 3., 6., 1., 0., -8., 1., -2.], &device)).into_shape((3, 4));
        let b = asarray((vec![3., 1., -1., 1., -6., 0., 3., -1., -5., -4., 4., -2.], &device)).into_shape((3, 4));

        let d = cdist((a.view(), b.view(), MetricCorrelation::default(), w.view()));
        println!("{d:16.8?}");

        let d = cdist((a.abs().view(), b.abs().view(), MetricJensenShannon::default()));
        println!("{d:16.8?}");
    }
}