use crate::data::Dataset;
use crate::error::{NeuralError, Result};
use scirs2_core::chunking::{ChunkConfig, ChunkStrategy, ChunkingUtils};
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::random::seq::SliceRandom;
use scirs2_core::NumAssign;
use std::collections::VecDeque;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::{Duration, Instant};
type BatchPair<F> = (Array<F, IxDyn>, Array<F, IxDyn>);
#[derive(Debug, Clone)]
pub struct OptimizedLoaderConfig {
pub batch_size: usize,
pub prefetch_size: usize,
pub num_workers: usize,
pub drop_last: bool,
pub shuffle: bool,
pub pin_memory: bool,
pub cache_batches: bool,
pub max_cache_memory: usize,
}
impl Default for OptimizedLoaderConfig {
fn default() -> Self {
Self {
batch_size: 32,
prefetch_size: 2,
num_workers: 0,
drop_last: false,
shuffle: true,
pin_memory: false,
cache_batches: false,
max_cache_memory: 0,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct LoadingStats {
pub batches_loaded: usize,
pub samples_loaded: usize,
pub total_load_time: Duration,
pub avg_batch_time: Duration,
pub cache_hits: usize,
pub cache_misses: usize,
pub prefetch_wait_time: Duration,
}
pub type BatchResult<F> = Result<(Array<F, IxDyn>, Array<F, IxDyn>)>;
struct BatchCache<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync> {
cache: Vec<Option<BatchPair<F>>>,
max_batches: usize,
memory_usage: usize,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync> BatchCache<F> {
fn new(max_batches: usize) -> Self {
Self {
cache: vec![None; max_batches],
max_batches,
memory_usage: 0,
}
}
fn get(&self, index: usize) -> Option<&BatchPair<F>> {
if index < self.cache.len() {
self.cache[index].as_ref()
} else {
None
}
}
fn insert(&mut self, index: usize, batch: BatchPair<F>) {
if index < self.cache.len() {
let batch_size = estimate_array_memory(&batch.0) + estimate_array_memory(&batch.1);
self.memory_usage += batch_size;
self.cache[index] = Some(batch);
}
}
fn clear(&mut self) {
self.cache.iter_mut().for_each(|b| *b = None);
self.memory_usage = 0;
}
}
fn estimate_array_memory<F: Float + NumAssign>(array: &Array<F, IxDyn>) -> usize {
array.len() * std::mem::size_of::<F>()
}
struct PrefetchQueue<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync> {
queue: Mutex<VecDeque<(usize, BatchResult<F>)>>,
max_size: usize,
size: AtomicUsize,
stop: AtomicBool,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync> PrefetchQueue<F> {
fn new(max_size: usize) -> Self {
Self {
queue: Mutex::new(VecDeque::with_capacity(max_size)),
max_size,
size: AtomicUsize::new(0),
stop: AtomicBool::new(false),
}
}
fn push(&self, index: usize, batch: BatchResult<F>) -> bool {
if self.stop.load(Ordering::Relaxed) {
return false;
}
while self.size.load(Ordering::Relaxed) >= self.max_size {
if self.stop.load(Ordering::Relaxed) {
return false;
}
thread::sleep(Duration::from_micros(100));
}
let mut queue = match self.queue.lock() {
Ok(q) => q,
Err(_) => return false,
};
queue.push_back((index, batch));
self.size.fetch_add(1, Ordering::Relaxed);
true
}
fn pop(&self) -> Option<(usize, BatchResult<F>)> {
let mut queue = match self.queue.lock() {
Ok(q) => q,
Err(_) => return None,
};
let result = queue.pop_front();
if result.is_some() {
self.size.fetch_sub(1, Ordering::Relaxed);
}
result
}
fn stop(&self) {
self.stop.store(true, Ordering::Relaxed);
}
fn is_empty(&self) -> bool {
self.size.load(Ordering::Relaxed) == 0
}
}
pub struct OptimizedDataLoader<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync,
D: Dataset<F> + Send + Sync + Clone + 'static,
> {
dataset: Arc<D>,
config: OptimizedLoaderConfig,
indices: Vec<usize>,
position: AtomicUsize,
num_batches: usize,
cache: Option<Mutex<BatchCache<F>>>,
stats: Mutex<LoadingStats>,
_phantom: PhantomData<F>,
}
impl<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
D: Dataset<F> + Send + Sync + Clone + 'static,
> OptimizedDataLoader<F, D>
{
pub fn new(dataset: D, config: OptimizedLoaderConfig) -> Self {
let dataset_len = dataset.len();
let batch_size = config.batch_size;
let drop_last = config.drop_last;
let num_batches = if drop_last {
dataset_len / batch_size
} else {
dataset_len.div_ceil(batch_size)
};
let indices: Vec<usize> = (0..dataset_len).collect();
let cache = if config.cache_batches {
Some(Mutex::new(BatchCache::new(num_batches)))
} else {
None
};
Self {
dataset: Arc::new(dataset),
config,
indices,
position: AtomicUsize::new(0),
num_batches,
cache,
stats: Mutex::new(LoadingStats::default()),
_phantom: PhantomData,
}
}
pub fn reset(&mut self) {
if self.config.shuffle {
let mut rng = scirs2_core::random::rng();
self.indices.shuffle(&mut rng);
}
self.position.store(0, Ordering::Relaxed);
}
pub fn num_batches(&self) -> usize {
self.num_batches
}
pub fn len(&self) -> usize {
self.dataset.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn stats(&self) -> LoadingStats {
self.stats
.lock()
.map_or_else(|_| LoadingStats::default(), |s| s.clone())
}
fn load_batch(&self, batch_idx: usize) -> BatchResult<F> {
let start = batch_idx * self.config.batch_size;
let end = (start + self.config.batch_size).min(self.indices.len());
if start >= self.indices.len() {
return Err(NeuralError::TrainingError(
"Batch index out of range".to_string(),
));
}
let batch_indices: Vec<usize> = self.indices[start..end].to_vec();
if batch_indices.is_empty() {
return Err(NeuralError::TrainingError("Empty batch".to_string()));
}
let (first_x, first_y) = self.dataset.get(batch_indices[0])?;
let batch_x_shape: Vec<usize> = std::iter::once(batch_indices.len())
.chain(first_x.shape().iter().copied())
.collect();
let batch_y_shape: Vec<usize> = std::iter::once(batch_indices.len())
.chain(first_y.shape().iter().copied())
.collect();
let mut batch_x = Array::zeros(IxDyn(&batch_x_shape));
let mut batch_y = Array::zeros(IxDyn(&batch_y_shape));
for (i, &idx) in batch_indices.iter().enumerate() {
let (x, y) = self.dataset.get(idx)?;
let mut batch_x_slice = batch_x.slice_mut(scirs2_core::ndarray::s![i, ..]);
batch_x_slice.assign(&x);
let mut batch_y_slice = batch_y.slice_mut(scirs2_core::ndarray::s![i, ..]);
batch_y_slice.assign(&y);
}
Ok((batch_x, batch_y))
}
pub fn next_batch(&self) -> Option<BatchResult<F>> {
let batch_idx = self.position.fetch_add(1, Ordering::Relaxed);
if batch_idx >= self.num_batches {
return None;
}
if let Some(ref cache) = self.cache {
if let Ok(cache_guard) = cache.lock() {
if let Some(batch) = cache_guard.get(batch_idx) {
if let Ok(mut stats) = self.stats.lock() {
stats.cache_hits += 1;
}
return Some(Ok((batch.0.clone(), batch.1.clone())));
}
}
}
let start = Instant::now();
let result = self.load_batch(batch_idx);
let load_time = start.elapsed();
if let Ok(mut stats) = self.stats.lock() {
stats.batches_loaded += 1;
stats.samples_loaded += self.config.batch_size.min(
self.indices
.len()
.saturating_sub(batch_idx * self.config.batch_size),
);
stats.total_load_time += load_time;
stats.avg_batch_time = stats.total_load_time / stats.batches_loaded as u32;
stats.cache_misses += 1;
}
if let Some(ref cache) = self.cache {
if let Ok(ref batch) = result {
if let Ok(mut cache_guard) = cache.lock() {
cache_guard.insert(batch_idx, (batch.0.clone(), batch.1.clone()));
}
}
}
Some(result)
}
pub fn prefetch_iter(self) -> PrefetchingIterator<F, D> {
PrefetchingIterator::new(self)
}
}
impl<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
D: Dataset<F> + Send + Sync + Clone + 'static,
> Iterator for OptimizedDataLoader<F, D>
{
type Item = BatchResult<F>;
fn next(&mut self) -> Option<Self::Item> {
self.next_batch()
}
}
pub struct PrefetchingIterator<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
D: Dataset<F> + Send + Sync + Clone + 'static,
> {
loader: Arc<OptimizedDataLoader<F, D>>,
queue: Arc<PrefetchQueue<F>>,
worker_handle: Option<thread::JoinHandle<()>>,
expected_idx: usize,
buffer: VecDeque<(usize, BatchResult<F>)>,
}
impl<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
D: Dataset<F> + Send + Sync + Clone + 'static,
> PrefetchingIterator<F, D>
{
fn new(loader: OptimizedDataLoader<F, D>) -> Self {
let prefetch_size = loader.config.prefetch_size;
let loader = Arc::new(loader);
let queue = Arc::new(PrefetchQueue::new(prefetch_size));
let worker_loader = Arc::clone(&loader);
let worker_queue = Arc::clone(&queue);
let worker_handle = thread::spawn(move || {
let mut batch_idx = 0;
loop {
if worker_queue.stop.load(Ordering::Relaxed) {
break;
}
if batch_idx >= worker_loader.num_batches {
break;
}
let result = worker_loader.load_batch(batch_idx);
if !worker_queue.push(batch_idx, result) {
break;
}
batch_idx += 1;
}
});
Self {
loader,
queue,
worker_handle: Some(worker_handle),
expected_idx: 0,
buffer: VecDeque::new(),
}
}
}
impl<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
D: Dataset<F> + Send + Sync + Clone + 'static,
> Iterator for PrefetchingIterator<F, D>
{
type Item = BatchResult<F>;
fn next(&mut self) -> Option<Self::Item> {
if self.expected_idx >= self.loader.num_batches {
return None;
}
if let Some(pos) = self
.buffer
.iter()
.position(|(idx, _)| *idx == self.expected_idx)
{
let (_, result) = self.buffer.remove(pos).expect("Position was just found");
self.expected_idx += 1;
return Some(result);
}
let wait_start = Instant::now();
loop {
if let Some((idx, result)) = self.queue.pop() {
if idx == self.expected_idx {
self.expected_idx += 1;
if let Ok(mut stats) = self.loader.stats.lock() {
stats.prefetch_wait_time += wait_start.elapsed();
}
return Some(result);
} else {
self.buffer.push_back((idx, result));
}
} else if self.queue.is_empty() && self.queue.stop.load(Ordering::Relaxed) {
return None;
} else {
thread::sleep(Duration::from_micros(10));
}
}
}
}
impl<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
D: Dataset<F> + Send + Sync + Clone + 'static,
> Drop for PrefetchingIterator<F, D>
{
fn drop(&mut self) {
self.queue.stop();
if let Some(handle) = self.worker_handle.take() {
let _ = handle.join();
}
}
}
#[derive(Debug, Clone)]
pub struct BatchSizeOptimizationResult {
pub recommended_batch_size: usize,
pub throughput_results: Vec<(usize, f64)>,
pub memory_results: Vec<(usize, usize)>,
pub memory_limited: bool,
}
pub struct BatchSizeOptimizer {
min_batch_size: usize,
max_batch_size: usize,
warmup_batches: usize,
timing_batches: usize,
max_memory: usize,
}
impl Default for BatchSizeOptimizer {
fn default() -> Self {
Self {
min_batch_size: 8,
max_batch_size: 512,
warmup_batches: 2,
timing_batches: 5,
max_memory: 0,
}
}
}
impl BatchSizeOptimizer {
pub fn new() -> Self {
Self::default()
}
pub fn with_range(mut self, min: usize, max: usize) -> Self {
self.min_batch_size = min;
self.max_batch_size = max;
self
}
pub fn with_max_memory(mut self, max_memory: usize) -> Self {
self.max_memory = max_memory;
self
}
pub fn find_optimal<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
D: Dataset<F> + Send + Sync + Clone + 'static,
>(
&self,
dataset: D,
) -> Result<BatchSizeOptimizationResult> {
let mut throughput_results = Vec::new();
let mut memory_results = Vec::new();
let mut best_throughput = 0.0;
let mut best_batch_size = self.min_batch_size;
let mut memory_limited = false;
let mut batch_size = self.min_batch_size;
while batch_size <= self.max_batch_size && batch_size <= dataset.len() {
let config = OptimizedLoaderConfig {
batch_size,
shuffle: false,
drop_last: true,
..Default::default()
};
let mut loader = OptimizedDataLoader::new(dataset.clone(), config);
loader.reset();
for _ in 0..self.warmup_batches {
if loader.next_batch().is_none() {
break;
}
}
let start = Instant::now();
let mut batches_processed = 0;
let mut total_memory = 0;
for _ in 0..self.timing_batches {
match loader.next_batch() {
Some(Ok((x, y))) => {
batches_processed += 1;
total_memory += estimate_array_memory(&x) + estimate_array_memory(&y);
}
Some(Err(_)) => break,
None => break,
}
}
if batches_processed == 0 {
break;
}
let elapsed = start.elapsed().as_secs_f64();
let samples_per_second = (batches_processed * batch_size) as f64 / elapsed;
let avg_memory = total_memory / batches_processed;
throughput_results.push((batch_size, samples_per_second));
memory_results.push((batch_size, avg_memory));
if self.max_memory > 0 && avg_memory > self.max_memory {
memory_limited = true;
break;
}
if samples_per_second > best_throughput {
best_throughput = samples_per_second;
best_batch_size = batch_size;
}
batch_size = (batch_size * 2).min(self.max_batch_size + 1);
}
Ok(BatchSizeOptimizationResult {
recommended_batch_size: best_batch_size,
throughput_results,
memory_results,
memory_limited,
})
}
}
#[derive(Debug, Clone)]
pub struct MemoryAwareConfig {
pub target_memory_fraction: f64,
pub bytes_per_sample: Option<usize>,
pub min_batch_size: usize,
pub max_batch_size: usize,
pub shuffle: bool,
pub drop_last: bool,
pub prefetch_ahead: usize,
}
impl Default for MemoryAwareConfig {
fn default() -> Self {
Self {
target_memory_fraction: 0.25,
bytes_per_sample: None,
min_batch_size: 4,
max_batch_size: 4096,
shuffle: true,
drop_last: false,
prefetch_ahead: 2,
}
}
}
fn estimate_available_memory_bytes() -> usize {
#[cfg(target_os = "linux")]
{
if let Ok(contents) = std::fs::read_to_string("/proc/meminfo") {
for line in contents.lines() {
if line.starts_with("MemAvailable:") {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 2 {
if let Ok(kb) = parts[1].parse::<usize>() {
return kb * 1024;
}
}
}
}
}
}
512 * 1024 * 1024
}
fn compute_adaptive_batch_size(
dataset_len: usize,
bytes_per_sample: usize,
config: &MemoryAwareConfig,
) -> usize {
let chunk_cfg = ChunkConfig {
strategy: ChunkStrategy::Adaptive,
min_chunk_size: config.min_batch_size,
max_chunk_size: config.max_batch_size,
..ChunkConfig::default()
};
let chunking_hint = ChunkingUtils::optimal_chunk_size(dataset_len, &chunk_cfg);
let available = estimate_available_memory_bytes();
let budget_bytes = ((available as f64) * config.target_memory_fraction) as usize;
let budget_samples = budget_bytes
.checked_div(bytes_per_sample)
.map(|v| v.max(1))
.unwrap_or(config.max_batch_size);
let raw = chunking_hint.min(budget_samples);
raw.max(config.min_batch_size).min(config.max_batch_size)
}
pub struct MemoryAwareDataLoader<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
D: Dataset<F> + Send + Sync + Clone + 'static,
> {
dataset: Arc<D>,
config: MemoryAwareConfig,
indices: Vec<usize>,
position: AtomicUsize,
batch_size: usize,
num_batches: usize,
stats: Mutex<LoadingStats>,
_phantom: PhantomData<F>,
}
impl<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
D: Dataset<F> + Send + Sync + Clone + 'static,
> MemoryAwareDataLoader<F, D>
{
pub fn new_adaptive(dataset: D, config: MemoryAwareConfig) -> Result<Self> {
let dataset_len = dataset.len();
if dataset_len == 0 {
return Err(NeuralError::TrainingError(
"Cannot create MemoryAwareDataLoader from an empty dataset".to_string(),
));
}
let bytes_per_sample = match config.bytes_per_sample {
Some(b) => b,
None => {
let (x0, y0) = dataset.get(0)?;
(x0.len() + y0.len()) * std::mem::size_of::<F>()
}
};
let batch_size = compute_adaptive_batch_size(dataset_len, bytes_per_sample, &config);
let num_batches = if config.drop_last {
dataset_len / batch_size
} else {
dataset_len.div_ceil(batch_size)
};
let indices: Vec<usize> = (0..dataset_len).collect();
Ok(Self {
dataset: Arc::new(dataset),
config,
indices,
position: AtomicUsize::new(0),
batch_size,
num_batches,
stats: Mutex::new(LoadingStats::default()),
_phantom: PhantomData,
})
}
pub fn refresh_batch_size(&mut self) -> Result<usize> {
let dataset_len = self.dataset.len();
let bytes_per_sample = match self.config.bytes_per_sample {
Some(b) => b,
None => {
let (x0, y0) = self.dataset.get(0)?;
(x0.len() + y0.len()) * std::mem::size_of::<F>()
}
};
let new_batch_size =
compute_adaptive_batch_size(dataset_len, bytes_per_sample, &self.config);
self.batch_size = new_batch_size;
self.num_batches = if self.config.drop_last {
dataset_len / new_batch_size
} else {
dataset_len.div_ceil(new_batch_size)
};
Ok(new_batch_size)
}
pub fn adaptive_batch_size(&self) -> usize {
self.batch_size
}
pub fn num_batches(&self) -> usize {
self.num_batches
}
pub fn len(&self) -> usize {
self.dataset.len()
}
pub fn is_empty(&self) -> bool {
self.dataset.len() == 0
}
pub fn stats(&self) -> LoadingStats {
self.stats
.lock()
.map_or_else(|_| LoadingStats::default(), |s| s.clone())
}
pub fn reset(&mut self) {
if self.config.shuffle {
let mut rng = scirs2_core::random::rng();
self.indices.shuffle(&mut rng);
}
self.position.store(0, Ordering::Relaxed);
}
fn load_batch(&self, batch_idx: usize) -> BatchResult<F> {
let start = batch_idx * self.batch_size;
let end = (start + self.batch_size).min(self.indices.len());
if start >= self.indices.len() {
return Err(NeuralError::TrainingError(
"Batch index out of range".to_string(),
));
}
let batch_indices: Vec<usize> = self.indices[start..end].to_vec();
if batch_indices.is_empty() {
return Err(NeuralError::TrainingError("Empty batch".to_string()));
}
let (first_x, first_y) = self.dataset.get(batch_indices[0])?;
let batch_x_shape: Vec<usize> = std::iter::once(batch_indices.len())
.chain(first_x.shape().iter().copied())
.collect();
let batch_y_shape: Vec<usize> = std::iter::once(batch_indices.len())
.chain(first_y.shape().iter().copied())
.collect();
let mut batch_x = Array::zeros(IxDyn(&batch_x_shape));
let mut batch_y = Array::zeros(IxDyn(&batch_y_shape));
for (i, &idx) in batch_indices.iter().enumerate() {
let (x, y) = self.dataset.get(idx)?;
let mut sx = batch_x.slice_mut(scirs2_core::ndarray::s![i, ..]);
sx.assign(&x);
let mut sy = batch_y.slice_mut(scirs2_core::ndarray::s![i, ..]);
sy.assign(&y);
}
Ok((batch_x, batch_y))
}
pub fn next_batch(&self) -> Option<BatchResult<F>> {
let batch_idx = self.position.fetch_add(1, Ordering::Relaxed);
if batch_idx >= self.num_batches {
return None;
}
let start_time = Instant::now();
let result = self.load_batch(batch_idx);
let elapsed = start_time.elapsed();
if let Ok(mut stats) = self.stats.lock() {
stats.batches_loaded += 1;
stats.samples_loaded += self.batch_size.min(
self.indices
.len()
.saturating_sub(batch_idx * self.batch_size),
);
stats.total_load_time += elapsed;
stats.avg_batch_time = stats.total_load_time / stats.batches_loaded as u32;
stats.cache_misses += 1;
}
Some(result)
}
pub fn into_prefetch_iter(self) -> MemoryAwarePrefetchIter<F, D> {
MemoryAwarePrefetchIter::new(self)
}
}
impl<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
D: Dataset<F> + Send + Sync + Clone + 'static,
> Iterator for MemoryAwareDataLoader<F, D>
{
type Item = BatchResult<F>;
fn next(&mut self) -> Option<Self::Item> {
self.next_batch()
}
}
pub struct MemoryAwarePrefetchIter<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
D: Dataset<F> + Send + Sync + Clone + 'static,
> {
loader: Arc<MemoryAwareDataLoader<F, D>>,
queue: Arc<PrefetchQueue<F>>,
worker: Option<thread::JoinHandle<()>>,
expected_idx: usize,
out_of_order: VecDeque<(usize, BatchResult<F>)>,
}
impl<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
D: Dataset<F> + Send + Sync + Clone + 'static,
> MemoryAwarePrefetchIter<F, D>
{
fn new(loader: MemoryAwareDataLoader<F, D>) -> Self {
let prefetch_ahead = loader.config.prefetch_ahead;
let num_batches = loader.num_batches;
let loader = Arc::new(loader);
let queue = Arc::new(PrefetchQueue::new(prefetch_ahead));
let worker_loader = Arc::clone(&loader);
let worker_queue = Arc::clone(&queue);
let worker = thread::spawn(move || {
for batch_idx in 0..num_batches {
if worker_queue.stop.load(Ordering::Relaxed) {
break;
}
let result = worker_loader.load_batch(batch_idx);
if !worker_queue.push(batch_idx, result) {
break;
}
}
});
Self {
loader,
queue,
worker: Some(worker),
expected_idx: 0,
out_of_order: VecDeque::new(),
}
}
}
impl<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
D: Dataset<F> + Send + Sync + Clone + 'static,
> Iterator for MemoryAwarePrefetchIter<F, D>
{
type Item = BatchResult<F>;
fn next(&mut self) -> Option<Self::Item> {
if self.expected_idx >= self.loader.num_batches {
return None;
}
if let Some(pos) = self
.out_of_order
.iter()
.position(|(idx, _)| *idx == self.expected_idx)
{
let (_, result) = self
.out_of_order
.remove(pos)
.expect("position was just found in out_of_order buffer");
self.expected_idx += 1;
return Some(result);
}
let wait_start = Instant::now();
loop {
if let Some((idx, result)) = self.queue.pop() {
if idx == self.expected_idx {
if let Ok(mut stats) = self.loader.stats.lock() {
stats.prefetch_wait_time += wait_start.elapsed();
}
self.expected_idx += 1;
return Some(result);
}
self.out_of_order.push_back((idx, result));
} else if self.queue.is_empty() && self.queue.stop.load(Ordering::Relaxed) {
return None;
} else {
thread::sleep(Duration::from_micros(10));
}
}
}
}
impl<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
D: Dataset<F> + Send + Sync + Clone + 'static,
> Drop for MemoryAwarePrefetchIter<F, D>
{
fn drop(&mut self) {
self.queue.stop();
if let Some(handle) = self.worker.take() {
let _ = handle.join();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::InMemoryDataset;
fn create_test_dataset() -> InMemoryDataset<f64> {
let features = Array::zeros(IxDyn(&[100, 10]));
let labels = Array::zeros(IxDyn(&[100, 2]));
InMemoryDataset::new(features, labels).expect("Failed to create test dataset")
}
#[test]
fn test_optimized_loader_config_default() {
let config = OptimizedLoaderConfig::default();
assert_eq!(config.batch_size, 32);
assert_eq!(config.prefetch_size, 2);
assert_eq!(config.num_workers, 0);
assert!(!config.drop_last);
assert!(config.shuffle);
}
#[test]
fn test_optimized_dataloader_creation() {
let dataset = create_test_dataset();
let config = OptimizedLoaderConfig {
batch_size: 10,
shuffle: false,
..Default::default()
};
let loader = OptimizedDataLoader::new(dataset, config);
assert_eq!(loader.len(), 100);
assert_eq!(loader.num_batches(), 10);
}
#[test]
fn test_optimized_dataloader_iteration() {
let dataset = create_test_dataset();
let config = OptimizedLoaderConfig {
batch_size: 10,
shuffle: false,
drop_last: true,
..Default::default()
};
let mut loader = OptimizedDataLoader::new(dataset, config);
loader.reset();
let mut batch_count = 0;
while let Some(result) = loader.next_batch() {
let (x, y) = result.expect("Failed to load batch");
assert_eq!(x.shape()[0], 10);
assert_eq!(y.shape()[0], 10);
batch_count += 1;
}
assert_eq!(batch_count, 10);
}
#[test]
fn test_optimized_dataloader_stats() {
let dataset = create_test_dataset();
let config = OptimizedLoaderConfig {
batch_size: 20,
shuffle: false,
..Default::default()
};
let mut loader = OptimizedDataLoader::new(dataset, config);
loader.reset();
while loader.next_batch().is_some() {}
let stats = loader.stats();
assert_eq!(stats.batches_loaded, 5);
assert_eq!(stats.samples_loaded, 100);
}
#[test]
fn test_batch_cache() {
let mut cache: BatchCache<f64> = BatchCache::new(10);
let batch1 = (Array::zeros(IxDyn(&[5, 10])), Array::zeros(IxDyn(&[5, 2])));
cache.insert(0, batch1.clone());
let cached = cache.get(0);
assert!(cached.is_some());
assert_eq!(cached.map(|b| b.0.shape()[0]), Some(5));
assert!(cache.get(1).is_none());
cache.clear();
assert!(cache.get(0).is_none());
}
#[test]
fn test_prefetch_queue() {
let queue: PrefetchQueue<f64> = PrefetchQueue::new(3);
let batch = Ok((Array::zeros(IxDyn(&[5, 10])), Array::zeros(IxDyn(&[5, 2]))));
assert!(queue.push(0, batch));
assert!(!queue.is_empty());
let popped = queue.pop();
assert!(popped.is_some());
assert_eq!(popped.map(|(idx, _)| idx), Some(0));
assert!(queue.is_empty());
queue.stop();
let batch2 = Ok((Array::zeros(IxDyn(&[5, 10])), Array::zeros(IxDyn(&[5, 2]))));
assert!(!queue.push(1, batch2));
}
#[test]
fn test_loading_stats_default() {
let stats = LoadingStats::default();
assert_eq!(stats.batches_loaded, 0);
assert_eq!(stats.samples_loaded, 0);
assert_eq!(stats.cache_hits, 0);
assert_eq!(stats.cache_misses, 0);
}
#[test]
fn test_estimate_array_memory() {
let array: Array<f64, IxDyn> = Array::zeros(IxDyn(&[10, 20]));
let memory = estimate_array_memory(&array);
assert_eq!(memory, 10 * 20 * std::mem::size_of::<f64>());
}
#[test]
fn test_batch_size_optimizer_default() {
let optimizer = BatchSizeOptimizer::default();
assert_eq!(optimizer.min_batch_size, 8);
assert_eq!(optimizer.max_batch_size, 512);
}
#[test]
fn test_batch_size_optimizer_with_range() {
let optimizer = BatchSizeOptimizer::new()
.with_range(16, 256)
.with_max_memory(1024 * 1024);
assert_eq!(optimizer.min_batch_size, 16);
assert_eq!(optimizer.max_batch_size, 256);
assert_eq!(optimizer.max_memory, 1024 * 1024);
}
#[test]
fn test_find_optimal_batch_size() {
let dataset = create_test_dataset();
let optimizer = BatchSizeOptimizer::new().with_range(10, 50);
let result = optimizer.find_optimal(dataset);
assert!(result.is_ok());
let result = result.expect("Optimization should succeed");
assert!(result.recommended_batch_size >= 10);
assert!(result.recommended_batch_size <= 50);
assert!(!result.throughput_results.is_empty());
}
#[test]
fn test_dataloader_with_caching() {
let dataset = create_test_dataset();
let config = OptimizedLoaderConfig {
batch_size: 10,
shuffle: false,
cache_batches: true,
..Default::default()
};
let mut loader = OptimizedDataLoader::new(dataset, config);
loader.reset();
while loader.next_batch().is_some() {}
let stats = loader.stats();
assert_eq!(stats.cache_misses, 10);
assert_eq!(stats.cache_hits, 0);
}
#[test]
fn test_iterator_trait() {
let dataset = create_test_dataset();
let config = OptimizedLoaderConfig {
batch_size: 25,
shuffle: false,
drop_last: true,
..Default::default()
};
let mut loader = OptimizedDataLoader::new(dataset, config);
loader.reset();
let batches: Vec<_> = loader.collect();
assert_eq!(batches.len(), 4); }
#[test]
fn test_memory_aware_config_default() {
let cfg = MemoryAwareConfig::default();
assert!(
cfg.target_memory_fraction > 0.0 && cfg.target_memory_fraction <= 1.0,
"target_memory_fraction must be in (0, 1]"
);
assert!(cfg.min_batch_size >= 1);
assert!(cfg.max_batch_size >= cfg.min_batch_size);
}
#[test]
fn test_memory_aware_loader_creation() {
let dataset = create_test_dataset();
let config = MemoryAwareConfig {
shuffle: false,
drop_last: false,
..Default::default()
};
let loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
.expect("loader creation must succeed");
let bs = loader.adaptive_batch_size();
assert!(bs >= 4, "batch_size ({bs}) must be >= min_batch_size (4)");
assert!(
bs <= 4096,
"batch_size ({bs}) must be <= max_batch_size (4096)"
);
assert!(loader.num_batches() >= 1);
assert_eq!(loader.len(), 100);
assert!(!loader.is_empty());
}
#[test]
fn test_memory_aware_loader_iteration_all_samples() {
let dataset = create_test_dataset();
let config = MemoryAwareConfig {
shuffle: false,
drop_last: false,
min_batch_size: 10,
max_batch_size: 10,
target_memory_fraction: 1.0, ..Default::default()
};
let mut loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
.expect("loader creation must succeed");
loader.reset();
let mut total_samples = 0usize;
let mut batch_count = 0usize;
while let Some(result) = loader.next_batch() {
let (x, _y) = result.expect("batch load must succeed");
total_samples += x.shape()[0];
batch_count += 1;
}
assert_eq!(total_samples, 100, "all 100 samples must be yielded");
assert_eq!(batch_count, 10, "100 samples / batch_size 10 = 10 batches");
}
#[test]
fn test_memory_aware_loader_drop_last() {
let dataset = create_test_dataset();
let config = MemoryAwareConfig {
shuffle: false,
drop_last: true,
min_batch_size: 32,
max_batch_size: 32,
target_memory_fraction: 1.0,
..Default::default()
};
let mut loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
.expect("loader creation must succeed");
loader.reset();
let batches: Vec<_> = loader.collect();
assert_eq!(batches.len(), 3, "drop_last: 100/32 = 3 full batches");
}
#[test]
fn test_memory_aware_loader_refresh_batch_size() {
let dataset = create_test_dataset();
let config = MemoryAwareConfig {
shuffle: false,
..Default::default()
};
let mut loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
.expect("loader creation must succeed");
let new_bs = loader.refresh_batch_size().expect("refresh must succeed");
assert!(new_bs >= loader.config.min_batch_size);
assert!(new_bs <= loader.config.max_batch_size);
assert_eq!(new_bs, loader.adaptive_batch_size());
}
#[test]
fn test_memory_aware_loader_stats() {
let dataset = create_test_dataset();
let config = MemoryAwareConfig {
shuffle: false,
drop_last: false,
min_batch_size: 10,
max_batch_size: 10,
target_memory_fraction: 1.0,
..Default::default()
};
let mut loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
.expect("loader creation must succeed");
loader.reset();
while loader.next_batch().is_some() {}
let stats = loader.stats();
assert_eq!(stats.batches_loaded, 10);
assert_eq!(stats.samples_loaded, 100);
}
#[test]
fn test_memory_aware_prefetch_iter() {
let dataset = create_test_dataset();
let config = MemoryAwareConfig {
shuffle: false,
drop_last: false,
min_batch_size: 10,
max_batch_size: 10,
target_memory_fraction: 1.0,
prefetch_ahead: 2,
..Default::default()
};
let mut loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
.expect("loader creation must succeed");
loader.reset();
let iter = loader.into_prefetch_iter();
let batches: Vec<_> = iter.collect();
for batch_result in &batches {
let (x, _y) = batch_result
.as_ref()
.expect("prefetch batch must not be an error");
assert_eq!(x.shape()[0], 10);
}
assert_eq!(batches.len(), 10);
}
#[test]
fn test_estimate_available_memory_is_positive() {
let mem = estimate_available_memory_bytes();
assert!(mem > 0, "available memory estimate must be > 0");
}
#[test]
fn test_compute_adaptive_batch_size_bounds() {
let config = MemoryAwareConfig {
min_batch_size: 8,
max_batch_size: 64,
target_memory_fraction: 0.1,
bytes_per_sample: Some(1024),
..Default::default()
};
let bs = compute_adaptive_batch_size(1000, 1024, &config);
assert!(bs >= 8, "must respect min_batch_size");
assert!(bs <= 64, "must respect max_batch_size");
}
}