1use crate::Dataset;
7use std::collections::{HashMap, VecDeque};
8use std::path::PathBuf;
9use std::sync::{Arc, Mutex, RwLock};
10use tenflowers_core::{Device, Result, Tensor, TensorError};
11
12#[cfg(feature = "serialize")]
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone)]
17#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
18pub struct StreamingOptimizedConfig {
19 pub buffer_size: usize,
21 pub num_workers: usize,
23 pub max_memory_bytes: usize,
25 pub chunk_size: usize,
27 pub shuffle_chunks: bool,
29 pub seed: Option<u64>,
31 pub use_memory_mapping: bool,
33 pub compression_type: CompressionType,
35 pub adaptive_buffering: bool,
37 pub gpu_acceleration: bool,
39 #[cfg_attr(feature = "serialize", serde(skip))]
41 pub device: Option<Device>,
42 pub parallel_loading: bool,
44 pub prefetch_threads: usize,
46}
47
48impl Default for StreamingOptimizedConfig {
49 fn default() -> Self {
50 Self {
51 buffer_size: 1000,
52 num_workers: 4,
53 max_memory_bytes: 1_000_000_000, chunk_size: 10000,
55 shuffle_chunks: true,
56 seed: None,
57 use_memory_mapping: true,
58 compression_type: CompressionType::None,
59 adaptive_buffering: true,
60 gpu_acceleration: false,
61 device: None,
62 parallel_loading: true,
63 prefetch_threads: 2,
64 }
65 }
66}
67
68#[derive(Debug, Clone)]
70#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
71pub enum CompressionType {
72 None,
73 Gzip,
74 Lz4,
75 Zstd,
76}
77
78#[derive(Debug, Clone)]
80pub struct StreamingStats {
81 pub samples_processed: u64,
82 pub bytes_read: u64,
83 pub cache_hits: u64,
84 pub cache_misses: u64,
85 pub avg_processing_time_ms: f64,
86 pub memory_usage_bytes: usize,
87 pub throughput_samples_per_second: f64,
88}
89
90impl Default for StreamingStats {
91 fn default() -> Self {
92 Self {
93 samples_processed: 0,
94 bytes_read: 0,
95 cache_hits: 0,
96 cache_misses: 0,
97 avg_processing_time_ms: 0.0,
98 memory_usage_bytes: 0,
99 throughput_samples_per_second: 0.0,
100 }
101 }
102}
103
104pub struct AdaptiveBuffer<T> {
106 buffer: VecDeque<(Tensor<T>, Tensor<T>)>,
107 max_size: usize,
108 min_size: usize,
109 current_size: usize,
110 consumption_rate: f64,
111 production_rate: f64,
112 last_adjustment: std::time::Instant,
113}
114
115impl<T> AdaptiveBuffer<T>
116where
117 T: Clone + Send + Sync + 'static,
118{
119 pub fn new(initial_size: usize) -> Self {
120 Self {
121 buffer: VecDeque::new(),
122 max_size: initial_size * 4,
123 min_size: initial_size / 2,
124 current_size: initial_size,
125 consumption_rate: 0.0,
126 production_rate: 0.0,
127 last_adjustment: std::time::Instant::now(),
128 }
129 }
130
131 pub fn push(&mut self, item: (Tensor<T>, Tensor<T>)) -> bool {
132 if self.buffer.len() >= self.current_size {
133 false } else {
135 self.buffer.push_back(item);
136 self.update_production_rate();
137 true
138 }
139 }
140
141 pub fn pop(&mut self) -> Option<(Tensor<T>, Tensor<T>)> {
142 let item = self.buffer.pop_front();
143 if item.is_some() {
144 self.update_consumption_rate();
145 }
146 item
147 }
148
149 pub fn len(&self) -> usize {
150 self.buffer.len()
151 }
152
153 pub fn is_empty(&self) -> bool {
154 self.buffer.is_empty()
155 }
156
157 pub fn is_full(&self) -> bool {
158 self.buffer.len() >= self.current_size
159 }
160
161 fn update_consumption_rate(&mut self) {
162 let now = std::time::Instant::now();
163 let elapsed = now.duration_since(self.last_adjustment).as_secs_f64();
164 if elapsed > 1.0 {
165 self.consumption_rate = self.buffer.len() as f64 / elapsed;
166 self.adjust_buffer_size();
167 self.last_adjustment = now;
168 }
169 }
170
171 fn update_production_rate(&mut self) {
172 let now = std::time::Instant::now();
173 let elapsed = now.duration_since(self.last_adjustment).as_secs_f64();
174 if elapsed > 1.0 {
175 self.production_rate = self.buffer.len() as f64 / elapsed;
176 }
177 }
178
179 fn adjust_buffer_size(&mut self) {
180 if self.consumption_rate > self.production_rate * 1.5 {
181 self.current_size = (self.current_size * 2).min(self.max_size);
183 } else if self.production_rate > self.consumption_rate * 1.5 {
184 self.current_size = (self.current_size / 2).max(self.min_size);
186 }
187 }
188}
189
190#[derive(Debug, Clone)]
192#[allow(dead_code)]
193struct ChunkMetadata {
194 file_path: PathBuf,
195 start_offset: u64,
196 end_offset: u64,
197 num_samples: usize,
198 compressed: bool,
199}
200
201#[allow(clippy::type_complexity)]
203pub struct StreamingOptimizedDataset<T> {
204 chunks: Vec<ChunkMetadata>,
205 current_chunk: usize,
206 buffer: Arc<Mutex<AdaptiveBuffer<T>>>,
207 config: StreamingOptimizedConfig,
208 stats: Arc<RwLock<StreamingStats>>,
209 cache: Arc<Mutex<HashMap<usize, Vec<(Tensor<T>, Tensor<T>)>>>>,
210 memory_monitor: Arc<Mutex<MemoryMonitor>>,
211 sample_indices: Vec<usize>,
212 _current_position: usize,
213}
214
215struct MemoryMonitor {
217 current_usage: usize,
218 peak_usage: usize,
219 max_allowed: usize,
220}
221
222impl MemoryMonitor {
223 fn new(max_allowed: usize) -> Self {
224 Self {
225 current_usage: 0,
226 peak_usage: 0,
227 max_allowed,
228 }
229 }
230
231 fn allocate(&mut self, size: usize) -> bool {
232 if self.current_usage + size > self.max_allowed {
233 false
234 } else {
235 self.current_usage += size;
236 self.peak_usage = self.peak_usage.max(self.current_usage);
237 true
238 }
239 }
240
241 fn deallocate(&mut self, size: usize) {
242 self.current_usage = self.current_usage.saturating_sub(size);
243 }
244
245 fn usage_ratio(&self) -> f64 {
246 self.current_usage as f64 / self.max_allowed as f64
247 }
248}
249
250impl<T> StreamingOptimizedDataset<T>
251where
252 T: Clone + Default + Send + Sync + 'static,
253{
254 pub fn from_files(file_paths: Vec<PathBuf>, config: StreamingOptimizedConfig) -> Result<Self> {
256 let chunks = Self::analyze_files(&file_paths, &config)?;
257 let total_samples: usize = chunks.iter().map(|c| c.num_samples).sum();
258
259 let mut sample_indices: Vec<usize> = (0..total_samples).collect();
260
261 if config.shuffle_chunks {
263 use scirs2_core::random::{rand_prelude::*, rngs::StdRng, SeedableRng};
264 let mut rng = if let Some(seed) = config.seed {
265 StdRng::seed_from_u64(seed)
266 } else {
267 StdRng::seed_from_u64(42) };
269 sample_indices.shuffle(&mut rng);
270 }
271
272 let max_memory = config.max_memory_bytes;
273 let buffer_size = config.buffer_size;
274
275 Ok(Self {
276 chunks,
277 current_chunk: 0,
278 buffer: Arc::new(Mutex::new(AdaptiveBuffer::new(buffer_size))),
279 config,
280 stats: Arc::new(RwLock::new(StreamingStats::default())),
281 cache: Arc::new(Mutex::new(HashMap::new())),
282 memory_monitor: Arc::new(Mutex::new(MemoryMonitor::new(max_memory))),
283 sample_indices,
284 _current_position: 0,
285 })
286 }
287
288 fn analyze_files(
290 file_paths: &[PathBuf],
291 config: &StreamingOptimizedConfig,
292 ) -> Result<Vec<ChunkMetadata>> {
293 let mut chunks = Vec::new();
294
295 for file_path in file_paths {
296 if !file_path.exists() {
297 return Err(TensorError::invalid_argument(format!(
298 "File does not exist: {file_path:?}"
299 )));
300 }
301
302 let file_size = std::fs::metadata(file_path)
303 .map_err(|e| {
304 TensorError::invalid_argument(format!("Failed to read file metadata: {e}"))
305 })?
306 .len();
307 let num_chunks = ((file_size as usize) + config.chunk_size - 1) / config.chunk_size;
308
309 for chunk_idx in 0..num_chunks {
310 let start_offset = (chunk_idx * config.chunk_size) as u64;
311 let end_offset =
312 ((chunk_idx + 1) * config.chunk_size).min(file_size as usize) as u64;
313
314 let estimated_samples = config.chunk_size / 100; chunks.push(ChunkMetadata {
318 file_path: file_path.clone(),
319 start_offset,
320 end_offset,
321 num_samples: estimated_samples,
322 compressed: matches!(
323 config.compression_type,
324 CompressionType::Gzip | CompressionType::Lz4 | CompressionType::Zstd
325 ),
326 });
327 }
328 }
329
330 Ok(chunks)
331 }
332
333 fn load_chunk(&self, chunk_idx: usize) -> Result<Vec<(Tensor<T>, Tensor<T>)>> {
335 if chunk_idx >= self.chunks.len() {
336 return Err(TensorError::invalid_argument(format!(
337 "Chunk index {chunk_idx} out of bounds"
338 )));
339 }
340
341 if let Ok(cache) = self.cache.lock() {
343 if let Some(cached_data) = cache.get(&chunk_idx) {
344 if let Ok(mut stats) = self.stats.write() {
346 stats.cache_hits += 1;
347 }
348 return Ok(cached_data.clone());
349 }
350 }
351
352 let chunk = &self.chunks[chunk_idx];
354 let start_time = std::time::Instant::now();
355
356 let samples = self.load_chunk_from_disk(chunk)?;
357
358 if let Ok(mut stats) = self.stats.write() {
360 stats.cache_misses += 1;
361 stats.bytes_read += chunk.end_offset - chunk.start_offset;
362 stats.samples_processed += samples.len() as u64;
363 stats.avg_processing_time_ms = start_time.elapsed().as_millis() as f64;
364 }
365
366 if let Ok(mut cache) = self.cache.lock() {
368 let data_size = self.estimate_sample_size(&samples);
369 if let Ok(mut monitor) = self.memory_monitor.lock() {
370 if monitor.allocate(data_size) {
371 cache.insert(chunk_idx, samples.clone());
372 }
373 }
374 }
375
376 Ok(samples)
377 }
378
379 fn load_chunk_from_disk(&self, chunk: &ChunkMetadata) -> Result<Vec<(Tensor<T>, Tensor<T>)>> {
381 let mut samples = Vec::new();
384
385 for _i in 0..chunk.num_samples {
387 let features = Tensor::from_vec(vec![T::default(); 10], &[10])?;
388 let label = Tensor::from_vec(vec![T::default()], &[1])?;
389 samples.push((features, label));
390 }
391
392 Ok(samples)
393 }
394
395 fn estimate_sample_size(&self, samples: &[(Tensor<T>, Tensor<T>)]) -> usize {
397 samples.len() * (std::mem::size_of::<T>() * 11) }
399
400 pub fn prefetch_background(&self) {
402 let _buffer = Arc::clone(&self.buffer);
403 let chunks = self.chunks.clone();
404 let current_chunk = self.current_chunk;
405 let _config = self.config.clone();
406
407 std::thread::spawn(move || {
408 let _next_chunk = (current_chunk + 1) % chunks.len();
409 });
412 }
413
414 pub fn get_stats(&self) -> Result<StreamingStats> {
416 Ok(self
417 .stats
418 .read()
419 .map_err(|_| TensorError::invalid_argument("Failed to read stats".to_string()))?
420 .clone())
421 }
422
423 pub fn load_chunk_gpu(&self, chunk_index: usize) -> Result<Vec<(Tensor<T>, Tensor<T>)>>
425 where
426 T: Clone
427 + Default
428 + scirs2_core::numeric::Zero
429 + scirs2_core::numeric::One
430 + Send
431 + Sync
432 + 'static
433 + bytemuck::Pod,
434 {
435 if !self.config.gpu_acceleration || self.config.device.is_none() {
436 return self.load_chunk(chunk_index);
437 }
438
439 let chunk_data = self.load_chunk(chunk_index)?;
440 let device = self.config.device.as_ref().ok_or_else(|| {
441 TensorError::invalid_argument(
442 "GPU device not configured for streaming optimization".to_string(),
443 )
444 })?;
445
446 let mut gpu_data = Vec::new();
448 for (features, labels) in chunk_data {
449 let gpu_features = features.to_device(*device)?;
450 let gpu_labels = labels.to_device(*device)?;
451 gpu_data.push((gpu_features, gpu_labels));
452 }
453
454 Ok(gpu_data)
455 }
456
457 pub fn prefetch_chunks_parallel(&self, chunk_indices: &[usize]) -> Result<()> {
459 if !self.config.parallel_loading {
460 for &index in chunk_indices {
462 self.load_chunk(index)?;
463 }
464 return Ok(());
465 }
466
467 for &index in chunk_indices {
470 if index < self.chunks.len() {
471 self.load_chunk(index)?;
472
473 if let Ok(mut stats) = self.stats.write() {
475 stats.samples_processed += 1;
476 stats.cache_hits += 1;
477 }
478 }
479 }
480
481 Ok(())
482 }
483
484 pub fn get_performance_metrics(&self) -> Result<StreamingPerformanceMetrics> {
486 let stats = self.get_stats()?;
487 let memory_usage = if let Ok(monitor) = self.memory_monitor.lock() {
488 monitor.current_usage
489 } else {
490 0
491 };
492
493 let cache_hit_rate = if stats.cache_hits + stats.cache_misses > 0 {
494 stats.cache_hits as f64 / (stats.cache_hits + stats.cache_misses) as f64
495 } else {
496 0.0
497 };
498
499 Ok(StreamingPerformanceMetrics {
500 throughput_samples_per_second: stats.throughput_samples_per_second,
501 memory_usage_bytes: memory_usage,
502 cache_hit_rate,
503 buffer_utilization: memory_usage as f64 / self.config.max_memory_bytes as f64,
504 chunks_loaded: (stats.samples_processed / 1000) as usize, gpu_acceleration_active: self.config.gpu_acceleration,
506 parallel_loading_active: self.config.parallel_loading,
507 })
508 }
509
510 pub fn reset_stats(&self) -> Result<()> {
512 let mut stats = self
513 .stats
514 .write()
515 .map_err(|_| TensorError::invalid_argument("Failed to write stats".to_string()))?;
516 *stats = StreamingStats::default();
517 Ok(())
518 }
519
520 pub fn memory_usage(&self) -> Result<(usize, usize, f64)> {
522 let monitor = self.memory_monitor.lock().map_err(|_| {
523 TensorError::invalid_argument("Failed to lock memory monitor".to_string())
524 })?;
525 Ok((
526 monitor.current_usage,
527 monitor.peak_usage,
528 monitor.usage_ratio(),
529 ))
530 }
531
532 pub fn gc(&self) -> Result<()> {
534 let mut cache = self
535 .cache
536 .lock()
537 .map_err(|_| TensorError::invalid_argument("Failed to lock cache".to_string()))?;
538
539 let mut monitor = self.memory_monitor.lock().map_err(|_| {
540 TensorError::invalid_argument("Failed to lock memory monitor".to_string())
541 })?;
542
543 let freed_bytes = cache.len() * 1000; cache.clear();
546 monitor.deallocate(freed_bytes);
547
548 Ok(())
549 }
550}
551
552impl<T> Dataset<T> for StreamingOptimizedDataset<T>
553where
554 T: Clone + Default + Send + Sync + 'static,
555{
556 fn len(&self) -> usize {
557 self.sample_indices.len()
558 }
559
560 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
561 if index >= self.len() {
562 return Err(TensorError::invalid_argument(format!(
563 "Index {} out of bounds for dataset of length {}",
564 index,
565 self.len()
566 )));
567 }
568
569 let actual_index = self.sample_indices[index];
570
571 let mut cumulative_samples = 0;
573 let mut chunk_idx = 0;
574 let mut sample_in_chunk = actual_index;
575
576 for (i, chunk) in self.chunks.iter().enumerate() {
577 if actual_index < cumulative_samples + chunk.num_samples {
578 chunk_idx = i;
579 sample_in_chunk = actual_index - cumulative_samples;
580 break;
581 }
582 cumulative_samples += chunk.num_samples;
583 }
584
585 let chunk_data = self.load_chunk(chunk_idx)?;
587
588 if sample_in_chunk >= chunk_data.len() {
589 return Err(TensorError::invalid_argument(format!(
590 "Sample index {sample_in_chunk} out of bounds in chunk"
591 )));
592 }
593
594 Ok(chunk_data[sample_in_chunk].clone())
595 }
596}
597
598pub struct StreamingOptimizedIterator<T> {
600 dataset: Arc<StreamingOptimizedDataset<T>>,
601 current_index: usize,
602 prefetch_enabled: bool,
603}
604
605impl<T> StreamingOptimizedIterator<T>
606where
607 T: Clone + Default + Send + Sync + 'static,
608{
609 pub fn new(dataset: Arc<StreamingOptimizedDataset<T>>) -> Self {
610 Self {
611 dataset,
612 current_index: 0,
613 prefetch_enabled: true,
614 }
615 }
616
617 pub fn with_prefetch(mut self, enabled: bool) -> Self {
618 self.prefetch_enabled = enabled;
619 self
620 }
621}
622
623impl<T> Iterator for StreamingOptimizedIterator<T>
624where
625 T: Clone + Default + Send + Sync + 'static,
626{
627 type Item = Result<(Tensor<T>, Tensor<T>)>;
628
629 fn next(&mut self) -> Option<Self::Item> {
630 if self.current_index >= self.dataset.len() {
631 return None;
632 }
633
634 if self.prefetch_enabled && self.current_index % 1000 == 0 {
636 self.dataset.prefetch_background();
637 }
638
639 let result = self.dataset.get(self.current_index);
640 self.current_index += 1;
641 Some(result)
642 }
643}
644
645pub struct StreamingOptimizedDatasetBuilder<T> {
647 file_paths: Vec<PathBuf>,
648 config: StreamingOptimizedConfig,
649 _phantom: std::marker::PhantomData<T>,
650}
651
652impl<T> StreamingOptimizedDatasetBuilder<T>
653where
654 T: Clone + Default + Send + Sync + 'static,
655{
656 pub fn new() -> Self {
657 Self {
658 file_paths: Vec::new(),
659 config: StreamingOptimizedConfig::default(),
660 _phantom: std::marker::PhantomData,
661 }
662 }
663
664 pub fn add_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
665 self.file_paths.push(path.into());
666 self
667 }
668
669 pub fn add_files<P: Into<PathBuf>>(mut self, paths: Vec<P>) -> Self {
670 self.file_paths.extend(paths.into_iter().map(|p| p.into()));
671 self
672 }
673
674 pub fn buffer_size(mut self, size: usize) -> Self {
675 self.config.buffer_size = size;
676 self
677 }
678
679 pub fn num_workers(mut self, workers: usize) -> Self {
680 self.config.num_workers = workers;
681 self
682 }
683
684 pub fn max_memory(mut self, bytes: usize) -> Self {
685 self.config.max_memory_bytes = bytes;
686 self
687 }
688
689 pub fn chunk_size(mut self, size: usize) -> Self {
690 self.config.chunk_size = size;
691 self
692 }
693
694 pub fn shuffle(mut self, enabled: bool) -> Self {
695 self.config.shuffle_chunks = enabled;
696 self
697 }
698
699 pub fn seed(mut self, seed: u64) -> Self {
700 self.config.seed = Some(seed);
701 self
702 }
703
704 pub fn compression(mut self, compression: CompressionType) -> Self {
705 self.config.compression_type = compression;
706 self
707 }
708
709 pub fn adaptive_buffering(mut self, enabled: bool) -> Self {
710 self.config.adaptive_buffering = enabled;
711 self
712 }
713
714 pub fn build(self) -> Result<StreamingOptimizedDataset<T>> {
715 if self.file_paths.is_empty() {
716 return Err(TensorError::invalid_argument(
717 "No file paths provided".to_string(),
718 ));
719 }
720
721 StreamingOptimizedDataset::from_files(self.file_paths, self.config)
722 }
723}
724
725impl<T> Default for StreamingOptimizedDatasetBuilder<T>
726where
727 T: Clone + Default + Send + Sync + 'static,
728{
729 fn default() -> Self {
730 Self::new()
731 }
732}
733
734#[derive(Debug, Clone)]
736pub struct StreamingPerformanceMetrics {
737 pub throughput_samples_per_second: f64,
739 pub memory_usage_bytes: usize,
741 pub cache_hit_rate: f64,
743 pub buffer_utilization: f64,
745 pub chunks_loaded: usize,
747 pub gpu_acceleration_active: bool,
749 pub parallel_loading_active: bool,
751}
752
753impl Default for StreamingPerformanceMetrics {
754 fn default() -> Self {
755 Self {
756 throughput_samples_per_second: 0.0,
757 memory_usage_bytes: 0,
758 cache_hit_rate: 0.0,
759 buffer_utilization: 0.0,
760 chunks_loaded: 0,
761 gpu_acceleration_active: false,
762 parallel_loading_active: false,
763 }
764 }
765}
766
767#[cfg(test)]
768mod tests {
769 use super::*;
770 use std::fs;
771 use tempfile::TempDir;
772
773 #[test]
774 fn test_adaptive_buffer() {
775 let mut buffer = AdaptiveBuffer::<f32>::new(100);
776
777 let sample = (
778 Tensor::from_vec(vec![1.0, 2.0], &[2]).expect("test: tensor creation should succeed"),
779 Tensor::from_vec(vec![0.0], &[1]).expect("test: tensor creation should succeed"),
780 );
781
782 assert!(buffer.push(sample.clone()));
783 assert_eq!(buffer.len(), 1);
784
785 let _popped = buffer.pop().expect("test: operation should succeed");
786 assert_eq!(buffer.len(), 0);
787 }
788
789 #[test]
790 fn test_streaming_dataset_builder() {
791 let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
792 let file_path = temp_dir.path().join("test.dat");
793
794 fs::write(&file_path, b"dummy data").expect("test: write should succeed");
796
797 let builder = StreamingOptimizedDatasetBuilder::<f32>::new()
798 .add_file(file_path)
799 .buffer_size(50)
800 .chunk_size(1000)
801 .shuffle(true);
802
803 assert!(builder.file_paths.len() == 1);
805 }
806
807 #[test]
808 fn test_memory_monitor() {
809 let mut monitor = MemoryMonitor::new(1000);
810
811 assert!(monitor.allocate(500));
812 assert_eq!(monitor.current_usage, 500);
813
814 assert!(monitor.allocate(400));
815 assert_eq!(monitor.current_usage, 900);
816
817 assert!(!monitor.allocate(200)); monitor.deallocate(300);
820 assert_eq!(monitor.current_usage, 600);
821 }
822
823 #[test]
824 fn test_streaming_stats() {
825 let stats = StreamingStats::default();
826 assert_eq!(stats.samples_processed, 0);
827 assert_eq!(stats.cache_hits, 0);
828 }
829}