use crate::cache::lru::ThreadSafeLruCache;
use crate::Dataset;
use std::sync::{Arc, Mutex};
use tenflowers_core::{Result, Tensor, TensorError};
pub struct CachedDataset<T, D: Dataset<T>> {
dataset: D,
cache: ThreadSafeLruCache<usize, (Tensor<T>, Tensor<T>)>,
cache_stats: Arc<Mutex<CacheStats>>,
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: usize,
pub misses: usize,
pub total_requests: usize,
}
impl CacheStats {
pub fn hit_ratio(&self) -> f64 {
if self.total_requests == 0 {
0.0
} else {
self.hits as f64 / self.total_requests as f64
}
}
}
impl<T, D: Dataset<T>> CachedDataset<T, D>
where
T: Clone + Send + Sync + 'static,
{
pub fn new(dataset: D, cache_capacity: usize) -> Self {
Self {
dataset,
cache: ThreadSafeLruCache::new(cache_capacity),
cache_stats: Arc::new(Mutex::new(CacheStats::default())),
}
}
pub fn cache_stats(&self) -> Result<CacheStats> {
match self.cache_stats.lock() {
Ok(stats) => Ok(stats.clone()),
Err(_) => Err(TensorError::CacheError {
operation: "cache_stats".to_string(),
details: "Cache stats mutex poisoned".to_string(),
recoverable: true,
context: None,
}),
}
}
pub fn clear_cache(&self) -> Result<()> {
self.cache.clear()?;
match self.cache_stats.lock() {
Ok(mut stats) => {
*stats = CacheStats::default();
Ok(())
}
Err(_) => Err(TensorError::CacheError {
operation: "clear_cache_stats".to_string(),
details: "Cache stats mutex poisoned during clear".to_string(),
recoverable: false,
context: None,
}),
}
}
pub fn warm_cache(&self, indices: &[usize]) -> Result<()> {
for &index in indices {
let _ = self.get(index)?;
}
Ok(())
}
pub fn into_inner(self) -> D {
self.dataset
}
pub fn inner(&self) -> &D {
&self.dataset
}
}
impl<T, D: Dataset<T>> Dataset<T> for CachedDataset<T, D>
where
T: Clone + Send + Sync + 'static,
{
fn len(&self) -> usize {
self.dataset.len()
}
fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
match self.cache_stats.lock() {
Ok(mut stats) => stats.total_requests += 1,
Err(_) => {
return Err(TensorError::CacheError {
operation: "cache_stats_update".to_string(),
details: "Cache stats mutex poisoned during total requests update".to_string(),
recoverable: false,
context: None,
})
}
}
if let Some(cached_sample) = self.cache.get(&index)? {
match self.cache_stats.lock() {
Ok(mut stats) => stats.hits += 1,
Err(_) => {
return Err(TensorError::CacheError {
operation: "cache_hit_stats".to_string(),
details: "Cache stats mutex poisoned during hit update".to_string(),
recoverable: false,
context: None,
})
}
}
return Ok(cached_sample);
}
let sample = self.dataset.get(index)?;
self.cache.insert(index, sample.clone())?;
match self.cache_stats.lock() {
Ok(mut stats) => stats.misses += 1,
Err(_) => {
return Err(TensorError::CacheError {
operation: "cache_miss_stats".to_string(),
details: "Cache stats mutex poisoned during miss update".to_string(),
recoverable: false,
context: None,
})
}
}
Ok(sample)
}
}
pub enum WarmingStrategy {
Sequential { start: usize, count: usize },
Random { count: usize, seed: Option<u64> },
Specific(Vec<usize>),
}
impl WarmingStrategy {
pub fn generate_indices(&self, dataset_len: usize) -> Vec<usize> {
match self {
WarmingStrategy::Sequential { start, count } => {
let end = (*start + *count).min(dataset_len);
(*start..end).collect()
}
WarmingStrategy::Random { count, seed } => {
use std::collections::HashSet;
let mut indices = Vec::new();
let mut seen = HashSet::new();
let mut state = seed.unwrap_or_else(|| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(12345))
.as_secs()
});
while indices.len() < *count && indices.len() < dataset_len {
state = state.wrapping_mul(1103515245).wrapping_add(12345);
let idx = (state as usize) % dataset_len;
if seen.insert(idx) {
indices.push(idx);
}
}
indices
}
WarmingStrategy::Specific(indices) => indices
.iter()
.filter(|&&idx| idx < dataset_len)
.copied()
.collect(),
}
}
}