use linfa::Float;
use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2, s};
#[cfg(feature = "serializable")]
use serde::{Deserialize, Serialize};
#[derive(Debug)]
#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
pub(crate) struct NormalizedData<F: Float> {
pub data: Array2<F>,
pub mean: Array1<F>,
pub std: Array1<F>,
}
impl<F: Float> Clone for NormalizedData<F> {
fn clone(&self) -> NormalizedData<F> {
NormalizedData {
data: self.data.to_owned(),
mean: self.mean.to_owned(),
std: self.std.to_owned(),
}
}
}
impl<F: Float> NormalizedData<F> {
pub fn new(x: &ArrayBase<impl Data<Elem = F>, Ix2>) -> NormalizedData<F> {
let (data, mean, std) = normalize(x);
NormalizedData {
data: data.to_owned(),
mean: mean.to_owned(),
std: std.to_owned(),
}
}
pub fn ncols(&self) -> usize {
self.data.ncols()
}
}
pub fn normalize<F: Float>(
x: &ArrayBase<impl Data<Elem = F>, Ix2>,
) -> (Array2<F>, Array1<F>, Array1<F>) {
let x_mean = x.mean_axis(Axis(0)).unwrap();
let mut x_std = x.std_axis(Axis(0), F::one());
x_std.mapv_inplace(|v| if v == F::zero() { F::one() } else { v });
let xnorm = (x - &x_mean) / &x_std;
(xnorm, x_mean, x_std)
}
#[derive(Debug)]
pub struct DiffMatrix<F: Float> {
pub d: Array2<F>,
pub d_indices: Array2<usize>,
pub n_obs: usize,
}
impl<F: Float> DiffMatrix<F> {
pub fn new(x: &ArrayBase<impl Data<Elem = F>, Ix2>) -> DiffMatrix<F> {
let (d, d_indices) = Self::_cross_diff(x);
let n_obs = x.nrows();
DiffMatrix {
d,
d_indices,
n_obs,
}
}
fn _cross_diff(x: &ArrayBase<impl Data<Elem = F>, Ix2>) -> (Array2<F>, Array2<usize>) {
let n_obs = x.nrows();
let nx = x.ncols();
let n_non_zero_cross_dist = n_obs * (n_obs - 1) / 2;
let mut indices = Array2::<usize>::zeros((n_non_zero_cross_dist, 2));
let mut d = Array2::zeros((n_non_zero_cross_dist, nx));
let mut idx = 0;
for k in 0..(n_obs - 1) {
let idx0 = idx;
let offset = n_obs - k - 1;
idx = idx0 + offset;
for i in (k + 1)..n_obs {
let r = idx0 + i - k - 1;
indices[[r, 0]] = k;
indices[[r, 1]] = i;
}
let diff = &x.slice(s![k, ..]) - &x.slice(s![k + 1..n_obs, ..]);
d.slice_mut(s![idx0..idx, ..]).assign(&diff);
}
d = d.mapv(|v| v.abs());
(d, indices)
}
}
pub fn pairwise_differences<F: Float>(
x: &ArrayBase<impl Data<Elem = F>, Ix2>,
y: &ArrayBase<impl Data<Elem = F>, Ix2>,
) -> Array2<F> {
assert!(x.ncols() == y.ncols());
let nx = x.nrows();
let ny = y.nrows();
let ncols = x.ncols();
let mut result = Array2::zeros((nx * ny, ncols));
for (i, x_row) in x.rows().into_iter().enumerate() {
for (j, y_row) in y.rows().into_iter().enumerate() {
let idx = i * ny + j;
for k in 0..ncols {
result[[idx, k]] = x_row[k] - y_row[k];
}
}
}
result
}
pub fn differences<F: Float>(
x: &ArrayBase<impl Data<Elem = F>, Ix1>,
y: &ArrayBase<impl Data<Elem = F>, Ix2>,
) -> Array2<F> {
assert!(x.len() == y.ncols());
x.to_owned() - y
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
#[test]
fn test_pairwise_differences() {
let x = array![[-0.9486833], [-0.82219219]];
let y = array![
[-1.26491106],
[-0.63245553],
[0.],
[0.63245553],
[1.26491106]
];
assert_abs_diff_eq!(
&array![
[0.31622777],
[-0.31622777],
[-0.9486833],
[-1.58113883],
[-2.21359436],
[0.44271887],
[-0.18973666],
[-0.82219219],
[-1.45464772],
[-2.08710326]
],
&pairwise_differences(&x, &y),
epsilon = 1e-6
)
}
#[test]
fn test_differences() {
let x = array![-0.9486833];
let y = array![
[-1.26491106],
[-0.63245553],
[0.],
[0.63245553],
[1.26491106]
];
assert_abs_diff_eq!(
&array![
[0.31622777],
[-0.31622777],
[-0.9486833],
[-1.58113883],
[-2.21359436],
],
&differences(&x, &y),
epsilon = 1e-6
)
}
#[test]
fn test_normalized_matrix() {
let x = array![[1., 2.], [3., 4.]];
let xnorm = NormalizedData::new(&x);
assert_eq!(xnorm.ncols(), 2);
assert_eq!(array![2., 3.], xnorm.mean);
assert_eq!(array![f64::sqrt(2.), f64::sqrt(2.)], xnorm.std);
}
#[test]
fn test_diff_matrix() {
let xt = array![[0.5], [1.2], [2.0], [3.0], [4.0]];
let expected = (
array![
[0.7],
[1.5],
[2.5],
[3.5],
[0.8],
[1.8],
[2.8],
[1.],
[2.],
[1.]
],
array![
[0, 1],
[0, 2],
[0, 3],
[0, 4],
[1, 2],
[1, 3],
[1, 4],
[2, 3],
[2, 4],
[3, 4]
],
);
let dm = DiffMatrix::new(&xt);
assert_eq!(expected.0, dm.d);
assert_eq!(expected.1, dm.d_indices);
}
}