use serde::{Deserialize, Serialize};
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use trustformers_core::errors::{Result, TrustformersError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MinimalPerfectHashConfig {
pub load_factor: f64,
pub num_hash_functions: usize,
pub use_double_hashing: bool,
}
impl Default for MinimalPerfectHashConfig {
fn default() -> Self {
Self {
load_factor: 0.9,
num_hash_functions: 3,
use_double_hashing: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MinimalPerfectHash {
config: MinimalPerfectHashConfig,
table: Vec<Option<u32>>,
secondary_table: Vec<Option<u32>>,
hash_params: Vec<(u64, u64)>,
table_size: usize,
num_keys: usize,
keys: Vec<String>,
}
impl MinimalPerfectHash {
pub fn new(keys: &[String], config: MinimalPerfectHashConfig) -> Result<Self> {
let num_keys = keys.len();
let target_size = (num_keys as f64 / config.load_factor).ceil() as usize;
let table_size = Self::next_prime(target_size);
println!(
"MPH Debug: num_keys={}, target_load_factor={}, target_size={}, table_size={}",
num_keys, config.load_factor, target_size, table_size
);
let mut mph = Self {
config,
table: vec![None; table_size],
secondary_table: vec![None; table_size],
hash_params: Vec::new(),
table_size,
num_keys,
keys: keys.to_vec(),
};
mph.build_hash_function(keys)?;
Ok(mph)
}
fn build_hash_function(&mut self, keys: &[String]) -> Result<()> {
self.hash_params = (0..self.config.num_hash_functions)
.map(|_| (Self::random_u64(), Self::random_u64()))
.collect();
for _attempt in 0..100 {
if self.try_build_hash_function(keys)? {
return Ok(());
}
self.hash_params = (0..self.config.num_hash_functions)
.map(|_| (Self::random_u64(), Self::random_u64()))
.collect();
self.table.fill(None);
self.secondary_table.fill(None);
}
Err(TrustformersError::other(
"Failed to build minimal perfect hash function after 100 attempts".to_string(),
))
}
fn try_build_hash_function(&mut self, keys: &[String]) -> Result<bool> {
let mut conflicts = HashMap::new();
let mut buckets: Vec<Vec<(usize, String)>> = vec![Vec::new(); self.table_size];
for (key_idx, key) in keys.iter().enumerate() {
let bucket = self.hash_primary(key);
buckets[bucket].push((key_idx, key.clone()));
}
for (bucket_idx, bucket) in buckets.iter().enumerate() {
if bucket.len() > 1 {
conflicts.insert(bucket_idx, bucket.clone());
}
}
for (bucket_idx, bucket_keys) in conflicts {
if !self.resolve_bucket_conflicts(bucket_idx, &bucket_keys)? {
return Ok(false);
}
}
for (bucket_idx, bucket) in buckets.iter().enumerate() {
if bucket.len() == 1 {
let (key_idx, _) = &bucket[0];
self.table[bucket_idx] = Some(*key_idx as u32);
}
}
Ok(true)
}
fn resolve_bucket_conflicts(
&mut self,
bucket_idx: usize,
keys: &[(usize, String)],
) -> Result<bool> {
if !self.config.use_double_hashing {
return Ok(false);
}
for hash_idx in 1..self.config.num_hash_functions {
let mut used_positions = std::collections::HashSet::new();
let mut success = true;
for (_key_idx, key) in keys {
let secondary_pos = self.hash_secondary(key, hash_idx);
if used_positions.contains(&secondary_pos)
|| self.secondary_table[secondary_pos].is_some()
{
success = false;
break;
}
used_positions.insert(secondary_pos);
}
if success {
for (key_idx, key) in keys {
let secondary_pos = self.hash_secondary(key, hash_idx);
self.secondary_table[secondary_pos] = Some(*key_idx as u32);
}
self.table[bucket_idx] = Some(u32::MAX); return Ok(true);
}
}
Ok(false)
}
fn hash_primary(&self, key: &str) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash = hasher.finish();
let (a, b) = self.hash_params[0];
((a.wrapping_mul(hash).wrapping_add(b)) % self.table_size as u64) as usize
}
fn hash_secondary(&self, key: &str, func_idx: usize) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash = hasher.finish();
let (a, b) = self.hash_params[func_idx];
((a.wrapping_mul(hash).wrapping_add(b)) % self.table_size as u64) as usize
}
pub fn lookup(&self, key: &str) -> Option<u32> {
let primary_pos = self.hash_primary(key);
match self.table[primary_pos] {
Some(u32::MAX) => {
for hash_idx in 1..self.config.num_hash_functions {
let secondary_pos = self.hash_secondary(key, hash_idx);
if let Some(value) = self.secondary_table[secondary_pos] {
if value < self.keys.len() as u32 && self.keys[value as usize] == key {
return Some(value);
}
}
}
None
},
Some(value) => {
if value < self.keys.len() as u32 && self.keys[value as usize] == key {
Some(value)
} else {
None
}
},
None => None,
}
}
pub fn contains(&self, key: &str) -> bool {
self.lookup(key).is_some()
}
pub fn len(&self) -> usize {
self.num_keys
}
pub fn is_empty(&self) -> bool {
self.num_keys == 0
}
pub fn memory_usage(&self) -> MemoryUsage {
let table_size = self.table.len() * std::mem::size_of::<Option<u32>>();
let secondary_table_size = self.secondary_table.len() * std::mem::size_of::<Option<u32>>();
let params_size = self.hash_params.len() * std::mem::size_of::<(u64, u64)>();
MemoryUsage {
total_bytes: table_size + secondary_table_size + params_size,
table_bytes: table_size,
secondary_table_bytes: secondary_table_size,
params_bytes: params_size,
load_factor: self.num_keys as f64 / self.table_size as f64,
}
}
fn next_prime(n: usize) -> usize {
if n <= 2 {
return 2;
}
let mut candidate = if n.is_multiple_of(2) { n + 1 } else { n };
while !Self::is_prime(candidate) {
candidate += 2;
}
candidate
}
fn is_prime(n: usize) -> bool {
if n <= 1 {
return false;
}
if n <= 3 {
return true;
}
if n.is_multiple_of(2) || n.is_multiple_of(3) {
return false;
}
let mut i = 5;
while i * i <= n {
if n.is_multiple_of(i) || n.is_multiple_of(i + 2) {
return false;
}
i += 6;
}
true
}
fn random_u64() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH");
let mut x = now.as_nanos() as u64;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
x
}
}
#[derive(Debug, Clone)]
pub struct MemoryUsage {
pub total_bytes: usize,
pub table_bytes: usize,
pub secondary_table_bytes: usize,
pub params_bytes: usize,
pub load_factor: f64,
}
impl std::fmt::Display for MemoryUsage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Total: {} bytes, Primary: {} bytes, Secondary: {} bytes, Params: {} bytes, Load factor: {:.2}",
self.total_bytes, self.table_bytes, self.secondary_table_bytes, self.params_bytes, self.load_factor
)
}
}
#[derive(Debug, Clone)]
pub struct MinimalPerfectHashVocab {
mph: MinimalPerfectHash,
id_to_token: Vec<String>,
tokens: Vec<String>,
}
impl MinimalPerfectHashVocab {
pub fn new(tokens: Vec<String>) -> Result<Self> {
let config = MinimalPerfectHashConfig::default();
let mph = MinimalPerfectHash::new(&tokens, config)?;
let id_to_token = tokens.clone();
Ok(Self {
mph,
id_to_token,
tokens,
})
}
pub fn with_config(tokens: Vec<String>, config: MinimalPerfectHashConfig) -> Result<Self> {
let mph = MinimalPerfectHash::new(&tokens, config)?;
let id_to_token = tokens.clone();
Ok(Self {
mph,
id_to_token,
tokens,
})
}
pub fn get_id(&self, token: &str) -> Option<u32> {
self.mph.lookup(token)
}
pub fn get_token(&self, id: u32) -> Option<&str> {
self.id_to_token.get(id as usize).map(|s| s.as_str())
}
pub fn contains(&self, token: &str) -> bool {
self.mph.contains(token)
}
pub fn size(&self) -> usize {
self.mph.len()
}
pub fn is_empty(&self) -> bool {
self.mph.is_empty()
}
pub fn memory_usage(&self) -> MemoryUsage {
let mut usage = self.mph.memory_usage();
let id_to_token_size: usize =
self.id_to_token.iter().map(|s| s.len() + std::mem::size_of::<String>()).sum();
usage.total_bytes += id_to_token_size;
usage
}
pub fn tokens(&self) -> &[String] {
&self.tokens
}
pub fn efficiency_comparison(&self) -> EfficiencyComparison {
let mph_usage = self.memory_usage();
let hashmap_overhead =
self.tokens.len() * (std::mem::size_of::<String>() + std::mem::size_of::<u32>());
let hashmap_capacity = self.tokens.len() * 2; let hashmap_buckets = hashmap_capacity * std::mem::size_of::<Option<(String, u32)>>();
let hashmap_total = hashmap_overhead + hashmap_buckets;
let token_storage: usize = self.tokens.iter().map(|s| s.len()).sum();
let hashmap_estimated = hashmap_total + token_storage;
EfficiencyComparison {
mph_bytes: mph_usage.total_bytes,
hashmap_bytes: hashmap_estimated,
space_savings: 1.0 - (mph_usage.total_bytes as f64 / hashmap_estimated as f64),
compression_ratio: hashmap_estimated as f64 / mph_usage.total_bytes as f64,
}
}
}
#[derive(Debug, Clone)]
pub struct EfficiencyComparison {
pub mph_bytes: usize,
pub hashmap_bytes: usize,
pub space_savings: f64,
pub compression_ratio: f64,
}
impl std::fmt::Display for EfficiencyComparison {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"MPH: {} bytes, HashMap: {} bytes, Savings: {:.1}%, Compression: {:.2}x",
self.mph_bytes,
self.hashmap_bytes,
self.space_savings * 100.0,
self.compression_ratio
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_minimal_perfect_hash_creation() {
let keys = vec![
"hello".to_string(),
"world".to_string(),
"test".to_string(),
"key".to_string(),
"value".to_string(),
];
let config = MinimalPerfectHashConfig::default();
let mph = MinimalPerfectHash::new(&keys, config).expect("Construction failed");
assert_eq!(mph.len(), 5);
assert!(!mph.is_empty());
}
#[test]
fn test_minimal_perfect_hash_lookup() {
let keys = vec!["hello".to_string(), "world".to_string(), "test".to_string()];
let config = MinimalPerfectHashConfig::default();
let mph = MinimalPerfectHash::new(&keys, config).expect("Construction failed");
assert!(mph.contains("hello"));
assert!(mph.contains("world"));
assert!(mph.contains("test"));
assert!(!mph.contains("nonexistent"));
assert!(!mph.contains("missing"));
}
#[test]
fn test_minimal_perfect_hash_no_collisions() {
let keys = vec![
"apple".to_string(),
"banana".to_string(),
"cherry".to_string(),
"date".to_string(),
"elderberry".to_string(),
];
let config = MinimalPerfectHashConfig::default();
let mph = MinimalPerfectHash::new(&keys, config).expect("Construction failed");
let mut hash_values = std::collections::HashSet::new();
for key in &keys {
if let Some(value) = mph.lookup(key) {
assert!(
!hash_values.contains(&value),
"Collision detected for key: {}",
key
);
hash_values.insert(value);
} else {
panic!("Key not found: {}", key);
}
}
assert_eq!(hash_values.len(), keys.len());
}
#[test]
fn test_minimal_perfect_hash_vocab() {
let tokens = vec![
"the".to_string(),
"quick".to_string(),
"brown".to_string(),
"fox".to_string(),
"jumps".to_string(),
];
let vocab = MinimalPerfectHashVocab::new(tokens.clone()).expect("Construction failed");
assert_eq!(vocab.size(), 5);
assert!(!vocab.is_empty());
for (expected_id, token) in tokens.iter().enumerate() {
if let Some(id) = vocab.get_id(token) {
assert_eq!(id, expected_id as u32);
} else {
panic!("Token not found: {}", token);
}
}
for (id, expected_token) in tokens.iter().enumerate() {
if let Some(token) = vocab.get_token(id as u32) {
assert_eq!(token, expected_token);
} else {
panic!("ID not found: {}", id);
}
}
}
#[test]
fn test_memory_usage() {
let tokens = vec![
"token1".to_string(),
"token2".to_string(),
"token3".to_string(),
];
let vocab = MinimalPerfectHashVocab::new(tokens).expect("Construction failed");
let usage = vocab.memory_usage();
assert!(usage.total_bytes > 0);
assert!(usage.table_bytes > 0);
assert!(usage.load_factor > 0.0);
assert!(usage.load_factor <= 1.0);
}
#[test]
fn test_efficiency_comparison() {
let tokens = vec![
"efficiency".to_string(),
"comparison".to_string(),
"test".to_string(),
"minimal".to_string(),
"perfect".to_string(),
"hash".to_string(),
];
let vocab = MinimalPerfectHashVocab::new(tokens).expect("Construction failed");
let comparison = vocab.efficiency_comparison();
assert!(comparison.mph_bytes > 0);
assert!(comparison.hashmap_bytes > 0);
assert!(comparison.compression_ratio > 0.0);
assert!(comparison.space_savings >= 0.0);
}
#[test]
fn test_large_vocabulary() {
let tokens: Vec<String> = vec![
"the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog", "hello", "world",
"rust", "code",
]
.into_iter()
.map(|s| s.to_string())
.collect();
let config = MinimalPerfectHashConfig {
load_factor: 0.5, num_hash_functions: 3,
use_double_hashing: true,
};
let vocab = MinimalPerfectHashVocab::with_config(tokens.clone(), config)
.expect("Operation failed in test");
assert_eq!(vocab.size(), tokens.len());
for i in [0, 3, 7, tokens.len() - 1] {
let token = &tokens[i];
assert!(vocab.contains(token));
assert_eq!(vocab.get_id(token), Some(i as u32));
assert_eq!(vocab.get_token(i as u32), Some(token.as_str()));
}
}
#[test]
fn test_custom_config() {
let tokens = vec![
"custom".to_string(),
"config".to_string(),
"test".to_string(),
];
let config = MinimalPerfectHashConfig {
load_factor: 0.8,
num_hash_functions: 2,
use_double_hashing: false,
};
let vocab =
MinimalPerfectHashVocab::with_config(tokens, config).expect("Operation failed in test");
assert_eq!(vocab.size(), 3);
let usage = vocab.memory_usage();
println!("Load factor: {}, expected <= 0.8", usage.load_factor);
assert!(usage.load_factor <= 0.8);
}
}