use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
#[derive(Debug, Clone)]
pub struct Dataset {
pub train: Vec<Vec<f32>>,
pub test: Vec<Vec<f32>>,
pub dimension: usize,
}
impl Dataset {
pub fn n_train(&self) -> usize {
self.train.len()
}
pub fn n_test(&self) -> usize {
self.test.len()
}
pub fn memory_bytes(&self) -> usize {
(self.train.len() + self.test.len()) * self.dimension * std::mem::size_of::<f32>()
}
}
pub fn create_benchmark_dataset(
n_train: usize,
n_test: usize,
dimension: usize,
seed: u64,
) -> Dataset {
let mut rng = StdRng::seed_from_u64(seed);
let train: Vec<Vec<f32>> = (0..n_train)
.map(|_| (0..dimension).map(|_| rng.random::<f32>()).collect())
.collect();
let test: Vec<Vec<f32>> = (0..n_test)
.map(|_| (0..dimension).map(|_| rng.random::<f32>()).collect())
.collect();
Dataset {
train,
test,
dimension,
}
}
pub fn create_clustered_dataset(
n_train: usize,
n_test: usize,
dimension: usize,
n_clusters: usize,
cluster_std: f32,
seed: u64,
) -> Dataset {
let mut rng = StdRng::seed_from_u64(seed);
let centers: Vec<Vec<f32>> = (0..n_clusters)
.map(|_| (0..dimension).map(|_| rng.random::<f32>()).collect())
.collect();
let sample_near_center = |rng: &mut StdRng, center: &[f32]| -> Vec<f32> {
center
.iter()
.map(|&c| {
let u1: f32 = rng.random();
let u2: f32 = rng.random();
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
(c + z * cluster_std).clamp(0.0, 1.0)
})
.collect()
};
let train: Vec<Vec<f32>> = (0..n_train)
.map(|_| {
let cluster_idx = rng.random_range(0..n_clusters);
sample_near_center(&mut rng, ¢ers[cluster_idx])
})
.collect();
let test: Vec<Vec<f32>> = (0..n_test)
.map(|_| {
let cluster_idx = rng.random_range(0..n_clusters);
sample_near_center(&mut rng, ¢ers[cluster_idx])
})
.collect();
Dataset {
train,
test,
dimension,
}
}
pub fn compute_ground_truth(query: &[f32], database: &[Vec<f32>], k: usize) -> Vec<u32> {
let mut distances: Vec<(u32, f32)> = database
.iter()
.enumerate()
.map(|(i, vec)| {
let dist = l2_distance_squared(query, vec);
(i as u32, dist)
})
.collect();
distances.sort_by(|a, b| a.1.total_cmp(&b.1));
distances.into_iter().take(k).map(|(id, _)| id).collect()
}
pub fn compute_all_ground_truth(dataset: &Dataset, k: usize) -> Vec<Vec<u32>> {
dataset
.test
.iter()
.map(|query| compute_ground_truth(query, &dataset.train, k))
.collect()
}
#[inline]
fn l2_distance_squared(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let d = x - y;
d * d
})
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_benchmark_dataset() {
let dataset = create_benchmark_dataset(100, 10, 64, 42);
assert_eq!(dataset.n_train(), 100);
assert_eq!(dataset.n_test(), 10);
assert_eq!(dataset.dimension, 64);
assert_eq!(dataset.train[0].len(), 64);
}
#[test]
fn test_create_clustered_dataset() {
let dataset = create_clustered_dataset(1000, 100, 64, 10, 0.1, 42);
assert_eq!(dataset.n_train(), 1000);
assert_eq!(dataset.n_test(), 100);
for vec in &dataset.train {
for &v in vec {
assert!(v >= 0.0 && v <= 1.0);
}
}
}
#[test]
fn test_compute_ground_truth() {
let database = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
];
let query = vec![0.1, 0.1];
let gt = compute_ground_truth(&query, &database, 2);
assert_eq!(gt[0], 0);
assert!(gt[1] == 1 || gt[1] == 2);
}
#[test]
fn test_memory_bytes() {
let dataset = create_benchmark_dataset(100, 10, 64, 42);
let expected = (100 + 10) * 64 * 4; assert_eq!(dataset.memory_bytes(), expected);
}
}