use super::*;
use crate::tensor::{DType, TensorOptions, test_device};
struct SimpleData {
x: Tensor,
y: Tensor,
}
impl DataSet for SimpleData {
fn len(&self) -> usize {
self.x.shape()[0] as usize
}
fn get(&self, index: usize) -> Result<Vec<Tensor>> {
Ok(vec![
self.x.select(0, index as i64)?,
self.y.select(0, index as i64)?,
])
}
}
struct SequentialData {
n: usize,
}
impl DataSet for SequentialData {
fn len(&self) -> usize {
self.n
}
fn get(&self, index: usize) -> Result<Vec<Tensor>> {
Ok(vec![
Tensor::from_f32(&[index as f32], &[1], Device::CPU)?,
])
}
}
struct PairBatch {
x: Tensor,
y: Tensor,
}
impl BatchDataSet for PairBatch {
fn len(&self) -> usize {
self.x.shape()[0] as usize
}
fn get_batch(&self, indices: &[usize]) -> Result<Vec<Tensor>> {
let idx: Vec<i64> = indices.iter().map(|&i| i as i64).collect();
let idx_t = Tensor::from_i64(&idx, &[idx.len() as i64], Device::CPU)?;
Ok(vec![
self.x.index_select(0, &idx_t)?,
self.y.index_select(0, &idx_t)?,
])
}
}
fn make_data(n: usize) -> SimpleData {
let opts = TensorOptions { dtype: DType::Float32, device: Device::CPU };
SimpleData {
x: Tensor::randn(&[n as i64, 4], opts).unwrap(),
y: Tensor::randn(&[n as i64, 2], opts).unwrap(),
}
}
fn make_cpu_data_for_device(n: usize) -> SimpleData {
let opts = TensorOptions { dtype: DType::Float32, device: Device::CPU };
SimpleData {
x: Tensor::randn(&[n as i64, 4], opts).unwrap(),
y: Tensor::randn(&[n as i64, 2], opts).unwrap(),
}
}
#[test]
fn test_basic_epoch_iteration() {
let data = make_data(20);
let mut loader = DataLoader::from_dataset(data)
.batch_size(5)
.build()
.unwrap();
let batches: Vec<Batch> = loader.epoch(0).map(|b| b.unwrap()).collect();
assert_eq!(batches.len(), 4); for b in &batches {
assert_eq!(b.len(), 2); assert_eq!(b[0].shape(), &[5, 4]);
assert_eq!(b[1].shape(), &[5, 2]);
}
}
#[test]
fn test_drop_last_true() {
let data = make_data(22);
let mut loader = DataLoader::from_dataset(data)
.batch_size(5)
.drop_last(true)
.build()
.unwrap();
let batches: Vec<Batch> = loader.epoch(0).map(|b| b.unwrap()).collect();
assert_eq!(batches.len(), 4); }
#[test]
fn test_drop_last_false() {
let data = make_data(22);
let mut loader = DataLoader::from_dataset(data)
.batch_size(5)
.drop_last(false)
.build()
.unwrap();
let batches: Vec<Batch> = loader.epoch(0).map(|b| b.unwrap()).collect();
assert_eq!(batches.len(), 5); assert_eq!(batches[4][0].shape(), &[2, 4]); }
#[test]
fn test_sequential_sampler() {
let mut loader = DataLoader::from_dataset(SequentialData { n: 10 })
.batch_size(3)
.shuffle(false)
.drop_last(false)
.build()
.unwrap();
let e0: Vec<f32> = loader
.epoch(0)
.flat_map(|b| {
let b = b.unwrap();
b[0].to_f32_vec().unwrap()
})
.collect();
let e1: Vec<f32> = loader
.epoch(1)
.flat_map(|b| {
let b = b.unwrap();
b[0].to_f32_vec().unwrap()
})
.collect();
assert_eq!(e0, e1);
assert_eq!(e0, vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
}
#[test]
fn test_shuffle_different_epochs() {
let mut loader = DataLoader::from_dataset(SequentialData { n: 20 })
.batch_size(20)
.drop_last(false)
.build()
.unwrap();
let e0: Vec<f32> = loader.epoch(0).next().unwrap().unwrap()[0]
.to_f32_vec()
.unwrap();
let e1: Vec<f32> = loader.epoch(1).next().unwrap().unwrap()[0]
.to_f32_vec()
.unwrap();
assert_ne!(e0, e1);
}
#[test]
fn test_shuffle_reproducible() {
let data1 = SequentialData { n: 20 };
let data2 = SequentialData { n: 20 };
let mut l1 = DataLoader::from_dataset(data1)
.batch_size(20)
.seed(99)
.drop_last(false)
.build()
.unwrap();
let mut l2 = DataLoader::from_dataset(data2)
.batch_size(20)
.seed(99)
.drop_last(false)
.build()
.unwrap();
let e1: Vec<f32> = l1.epoch(3).next().unwrap().unwrap()[0]
.to_f32_vec()
.unwrap();
let e2: Vec<f32> = l2.epoch(3).next().unwrap().unwrap()[0]
.to_f32_vec()
.unwrap();
assert_eq!(e1, e2);
}
#[test]
fn test_all_samples_visited() {
let mut loader = DataLoader::from_dataset(SequentialData { n: 10 })
.batch_size(3)
.drop_last(false)
.build()
.unwrap();
let mut vals: Vec<f32> = loader
.epoch(0)
.flat_map(|b| {
let b = b.unwrap();
b[0].to_f32_vec().unwrap()
})
.collect();
vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
assert_eq!(
vals,
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
);
}
#[test]
fn test_batch_dataset_path() {
let opts = TensorOptions { dtype: DType::Float32, device: Device::CPU };
let batch_ds = PairBatch {
x: Tensor::randn(&[30, 8], opts).unwrap(),
y: Tensor::randn(&[30, 3], opts).unwrap(),
};
let mut loader = DataLoader::from_batch_dataset(batch_ds)
.batch_size(10)
.build()
.unwrap();
let batches: Vec<Batch> = loader.epoch(0).map(|b| b.unwrap()).collect();
assert_eq!(batches.len(), 3);
assert_eq!(batches[0][0].shape(), &[10, 8]);
assert_eq!(batches[0][1].shape(), &[10, 3]);
}
#[test]
fn test_exact_size_iterator() {
let data = make_data(20);
let mut loader = DataLoader::from_dataset(data)
.batch_size(5)
.build()
.unwrap();
let iter = loader.epoch(0);
assert_eq!(iter.len(), 4);
}
#[test]
fn test_loader_metadata() {
let data = make_data(50);
let loader = DataLoader::from_dataset(data)
.batch_size(8)
.build()
.unwrap();
assert_eq!(loader.len(), 50);
assert_eq!(loader.batch_size(), 8);
assert_eq!(loader.num_batches(), 6); assert!(!loader.is_empty());
assert!(loader.is_resident());
}
#[test]
fn test_empty_dataset_errors() {
struct Empty;
impl DataSet for Empty {
fn len(&self) -> usize { 0 }
fn get(&self, _: usize) -> Result<Vec<Tensor>> { unreachable!() }
}
let result = DataLoader::from_dataset(Empty)
.batch_size(10)
.build();
assert!(result.is_err());
}
#[test]
fn test_zero_batch_size_errors() {
let data = make_data(10);
let result = DataLoader::from_dataset(data)
.batch_size(0)
.build();
assert!(result.is_err());
}
#[test]
fn test_batch_size_larger_than_dataset() {
let data = make_data(5);
let mut loader = DataLoader::from_dataset(data)
.batch_size(100)
.drop_last(false)
.build()
.unwrap();
let batches: Vec<Batch> = loader.epoch(0).map(|b| b.unwrap()).collect();
assert_eq!(batches.len(), 1);
assert_eq!(batches[0][0].shape(), &[5, 4]);
}
#[test]
fn test_batch_size_larger_than_dataset_drop_last() {
let data = make_data(5);
let mut loader = DataLoader::from_dataset(data)
.batch_size(100)
.drop_last(true)
.build()
.unwrap();
let batches: Vec<Batch> = loader.epoch(0).map(|b| b.unwrap()).collect();
assert_eq!(batches.len(), 0);
}
#[test]
fn test_device_aware_loading() {
let data = make_cpu_data_for_device(20);
let dev = test_device();
let mut loader = DataLoader::from_dataset(data)
.batch_size(5)
.device(dev)
.build()
.unwrap();
assert_eq!(loader.device(), dev);
let b = loader.epoch(0).next().unwrap().unwrap();
assert_eq!(b[0].device(), dev);
assert_eq!(b[1].device(), dev);
}
#[test]
fn test_multi_target_dataset() {
struct FbrlLike {
images: Tensor,
letters: Tensor,
cases: Tensor,
origins: Tensor,
}
impl DataSet for FbrlLike {
fn len(&self) -> usize { self.images.shape()[0] as usize }
fn get(&self, i: usize) -> Result<Vec<Tensor>> {
Ok(vec![
self.images.select(0, i as i64)?,
self.letters.select(0, i as i64)?,
self.cases.select(0, i as i64)?,
self.origins.select(0, i as i64)?,
])
}
}
let opts = TensorOptions { dtype: DType::Float32, device: Device::CPU };
let data = FbrlLike {
images: Tensor::randn(&[16, 3, 8, 8], opts).unwrap(),
letters: Tensor::randn(&[16, 26], opts).unwrap(),
cases: Tensor::randn(&[16, 2], opts).unwrap(),
origins: Tensor::randn(&[16, 5], opts).unwrap(),
};
let mut loader = DataLoader::from_dataset(data)
.batch_size(4)
.build()
.unwrap();
let b = loader.epoch(0).next().unwrap().unwrap();
assert_eq!(b.len(), 4);
assert_eq!(b[0].shape(), &[4, 3, 8, 8]); assert_eq!(b[1].shape(), &[4, 26]); assert_eq!(b[2].shape(), &[4, 2]); assert_eq!(b[3].shape(), &[4, 5]); }
#[test]
fn test_streaming_basic_epoch() {
let data = make_data(20);
let mut loader = DataLoader::from_dataset(data)
.batch_size(5)
.streaming()
.build()
.unwrap();
assert!(!loader.is_resident());
let batches: Vec<Batch> = loader.epoch(0).map(|b| b.unwrap()).collect();
assert_eq!(batches.len(), 4);
for b in &batches {
assert_eq!(b.len(), 2);
assert_eq!(b[0].shape(), &[5, 4]);
assert_eq!(b[1].shape(), &[5, 2]);
}
}
#[test]
fn test_streaming_drop_last() {
let data = make_data(22);
let mut loader = DataLoader::from_dataset(data)
.batch_size(5)
.drop_last(true)
.streaming()
.build()
.unwrap();
let batches: Vec<Batch> = loader.epoch(0).map(|b| b.unwrap()).collect();
assert_eq!(batches.len(), 4); }
#[test]
fn test_streaming_drop_last_false() {
let data = make_data(22);
let mut loader = DataLoader::from_dataset(data)
.batch_size(5)
.drop_last(false)
.streaming()
.build()
.unwrap();
let batches: Vec<Batch> = loader.epoch(0).map(|b| b.unwrap()).collect();
assert_eq!(batches.len(), 5); assert_eq!(batches[4][0].shape(), &[2, 4]);
}
#[test]
fn test_streaming_all_samples_visited() {
let mut loader = DataLoader::from_dataset(SequentialData { n: 10 })
.batch_size(3)
.drop_last(false)
.streaming()
.build()
.unwrap();
let mut vals: Vec<f32> = loader
.epoch(0)
.flat_map(|b| {
let b = b.unwrap();
b[0].to_f32_vec().unwrap()
})
.collect();
vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
assert_eq!(vals, vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
}
#[test]
fn test_streaming_multiple_epochs() {
let mut loader = DataLoader::from_dataset(SequentialData { n: 20 })
.batch_size(20)
.drop_last(false)
.streaming()
.build()
.unwrap();
let e0: Vec<f32> = loader.epoch(0).next().unwrap().unwrap()[0]
.to_f32_vec()
.unwrap();
let e1: Vec<f32> = loader.epoch(1).next().unwrap().unwrap()[0]
.to_f32_vec()
.unwrap();
assert_ne!(e0, e1);
assert_eq!(e0.len(), 20);
assert_eq!(e1.len(), 20);
}
#[test]
fn test_streaming_sequential() {
let mut loader = DataLoader::from_dataset(SequentialData { n: 10 })
.batch_size(3)
.shuffle(false)
.drop_last(false)
.streaming()
.build()
.unwrap();
let vals: Vec<f32> = loader
.epoch(0)
.flat_map(|b| b.unwrap()[0].to_f32_vec().unwrap())
.collect();
assert_eq!(vals, vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
}
#[test]
fn test_streaming_multi_target() {
struct Multi {
a: Tensor,
b: Tensor,
c: Tensor,
}
impl DataSet for Multi {
fn len(&self) -> usize { self.a.shape()[0] as usize }
fn get(&self, i: usize) -> Result<Vec<Tensor>> {
Ok(vec![
self.a.select(0, i as i64)?,
self.b.select(0, i as i64)?,
self.c.select(0, i as i64)?,
])
}
}
let opts = TensorOptions { dtype: DType::Float32, device: Device::CPU };
let data = Multi {
a: Tensor::randn(&[12, 4], opts).unwrap(),
b: Tensor::randn(&[12, 8], opts).unwrap(),
c: Tensor::randn(&[12, 2], opts).unwrap(),
};
let mut loader = DataLoader::from_dataset(data)
.batch_size(4)
.streaming()
.build()
.unwrap();
let b = loader.epoch(0).next().unwrap().unwrap();
assert_eq!(b.len(), 3);
assert_eq!(b[0].shape(), &[4, 4]);
assert_eq!(b[1].shape(), &[4, 8]);
assert_eq!(b[2].shape(), &[4, 2]);
}
#[test]
fn test_streaming_drop_mid_epoch() {
let data = make_data(100);
let mut loader = DataLoader::from_dataset(data)
.batch_size(10)
.streaming()
.build()
.unwrap();
{
let mut iter = loader.epoch(0);
let _ = iter.next().unwrap().unwrap();
let _ = iter.next().unwrap().unwrap();
}
let batches: Vec<Batch> = loader.epoch(1).map(|b| b.unwrap()).collect();
assert_eq!(batches.len(), 10);
}
#[test]
fn test_named_batch_via_loader() {
let data = make_data(20);
let mut loader = DataLoader::from_dataset(data)
.batch_size(5)
.names(&["input", "target"])
.build()
.unwrap();
let b = loader.epoch(0).next().unwrap().unwrap();
assert_eq!(b.names(), &["input", "target"]);
assert_eq!(b["input"].shape(), &[5, 4]);
assert_eq!(b["target"].shape(), &[5, 2]);
assert!(b.has("input"));
assert!(b.has("target"));
assert!(!b.has("missing"));
}
#[test]
fn test_named_batch_streaming() {
let data = make_data(20);
let mut loader = DataLoader::from_dataset(data)
.batch_size(5)
.names(&["x", "y"])
.streaming()
.build()
.unwrap();
let b = loader.epoch(0).next().unwrap().unwrap();
assert_eq!(b.names(), &["x", "y"]);
assert_eq!(b["x"].shape(), &[5, 4]);
assert_eq!(b["y"].shape(), &[5, 2]);
}
#[test]
fn test_names_count_mismatch_errors() {
let data = make_data(10);
let result = DataLoader::from_dataset(data)
.batch_size(5)
.names(&["only_one"])
.build();
assert!(result.is_err());
}
#[test]
fn test_auto_names_when_unspecified() {
let data = make_data(10);
let mut loader = DataLoader::from_dataset(data)
.batch_size(5)
.build()
.unwrap();
assert_eq!(loader.names(), &["0", "1"]);
let b = loader.epoch(0).next().unwrap().unwrap();
assert_eq!(b["0"].shape(), &[5, 4]);
assert_eq!(b["1"].shape(), &[5, 2]);
}
#[test]
fn test_graph_set_data_loader_single_gpu() {
use crate::graph::FlowBuilder;
use crate::nn::{Adam, Linear, Module, ReLU, mse_loss};
let model = FlowBuilder::from(Linear::new(4, 8).unwrap())
.through(ReLU::new())
.through(Linear::new(8, 2).unwrap())
.build()
.unwrap();
let opts = TensorOptions { dtype: DType::Float32, device: Device::CPU };
struct TrainData { x: Tensor, y: Tensor }
impl super::DataSet for TrainData {
fn len(&self) -> usize { self.x.shape()[0] as usize }
fn get(&self, i: usize) -> Result<Vec<Tensor>> {
Ok(vec![
self.x.select(0, i as i64)?,
self.y.select(0, i as i64)?,
])
}
}
let data = TrainData {
x: Tensor::randn(&[20, 4], opts).unwrap(),
y: Tensor::randn(&[20, 2], opts).unwrap(),
};
let loader = DataLoader::from_dataset(data)
.batch_size(5)
.names(&["input", "target"])
.build()
.unwrap();
model.set_data_loader(loader, "input").unwrap();
model.set_optimizer(|p| Adam::new(p, 0.01));
model.set_training(true);
let params_before: Vec<f32> = model
.parameters()
.iter()
.flat_map(|p| p.variable.data().to_f32_vec().unwrap())
.collect();
let iter = model.epoch(0);
let active = iter.activate();
let mut batch_count = 0;
for batch_result in active {
let b = batch_result.unwrap();
assert!(b.has("input"));
assert!(b.has("target"));
let out = model.forward_batch(&b).unwrap();
let target = crate::autograd::Variable::new(b["target"].clone(), false);
let loss = mse_loss(&out, &target).unwrap();
loss.backward().unwrap();
model.step().unwrap();
batch_count += 1;
}
assert_eq!(batch_count, 4);
let params_after: Vec<f32> = model
.parameters()
.iter()
.flat_map(|p| p.variable.data().to_f32_vec().unwrap())
.collect();
let changed = params_before
.iter()
.zip(¶ms_after)
.any(|(a, b)| (a - b).abs() > 1e-8);
assert!(changed, "parameters should change after training");
}
#[test]
fn test_graph_data_num_batches() {
use crate::graph::FlowBuilder;
use crate::nn::Linear;
let model = FlowBuilder::from(Linear::new(4, 2).unwrap())
.build()
.unwrap();
let data = make_data(20);
let loader = DataLoader::from_dataset(data)
.batch_size(5)
.names(&["x", "y"])
.build()
.unwrap();
model.set_data_loader(loader, "x").unwrap();
assert_eq!(model.data_num_batches(), 4);
assert_eq!(model.data_batch_size(), 5);
}
#[test]
fn test_set_data_loader_invalid_input_name() {
use crate::graph::FlowBuilder;
use crate::nn::Linear;
let model = FlowBuilder::from(Linear::new(4, 2).unwrap())
.build()
.unwrap();
let data = make_data(10);
let loader = DataLoader::from_dataset(data)
.batch_size(5)
.names(&["x", "y"])
.build()
.unwrap();
let result = model.set_data_loader(loader, "missing");
assert!(result.is_err());
}
#[test]
fn test_scatter_fallback_without_data_loader() {
use crate::graph::FlowBuilder;
use crate::nn::{Linear, Module};
let model = FlowBuilder::from(Linear::new(4, 2).unwrap())
.build()
.unwrap();
let x = crate::autograd::Variable::new(
Tensor::randn(&[3, 4], Default::default()).unwrap(),
false,
);
let out = model.forward(&x).unwrap();
assert_eq!(out.shape(), &[3, 2]);
}
#[test]
fn test_prefetch_depth_from_vram_cpu() {
let depth = prefetch_depth_from_vram(100, 32, Device::CPU, 0.90, 0);
assert_eq!(depth, 2);
}
#[test]
fn test_prefetch_depth_from_vram_zero_batch() {
let depth = prefetch_depth_from_vram(0, 32, Device::CPU, 0.90, 0);
assert_eq!(depth, 2);
}
#[test]
fn test_prefetch_depth_from_vram_zero_bytes() {
let depth = prefetch_depth_from_vram(100, 0, Device::CPU, 0.90, 0);
assert_eq!(depth, 2);
}
#[test]
fn test_streaming_prefetch_depth_and_resize() {
let data = SequentialData { n: 100 };
let mut loader = DataLoader::from_dataset(data)
.batch_size(10)
.streaming()
.build()
.unwrap();
assert!(!loader.is_resident());
let initial = loader.prefetch_depth();
assert!(initial >= 2, "initial depth should be >= 2, got {initial}");
loader.set_prefetch_depth(42);
assert_eq!(loader.prefetch_depth(), 42);
loader.set_prefetch_depth(4);
assert_eq!(loader.prefetch_depth(), 4);
}
#[test]
fn test_resident_prefetch_depth_is_zero() {
let data = SequentialData { n: 20 };
let mut loader = DataLoader::from_dataset(data)
.batch_size(5)
.build()
.unwrap();
assert!(loader.is_resident());
assert_eq!(loader.prefetch_depth(), 0);
loader.set_prefetch_depth(100);
assert_eq!(loader.prefetch_depth(), 0);
let depth = loader.auto_resize();
assert_eq!(depth, 0);
}
#[test]
fn test_streaming_auto_resize_cpu() {
let data = SequentialData { n: 100 };
let mut loader = DataLoader::from_dataset(data)
.batch_size(10)
.streaming()
.build()
.unwrap();
let depth = loader.auto_resize();
assert_eq!(depth, 2);
}
#[test]
fn test_streaming_epoch_after_resize() {
let data = SequentialData { n: 50 };
let mut loader = DataLoader::from_dataset(data)
.batch_size(10)
.streaming()
.build()
.unwrap();
loader.set_prefetch_depth(8);
let mut count = 0;
for batch in loader.epoch(0) {
let b = batch.unwrap();
assert_eq!(b[0].shape(), &[10, 1]);
count += 1;
}
assert_eq!(count, 5);
loader.set_prefetch_depth(2);
count = 0;
for batch in loader.epoch(1) {
batch.unwrap();
count += 1;
}
assert_eq!(count, 5);
}
#[test]
fn test_vram_max_usage_builder() {
let data = SequentialData { n: 100 };
let loader = DataLoader::from_dataset(data)
.batch_size(10)
.vram_max_usage(0.80) .streaming()
.build()
.unwrap();
assert!(!loader.is_resident());
assert!(loader.prefetch_depth() >= 2);
}
#[test]
fn test_vram_max_usage_clamped() {
let data = SequentialData { n: 100 };
let loader = DataLoader::from_dataset(data)
.batch_size(10)
.vram_max_usage(0.10) .streaming()
.build()
.unwrap();
assert!(!loader.is_resident());
}
#[test]
fn test_el_che_counts_cell_roundtrip() {
let cell: std::cell::Cell<Option<Vec<usize>>> = std::cell::Cell::new(None);
assert!(cell.take().is_none());
cell.set(Some(vec![10, 23]));
let val = cell.take();
assert_eq!(val, Some(vec![10, 23]));
assert!(cell.take().is_none());
}
#[test]
fn test_el_che_batches_cell_roundtrip() {
let cell: std::cell::Cell<Option<Vec<Vec<Vec<Tensor>>>>> = std::cell::Cell::new(None);
assert!(cell.take().is_none());
let t = Tensor::zeros(&[2, 3], Default::default()).unwrap();
let batches = vec![vec![vec![t.clone()]], vec![vec![t]]];
cell.set(Some(batches));
let val = cell.take();
assert!(val.is_some());
let batches = val.unwrap();
assert_eq!(batches.len(), 2); assert_eq!(batches[0].len(), 1); assert_eq!(batches[1].len(), 1); }
#[test]
fn test_el_che_clamping_proportional() {
let counts = [10usize, 23];
let total: usize = counts.iter().sum(); let remaining = 20usize;
let scale = remaining as f64 / total as f64;
let mut clamped: Vec<usize> = counts.iter()
.map(|&c| (c as f64 * scale).floor() as usize)
.collect();
let clamped_total: usize = clamped.iter().sum();
let mut deficit = remaining.saturating_sub(clamped_total);
for c in &mut clamped {
if deficit == 0 { break; }
*c += 1;
deficit -= 1;
}
let final_total: usize = clamped.iter().sum();
assert_eq!(final_total, remaining);
assert!(clamped[0] < clamped[1], "fast device should still get more");
}
#[test]
fn test_single_item_dataset() {
let dev = test_device();
let opts = TensorOptions { dtype: DType::Float32, device: Device::CPU };
let data = SimpleData {
x: Tensor::randn(&[1, 4], opts).unwrap(),
y: Tensor::randn(&[1, 2], opts).unwrap(),
};
let mut loader = DataLoader::from_dataset(data)
.batch_size(1)
.device(dev)
.drop_last(false)
.build()
.unwrap();
assert_eq!(loader.len(), 1);
assert_eq!(loader.num_batches(), 1);
let batches: Vec<Batch> = loader.epoch(0).map(|b| b.unwrap()).collect();
assert_eq!(batches.len(), 1);
assert_eq!(batches[0][0].shape(), &[1, 4]);
assert_eq!(batches[0][1].shape(), &[1, 2]);
assert_eq!(batches[0][0].device(), dev);
}
#[test]
fn test_dataset_smaller_than_batch_no_drop() {
let dev = test_device();
let opts = TensorOptions { dtype: DType::Float32, device: Device::CPU };
let data = SimpleData {
x: Tensor::randn(&[3, 4], opts).unwrap(),
y: Tensor::randn(&[3, 2], opts).unwrap(),
};
let mut loader = DataLoader::from_dataset(data)
.batch_size(10)
.device(dev)
.drop_last(false)
.build()
.unwrap();
let batches: Vec<Batch> = loader.epoch(0).map(|b| b.unwrap()).collect();
assert_eq!(batches.len(), 1);
assert_eq!(batches[0][0].shape(), &[3, 4]);
assert_eq!(batches[0][1].shape(), &[3, 2]);
}
#[test]
fn test_dataset_smaller_than_batch_drop_last() {
let opts = TensorOptions { dtype: DType::Float32, device: Device::CPU };
let data = SimpleData {
x: Tensor::randn(&[3, 4], opts).unwrap(),
y: Tensor::randn(&[3, 2], opts).unwrap(),
};
let mut loader = DataLoader::from_dataset(data)
.batch_size(10)
.drop_last(true)
.build()
.unwrap();
let batches: Vec<Batch> = loader.epoch(0).map(|b| b.unwrap()).collect();
assert_eq!(batches.len(), 0);
}
#[test]
fn test_drop_last_exact_division() {
let data = make_data(100);
let mut loader = DataLoader::from_dataset(data)
.batch_size(10)
.drop_last(true)
.build()
.unwrap();
assert_eq!(loader.num_batches(), 10);
let batches: Vec<Batch> = loader.epoch(0).map(|b| b.unwrap()).collect();
assert_eq!(batches.len(), 10);
for b in &batches {
assert_eq!(b[0].shape(), &[10, 4]);
}
}
#[test]
fn test_drop_last_with_remainder() {
let data = make_data(105);
let mut loader = DataLoader::from_dataset(data)
.batch_size(10)
.drop_last(true)
.build()
.unwrap();
assert_eq!(loader.num_batches(), 10);
let batches: Vec<Batch> = loader.epoch(0).map(|b| b.unwrap()).collect();
assert_eq!(batches.len(), 10);
for b in &batches {
assert_eq!(b[0].shape(), &[10, 4]);
}
}
#[test]
fn test_two_epoch_consistency() {
let n = 25;
let mut loader = DataLoader::from_dataset(SequentialData { n })
.batch_size(7)
.drop_last(false)
.build()
.unwrap();
for epoch in 0..2 {
let mut vals: Vec<f32> = loader
.epoch(epoch)
.flat_map(|b| b.unwrap()[0].to_f32_vec().unwrap())
.collect();
assert_eq!(vals.len(), n, "epoch {epoch}: should see all {n} items");
vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
let expected: Vec<f32> = (0..n).map(|i| i as f32).collect();
assert_eq!(vals, expected, "epoch {epoch}: no data lost or duplicated");
}
}
#[test]
fn test_sequential_sampler_batch_ordering() {
let mut loader = DataLoader::from_dataset(SequentialData { n: 12 })
.batch_size(4)
.shuffle(false)
.drop_last(false)
.build()
.unwrap();
let batches: Vec<Vec<f32>> = loader
.epoch(0)
.map(|b| b.unwrap()[0].to_f32_vec().unwrap())
.collect();
assert_eq!(batches.len(), 3);
assert_eq!(batches[0], vec![0.0, 1.0, 2.0, 3.0]);
assert_eq!(batches[1], vec![4.0, 5.0, 6.0, 7.0]);
assert_eq!(batches[2], vec![8.0, 9.0, 10.0, 11.0]);
}
#[test]
fn test_empty_iteration_no_leak() {
let data = make_data(20);
let mut loader = DataLoader::from_dataset(data)
.batch_size(5)
.build()
.unwrap();
{
let _iter = loader.epoch(0);
}
let batches: Vec<Batch> = loader.epoch(1).map(|b| b.unwrap()).collect();
assert_eq!(batches.len(), 4);
}
#[test]
fn test_named_and_positional_access() {
let data = make_data(10);
let mut loader = DataLoader::from_dataset(data)
.batch_size(5)
.names(&["features", "labels"])
.build()
.unwrap();
let b = loader.epoch(0).next().unwrap().unwrap();
let by_pos = b[0].shape().to_vec();
let by_name = b["features"].shape().to_vec();
assert_eq!(by_pos, by_name);
let by_pos_1 = b[1].shape().to_vec();
let by_name_1 = b["labels"].shape().to_vec();
assert_eq!(by_pos_1, by_name_1);
assert!(b.get_named("features").is_some());
assert!(b.get_named("nonexistent").is_none());
}
#[test]
fn test_multiple_tensors_per_sample() {
struct TripleData {
input: Tensor,
target: Tensor,
mask: Tensor,
}
impl DataSet for TripleData {
fn len(&self) -> usize { self.input.shape()[0] as usize }
fn get(&self, i: usize) -> Result<Vec<Tensor>> {
Ok(vec![
self.input.select(0, i as i64)?,
self.target.select(0, i as i64)?,
self.mask.select(0, i as i64)?,
])
}
}
let dev = test_device();
let opts = TensorOptions { dtype: DType::Float32, device: Device::CPU };
let data = TripleData {
input: Tensor::randn(&[16, 10], opts).unwrap(),
target: Tensor::randn(&[16, 5], opts).unwrap(),
mask: Tensor::ones(&[16, 10], opts).unwrap(),
};
let mut loader = DataLoader::from_dataset(data)
.batch_size(4)
.device(dev)
.names(&["input", "target", "mask"])
.build()
.unwrap();
assert_eq!(loader.names(), &["input", "target", "mask"]);
let batches: Vec<Batch> = loader.epoch(0).map(|b| b.unwrap()).collect();
assert_eq!(batches.len(), 4);
for b in &batches {
assert_eq!(b.len(), 3);
assert_eq!(b["input"].shape(), &[4, 10]);
assert_eq!(b["target"].shape(), &[4, 5]);
assert_eq!(b["mask"].shape(), &[4, 10]);
assert_eq!(b["input"].device(), dev);
assert_eq!(b["target"].device(), dev);
assert_eq!(b["mask"].device(), dev);
}
}
#[test]
fn test_exact_size_iterator_with_drop_last() {
let data = make_data(23);
let mut loader = DataLoader::from_dataset(data)
.batch_size(5)
.drop_last(true)
.build()
.unwrap();
let iter = loader.epoch(0);
assert_eq!(iter.len(), 4); }
#[test]
fn test_exact_size_iterator_no_drop_last() {
let data = make_data(23);
let mut loader = DataLoader::from_dataset(data)
.batch_size(5)
.drop_last(false)
.build()
.unwrap();
let iter = loader.epoch(0);
assert_eq!(iter.len(), 5); }