use crate::category::Category;
use crate::error::{Result, SanitizeError};
use crate::generator::ReplacementGenerator;
use compact_str::CompactString;
use dashmap::DashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use zeroize::Zeroize;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct ForwardKey {
category: Category,
original: String,
}
pub struct MappingStore {
forward: DashMap<ForwardKey, CompactString>,
generator: Arc<dyn ReplacementGenerator>,
len: AtomicUsize,
capacity_limit: Option<usize>,
}
impl MappingStore {
#[must_use]
pub fn new(generator: Arc<dyn ReplacementGenerator>, capacity_limit: Option<usize>) -> Self {
Self {
forward: DashMap::with_capacity(1024),
generator,
len: AtomicUsize::new(0),
capacity_limit,
}
}
pub fn get_or_insert(&self, category: &Category, original: &str) -> Result<CompactString> {
let key = ForwardKey {
category: category.clone(),
original: original.to_owned(),
};
if let Some(existing) = self.forward.get(&key) {
return Ok(existing.value().clone());
}
if let Some(limit) = self.capacity_limit {
loop {
let current = self.len.load(Ordering::Acquire);
if current >= limit {
if let Some(existing) = self.forward.get(&key) {
return Ok(existing.value().clone());
}
return Err(SanitizeError::CapacityExceeded { current, limit });
}
if self
.len
.compare_exchange_weak(
current,
current + 1,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
break;
}
}
let mut was_inserted = false;
let result = self
.forward
.entry(key)
.or_insert_with(|| {
was_inserted = true;
let val = self.generator.generate(category, original);
CompactString::new(val)
})
.value()
.clone();
if !was_inserted {
self.len.fetch_sub(1, Ordering::Release);
}
Ok(result)
} else {
let mut was_inserted = false;
let result = self
.forward
.entry(key)
.or_insert_with(|| {
was_inserted = true;
let val = self.generator.generate(category, original);
CompactString::new(val)
})
.value()
.clone();
if was_inserted {
self.len.fetch_add(1, Ordering::Release);
}
Ok(result)
}
}
#[must_use]
pub fn forward_lookup(&self, category: &Category, original: &str) -> Option<CompactString> {
let key = ForwardKey {
category: category.clone(),
original: original.to_owned(),
};
self.forward.get(&key).map(|r| r.value().clone())
}
#[must_use]
pub fn len(&self) -> usize {
self.len.load(Ordering::Relaxed)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&mut self) {
let old_map = std::mem::take(&mut self.forward);
for (mut key, _value) in old_map {
key.original.zeroize();
}
self.len.store(0, Ordering::Release);
}
pub fn snapshot_keys(&self) -> std::collections::HashSet<(Category, String)> {
self.forward
.iter()
.map(|e| (e.key().category.clone(), e.key().original.clone()))
.collect()
}
pub fn iter(&self) -> impl Iterator<Item = (Category, CompactString, CompactString)> + '_ {
self.forward.iter().map(|entry| {
(
entry.key().category.clone(),
CompactString::new(&entry.key().original),
entry.value().clone(),
)
})
}
}
impl Drop for MappingStore {
fn drop(&mut self) {
let old_map = std::mem::take(&mut self.forward);
for (mut key, _value) in old_map {
key.original.zeroize();
}
}
}
macro_rules! static_assertions_send_sync {
($t:ty) => {
const _: fn() = || {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<$t>();
assert_sync::<$t>();
};
};
}
static_assertions_send_sync!(MappingStore);
#[cfg(test)]
mod tests {
use super::*;
use crate::generator::{HmacGenerator, RandomGenerator};
use std::sync::Arc;
fn hmac_store(limit: Option<usize>) -> MappingStore {
let gen = Arc::new(HmacGenerator::new([42u8; 32]));
MappingStore::new(gen, limit)
}
fn random_store() -> MappingStore {
let gen = Arc::new(RandomGenerator::new());
MappingStore::new(gen, None)
}
#[test]
fn insert_and_lookup() {
let store = hmac_store(None);
let s1 = store
.get_or_insert(&Category::Email, "alice@corp.com")
.unwrap();
assert!(!s1.is_empty());
assert!(s1.contains("@corp.com"), "domain must be preserved");
assert_eq!(s1.len(), "alice@corp.com".len(), "length must be preserved");
assert_eq!(store.len(), 1);
}
#[test]
fn same_input_same_output() {
let store = hmac_store(None);
let s1 = store
.get_or_insert(&Category::Email, "alice@corp.com")
.unwrap();
let s2 = store
.get_or_insert(&Category::Email, "alice@corp.com")
.unwrap();
assert_eq!(s1, s2, "repeated insert must return cached value");
assert_eq!(store.len(), 1, "no duplicate entry");
}
#[test]
fn different_inputs_different_outputs() {
let store = hmac_store(None);
let s1 = store
.get_or_insert(&Category::Email, "alice@corp.com")
.unwrap();
let s2 = store
.get_or_insert(&Category::Email, "bob@corp.com")
.unwrap();
assert_ne!(s1, s2);
assert_eq!(store.len(), 2);
}
#[test]
fn different_categories_different_outputs() {
let store = hmac_store(None);
let s1 = store.get_or_insert(&Category::Email, "test").unwrap();
let s2 = store.get_or_insert(&Category::Name, "test").unwrap();
assert_ne!(s1, s2);
}
#[test]
fn forward_lookup_works() {
let store = hmac_store(None);
let sanitized = store.get_or_insert(&Category::IpV4, "192.168.1.1").unwrap();
let found = store.forward_lookup(&Category::IpV4, "192.168.1.1");
assert_eq!(found, Some(sanitized));
}
#[test]
fn forward_lookup_missing() {
let store = hmac_store(None);
assert!(store.forward_lookup(&Category::Email, "nope").is_none());
}
#[test]
fn capacity_limit_enforced() {
let store = hmac_store(Some(2));
store.get_or_insert(&Category::Email, "a@a.com").unwrap();
store.get_or_insert(&Category::Email, "b@b.com").unwrap();
let result = store.get_or_insert(&Category::Email, "c@c.com");
assert!(result.is_err());
match result.unwrap_err() {
SanitizeError::CapacityExceeded {
current: 2,
limit: 2,
} => {}
other => panic!("unexpected error: {:?}", other),
}
}
#[test]
fn capacity_limit_allows_duplicate() {
let store = hmac_store(Some(1));
store.get_or_insert(&Category::Email, "a@a.com").unwrap();
let s2 = store.get_or_insert(&Category::Email, "a@a.com").unwrap();
assert!(!s2.is_empty());
}
#[test]
fn random_store_caches() {
let store = random_store();
let s1 = store
.get_or_insert(&Category::Email, "alice@corp.com")
.unwrap();
let s2 = store
.get_or_insert(&Category::Email, "alice@corp.com")
.unwrap();
assert_eq!(s1, s2, "random store must still cache the first result");
}
#[test]
fn iter_yields_all_mappings() {
let store = hmac_store(None);
store.get_or_insert(&Category::Email, "a@a.com").unwrap();
store.get_or_insert(&Category::IpV4, "1.2.3.4").unwrap();
let collected: Vec<_> = store.iter().collect();
assert_eq!(collected.len(), 2);
}
#[test]
fn concurrent_inserts_no_panic() {
use std::sync::Arc;
use std::thread;
let gen = Arc::new(HmacGenerator::new([99u8; 32]));
let store = Arc::new(MappingStore::new(gen, None));
let mut handles = vec![];
for t in 0..8 {
let store = Arc::clone(&store);
handles.push(thread::spawn(move || {
for i in 0..1000 {
let val = format!("thread{}-val{}", t, i);
store.get_or_insert(&Category::Email, &val).unwrap();
}
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(store.len(), 8000);
}
#[test]
fn concurrent_inserts_same_key_idempotent() {
use std::sync::Arc;
use std::thread;
let gen = Arc::new(HmacGenerator::new([7u8; 32]));
let store = Arc::new(MappingStore::new(gen, None));
let mut handles = vec![];
for _ in 0..8 {
let store = Arc::clone(&store);
handles.push(thread::spawn(move || {
let mut results = Vec::new();
for i in 0..100 {
let val = format!("shared-{}", i);
let r = store.get_or_insert(&Category::Email, &val).unwrap();
results.push((val, r));
}
results
}));
}
let mut all_results: Vec<Vec<(String, CompactString)>> = vec![];
for h in handles {
all_results.push(h.join().unwrap());
}
assert_eq!(store.len(), 100);
for i in 0..100 {
let val = format!("shared-{}", i);
let expected = store.forward_lookup(&Category::Email, &val).unwrap();
for thread_results in &all_results {
let (_, got) = &thread_results[i];
assert_eq!(
got, &expected,
"all threads must see the same mapping for {}",
val
);
}
}
}
}