#[cfg(feature = "stdsimd")]
use super::distsimd::*;
#[cfg(feature = "simdeez_f")]
use super::disteez::*;
use std::os::raw::*;
use num_traits::float::*;
#[allow(unused)]
enum DistKind {
DistL1(String),
DistL2(String),
DistDot(String),
DistCosine(String),
DistHamming(String),
DistJaccard(String),
DistHellinger(String),
DistJeffreys(String),
DistJensenShannon(String),
DistCFnPtr,
DistFn,
DistPtr,
DistLevenshtein(String),
DistNoDist(String),
}
pub trait Distance<T: Send + Sync> {
fn eval(&self, va: &[T], vb: &[T]) -> f32;
}
#[derive(Default, Copy, Clone)]
pub struct NoDist;
impl<T: Send + Sync> Distance<T> for NoDist {
fn eval(&self, _va: &[T], _vb: &[T]) -> f32 {
log::error!("panic error : cannot call eval on NoDist");
panic!("cannot call distance with NoDist");
}
}
#[derive(Default, Copy, Clone)]
pub struct DistL1;
macro_rules! implementL1Distance (
($ty:ty) => (
impl Distance<$ty> for DistL1 {
fn eval(&self, va:&[$ty], vb: &[$ty]) -> f32 {
assert_eq!(va.len(), vb.len());
va.iter().zip(vb.iter()).map(|t| (*t.0 as f32- *t.1 as f32).abs()).sum()
} } ) );
implementL1Distance!(i32);
implementL1Distance!(f64);
implementL1Distance!(i64);
implementL1Distance!(u32);
implementL1Distance!(u16);
implementL1Distance!(u8);
impl Distance<f32> for DistL1 {
fn eval(&self, va: &[f32], vb: &[f32]) -> f32 {
cfg_if::cfg_if! {
if #[cfg(feature = "simdeez_f")] {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] {
if is_x86_feature_detected!("avx2") {
distance_l1_f32_simdeez(va,vb)
}
else {
assert_eq!(va.len(), vb.len());
va.iter().zip(vb.iter()).map(|t| (*t.0 - *t.1).abs()).sum()
}
}
#[cfg(any(target_arch = "aarch64"))] {
if std::arch::is_aarch64_feature_detected!("asimd") {
distance_l1_f32_simdeez(va,vb)
}
else {
assert_eq!(va.len(), vb.len());
va.iter().zip(vb.iter()).map(|t| (*t.0 - *t.1).abs()).sum()
}
}
}
else if #[cfg(feature = "stdsimd")] {
distance_l1_f32_simd(va,vb)
}
else {
va.iter().zip(vb.iter()).map(|t| (*t.0 - *t.1 ).abs()).sum()
}
} } }
#[derive(Default, Copy, Clone)]
pub struct DistL2;
macro_rules! implementL2Distance (
($ty:ty) => (
impl Distance<$ty> for DistL2 {
fn eval(&self, va:&[$ty], vb: &[$ty]) -> f32 {
assert_eq!(va.len(), vb.len());
let norm : f32 = va.iter().zip(vb.iter()).map(|t| (*t.0 as f32- *t.1 as f32) * (*t.0 as f32- *t.1 as f32)).sum();
norm.sqrt()
} } ) );
implementL2Distance!(i32);
implementL2Distance!(f64);
implementL2Distance!(i64);
implementL2Distance!(u32);
implementL2Distance!(u16);
implementL2Distance!(u8);
#[allow(unused)]
fn scalar_l2_f32(va: &[f32], vb: &[f32]) -> f32 {
let norm: f32 = va
.iter()
.zip(vb.iter())
.map(|t| (*t.0 - *t.1) * (*t.0 - *t.1))
.sum();
assert!(norm >= 0.);
norm.sqrt()
}
impl Distance<f32> for DistL2 {
fn eval(&self, va: &[f32], vb: &[f32]) -> f32 {
cfg_if::cfg_if! {
if #[cfg(feature = "simdeez_f")] {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] {
if is_x86_feature_detected!("avx2") {
distance_l2_f32_simdeez(va, vb)
}
else {
scalar_l2_f32(va, vb)
}
}
#[cfg(any(target_arch = "aarch64"))] {
if std::arch::is_aarch64_feature_detected!("asimd") {
distance_l2_f32_simdeez(va, vb)
}
else {
scalar_l2_f32(va, vb)
}
}
} else if #[cfg(feature = "stdsimd")] {
return distance_l2_f32_simd(va, vb);
}
else {
scalar_l2_f32(va, vb)
}
}
} }
#[derive(Default, Copy, Clone)]
pub struct DistCosine;
macro_rules! implementCosDistance(
($ty:ty) => (
impl Distance<$ty> for DistCosine {
fn eval(&self, va:&[$ty], vb: &[$ty]) -> f32 {
assert_eq!(va.len(), vb.len());
let dist:f32;
let zero:f64 = 0.;
let res = va.iter().zip(vb.iter()).map(|t| ((*t.0 * *t.1) as f64, (*t.0 * *t.0) as f64, (*t.1 * *t.1) as f64)).
fold((0., 0., 0.), |acc , t| (acc.0 + t.0, acc.1 + t.1, acc.2 + t.2));
if res.1 > zero && res.2 > zero {
let dist_unchecked = 1. - res.0 / (res.1 * res.2).sqrt();
assert!(dist_unchecked >= - 0.00002);
dist = dist_unchecked.max(0.) as f32;
}
else {
dist = 0.;
}
return dist;
} } ) );
implementCosDistance!(f32);
implementCosDistance!(f64);
implementCosDistance!(i64);
implementCosDistance!(i32);
implementCosDistance!(u16);
#[derive(Default, Copy, Clone)]
pub struct DistDot;
#[allow(unused)]
macro_rules! implementDotDistance(
($ty:ty) => (
impl Distance<$ty> for DistDot {
fn eval(&self, va:&[$ty], vb: &[$ty]) -> f32 {
assert_eq!(va.len(), vb.len());
let zero:f32 = 0f32;
let dot = va.iter().zip(vb.iter()).map(|t| (*t.0 * *t.1) as f32).fold(0., |acc , t| (acc + t));
assert(dot <= 1.);
return 1. - dot;
} } ) );
#[allow(unused)]
fn scalar_dot_f32(va: &[f32], vb: &[f32]) -> f32 {
let dot = 1.
- va.iter()
.zip(vb.iter())
.map(|t| (*t.0 * *t.1))
.fold(0., |acc, t| (acc + t));
assert!(dot >= 0.);
dot
}
impl Distance<f32> for DistDot {
fn eval(&self, va: &[f32], vb: &[f32]) -> f32 {
cfg_if::cfg_if! {
if #[cfg(feature = "simdeez_f")] {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
distance_dot_f32_simdeez(va, vb)
} else if is_x86_feature_detected!("sse2") {
distance_dot_f32_simdeez(va, vb)
}
else {
scalar_dot_f32(va, vb)
}
} #[cfg(any(target_arch = "aarch64"))] {
if std::arch::is_aarch64_feature_detected!("asimd") {
distance_dot_f32_simdeez(va, vb)
}
else {
scalar_l2_f32(va, vb)
}
}
} else if #[cfg(feature = "stdsimd")] {
distance_dot_f32_simd_iter(va,vb)
}
else {
scalar_dot_f32(va, vb)
}
}
} }
pub fn l2_normalize(va: &mut [f32]) {
let l2norm = va.iter().map(|t| *t * *t).sum::<f32>().sqrt();
if l2norm > 0. {
for v in va {
*v /= l2norm;
}
}
}
#[derive(Default, Copy, Clone)]
pub struct DistHellinger;
macro_rules! implementHellingerDistance (
($ty:ty) => (
impl Distance<$ty> for DistHellinger {
fn eval(&self, va:&[$ty], vb: &[$ty]) -> f32 {
assert_eq!(va.len(), vb.len());
let mut dist = va.iter().zip(vb.iter()).map(|t| ((*t.0).sqrt() * (*t.1).sqrt()) as f32).fold(0., |acc , t| (acc + t*t));
dist = (1. - dist).sqrt();
dist
} } ) );
implementHellingerDistance!(f64);
impl Distance<f32> for DistHellinger {
fn eval(&self, va: &[f32], vb: &[f32]) -> f32 {
#[cfg(feature = "simdeez_f")]
{
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
return distance_hellinger_f32_simdeez(va, vb);
}
}
#[cfg(any(target_arch = "aarch64"))]
{
if std::arch::is_aarch64_feature_detected!("asimd") {
return distance_hellinger_f32_simdeez(va, vb);
}
}
}
let mut dist = va
.iter()
.zip(vb.iter())
.map(|t| ((*t.0) * (*t.1)).sqrt())
.fold(0., |acc, t| acc + t);
assert!(1. - dist >= -0.000001);
dist = (1. - dist).max(0.).sqrt();
dist
} }
#[derive(Default, Copy, Clone)]
pub struct DistJeffreys;
pub const M_MIN: f32 = 1.0e-30;
macro_rules! implementJeffreysDistance (
($ty:ty) => (
impl Distance<$ty> for DistJeffreys {
fn eval(&self, va:&[$ty], vb: &[$ty]) -> f32 {
let dist = va.iter().zip(vb.iter()).map(|t| (*t.0 - *t.1) * ((*t.0).max(M_MIN as f64)/ (*t.1).max(M_MIN as f64)).ln() as f64).fold(0., |acc , t| (acc + t*t));
dist as f32
} } ) );
implementJeffreysDistance!(f64);
impl Distance<f32> for DistJeffreys {
fn eval(&self, va: &[f32], vb: &[f32]) -> f32 {
#[cfg(feature = "simdeez_f")]
{
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
return distance_jeffreys_f32_simdeez(va, vb);
}
}
#[cfg(any(target_arch = "aarch64"))]
{
if std::arch::is_aarch64_feature_detected!("asimd") {
return distance_jeffreys_f32_simdeez(va, vb);
}
}
}
va.iter()
.zip(vb.iter())
.map(|t| (*t.0 - *t.1) * ((*t.0).max(M_MIN) / (*t.1).max(M_MIN)).ln())
.fold(0., |acc, t| acc + t)
} }
#[derive(Default, Copy, Clone)]
pub struct DistJensenShannon;
macro_rules! implementDistJensenShannon (
($ty:ty) => (
impl Distance<$ty> for DistJensenShannon {
fn eval(&self, va:&[$ty], vb: &[$ty]) -> f32 {
let mut dist = 0.;
assert_eq!(va.len(), vb.len());
for i in 0..va.len() {
let mean_ab = 0.5 * (va[i] + vb[i]);
if va[i] > 0. {
dist += va[i] * (va[i]/mean_ab).ln();
}
if vb[i] > 0. {
dist += vb[i] * (vb[i]/mean_ab).ln();
}
}
(0.5 * dist).sqrt() as f32
} } ) );
implementDistJensenShannon!(f64);
implementDistJensenShannon!(f32);
#[derive(Default, Copy, Clone)]
pub struct DistHamming;
macro_rules! implementHammingDistance (
($ty:ty) => (
impl Distance<$ty> for DistHamming {
fn eval(&self, va:&[$ty], vb: &[$ty]) -> f32 {
assert_eq!(va.len(), vb.len());
let norm : f32 = va.iter().zip(vb.iter()).filter(|t| t.0 != t.1).count() as f32;
norm / va.len() as f32
} } ) );
impl Distance<i32> for DistHamming {
fn eval(&self, va: &[i32], vb: &[i32]) -> f32 {
#[cfg(feature = "simdeez_f")]
{
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
return distance_hamming_i32_simdeez(va, vb);
}
}
#[cfg(any(target_arch = "aarch64"))]
{
if std::arch::is_aarch64_feature_detected!("asimd") {
return distance_hamming_i32_simdeez(va, vb);
}
}
}
assert_eq!(va.len(), vb.len());
let dist: f32 = va.iter().zip(vb.iter()).filter(|t| t.0 != t.1).count() as f32;
dist / va.len() as f32
} }
impl Distance<f64> for DistHamming {
fn eval(&self, va: &[f64], vb: &[f64]) -> f32 {
assert_eq!(va.len(), vb.len());
let dist: usize = va.iter().zip(vb.iter()).filter(|t| t.0 != t.1).count();
(dist as f64 / va.len() as f64) as f32
} }
impl Distance<f32> for DistHamming {
fn eval(&self, va: &[f32], vb: &[f32]) -> f32 {
cfg_if::cfg_if! {
if #[cfg(feature = "stdsimd")] {
return distance_jaccard_f32_16_simd(va,vb);
}
else {
assert_eq!(va.len(), vb.len());
let dist : usize = va.iter().zip(vb.iter()).filter(|t| t.0 != t.1).count();
(dist as f64 / va.len() as f64) as f32
}
}
} }
#[cfg(feature = "stdsimd")]
impl Distance<u32> for DistHamming {
fn eval(&self, va: &[u32], vb: &[u32]) -> f32 {
return distance_jaccard_u32_16_simd(va, vb);
} }
#[cfg(feature = "stdsimd")]
impl Distance<u64> for DistHamming {
fn eval(&self, va: &[u64], vb: &[u64]) -> f32 {
return distance_jaccard_u64_8_simd(va, vb);
} }
#[cfg(feature = "stdsimd")]
impl Distance<u16> for DistHamming {
fn eval(&self, va: &[u16], vb: &[u16]) -> f32 {
return distance_jaccard_u16_32_simd(va, vb);
}
}
implementHammingDistance!(u8);
#[cfg(not(feature = "stdsimd"))]
implementHammingDistance!(u16);
#[cfg(not(feature = "stdsimd"))]
implementHammingDistance!(u32);
#[cfg(not(feature = "stdsimd"))]
implementHammingDistance!(u64);
implementHammingDistance!(i16);
#[derive(Default, Copy, Clone)]
pub struct DistJaccard;
macro_rules! implementJaccardDistance (
($ty:ty) => (
impl Distance<$ty> for DistJaccard {
fn eval(&self, va:&[$ty], vb: &[$ty]) -> f32 {
let (max,min) : (u64, u64) = va.iter().zip(vb.iter()).fold((0u64,0u64), |acc, t| if t.0 > t.1 {
(acc.0 + *t.0 as u64, acc.1 + *t.1 as u64) }
else {
(acc.0 + *t.1 as u64 , acc.1 + *t.0 as u64)
}
);
if max > 0 {
let dist = 1. - (min as f64)/ (max as f64);
assert!(dist >= 0.);
dist as f32
}
else {
0.
}
} } ) );
implementJaccardDistance!(u8);
implementJaccardDistance!(u16);
implementJaccardDistance!(u32);
#[derive(Default, Copy, Clone)]
pub struct DistLevenshtein;
impl Distance<u16> for DistLevenshtein {
fn eval(&self, a: &[u16], b: &[u16]) -> f32 {
let len_a = a.len();
let len_b = b.len();
if len_a < len_b {
return self.eval(b, a);
}
if len_a == 0 {
return len_b as f32;
} else if len_b == 0 {
return len_a as f32;
}
let len_b = len_b + 1;
let mut pre;
let mut tmp;
let mut cur: Vec<usize> = (0..len_b).collect();
for (i, ca) in a.iter().enumerate() {
pre = cur[0];
cur[0] = i + 1;
for (j, cb) in b.iter().enumerate() {
tmp = cur[j + 1];
cur[j + 1] = std::cmp::min(
tmp + 1,
std::cmp::min(
cur[j] + 1,
pre + if ca == cb { 0 } else { 1 },
),
);
pre = tmp;
}
}
cur[len_b - 1] as f32
}
}
type DistCFnPtr<T> = extern "C" fn(*const T, *const T, len: c_ulonglong) -> f32;
pub struct DistCFFI<T: Copy + Clone + Sized + Send + Sync> {
dist_function: DistCFnPtr<T>,
}
impl<T: Copy + Clone + Sized + Send + Sync> DistCFFI<T> {
pub fn new(f: DistCFnPtr<T>) -> Self {
DistCFFI { dist_function: f }
}
}
impl<T: Copy + Clone + Sized + Send + Sync> Distance<T> for DistCFFI<T> {
fn eval(&self, va: &[T], vb: &[T]) -> f32 {
let len = va.len();
let ptr_a = va.as_ptr();
let ptr_b = vb.as_ptr();
let dist = (self.dist_function)(ptr_a, ptr_b, len as c_ulonglong);
log::trace!(
"DistCFFI dist_function_ptr {:?} returning {:?} ",
self.dist_function,
dist
);
dist
} }
#[allow(clippy::type_complexity)]
pub struct DistFn<T: Copy + Clone + Sized + Send + Sync> {
dist_function: Box<dyn Fn(&[T], &[T]) -> f32 + Send + Sync>,
}
#[allow(clippy::type_complexity)]
impl<T: Copy + Clone + Sized + Send + Sync> DistFn<T> {
pub fn new(f: Box<dyn Fn(&[T], &[T]) -> f32 + Send + Sync>) -> Self {
DistFn { dist_function: f }
}
}
impl<T: Copy + Clone + Sized + Send + Sync> Distance<T> for DistFn<T> {
fn eval(&self, va: &[T], vb: &[T]) -> f32 {
(self.dist_function)(va, vb)
}
}
#[derive(Copy, Clone)]
pub struct DistPtr<T: Copy + Clone + Sized + Send + Sync, F: Float> {
dist_function: fn(&[T], &[T]) -> F,
}
impl<T: Copy + Clone + Sized + Send + Sync, F: Float> DistPtr<T, F> {
pub fn new(f: fn(&[T], &[T]) -> F) -> Self {
DistPtr { dist_function: f }
}
}
impl<T: Copy + Clone + Sized + Send + Sync, F: Float> Distance<T> for DistPtr<T, F> {
fn eval(&self, va: &[T], vb: &[T]) -> f32 {
(self.dist_function)(va, vb).to_f32().unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn init_log() -> u64 {
let mut builder = env_logger::Builder::from_default_env();
let _ = builder.is_test(true).try_init();
println!("\n ************** initializing logger *****************\n");
1
}
#[test]
fn test_access_to_dist_l1() {
let distl1 = DistL1;
let v1: Vec<i32> = vec![1, 2, 3];
let v2: Vec<i32> = vec![2, 2, 3];
let d1 = Distance::eval(&distl1, &v1, &v2);
assert_eq!(d1, 1_f32);
let v3: Vec<f32> = vec![1., 2., 3.];
let v4: Vec<f32> = vec![2., 2., 3.];
let d2 = distl1.eval(&v3, &v4);
assert_eq!(d2, 1_f32);
}
#[test]
fn have_avx2() {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
println!("I have avx2");
} else {
println!(" ************ I DO NOT have avx2 ***************");
}
}
}
#[test]
fn have_avx512f_x86() {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx512f") {
println!("have_avx512f_x86 test : I have avx512f");
} else {
println!(
"have_avx512f_x86 test : ************ I DO NOT have avx512f ***************"
);
}
} }
#[test]
fn have_asimd_aarch64() {
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("asimd") {
println!("have_asimd_aarch64 test : I have asimd");
} else {
println!(
"have_asimd_aarch64 test : ************ I DO NOT have asimd ***************"
);
}
} }
#[test]
fn have_sse2() {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse2") {
println!("I have sse2");
} else {
println!(" ************ I DO NOT have SSE2 ***************");
}
}
}
#[test]
fn test_access_to_dist_cos() {
let distcos = DistCosine;
let v1: Vec<i32> = vec![1, -1, 1];
let v2: Vec<i32> = vec![2, 1, -1];
let d1 = Distance::eval(&distcos, &v1, &v2);
assert_eq!(d1, 1_f32);
let v1: Vec<f32> = vec![1.234, -1.678, 1.367];
let v2: Vec<f32> = vec![4.234, -6.678, 10.367];
let d1 = Distance::eval(&distcos, &v1, &v2);
let mut normv1 = 0.;
let mut normv2 = 0.;
let mut prod = 0.;
for i in 0..v1.len() {
prod += v1[i] * v2[i];
normv1 += v1[i] * v1[i];
normv2 += v2[i] * v2[i];
}
let dcos = 1. - prod / (normv1 * normv2).sqrt();
println!("dist cos avec macro = {:?} , avec for {:?}", d1, dcos);
}
#[test]
fn test_dot_distances() {
let mut v1: Vec<f32> = vec![1.234, -1.678, 1.367];
let mut v2: Vec<f32> = vec![4.234, -6.678, 10.367];
let mut normv1 = 0.;
let mut normv2 = 0.;
let mut prod = 0.;
for i in 0..v1.len() {
prod += v1[i] * v2[i];
normv1 += v1[i] * v1[i];
normv2 += v2[i] * v2[i];
}
let dcos = 1. - prod / (normv1 * normv2).sqrt();
l2_normalize(&mut v1);
l2_normalize(&mut v2);
println!(" after normalisation v1 = {:?}", v1);
let dot = DistDot.eval(&v1, &v2);
println!(
"dot cos avec prenormalisation = {:?} , avec for {:?}",
dot, dcos
);
}
#[test]
fn test_l1() {
init_log();
let va: Vec<f32> = vec![1.234, -1.678, 1.367, 1.234, -1.678, 1.367];
let vb: Vec<f32> = vec![4.234, -6.678, 10.367, 1.234, -1.678, 1.367];
let dist = DistL1.eval(&va, &vb);
let dist_check = va
.iter()
.zip(vb.iter())
.map(|t| (*t.0 - *t.1).abs())
.sum::<f32>();
log::info!(" dist : {:.5e} dist_check : {:.5e}", dist, dist_check);
assert!((dist - dist_check).abs() / dist_check < 1.0e-5);
}
#[test]
fn test_jaccard_u16() {
let v1: Vec<u16> = vec![1, 2, 1, 4, 3];
let v2: Vec<u16> = vec![2, 2, 1, 5, 6];
let dist = DistJaccard.eval(&v1, &v2);
println!("dist jaccard = {:?}", dist);
assert_eq!(dist, 1. - 11. / 16.);
}
#[test]
fn test_levenshtein() {
let mut v1: Vec<u16> = vec![1, 2, 3, 4];
let mut v2: Vec<u16> = vec![1, 2, 3, 3];
let mut dist = DistLevenshtein.eval(&v1, &v2);
println!("dist levenshtein = {:?}", dist);
assert_eq!(dist, 1.0);
v1 = vec![1, 2, 3, 4];
v2 = vec![1, 2, 3, 4];
dist = DistLevenshtein.eval(&v1, &v2);
println!("dist levenshtein = {:?}", dist);
assert_eq!(dist, 0.0);
v1 = vec![1, 1, 1, 4];
v2 = vec![1, 2, 3, 4];
dist = DistLevenshtein.eval(&v1, &v2);
println!("dist levenshtein = {:?}", dist);
assert_eq!(dist, 2.0);
v2 = vec![1, 1, 1, 4];
v1 = vec![1, 2, 3, 4];
dist = DistLevenshtein.eval(&v1, &v2);
println!("dist levenshtein = {:?}", dist);
assert_eq!(dist, 2.0);
}
extern "C" fn dist_func_float(va: *const f32, vb: *const f32, len: c_ulonglong) -> f32 {
let mut dist: f32 = 0.;
let sa = unsafe { std::slice::from_raw_parts(va, len as usize) };
let sb = unsafe { std::slice::from_raw_parts(vb, len as usize) };
for i in 0..len {
dist += (sa[i as usize] - sb[i as usize]).abs().sqrt();
}
dist
}
#[test]
fn test_dist_ext_float() {
let va: Vec<f32> = vec![1., 2., 3.];
let vb: Vec<f32> = vec![1., 2., 3.];
println!("in test_dist_ext_float");
let dist1 = dist_func_float(va.as_ptr(), vb.as_ptr(), va.len() as c_ulonglong);
println!("test_dist_ext_float computed : {:?}", dist1);
let mydist = DistCFFI::<f32>::new(dist_func_float);
let dist2 = mydist.eval(&va, &vb);
assert_eq!(dist1, dist2);
}
#[test]
fn test_my_closure() {
let weight = [0.1, 0.8, 0.1];
let my_fn = move |va: &[f32], vb: &[f32]| -> f32 {
let mut dist: f32 = 0.;
for i in 0..va.len() {
dist += weight[i] * (va[i] - vb[i]).abs();
}
dist
};
let my_boxed_f = Box::new(my_fn);
let my_boxed_dist = DistFn::<f32>::new(my_boxed_f);
let va: Vec<f32> = vec![1., 2., 3.];
let vb: Vec<f32> = vec![2., 2., 4.];
let dist = my_boxed_dist.eval(&va, &vb);
println!("test_my_closure computed : {:?}", dist);
assert_eq!(dist, 0.2);
}
#[test]
fn test_hellinger() {
init_log();
let length = 9;
let mut p_data = Vec::with_capacity(length);
let mut q_data = Vec::with_capacity(length);
for _ in 0..length {
p_data.push(1. / length as f32);
q_data.push(1. / length as f32);
}
p_data[0] -= 1. / (2 * length) as f32;
p_data[1] += 1. / (2 * length) as f32;
let dist = DistHellinger.eval(&p_data, &q_data);
let dist_exact_fn = |n: usize| -> f32 {
let d1 = (4. - (6_f32).sqrt() - (2_f32).sqrt()) / n as f32;
d1.sqrt() / (2_f32).sqrt()
};
let dist_exact = dist_exact_fn(length);
log::info!("dist computed {:?} dist exact{:?} ", dist, dist_exact);
println!("dist computed {:?} , dist exact {:?} ", dist, dist_exact);
assert!((dist - dist_exact).abs() < 1.0e-5);
}
#[test]
fn test_jeffreys() {
init_log();
let length = 19;
let mut p_data: Vec<f32> = Vec::with_capacity(length);
let mut q_data: Vec<f32> = Vec::with_capacity(length);
for _ in 0..length {
p_data.push(1. / length as f32);
q_data.push(1. / length as f32);
}
p_data[0] -= 1. / (2 * length) as f32;
p_data[1] += 1. / (2 * length) as f32;
q_data[10] += 1. / (2 * length) as f32;
let dist_eval = DistJeffreys.eval(&p_data, &q_data);
let mut dist_test = 0.;
for i in 0..length {
dist_test +=
(p_data[i] - q_data[i]) * (p_data[i].max(M_MIN) / q_data[i].max(M_MIN)).ln();
}
log::info!("dist eval {:?} dist test{:?} ", dist_eval, dist_test);
println!("dist eval {:?} , dist test {:?} ", dist_eval, dist_test);
assert!(dist_test >= 0.);
assert!((dist_eval - dist_test).abs() < 1.0e-5);
}
#[test]
fn test_jensenshannon() {
init_log();
let length = 19;
let mut p_data: Vec<f32> = Vec::with_capacity(length);
let mut q_data: Vec<f32> = Vec::with_capacity(length);
for _ in 0..length {
p_data.push(1. / length as f32);
q_data.push(1. / length as f32);
}
p_data[0] -= 1. / (2 * length) as f32;
p_data[1] += 1. / (2 * length) as f32;
q_data[10] += 1. / (2 * length) as f32;
p_data[12] = 0.;
q_data[12] = 0.;
let dist_eval = DistJensenShannon.eval(&p_data, &q_data);
log::info!("dist eval {:?} ", dist_eval);
println!("dist eval {:?} ", dist_eval);
}
#[allow(unused)]
use rand::distr::{Distribution, Uniform};
#[test]
fn test_hamming_f64() {
init_log();
let size_test = 500;
let fmax: f64 = 3.;
let mut rng = rand::rng();
for i in 300..size_test {
let between = Uniform::<f64>::try_from(-fmax..fmax).unwrap();
let va: Vec<f64> = (0..i).map(|_| between.sample(&mut rng)).collect();
let mut vb: Vec<f64> = (0..i).map(|_| between.sample(&mut rng)).collect();
vb[..(i / 2)].copy_from_slice(&va[..(i / 2)]);
let easy_dist: u32 = va
.iter()
.zip(vb.iter())
.map(|(a, b)| if a != b { 1 } else { 0 })
.sum();
let h_dist = DistHamming.eval(&va, &vb);
let easy_dist = easy_dist as f32 / va.len() as f32;
let j_exact = ((i / 2) as f32) / (i as f32);
log::debug!(
"test size {:?} HammingDist {:.3e} easy : {:.3e} exact : {:.3e} ",
i,
h_dist,
easy_dist,
j_exact
);
if (easy_dist - h_dist).abs() > 1.0e-5 {
println!(" jhamming = {:?} , jexact = {:?}", h_dist, easy_dist);
log::debug!("va = {:?}", va);
log::debug!("vb = {:?}", vb);
std::process::exit(1);
}
if (j_exact - h_dist).abs() > 1. / i as f32 + 1.0E-5 {
println!(
" jhamming = {:?} , jexact = {:?}, j_easy : {:?}",
h_dist, j_exact, easy_dist
);
log::debug!("va = {:?}", va);
log::debug!("vb = {:?}", vb);
std::process::exit(1);
}
}
}
#[test]
fn test_hamming_f32() {
init_log();
let size_test = 500;
let fmax: f32 = 3.;
let mut rng = rand::rng();
for i in 300..size_test {
let between = Uniform::<f32>::try_from(-fmax..fmax).unwrap();
let va: Vec<f32> = (0..i).map(|_| between.sample(&mut rng)).collect();
let mut vb: Vec<f32> = (0..i).map(|_| between.sample(&mut rng)).collect();
vb[..(i / 2)].copy_from_slice(&va[..(i / 2)]);
let easy_dist: u32 = va
.iter()
.zip(vb.iter())
.map(|(a, b)| if a != b { 1 } else { 0 })
.sum();
let h_dist = DistHamming.eval(&va, &vb);
let easy_dist = easy_dist as f32 / va.len() as f32;
let j_exact = ((i / 2) as f32) / (i as f32);
log::debug!(
"test size {:?} HammingDist {:.3e} easy : {:.3e} exact : {:.3e} ",
i,
h_dist,
easy_dist,
j_exact
);
if (easy_dist - h_dist).abs() > 1.0e-5 {
println!(
" jhamming = {:?} , jexact = {:?}, j_easy : {:?}",
h_dist, j_exact, easy_dist
);
log::debug!("va = {:?}", va);
log::debug!("vb = {:?}", vb);
std::process::exit(1);
}
if (j_exact - h_dist).abs() > 1. / i as f32 + 1.0E-5 {
println!(
" jhamming = {:?} , jexact = {:?}, j_easy : {:?}",
h_dist, j_exact, easy_dist
);
log::debug!("va = {:?}", va);
log::debug!("vb = {:?}", vb);
std::process::exit(1);
}
}
}
#[cfg(feature = "stdsimd")]
#[test]
fn test_feature_simd() {
init_log();
log::info!("I have activated stdsimd");
}
#[test]
#[cfg(feature = "simdeez_f")]
fn test_feature_simdeez() {
init_log();
log::info!("I have activated simdeez");
} }