Skip to main content

shrew_data/
combinators.rs

1// Dataset Combinators — compose, filter, subset, concatenate datasets
2
3use crate::dataset::{Dataset, Sample};
4use crate::transform::Transform;
5
6// SubsetDataset — view of selected indices
7
8/// A dataset that exposes only the samples at the given indices.
9///
10/// This is useful for train/val/test splitting.
11pub struct SubsetDataset<D: Dataset> {
12    inner: D,
13    indices: Vec<usize>,
14}
15
16impl<D: Dataset> SubsetDataset<D> {
17    /// Create a subset of `inner` containing only the samples at `indices`.
18    ///
19    /// # Panics
20    /// Panics (lazily, at `get` time) if any index is out of range.
21    pub fn new(inner: D, indices: Vec<usize>) -> Self {
22        Self { inner, indices }
23    }
24}
25
26impl<D: Dataset> Dataset for SubsetDataset<D> {
27    fn len(&self) -> usize {
28        self.indices.len()
29    }
30
31    fn get(&self, index: usize) -> Sample {
32        self.inner.get(self.indices[index])
33    }
34
35    fn feature_shape(&self) -> &[usize] {
36        self.inner.feature_shape()
37    }
38
39    fn target_shape(&self) -> &[usize] {
40        self.inner.target_shape()
41    }
42
43    fn name(&self) -> &str {
44        self.inner.name()
45    }
46}
47
48// ConcatDataset — concatenate multiple datasets
49
50/// Concatenate two or more datasets end-to-end.
51///
52/// All datasets must share the same feature and target shapes.
53pub struct ConcatDataset {
54    datasets: Vec<Box<dyn Dataset>>,
55    cumulative_sizes: Vec<usize>,
56    feature_shape: Vec<usize>,
57    target_shape: Vec<usize>,
58}
59
60impl ConcatDataset {
61    /// Create a concatenation of the given datasets.
62    ///
63    /// # Panics
64    /// Panics if `datasets` is empty.
65    pub fn new(datasets: Vec<Box<dyn Dataset>>) -> Self {
66        assert!(
67            !datasets.is_empty(),
68            "ConcatDataset: need at least one dataset"
69        );
70
71        let feature_shape = datasets[0].feature_shape().to_vec();
72        let target_shape = datasets[0].target_shape().to_vec();
73
74        let mut cumulative_sizes = Vec::with_capacity(datasets.len());
75        let mut total = 0;
76        for ds in &datasets {
77            total += ds.len();
78            cumulative_sizes.push(total);
79        }
80
81        Self {
82            datasets,
83            cumulative_sizes,
84            feature_shape,
85            target_shape,
86        }
87    }
88
89    /// Locate which dataset and local index a global index maps to.
90    fn locate(&self, index: usize) -> (usize, usize) {
91        for (ds_idx, &cum) in self.cumulative_sizes.iter().enumerate() {
92            if index < cum {
93                let offset = if ds_idx == 0 {
94                    0
95                } else {
96                    self.cumulative_sizes[ds_idx - 1]
97                };
98                return (ds_idx, index - offset);
99            }
100        }
101        panic!(
102            "ConcatDataset: index {} out of range (total {})",
103            index,
104            self.cumulative_sizes.last().unwrap_or(&0)
105        );
106    }
107}
108
109impl Dataset for ConcatDataset {
110    fn len(&self) -> usize {
111        *self.cumulative_sizes.last().unwrap_or(&0)
112    }
113
114    fn get(&self, index: usize) -> Sample {
115        let (ds_idx, local_idx) = self.locate(index);
116        self.datasets[ds_idx].get(local_idx)
117    }
118
119    fn feature_shape(&self) -> &[usize] {
120        &self.feature_shape
121    }
122
123    fn target_shape(&self) -> &[usize] {
124        &self.target_shape
125    }
126
127    fn name(&self) -> &str {
128        "concat"
129    }
130}
131
132// MapDataset — apply a transform lazily
133
134/// Wraps a dataset and applies a `Transform` lazily on each `get()`.
135pub struct MapDataset<D: Dataset> {
136    inner: D,
137    transform: Box<dyn Transform>,
138    /// Feature shape after transform (caller-provided).
139    feat_shape: Vec<usize>,
140    /// Target shape after transform (caller-provided).
141    tgt_shape: Vec<usize>,
142}
143
144impl<D: Dataset> MapDataset<D> {
145    /// Create a MapDataset.
146    ///
147    /// `feat_shape` and `tgt_shape` describe the shapes *after* the
148    /// transform is applied.  If the transform doesn't change shapes,
149    /// you can clone them from the inner dataset.
150    pub fn new(
151        inner: D,
152        transform: Box<dyn Transform>,
153        feat_shape: Vec<usize>,
154        tgt_shape: Vec<usize>,
155    ) -> Self {
156        Self {
157            inner,
158            transform,
159            feat_shape,
160            tgt_shape,
161        }
162    }
163
164    /// Convenience: create a MapDataset whose shapes are unchanged.
165    pub fn same_shape(inner: D, transform: Box<dyn Transform>) -> Self {
166        let feat_shape = inner.feature_shape().to_vec();
167        let tgt_shape = inner.target_shape().to_vec();
168        Self {
169            inner,
170            transform,
171            feat_shape,
172            tgt_shape,
173        }
174    }
175}
176
177impl<D: Dataset> Dataset for MapDataset<D> {
178    fn len(&self) -> usize {
179        self.inner.len()
180    }
181
182    fn get(&self, index: usize) -> Sample {
183        let sample = self.inner.get(index);
184        self.transform.apply(sample)
185    }
186
187    fn feature_shape(&self) -> &[usize] {
188        &self.feat_shape
189    }
190
191    fn target_shape(&self) -> &[usize] {
192        &self.tgt_shape
193    }
194
195    fn name(&self) -> &str {
196        self.inner.name()
197    }
198}
199
200// VecDataset — in-memory dataset from raw vectors
201
202/// A simple in-memory dataset backed by a `Vec<Sample>`.
203///
204/// Useful for building datasets programmatically or loading from CSV/JSON.
205pub struct VecDataset {
206    samples: Vec<Sample>,
207    feature_shape: Vec<usize>,
208    target_shape: Vec<usize>,
209    dataset_name: String,
210}
211
212impl VecDataset {
213    /// Create a VecDataset from a vector of samples.
214    ///
215    /// # Panics
216    /// Panics if `samples` is empty.
217    pub fn new(samples: Vec<Sample>, name: &str) -> Self {
218        assert!(!samples.is_empty(), "VecDataset: need at least one sample");
219        let feature_shape = samples[0].feature_shape.clone();
220        let target_shape = samples[0].target_shape.clone();
221        Self {
222            samples,
223            feature_shape,
224            target_shape,
225            dataset_name: name.to_string(),
226        }
227    }
228
229    /// Build from feature/target matrices.
230    ///
231    /// `features`: `[n_samples, n_features]` row-major
232    /// `targets`:  `[n_samples, n_targets]` row-major
233    pub fn from_flat(
234        features: &[f64],
235        feature_shape: &[usize],
236        targets: &[f64],
237        target_shape: &[usize],
238        name: &str,
239    ) -> Self {
240        let feat_per_sample: usize = feature_shape.iter().product();
241        let tgt_per_sample: usize = target_shape.iter().product();
242        let n = features.len() / feat_per_sample;
243        assert_eq!(features.len(), n * feat_per_sample);
244        assert_eq!(targets.len(), n * tgt_per_sample);
245
246        let samples: Vec<Sample> = (0..n)
247            .map(|i| Sample {
248                features: features[i * feat_per_sample..(i + 1) * feat_per_sample].to_vec(),
249                feature_shape: feature_shape.to_vec(),
250                target: targets[i * tgt_per_sample..(i + 1) * tgt_per_sample].to_vec(),
251                target_shape: target_shape.to_vec(),
252            })
253            .collect();
254
255        Self {
256            samples,
257            feature_shape: feature_shape.to_vec(),
258            target_shape: target_shape.to_vec(),
259            dataset_name: name.to_string(),
260        }
261    }
262}
263
264impl Dataset for VecDataset {
265    fn len(&self) -> usize {
266        self.samples.len()
267    }
268
269    fn get(&self, index: usize) -> Sample {
270        self.samples[index].clone()
271    }
272
273    fn feature_shape(&self) -> &[usize] {
274        &self.feature_shape
275    }
276
277    fn target_shape(&self) -> &[usize] {
278        &self.target_shape
279    }
280
281    fn name(&self) -> &str {
282        &self.dataset_name
283    }
284}
285
286// Train / Validation / Test Split
287
288/// Split a dataset into (train, val) or (train, val, test) subsets.
289///
290/// Returns `SubsetDataset` views over the original dataset.
291///
292/// # Arguments
293/// * `dataset` — the source dataset
294/// * `ratios` — slice of 2 or 3 floats that sum to 1.0, e.g. `[0.8, 0.2]`
295///   or `[0.7, 0.15, 0.15]`
296/// * `seed` — random seed for reproducible shuffling of indices
297pub fn train_test_split<D>(dataset: D, ratios: &[f64], seed: u64) -> Vec<SubsetDataset<D>>
298where
299    D: Dataset + Clone,
300{
301    use rand::rngs::StdRng;
302    use rand::seq::SliceRandom;
303    use rand::SeedableRng;
304
305    assert!(
306        ratios.len() >= 2 && ratios.len() <= 3,
307        "train_test_split: ratios must have 2 or 3 elements"
308    );
309    let sum: f64 = ratios.iter().sum();
310    assert!(
311        (sum - 1.0).abs() < 1e-6,
312        "train_test_split: ratios must sum to 1.0, got {}",
313        sum
314    );
315
316    let n = dataset.len();
317    let mut indices: Vec<usize> = (0..n).collect();
318    let mut rng = StdRng::seed_from_u64(seed);
319    indices.shuffle(&mut rng);
320
321    let mut splits = Vec::new();
322    let mut offset = 0;
323    for (i, &ratio) in ratios.iter().enumerate() {
324        let count = if i == ratios.len() - 1 {
325            n - offset // give remainder to last split
326        } else {
327            (n as f64 * ratio).round() as usize
328        };
329        let end = (offset + count).min(n);
330        splits.push(SubsetDataset::new(
331            dataset.clone(),
332            indices[offset..end].to_vec(),
333        ));
334        offset = end;
335    }
336
337    splits
338}
339
340// Tests
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    /// Tiny helper dataset for testing.
347    #[derive(Clone)]
348    struct TinyDataset {
349        n: usize,
350    }
351
352    impl Dataset for TinyDataset {
353        fn len(&self) -> usize {
354            self.n
355        }
356        fn get(&self, idx: usize) -> Sample {
357            Sample {
358                features: vec![idx as f64],
359                feature_shape: vec![1],
360                target: vec![(idx % 3) as f64],
361                target_shape: vec![1],
362            }
363        }
364        fn feature_shape(&self) -> &[usize] {
365            &[1]
366        }
367        fn target_shape(&self) -> &[usize] {
368            &[1]
369        }
370    }
371
372    #[test]
373    fn subset_dataset() {
374        let ds = TinyDataset { n: 10 };
375        let sub = SubsetDataset::new(ds, vec![2, 5, 7]);
376        assert_eq!(sub.len(), 3);
377        assert_eq!(sub.get(0).features[0], 2.0);
378        assert_eq!(sub.get(1).features[0], 5.0);
379        assert_eq!(sub.get(2).features[0], 7.0);
380    }
381
382    #[test]
383    fn concat_dataset() {
384        let ds1 = TinyDataset { n: 5 };
385        let ds2 = TinyDataset { n: 3 };
386        let concat = ConcatDataset::new(vec![Box::new(ds1), Box::new(ds2)]);
387        assert_eq!(concat.len(), 8);
388        // First 5 come from ds1, next 3 from ds2
389        assert_eq!(concat.get(0).features[0], 0.0);
390        assert_eq!(concat.get(4).features[0], 4.0);
391        assert_eq!(concat.get(5).features[0], 0.0); // ds2 index 0
392        assert_eq!(concat.get(7).features[0], 2.0); // ds2 index 2
393    }
394
395    #[test]
396    fn map_dataset() {
397        use crate::transform::Normalize;
398        let ds = TinyDataset { n: 4 };
399        let mapped = MapDataset::same_shape(ds, Box::new(Normalize::new(10.0)));
400        assert_eq!(mapped.len(), 4);
401        let s = mapped.get(2);
402        assert!((s.features[0] - 0.2).abs() < 1e-10);
403    }
404
405    #[test]
406    fn vec_dataset() {
407        let features = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
408        let targets = vec![0.0, 1.0, 0.0];
409        let ds = VecDataset::from_flat(&features, &[2], &targets, &[1], "test");
410        assert_eq!(ds.len(), 3);
411        assert_eq!(ds.get(0).features, vec![1.0, 2.0]);
412        assert_eq!(ds.get(1).features, vec![3.0, 4.0]);
413        assert_eq!(ds.get(2).target, vec![0.0]);
414    }
415
416    #[test]
417    fn train_test_split_two_way() {
418        let ds = TinyDataset { n: 100 };
419        let splits = train_test_split(ds, &[0.8, 0.2], 42);
420        assert_eq!(splits.len(), 2);
421        assert_eq!(splits[0].len() + splits[1].len(), 100);
422        assert_eq!(splits[0].len(), 80);
423        assert_eq!(splits[1].len(), 20);
424    }
425
426    #[test]
427    fn train_test_split_three_way() {
428        let ds = TinyDataset { n: 100 };
429        let splits = train_test_split(ds, &[0.7, 0.15, 0.15], 42);
430        assert_eq!(splits.len(), 3);
431        assert_eq!(splits[0].len() + splits[1].len() + splits[2].len(), 100);
432    }
433
434    #[test]
435    fn train_test_split_reproducible() {
436        let ds1 = TinyDataset { n: 50 };
437        let ds2 = TinyDataset { n: 50 };
438        let s1 = train_test_split(ds1, &[0.8, 0.2], 123);
439        let s2 = train_test_split(ds2, &[0.8, 0.2], 123);
440        // Same seed → same indices → same samples
441        for i in 0..s1[0].len() {
442            assert_eq!(s1[0].get(i).features, s2[0].get(i).features);
443        }
444    }
445}