use std::cell::RefCell;
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::error::{ZiporaError, Result};
use crate::fsa::version_sync::{ConcurrencyLevel, VersionManager, ReaderToken, WriterToken};
#[derive(Debug)]
pub struct TokenCache {
cached_reader: Option<ReaderToken>,
cached_writer: Option<WriterToken>,
stats: TokenCacheStats,
}
impl TokenCache {
pub fn new() -> Self {
Self {
cached_reader: None,
cached_writer: None,
stats: TokenCacheStats::default(),
}
}
pub fn get_reader_token(&mut self) -> Option<ReaderToken> {
if let Some(token) = self.cached_reader.take() {
if token.is_valid() {
self.stats.reader_cache_hits += 1;
return Some(token);
} else {
self.stats.reader_cache_invalidations += 1;
}
}
self.stats.reader_cache_misses += 1;
None
}
pub fn get_writer_token(&mut self) -> Option<WriterToken> {
if let Some(token) = self.cached_writer.take() {
if token.is_valid() {
self.stats.writer_cache_hits += 1;
return Some(token);
} else {
self.stats.writer_cache_invalidations += 1;
}
}
self.stats.writer_cache_misses += 1;
None
}
pub fn cache_reader_token(&mut self, token: ReaderToken) {
self.cached_reader = Some(token);
self.stats.reader_tokens_cached += 1;
}
pub fn cache_writer_token(&mut self, token: WriterToken) {
self.cached_writer = Some(token);
self.stats.writer_tokens_cached += 1;
}
pub fn clear(&mut self) {
self.cached_reader = None;
self.cached_writer = None;
self.stats.cache_clears += 1;
}
pub fn stats(&self) -> &TokenCacheStats {
&self.stats
}
pub fn clear_stats(&mut self) {
self.stats = TokenCacheStats::default();
}
}
impl Default for TokenCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default, Clone)]
pub struct TokenCacheStats {
pub reader_cache_hits: u64,
pub reader_cache_misses: u64,
pub reader_cache_invalidations: u64,
pub reader_tokens_cached: u64,
pub writer_cache_hits: u64,
pub writer_cache_misses: u64,
pub writer_cache_invalidations: u64,
pub writer_tokens_cached: u64,
pub cache_clears: u64,
}
impl TokenCacheStats {
pub fn reader_hit_rate(&self) -> f64 {
let total = self.reader_cache_hits + self.reader_cache_misses;
if total == 0 {
0.0
} else {
self.reader_cache_hits as f64 / total as f64
}
}
pub fn writer_hit_rate(&self) -> f64 {
let total = self.writer_cache_hits + self.writer_cache_misses;
if total == 0 {
0.0
} else {
self.writer_cache_hits as f64 / total as f64
}
}
pub fn overall_hit_rate(&self) -> f64 {
let total_hits = self.reader_cache_hits + self.writer_cache_hits;
let total_requests = total_hits + self.reader_cache_misses + self.writer_cache_misses;
if total_requests == 0 {
0.0
} else {
total_hits as f64 / total_requests as f64
}
}
pub fn reader_invalidation_rate(&self) -> f64 {
let total = self.reader_cache_hits + self.reader_cache_invalidations;
if total == 0 {
0.0
} else {
self.reader_cache_invalidations as f64 / total as f64
}
}
pub fn writer_invalidation_rate(&self) -> f64 {
let total = self.writer_cache_hits + self.writer_cache_invalidations;
if total == 0 {
0.0
} else {
self.writer_cache_invalidations as f64 / total as f64
}
}
}
thread_local! {
static TOKEN_CACHE: RefCell<TokenCache> = RefCell::new(TokenCache::new());
}
#[derive(Debug)]
pub struct TokenManager {
version_manager: Arc<VersionManager>,
global_stats: Arc<std::sync::Mutex<GlobalTokenStats>>,
}
impl TokenManager {
pub fn new(concurrency_level: ConcurrencyLevel) -> Self {
Self {
version_manager: Arc::new(VersionManager::new(concurrency_level)),
global_stats: Arc::new(std::sync::Mutex::new(GlobalTokenStats::default())),
}
}
pub fn with_version_manager(version_manager: Arc<VersionManager>) -> Self {
Self {
version_manager,
global_stats: Arc::new(std::sync::Mutex::new(GlobalTokenStats::default())),
}
}
pub fn version_manager(&self) -> &Arc<VersionManager> {
&self.version_manager
}
pub fn acquire_reader_token(&self) -> Result<ReaderToken> {
let start_time = Instant::now();
let token = TOKEN_CACHE.with(|cache| {
cache.borrow_mut().get_reader_token()
});
let token = if let Some(cached_token) = token {
if let Ok(mut stats) = self.global_stats.lock() {
stats.total_reader_cache_hits += 1;
}
cached_token
} else {
let new_token = self.version_manager.acquire_reader_token()?;
if let Ok(mut stats) = self.global_stats.lock() {
stats.total_reader_cache_misses += 1;
stats.total_reader_acquisition_time += start_time.elapsed();
}
new_token
};
Ok(token)
}
pub fn acquire_writer_token(&self) -> Result<WriterToken> {
let start_time = Instant::now();
let token = TOKEN_CACHE.with(|cache| {
cache.borrow_mut().get_writer_token()
});
let token = if let Some(cached_token) = token {
if let Ok(mut stats) = self.global_stats.lock() {
stats.total_writer_cache_hits += 1;
}
cached_token
} else {
let new_token = self.version_manager.acquire_writer_token()?;
if let Ok(mut stats) = self.global_stats.lock() {
stats.total_writer_cache_misses += 1;
stats.total_writer_acquisition_time += start_time.elapsed();
}
new_token
};
Ok(token)
}
pub fn return_reader_token(&self, token: ReaderToken) {
TOKEN_CACHE.with(|cache| {
cache.borrow_mut().cache_reader_token(token);
});
if let Ok(mut stats) = self.global_stats.lock() {
stats.total_reader_tokens_returned += 1;
}
}
pub fn return_writer_token(&self, token: WriterToken) {
TOKEN_CACHE.with(|cache| {
cache.borrow_mut().cache_writer_token(token);
});
if let Ok(mut stats) = self.global_stats.lock() {
stats.total_writer_tokens_returned += 1;
}
}
pub fn clear_thread_cache(&self) {
TOKEN_CACHE.with(|cache| {
cache.borrow_mut().clear();
});
}
pub fn concurrency_level(&self) -> ConcurrencyLevel {
self.version_manager.concurrency_level()
}
pub fn global_stats(&self) -> Result<GlobalTokenStats> {
self.global_stats
.lock()
.map(|stats| stats.clone())
.map_err(|_| ZiporaError::system_error("Failed to acquire global stats mutex"))
}
pub fn thread_cache_stats(&self) -> TokenCacheStats {
TOKEN_CACHE.with(|cache| {
cache.borrow().stats().clone()
})
}
pub fn clear_all_stats(&self) -> Result<()> {
self.global_stats
.lock()
.map(|mut stats| *stats = GlobalTokenStats::default())
.map_err(|_| ZiporaError::system_error("Failed to acquire global stats mutex"))?;
TOKEN_CACHE.with(|cache| {
cache.borrow_mut().clear_stats();
});
self.version_manager.clear_stats()
}
}
#[derive(Debug, Default, Clone)]
pub struct GlobalTokenStats {
pub total_reader_cache_hits: u64,
pub total_reader_cache_misses: u64,
pub total_reader_tokens_returned: u64,
pub total_reader_acquisition_time: Duration,
pub total_writer_cache_hits: u64,
pub total_writer_cache_misses: u64,
pub total_writer_tokens_returned: u64,
pub total_writer_acquisition_time: Duration,
}
impl GlobalTokenStats {
pub fn reader_hit_rate(&self) -> f64 {
let total = self.total_reader_cache_hits + self.total_reader_cache_misses;
if total == 0 {
0.0
} else {
self.total_reader_cache_hits as f64 / total as f64
}
}
pub fn writer_hit_rate(&self) -> f64 {
let total = self.total_writer_cache_hits + self.total_writer_cache_misses;
if total == 0 {
0.0
} else {
self.total_writer_cache_hits as f64 / total as f64
}
}
pub fn overall_hit_rate(&self) -> f64 {
let total_hits = self.total_reader_cache_hits + self.total_writer_cache_hits;
let total_requests = total_hits + self.total_reader_cache_misses + self.total_writer_cache_misses;
if total_requests == 0 {
0.0
} else {
total_hits as f64 / total_requests as f64
}
}
pub fn avg_reader_acquisition_time(&self) -> Duration {
if self.total_reader_cache_misses == 0 {
Duration::ZERO
} else {
self.total_reader_acquisition_time / self.total_reader_cache_misses as u32
}
}
pub fn avg_writer_acquisition_time(&self) -> Duration {
if self.total_writer_cache_misses == 0 {
Duration::ZERO
} else {
self.total_writer_acquisition_time / self.total_writer_cache_misses as u32
}
}
}
pub fn with_reader_token<F, R>(token_manager: &TokenManager, f: F) -> Result<R>
where
F: FnOnce(&ReaderToken) -> Result<R>,
{
let token = token_manager.acquire_reader_token()?;
let result = f(&token)?;
token_manager.return_reader_token(token);
Ok(result)
}
pub fn with_writer_token<F, R>(token_manager: &TokenManager, f: F) -> Result<R>
where
F: FnOnce(&WriterToken) -> Result<R>,
{
let token = token_manager.acquire_writer_token()?;
let result = f(&token)?;
token_manager.return_writer_token(token);
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[test]
fn test_token_cache_basic() {
let mut cache = TokenCache::new();
assert!(cache.get_reader_token().is_none());
assert!(cache.get_writer_token().is_none());
let manager = TokenManager::new(ConcurrencyLevel::OneWriteMultiRead);
let reader_token = manager.acquire_reader_token().unwrap();
let writer_token = manager.acquire_writer_token().unwrap();
cache.cache_reader_token(reader_token);
cache.cache_writer_token(writer_token);
let cached_reader = cache.get_reader_token();
let cached_writer = cache.get_writer_token();
assert!(cached_reader.is_some());
assert!(cached_writer.is_some());
assert!(cache.get_reader_token().is_none());
assert!(cache.get_writer_token().is_none());
let stats = cache.stats();
assert_eq!(stats.reader_cache_hits, 1);
assert_eq!(stats.writer_cache_hits, 1);
}
#[test]
fn test_token_manager_basic() -> Result<()> {
let manager = TokenManager::new(ConcurrencyLevel::OneWriteMultiRead);
let reader_token = manager.acquire_reader_token()?;
assert!(reader_token.is_valid());
manager.return_reader_token(reader_token);
let writer_token = manager.acquire_writer_token()?;
assert!(writer_token.is_valid());
manager.return_writer_token(writer_token);
let global_stats = manager.global_stats()?;
assert_eq!(global_stats.total_reader_tokens_returned, 1);
assert_eq!(global_stats.total_writer_tokens_returned, 1);
Ok(())
}
#[test]
fn test_token_caching_performance() -> Result<()> {
let manager = TokenManager::new(ConcurrencyLevel::MultiWriteMultiRead);
let reader1 = manager.acquire_reader_token()?;
manager.return_reader_token(reader1);
let reader2 = manager.acquire_reader_token()?;
manager.return_reader_token(reader2);
let global_stats = manager.global_stats()?;
assert_eq!(global_stats.total_reader_cache_misses, 1);
assert_eq!(global_stats.total_reader_cache_hits, 1);
assert_eq!(global_stats.reader_hit_rate(), 0.5);
Ok(())
}
#[test]
fn test_with_reader_token_convenience() -> Result<()> {
let manager = TokenManager::new(ConcurrencyLevel::OneWriteMultiRead);
let result = with_reader_token(&manager, |token| {
assert!(token.is_valid());
Ok(42)
})?;
assert_eq!(result, 42);
let global_stats = manager.global_stats()?;
assert_eq!(global_stats.total_reader_tokens_returned, 1);
Ok(())
}
#[test]
fn test_with_writer_token_convenience() -> Result<()> {
let manager = TokenManager::new(ConcurrencyLevel::OneWriteMultiRead);
let result = with_writer_token(&manager, |token| {
assert!(token.is_valid());
Ok(84)
})?;
assert_eq!(result, 84);
let global_stats = manager.global_stats()?;
assert_eq!(global_stats.total_writer_tokens_returned, 1);
Ok(())
}
#[test]
fn test_concurrent_token_caching() -> Result<()> {
let manager = Arc::new(TokenManager::new(ConcurrencyLevel::MultiWriteMultiRead));
let num_threads = 4;
let operations_per_thread = 20;
let handles: Vec<_> = (0..num_threads)
.map(|_| {
let manager_clone = Arc::clone(&manager);
thread::spawn(move || -> Result<()> {
for _ in 0..operations_per_thread {
let reader = manager_clone.acquire_reader_token()?;
manager_clone.return_reader_token(reader);
let writer = manager_clone.acquire_writer_token()?;
manager_clone.return_writer_token(writer);
thread::sleep(Duration::from_millis(1));
}
Ok(())
})
})
.collect();
for handle in handles {
handle.join().unwrap()?;
}
let global_stats = manager.global_stats()?;
assert_eq!(
global_stats.total_reader_tokens_returned,
num_threads * operations_per_thread
);
assert_eq!(
global_stats.total_writer_tokens_returned,
num_threads * operations_per_thread
);
assert!(global_stats.overall_hit_rate() > 0.5);
Ok(())
}
#[test]
fn test_thread_cache_isolation() -> Result<()> {
let manager = Arc::new(TokenManager::new(ConcurrencyLevel::MultiWriteMultiRead));
let handles: Vec<_> = (0..2)
.map(|thread_id| {
let manager_clone = Arc::clone(&manager);
thread::spawn(move || -> Result<()> {
let reader = manager_clone.acquire_reader_token()?;
manager_clone.return_reader_token(reader);
let thread_stats = manager_clone.thread_cache_stats();
assert_eq!(thread_stats.reader_cache_misses, 1);
assert_eq!(thread_stats.reader_tokens_cached, 1);
Ok(())
})
})
.collect();
for handle in handles {
handle.join().unwrap()?;
}
let global_stats = manager.global_stats()?;
assert_eq!(global_stats.total_reader_cache_misses, 2); assert_eq!(global_stats.total_reader_tokens_returned, 2);
Ok(())
}
#[test]
fn test_statistics_accuracy() -> Result<()> {
let manager = TokenManager::new(ConcurrencyLevel::OneWriteMultiRead);
for _ in 0..5 {
with_reader_token(&manager, |_| Ok(()))?;
with_writer_token(&manager, |_| Ok(()))?;
}
let global_stats = manager.global_stats()?;
let thread_stats = manager.thread_cache_stats();
assert_eq!(global_stats.total_reader_cache_misses, 1); assert_eq!(global_stats.total_reader_cache_hits, 4); assert_eq!(global_stats.total_writer_cache_misses, 1); assert_eq!(global_stats.total_writer_cache_hits, 4);
assert_eq!(thread_stats.reader_cache_misses, 1);
assert_eq!(thread_stats.reader_cache_hits, 4);
assert_eq!(thread_stats.writer_cache_misses, 1);
assert_eq!(thread_stats.writer_cache_hits, 4);
assert_eq!(global_stats.reader_hit_rate(), 0.8); assert_eq!(global_stats.writer_hit_rate(), 0.8); assert_eq!(global_stats.overall_hit_rate(), 0.8);
Ok(())
}
#[test]
fn test_cache_clearing() -> Result<()> {
let manager = TokenManager::new(ConcurrencyLevel::OneWriteMultiRead);
with_reader_token(&manager, |_| Ok(()))?;
with_writer_token(&manager, |_| Ok(()))?;
manager.clear_thread_cache();
with_reader_token(&manager, |_| Ok(()))?;
with_writer_token(&manager, |_| Ok(()))?;
let thread_stats = manager.thread_cache_stats();
assert_eq!(thread_stats.reader_cache_misses, 2);
assert_eq!(thread_stats.writer_cache_misses, 2);
assert_eq!(thread_stats.cache_clears, 1);
Ok(())
}
}