shrew_data/dataset.rs
1// Dataset trait — unified interface for any data source
2
3/// A single sample: a pair of (input features, label/target).
4///
5/// Both are stored as `Vec<f64>` with their associated shapes so they can be
6/// batched into tensors later.
7#[derive(Debug, Clone)]
8pub struct Sample {
9 /// Input feature vector (flattened).
10 pub features: Vec<f64>,
11 /// Shape of the feature tensor (e.g. `[784]` for MNIST, `[3,32,32]` for CIFAR).
12 pub feature_shape: Vec<usize>,
13 /// Target / label value(s) (flattened). For classification this is typically
14 /// a single-element vec holding the class index as `f64`.
15 pub target: Vec<f64>,
16 /// Shape of the target tensor (e.g. `[1]` for a class index, `[10]` for one-hot).
17 pub target_shape: Vec<usize>,
18}
19
20/// A dataset is an indexed collection of samples.
21///
22/// Implementations must be `Send + Sync` so DataLoader can read from multiple
23/// threads when parallel prefetching is enabled.
24pub trait Dataset: Send + Sync {
25 /// Total number of samples in the dataset.
26 fn len(&self) -> usize;
27
28 /// Whether the dataset is empty.
29 fn is_empty(&self) -> bool {
30 self.len() == 0
31 }
32
33 /// Retrieve the sample at position `index`.
34 ///
35 /// # Panics
36 /// May panic if `index >= self.len()`.
37 fn get(&self, index: usize) -> Sample;
38
39 /// The shape of a single feature sample (without batch dim).
40 fn feature_shape(&self) -> &[usize];
41
42 /// The shape of a single target sample (without batch dim).
43 fn target_shape(&self) -> &[usize];
44
45 /// Optional human-readable name.
46 fn name(&self) -> &str {
47 "dataset"
48 }
49}