ghostflow_data/
transforms.rs1use ghostflow_core::Tensor;
4use rand::Rng;
5
6pub trait Transform: Send + Sync {
8 fn apply(&self, tensor: &Tensor) -> Tensor;
9}
10
11pub struct Normalize {
13 mean: Vec<f32>,
14 std: Vec<f32>,
15}
16
17impl Normalize {
18 pub fn new(mean: Vec<f32>, std: Vec<f32>) -> Self {
19 Normalize { mean, std }
20 }
21
22 pub fn imagenet() -> Self {
24 Normalize {
25 mean: vec![0.485, 0.456, 0.406],
26 std: vec![0.229, 0.224, 0.225],
27 }
28 }
29}
30
31impl Transform for Normalize {
32 fn apply(&self, tensor: &Tensor) -> Tensor {
33 let dims = tensor.dims();
34 let data = tensor.data_f32();
35
36 let channels = if dims.len() == 3 { dims[0] } else { dims[1] };
38 let spatial_size: usize = dims[dims.len()-2..].iter().product();
39
40 let mut result = data.clone();
41
42 if dims.len() == 3 {
43 for c in 0..channels {
45 let start = c * spatial_size;
46 let end = start + spatial_size;
47 for i in start..end {
48 result[i] = (result[i] - self.mean[c]) / self.std[c];
49 }
50 }
51 } else {
52 let batch_size = dims[0];
54 let batch_stride = channels * spatial_size;
55
56 for b in 0..batch_size {
57 for c in 0..channels {
58 let start = b * batch_stride + c * spatial_size;
59 let end = start + spatial_size;
60 for i in start..end {
61 result[i] = (result[i] - self.mean[c]) / self.std[c];
62 }
63 }
64 }
65 }
66
67 Tensor::from_slice(&result, dims).unwrap()
68 }
69}
70
71pub struct RandomHorizontalFlip {
73 p: f32,
74}
75
76impl RandomHorizontalFlip {
77 pub fn new(p: f32) -> Self {
78 RandomHorizontalFlip { p }
79 }
80}
81
82impl Default for RandomHorizontalFlip {
83 fn default() -> Self {
84 Self::new(0.5)
85 }
86}
87
88impl Transform for RandomHorizontalFlip {
89 fn apply(&self, tensor: &Tensor) -> Tensor {
90 if rand::thread_rng().gen::<f32>() > self.p {
91 return tensor.clone();
92 }
93
94 let dims = tensor.dims();
95 let data = tensor.data_f32();
96
97 let (height, width) = if dims.len() == 3 {
99 (dims[1], dims[2])
100 } else {
101 (dims[2], dims[3])
102 };
103
104 let mut result = data.clone();
105
106 if dims.len() == 3 {
108 let channels = dims[0];
109 for c in 0..channels {
110 for h in 0..height {
111 for w in 0..width / 2 {
112 let idx1 = c * height * width + h * width + w;
113 let idx2 = c * height * width + h * width + (width - 1 - w);
114 result.swap(idx1, idx2);
115 }
116 }
117 }
118 }
119
120 Tensor::from_slice(&result, dims).unwrap()
121 }
122}
123
124pub struct RandomCrop {
126 size: (usize, usize),
127}
128
129impl RandomCrop {
130 pub fn new(height: usize, width: usize) -> Self {
131 RandomCrop { size: (height, width) }
132 }
133}
134
135impl Transform for RandomCrop {
136 fn apply(&self, tensor: &Tensor) -> Tensor {
137 let dims = tensor.dims();
138 let data = tensor.data_f32();
139
140 let (channels, in_h, in_w) = if dims.len() == 3 {
141 (dims[0], dims[1], dims[2])
142 } else {
143 panic!("RandomCrop expects 3D tensor [C, H, W]");
144 };
145
146 let (out_h, out_w) = self.size;
147
148 if in_h < out_h || in_w < out_w {
149 panic!("Crop size larger than input");
150 }
151
152 let mut rng = rand::thread_rng();
153 let top = rng.gen_range(0..=in_h - out_h);
154 let left = rng.gen_range(0..=in_w - out_w);
155
156 let mut result = Vec::with_capacity(channels * out_h * out_w);
157
158 for c in 0..channels {
159 for h in 0..out_h {
160 for w in 0..out_w {
161 let src_idx = c * in_h * in_w + (top + h) * in_w + (left + w);
162 result.push(data[src_idx]);
163 }
164 }
165 }
166
167 Tensor::from_slice(&result, &[channels, out_h, out_w]).unwrap()
168 }
169}
170
171pub struct Compose {
173 transforms: Vec<Box<dyn Transform>>,
174}
175
176impl Compose {
177 pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
178 Compose { transforms }
179 }
180}
181
182impl Transform for Compose {
183 fn apply(&self, tensor: &Tensor) -> Tensor {
184 let mut result = tensor.clone();
185 for t in &self.transforms {
186 result = t.apply(&result);
187 }
188 result
189 }
190}
191
192pub struct ToTensor;
194
195impl Transform for ToTensor {
196 fn apply(&self, tensor: &Tensor) -> Tensor {
197 tensor.clone()
198 }
199}
200
201pub struct RandomErasing {
203 p: f32,
204 scale: (f32, f32),
205 ratio: (f32, f32),
206 value: f32,
207}
208
209impl RandomErasing {
210 pub fn new(p: f32) -> Self {
211 RandomErasing {
212 p,
213 scale: (0.02, 0.33),
214 ratio: (0.3, 3.3),
215 value: 0.0,
216 }
217 }
218}
219
220impl Default for RandomErasing {
221 fn default() -> Self {
222 Self::new(0.5)
223 }
224}
225
226impl Transform for RandomErasing {
227 fn apply(&self, tensor: &Tensor) -> Tensor {
228 if rand::thread_rng().gen::<f32>() > self.p {
229 return tensor.clone();
230 }
231
232 let dims = tensor.dims();
233 let mut data = tensor.data_f32();
234
235 let (channels, height, width) = if dims.len() == 3 {
236 (dims[0], dims[1], dims[2])
237 } else {
238 return tensor.clone();
239 };
240
241 let area = (height * width) as f32;
242 let mut rng = rand::thread_rng();
243
244 for _ in 0..10 {
245 let target_area = rng.gen_range(self.scale.0..self.scale.1) * area;
246 let aspect_ratio = rng.gen_range(self.ratio.0..self.ratio.1);
247
248 let h = (target_area * aspect_ratio).sqrt() as usize;
249 let w = (target_area / aspect_ratio).sqrt() as usize;
250
251 if h < height && w < width {
252 let top = rng.gen_range(0..height - h);
253 let left = rng.gen_range(0..width - w);
254
255 for c in 0..channels {
256 for y in top..top + h {
257 for x in left..left + w {
258 let idx = c * height * width + y * width + x;
259 data[idx] = self.value;
260 }
261 }
262 }
263 break;
264 }
265 }
266
267 Tensor::from_slice(&data, dims).unwrap()
268 }
269}