use anyhow::Result;
use oxirs_vec::{LshConfig, LshFamily, LshIndex, Vector, VectorIndex};
use std::time::Instant;
fn main() -> Result<()> {
println!("Locality Sensitive Hashing (LSH) Example");
println!("========================================\n");
println!("1. Random Projection LSH (Cosine Similarity)");
random_projection_example()?;
println!();
println!("2. MinHash LSH (Jaccard Similarity)");
minhash_example()?;
println!();
println!("3. Multi-probe LSH");
multiprobe_example()?;
println!();
println!("4. Performance Comparison");
performance_comparison()?;
Ok(())
}
fn random_projection_example() -> Result<()> {
let config = LshConfig {
num_tables: 10,
num_hash_functions: 8,
lsh_family: LshFamily::RandomProjection,
seed: 42,
multi_probe: false,
num_probes: 0,
};
let mut index = LshIndex::new(config);
let documents = vec![
("doc1", vec![0.8, 0.2, 0.1, 0.0, 0.3]),
("doc2", vec![0.1, 0.9, 0.2, 0.1, 0.0]),
("doc3", vec![0.0, 0.1, 0.8, 0.9, 0.2]),
("doc4", vec![0.7, 0.3, 0.2, 0.1, 0.4]), ];
for (uri, values) in &documents {
let vector = Vector::new(values.clone());
index.insert(uri.to_string(), vector)?;
}
let query = Vector::new(vec![0.9, 0.1, 0.2, 0.0, 0.2]);
let results = index.search_knn(&query, 2)?;
println!(" Query vector: [0.9, 0.1, 0.2, 0.0, 0.2]");
println!(" Top 2 results:");
for (uri, distance) in results {
println!(" {uri} (distance: {distance:.4})");
}
let stats = index.stats();
println!(
" Index stats: {} vectors, {} tables, avg bucket size: {:.2}",
stats.num_vectors, stats.num_tables, stats.avg_bucket_size
);
Ok(())
}
fn minhash_example() -> Result<()> {
let config = LshConfig {
num_tables: 5,
num_hash_functions: 64,
lsh_family: LshFamily::MinHash,
seed: 42,
multi_probe: false,
num_probes: 0,
};
let mut index = LshIndex::new(config);
let mut doc1 = vec![0.0; 100];
doc1[5] = 1.0;
doc1[10] = 1.0;
doc1[15] = 1.0;
doc1[20] = 1.0;
let mut doc2 = vec![0.0; 100];
doc2[5] = 1.0;
doc2[10] = 1.0;
doc2[25] = 1.0;
doc2[30] = 1.0;
let mut doc3 = vec![0.0; 100];
doc3[50] = 1.0;
doc3[55] = 1.0;
doc3[60] = 1.0;
doc3[65] = 1.0;
index.insert("doc1".to_string(), Vector::new(doc1.clone()))?;
index.insert("doc2".to_string(), Vector::new(doc2))?;
index.insert("doc3".to_string(), Vector::new(doc3))?;
let query = Vector::new(doc1);
let results = index.search_knn(&query, 3)?;
println!(" Query: Document with terms at positions [5, 10, 15, 20]");
println!(" Results ordered by Jaccard similarity:");
for (uri, distance) in results {
let similarity = 1.0 - distance; println!(" {uri} (Jaccard similarity: {similarity:.4})");
}
Ok(())
}
fn multiprobe_example() -> Result<()> {
let config_standard = LshConfig {
num_tables: 3,
num_hash_functions: 4,
lsh_family: LshFamily::RandomProjection,
seed: 42,
multi_probe: false,
num_probes: 0,
};
let config_multiprobe = LshConfig {
num_tables: 3,
num_hash_functions: 4,
lsh_family: LshFamily::RandomProjection,
seed: 42,
multi_probe: true,
num_probes: 3,
};
let mut index_standard = LshIndex::new(config_standard);
let mut index_multiprobe = LshIndex::new(config_multiprobe);
let num_vectors = 20;
for i in 0..num_vectors {
let angle = i as f32 * 2.0 * std::f32::consts::PI / num_vectors as f32;
let vector = Vector::new(vec![angle.cos(), angle.sin()]);
let uri = format!("point_{i}");
index_standard.insert(uri.clone(), vector.clone())?;
index_multiprobe.insert(uri, vector)?;
}
let query = Vector::new(vec![1.0, 0.0]);
let results_standard = index_standard.search_knn(&query, 5)?;
let results_multiprobe = index_multiprobe.search_knn(&query, 5)?;
println!(" Query point: [1.0, 0.0]");
println!(" Standard LSH found {} neighbors", results_standard.len());
println!(
" Multi-probe LSH found {} neighbors",
results_multiprobe.len()
);
Ok(())
}
fn performance_comparison() -> Result<()> {
let dimensions = 128;
let num_vectors = 5000;
let num_queries = 100;
let lsh_config = LshConfig {
num_tables: 10,
num_hash_functions: 8,
lsh_family: LshFamily::RandomProjection,
seed: 42,
multi_probe: true,
num_probes: 2,
};
let mut lsh_index = LshIndex::new(lsh_config);
use oxirs_vec::MemoryVectorIndex;
let mut brute_force = MemoryVectorIndex::new();
println!(" Indexing {num_vectors} {dimensions}-dimensional vectors...");
let start = Instant::now();
for i in 0..num_vectors {
let mut values = Vec::with_capacity(dimensions);
for j in 0..dimensions {
let value = ((i * j + i) as f32 % 100.0) / 100.0 - 0.5;
values.push(value);
}
let vector = Vector::new(values);
let uri = format!("vec_{i}");
lsh_index.insert(uri.clone(), vector.clone())?;
brute_force.insert(uri, vector)?;
}
let indexing_time = start.elapsed();
println!(" Indexing completed in {indexing_time:?}");
println!(" Running {num_queries} queries...");
let mut lsh_times = Vec::new();
let mut brute_times = Vec::new();
let mut recall_scores = Vec::new();
for q in 0..num_queries {
let mut query_values = Vec::with_capacity(dimensions);
for j in 0..dimensions {
let value = ((q * j * 7) as f32 % 100.0) / 100.0 - 0.5;
query_values.push(value);
}
let query = Vector::new(query_values);
let lsh_start = Instant::now();
let lsh_results = lsh_index.search_knn(&query, 10)?;
lsh_times.push(lsh_start.elapsed());
let brute_start = Instant::now();
let brute_results = brute_force.search_knn(&query, 10)?;
brute_times.push(brute_start.elapsed());
let lsh_set: std::collections::HashSet<_> =
lsh_results.iter().map(|(uri, _)| uri).collect();
let brute_set: std::collections::HashSet<_> =
brute_results.iter().map(|(uri, _)| uri).collect();
let intersection = lsh_set.intersection(&brute_set).count();
let recall = intersection as f32 / brute_set.len() as f32;
recall_scores.push(recall);
}
let avg_lsh_time = lsh_times.iter().sum::<std::time::Duration>() / num_queries as u32;
let avg_brute_time = brute_times.iter().sum::<std::time::Duration>() / num_queries as u32;
let avg_recall = recall_scores.iter().sum::<f32>() / num_queries as f32;
let speedup = avg_brute_time.as_secs_f64() / avg_lsh_time.as_secs_f64();
println!("\n Results:");
println!(" LSH average query time: {avg_lsh_time:?}");
println!(" Brute force average query time: {avg_brute_time:?}");
println!(" Speedup: {speedup:.2}x");
println!(" Average recall@10: {:.2}%", avg_recall * 100.0);
let lsh_stats = lsh_index.stats();
println!("\n LSH Index Statistics:");
println!(" Number of tables: {}", lsh_stats.num_tables);
println!(" Average bucket size: {:.2}", lsh_stats.avg_bucket_size);
println!(" Memory usage: {} KB", lsh_stats.memory_usage / 1024);
Ok(())
}