Skip to main content

shrew_data/
dataset.rs

1// Dataset trait — unified interface for any data source
2
3/// A single sample: a pair of (input features, label/target).
4///
5/// Both are stored as `Vec<f64>` with their associated shapes so they can be
6/// batched into tensors later.
7#[derive(Debug, Clone)]
8pub struct Sample {
9    /// Input feature vector (flattened).
10    pub features: Vec<f64>,
11    /// Shape of the feature tensor (e.g. `[784]` for MNIST, `[3,32,32]` for CIFAR).
12    pub feature_shape: Vec<usize>,
13    /// Target / label value(s) (flattened).  For classification this is typically
14    /// a single-element vec holding the class index as `f64`.
15    pub target: Vec<f64>,
16    /// Shape of the target tensor (e.g. `[1]` for a class index, `[10]` for one-hot).
17    pub target_shape: Vec<usize>,
18}
19
20/// A dataset is an indexed collection of samples.
21///
22/// Implementations must be `Send + Sync` so DataLoader can read from multiple
23/// threads when parallel prefetching is enabled.
24pub trait Dataset: Send + Sync {
25    /// Total number of samples in the dataset.
26    fn len(&self) -> usize;
27
28    /// Whether the dataset is empty.
29    fn is_empty(&self) -> bool {
30        self.len() == 0
31    }
32
33    /// Retrieve the sample at position `index`.
34    ///
35    /// # Panics
36    /// May panic if `index >= self.len()`.
37    fn get(&self, index: usize) -> Sample;
38
39    /// The shape of a single feature sample (without batch dim).
40    fn feature_shape(&self) -> &[usize];
41
42    /// The shape of a single target sample (without batch dim).
43    fn target_shape(&self) -> &[usize];
44
45    /// Optional human-readable name.
46    fn name(&self) -> &str {
47        "dataset"
48    }
49}