1use crate::dataset::Sample;
4
5pub trait Transform: Send + Sync {
7 fn apply(&self, sample: Sample) -> Sample;
9}
10
11#[derive(Debug, Clone)]
17pub struct Normalize {
18 scale: f64,
19}
20
21impl Normalize {
22 pub fn new(scale: f64) -> Self {
23 Self { scale }
24 }
25}
26
27impl Transform for Normalize {
28 fn apply(&self, mut sample: Sample) -> Sample {
29 for v in &mut sample.features {
30 *v /= self.scale;
31 }
32 sample
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct Standardize {
39 pub mean: f64,
40 pub std: f64,
41}
42
43impl Standardize {
44 pub fn new(mean: f64, std: f64) -> Self {
45 Self { mean, std }
46 }
47}
48
49impl Transform for Standardize {
50 fn apply(&self, mut sample: Sample) -> Sample {
51 for v in &mut sample.features {
52 *v = (*v - self.mean) / self.std;
53 }
54 sample
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct OneHotEncode {
61 pub num_classes: usize,
62}
63
64impl OneHotEncode {
65 pub fn new(num_classes: usize) -> Self {
66 Self { num_classes }
67 }
68}
69
70impl Transform for OneHotEncode {
71 fn apply(&self, mut sample: Sample) -> Sample {
72 let class_idx = sample.target[0] as usize;
73 let mut one_hot = vec![0.0; self.num_classes];
74 if class_idx < self.num_classes {
75 one_hot[class_idx] = 1.0;
76 }
77 sample.target = one_hot;
78 sample.target_shape = vec![self.num_classes];
79 sample
80 }
81}
82
83pub struct Compose {
85 transforms: Vec<Box<dyn Transform>>,
86}
87
88impl Compose {
89 pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
90 Self { transforms }
91 }
92}
93
94impl Transform for Compose {
95 fn apply(&self, mut sample: Sample) -> Sample {
96 for t in &self.transforms {
97 sample = t.apply(sample);
98 }
99 sample
100 }
101}
102
103#[derive(Debug, Clone)]
114pub struct ReshapeFeatures {
115 pub new_shape: Vec<usize>,
116}
117
118impl ReshapeFeatures {
119 pub fn new(new_shape: Vec<usize>) -> Self {
120 Self { new_shape }
121 }
122}
123
124impl Transform for ReshapeFeatures {
125 fn apply(&self, mut sample: Sample) -> Sample {
126 let old_count: usize = sample.feature_shape.iter().product();
128 let new_count: usize = self.new_shape.iter().product();
129 assert_eq!(
130 old_count, new_count,
131 "ReshapeFeatures: old shape {:?} ({}) != new shape {:?} ({})",
132 sample.feature_shape, old_count, self.new_shape, new_count,
133 );
134 sample.feature_shape = self.new_shape.clone();
135 sample
136 }
137}