pub mod inner;
mod sparse;
pub use inner::{Inner, KernelInner};
use linfa_nn::CommonNearestNeighbour;
use linfa_nn::NearestNeighbour;
use ndarray::prelude::*;
use ndarray::Data;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use sprs::{CsMat, CsMatView};
use std::ops::Mul;
use linfa::{
dataset::AsTargets, dataset::DatasetBase, dataset::FromTargetArray, dataset::Records,
traits::Transformer, Float,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum KernelType {
Dense,
Sparse(usize),
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug, Clone, PartialEq)]
pub struct KernelBase<K1: Inner, K2: Inner>
where
K1::Elem: Float,
K2::Elem: Float,
{
#[cfg_attr(
feature = "serde",
serde(bound(
serialize = "KernelInner<K1, K2>: Serialize",
deserialize = "KernelInner<K1, K2>: Deserialize<'de>"
))
)]
pub inner: KernelInner<K1, K2>,
#[cfg_attr(
feature = "serde",
serde(bound(
serialize = "KernelMethod<K1::Elem>: Serialize",
deserialize = "KernelMethod<K1::Elem>: Deserialize<'de>"
))
)]
pub method: KernelMethod<K1::Elem>,
}
pub type Kernel<F> = KernelBase<Array2<F>, CsMat<F>>;
pub type KernelView<'a, F> = KernelBase<ArrayView2<'a, F>, CsMatView<'a, F>>;
impl<F: Float, K1: Inner<Elem = F>, K2: Inner<Elem = F>> KernelBase<K1, K2> {
pub fn is_linear(&self) -> bool {
self.method.is_linear()
}
pub fn params() -> KernelParams<F, CommonNearestNeighbour> {
Self::params_with_nn(CommonNearestNeighbour::KdTree)
}
pub fn params_with_nn<N: NearestNeighbour>(nn_algo: N) -> KernelParams<F, N> {
KernelParams {
kind: KernelType::Dense,
method: KernelMethod::Gaussian(F::cast(0.5)),
nn_algo,
}
}
pub fn dot(&self, rhs: &ArrayView2<F>) -> Array2<F> {
match &self.inner {
KernelInner::Dense(inn) => inn.dot(rhs),
KernelInner::Sparse(inn) => inn.dot(rhs),
}
}
pub fn sum(&self) -> Array1<F> {
match &self.inner {
KernelInner::Dense(inn) => inn.sum(),
KernelInner::Sparse(inn) => inn.sum(),
}
}
pub fn size(&self) -> usize {
match &self.inner {
KernelInner::Dense(inn) => inn.size(),
KernelInner::Sparse(inn) => inn.size(),
}
}
pub fn column(&self, i: usize) -> Vec<F> {
match &self.inner {
KernelInner::Dense(inn) => inn.column(i),
KernelInner::Sparse(inn) => inn.column(i),
}
}
pub fn to_upper_triangle(&self) -> Vec<F> {
match &self.inner {
KernelInner::Dense(inn) => inn.to_upper_triangle(),
KernelInner::Sparse(inn) => inn.to_upper_triangle(),
}
}
pub fn diagonal(&self) -> Array1<F> {
match &self.inner {
KernelInner::Dense(inn) => inn.diagonal(),
KernelInner::Sparse(inn) => inn.diagonal(),
}
}
}
impl<'a, F: Float> Kernel<F> {
pub fn new<N: NearestNeighbour>(
dataset: ArrayView2<'a, F>,
params: &KernelParams<F, N>,
) -> Kernel<F> {
let inner = match params.kind {
KernelType::Dense => KernelInner::Dense(dense_from_fn(&dataset, ¶ms.method)),
KernelType::Sparse(k) => {
KernelInner::Sparse(sparse_from_fn(&dataset, k, ¶ms.method, ¶ms.nn_algo))
}
};
Kernel {
inner,
method: params.method.clone(),
}
}
pub fn view(&'a self) -> KernelView<'a, F> {
KernelView {
inner: match &self.inner {
KernelInner::Dense(inn) => KernelInner::Dense(inn.view()),
KernelInner::Sparse(inn) => KernelInner::Sparse(inn.view()),
},
method: self.method.clone(),
}
}
}
impl<'a, F: Float> KernelView<'a, F> {
pub fn to_owned(&self) -> Kernel<F> {
Kernel {
inner: match &self.inner {
KernelInner::Dense(inn) => KernelInner::Dense(inn.to_owned()),
KernelInner::Sparse(inn) => KernelInner::Sparse(inn.to_owned()),
},
method: self.method.clone(),
}
}
}
impl<F: Float, K1: Inner<Elem = F>, K2: Inner<Elem = F>> Records for KernelBase<K1, K2> {
type Elem = F;
fn nsamples(&self) -> usize {
self.size()
}
fn nfeatures(&self) -> usize {
self.size()
}
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug, Clone, PartialEq)]
pub enum KernelMethod<F> {
Gaussian(F),
Linear,
Polynomial(F, F),
}
impl<F: Float> KernelMethod<F> {
pub fn distance(&self, a: ArrayView1<F>, b: ArrayView1<F>) -> F {
match *self {
KernelMethod::Gaussian(eps) => {
let distance = a
.iter()
.zip(b.iter())
.map(|(x, y)| (*x - *y) * (*x - *y))
.sum::<F>();
(-distance / eps).exp()
}
KernelMethod::Linear => a.mul(&b).sum(),
KernelMethod::Polynomial(c, d) => (a.mul(&b).sum() + c).powf(d),
}
}
pub fn is_linear(&self) -> bool {
matches!(*self, KernelMethod::Linear)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct KernelParams<F, N = CommonNearestNeighbour> {
kind: KernelType,
method: KernelMethod<F>,
nn_algo: N,
}
impl<F, N> KernelParams<F, N> {
pub fn method(mut self, method: KernelMethod<F>) -> Self {
self.method = method;
self
}
pub fn kind(mut self, kind: KernelType) -> Self {
self.kind = kind;
self
}
pub fn nn_algo(mut self, nn_algo: N) -> Self {
self.nn_algo = nn_algo;
self
}
}
impl<F: Float, N: NearestNeighbour> Transformer<&Array2<F>, Kernel<F>> for KernelParams<F, N> {
fn transform(&self, x: &Array2<F>) -> Kernel<F> {
Kernel::new(x.view(), self)
}
}
impl<'a, F: Float, N: NearestNeighbour> Transformer<ArrayView2<'a, F>, Kernel<F>>
for KernelParams<F, N>
{
fn transform(&self, x: ArrayView2<'a, F>) -> Kernel<F> {
Kernel::new(x, self)
}
}
impl<'a, F: Float, N: NearestNeighbour> Transformer<&ArrayView2<'a, F>, Kernel<F>>
for KernelParams<F, N>
{
fn transform(&self, x: &ArrayView2<'a, F>) -> Kernel<F> {
Kernel::new(*x, self)
}
}
impl<'a, F: Float, T: AsTargets, N: NearestNeighbour>
Transformer<DatasetBase<Array2<F>, T>, DatasetBase<Kernel<F>, T>> for KernelParams<F, N>
{
fn transform(&self, x: DatasetBase<Array2<F>, T>) -> DatasetBase<Kernel<F>, T> {
let kernel = Kernel::new(x.records.view(), self);
DatasetBase::new(kernel, x.targets)
}
}
impl<'a, F: Float, L: 'a, T: AsTargets<Elem = L> + FromTargetArray<'a>, N: NearestNeighbour>
Transformer<&'a DatasetBase<Array2<F>, T>, DatasetBase<Kernel<F>, T::View>>
for KernelParams<F, N>
{
fn transform(&self, x: &'a DatasetBase<Array2<F>, T>) -> DatasetBase<Kernel<F>, T::View> {
let kernel = Kernel::new(x.records.view(), self);
DatasetBase::new(kernel, T::new_targets_view(x.as_targets()))
}
}
impl<
'a,
'b,
F: Float,
L: 'b,
T: AsTargets<Elem = L> + FromTargetArray<'b>,
N: NearestNeighbour,
> Transformer<&'b DatasetBase<ArrayView2<'a, F>, T>, DatasetBase<Kernel<F>, T::View>>
for KernelParams<F, N>
{
fn transform(
&self,
x: &'b DatasetBase<ArrayView2<'a, F>, T>,
) -> DatasetBase<Kernel<F>, T::View> {
let kernel = Kernel::new(x.records.view(), self);
DatasetBase::new(kernel, T::new_targets_view(x.as_targets()))
}
}
fn dense_from_fn<F: Float, D: Data<Elem = F>>(
dataset: &ArrayBase<D, Ix2>,
method: &KernelMethod<F>,
) -> Array2<F> {
let n_observations = dataset.len_of(Axis(0));
let mut similarity = Array2::eye(n_observations);
for i in 0..n_observations {
for j in 0..n_observations {
let a = dataset.row(i);
let b = dataset.row(j);
similarity[(i, j)] = method.distance(a, b);
}
}
similarity
}
fn sparse_from_fn<F: Float, D: Data<Elem = F>, N: NearestNeighbour>(
dataset: &ArrayBase<D, Ix2>,
k: usize,
method: &KernelMethod<F>,
nn_algo: &N,
) -> CsMat<F> {
let mut data = sparse::adjacency_matrix(dataset, k, nn_algo);
for (i, mut vec) in data.outer_iterator_mut().enumerate() {
for (j, val) in vec.iter_mut() {
let a = dataset.row(i);
let b = dataset.row(j);
*val = method.distance(a, b);
}
}
data
}
#[cfg(test)]
mod tests {
use super::*;
use linfa::Dataset;
use linfa_nn::{BallTree, KdTree};
use ndarray::{Array1, Array2};
use std::f64::consts;
#[test]
fn autotraits() {
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
has_autotraits::<KernelType>();
has_autotraits::<KernelBase<ArrayView2<f64>, ArrayView2<f64>>>();
has_autotraits::<KernelMethod<f64>>();
has_autotraits::<KernelParams<f64, f64>>();
has_autotraits::<KernelView<f64>>();
has_autotraits::<KernelInner<ArrayView2<f64>, ArrayView2<f64>>>();
has_autotraits::<Kernel<f64>>();
}
#[test]
fn sparse_from_fn_test() {
let input_mat = vec![
0., 0., 0.1, 0.1, 1., 1., 1.1, 1.1, 2., 2., 2.1, 2.1, 3., 3., 3.1, 3.1,
];
let input_arr = Array2::from_shape_vec((8, 2), input_mat).unwrap();
let adj_mat = sparse_from_fn(&input_arr, 1, &KernelMethod::Linear, &KdTree);
assert_eq!(adj_mat.nnz(), 16);
assert_eq!(*adj_mat.get(0, 0).unwrap() as usize, 0);
assert_eq!((*adj_mat.get(1, 1).unwrap() * 100.) as usize, 2);
assert_eq!(*adj_mat.get(2, 2).unwrap() as usize, 2);
assert_eq!((*adj_mat.get(3, 3).unwrap() * 100.) as usize, 242);
assert_eq!(*adj_mat.get(4, 4).unwrap() as usize, 8);
assert_eq!((*adj_mat.get(5, 5).unwrap() * 100.) as usize, 882);
assert_eq!(*adj_mat.get(6, 6).unwrap() as usize, 18);
assert_eq!((*adj_mat.get(7, 7).unwrap() * 100.) as usize, 1922);
assert_eq!(*adj_mat.get(0, 1).unwrap() as usize, 0);
assert_eq!(*adj_mat.get(1, 0).unwrap() as usize, 0);
assert_eq!((*adj_mat.get(2, 3).unwrap() * 10.) as usize, 22);
assert_eq!((*adj_mat.get(3, 2).unwrap() * 10.) as usize, 22);
assert_eq!((*adj_mat.get(4, 5).unwrap() * 10.) as usize, 84);
assert_eq!((*adj_mat.get(5, 4).unwrap() * 10.) as usize, 84);
assert_eq!((*adj_mat.get(6, 7).unwrap() * 10.) as usize, 186);
assert_eq!((*adj_mat.get(7, 6).unwrap() * 10.) as usize, 186);
}
#[test]
fn dense_from_fn_test() {
let input_mat = vec![
0., 0., 0.1, 0.1, 1., 1., 1.1, 1.1, 2., 2., 2.1, 2.1, 3., 3., 3.1, 3.1,
];
let input_arr = Array2::from_shape_vec((8, 2), input_mat).unwrap();
let method: KernelMethod<f64> = KernelMethod::Linear;
let similarity_matrix = dense_from_fn(&input_arr, &method);
for i in 0..8 {
for j in 0..8 {
assert!(
(similarity_matrix.row(i)[j]
- method.distance(input_arr.row(i), input_arr.row(j)))
.abs()
<= f64::EPSILON
);
}
}
}
#[test]
fn gaussian_test() {
let gauss_1 = KernelMethod::Gaussian(1.);
let p1 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
let p2 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
let distance = gauss_1.distance(p1.view(), p2.view());
let expected = 1.;
assert!(((distance - expected) as f64).abs() <= f64::EPSILON);
let p1 = Array1::from_shape_vec(2, vec![1., 1.]).unwrap();
let p2 = Array1::from_shape_vec(2, vec![5., 5.]).unwrap();
let distance = gauss_1.distance(p1.view(), p2.view());
let expected = (consts::E).powf(-32.);
assert!(((distance - expected) as f64).abs() <= f64::EPSILON);
let gauss_01 = KernelMethod::Gaussian(0.1);
let p1 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
let p2 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
let distance = gauss_01.distance(p1.view(), p2.view());
let expected = 1.;
assert!(((distance - expected) as f64).abs() <= f64::EPSILON);
let p1 = Array1::from_shape_vec(2, vec![1., 1.]).unwrap();
let p2 = Array1::from_shape_vec(2, vec![2., 2.]).unwrap();
let distance = gauss_01.distance(p1.view(), p2.view());
let expected = (consts::E).powf(-20.);
assert!(((distance - expected) as f64).abs() <= f64::EPSILON);
}
#[test]
fn poly2_test() {
let pol_0 = KernelMethod::Polynomial(0., 2.);
let p1 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
let p2 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
let distance = pol_0.distance(p1.view(), p2.view());
let expected = 0.;
assert!(((distance - expected) as f64).abs() <= f64::EPSILON);
let p1 = Array1::from_shape_vec(2, vec![1., 1.]).unwrap();
let p2 = Array1::from_shape_vec(2, vec![5., 5.]).unwrap();
let distance = pol_0.distance(p1.view(), p2.view());
let expected = 100.;
assert!(((distance - expected) as f64).abs() <= f64::EPSILON);
let pol_2 = KernelMethod::Polynomial(2., 2.);
let p1 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
let p2 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
let distance = pol_2.distance(p1.view(), p2.view());
let expected = 4.;
assert!(((distance - expected) as f64).abs() <= f64::EPSILON);
let p1 = Array1::from_shape_vec(2, vec![1., 1.]).unwrap();
let p2 = Array1::from_shape_vec(2, vec![2., 2.]).unwrap();
let distance = pol_2.distance(p1.view(), p2.view());
let expected = 36.;
assert!(((distance - expected) as f64).abs() <= f64::EPSILON);
}
#[test]
fn test_kernel_dot() {
let input_vec: Vec<f64> = (0..100).map(|v| v as f64 * 0.1).collect();
let vec_to_multiply: Vec<f64> = (0..100).map(|v| v as f64 * 0.3).collect();
let input_arr = Array2::from_shape_vec((10, 10), input_vec).unwrap();
let to_multiply = Array2::from_shape_vec((10, 10), vec_to_multiply).unwrap();
let mul_mat = dense_from_fn(&input_arr, &KernelMethod::Linear).dot(&to_multiply);
let kernel = KernelView::params()
.kind(KernelType::Dense)
.method(KernelMethod::Linear)
.transform(input_arr.view());
let mul_ker = kernel.dot(&to_multiply.view());
assert!(matrices_almost_equal(mul_mat.view(), mul_ker.view()));
let mul_mat =
sparse_from_fn(&input_arr, 3, &KernelMethod::Linear, &KdTree).mul(&to_multiply.view());
let kernel = KernelView::params()
.kind(KernelType::Sparse(3))
.method(KernelMethod::Linear)
.transform(input_arr.view());
let mul_ker = kernel.dot(&to_multiply.view());
assert!(matrices_almost_equal(mul_mat.view(), mul_ker.view()));
}
#[test]
fn test_kernel_upper_triangle() {
let input_vec: Vec<f64> = (0..50).map(|v| v as f64 * 0.1).collect();
let input_arr_1 = Array2::from_shape_vec((5, 10), input_vec.clone()).unwrap();
let mut input_arr_2 = Array2::from_shape_vec((5, 10), input_vec).unwrap();
input_arr_2.invert_axis(Axis(0));
let input_arr =
ndarray::concatenate(Axis(0), &[input_arr_1.view(), input_arr_2.view()]).unwrap();
for kind in vec![KernelType::Dense, KernelType::Sparse(1)] {
let kernel = KernelView::params()
.kind(kind)
.method(KernelMethod::Gaussian(1e-5))
.transform(input_arr.view());
let mut kernel_upper_triang = kernel.to_upper_triangle();
assert_eq!(kernel_upper_triang.len(), 45);
kernel_upper_triang.reverse();
for i in 0..9 {
for j in (i + 1)..10 {
if j == (9 - i) {
assert_eq!(kernel_upper_triang.pop().unwrap() as usize, 1);
} else {
assert_eq!(kernel_upper_triang.pop().unwrap() as usize, 0);
}
}
}
assert!(kernel_upper_triang.is_empty());
}
}
#[test]
fn test_kernel_sum() {
let input_vec: Vec<f64> = (0..100).map(|v| v as f64 * 0.1).collect();
let input_arr = Array2::from_shape_vec((10, 10), input_vec).unwrap();
let method = KernelMethod::Linear;
let cols_sum = dense_from_fn(&input_arr, &method).sum_axis(Axis(1));
let kernel = KernelView::params()
.kind(KernelType::Dense)
.method(method.clone())
.transform(input_arr.view());
let kers_sum = kernel.sum();
assert!(arrays_almost_equal(cols_sum.view(), kers_sum.view()));
let cols_sum = sparse_from_fn(&input_arr, 3, &method, &BallTree)
.to_dense()
.sum_axis(Axis(1));
let kernel = KernelView::params()
.kind(KernelType::Sparse(3))
.method(method)
.transform(input_arr.view());
let kers_sum = kernel.sum();
assert!(arrays_almost_equal(cols_sum.view(), kers_sum.view()));
}
#[test]
fn test_kernel_diag() {
let input_vec: Vec<f64> = (0..100).map(|v| v as f64 * 0.1).collect();
let input_arr = Array2::from_shape_vec((10, 10), input_vec).unwrap();
let method = KernelMethod::Linear;
let input_diagonal = dense_from_fn(&input_arr, &method).diag().into_owned();
let kernel = KernelView::params()
.kind(KernelType::Dense)
.method(method.clone())
.transform(input_arr.view());
let kers_diagonal = kernel.diagonal();
assert!(arrays_almost_equal(
input_diagonal.view(),
kers_diagonal.view()
));
let input_diagonal: Vec<_> = sparse_from_fn(&input_arr, 3, &method, &BallTree)
.outer_iterator()
.enumerate()
.map(|(i, row)| *row.get(i).unwrap())
.collect();
let input_diagonal = Array1::from_shape_vec(10, input_diagonal).unwrap();
let kernel = KernelView::params()
.kind(KernelType::Sparse(3))
.method(method)
.transform(input_arr.view());
let kers_diagonal = kernel.diagonal();
assert!(arrays_almost_equal(
input_diagonal.view(),
kers_diagonal.view()
));
}
#[test]
fn test_kernel_transform_from_array2() {
let input_vec: Vec<f64> = (0..100).map(|v| v as f64 * 0.1).collect();
let input = Array2::from_shape_vec((50, 2), input_vec).unwrap();
check_kernel_from_array2_type(&input, KernelType::Dense);
check_kernel_from_array2_type(&input, KernelType::Sparse(3));
check_kernel_from_array_view_2_type(input.view(), KernelType::Dense);
check_kernel_from_array_view_2_type(input.view(), KernelType::Sparse(3));
}
#[test]
fn test_kernel_transform_from_dataset() {
let input_vec: Vec<f64> = (0..100).map(|v| v as f64 * 0.1).collect();
let input_arr = Array2::from_shape_vec((50, 2), input_vec).unwrap();
let input = Dataset::from(input_arr);
check_kernel_from_dataset_type(&input, KernelType::Dense);
check_kernel_from_dataset_type(&input, KernelType::Sparse(3));
check_kernel_from_dataset_view_type(&input.view(), KernelType::Dense);
check_kernel_from_dataset_view_type(&input.view(), KernelType::Sparse(3));
}
fn check_kernel_from_dataset_type<'a, L: 'a, T: AsTargets<Elem = L> + FromTargetArray<'a>>(
input: &'a DatasetBase<Array2<f64>, T>,
k_type: KernelType,
) {
let methods = vec![
KernelMethod::Linear,
KernelMethod::Gaussian(0.1),
KernelMethod::Polynomial(1., 2.),
];
for method in methods {
let kernel_ref = Kernel::new(
input.records().view(),
&Kernel::params_with_nn(KdTree)
.method(method.clone())
.kind(k_type.clone()),
);
let kernel_tr = Kernel::params()
.kind(k_type.clone())
.method(method.clone())
.transform(input);
match (&kernel_ref.inner, &kernel_tr.records().inner) {
(KernelInner::Dense(m1), KernelInner::Dense(m2)) => {
assert!(kernels_almost_equal(m1, m2))
}
(KernelInner::Sparse(m1), KernelInner::Sparse(m2)) => {
assert!(kernels_almost_equal(m1, m2))
}
_ => panic!("Kernel inners must match!"),
};
}
}
fn check_kernel_from_dataset_view_type<
'a,
L: 'a,
T: AsTargets<Elem = L> + FromTargetArray<'a>,
>(
input: &'a DatasetBase<ArrayView2<'a, f64>, T>,
k_type: KernelType,
) {
let methods = vec![
KernelMethod::Linear,
KernelMethod::Gaussian(0.1),
KernelMethod::Polynomial(1., 2.),
];
for method in methods {
let kernel_ref = Kernel::new(
*input.records(),
&Kernel::params_with_nn(KdTree)
.method(method.clone())
.kind(k_type.clone()),
);
let kernel_tr = Kernel::params()
.kind(k_type.clone())
.method(method.clone())
.transform(input);
match (&kernel_ref.inner, &kernel_tr.records().inner) {
(KernelInner::Dense(m1), KernelInner::Dense(m2)) => {
assert!(kernels_almost_equal(m1, m2))
}
(KernelInner::Sparse(m1), KernelInner::Sparse(m2)) => {
assert!(kernels_almost_equal(m1, m2))
}
_ => panic!("Kernel inners must match!"),
};
}
}
fn check_kernel_from_array2_type(input: &Array2<f64>, k_type: KernelType) {
let methods = vec![
KernelMethod::Linear,
KernelMethod::Gaussian(0.1),
KernelMethod::Polynomial(1., 2.),
];
for method in methods {
let kernel_ref = Kernel::new(
input.view(),
&Kernel::params_with_nn(KdTree)
.method(method.clone())
.kind(k_type.clone()),
);
let kernel_tr = Kernel::params()
.kind(k_type.clone())
.method(method.clone())
.transform(input.view());
match (&kernel_ref.inner, &kernel_tr.inner) {
(KernelInner::Dense(m1), KernelInner::Dense(m2)) => {
assert!(kernels_almost_equal(m1, m2))
}
(KernelInner::Sparse(m1), KernelInner::Sparse(m2)) => {
assert!(kernels_almost_equal(m1, m2))
}
_ => panic!("Kernel inners must match!"),
};
}
}
fn check_kernel_from_array_view_2_type(input: ArrayView2<f64>, k_type: KernelType) {
let methods = vec![
KernelMethod::Linear,
KernelMethod::Gaussian(0.1),
KernelMethod::Polynomial(1., 2.),
];
for method in methods {
let kernel_ref = Kernel::new(
input,
&Kernel::params_with_nn(KdTree)
.method(method.clone())
.kind(k_type.clone()),
);
let kernel_tr = Kernel::params()
.kind(k_type.clone())
.method(method.clone())
.transform(input);
match (&kernel_ref.inner, &kernel_tr.inner) {
(KernelInner::Dense(m1), KernelInner::Dense(m2)) => {
assert!(kernels_almost_equal(m1, m2))
}
(KernelInner::Sparse(m1), KernelInner::Sparse(m2)) => {
assert!(kernels_almost_equal(m1, m2))
}
_ => panic!("Kernel inners must match!"),
};
}
}
fn matrices_almost_equal(reference: ArrayView2<f64>, transformed: ArrayView2<f64>) -> bool {
for (ref_row, tr_row) in reference
.axis_iter(Axis(0))
.zip(transformed.axis_iter(Axis(0)))
{
if !arrays_almost_equal(ref_row, tr_row) {
return false;
}
}
true
}
fn arrays_almost_equal(reference: ArrayView1<f64>, transformed: ArrayView1<f64>) -> bool {
for (ref_item, tr_item) in reference.iter().zip(transformed.iter()) {
if !values_almost_equal(ref_item, tr_item) {
return false;
}
}
true
}
fn kernels_almost_equal<K: Inner<Elem = f64>>(reference: &K, transformed: &K) -> bool {
for i in 0..reference.size() {
if !vecs_almost_equal(reference.column(i), transformed.column(i)) {
return false;
}
}
true
}
fn vecs_almost_equal(reference: Vec<f64>, transformed: Vec<f64>) -> bool {
for (ref_item, tr_item) in reference.iter().zip(transformed.iter()) {
if !values_almost_equal(ref_item, tr_item) {
return false;
}
}
true
}
fn values_almost_equal(v1: &f64, v2: &f64) -> bool {
(v1 - v2).abs() <= f64::EPSILON
}
}