use crate::Dataset;
use std::collections::{HashMap, VecDeque};
use std::path::PathBuf;
use std::sync::{Arc, Mutex, RwLock};
use tenflowers_core::{Device, Result, Tensor, TensorError};
#[cfg(feature = "serialize")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub struct StreamingOptimizedConfig {
pub buffer_size: usize,
pub num_workers: usize,
pub max_memory_bytes: usize,
pub chunk_size: usize,
pub shuffle_chunks: bool,
pub seed: Option<u64>,
pub use_memory_mapping: bool,
pub compression_type: CompressionType,
pub adaptive_buffering: bool,
pub gpu_acceleration: bool,
#[cfg_attr(feature = "serialize", serde(skip))]
pub device: Option<Device>,
pub parallel_loading: bool,
pub prefetch_threads: usize,
}
impl Default for StreamingOptimizedConfig {
fn default() -> Self {
Self {
buffer_size: 1000,
num_workers: 4,
max_memory_bytes: 1_000_000_000, chunk_size: 10000,
shuffle_chunks: true,
seed: None,
use_memory_mapping: true,
compression_type: CompressionType::None,
adaptive_buffering: true,
gpu_acceleration: false,
device: None,
parallel_loading: true,
prefetch_threads: 2,
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub enum CompressionType {
None,
Gzip,
Lz4,
Zstd,
}
#[derive(Debug, Clone)]
pub struct StreamingStats {
pub samples_processed: u64,
pub bytes_read: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub avg_processing_time_ms: f64,
pub memory_usage_bytes: usize,
pub throughput_samples_per_second: f64,
}
impl Default for StreamingStats {
fn default() -> Self {
Self {
samples_processed: 0,
bytes_read: 0,
cache_hits: 0,
cache_misses: 0,
avg_processing_time_ms: 0.0,
memory_usage_bytes: 0,
throughput_samples_per_second: 0.0,
}
}
}
pub struct AdaptiveBuffer<T> {
buffer: VecDeque<(Tensor<T>, Tensor<T>)>,
max_size: usize,
min_size: usize,
current_size: usize,
consumption_rate: f64,
production_rate: f64,
last_adjustment: std::time::Instant,
}
impl<T> AdaptiveBuffer<T>
where
T: Clone + Send + Sync + 'static,
{
pub fn new(initial_size: usize) -> Self {
Self {
buffer: VecDeque::new(),
max_size: initial_size * 4,
min_size: initial_size / 2,
current_size: initial_size,
consumption_rate: 0.0,
production_rate: 0.0,
last_adjustment: std::time::Instant::now(),
}
}
pub fn push(&mut self, item: (Tensor<T>, Tensor<T>)) -> bool {
if self.buffer.len() >= self.current_size {
false } else {
self.buffer.push_back(item);
self.update_production_rate();
true
}
}
pub fn pop(&mut self) -> Option<(Tensor<T>, Tensor<T>)> {
let item = self.buffer.pop_front();
if item.is_some() {
self.update_consumption_rate();
}
item
}
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
pub fn is_full(&self) -> bool {
self.buffer.len() >= self.current_size
}
fn update_consumption_rate(&mut self) {
let now = std::time::Instant::now();
let elapsed = now.duration_since(self.last_adjustment).as_secs_f64();
if elapsed > 1.0 {
self.consumption_rate = self.buffer.len() as f64 / elapsed;
self.adjust_buffer_size();
self.last_adjustment = now;
}
}
fn update_production_rate(&mut self) {
let now = std::time::Instant::now();
let elapsed = now.duration_since(self.last_adjustment).as_secs_f64();
if elapsed > 1.0 {
self.production_rate = self.buffer.len() as f64 / elapsed;
}
}
fn adjust_buffer_size(&mut self) {
if self.consumption_rate > self.production_rate * 1.5 {
self.current_size = (self.current_size * 2).min(self.max_size);
} else if self.production_rate > self.consumption_rate * 1.5 {
self.current_size = (self.current_size / 2).max(self.min_size);
}
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct ChunkMetadata {
file_path: PathBuf,
start_offset: u64,
end_offset: u64,
num_samples: usize,
compressed: bool,
}
#[allow(clippy::type_complexity)]
pub struct StreamingOptimizedDataset<T> {
chunks: Vec<ChunkMetadata>,
current_chunk: usize,
buffer: Arc<Mutex<AdaptiveBuffer<T>>>,
config: StreamingOptimizedConfig,
stats: Arc<RwLock<StreamingStats>>,
cache: Arc<Mutex<HashMap<usize, Vec<(Tensor<T>, Tensor<T>)>>>>,
memory_monitor: Arc<Mutex<MemoryMonitor>>,
sample_indices: Vec<usize>,
_current_position: usize,
}
struct MemoryMonitor {
current_usage: usize,
peak_usage: usize,
max_allowed: usize,
}
impl MemoryMonitor {
fn new(max_allowed: usize) -> Self {
Self {
current_usage: 0,
peak_usage: 0,
max_allowed,
}
}
fn allocate(&mut self, size: usize) -> bool {
if self.current_usage + size > self.max_allowed {
false
} else {
self.current_usage += size;
self.peak_usage = self.peak_usage.max(self.current_usage);
true
}
}
fn deallocate(&mut self, size: usize) {
self.current_usage = self.current_usage.saturating_sub(size);
}
fn usage_ratio(&self) -> f64 {
self.current_usage as f64 / self.max_allowed as f64
}
}
impl<T> StreamingOptimizedDataset<T>
where
T: Clone + Default + Send + Sync + 'static,
{
pub fn from_files(file_paths: Vec<PathBuf>, config: StreamingOptimizedConfig) -> Result<Self> {
let chunks = Self::analyze_files(&file_paths, &config)?;
let total_samples: usize = chunks.iter().map(|c| c.num_samples).sum();
let mut sample_indices: Vec<usize> = (0..total_samples).collect();
if config.shuffle_chunks {
use scirs2_core::random::{rand_prelude::*, rngs::StdRng, SeedableRng};
let mut rng = if let Some(seed) = config.seed {
StdRng::seed_from_u64(seed)
} else {
StdRng::seed_from_u64(42) };
sample_indices.shuffle(&mut rng);
}
let max_memory = config.max_memory_bytes;
let buffer_size = config.buffer_size;
Ok(Self {
chunks,
current_chunk: 0,
buffer: Arc::new(Mutex::new(AdaptiveBuffer::new(buffer_size))),
config,
stats: Arc::new(RwLock::new(StreamingStats::default())),
cache: Arc::new(Mutex::new(HashMap::new())),
memory_monitor: Arc::new(Mutex::new(MemoryMonitor::new(max_memory))),
sample_indices,
_current_position: 0,
})
}
fn analyze_files(
file_paths: &[PathBuf],
config: &StreamingOptimizedConfig,
) -> Result<Vec<ChunkMetadata>> {
let mut chunks = Vec::new();
for file_path in file_paths {
if !file_path.exists() {
return Err(TensorError::invalid_argument(format!(
"File does not exist: {file_path:?}"
)));
}
let file_size = std::fs::metadata(file_path)
.map_err(|e| {
TensorError::invalid_argument(format!("Failed to read file metadata: {e}"))
})?
.len();
let num_chunks = ((file_size as usize) + config.chunk_size - 1) / config.chunk_size;
for chunk_idx in 0..num_chunks {
let start_offset = (chunk_idx * config.chunk_size) as u64;
let end_offset =
((chunk_idx + 1) * config.chunk_size).min(file_size as usize) as u64;
let estimated_samples = config.chunk_size / 100;
chunks.push(ChunkMetadata {
file_path: file_path.clone(),
start_offset,
end_offset,
num_samples: estimated_samples,
compressed: matches!(
config.compression_type,
CompressionType::Gzip | CompressionType::Lz4 | CompressionType::Zstd
),
});
}
}
Ok(chunks)
}
fn load_chunk(&self, chunk_idx: usize) -> Result<Vec<(Tensor<T>, Tensor<T>)>> {
if chunk_idx >= self.chunks.len() {
return Err(TensorError::invalid_argument(format!(
"Chunk index {chunk_idx} out of bounds"
)));
}
if let Ok(cache) = self.cache.lock() {
if let Some(cached_data) = cache.get(&chunk_idx) {
if let Ok(mut stats) = self.stats.write() {
stats.cache_hits += 1;
}
return Ok(cached_data.clone());
}
}
let chunk = &self.chunks[chunk_idx];
let start_time = std::time::Instant::now();
let samples = self.load_chunk_from_disk(chunk)?;
if let Ok(mut stats) = self.stats.write() {
stats.cache_misses += 1;
stats.bytes_read += chunk.end_offset - chunk.start_offset;
stats.samples_processed += samples.len() as u64;
stats.avg_processing_time_ms = start_time.elapsed().as_millis() as f64;
}
if let Ok(mut cache) = self.cache.lock() {
let data_size = self.estimate_sample_size(&samples);
if let Ok(mut monitor) = self.memory_monitor.lock() {
if monitor.allocate(data_size) {
cache.insert(chunk_idx, samples.clone());
}
}
}
Ok(samples)
}
fn load_chunk_from_disk(&self, chunk: &ChunkMetadata) -> Result<Vec<(Tensor<T>, Tensor<T>)>> {
let mut samples = Vec::new();
for _i in 0..chunk.num_samples {
let features = Tensor::from_vec(vec![T::default(); 10], &[10])?;
let label = Tensor::from_vec(vec![T::default()], &[1])?;
samples.push((features, label));
}
Ok(samples)
}
fn estimate_sample_size(&self, samples: &[(Tensor<T>, Tensor<T>)]) -> usize {
samples.len() * (std::mem::size_of::<T>() * 11) }
pub fn prefetch_background(&self) {
let _buffer = Arc::clone(&self.buffer);
let chunks = self.chunks.clone();
let current_chunk = self.current_chunk;
let _config = self.config.clone();
std::thread::spawn(move || {
let _next_chunk = (current_chunk + 1) % chunks.len();
});
}
pub fn get_stats(&self) -> Result<StreamingStats> {
Ok(self
.stats
.read()
.map_err(|_| TensorError::invalid_argument("Failed to read stats".to_string()))?
.clone())
}
pub fn load_chunk_gpu(&self, chunk_index: usize) -> Result<Vec<(Tensor<T>, Tensor<T>)>>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::One
+ Send
+ Sync
+ 'static
+ bytemuck::Pod,
{
if !self.config.gpu_acceleration || self.config.device.is_none() {
return self.load_chunk(chunk_index);
}
let chunk_data = self.load_chunk(chunk_index)?;
let device = self.config.device.as_ref().ok_or_else(|| {
TensorError::invalid_argument(
"GPU device not configured for streaming optimization".to_string(),
)
})?;
let mut gpu_data = Vec::new();
for (features, labels) in chunk_data {
let gpu_features = features.to_device(*device)?;
let gpu_labels = labels.to_device(*device)?;
gpu_data.push((gpu_features, gpu_labels));
}
Ok(gpu_data)
}
pub fn prefetch_chunks_parallel(&self, chunk_indices: &[usize]) -> Result<()> {
if !self.config.parallel_loading {
for &index in chunk_indices {
self.load_chunk(index)?;
}
return Ok(());
}
for &index in chunk_indices {
if index < self.chunks.len() {
self.load_chunk(index)?;
if let Ok(mut stats) = self.stats.write() {
stats.samples_processed += 1;
stats.cache_hits += 1;
}
}
}
Ok(())
}
pub fn get_performance_metrics(&self) -> Result<StreamingPerformanceMetrics> {
let stats = self.get_stats()?;
let memory_usage = if let Ok(monitor) = self.memory_monitor.lock() {
monitor.current_usage
} else {
0
};
let cache_hit_rate = if stats.cache_hits + stats.cache_misses > 0 {
stats.cache_hits as f64 / (stats.cache_hits + stats.cache_misses) as f64
} else {
0.0
};
Ok(StreamingPerformanceMetrics {
throughput_samples_per_second: stats.throughput_samples_per_second,
memory_usage_bytes: memory_usage,
cache_hit_rate,
buffer_utilization: memory_usage as f64 / self.config.max_memory_bytes as f64,
chunks_loaded: (stats.samples_processed / 1000) as usize, gpu_acceleration_active: self.config.gpu_acceleration,
parallel_loading_active: self.config.parallel_loading,
})
}
pub fn reset_stats(&self) -> Result<()> {
let mut stats = self
.stats
.write()
.map_err(|_| TensorError::invalid_argument("Failed to write stats".to_string()))?;
*stats = StreamingStats::default();
Ok(())
}
pub fn memory_usage(&self) -> Result<(usize, usize, f64)> {
let monitor = self.memory_monitor.lock().map_err(|_| {
TensorError::invalid_argument("Failed to lock memory monitor".to_string())
})?;
Ok((
monitor.current_usage,
monitor.peak_usage,
monitor.usage_ratio(),
))
}
pub fn gc(&self) -> Result<()> {
let mut cache = self
.cache
.lock()
.map_err(|_| TensorError::invalid_argument("Failed to lock cache".to_string()))?;
let mut monitor = self.memory_monitor.lock().map_err(|_| {
TensorError::invalid_argument("Failed to lock memory monitor".to_string())
})?;
let freed_bytes = cache.len() * 1000; cache.clear();
monitor.deallocate(freed_bytes);
Ok(())
}
}
impl<T> Dataset<T> for StreamingOptimizedDataset<T>
where
T: Clone + Default + Send + Sync + 'static,
{
fn len(&self) -> usize {
self.sample_indices.len()
}
fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
if index >= self.len() {
return Err(TensorError::invalid_argument(format!(
"Index {} out of bounds for dataset of length {}",
index,
self.len()
)));
}
let actual_index = self.sample_indices[index];
let mut cumulative_samples = 0;
let mut chunk_idx = 0;
let mut sample_in_chunk = actual_index;
for (i, chunk) in self.chunks.iter().enumerate() {
if actual_index < cumulative_samples + chunk.num_samples {
chunk_idx = i;
sample_in_chunk = actual_index - cumulative_samples;
break;
}
cumulative_samples += chunk.num_samples;
}
let chunk_data = self.load_chunk(chunk_idx)?;
if sample_in_chunk >= chunk_data.len() {
return Err(TensorError::invalid_argument(format!(
"Sample index {sample_in_chunk} out of bounds in chunk"
)));
}
Ok(chunk_data[sample_in_chunk].clone())
}
}
pub struct StreamingOptimizedIterator<T> {
dataset: Arc<StreamingOptimizedDataset<T>>,
current_index: usize,
prefetch_enabled: bool,
}
impl<T> StreamingOptimizedIterator<T>
where
T: Clone + Default + Send + Sync + 'static,
{
pub fn new(dataset: Arc<StreamingOptimizedDataset<T>>) -> Self {
Self {
dataset,
current_index: 0,
prefetch_enabled: true,
}
}
pub fn with_prefetch(mut self, enabled: bool) -> Self {
self.prefetch_enabled = enabled;
self
}
}
impl<T> Iterator for StreamingOptimizedIterator<T>
where
T: Clone + Default + Send + Sync + 'static,
{
type Item = Result<(Tensor<T>, Tensor<T>)>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_index >= self.dataset.len() {
return None;
}
if self.prefetch_enabled && self.current_index % 1000 == 0 {
self.dataset.prefetch_background();
}
let result = self.dataset.get(self.current_index);
self.current_index += 1;
Some(result)
}
}
pub struct StreamingOptimizedDatasetBuilder<T> {
file_paths: Vec<PathBuf>,
config: StreamingOptimizedConfig,
_phantom: std::marker::PhantomData<T>,
}
impl<T> StreamingOptimizedDatasetBuilder<T>
where
T: Clone + Default + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
file_paths: Vec::new(),
config: StreamingOptimizedConfig::default(),
_phantom: std::marker::PhantomData,
}
}
pub fn add_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
self.file_paths.push(path.into());
self
}
pub fn add_files<P: Into<PathBuf>>(mut self, paths: Vec<P>) -> Self {
self.file_paths.extend(paths.into_iter().map(|p| p.into()));
self
}
pub fn buffer_size(mut self, size: usize) -> Self {
self.config.buffer_size = size;
self
}
pub fn num_workers(mut self, workers: usize) -> Self {
self.config.num_workers = workers;
self
}
pub fn max_memory(mut self, bytes: usize) -> Self {
self.config.max_memory_bytes = bytes;
self
}
pub fn chunk_size(mut self, size: usize) -> Self {
self.config.chunk_size = size;
self
}
pub fn shuffle(mut self, enabled: bool) -> Self {
self.config.shuffle_chunks = enabled;
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.config.seed = Some(seed);
self
}
pub fn compression(mut self, compression: CompressionType) -> Self {
self.config.compression_type = compression;
self
}
pub fn adaptive_buffering(mut self, enabled: bool) -> Self {
self.config.adaptive_buffering = enabled;
self
}
pub fn build(self) -> Result<StreamingOptimizedDataset<T>> {
if self.file_paths.is_empty() {
return Err(TensorError::invalid_argument(
"No file paths provided".to_string(),
));
}
StreamingOptimizedDataset::from_files(self.file_paths, self.config)
}
}
impl<T> Default for StreamingOptimizedDatasetBuilder<T>
where
T: Clone + Default + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct StreamingPerformanceMetrics {
pub throughput_samples_per_second: f64,
pub memory_usage_bytes: usize,
pub cache_hit_rate: f64,
pub buffer_utilization: f64,
pub chunks_loaded: usize,
pub gpu_acceleration_active: bool,
pub parallel_loading_active: bool,
}
impl Default for StreamingPerformanceMetrics {
fn default() -> Self {
Self {
throughput_samples_per_second: 0.0,
memory_usage_bytes: 0,
cache_hit_rate: 0.0,
buffer_utilization: 0.0,
chunks_loaded: 0,
gpu_acceleration_active: false,
parallel_loading_active: false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[test]
fn test_adaptive_buffer() {
let mut buffer = AdaptiveBuffer::<f32>::new(100);
let sample = (
Tensor::from_vec(vec![1.0, 2.0], &[2]).expect("test: tensor creation should succeed"),
Tensor::from_vec(vec![0.0], &[1]).expect("test: tensor creation should succeed"),
);
assert!(buffer.push(sample.clone()));
assert_eq!(buffer.len(), 1);
let _popped = buffer.pop().expect("test: operation should succeed");
assert_eq!(buffer.len(), 0);
}
#[test]
fn test_streaming_dataset_builder() {
let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
let file_path = temp_dir.path().join("test.dat");
fs::write(&file_path, b"dummy data").expect("test: write should succeed");
let builder = StreamingOptimizedDatasetBuilder::<f32>::new()
.add_file(file_path)
.buffer_size(50)
.chunk_size(1000)
.shuffle(true);
assert!(builder.file_paths.len() == 1);
}
#[test]
fn test_memory_monitor() {
let mut monitor = MemoryMonitor::new(1000);
assert!(monitor.allocate(500));
assert_eq!(monitor.current_usage, 500);
assert!(monitor.allocate(400));
assert_eq!(monitor.current_usage, 900);
assert!(!monitor.allocate(200));
monitor.deallocate(300);
assert_eq!(monitor.current_usage, 600);
}
#[test]
fn test_streaming_stats() {
let stats = StreamingStats::default();
assert_eq!(stats.samples_processed, 0);
assert_eq!(stats.cache_hits, 0);
}
}