use std::collections::HashMap;
use rand::SeedableRng;
use rand::rngs::StdRng;
use rand_distr::{Distribution, Normal, StandardNormal};
use serde::{Deserialize, Serialize};
use mnemonist_core::distance::normalize;
#[derive(Debug, Clone)]
pub struct BenchmarkDataset {
pub documents: Vec<Document>,
pub queries: Vec<BenchmarkQuery>,
}
#[derive(Debug, Clone)]
pub struct Document {
pub id: String,
pub embedding: Vec<f32>,
pub group: usize,
}
#[derive(Debug, Clone)]
pub struct BenchmarkQuery {
pub id: String,
pub embedding: Vec<f32>,
pub judgments: HashMap<String, u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SyntheticConfig {
pub n_documents: usize,
pub n_clusters: usize,
pub n_queries: usize,
pub dimension: usize,
pub noise: f32,
pub seed: u64,
}
impl Default for SyntheticConfig {
fn default() -> Self {
Self {
n_documents: 500,
n_clusters: 10,
n_queries: 50,
dimension: 32,
noise: 0.3,
seed: 42,
}
}
}
pub fn generate_synthetic(config: &SyntheticConfig) -> BenchmarkDataset {
let mut rng = StdRng::seed_from_u64(config.seed);
let normal = Normal::new(0.0f32, config.noise).unwrap();
let centroids: Vec<Vec<f32>> = (0..config.n_clusters)
.map(|_| {
let mut v: Vec<f32> = (0..config.dimension)
.map(|_| StandardNormal.sample(&mut rng))
.collect();
normalize(&mut v);
v
})
.collect();
let docs_per_cluster = config.n_documents / config.n_clusters.max(1);
let mut documents = Vec::with_capacity(config.n_documents);
for (cluster_id, centroid) in centroids.iter().enumerate() {
for j in 0..docs_per_cluster {
let mut emb: Vec<f32> = centroid
.iter()
.map(|&c| c + normal.sample(&mut rng))
.collect();
normalize(&mut emb);
documents.push(Document {
id: format!("doc_{cluster_id}_{j}"),
embedding: emb,
group: cluster_id,
});
}
}
let queries_per_cluster = config.n_queries / config.n_clusters.max(1);
let query_noise = Normal::new(0.0f32, config.noise * 0.5).unwrap();
let mut queries = Vec::with_capacity(config.n_queries);
for (cluster_id, centroid) in centroids.iter().enumerate() {
for j in 0..queries_per_cluster {
let mut emb: Vec<f32> = centroid
.iter()
.map(|&c| c + query_noise.sample(&mut rng))
.collect();
normalize(&mut emb);
let mut judgments = HashMap::new();
for doc in &documents {
if doc.group == cluster_id {
judgments.insert(doc.id.clone(), 2);
}
}
queries.push(BenchmarkQuery {
id: format!("query_{cluster_id}_{j}"),
embedding: emb,
judgments,
});
}
}
BenchmarkDataset { documents, queries }
}
pub fn brute_force_top_k(query: &[f32], documents: &[Document], k: usize) -> Vec<usize> {
let mut scored: Vec<(usize, f32)> = documents
.iter()
.enumerate()
.map(|(i, doc)| {
(
i,
mnemonist_core::distance::cosine_similarity(query, &doc.embedding),
)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
scored.into_iter().take(k).map(|(i, _)| i).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn synthetic_dataset_structure() {
let config = SyntheticConfig {
n_documents: 100,
n_clusters: 5,
n_queries: 10,
dimension: 16,
noise: 0.3,
seed: 42,
};
let ds = generate_synthetic(&config);
assert_eq!(ds.documents.len(), 100);
assert_eq!(ds.queries.len(), 10);
for q in &ds.queries {
assert!(!q.judgments.is_empty());
}
for doc in &ds.documents {
let norm: f32 = doc.embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-4);
}
}
#[test]
fn brute_force_returns_nearest() {
let config = SyntheticConfig {
n_documents: 50,
n_clusters: 5,
n_queries: 5,
dimension: 8,
noise: 0.1,
seed: 99,
};
let ds = generate_synthetic(&config);
let q = &ds.queries[0];
let top = brute_force_top_k(&q.embedding, &ds.documents, 5);
assert_eq!(top.len(), 5);
let query_cluster = q
.judgments
.keys()
.next()
.and_then(|id| id.split('_').nth(1))
.and_then(|s| s.parse::<usize>().ok())
.unwrap();
let same_cluster = top
.iter()
.filter(|&&i| ds.documents[i].group == query_cluster)
.count();
assert!(
same_cluster >= 3,
"expected >= 3 same-cluster, got {same_cluster}"
);
}
#[test]
fn deterministic_output() {
let config = SyntheticConfig::default();
let ds1 = generate_synthetic(&config);
let ds2 = generate_synthetic(&config);
assert_eq!(ds1.documents.len(), ds2.documents.len());
assert_eq!(ds1.documents[0].embedding, ds2.documents[0].embedding);
}
}