use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::RwLock;
#[derive(Debug, Clone)]
pub struct CachedToken {
pub surface: String,
pub pos: String,
pub start_byte: usize,
pub end_byte: usize,
}
pub type CacheKey = u64;
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_entries: usize,
pub max_key_length: usize,
pub track_stats: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 10_000,
max_key_length: 1024,
track_stats: true,
}
}
}
impl CacheConfig {
#[must_use]
pub const fn new() -> Self {
Self {
max_entries: 10_000,
max_key_length: 1024,
track_stats: true,
}
}
#[must_use]
pub const fn with_max_entries(mut self, max: usize) -> Self {
self.max_entries = max;
self
}
#[must_use]
pub const fn with_max_key_length(mut self, max: usize) -> Self {
self.max_key_length = max;
self
}
#[must_use]
pub const fn with_track_stats(mut self, track: bool) -> Self {
self.track_stats = track;
self
}
}
#[derive(Debug, Default)]
pub struct CacheStats {
hits: AtomicU64,
misses: AtomicU64,
evictions: AtomicU64,
}
impl CacheStats {
#[must_use]
pub fn hits(&self) -> u64 {
self.hits.load(Ordering::Relaxed)
}
#[must_use]
pub fn misses(&self) -> u64 {
self.misses.load(Ordering::Relaxed)
}
#[must_use]
pub fn total_requests(&self) -> u64 {
self.hits() + self.misses()
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn hit_rate(&self) -> f64 {
let total = self.total_requests();
if total == 0 {
0.0
} else {
self.hits() as f64 / total as f64
}
}
#[must_use]
pub fn evictions(&self) -> u64 {
self.evictions.load(Ordering::Relaxed)
}
fn record_hit(&self) {
self.hits.fetch_add(1, Ordering::Relaxed);
}
fn record_miss(&self) {
self.misses.fetch_add(1, Ordering::Relaxed);
}
fn record_eviction(&self) {
self.evictions.fetch_add(1, Ordering::Relaxed);
}
pub fn reset(&self) {
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
self.evictions.store(0, Ordering::Relaxed);
}
}
struct CacheEntry {
tokens: Vec<CachedToken>,
last_access: u64,
}
pub struct TokenCache {
config: CacheConfig,
entries: RwLock<HashMap<CacheKey, CacheEntry>>,
stats: CacheStats,
access_counter: AtomicU64,
}
impl TokenCache {
#[must_use]
pub fn new(config: CacheConfig) -> Self {
Self {
config,
entries: RwLock::new(HashMap::new()),
stats: CacheStats::default(),
access_counter: AtomicU64::new(0),
}
}
#[must_use]
pub fn with_defaults() -> Self {
Self::new(CacheConfig::default())
}
#[must_use]
pub fn make_key(&self, text: &str) -> CacheKey {
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
hasher.finish()
}
#[must_use]
pub fn get(&self, key: CacheKey) -> Option<Vec<CachedToken>> {
let mut entries = self.entries.write().ok()?;
if let Some(entry) = entries.get_mut(&key) {
entry.last_access = self.access_counter.fetch_add(1, Ordering::Relaxed);
if self.config.track_stats {
self.stats.record_hit();
}
Some(entry.tokens.clone())
} else {
if self.config.track_stats {
self.stats.record_miss();
}
None
}
}
pub fn insert(&self, key: CacheKey, tokens: Vec<CachedToken>) {
let Ok(mut entries) = self.entries.write() else {
return;
};
while entries.len() >= self.config.max_entries {
self.evict_lru(&mut entries);
}
let access = self.access_counter.fetch_add(1, Ordering::Relaxed);
entries.insert(
key,
CacheEntry {
tokens,
last_access: access,
},
);
}
pub fn get_or_insert<F>(&self, key: CacheKey, compute: F) -> Vec<CachedToken>
where
F: FnOnce() -> Vec<CachedToken>,
{
if let Some(tokens) = self.get(key) {
return tokens;
}
let tokens = compute();
self.insert(key, tokens.clone());
tokens
}
pub fn get_or_insert_with_text<F>(&self, text: &str, compute: F) -> Vec<CachedToken>
where
F: FnOnce() -> Vec<CachedToken>,
{
if text.len() > self.config.max_key_length {
return compute();
}
let key = self.make_key(text);
self.get_or_insert(key, compute)
}
fn evict_lru(&self, entries: &mut HashMap<CacheKey, CacheEntry>) {
if entries.is_empty() {
return;
}
let oldest_key = entries
.iter()
.min_by_key(|(_, entry)| entry.last_access)
.map(|(key, _)| *key);
if let Some(key) = oldest_key {
entries.remove(&key);
if self.config.track_stats {
self.stats.record_eviction();
}
}
}
pub fn clear(&self) {
if let Ok(mut entries) = self.entries.write() {
entries.clear();
}
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.read().map_or(0, |e| e.len())
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub const fn stats(&self) -> &CacheStats {
&self.stats
}
#[must_use]
pub const fn config(&self) -> &CacheConfig {
&self.config
}
}
impl Default for TokenCache {
fn default() -> Self {
Self::with_defaults()
}
}
pub struct CachingTokenizer<T> {
inner: T,
cache: TokenCache,
}
impl<T> CachingTokenizer<T> {
pub fn new(inner: T, config: CacheConfig) -> Self {
Self {
inner,
cache: TokenCache::new(config),
}
}
pub fn with_defaults(inner: T) -> Self {
Self::new(inner, CacheConfig::default())
}
#[must_use]
pub const fn inner(&self) -> &T {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut T {
&mut self.inner
}
#[must_use]
pub const fn cache(&self) -> &TokenCache {
&self.cache
}
#[must_use]
pub const fn stats(&self) -> &CacheStats {
self.cache.stats()
}
pub fn clear_cache(&self) {
self.cache.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_config_default() {
let config = CacheConfig::default();
assert_eq!(config.max_entries, 10_000);
assert_eq!(config.max_key_length, 1024);
assert!(config.track_stats);
}
#[test]
fn test_cache_config_builder() {
let config = CacheConfig::new()
.with_max_entries(1000)
.with_max_key_length(512)
.with_track_stats(false);
assert_eq!(config.max_entries, 1000);
assert_eq!(config.max_key_length, 512);
assert!(!config.track_stats);
}
#[test]
fn test_cache_basic_operations() {
let cache = TokenCache::with_defaults();
let key = cache.make_key("테스트");
assert!(cache.get(key).is_none());
assert_eq!(cache.stats().misses(), 1);
let tokens = vec![CachedToken {
surface: "테스트".to_string(),
pos: "NNG".to_string(),
start_byte: 0,
end_byte: 9,
}];
cache.insert(key, tokens);
let cached = cache.get(key).unwrap();
assert_eq!(cached.len(), 1);
assert_eq!(cached[0].surface, "테스트");
assert_eq!(cache.stats().hits(), 1);
}
#[test]
fn test_cache_get_or_insert() {
let cache = TokenCache::with_defaults();
let key = cache.make_key("안녕");
let mut call_count = 0;
let tokens1 = cache.get_or_insert(key, || {
call_count += 1;
vec![CachedToken {
surface: "안녕".to_string(),
pos: "IC".to_string(),
start_byte: 0,
end_byte: 6,
}]
});
assert_eq!(call_count, 1);
assert_eq!(tokens1.len(), 1);
let tokens2 = cache.get_or_insert(key, || {
call_count += 1;
vec![]
});
assert_eq!(call_count, 1); assert_eq!(tokens2.len(), 1);
}
#[test]
fn test_cache_lru_eviction() {
let config = CacheConfig::new().with_max_entries(3);
let cache = TokenCache::new(config);
for i in 0..3 {
let key = cache.make_key(&format!("text{i}"));
cache.insert(key, vec![]);
}
assert_eq!(cache.len(), 3);
let key0 = cache.make_key("text0");
let _ = cache.get(key0);
let key3 = cache.make_key("text3");
cache.insert(key3, vec![]);
assert_eq!(cache.len(), 3);
assert_eq!(cache.stats().evictions(), 1);
assert!(cache.get(key0).is_some());
let key1 = cache.make_key("text1");
assert!(cache.get(key1).is_none());
}
#[test]
fn test_cache_stats() {
let cache = TokenCache::with_defaults();
let key = cache.make_key("test");
let _ = cache.get(key);
assert_eq!(cache.stats().misses(), 1);
assert_eq!(cache.stats().hits(), 0);
assert!((cache.stats().hit_rate() - 0.0).abs() < f64::EPSILON);
cache.insert(key, vec![]);
let _ = cache.get(key);
assert_eq!(cache.stats().hits(), 1);
assert!((cache.stats().hit_rate() - 0.5).abs() < f64::EPSILON);
cache.stats().reset();
assert_eq!(cache.stats().total_requests(), 0);
}
#[test]
fn test_cache_clear() {
let cache = TokenCache::with_defaults();
for i in 0..10 {
let key = cache.make_key(&format!("text{i}"));
cache.insert(key, vec![]);
}
assert_eq!(cache.len(), 10);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_cache_skip_long_text() {
let config = CacheConfig::new().with_max_key_length(10);
let cache = TokenCache::new(config);
let mut call_count = 0;
let short = "짧은";
cache.get_or_insert_with_text(short, || {
call_count += 1;
vec![]
});
cache.get_or_insert_with_text(short, || {
call_count += 1;
vec![]
});
assert_eq!(call_count, 1);
let long = "이것은 아주 긴 텍스트입니다";
cache.get_or_insert_with_text(long, || {
call_count += 1;
vec![]
});
cache.get_or_insert_with_text(long, || {
call_count += 1;
vec![]
});
assert_eq!(call_count, 3); }
#[test]
fn test_caching_tokenizer() {
struct DummyTokenizer;
let caching = CachingTokenizer::with_defaults(DummyTokenizer);
assert!(caching.cache().is_empty());
assert_eq!(caching.stats().total_requests(), 0);
let key = caching.cache().make_key("test");
caching.cache().insert(key, vec![]);
assert_eq!(caching.cache().len(), 1);
caching.clear_cache();
assert!(caching.cache().is_empty());
}
}