Skip to main content

yscv_model/
transform.rs

1use std::sync::atomic::{AtomicU64, Ordering};
2
3use super::error::ModelError;
4use yscv_tensor::Tensor;
5
6/// Trait for deterministic tensor transforms (preprocessing).
7pub trait Transform: Send + Sync {
8    fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError>;
9}
10
11/// Chains multiple transforms sequentially.
12pub struct Compose {
13    transforms: Vec<Box<dyn Transform>>,
14}
15
16impl Default for Compose {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl Compose {
23    pub fn new() -> Self {
24        Self {
25            transforms: Vec::new(),
26        }
27    }
28
29    pub fn add<T: Transform + 'static>(mut self, t: T) -> Self {
30        self.transforms.push(Box::new(t));
31        self
32    }
33}
34
35impl Transform for Compose {
36    fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError> {
37        let mut current = input.clone();
38        for t in &self.transforms {
39            current = t.apply(&current)?;
40        }
41        Ok(current)
42    }
43}
44
45/// Normalize channels: `(x - mean) / std`
46pub struct Normalize {
47    pub mean: Vec<f32>,
48    pub std: Vec<f32>,
49}
50
51impl Normalize {
52    pub fn new(mean: Vec<f32>, std: Vec<f32>) -> Self {
53        Self { mean, std }
54    }
55}
56
57impl Transform for Normalize {
58    fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError> {
59        // Apply per-channel normalization on last dim
60        let data = input.data();
61        let c = self.mean.len();
62        let mut out = data.to_vec();
63        for (i, val) in out.iter_mut().enumerate() {
64            let ch = i % c;
65            *val = (*val - self.mean[ch]) / self.std[ch];
66        }
67        Ok(Tensor::from_vec(input.shape().to_vec(), out)?)
68    }
69}
70
71/// Scale f32 values by a constant factor.
72pub struct ScaleValues {
73    pub factor: f32,
74}
75
76impl ScaleValues {
77    pub fn new(factor: f32) -> Self {
78        Self { factor }
79    }
80}
81
82impl Transform for ScaleValues {
83    fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError> {
84        Ok(input.scale(self.factor))
85    }
86}
87
88/// Permute dimensions.
89pub struct PermuteDims {
90    pub order: Vec<usize>,
91}
92
93impl PermuteDims {
94    pub fn new(order: Vec<usize>) -> Self {
95        Self { order }
96    }
97}
98
99impl Transform for PermuteDims {
100    fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError> {
101        Ok(input.permute(&self.order)?)
102    }
103}
104
105/// Resize image tensor to target height and width using bilinear interpolation.
106/// Input shape: `[H, W, C]`.
107pub struct Resize {
108    pub height: usize,
109    pub width: usize,
110}
111
112impl Resize {
113    pub fn new(height: usize, width: usize) -> Self {
114        Self { height, width }
115    }
116}
117
118impl Transform for Resize {
119    fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError> {
120        let shape = input.shape();
121        if shape.len() != 3 {
122            return Err(ModelError::InvalidInputShape {
123                expected_features: 3,
124                got: shape.to_vec(),
125            });
126        }
127        let (in_h, in_w, c) = (shape[0], shape[1], shape[2]);
128        let data = input.data();
129        let out_h = self.height;
130        let out_w = self.width;
131        let mut out = vec![0.0f32; out_h * out_w * c];
132
133        for oh in 0..out_h {
134            for ow in 0..out_w {
135                // Map output pixel center to input coordinates
136                let sy = if out_h > 1 {
137                    oh as f32 * (in_h as f32 - 1.0) / (out_h as f32 - 1.0)
138                } else {
139                    (in_h as f32 - 1.0) / 2.0
140                };
141                let sx = if out_w > 1 {
142                    ow as f32 * (in_w as f32 - 1.0) / (out_w as f32 - 1.0)
143                } else {
144                    (in_w as f32 - 1.0) / 2.0
145                };
146
147                let y0 = sy.floor() as usize;
148                let x0 = sx.floor() as usize;
149                let y1 = (y0 + 1).min(in_h - 1);
150                let x1 = (x0 + 1).min(in_w - 1);
151                let fy = sy - sy.floor();
152                let fx = sx - sx.floor();
153
154                for ch in 0..c {
155                    let v00 = data[(y0 * in_w + x0) * c + ch];
156                    let v01 = data[(y0 * in_w + x1) * c + ch];
157                    let v10 = data[(y1 * in_w + x0) * c + ch];
158                    let v11 = data[(y1 * in_w + x1) * c + ch];
159                    let val = v00 * (1.0 - fy) * (1.0 - fx)
160                        + v01 * (1.0 - fy) * fx
161                        + v10 * fy * (1.0 - fx)
162                        + v11 * fy * fx;
163                    out[(oh * out_w + ow) * c + ch] = val;
164                }
165            }
166        }
167        Ok(Tensor::from_vec(vec![out_h, out_w, c], out)?)
168    }
169}
170
171/// Crop the center region of an image tensor.
172/// Input shape: `[H, W, C]`. Output shape: `[size, size, C]`.
173pub struct CenterCrop {
174    pub size: usize,
175}
176
177impl CenterCrop {
178    pub fn new(size: usize) -> Self {
179        Self { size }
180    }
181}
182
183impl Transform for CenterCrop {
184    fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError> {
185        let shape = input.shape();
186        if shape.len() != 3 {
187            return Err(ModelError::InvalidInputShape {
188                expected_features: 3,
189                got: shape.to_vec(),
190            });
191        }
192        let (h, w) = (shape[0], shape[1]);
193        let start_h = (h.saturating_sub(self.size)) / 2;
194        let start_w = (w.saturating_sub(self.size)) / 2;
195        let cropped = input.narrow(0, start_h, self.size)?;
196        let cropped = cropped.narrow(1, start_w, self.size)?;
197        Ok(cropped)
198    }
199}
200
201/// Randomly flip horizontally with probability `p`.
202/// Uses xorshift64 PRNG seeded at construction.
203/// Input shape: `[H, W, C]`.
204pub struct RandomHorizontalFlip {
205    p: f32,
206    seed: AtomicU64,
207}
208
209impl RandomHorizontalFlip {
210    pub fn new(p: f32, seed: u64) -> Self {
211        Self {
212            p,
213            seed: AtomicU64::new(seed),
214        }
215    }
216
217    fn next_rand(&self) -> f32 {
218        let mut s = self.seed.load(Ordering::Relaxed);
219        s ^= s << 13;
220        s ^= s >> 7;
221        s ^= s << 17;
222        self.seed.store(s, Ordering::Relaxed);
223        // Map to [0, 1)
224        (s as u32 as f32) / (u32::MAX as f32)
225    }
226}
227
228impl Transform for RandomHorizontalFlip {
229    fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError> {
230        let shape = input.shape();
231        if shape.len() != 3 {
232            return Err(ModelError::InvalidInputShape {
233                expected_features: 3,
234                got: shape.to_vec(),
235            });
236        }
237        if self.next_rand() >= self.p {
238            return Ok(input.clone());
239        }
240        let (h, w, c) = (shape[0], shape[1], shape[2]);
241        let data = input.data();
242        let mut out = vec![0.0f32; h * w * c];
243        for row in 0..h {
244            for col in 0..w {
245                let src = (row * w + (w - 1 - col)) * c;
246                let dst = (row * w + col) * c;
247                out[dst..dst + c].copy_from_slice(&data[src..src + c]);
248            }
249        }
250        Ok(Tensor::from_vec(shape.to_vec(), out)?)
251    }
252}
253
254/// Apply Gaussian blur to an image tensor.
255/// Input shape: `[H, W, C]`.
256pub struct GaussianBlur {
257    pub kernel_size: usize,
258    pub sigma: f32,
259}
260
261impl GaussianBlur {
262    pub fn new(kernel_size: usize, sigma: f32) -> Self {
263        Self { kernel_size, sigma }
264    }
265
266    fn build_kernel(&self) -> Vec<f32> {
267        let ks = self.kernel_size;
268        let half = ks as f32 / 2.0;
269        let mut kernel = vec![0.0f32; ks * ks];
270        let mut sum = 0.0f32;
271        for ky in 0..ks {
272            for kx in 0..ks {
273                let dy = ky as f32 - half + 0.5;
274                let dx = kx as f32 - half + 0.5;
275                let val = (-(dy * dy + dx * dx) / (2.0 * self.sigma * self.sigma)).exp();
276                kernel[ky * ks + kx] = val;
277                sum += val;
278            }
279        }
280        for v in kernel.iter_mut() {
281            *v /= sum;
282        }
283        kernel
284    }
285}
286
287impl Transform for GaussianBlur {
288    fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError> {
289        let shape = input.shape();
290        if shape.len() != 3 {
291            return Err(ModelError::InvalidInputShape {
292                expected_features: 3,
293                got: shape.to_vec(),
294            });
295        }
296        let (h, w, c) = (shape[0], shape[1], shape[2]);
297        let ks = self.kernel_size;
298        let pad = ks / 2;
299        let kernel = self.build_kernel();
300        let data = input.data();
301        let mut out = vec![0.0f32; h * w * c];
302
303        for row in 0..h {
304            for col in 0..w {
305                for ch in 0..c {
306                    let mut acc = 0.0f32;
307                    for ky in 0..ks {
308                        for kx in 0..ks {
309                            let sy = row as isize + ky as isize - pad as isize;
310                            let sx = col as isize + kx as isize - pad as isize;
311                            // Clamp to border
312                            let sy = sy.max(0).min(h as isize - 1) as usize;
313                            let sx = sx.max(0).min(w as isize - 1) as usize;
314                            acc += data[(sy * w + sx) * c + ch] * kernel[ky * ks + kx];
315                        }
316                    }
317                    out[(row * w + col) * c + ch] = acc;
318                }
319            }
320        }
321        Ok(Tensor::from_vec(shape.to_vec(), out)?)
322    }
323}