use std::collections::HashMap;
#[derive(Clone, Debug)]
pub struct SimpleLSHConfig {
pub num_bits: usize,
pub num_tables: usize,
pub dimensions: usize,
}
impl Default for SimpleLSHConfig {
fn default() -> Self {
Self {
num_bits: 6, num_tables: 10,
dimensions: 768,
}
}
}
pub struct SimpleLSH {
hyperplanes: Vec<Vec<Vec<f32>>>,
buckets: Vec<HashMap<u64, Vec<u32>>>,
config: SimpleLSHConfig,
}
impl SimpleLSH {
pub fn new(config: SimpleLSHConfig) -> Self {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hyperplanes = Vec::with_capacity(config.num_tables);
for table in 0..config.num_tables {
let mut table_planes = Vec::with_capacity(config.num_bits);
for bit in 0..config.num_bits {
let mut plane = Vec::with_capacity(config.dimensions);
for dim in 0..config.dimensions {
let mut hasher = DefaultHasher::new();
(table, bit, dim, "hyperplane").hash(&mut hasher);
let hash = hasher.finish();
let val = ((hash as f64 / u64::MAX as f64) * 2.0 - 1.0) as f32;
plane.push(val);
}
let norm: f32 = plane.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in plane.iter_mut() {
*x /= norm;
}
}
table_planes.push(plane);
}
hyperplanes.push(table_planes);
}
let buckets = (0..config.num_tables).map(|_| HashMap::new()).collect();
Self {
hyperplanes,
buckets,
config,
}
}
#[inline]
pub fn hash(&self, vector: &[f32], table: usize) -> u64 {
let mut hash = 0u64;
for (bit, plane) in self.hyperplanes[table].iter().enumerate() {
let dot: f32 = vector.iter().zip(plane.iter()).map(|(v, p)| v * p).sum();
if dot > 0.0 {
hash |= 1 << bit;
}
}
hash
}
pub fn hash_all(&self, vector: &[f32]) -> Vec<u64> {
(0..self.config.num_tables)
.map(|t| self.hash(vector, t))
.collect()
}
pub fn insert(&mut self, id: u32, vector: &[f32]) {
for table in 0..self.config.num_tables {
let hash = self.hash(vector, table);
self.buckets[table].entry(hash).or_default().push(id);
}
}
pub fn query_table(&self, vector: &[f32], table: usize) -> &[u32] {
let hash = self.hash(vector, table);
self.buckets[table]
.get(&hash)
.map(|v| v.as_slice())
.unwrap_or(&[])
}
pub fn query_multiprobe(&self, vector: &[f32], num_probes: usize) -> Vec<u32> {
let mut candidates = Vec::new();
let mut seen = std::collections::HashSet::new();
for table in 0..self.config.num_tables {
let base_hash = self.hash(vector, table);
if let Some(ids) = self.buckets[table].get(&base_hash) {
for &id in ids {
if seen.insert(id) {
candidates.push(id);
}
}
}
for probe in 0..num_probes.min(self.config.num_bits) {
let probe_hash = base_hash ^ (1 << probe);
if let Some(ids) = self.buckets[table].get(&probe_hash) {
for &id in ids {
if seen.insert(id) {
candidates.push(id);
}
}
}
}
}
candidates
}
pub fn get_probe_sequence(&self, vector: &[f32], table: usize, num_probes: usize) -> Vec<u64> {
let base_hash = self.hash(vector, table);
let mut probes = vec![base_hash];
for bit in 0..num_probes.min(self.config.num_bits) {
probes.push(base_hash ^ (1 << bit));
}
probes
}
pub fn num_tables(&self) -> usize {
self.config.num_tables
}
pub fn bucket_count(&self) -> usize {
self.buckets.iter().map(|t| t.len()).sum()
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::new();
bytes.extend(&(self.config.num_bits as u32).to_le_bytes());
bytes.extend(&(self.config.num_tables as u32).to_le_bytes());
bytes.extend(&(self.config.dimensions as u32).to_le_bytes());
for table in &self.hyperplanes {
for plane in table {
for &val in plane {
bytes.extend(&val.to_le_bytes());
}
}
}
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 12 {
return None;
}
let num_bits = u32::from_le_bytes(bytes[0..4].try_into().ok()?) as usize;
let num_tables = u32::from_le_bytes(bytes[4..8].try_into().ok()?) as usize;
let dimensions = u32::from_le_bytes(bytes[8..12].try_into().ok()?) as usize;
let config = SimpleLSHConfig {
num_bits,
num_tables,
dimensions,
};
let mut offset = 12;
let mut hyperplanes = Vec::with_capacity(num_tables);
for _ in 0..num_tables {
let mut table_planes = Vec::with_capacity(num_bits);
for _ in 0..num_bits {
let mut plane = Vec::with_capacity(dimensions);
for _ in 0..dimensions {
if offset + 4 > bytes.len() {
return None;
}
let val = f32::from_le_bytes(bytes[offset..offset + 4].try_into().ok()?);
plane.push(val);
offset += 4;
}
table_planes.push(plane);
}
hyperplanes.push(table_planes);
}
let buckets = (0..num_tables).map(|_| HashMap::new()).collect();
Some(Self {
hyperplanes,
buckets,
config,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_lsh_hash() {
let config = SimpleLSHConfig {
num_bits: 6,
num_tables: 4,
dimensions: 64,
};
let lsh = SimpleLSH::new(config);
let vec1: Vec<f32> = (0..64).map(|i| i as f32 * 0.01).collect();
let vec2: Vec<f32> = (0..64).map(|i| i as f32 * 0.01 + 0.001).collect();
let h1 = lsh.hash_all(&vec1);
let h2 = lsh.hash_all(&vec2);
let matches = h1.iter().zip(h2.iter()).filter(|(a, b)| a == b).count();
assert!(
matches > 0,
"Similar vectors should have some matching hashes"
);
}
#[test]
fn test_simple_lsh_insert_query() {
let config = SimpleLSHConfig {
num_bits: 4,
num_tables: 2,
dimensions: 32,
};
let mut lsh = SimpleLSH::new(config);
for i in 0..10 {
let vec: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32 * 0.01).collect();
lsh.insert(i, &vec);
}
let query: Vec<f32> = (0..32).map(|j| j as f32 * 0.01).collect();
let candidates = lsh.query_multiprobe(&query, 4);
assert!(!candidates.is_empty());
}
#[test]
fn test_simple_lsh_serialization() {
let config = SimpleLSHConfig {
num_bits: 4,
num_tables: 2,
dimensions: 16,
};
let lsh = SimpleLSH::new(config);
let bytes = lsh.to_bytes();
let restored = SimpleLSH::from_bytes(&bytes).unwrap();
let vec: Vec<f32> = (0..16).map(|i| i as f32).collect();
assert_eq!(lsh.hash_all(&vec), restored.hash_all(&vec));
}
}