use blake3;
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
use rand_distr::{Distribution, StandardNormal};
use super::encoding::encode_sortable;
use super::error::ElidError;
#[inline]
pub fn derive_bit_seed(base_seed: u64, bit_idx: u8) -> [u8; 32] {
let mut input = [0u8; 9];
input[0..8].copy_from_slice(&base_seed.to_le_bytes());
input[8] = bit_idx;
let hash = blake3::hash(&input);
*hash.as_bytes()
}
#[inline]
pub fn simhash_128(embedding: &[f32], seed: u64) -> u128 {
let _dim = embedding.len();
let mut hash: u128 = 0;
for bit_idx in 0..128 {
let bit_seed = derive_bit_seed(seed, bit_idx);
let mut rng = ChaCha20Rng::from_seed(bit_seed);
let mut dot_product: f32 = 0.0;
for &value in embedding {
let projection_value: f32 = StandardNormal.sample(&mut rng);
dot_product += value * projection_value;
}
if dot_product > 0.0 {
hash |= 1u128 << bit_idx;
}
}
hash
}
#[inline]
#[must_use]
pub fn elid_hamming_distance(a: u128, b: u128) -> u32 {
(a ^ b).count_ones()
}
#[inline]
#[must_use]
pub fn simhash_to_bytes(hash: u128) -> [u8; 16] {
hash.to_be_bytes()
}
#[inline]
pub fn simhash_from_bytes(bytes: &[u8]) -> Result<u128, ElidError> {
if bytes.len() != 16 {
return Err(ElidError::InvalidEncoding);
}
let mut array = [0u8; 16];
array.copy_from_slice(bytes);
Ok(u128::from_be_bytes(array))
}
#[inline]
#[must_use]
pub fn cosine_similarity_approx(hash_a: u128, hash_b: u128) -> f32 {
let distance = elid_hamming_distance(hash_a, hash_b) as f32;
1.0 - (distance / 128.0) * std::f32::consts::PI
}
#[must_use]
pub fn mini128_to_bands(hash: &[u8; 16], num_bands: u8) -> Vec<String> {
if num_bands == 0 || 16 % num_bands != 0 {
return Vec::new();
}
let bytes_per_band = 16 / num_bands as usize;
let mut bands = Vec::with_capacity(num_bands as usize);
for i in 0..num_bands as usize {
let start = i * bytes_per_band;
let end = start + bytes_per_band;
let band_bytes = &hash[start..end];
bands.push(encode_sortable(band_bytes));
}
bands
}
#[must_use]
pub fn embedding_to_bands(embedding: &[f32], num_bands: u8, seed: u64) -> Vec<String> {
if num_bands == 0 || 16 % num_bands != 0 {
return Vec::new();
}
let hash = simhash_128(embedding, seed);
let hash_bytes = simhash_to_bytes(hash);
mini128_to_bands(&hash_bytes, num_bands)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_derive_bit_seed_deterministic() {
let seed1 = derive_bit_seed(0x1234_5678_9ABC_DEF0, 0);
let seed2 = derive_bit_seed(0x1234_5678_9ABC_DEF0, 0);
assert_eq!(seed1, seed2);
}
#[test]
fn test_derive_bit_seed_different_bits() {
let seed0 = derive_bit_seed(0x1234_5678_9ABC_DEF0, 0);
let seed1 = derive_bit_seed(0x1234_5678_9ABC_DEF0, 1);
assert_ne!(seed0, seed1);
}
#[test]
fn test_derive_bit_seed_different_base() {
let seed_a = derive_bit_seed(0x1111_1111_1111_1111, 0);
let seed_b = derive_bit_seed(0x2222_2222_2222_2222, 0);
assert_ne!(seed_a, seed_b);
}
#[test]
fn test_derive_bit_seed_coverage() {
let base = 0x454c4944_53494d48;
let mut seeds = std::collections::HashSet::new();
for bit_idx in 0..128 {
let seed = derive_bit_seed(base, bit_idx);
assert!(seeds.insert(seed), "Duplicate seed at bit {}", bit_idx);
}
}
#[test]
fn test_simhash_128_deterministic() {
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
let embedding = embedding.into_iter().cycle().take(128).collect::<Vec<_>>();
let seed = 0x454c4944_53494d48;
let hash1 = simhash_128(&embedding, seed);
let hash2 = simhash_128(&embedding, seed);
assert_eq!(hash1, hash2, "SimHash must be deterministic");
}
#[test]
fn test_simhash_128_different_seeds() {
let embedding = vec![0.1, 0.2, 0.3, 0.4];
let embedding = embedding.into_iter().cycle().take(128).collect::<Vec<_>>();
let hash1 = simhash_128(&embedding, 0x1111_1111_1111_1111);
let hash2 = simhash_128(&embedding, 0x2222_2222_2222_2222);
assert_ne!(
hash1, hash2,
"Different seeds should produce different hashes"
);
}
#[test]
fn test_simhash_128_different_embeddings() {
let seed = 0x454c4944_53494d48;
let emb1 = vec![1.0, 0.0, 0.0, 0.0]
.into_iter()
.cycle()
.take(128)
.collect::<Vec<_>>();
let emb2 = vec![0.0, 1.0, 0.0, 0.0]
.into_iter()
.cycle()
.take(128)
.collect::<Vec<_>>();
let hash1 = simhash_128(&emb1, seed);
let hash2 = simhash_128(&emb2, seed);
assert_ne!(
hash1, hash2,
"Different embeddings should produce different hashes"
);
}
#[test]
fn test_simhash_128_all_zeros() {
let embedding = vec![0.0; 128];
let seed = 0x454c4944_53494d48;
let hash = simhash_128(&embedding, seed);
let hash2 = simhash_128(&embedding, seed);
assert_eq!(hash, hash2);
}
#[test]
fn test_simhash_128_various_dimensions() {
let seed = 0x454c4944_53494d48;
for dim in [64, 128, 256, 512, 768, 1024, 1536, 2048] {
let embedding = vec![0.1; dim];
let hash = simhash_128(&embedding, seed);
let _ = hash; }
}
#[test]
fn test_simhash_128_sign_extraction() {
let seed = 0x454c4944_53494d48;
let positive_embedding = vec![1.0; 128];
let negative_embedding = vec![-1.0; 128];
let hash_pos = simhash_128(&positive_embedding, seed);
let hash_neg = simhash_128(&negative_embedding, seed);
let dist = elid_hamming_distance(hash_pos, hash_neg);
assert!(
dist > 64,
"Opposite embeddings should have high Hamming distance, got {}",
dist
);
}
#[test]
fn test_simhash_128_locality_preservation() {
let seed = 0x454c4944_53494d48;
let base = vec![0.5; 256];
let mut similar = base.clone();
similar[0] = 0.51;
let hash_base = simhash_128(&base, seed);
let hash_similar = simhash_128(&similar, seed);
let dist = elid_hamming_distance(hash_base, hash_similar);
assert!(
dist < 64,
"Similar embeddings should have low Hamming distance, got {}",
dist
);
}
#[test]
fn test_hamming_distance_identical() {
let a = 0xDEAD_BEEF_CAFE_BABE_1234_5678_9ABC_DEF0_u128;
assert_eq!(elid_hamming_distance(a, a), 0);
}
#[test]
fn test_hamming_distance_one_bit() {
let a = 0b0000_u128;
let b = 0b0001_u128;
assert_eq!(elid_hamming_distance(a, b), 1);
}
#[test]
fn test_hamming_distance_all_bits() {
let a = 0_u128;
let b = !0_u128; assert_eq!(elid_hamming_distance(a, b), 128);
}
#[test]
fn test_hamming_distance_symmetric() {
let a = 0x1234_5678_u128;
let b = 0x9ABC_DEF0_u128;
assert_eq!(elid_hamming_distance(a, b), elid_hamming_distance(b, a));
}
#[test]
fn test_hamming_distance_known_pattern() {
let a = 0b1010_u128;
let b = 0b1100_u128;
assert_eq!(elid_hamming_distance(a, b), 2);
}
#[test]
fn test_simhash_to_bytes_basic() {
let hash = 0x0102030405060708090A0B0C0D0E0F10_u128;
let bytes = simhash_to_bytes(hash);
assert_eq!(bytes.len(), 16);
assert_eq!(bytes[0], 0x01);
assert_eq!(bytes[15], 0x10);
}
#[test]
fn test_simhash_from_bytes_valid() {
let bytes = [
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,
0x0F, 0x10,
];
let hash = simhash_from_bytes(&bytes).unwrap();
assert_eq!(hash, 0x0102030405060708090A0B0C0D0E0F10_u128);
}
#[test]
fn test_simhash_from_bytes_invalid_length() {
let too_short = [0u8; 8];
assert!(simhash_from_bytes(&too_short).is_err());
let too_long = [0u8; 32];
assert!(simhash_from_bytes(&too_long).is_err());
let empty: [u8; 0] = [];
assert!(simhash_from_bytes(&empty).is_err());
}
#[test]
fn test_simhash_bytes_roundtrip() {
let test_hashes = vec![
0_u128,
!0_u128,
0xDEADBEEFCAFEBABE1234567890ABCDEF_u128,
0x0000000000000000FFFFFFFFFFFFFFFF_u128,
0x5555555555555555AAAAAAAAAAAAAAAA_u128,
];
for hash in test_hashes {
let bytes = simhash_to_bytes(hash);
let recovered = simhash_from_bytes(&bytes).unwrap();
assert_eq!(hash, recovered, "Round-trip failed for 0x{:032x}", hash);
}
}
#[test]
fn test_cosine_similarity_approx_identical() {
let hash = 0xDEADBEEFCAFEBABE_u128;
let sim = cosine_similarity_approx(hash, hash);
assert_eq!(sim, 1.0, "Identical hashes should have similarity 1.0");
}
#[test]
fn test_cosine_similarity_approx_opposite() {
let a = 0_u128;
let b = !0_u128; let sim = cosine_similarity_approx(a, b);
let expected = 1.0 - std::f32::consts::PI;
assert!(
(sim - expected).abs() < 0.001,
"Expected {}, got {}",
expected,
sim
);
assert!(sim < 0.0, "Opposite hashes should have negative similarity");
}
#[test]
fn test_cosine_similarity_approx_symmetry() {
let a = 0x1234567890ABCDEF_u128;
let b = 0xFEDCBA0987654321_u128;
let sim_ab = cosine_similarity_approx(a, b);
let sim_ba = cosine_similarity_approx(b, a);
assert_eq!(sim_ab, sim_ba, "Cosine similarity should be symmetric");
}
#[test]
fn test_simhash_reference_output() {
let embedding = vec![
0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8,
]
.into_iter()
.cycle()
.take(768)
.collect::<Vec<_>>();
let seed = 0x454c4944_53494d48; let hash = simhash_128(&embedding, seed);
let expected: u128 = 0x9f52baea6db62f9b250de36caf8f0b13;
assert_eq!(
hash, expected,
"Reference hash mismatch! Got 0x{:032x}, expected 0x{:032x}",
hash, expected
);
}
#[test]
fn test_bands_deterministic() {
let embedding: Vec<f32> = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
.into_iter()
.cycle()
.take(768)
.collect();
let seed = 0x454c4944_53494d48;
let bands1 = embedding_to_bands(&embedding, 4, seed);
let bands2 = embedding_to_bands(&embedding, 4, seed);
assert_eq!(bands1, bands2, "Bands must be deterministic");
}
#[test]
fn test_similar_embeddings_share_band() {
let seed = 0x454c4944_53494d48;
let base: Vec<f32> = (0..768).map(|i| (i as f32 * 0.001).sin()).collect();
let similar: Vec<f32> = base.iter().map(|&x| x + 0.001).collect();
let bands_base = embedding_to_bands(&base, 4, seed);
let bands_similar = embedding_to_bands(&similar, 4, seed);
let shared_bands = bands_base
.iter()
.zip(bands_similar.iter())
.filter(|(a, b)| a == b)
.count();
assert!(
shared_bands >= 1,
"Similar embeddings should share at least one band, got {} shared out of 4",
shared_bands
);
}
#[test]
fn test_invalid_num_bands() {
let hash = [0u8; 16];
assert!(mini128_to_bands(&hash, 0).is_empty(), "0 bands should fail");
assert!(mini128_to_bands(&hash, 3).is_empty(), "3 bands should fail");
assert!(mini128_to_bands(&hash, 5).is_empty(), "5 bands should fail");
assert!(mini128_to_bands(&hash, 6).is_empty(), "6 bands should fail");
assert!(mini128_to_bands(&hash, 7).is_empty(), "7 bands should fail");
assert!(mini128_to_bands(&hash, 9).is_empty(), "9 bands should fail");
assert!(
mini128_to_bands(&hash, 15).is_empty(),
"15 bands should fail"
);
assert!(
mini128_to_bands(&hash, 32).is_empty(),
"32 bands should fail"
);
assert_eq!(mini128_to_bands(&hash, 1).len(), 1);
assert_eq!(mini128_to_bands(&hash, 2).len(), 2);
assert_eq!(mini128_to_bands(&hash, 4).len(), 4);
assert_eq!(mini128_to_bands(&hash, 8).len(), 8);
assert_eq!(mini128_to_bands(&hash, 16).len(), 16);
}
#[test]
fn test_band_sizes() {
let hash = [
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,
0x0F, 0x10,
];
let bands_1 = mini128_to_bands(&hash, 1);
assert_eq!(bands_1.len(), 1);
assert_eq!(bands_1[0].len(), 26);
let bands_2 = mini128_to_bands(&hash, 2);
assert_eq!(bands_2.len(), 2);
for band in &bands_2 {
assert_eq!(band.len(), 13);
}
let bands_4 = mini128_to_bands(&hash, 4);
assert_eq!(bands_4.len(), 4);
for band in &bands_4 {
assert_eq!(band.len(), 7);
}
let bands_8 = mini128_to_bands(&hash, 8);
assert_eq!(bands_8.len(), 8);
for band in &bands_8 {
assert_eq!(band.len(), 4);
}
let bands_16 = mini128_to_bands(&hash, 16);
assert_eq!(bands_16.len(), 16);
for band in &bands_16 {
assert_eq!(band.len(), 2);
}
}
#[test]
fn test_band_content_matches_hash_segments() {
let hash = [
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, ];
let bands = mini128_to_bands(&hash, 4);
use super::super::encoding::encode_sortable;
assert_eq!(bands[0], encode_sortable(&[0x01, 0x02, 0x03, 0x04]));
assert_eq!(bands[1], encode_sortable(&[0x05, 0x06, 0x07, 0x08]));
assert_eq!(bands[2], encode_sortable(&[0x09, 0x0A, 0x0B, 0x0C]));
assert_eq!(bands[3], encode_sortable(&[0x0D, 0x0E, 0x0F, 0x10]));
}
#[test]
fn test_embedding_to_bands_invalid() {
let embedding = vec![0.1f32; 768];
let seed = 0x454c4944_53494d48;
assert!(embedding_to_bands(&embedding, 0, seed).is_empty());
assert!(embedding_to_bands(&embedding, 3, seed).is_empty());
assert!(embedding_to_bands(&embedding, 5, seed).is_empty());
}
#[test]
fn test_bands_all_lowercase_base32hex() {
let embedding: Vec<f32> = (0..768).map(|i| (i as f32 * 0.01).cos()).collect();
let seed = 0x454c4944_53494d48;
let bands = embedding_to_bands(&embedding, 4, seed);
for band in &bands {
for c in band.chars() {
assert!(
c.is_ascii_digit() || ('a'..='v').contains(&c),
"Invalid character '{}' in band, expected base32hex (0-9, a-v)",
c
);
}
}
}
#[test]
fn test_bands_different_embeddings_differ() {
let seed = 0x454c4944_53494d48;
let emb1: Vec<f32> = vec![1.0; 768];
let emb2: Vec<f32> = vec![-1.0; 768];
let bands1 = embedding_to_bands(&emb1, 4, seed);
let bands2 = embedding_to_bands(&emb2, 4, seed);
let matching_bands = bands1
.iter()
.zip(bands2.iter())
.filter(|(a, b)| a == b)
.count();
assert!(
matching_bands < 4,
"Orthogonal embeddings should have mostly different bands"
);
}
}