use rand::prelude::*;
use rand_distr::StandardNormal;
#[derive(Clone, Debug)]
pub struct LSHConfig {
pub num_bits: usize,
pub num_tables: usize,
pub dimensions: usize,
}
impl Default for LSHConfig {
fn default() -> Self {
Self {
num_bits: 12,
num_tables: 8,
dimensions: 768,
}
}
}
pub struct HyperplaneLSH {
config: LSHConfig,
planes: Vec<f32>,
}
impl HyperplaneLSH {
pub fn new(config: LSHConfig) -> Self {
let mut rng = rand::thread_rng();
let total_planes = config.num_tables * config.num_bits;
let total_floats = total_planes * config.dimensions;
let mut planes = Vec::with_capacity(total_floats);
for _ in 0..total_floats {
planes.push(rng.sample(StandardNormal));
}
for plane_idx in 0..total_planes {
let start = plane_idx * config.dimensions;
let end = start + config.dimensions;
let plane = &mut planes[start..end];
let norm: f32 = plane.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in plane.iter_mut() {
*x /= norm;
}
}
}
Self { config, planes }
}
pub fn with_seed(config: LSHConfig, seed: u64) -> Self {
let mut rng = StdRng::seed_from_u64(seed);
let total_planes = config.num_tables * config.num_bits;
let total_floats = total_planes * config.dimensions;
let mut planes = Vec::with_capacity(total_floats);
for _ in 0..total_floats {
planes.push(rng.sample(StandardNormal));
}
for plane_idx in 0..total_planes {
let start = plane_idx * config.dimensions;
let end = start + config.dimensions;
let plane = &mut planes[start..end];
let norm: f32 = plane.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in plane.iter_mut() {
*x /= norm;
}
}
}
Self { config, planes }
}
pub fn hash(&self, vector: &[f32]) -> Vec<u64> {
(0..self.config.num_tables)
.map(|table| self.hash_single_table(table, vector))
.collect()
}
pub fn hash_single_table(&self, table: usize, vector: &[f32]) -> u64 {
let mut hash: u64 = 0;
for bit in 0..self.config.num_bits {
let plane_idx = table * self.config.num_bits + bit;
let plane_start = plane_idx * self.config.dimensions;
let plane = &self.planes[plane_start..plane_start + self.config.dimensions];
let dot: f32 = vector.iter().zip(plane.iter()).map(|(v, p)| v * p).sum();
if dot > 0.0 {
hash |= 1 << bit;
}
}
hash
}
pub fn get_probe_sequence(&self, vector: &[f32], table: usize, num_probes: usize) -> Vec<u64> {
let original_hash = self.hash_single_table(table, vector);
let mut probes = Vec::with_capacity(num_probes);
probes.push(original_hash);
if num_probes <= 1 {
return probes;
}
let mut bit_distances: Vec<(usize, f32)> = Vec::with_capacity(self.config.num_bits);
for bit in 0..self.config.num_bits {
let plane_idx = table * self.config.num_bits + bit;
let plane_start = plane_idx * self.config.dimensions;
let plane = &self.planes[plane_start..plane_start + self.config.dimensions];
let dot: f32 = vector.iter().zip(plane.iter()).map(|(v, p)| v * p).sum();
bit_distances.push((bit, dot.abs()));
}
bit_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
for (bit, _) in bit_distances.iter() {
if probes.len() >= num_probes {
break;
}
let flipped = original_hash ^ (1 << bit);
if !probes.contains(&flipped) {
probes.push(flipped);
}
}
if probes.len() < num_probes {
'outer: for i in 0..bit_distances.len() {
for j in (i + 1)..bit_distances.len() {
if probes.len() >= num_probes {
break 'outer;
}
let flipped =
original_hash ^ (1 << bit_distances[i].0) ^ (1 << bit_distances[j].0);
if !probes.contains(&flipped) {
probes.push(flipped);
}
}
}
}
probes
}
pub fn config(&self) -> &LSHConfig {
&self.config
}
pub fn num_tables(&self) -> usize {
self.config.num_tables
}
pub fn num_bits(&self) -> usize {
self.config.num_bits
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::new();
bytes.extend_from_slice(&(self.config.num_tables as u32).to_le_bytes());
bytes.extend_from_slice(&(self.config.num_bits as u32).to_le_bytes());
bytes.extend_from_slice(&(self.config.dimensions as u32).to_le_bytes());
for &f in &self.planes {
bytes.extend_from_slice(&f.to_le_bytes());
}
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 12 {
return None;
}
let num_tables = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
let num_bits = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]) as usize;
let dimensions = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize;
let expected_floats = num_tables * num_bits * dimensions;
let expected_bytes = 12 + expected_floats * 4;
if bytes.len() < expected_bytes {
return None;
}
let mut planes = Vec::with_capacity(expected_floats);
for i in 0..expected_floats {
let offset = 12 + i * 4;
let f = f32::from_le_bytes([
bytes[offset],
bytes[offset + 1],
bytes[offset + 2],
bytes[offset + 3],
]);
planes.push(f);
}
Some(Self {
config: LSHConfig {
num_tables,
num_bits,
dimensions,
},
planes,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lsh_deterministic_with_seed() {
let config = LSHConfig {
num_bits: 8,
num_tables: 4,
dimensions: 128,
};
let lsh1 = HyperplaneLSH::with_seed(config.clone(), 42);
let lsh2 = HyperplaneLSH::with_seed(config, 42);
let vector: Vec<f32> = (0..128).map(|i| i as f32 / 128.0).collect();
assert_eq!(lsh1.hash(&vector), lsh2.hash(&vector));
}
#[test]
fn test_lsh_similar_vectors_collide() {
let config = LSHConfig {
num_bits: 8,
num_tables: 8,
dimensions: 128,
};
let lsh = HyperplaneLSH::with_seed(config, 42);
let v1: Vec<f32> = (0..128).map(|i| i as f32 / 128.0).collect();
let v2: Vec<f32> = (0..128).map(|i| (i as f32 / 128.0) + 0.001).collect();
let h1 = lsh.hash(&v1);
let h2 = lsh.hash(&v2);
let matches = h1.iter().zip(h2.iter()).filter(|(a, b)| a == b).count();
assert!(
matches > 0,
"Similar vectors should have some matching hashes"
);
}
#[test]
fn test_probe_sequence_includes_original() {
let config = LSHConfig {
num_bits: 8,
num_tables: 4,
dimensions: 128,
};
let lsh = HyperplaneLSH::with_seed(config, 42);
let vector: Vec<f32> = (0..128).map(|i| i as f32 / 128.0).collect();
let probes = lsh.get_probe_sequence(&vector, 0, 5);
let original = lsh.hash_single_table(0, &vector);
assert_eq!(probes[0], original, "First probe should be original hash");
assert_eq!(probes.len(), 5, "Should return requested number of probes");
}
#[test]
fn test_serialization_roundtrip() {
let config = LSHConfig {
num_bits: 8,
num_tables: 4,
dimensions: 64,
};
let lsh = HyperplaneLSH::with_seed(config, 42);
let bytes = lsh.to_bytes();
let restored = HyperplaneLSH::from_bytes(&bytes).unwrap();
let vector: Vec<f32> = (0..64).map(|i| i as f32 / 64.0).collect();
assert_eq!(lsh.hash(&vector), restored.hash(&vector));
}
}