use crate::cache::dataset::CacheStats;
use crate::Dataset;
use std::collections::HashMap;
use std::fs::{create_dir_all, File};
use std::hash::Hash;
use std::io::{BufReader, BufWriter};
use std::marker::PhantomData;
use std::path::Path;
use std::sync::{Arc, Mutex};
use tenflowers_core::{Result, Tensor, TensorError};
#[cfg(feature = "serialize")]
pub struct PersistentCache<K, V> {
cache_dir: std::path::PathBuf,
capacity: usize,
index: HashMap<K, (String, usize)>, access_counter: usize,
_phantom: PhantomData<V>,
}
impl<K, V> PersistentCache<K, V>
where
K: Clone + Eq + Hash + std::fmt::Display + std::str::FromStr,
V: Clone + serde::Serialize + for<'de> serde::Deserialize<'de>,
{
pub fn new<P: AsRef<Path>>(cache_dir: P, capacity: usize) -> Result<Self> {
let cache_dir = cache_dir.as_ref().to_path_buf();
if !cache_dir.exists() {
create_dir_all(&cache_dir).map_err(|e| {
TensorError::invalid_argument(format!("Failed to create cache directory: {e}"))
})?;
}
let mut cache = Self {
cache_dir,
capacity,
index: HashMap::new(),
access_counter: 0,
_phantom: PhantomData,
};
cache.load_index()?;
Ok(cache)
}
fn load_index(&mut self) -> Result<()> {
let index_path = self.cache_dir.join("cache_index.json");
if !index_path.exists() {
return Ok(()); }
let file = File::open(&index_path).map_err(|e| {
TensorError::invalid_argument(format!("Failed to open cache index: {e}"))
})?;
let reader = BufReader::new(file);
let index_data: HashMap<String, (String, usize)> = serde_json::from_reader(reader)
.map_err(|e| {
TensorError::invalid_argument(format!("Failed to parse cache index: {e}"))
})?;
for (key_str, (filename, access_order)) in index_data {
if let Ok(key) = key_str.parse::<K>() {
self.index.insert(key, (filename, access_order));
self.access_counter = self.access_counter.max(access_order);
}
}
self.access_counter += 1;
Ok(())
}
fn save_index(&self) -> Result<()> {
let index_path = self.cache_dir.join("cache_index.json");
let file = File::create(&index_path).map_err(|e| {
TensorError::invalid_argument(format!("Failed to create cache index: {e}"))
})?;
let writer = BufWriter::new(file);
let index_data: HashMap<String, (String, usize)> = self
.index
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect();
serde_json::to_writer(writer, &index_data).map_err(|e| {
TensorError::invalid_argument(format!("Failed to save cache index: {e}"))
})?;
Ok(())
}
pub fn get(&mut self, key: &K) -> Result<Option<V>> {
if let Some((filename, access_time)) = self.index.get_mut(key) {
self.access_counter += 1;
*access_time = self.access_counter;
let file_path = self.cache_dir.join(filename);
if !file_path.exists() {
self.index.remove(key);
return Ok(None);
}
let file = File::open(&file_path).map_err(|e| {
TensorError::invalid_argument(format!("Failed to open cache file: {e}"))
})?;
let reader = BufReader::new(file);
let value: V =
oxicode::serde::decode_from_std_read(reader, oxicode::config::standard())
.map_err(|e| {
TensorError::invalid_argument(format!(
"Failed to deserialize cached value: {e}"
))
})?
.0;
Ok(Some(value))
} else {
Ok(None)
}
}
pub fn insert(&mut self, key: K, value: V) -> Result<()> {
self.access_counter += 1;
if self.index.len() >= self.capacity && !self.index.contains_key(&key) {
self.evict_lru()?;
}
let filename = format!("cache_{}_{}.bin", key, self.access_counter);
let file_path = self.cache_dir.join(&filename);
let file = File::create(&file_path).map_err(|e| {
TensorError::invalid_argument(format!("Failed to create cache file: {e}"))
})?;
let writer = BufWriter::new(file);
oxicode::serde::encode_into_std_write(&value, writer, oxicode::config::standard())
.map_err(|e| {
TensorError::invalid_argument(format!("Failed to serialize value: {e}"))
})?;
if let Some((old_filename, _)) = self.index.insert(key, (filename, self.access_counter)) {
let old_path = self.cache_dir.join(old_filename);
let _ = std::fs::remove_file(old_path); }
self.save_index()?;
Ok(())
}
fn evict_lru(&mut self) -> Result<()> {
if let Some((lru_key, (filename, _))) = self
.index
.iter()
.min_by_key(|(_, (_, access_time))| *access_time)
.map(|(k, v)| (k.clone(), v.clone()))
{
let file_path = self.cache_dir.join(&filename);
let _ = std::fs::remove_file(file_path);
self.index.remove(&lru_key);
}
Ok(())
}
pub fn len(&self) -> usize {
self.index.len()
}
pub fn is_empty(&self) -> bool {
self.index.is_empty()
}
pub fn clear(&mut self) -> Result<()> {
for (filename, _) in self.index.values() {
let file_path = self.cache_dir.join(filename);
let _ = std::fs::remove_file(file_path); }
self.index.clear();
self.access_counter = 0;
self.save_index()?;
Ok(())
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn cache_dir(&self) -> &Path {
&self.cache_dir
}
}
#[cfg(feature = "serialize")]
pub struct TensorPersistentCache {
cache: PersistentCache<usize, (Vec<u8>, Vec<u8>)>, }
impl TensorPersistentCache {
pub fn new<P: AsRef<Path>>(cache_dir: P, capacity: usize) -> Result<Self> {
Ok(Self {
cache: PersistentCache::new(cache_dir, capacity)?,
})
}
pub fn get<T>(&mut self, index: &usize) -> Result<Option<(Tensor<T>, Tensor<T>)>>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::cast::NumCast,
{
if let Some((features_bytes, labels_bytes)) = self.cache.get(index)? {
let features_tensor = Self::deserialize_tensor(&features_bytes)?;
let labels_tensor = Self::deserialize_tensor(&labels_bytes)?;
Ok(Some((features_tensor, labels_tensor)))
} else {
Ok(None)
}
}
pub fn insert<T>(
&mut self,
index: usize,
features: &Tensor<T>,
labels: &Tensor<T>,
) -> Result<()>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::cast::NumCast,
{
let features_bytes = Self::serialize_tensor(features)?;
let labels_bytes = Self::serialize_tensor(labels)?;
self.cache.insert(index, (features_bytes, labels_bytes))?;
Ok(())
}
pub fn clear(&mut self) -> Result<()> {
self.cache.clear()
}
fn serialize_tensor<T>(tensor: &Tensor<T>) -> Result<Vec<u8>>
where
T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
{
let mut bytes = Vec::new();
let type_id = std::mem::size_of::<T>() as u8;
bytes.push(type_id);
let shape = tensor.shape().dims();
let shape_len = shape.len() as u32;
bytes.extend_from_slice(&shape_len.to_le_bytes());
for &dim in shape {
bytes.extend_from_slice(&(dim as u32).to_le_bytes());
}
if let Some(data_slice) = tensor.as_slice() {
for element in data_slice.iter() {
let element_ptr = element as *const T as *const u8;
let element_bytes = std::mem::size_of::<T>();
#[allow(unsafe_code)]
let element_data =
unsafe { std::slice::from_raw_parts(element_ptr, element_bytes) };
bytes.extend_from_slice(element_data);
}
} else {
return Err(TensorError::invalid_argument(
"Cannot serialize GPU tensors or tensors without CPU data".to_string(),
));
}
Ok(bytes)
}
fn deserialize_tensor<T>(bytes: &[u8]) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::cast::NumCast,
{
if bytes.len() < 5 {
return Err(TensorError::invalid_argument(
"Invalid tensor serialization: too few bytes".to_string(),
));
}
let mut offset = 0;
let _type_id = bytes[offset];
offset += 1;
let shape_len = u32::from_le_bytes([
bytes[offset],
bytes[offset + 1],
bytes[offset + 2],
bytes[offset + 3],
]) as usize;
offset += 4;
if bytes.len() < offset + shape_len * 4 {
return Err(TensorError::invalid_argument(
"Invalid tensor serialization: insufficient bytes for shape".to_string(),
));
}
let mut shape = Vec::with_capacity(shape_len);
for _ in 0..shape_len {
let dim = u32::from_le_bytes([
bytes[offset],
bytes[offset + 1],
bytes[offset + 2],
bytes[offset + 3],
]) as usize;
shape.push(dim);
offset += 4;
}
let total_elements = shape.iter().product::<usize>();
let element_size = std::mem::size_of::<T>();
let expected_data_bytes = total_elements * element_size;
if bytes.len() < offset + expected_data_bytes {
return Err(TensorError::invalid_argument(
"Invalid tensor serialization: insufficient bytes for data".to_string(),
));
}
let data_bytes = &bytes[offset..offset + expected_data_bytes];
let mut data = Vec::with_capacity(total_elements);
for i in 0..total_elements {
let element_offset = i * element_size;
let value = match element_size {
1 => {
let byte_val = data_bytes[element_offset];
scirs2_core::num_traits::cast::NumCast::from(byte_val)
.unwrap_or_else(T::default)
}
2 => {
if element_offset + 2 <= data_bytes.len() {
let val = u16::from_le_bytes([
data_bytes[element_offset],
data_bytes[element_offset + 1],
]);
scirs2_core::num_traits::cast::NumCast::from(val).unwrap_or_else(T::default)
} else {
T::default()
}
}
4 => {
if element_offset + 4 <= data_bytes.len() {
let val = f32::from_le_bytes([
data_bytes[element_offset],
data_bytes[element_offset + 1],
data_bytes[element_offset + 2],
data_bytes[element_offset + 3],
]);
scirs2_core::num_traits::cast::NumCast::from(val).unwrap_or_else(T::default)
} else {
T::default()
}
}
8 => {
if element_offset + 8 <= data_bytes.len() {
let val = f64::from_le_bytes([
data_bytes[element_offset],
data_bytes[element_offset + 1],
data_bytes[element_offset + 2],
data_bytes[element_offset + 3],
data_bytes[element_offset + 4],
data_bytes[element_offset + 5],
data_bytes[element_offset + 6],
data_bytes[element_offset + 7],
]);
scirs2_core::num_traits::cast::NumCast::from(val).unwrap_or_else(T::default)
} else {
T::default()
}
}
_ => {
T::default()
}
};
data.push(value);
}
Tensor::from_vec(data, &shape)
}
}
#[cfg(feature = "serialize")]
pub struct PersistentlyCachedDataset<T, D: Dataset<T>> {
dataset: D,
cache: Arc<Mutex<TensorPersistentCache>>,
cache_stats: Arc<Mutex<CacheStats>>,
_phantom: PhantomData<T>,
}
impl<T, D: Dataset<T>> PersistentlyCachedDataset<T, D>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::cast::NumCast,
{
pub fn new<P: AsRef<Path>>(dataset: D, cache_dir: P, cache_capacity: usize) -> Result<Self> {
let cache = TensorPersistentCache::new(cache_dir, cache_capacity)?;
Ok(Self {
dataset,
cache: Arc::new(Mutex::new(cache)),
cache_stats: Arc::new(Mutex::new(CacheStats::default())),
_phantom: PhantomData,
})
}
pub fn cache_stats(&self) -> Result<CacheStats> {
match self.cache_stats.lock() {
Ok(stats) => Ok(stats.clone()),
Err(_) => Err(TensorError::CacheError {
operation: "persistent_cache_stats".to_string(),
details: "Persistent cache stats mutex poisoned".to_string(),
recoverable: true,
context: None,
}),
}
}
pub fn clear_cache(&self) -> Result<()> {
match self.cache.lock() {
Ok(mut cache) => cache.clear()?,
Err(_) => {
return Err(TensorError::CacheError {
operation: "persistent_cache_clear".to_string(),
details: "Persistent cache mutex poisoned during clear".to_string(),
recoverable: false,
context: None,
})
}
}
match self.cache_stats.lock() {
Ok(mut stats) => {
*stats = CacheStats::default();
Ok(())
}
Err(_) => Err(TensorError::CacheError {
operation: "persistent_cache_clear_stats".to_string(),
details: "Persistent cache stats mutex poisoned during clear".to_string(),
recoverable: false,
context: None,
}),
}
}
pub fn into_inner(self) -> D {
self.dataset
}
pub fn inner(&self) -> &D {
&self.dataset
}
}
impl<T, D: Dataset<T>> Dataset<T> for PersistentlyCachedDataset<T, D>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::cast::NumCast,
{
fn len(&self) -> usize {
self.dataset.len()
}
fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
match self.cache_stats.lock() {
Ok(mut stats) => stats.total_requests += 1,
Err(_) => {
return Err(TensorError::CacheError {
operation: "persistent_cache_stats_update".to_string(),
details: "Persistent cache stats mutex poisoned during total requests update"
.to_string(),
recoverable: false,
context: None,
})
}
}
let cache_result = match self.cache.lock() {
Ok(mut cache) => cache.get(&index),
Err(_) => {
return Err(TensorError::CacheError {
operation: "persistent_cache_get".to_string(),
details: "Persistent cache mutex poisoned during get operation".to_string(),
recoverable: false,
context: None,
})
}
};
if let Ok(Some(cached_sample)) = cache_result {
match self.cache_stats.lock() {
Ok(mut stats) => stats.hits += 1,
Err(_) => {
return Err(TensorError::CacheError {
operation: "persistent_cache_hit_stats".to_string(),
details: "Persistent cache stats mutex poisoned during hit update"
.to_string(),
recoverable: false,
context: None,
})
}
}
return Ok(cached_sample);
}
let sample = self.dataset.get(index)?;
match self.cache.lock() {
Ok(mut cache) => {
if let Err(e) = cache.insert(index, &sample.0, &sample.1) {
eprintln!("Warning: Failed to cache sample {index}: {e}");
}
}
Err(_) => {
eprintln!("Warning: Cache mutex poisoned during insert for sample {index}");
}
}
match self.cache_stats.lock() {
Ok(mut stats) => stats.misses += 1,
Err(_) => {
return Err(TensorError::CacheError {
operation: "persistent_cache_miss_stats".to_string(),
details: "Persistent cache stats mutex poisoned during miss update".to_string(),
recoverable: false,
context: None,
})
}
}
Ok(sample)
}
}