use crate::types::VectorId;
use ndarray::{Array1, Array2};
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub trait NeuralHash {
fn encode(&self, vector: &[f32]) -> Vec<u8>;
fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32;
fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32;
}
#[derive(Clone, Serialize, Deserialize)]
pub struct DeepHashEmbedding {
projections: Vec<Array2<f32>>,
biases: Vec<Array1<f32>>,
output_bits: usize,
input_dims: usize,
}
impl DeepHashEmbedding {
pub fn new(input_dims: usize, hidden_dims: Vec<usize>, output_bits: usize) -> Self {
let mut rng = rand::thread_rng();
let mut projections = Vec::new();
let mut biases = Vec::new();
let mut layer_dims = vec![input_dims];
layer_dims.extend(&hidden_dims);
layer_dims.push(output_bits);
for i in 0..layer_dims.len() - 1 {
let in_dim = layer_dims[i];
let out_dim = layer_dims[i + 1];
let scale = (2.0 / (in_dim + out_dim) as f32).sqrt();
let proj = Array2::from_shape_fn((out_dim, in_dim), |_| {
rng.gen::<f32>() * 2.0 * scale - scale
});
let bias = Array1::zeros(out_dim);
projections.push(proj);
biases.push(bias);
}
Self {
projections,
biases,
output_bits,
input_dims,
}
}
fn forward(&self, input: &[f32]) -> Vec<f32> {
let mut activations = Array1::from_vec(input.to_vec());
for (proj, bias) in self.projections.iter().zip(self.biases.iter()) {
activations = proj.dot(&activations) + bias;
if proj.nrows() != self.output_bits {
activations.mapv_inplace(|x| x.max(0.0));
}
}
activations.to_vec()
}
pub fn train(
&mut self,
positive_pairs: &[(Vec<f32>, Vec<f32>)],
negative_pairs: &[(Vec<f32>, Vec<f32>)],
learning_rate: f32,
epochs: usize,
) {
for _ in 0..epochs {
for (a, b) in positive_pairs {
let code_a = self.encode(a);
let code_b = self.encode(b);
let dist = self.hamming_distance(&code_a, &code_b);
if dist as f32 > self.output_bits as f32 * 0.3 {
self.update_weights(a, b, learning_rate, true);
}
}
for (a, b) in negative_pairs {
let code_a = self.encode(a);
let code_b = self.encode(b);
let dist = self.hamming_distance(&code_a, &code_b);
if (dist as f32) < self.output_bits as f32 * 0.6 {
self.update_weights(a, b, learning_rate, false);
}
}
}
}
fn update_weights(&mut self, a: &[f32], b: &[f32], lr: f32, attract: bool) {
let direction = if attract { 1.0 } else { -1.0 };
if let Some(last_proj) = self.projections.last_mut() {
let a_arr = Array1::from_vec(a.to_vec());
let b_arr = Array1::from_vec(b.to_vec());
for i in 0..last_proj.nrows() {
for j in 0..last_proj.ncols() {
let grad = direction * lr * (a_arr[j] - b_arr[j]);
last_proj[[i, j]] += grad * 0.001; }
}
}
}
pub fn dimensions(&self) -> (usize, usize) {
(self.input_dims, self.output_bits)
}
}
impl NeuralHash for DeepHashEmbedding {
fn encode(&self, vector: &[f32]) -> Vec<u8> {
if vector.len() != self.input_dims {
return vec![0; self.output_bits.div_ceil(8)];
}
let logits = self.forward(vector);
let mut bits = vec![0u8; self.output_bits.div_ceil(8)];
for (i, &logit) in logits.iter().enumerate() {
if logit > 0.0 {
let byte_idx = i / 8;
let bit_idx = i % 8;
bits[byte_idx] |= 1 << bit_idx;
}
}
bits
}
fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32 {
code_a
.iter()
.zip(code_b.iter())
.map(|(a, b)| (a ^ b).count_ones())
.sum()
}
fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32 {
let normalized_dist = hamming_dist as f32 / code_bits as f32;
1.0 - 2.0 * normalized_dist
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct SimpleLSH {
projections: Array2<f32>,
num_bits: usize,
}
impl SimpleLSH {
pub fn new(input_dims: usize, num_bits: usize) -> Self {
let mut rng = rand::thread_rng();
let projections =
Array2::from_shape_fn((num_bits, input_dims), |_| rng.gen::<f32>() * 2.0 - 1.0);
Self {
projections,
num_bits,
}
}
}
impl NeuralHash for SimpleLSH {
fn encode(&self, vector: &[f32]) -> Vec<u8> {
let input = Array1::from_vec(vector.to_vec());
let projections = self.projections.dot(&input);
let mut bits = vec![0u8; self.num_bits.div_ceil(8)];
for (i, &val) in projections.iter().enumerate() {
if val > 0.0 {
let byte_idx = i / 8;
let bit_idx = i % 8;
bits[byte_idx] |= 1 << bit_idx;
}
}
bits
}
fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32 {
code_a
.iter()
.zip(code_b.iter())
.map(|(a, b)| (a ^ b).count_ones())
.sum()
}
fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32 {
let normalized_dist = hamming_dist as f32 / code_bits as f32;
1.0 - 2.0 * normalized_dist
}
}
pub struct HashIndex<H: NeuralHash + Clone> {
hasher: H,
tables: HashMap<Vec<u8>, Vec<VectorId>>,
vectors: HashMap<VectorId, Vec<f32>>,
code_bits: usize,
}
impl<H: NeuralHash + Clone> HashIndex<H> {
pub fn new(hasher: H, code_bits: usize) -> Self {
Self {
hasher,
tables: HashMap::new(),
vectors: HashMap::new(),
code_bits,
}
}
pub fn insert(&mut self, id: VectorId, vector: Vec<f32>) {
let code = self.hasher.encode(&vector);
self.tables.entry(code).or_default().push(id.clone());
self.vectors.insert(id, vector);
}
pub fn search(&self, query: &[f32], k: usize, max_hamming: u32) -> Vec<(VectorId, f32)> {
let query_code = self.hasher.encode(query);
let mut candidates = Vec::new();
for (code, ids) in &self.tables {
let hamming = self.hasher.hamming_distance(&query_code, code);
if hamming <= max_hamming {
for id in ids {
if let Some(vec) = self.vectors.get(id) {
let similarity = cosine_similarity(query, vec);
candidates.push((id.clone(), similarity));
}
}
}
}
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
candidates.truncate(k);
candidates
}
pub fn compression_ratio(&self) -> f32 {
if self.vectors.is_empty() {
return 0.0;
}
let original_size: usize = self
.vectors
.values()
.map(|v| v.len() * std::mem::size_of::<f32>())
.sum();
let compressed_size = self.tables.len() * self.code_bits.div_ceil(8);
original_size as f32 / compressed_size as f32
}
pub fn stats(&self) -> HashIndexStats {
let buckets = self.tables.len();
let total_vectors = self.vectors.len();
let avg_bucket_size = if buckets > 0 {
total_vectors as f32 / buckets as f32
} else {
0.0
};
HashIndexStats {
total_vectors,
num_buckets: buckets,
avg_bucket_size,
compression_ratio: self.compression_ratio(),
}
}
}
#[derive(Debug, Clone)]
pub struct HashIndexStats {
pub total_vectors: usize,
pub num_buckets: usize,
pub avg_bucket_size: f32,
pub compression_ratio: f32,
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deep_hash_encoding() {
let hash = DeepHashEmbedding::new(4, vec![8], 16);
let vector = vec![0.1, 0.2, 0.3, 0.4];
let code = hash.encode(&vector);
assert_eq!(code.len(), 2); }
#[test]
fn test_hamming_distance() {
let hash = DeepHashEmbedding::new(2, vec![], 8);
let code_a = vec![0b10101010];
let code_b = vec![0b11001100];
let dist = hash.hamming_distance(&code_a, &code_b);
assert_eq!(dist, 4); }
#[test]
fn test_lsh_encoding() {
let lsh = SimpleLSH::new(4, 16);
let vector = vec![1.0, 2.0, 3.0, 4.0];
let code = lsh.encode(&vector);
assert_eq!(code.len(), 2);
let code2 = lsh.encode(&vector);
assert_eq!(code, code2);
}
#[test]
fn test_hash_index() {
let lsh = SimpleLSH::new(3, 8);
let mut index = HashIndex::new(lsh, 8);
index.insert("0".to_string(), vec![1.0, 0.0, 0.0]);
index.insert("1".to_string(), vec![0.9, 0.1, 0.0]);
index.insert("2".to_string(), vec![0.0, 1.0, 0.0]);
let results = index.search(&[1.0, 0.0, 0.0], 2, 4);
assert!(!results.is_empty());
let stats = index.stats();
assert_eq!(stats.total_vectors, 3);
}
#[test]
fn test_compression_ratio() {
let lsh = SimpleLSH::new(128, 32); let mut index = HashIndex::new(lsh, 32);
for i in 0..10 {
let vec: Vec<f32> = (0..128).map(|j| (i + j) as f32 / 128.0).collect();
index.insert(i.to_string(), vec);
}
let ratio = index.compression_ratio();
assert!(ratio > 1.0); }
}