use crate::{AppxDbscan, AppxDbscanParams, AppxDbscanParamsError, AppxDbscanValidParams, Dbscan};
use linfa::traits::Transformer;
use linfa::ParamGuard;
use linfa_datasets::generate;
use linfa_nn::distance::L2Dist;
use ndarray::{arr1, arr2, concatenate, s, Array1, Array2};
use ndarray_rand::rand::SeedableRng;
use ndarray_rand::rand_distr::Uniform;
use rand_xoshiro::Xoshiro256Plus;
use std::collections::HashMap;
#[test]
fn autotraits() {
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
has_autotraits::<AppxDbscan>();
has_autotraits::<Dbscan>();
has_autotraits::<AppxDbscanValidParams<f64, L2Dist>>();
has_autotraits::<AppxDbscanParams<f64, L2Dist>>();
}
#[test]
fn appx_dbscan_parity() {
let mut rng = Xoshiro256Plus::seed_from_u64(40);
let min_points = 4;
let tolerance = 0.8;
let centroids = arr2(&[
[-99.9, -88.3, 78.9],
[-69.3, 90.1, -87.3],
[20., 43.2, 10.2],
[-1.3, 56.0, 98.9],
]);
let outliers = arr2(&[[40.0, 55.5, 78.0], [-33.3, -1., 0.3], [-87.1, 0., 33.3]]);
let clusters =
generate::blobs_with_distribution(100, ¢roids, Uniform::new(-1., 1.), &mut rng);
let dataset = concatenate![ndarray::Axis(0), clusters, outliers];
let appx_res = AppxDbscan::params(min_points)
.tolerance(tolerance)
.slack(1e-6)
.transform(&dataset)
.unwrap();
let ex_res = Dbscan::params(min_points)
.tolerance(tolerance)
.check()
.unwrap()
.transform(&dataset);
let mut ex_appx_correspondence: HashMap<i64, i64> = HashMap::new();
for (i, (ex_label, appx_label)) in ex_res.iter().zip(appx_res.iter()).enumerate() {
println!("{:?} = {:?} {}", ex_label, appx_label, dataset.row(i));
let ex_value = match ex_label {
Some(value) => *value as i64,
None => -1,
};
let appx_value = match appx_label {
Some(value) => *value as i64,
None => -1,
};
if ex_value == -1 {
assert!(appx_value == -1);
}
let expected_appx_val = ex_appx_correspondence.entry(ex_value).or_insert(appx_value);
assert_eq!(*expected_appx_val, appx_value);
}
}
#[test]
fn non_cluster_points() {
let mut data: Array2<f64> = Array2::zeros((5, 2));
data.row_mut(0).assign(&arr1(&[10.0, 10.0]));
let labels = AppxDbscan::params(4).check().unwrap().transform(&data);
let expected = arr1(&[None, Some(0), Some(0), Some(0), Some(0)]);
assert_eq!(labels, expected);
}
#[test]
fn test_border() {
let data: Array2<f64> = arr2(&[
[0.0, 2.0],
[0.0, 0.0],
[0.0, 1.0],
[0.0, -1.0],
[-1.0, 0.0],
[1.0, 0.0],
]);
let labels = AppxDbscan::params(5)
.tolerance(1.1)
.slack(1e-5)
.transform(&data)
.unwrap();
assert_eq!(labels[0], None);
for id in labels.slice(s![1..]).iter() {
assert_eq!(id, &Some(0));
}
}
#[test]
fn test_outliers() {
let mut data: Array2<f64> = Array2::zeros((100, 2));
let linspace_center = Array1::linspace(0.0, 0.8, 50);
data.column_mut(0)
.slice_mut(s![0..50])
.assign(&linspace_center);
data.column_mut(1)
.slice_mut(s![0..50])
.assign(&linspace_center);
let linspace_out = Array1::linspace(5., 1000., 25);
let linspace_out_neg = Array1::linspace(-1000., -5., 25);
data.column_mut(0)
.slice_mut(s![50..75])
.assign(&linspace_out);
data.column_mut(1)
.slice_mut(s![50..75])
.assign(&linspace_out);
data.column_mut(0)
.slice_mut(s![75..100])
.assign(&linspace_out_neg);
data.column_mut(1)
.slice_mut(s![75..100])
.assign(&linspace_out_neg);
let labels = AppxDbscan::params(2)
.tolerance(1.0)
.slack(1e-4)
.transform(&data)
.unwrap();
for i in 0..50 {
assert!(labels[i].is_some());
assert_eq!(labels[i].unwrap(), 0);
assert!(labels[i + 50].is_none());
}
}
#[test]
fn nested_clusters() {
let mut data: Array2<f64> = Array2::zeros((50, 2));
let rising = Array1::linspace(0.0, 8.0, 10);
data.column_mut(0).slice_mut(s![0..10]).assign(&rising);
data.column_mut(0).slice_mut(s![10..20]).assign(&rising);
data.column_mut(1).slice_mut(s![20..30]).assign(&rising);
data.column_mut(1).slice_mut(s![30..40]).assign(&rising);
data.column_mut(1).slice_mut(s![0..10]).fill(0.0);
data.column_mut(1).slice_mut(s![10..20]).fill(8.0);
data.column_mut(0).slice_mut(s![20..30]).fill(0.0);
data.column_mut(0).slice_mut(s![30..40]).fill(8.0);
data.column_mut(0).slice_mut(s![40..]).fill(5.0);
data.column_mut(1).slice_mut(s![40..]).fill(5.0);
let labels = AppxDbscan::params(2)
.tolerance(1.0)
.slack(1e-4)
.transform(&data)
.unwrap();
assert!(labels.slice(s![..40]).iter().all(|x| x == &Some(0)));
assert!(labels.slice(s![40..]).iter().all(|x| x == &Some(1)));
}
#[test]
fn tolerance_cannot_be_zero() {
let res = AppxDbscan::params(2).tolerance(0.0).slack(0.1).check();
assert!(matches!(res, Err(AppxDbscanParamsError::Tolerance)));
}
#[test]
fn slack_cannot_be_zero() {
let res = AppxDbscan::params(2).tolerance(0.1).slack(0.0).check();
assert!(matches!(res, Err(AppxDbscanParamsError::Slack)));
}
#[test]
fn min_points_at_least_2() {
let res = AppxDbscan::params(1).tolerance(0.1).slack(0.1).check();
assert!(matches!(res, Err(AppxDbscanParamsError::MinPoints)));
}
#[test]
fn tolerance_should_be_positive() {
let res = AppxDbscan::params(2).tolerance(-1.0).slack(0.1).check();
assert!(matches!(res, Err(AppxDbscanParamsError::Tolerance)));
}
#[test]
fn slack_should_be_positive() {
let res = AppxDbscan::params(2).tolerance(0.1).slack(-1.0).check();
assert!(matches!(res, Err(AppxDbscanParamsError::Slack)));
}