use crate::error::{Result, ZiporaError};
use crate::memory::{SecureMemoryPool, SecurePooledPtr, get_global_pool_for_size};
use ahash::AHasher;
use std::hash::{Hash, Hasher};
use std::marker::PhantomData;
use std::mem;
use std::sync::Arc;
pub struct GoldHashIdx<K, V>
where
K: Hash + Eq + Clone,
V: Clone,
{
table: Vec<Option<Bucket<K>>>,
values: Vec<Option<SecurePooledPtr>>,
free_values: Vec<usize>,
len: usize,
capacity: usize,
load_factor: f64,
pool: Option<Arc<SecureMemoryPool>>,
key_memory: usize,
value_memory: usize,
_phantom: PhantomData<V>,
}
#[derive(Debug, Clone)]
struct Bucket<K> {
key: K,
value_index: usize,
hash: u64,
}
impl<K, V> GoldHashIdx<K, V>
where
K: Hash + Eq + Clone,
V: Clone,
{
pub fn new() -> Self {
Self::with_capacity(16)
}
pub fn with_capacity(capacity: usize) -> Self {
let capacity = capacity.next_power_of_two().max(16);
Self {
table: vec![None; capacity],
values: Vec::new(),
free_values: Vec::new(),
len: 0,
capacity,
load_factor: 0.75,
pool: None,
key_memory: 0,
value_memory: 0,
_phantom: PhantomData,
}
}
pub fn with_pool(capacity: usize, pool: Arc<SecureMemoryPool>) -> Self {
let capacity = capacity.next_power_of_two().max(16);
Self {
table: vec![None; capacity],
values: Vec::new(),
free_values: Vec::new(),
len: 0,
capacity,
load_factor: 0.75,
pool: Some(pool),
key_memory: 0,
value_memory: 0,
_phantom: PhantomData,
}
}
pub fn insert(&mut self, key: K, value: V) -> Result<Option<V>> {
if self.len >= (self.capacity as f64 * self.load_factor) as usize {
self.resize()?;
}
let hash = self.hash_key(&key);
let mut index = (hash % self.capacity as u64) as usize;
loop {
if self.table[index].is_none() {
let value_index = self.allocate_value(value)?;
self.table[index] = Some(Bucket {
key: key.clone(),
value_index,
hash,
});
self.len += 1;
self.key_memory += mem::size_of::<K>() + mem::size_of::<Bucket<K>>();
return Ok(None);
} else if let Some(bucket) = &self.table[index] {
if bucket.hash == hash && bucket.key == key {
let value_index = bucket.value_index;
let old_value = self.get_value(value_index)?.clone();
self.store_value(value_index, value)?;
return Ok(Some(old_value));
}
index = (index + 1) % self.capacity;
}
}
}
pub fn get(&self, key: &K) -> Option<&V> {
let hash = self.hash_key(key);
let mut index = (hash % self.capacity as u64) as usize;
for _ in 0..self.capacity {
match &self.table[index] {
None => return None,
Some(bucket) => {
if bucket.hash == hash && bucket.key == *key {
return self.get_value(bucket.value_index).ok();
}
index = (index + 1) % self.capacity;
}
}
}
None
}
pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
let hash = self.hash_key(key);
let mut index = (hash % self.capacity as u64) as usize;
for _ in 0..self.capacity {
match &self.table[index] {
None => return None,
Some(bucket) => {
if bucket.hash == hash && bucket.key == *key {
let value_index = bucket.value_index;
return self.get_value_mut(value_index).ok();
}
index = (index + 1) % self.capacity;
}
}
}
None
}
pub fn remove(&mut self, key: &K) -> Option<V> {
let hash = self.hash_key(key);
let mut index = (hash % self.capacity as u64) as usize;
for _ in 0..self.capacity {
match &self.table[index] {
None => return None,
Some(bucket) => {
if bucket.hash == hash && bucket.key == *key {
let value_index = bucket.value_index;
let value = self.get_value(value_index).ok()?.clone();
self.table[index] = None;
self.free_value(value_index);
self.len -= 1;
self.key_memory = self
.key_memory
.saturating_sub(mem::size_of::<K>() + mem::size_of::<Bucket<K>>());
self.rehash_after_removal(index);
return Some(value);
}
index = (index + 1) % self.capacity;
}
}
}
None
}
pub fn contains_key(&self, key: &K) -> bool {
self.get(key).is_some()
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn insert_batch(&mut self, items: Vec<(K, V)>) -> Result<()> {
let target_capacity = (self.len + items.len()) * 2;
if target_capacity > self.capacity {
let new_capacity = target_capacity.next_power_of_two();
self.resize_to(new_capacity)?;
}
for (key, value) in items {
self.insert(key, value)?;
}
Ok(())
}
pub fn get_batch(&self, keys: &[K]) -> Vec<Option<&V>> {
keys.iter().map(|key| self.get(key)).collect()
}
pub fn shrink_to_fit(&mut self) {
let min_capacity = ((self.len as f64 / self.load_factor) as usize)
.next_power_of_two()
.max(16);
if min_capacity < self.capacity {
if let Err(_) = self.resize_to(min_capacity) {
}
}
self.values.shrink_to_fit();
self.free_values.shrink_to_fit();
}
#[inline]
pub fn memory_usage(&self) -> (usize, usize) {
(self.key_memory, self.value_memory)
}
fn hash_key(&self, key: &K) -> u64 {
let mut hasher = AHasher::default();
key.hash(&mut hasher);
hasher.finish()
}
fn allocate_value(&mut self, value: V) -> Result<usize> {
if let Some(index) = self.free_values.pop() {
self.store_value(index, value)?;
return Ok(index);
}
let value_ptr = self.allocate_pooled_value(value)?;
let index = self.values.len();
self.values.push(Some(value_ptr));
self.value_memory += mem::size_of::<V>();
Ok(index)
}
fn store_value(&mut self, index: usize, value: V) -> Result<()> {
if index >= self.values.len() {
return Err(ZiporaError::invalid_data("Invalid value index"));
}
let value_ptr = self.allocate_pooled_value(value)?;
self.values[index] = Some(value_ptr);
Ok(())
}
fn get_value(&self, index: usize) -> Result<&V> {
self.values
.get(index)
.and_then(|opt| opt.as_ref())
.map(|ptr| unsafe { &*(ptr.as_ptr() as *const V) })
.ok_or_else(|| ZiporaError::invalid_data("Invalid value index"))
}
fn get_value_mut(&mut self, index: usize) -> Result<&mut V> {
self.values
.get_mut(index)
.and_then(|opt| opt.as_ref())
.map(|ptr| unsafe { &mut *(ptr.as_ptr() as *mut V) })
.ok_or_else(|| ZiporaError::invalid_data("Invalid value index"))
}
fn free_value(&mut self, index: usize) {
if index < self.values.len() {
self.values[index] = None;
self.free_values.push(index);
self.value_memory = self.value_memory.saturating_sub(mem::size_of::<V>());
}
}
fn allocate_pooled_value(&mut self, value: V) -> Result<SecurePooledPtr> {
match &self.pool {
Some(pool) => {
let ptr = pool.allocate()?;
unsafe {
std::ptr::write(ptr.as_ptr() as *mut V, value);
}
Ok(ptr)
}
None => {
let pool = get_global_pool_for_size(mem::size_of::<V>());
let ptr = pool.allocate()?;
unsafe {
std::ptr::write(ptr.as_ptr() as *mut V, value);
}
Ok(ptr)
}
}
}
fn resize(&mut self) -> Result<()> {
self.resize_to(self.capacity * 2)
}
fn resize_to(&mut self, new_capacity: usize) -> Result<()> {
let old_table = mem::replace(&mut self.table, vec![None; new_capacity]);
let _old_capacity = self.capacity;
self.capacity = new_capacity;
self.len = 0;
self.key_memory = 0;
for bucket in old_table.into_iter().flatten() {
let value = self.get_value(bucket.value_index)?.clone();
self.insert(bucket.key, value)?;
}
Ok(())
}
fn rehash_after_removal(&mut self, start_index: usize) {
let mut index = (start_index + 1) % self.capacity;
while let Some(bucket) = self.table[index].take() {
let ideal_index = (bucket.hash % self.capacity as u64) as usize;
let mut new_index = ideal_index;
while self.table[new_index].is_some() {
new_index = (new_index + 1) % self.capacity;
}
self.table[new_index] = Some(bucket);
index = (index + 1) % self.capacity;
}
}
}
impl<K, V> Default for GoldHashIdx<K, V>
where
K: Hash + Eq + Clone,
V: Clone,
{
fn default() -> Self {
Self::new()
}
}
impl<K, V> std::fmt::Debug for GoldHashIdx<K, V>
where
K: Hash + Eq + Clone + std::fmt::Debug,
V: Clone + std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GoldHashIdx")
.field("len", &self.len)
.field("capacity", &self.capacity)
.field("load_factor", &self.load_factor)
.field("key_memory", &self.key_memory)
.field("value_memory", &self.value_memory)
.field("free_values_count", &self.free_values.len())
.field("has_pool", &self.pool.is_some())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::SecurePoolConfig;
#[derive(Clone, Debug, PartialEq)]
struct LargeValue {
data: [u8; 256],
id: u32,
}
impl LargeValue {
fn new(id: u32) -> Self {
Self {
data: [id as u8; 256],
id,
}
}
}
#[test]
fn test_basic_operations() -> Result<()> {
let mut idx = GoldHashIdx::new();
let val1 = LargeValue::new(1);
let val2 = LargeValue::new(2);
assert_eq!(idx.insert("key1".to_string(), val1.clone())?, None);
assert_eq!(idx.insert("key2".to_string(), val2.clone())?, None);
assert_eq!(idx.len(), 2);
assert_eq!(idx.get(&"key1".to_string()), Some(&val1));
assert_eq!(idx.get(&"key2".to_string()), Some(&val2));
assert_eq!(idx.get(&"key3".to_string()), None);
assert!(idx.contains_key(&"key1".to_string()));
assert!(!idx.contains_key(&"key3".to_string()));
let val1_updated = LargeValue::new(11);
assert_eq!(
idx.insert("key1".to_string(), val1_updated.clone())?,
Some(val1)
);
assert_eq!(idx.get(&"key1".to_string()), Some(&val1_updated));
assert_eq!(idx.remove(&"key1".to_string()), Some(val1_updated));
assert_eq!(idx.get(&"key1".to_string()), None);
assert_eq!(idx.len(), 1);
Ok(())
}
#[test]
fn test_mutable_access() -> Result<()> {
let mut idx = GoldHashIdx::new();
let val = LargeValue::new(42);
idx.insert("key".to_string(), val)?;
{
let val_mut = idx.get_mut(&"key".to_string()).unwrap();
val_mut.id = 99;
}
assert_eq!(idx.get(&"key".to_string()).unwrap().id, 99);
Ok(())
}
#[test]
fn test_batch_operations() -> Result<()> {
let mut idx = GoldHashIdx::new();
let items: Vec<(String, LargeValue)> = (0..100)
.map(|i| (format!("key{}", i), LargeValue::new(i)))
.collect();
idx.insert_batch(items.clone())?;
assert_eq!(idx.len(), 100);
let keys: Vec<String> = (0..100).map(|i| format!("key{}", i)).collect();
let values = idx.get_batch(&keys);
for (i, value_opt) in values.iter().enumerate() {
assert!(value_opt.is_some());
assert_eq!(value_opt.unwrap().id, i as u32);
}
Ok(())
}
#[test]
fn test_with_custom_pool() -> Result<()> {
let config = SecurePoolConfig::small_secure();
let pool = SecureMemoryPool::new(config)?;
let mut idx = GoldHashIdx::with_pool(32, pool);
let val = LargeValue::new(42);
idx.insert("test".to_string(), val.clone())?;
assert_eq!(idx.get(&"test".to_string()), Some(&val));
Ok(())
}
#[test]
fn test_memory_usage_tracking() -> Result<()> {
let mut idx = GoldHashIdx::new();
let initial_usage = idx.memory_usage();
assert_eq!(initial_usage, (0, 0));
for i in 0..10 {
idx.insert(format!("key{}", i), LargeValue::new(i))?;
}
let (key_mem, value_mem) = idx.memory_usage();
assert!(key_mem > 0);
assert!(value_mem > 0);
println!(
"Memory usage - Keys: {} bytes, Values: {} bytes",
key_mem, value_mem
);
Ok(())
}
#[test]
fn test_shrink_to_fit() -> Result<()> {
let mut idx = GoldHashIdx::with_capacity(1024);
for i in 0..10 {
idx.insert(format!("key{}", i), LargeValue::new(i))?;
}
let initial_capacity = idx.capacity;
idx.shrink_to_fit();
assert!(idx.capacity <= initial_capacity);
assert_eq!(idx.len(), 10);
for i in 0..10 {
assert!(idx.contains_key(&format!("key{}", i)));
}
Ok(())
}
#[test]
fn test_large_scale() -> Result<()> {
let mut idx = GoldHashIdx::new();
for i in 0..1000 {
idx.insert(i.to_string(), LargeValue::new(i))?;
}
assert_eq!(idx.len(), 1000);
for i in 0..1000 {
assert_eq!(idx.get(&i.to_string()).unwrap().id, i);
}
for i in 0..500 {
assert!(idx.remove(&i.to_string()).is_some());
}
assert_eq!(idx.len(), 500);
for i in 500..1000 {
assert_eq!(idx.get(&i.to_string()).unwrap().id, i);
}
Ok(())
}
#[test]
fn test_empty_operations() {
let idx: GoldHashIdx<String, LargeValue> = GoldHashIdx::new();
assert_eq!(idx.len(), 0);
assert!(idx.is_empty());
assert_eq!(idx.get(&"key".to_string()), None);
assert!(!idx.contains_key(&"key".to_string()));
assert_eq!(idx.memory_usage(), (0, 0));
}
}