use ndarray::Array2;
use ::hdf5::*;
use log::debug;
pub struct AnnBenchmarkData {
pub fname: String,
pub test_distances: Array2<f32>,
#[allow(unused)]
pub test_neighbours: Array2<i32>,
pub test_data: Vec<Vec<f32>>,
pub train_data: Vec<(Vec<f32>, usize)>,
#[allow(unused)]
pub searched_neighbours: Vec<Vec<i32>>,
#[allow(unused)]
pub searched_distances: Vec<Vec<f32>>,
}
impl AnnBenchmarkData {
pub fn new(fname: String) -> Result<AnnBenchmarkData> {
let res = hdf5::File::open(fname.clone());
if res.is_err() {
println!("you must download file {:?}", fname);
panic!(
"download benchmark file some where and modify examples source file accordingly"
);
}
let file = res.ok().unwrap();
let res_distances = file.dataset("distances");
if res_distances.is_err() {
panic!("error getting distances dataset");
}
let distances = res_distances.unwrap();
let shape = distances.shape();
assert_eq!(shape.len(), 2);
let dataf32 = distances.dtype().unwrap().is::<f32>();
if !dataf32 {
panic!("error getting type distances dataset");
}
let res = distances.read_2d::<f32>();
if res.is_err() {
panic!("error reading distances dataset");
}
let test_distances = res.unwrap();
debug!(
"First 2 distances for first test {:?} {:?} ",
test_distances.get((0, 0)).unwrap(),
test_distances.get((0, 1)).unwrap()
);
let res_neighbours = file.dataset("neighbors");
if res_neighbours.is_err() {
panic!("error getting neighbours");
}
let neighbours = res_neighbours.unwrap();
let shape = neighbours.shape();
assert_eq!(shape.len(), 2);
println!("neighbours shape : {:?}", shape);
let datai32 = neighbours.dtype().unwrap().is::<i32>();
if !datai32 {
panic!("error getting type neighbours");
}
let res = neighbours.read_2d::<i32>();
if res.is_err() {
panic!("error reading neighbours dataset");
}
let test_neighbours = res.unwrap();
debug!(
"First 2 neighbours for first test {:?} {:?} ",
test_neighbours.get((0, 0)).unwrap(),
test_neighbours.get((0, 1)).unwrap()
);
println!("\n 10 first neighbours for first vector : ");
for i in 0..10 {
print!(" {:?} ", test_neighbours.get((0, i)).unwrap());
}
println!("\n 10 first neighbours for second vector : ");
for i in 0..10 {
print!(" {:?} ", test_neighbours.get((1, i)).unwrap());
}
let res_testdata = file.dataset("test");
if res_testdata.is_err() {
panic!("error getting test de notataset");
}
let test_data = res_testdata.unwrap();
let shape = test_data.shape(); assert_eq!(shape.len(), 2);
let dataf32 = test_data.dtype().unwrap().is::<f32>();
if !dataf32 {
panic!("error getting type de notistances dataset");
}
let res = test_data.read_2d::<f32>();
if res.is_err() {
panic!("error reading distances dataset");
}
let test_data_2d = res.unwrap();
let mut test_data = Vec::<Vec<f32>>::with_capacity(shape[1]);
let (nbrow, nbcolumn) = test_data_2d.dim();
println!(" test data, nb element {:?}, dim : {:?}", nbrow, nbcolumn);
for i in 0..nbrow {
let mut vec = Vec::with_capacity(nbcolumn);
for j in 0..nbcolumn {
vec.push(*test_data_2d.get((i, j)).unwrap());
}
test_data.push(vec);
}
let res_traindata = file.dataset("train");
if res_traindata.is_err() {
panic!("error getting distances dataset");
}
let train_data = res_traindata.unwrap();
let train_shape = train_data.shape();
assert_eq!(shape.len(), 2);
if test_data_2d.dim().1 != train_shape[1] {
println!("test and train have not the same dimension");
panic!();
}
println!(
"\n train data shape : {:?}, nbvector {:?} ",
train_shape, train_shape[0]
);
let dataf32 = train_data.dtype().unwrap().is::<f32>();
if !dataf32 {
panic!("error getting type distances dataset");
}
let res = train_data.read_2d::<f32>();
if res.is_err() {
panic!("error reading distances dataset");
}
let train_data_2d = res.unwrap();
let mut train_data = Vec::<(Vec<f32>, usize)>::with_capacity(shape[1]);
let (nbrow, nbcolumn) = train_data_2d.dim();
for i in 0..nbrow {
let mut vec = Vec::with_capacity(nbcolumn);
for j in 0..nbcolumn {
vec.push(*train_data_2d.get((i, j)).unwrap());
}
train_data.push((vec, i));
}
println!(
" allocating vector for search neighbours answer : {:?}",
test_data.len()
);
let searched_neighbours = Vec::<Vec<i32>>::with_capacity(test_data.len());
let searched_distances = Vec::<Vec<f32>>::with_capacity(test_data.len());
Ok(AnnBenchmarkData {
fname: fname.clone(),
test_distances,
test_neighbours,
test_data,
train_data,
searched_neighbours,
searched_distances,
})
}
#[allow(unused)]
pub fn do_l2_normalization(&mut self) {
for i in 0..self.test_data.len() {
anndists::dist::l2_normalize(&mut self.test_data[i]);
}
for i in 0..self.train_data.len() {
anndists::dist::l2_normalize(&mut self.train_data[i].0);
}
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_load_hdf5() {
env_logger::Builder::from_default_env().init();
let fname = String::from("/home.2/Data/ANN/glove-25-angular.hdf5");
println!("\n\n test_load_hdf5 {:?}", fname);
let _anndata = AnnBenchmarkData::new(fname).unwrap();
} }