use std::collections::HashMap;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
use torsh_core::{
dtype::TensorElement,
error::{Result, TorshError},
shape::Shape,
};
#[derive(Debug, Clone)]
pub struct LazyLoadConfig {
pub chunk_size: usize,
pub max_cached_chunks: usize,
pub cache_ttl: Duration,
pub memory_pressure_threshold: usize,
}
impl Default for LazyLoadConfig {
fn default() -> Self {
Self {
chunk_size: 1024 * 1024, max_cached_chunks: 16,
cache_ttl: Duration::from_secs(300), memory_pressure_threshold: 1024 * 1024 * 1024, }
}
}
#[derive(Debug, Clone)]
pub struct LazyTensorMetadata {
pub shape: Shape,
pub dtype: String,
pub total_elements: usize,
pub element_size: usize,
pub data_offset: u64,
}
#[derive(Debug, Clone)]
struct CachedChunk<T: TensorElement> {
data: Vec<T>,
range: (usize, usize),
last_accessed: Instant,
}
pub struct LazyTensor<T: TensorElement> {
metadata: LazyTensorMetadata,
file: Arc<Mutex<File>>,
#[allow(dead_code)]
file_path: PathBuf,
chunk_cache: Arc<RwLock<HashMap<usize, CachedChunk<T>>>>,
config: LazyLoadConfig,
_phantom: std::marker::PhantomData<T>,
}
impl<T: TensorElement> LazyTensor<T> {
pub fn new<P: AsRef<Path>>(
file_path: P,
metadata: LazyTensorMetadata,
config: LazyLoadConfig,
) -> Result<Self> {
let file_path = file_path.as_ref().to_path_buf();
let file = File::open(&file_path)
.map_err(|e| TorshError::IoError(format!("Failed to open file: {}", e)))?;
Ok(Self {
metadata,
file: Arc::new(Mutex::new(file)),
file_path,
chunk_cache: Arc::new(RwLock::new(HashMap::new())),
config,
_phantom: std::marker::PhantomData,
})
}
pub fn shape(&self) -> &Shape {
&self.metadata.shape
}
pub fn len(&self) -> usize {
self.metadata.total_elements
}
pub fn is_empty(&self) -> bool {
self.metadata.total_elements == 0
}
pub fn get_element(&self, index: usize) -> Result<T> {
if index >= self.metadata.total_elements {
return Err(TorshError::InvalidArgument(format!(
"Index {} out of bounds for tensor with {} elements",
index, self.metadata.total_elements
)));
}
let chunk_index = index / self.config.chunk_size;
let chunk_offset = index % self.config.chunk_size;
let chunk = self.load_chunk(chunk_index)?;
Ok(chunk.data[chunk_offset])
}
pub fn get_range(&self, start: usize, end: usize) -> Result<Vec<T>> {
if start > end || end > self.metadata.total_elements {
return Err(TorshError::InvalidArgument(format!(
"Invalid range [{}..{}] for tensor with {} elements",
start, end, self.metadata.total_elements
)));
}
let mut result = Vec::with_capacity(end - start);
let start_chunk = start / self.config.chunk_size;
let end_chunk = (end - 1) / self.config.chunk_size;
for chunk_idx in start_chunk..=end_chunk {
let chunk = self.load_chunk(chunk_idx)?;
let chunk_start = chunk_idx * self.config.chunk_size;
let chunk_end = std::cmp::min(
(chunk_idx + 1) * self.config.chunk_size,
self.metadata.total_elements,
);
let range_start = std::cmp::max(start, chunk_start) - chunk_start;
let range_end = std::cmp::min(end, chunk_end) - chunk_start;
result.extend_from_slice(&chunk.data[range_start..range_end]);
}
Ok(result)
}
pub fn load_all(&self) -> Result<Vec<T>> {
self.get_range(0, self.metadata.total_elements)
}
fn load_chunk(&self, chunk_index: usize) -> Result<Arc<CachedChunk<T>>> {
{
let cache = self
.chunk_cache
.read()
.expect("lock should not be poisoned");
if let Some(cached) = cache.get(&chunk_index) {
return Ok(Arc::new(CachedChunk {
data: cached.data.clone(),
range: cached.range,
last_accessed: Instant::now(),
}));
}
}
let start_element = chunk_index * self.config.chunk_size;
let end_element = std::cmp::min(
(chunk_index + 1) * self.config.chunk_size,
self.metadata.total_elements,
);
let chunk_size = end_element - start_element;
let data = self.load_chunk_from_file(start_element, chunk_size)?;
let chunk = Arc::new(CachedChunk {
data,
range: (start_element, end_element),
last_accessed: Instant::now(),
});
{
let mut cache = self
.chunk_cache
.write()
.expect("lock should not be poisoned");
self.cleanup_cache(&mut cache);
cache.insert(chunk_index, (*chunk).clone());
}
Ok(chunk)
}
fn load_chunk_from_file(&self, start_element: usize, chunk_size: usize) -> Result<Vec<T>> {
let mut file = self.file.lock().expect("lock should not be poisoned");
let file_offset =
self.metadata.data_offset + (start_element as u64 * self.metadata.element_size as u64);
file.seek(SeekFrom::Start(file_offset))
.map_err(|e| TorshError::IoError(format!("Failed to seek: {}", e)))?;
let mut buffer = vec![0u8; chunk_size * self.metadata.element_size];
file.read_exact(&mut buffer)
.map_err(|e| TorshError::IoError(format!("Failed to read chunk: {}", e)))?;
let data =
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, chunk_size).to_vec() };
Ok(data)
}
fn cleanup_cache(&self, cache: &mut HashMap<usize, CachedChunk<T>>) {
let now = Instant::now();
cache.retain(|_, chunk| now.duration_since(chunk.last_accessed) < self.config.cache_ttl);
if cache.len() > self.config.max_cached_chunks {
let mut chunks_to_remove: Vec<_> = cache
.iter()
.map(|(idx, chunk)| (*idx, chunk.last_accessed))
.collect();
chunks_to_remove.sort_by_key(|(_, accessed)| *accessed);
let to_remove = cache.len() - self.config.max_cached_chunks;
for (idx, _) in chunks_to_remove.iter().take(to_remove) {
cache.remove(idx);
}
}
}
pub fn cache_stats(&self) -> CacheStats {
let cache = self
.chunk_cache
.read()
.expect("lock should not be poisoned");
let total_cached_elements: usize = cache.values().map(|chunk| chunk.data.len()).sum();
CacheStats {
cached_chunks: cache.len(),
total_cached_elements,
estimated_memory_usage: total_cached_elements * std::mem::size_of::<T>(),
}
}
pub fn clear_cache(&self) {
let mut cache = self
.chunk_cache
.write()
.expect("lock should not be poisoned");
cache.clear();
}
pub fn check_memory_pressure(&self) -> Result<()> {
let stats = self.cache_stats();
if stats.estimated_memory_usage > self.config.memory_pressure_threshold {
let mut cache = self
.chunk_cache
.write()
.expect("lock should not be poisoned");
let recent_threshold = Duration::from_secs(60);
let now = Instant::now();
cache.retain(|_, chunk| now.duration_since(chunk.last_accessed) < recent_threshold);
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub cached_chunks: usize,
pub total_cached_elements: usize,
pub estimated_memory_usage: usize,
}
pub struct LazyTensorBuilder {
config: LazyLoadConfig,
}
impl LazyTensorBuilder {
pub fn new() -> Self {
Self {
config: LazyLoadConfig::default(),
}
}
pub fn chunk_size(mut self, size: usize) -> Self {
self.config.chunk_size = size;
self
}
pub fn max_cached_chunks(mut self, max: usize) -> Self {
self.config.max_cached_chunks = max;
self
}
pub fn cache_ttl(mut self, ttl: Duration) -> Self {
self.config.cache_ttl = ttl;
self
}
pub fn memory_pressure_threshold(mut self, threshold: usize) -> Self {
self.config.memory_pressure_threshold = threshold;
self
}
pub fn build<T: TensorElement, P: AsRef<Path>>(
self,
file_path: P,
metadata: LazyTensorMetadata,
) -> Result<LazyTensor<T>> {
LazyTensor::new(file_path, metadata, self.config)
}
}
impl Default for LazyTensorBuilder {
fn default() -> Self {
Self::new()
}
}
pub mod utils {
use super::*;
use std::fs::File;
use std::io::{BufReader, Read};
pub fn create_metadata_from_header<P: AsRef<Path>>(file_path: P) -> Result<LazyTensorMetadata> {
let file = File::open(file_path)?;
let mut reader = BufReader::new(file);
let mut header_size_bytes = [0u8; 4];
reader
.read_exact(&mut header_size_bytes)
.map_err(|e| TorshError::IoError(format!("Failed to read header size: {}", e)))?;
let header_size = u32::from_le_bytes(header_size_bytes) as usize;
let mut header_data = vec![0u8; header_size];
reader
.read_exact(&mut header_data)
.map_err(|e| TorshError::IoError(format!("Failed to read header: {}", e)))?;
let _header_str = String::from_utf8(header_data)
.map_err(|e| TorshError::SerializationError(format!("Invalid header: {}", e)))?;
let metadata = LazyTensorMetadata {
shape: Shape::new(vec![100, 100]), dtype: "f32".to_string(),
total_elements: 10000,
element_size: 4,
data_offset: 4 + header_size as u64,
};
Ok(metadata)
}
pub fn lazy_tensor_from_file<T: TensorElement, P: AsRef<Path>>(
file_path: P,
) -> Result<LazyTensor<T>> {
let metadata = create_metadata_from_header(&file_path)?;
LazyTensor::new(file_path, metadata, LazyLoadConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
use torsh_core::shape::Shape;
fn create_test_file() -> (NamedTempFile, LazyTensorMetadata) {
let mut temp_file = NamedTempFile::new().expect("temp file creation should succeed");
let header = r#"{"shape":[10,10],"dtype":"f32","total_elements":100}"#;
let header_size = header.len() as u32;
temp_file
.write_all(&header_size.to_le_bytes())
.expect("write should succeed");
temp_file
.write_all(header.as_bytes())
.expect("write should succeed");
for i in 0..100 {
temp_file
.write_all(&(i as f32).to_le_bytes())
.expect("write should succeed");
}
temp_file.flush().expect("flush should succeed");
let metadata = LazyTensorMetadata {
shape: Shape::new(vec![10, 10]),
dtype: "f32".to_string(),
total_elements: 100,
element_size: 4,
data_offset: 4 + header.len() as u64,
};
(temp_file, metadata)
}
#[test]
fn test_lazy_tensor_creation() {
let (temp_file, metadata) = create_test_file();
let lazy_tensor: LazyTensor<f32> =
LazyTensor::new(temp_file.path(), metadata, LazyLoadConfig::default())
.expect("lazy tensor creation should succeed");
assert_eq!(lazy_tensor.len(), 100);
assert!(!lazy_tensor.is_empty());
assert_eq!(lazy_tensor.shape().dims(), &[10, 10]);
}
#[test]
fn test_lazy_loading_element_access() {
let (temp_file, metadata) = create_test_file();
let lazy_tensor: LazyTensor<f32> = LazyTensor::new(
temp_file.path(),
metadata,
LazyLoadConfig {
chunk_size: 10,
..LazyLoadConfig::default()
},
)
.expect("lazy tensor creation should succeed");
let element = lazy_tensor
.get_element(5)
.expect("get_element should succeed");
assert!((element - 5.0).abs() < f32::EPSILON);
let element = lazy_tensor
.get_element(50)
.expect("get_element should succeed");
assert!((element - 50.0).abs() < f32::EPSILON);
}
#[test]
fn test_lazy_loading_range_access() {
let (temp_file, metadata) = create_test_file();
let lazy_tensor: LazyTensor<f32> = LazyTensor::new(
temp_file.path(),
metadata,
LazyLoadConfig {
chunk_size: 10,
..LazyLoadConfig::default()
},
)
.expect("lazy tensor creation should succeed");
let range = lazy_tensor
.get_range(10, 20)
.expect("get_range should succeed");
assert_eq!(range.len(), 10);
for (i, &value) in range.iter().enumerate() {
let expected = (10 + i) as f32;
assert!((value - expected).abs() < f32::EPSILON);
}
}
#[test]
fn test_cache_management() {
let (temp_file, metadata) = create_test_file();
let lazy_tensor: LazyTensor<f32> = LazyTensor::new(
temp_file.path(),
metadata,
LazyLoadConfig {
chunk_size: 10,
max_cached_chunks: 2,
..LazyLoadConfig::default()
},
)
.expect("lazy tensor creation should succeed");
lazy_tensor
.get_element(5)
.expect("get_element should succeed"); lazy_tensor
.get_element(15)
.expect("get_element should succeed");
let stats = lazy_tensor.cache_stats();
assert!(stats.cached_chunks <= 2);
lazy_tensor
.get_element(25)
.expect("get_element should succeed");
let stats_after = lazy_tensor.cache_stats();
assert!(stats_after.cached_chunks <= 3);
assert!(stats_after.total_cached_elements > 0);
}
#[test]
fn test_lazy_tensor_builder() {
let (temp_file, metadata) = create_test_file();
let lazy_tensor: LazyTensor<f32> = LazyTensorBuilder::new()
.chunk_size(5)
.max_cached_chunks(3)
.build(temp_file.path(), metadata)
.expect("lazy operation should succeed");
assert_eq!(lazy_tensor.config.chunk_size, 5);
assert_eq!(lazy_tensor.config.max_cached_chunks, 3);
}
#[test]
fn test_out_of_bounds_access() {
let (temp_file, metadata) = create_test_file();
let lazy_tensor: LazyTensor<f32> =
LazyTensor::new(temp_file.path(), metadata, LazyLoadConfig::default())
.expect("lazy tensor creation should succeed");
let result = lazy_tensor.get_element(1000);
assert!(result.is_err());
let result = lazy_tensor.get_range(90, 110);
assert!(result.is_err());
}
}