Skip to main content

shrew_data/
transform.rs

1// Transform — data augmentation / preprocessing pipeline
2
3use crate::dataset::Sample;
4
5/// A transform applied to each sample before batching.
6pub trait Transform: Send + Sync {
7    /// Apply the transform to a sample, returning the modified sample.
8    fn apply(&self, sample: Sample) -> Sample;
9}
10
11// Built-in transforms
12
13/// Normalize features to [0, 1] by dividing by a given scale factor.
14///
15/// Commonly used for image pixels: `Normalize::new(255.0)`.
16#[derive(Debug, Clone)]
17pub struct Normalize {
18    scale: f64,
19}
20
21impl Normalize {
22    pub fn new(scale: f64) -> Self {
23        Self { scale }
24    }
25}
26
27impl Transform for Normalize {
28    fn apply(&self, mut sample: Sample) -> Sample {
29        for v in &mut sample.features {
30            *v /= self.scale;
31        }
32        sample
33    }
34}
35
36/// Standardize features to zero mean and unit variance.
37#[derive(Debug, Clone)]
38pub struct Standardize {
39    pub mean: f64,
40    pub std: f64,
41}
42
43impl Standardize {
44    pub fn new(mean: f64, std: f64) -> Self {
45        Self { mean, std }
46    }
47}
48
49impl Transform for Standardize {
50    fn apply(&self, mut sample: Sample) -> Sample {
51        for v in &mut sample.features {
52            *v = (*v - self.mean) / self.std;
53        }
54        sample
55    }
56}
57
58/// One-hot encode the target label into a vector of size `num_classes`.
59#[derive(Debug, Clone)]
60pub struct OneHotEncode {
61    pub num_classes: usize,
62}
63
64impl OneHotEncode {
65    pub fn new(num_classes: usize) -> Self {
66        Self { num_classes }
67    }
68}
69
70impl Transform for OneHotEncode {
71    fn apply(&self, mut sample: Sample) -> Sample {
72        let class_idx = sample.target[0] as usize;
73        let mut one_hot = vec![0.0; self.num_classes];
74        if class_idx < self.num_classes {
75            one_hot[class_idx] = 1.0;
76        }
77        sample.target = one_hot;
78        sample.target_shape = vec![self.num_classes];
79        sample
80    }
81}
82
83/// Chain multiple transforms.
84pub struct Compose {
85    transforms: Vec<Box<dyn Transform>>,
86}
87
88impl Compose {
89    pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
90        Self { transforms }
91    }
92}
93
94impl Transform for Compose {
95    fn apply(&self, mut sample: Sample) -> Sample {
96        for t in &self.transforms {
97            sample = t.apply(sample);
98        }
99        sample
100    }
101}
102
103/// Reshape the feature tensor to a different shape (without changing data).
104///
105/// Useful for converting flat MNIST images `[784]` to 2D `[1, 28, 28]`
106/// for convolutional networks.
107///
108/// # Examples
109/// ```ignore
110/// // MNIST: [784] → [1, 28, 28] for Conv2d input
111/// let reshape = ReshapeFeatures::new(vec![1, 28, 28]);
112/// ```
113#[derive(Debug, Clone)]
114pub struct ReshapeFeatures {
115    pub new_shape: Vec<usize>,
116}
117
118impl ReshapeFeatures {
119    pub fn new(new_shape: Vec<usize>) -> Self {
120        Self { new_shape }
121    }
122}
123
124impl Transform for ReshapeFeatures {
125    fn apply(&self, mut sample: Sample) -> Sample {
126        // Verify element count matches
127        let old_count: usize = sample.feature_shape.iter().product();
128        let new_count: usize = self.new_shape.iter().product();
129        assert_eq!(
130            old_count, new_count,
131            "ReshapeFeatures: old shape {:?} ({}) != new shape {:?} ({})",
132            sample.feature_shape, old_count, self.new_shape, new_count,
133        );
134        sample.feature_shape = self.new_shape.clone();
135        sample
136    }
137}