Skip to main content

proof_engine/ml/
style_transfer.rs

1//! Neural style transfer for applying artistic styles to content tensors.
2
3use super::tensor::Tensor;
4use super::model::{Model, Sequential, DenseLayer, Layer};
5
6/// Pre-baked style presets.
7#[derive(Debug, Clone, Copy, PartialEq)]
8pub enum StylePreset {
9    Pencil,
10    Neon,
11    Retro,
12    Gothic,
13    Watercolor,
14}
15
16impl StylePreset {
17    /// Return (content_weight, style_weight, iterations) tuning for this preset.
18    pub fn params(&self) -> (f32, f32, usize) {
19        match self {
20            StylePreset::Pencil => (1.0, 0.001, 50),
21            StylePreset::Neon => (1.0, 0.01, 80),
22            StylePreset::Retro => (1.0, 0.005, 60),
23            StylePreset::Gothic => (1.0, 0.008, 70),
24            StylePreset::Watercolor => (1.0, 0.003, 40),
25        }
26    }
27}
28
29/// Style transfer engine.
30pub struct StyleTransfer {
31    pub content_model: Model,
32    pub style_model: Model,
33    pub iterations: usize,
34    pub content_weight: f32,
35    pub style_weight: f32,
36    pub learning_rate: f32,
37}
38
39impl StyleTransfer {
40    pub fn new(content_model: Model, style_model: Model) -> Self {
41        Self {
42            content_model,
43            style_model,
44            iterations: 100,
45            content_weight: 1.0,
46            style_weight: 0.01,
47            learning_rate: 0.01,
48        }
49    }
50
51    /// Create a style transfer engine from a preset.
52    pub fn from_preset(preset: StylePreset) -> Self {
53        let (cw, sw, iters) = preset.params();
54        // Simple feature extractor models
55        let content_model = Sequential::new("content_extractor")
56            .dense(64, 32)
57            .relu()
58            .build();
59        let style_model = Sequential::new("style_extractor")
60            .dense(64, 32)
61            .relu()
62            .build();
63        Self {
64            content_model,
65            style_model,
66            iterations: iters,
67            content_weight: cw,
68            style_weight: sw,
69            learning_rate: 0.01,
70        }
71    }
72
73    /// Compute the Gram matrix: G = F^T * F, where F has shape (C, N).
74    /// If the input is flattened or 1-D, reshape to (sqrt, sqrt) approximately.
75    pub fn gram_matrix(features: &Tensor) -> Tensor {
76        let f = if features.shape.len() == 1 {
77            features.reshape(vec![1, features.data.len()])
78        } else if features.shape.len() == 2 {
79            features.clone()
80        } else {
81            // (C, H, W) -> (C, H*W)
82            let c = features.shape[0];
83            let spatial: usize = features.shape[1..].iter().product();
84            features.reshape(vec![c, spatial])
85        };
86        let ft = f.transpose();
87        Tensor::matmul(&f, &ft)
88    }
89
90    /// Content loss: mean squared error between generated and target features.
91    pub fn content_loss(generated: &Tensor, target: &Tensor) -> f32 {
92        assert_eq!(generated.data.len(), target.data.len());
93        let n = generated.data.len() as f32;
94        generated.data.iter().zip(&target.data)
95            .map(|(g, t)| (g - t) * (g - t))
96            .sum::<f32>() / n
97    }
98
99    /// Style loss: MSE between Gram matrices of generated and target features.
100    pub fn style_loss(generated_gram: &Tensor, target_gram: &Tensor) -> f32 {
101        Self::content_loss(generated_gram, target_gram)
102    }
103
104    /// Total loss combining content and style.
105    pub fn total_loss(
106        &self,
107        gen_content_features: &Tensor,
108        target_content_features: &Tensor,
109        gen_style_features: &Tensor,
110        target_style_features: &Tensor,
111    ) -> f32 {
112        let cl = Self::content_loss(gen_content_features, target_content_features);
113        let gen_gram = Self::gram_matrix(gen_style_features);
114        let target_gram = Self::gram_matrix(target_style_features);
115        let sl = Self::style_loss(&gen_gram, &target_gram);
116        self.content_weight * cl + self.style_weight * sl
117    }
118
119    /// Run iterative style transfer optimization.
120    /// content and style should have the same shape.
121    /// Returns a generated tensor of the same shape.
122    pub fn transfer(&self, content: &Tensor, style: &Tensor) -> Tensor {
123        assert_eq!(content.shape, style.shape);
124        // Extract target features
125        let target_content_feat = self.content_model.forward(content);
126        let target_style_feat = self.style_model.forward(style);
127        let target_style_gram = Self::gram_matrix(&target_style_feat);
128
129        // Initialize generated image as content clone
130        let mut generated = content.clone();
131        let lr = self.learning_rate;
132
133        for _iter in 0..self.iterations {
134            let gen_content_feat = self.content_model.forward(&generated);
135            let gen_style_feat = self.style_model.forward(&generated);
136            let gen_style_gram = Self::gram_matrix(&gen_style_feat);
137
138            // Compute gradients via finite differences (simplified)
139            // For each element in generated, nudge and measure loss change
140            let n = generated.data.len();
141            let eps = 1e-4f32;
142
143            // Content gradient: d/dx MSE = 2*(gen - target) / N propagated through model
144            // Simplified: we use the feature-space gradient directly mapped back
145            let content_diff = gen_content_feat.sub(&target_content_feat);
146            let style_diff = gen_style_gram.sub(&target_style_gram);
147
148            // Approximate: update generated by blending towards reducing content loss
149            // and style loss. This is a simplified gradient step.
150            let content_grad_scale = self.content_weight * 2.0 / n as f32;
151            let style_grad_scale = self.style_weight * 2.0 / gen_style_gram.data.len().max(1) as f32;
152
153            // Direct pixel update heuristic: blend content signal and style signal
154            let content_signal = content_diff.mean();
155            let style_signal = style_diff.mean();
156            let total_signal = content_grad_scale * content_signal + style_grad_scale * style_signal;
157
158            for i in 0..n {
159                // Move each pixel slightly toward content value and away from loss
160                let toward_content = (content.data[i] - generated.data[i]) * 0.1;
161                let toward_style = (style.data[i] - generated.data[i]) * 0.05;
162                generated.data[i] += lr * (toward_content * self.content_weight
163                    + toward_style * self.style_weight
164                    - total_signal * 0.01);
165            }
166        }
167        generated
168    }
169}
170
171/// ASCII-art style transfer: applies style modifications to glyph emission/color values.
172pub struct AsciiStyleTransfer {
173    pub preset: StylePreset,
174}
175
176impl AsciiStyleTransfer {
177    pub fn new(preset: StylePreset) -> Self {
178        Self { preset }
179    }
180
181    /// Apply style to a 1-D tensor of glyph values (brightness/emission).
182    /// Returns modified values biased by the style preset.
183    pub fn apply(&self, values: &Tensor) -> Tensor {
184        let data: Vec<f32> = match self.preset {
185            StylePreset::Pencil => {
186                // High contrast, emphasize edges
187                values.data.iter().map(|&v| {
188                    if v > 0.5 { (v * 1.5).min(1.0) } else { (v * 0.3).max(0.0) }
189                }).collect()
190            }
191            StylePreset::Neon => {
192                // Boost bright values, saturate
193                values.data.iter().map(|&v| {
194                    let boosted = v * 2.0;
195                    (1.0 / (1.0 + (-10.0 * (boosted - 0.5)).exp())).min(1.0)
196                }).collect()
197            }
198            StylePreset::Retro => {
199                // Quantize to 4 levels
200                values.data.iter().map(|&v| {
201                    ((v * 4.0).floor() / 4.0).clamp(0.0, 1.0)
202                }).collect()
203            }
204            StylePreset::Gothic => {
205                // Darken everything, high contrast
206                values.data.iter().map(|&v| {
207                    (v * v * 1.2).min(1.0)
208                }).collect()
209            }
210            StylePreset::Watercolor => {
211                // Soft, blurred feel via smoothing adjacent values
212                let n = values.data.len();
213                let mut out = vec![0.0f32; n];
214                for i in 0..n {
215                    let prev = if i > 0 { values.data[i - 1] } else { values.data[i] };
216                    let next = if i + 1 < n { values.data[i + 1] } else { values.data[i] };
217                    out[i] = (prev * 0.25 + values.data[i] * 0.5 + next * 0.25).clamp(0.0, 1.0);
218                }
219                out
220            }
221        };
222        Tensor { shape: values.shape.clone(), data }
223    }
224
225    /// Apply color tinting based on preset. Input: (N, 4) RGBA tensor.
226    pub fn tint_colors(&self, colors: &Tensor) -> Tensor {
227        assert_eq!(colors.shape.len(), 2);
228        assert_eq!(colors.shape[1], 4);
229        let n = colors.shape[0];
230        let mut data = colors.data.clone();
231        let (r_mul, g_mul, b_mul) = match self.preset {
232            StylePreset::Pencil => (0.9, 0.9, 0.9),
233            StylePreset::Neon => (1.2, 0.3, 1.5),
234            StylePreset::Retro => (1.1, 0.8, 0.5),
235            StylePreset::Gothic => (0.3, 0.1, 0.3),
236            StylePreset::Watercolor => (0.8, 0.9, 1.1),
237        };
238        for i in 0..n {
239            let base = i * 4;
240            data[base] = (data[base] * r_mul).clamp(0.0, 1.0);
241            data[base + 1] = (data[base + 1] * g_mul).clamp(0.0, 1.0);
242            data[base + 2] = (data[base + 2] * b_mul).clamp(0.0, 1.0);
243            // alpha unchanged
244        }
245        Tensor { shape: colors.shape.clone(), data }
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_gram_matrix() {
255        // 2x3 matrix -> gram is 2x2
256        let f = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
257        let g = StyleTransfer::gram_matrix(&f);
258        assert_eq!(g.shape, vec![2, 2]);
259        // G[0,0] = 1*1+2*2+3*3 = 14
260        assert_eq!(g.get(&[0, 0]), 14.0);
261        // G[0,1] = 1*4+2*5+3*6 = 32
262        assert_eq!(g.get(&[0, 1]), 32.0);
263    }
264
265    #[test]
266    fn test_content_loss() {
267        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
268        let b = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
269        assert_eq!(StyleTransfer::content_loss(&a, &b), 0.0);
270
271        let c = Tensor::from_vec(vec![2.0, 3.0, 4.0], vec![3]);
272        let loss = StyleTransfer::content_loss(&a, &c);
273        assert!((loss - 1.0).abs() < 1e-5); // MSE = (1+1+1)/3 = 1
274    }
275
276    #[test]
277    fn test_style_loss() {
278        let a = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
279        let b = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
280        assert_eq!(StyleTransfer::style_loss(&a, &b), 0.0);
281    }
282
283    #[test]
284    fn test_transfer_preserves_shape() {
285        let st = StyleTransfer::from_preset(StylePreset::Pencil);
286        // Need input matching model's expected input size (64)
287        let content = Tensor::rand(vec![1, 64], 42);
288        let style = Tensor::rand(vec![1, 64], 99);
289        let result = st.transfer(&content, &style);
290        assert_eq!(result.shape, content.shape);
291    }
292
293    #[test]
294    fn test_ascii_style_pencil() {
295        let ast = AsciiStyleTransfer::new(StylePreset::Pencil);
296        let vals = Tensor::from_vec(vec![0.1, 0.5, 0.9], vec![3]);
297        let result = ast.apply(&vals);
298        assert_eq!(result.shape, vec![3]);
299        // Pencil: low values get darker, high values get brighter
300        assert!(result.data[0] < vals.data[0]);
301        assert!(result.data[2] > vals.data[2] || (result.data[2] - 1.0).abs() < 1e-5);
302    }
303
304    #[test]
305    fn test_ascii_style_retro_quantizes() {
306        let ast = AsciiStyleTransfer::new(StylePreset::Retro);
307        let vals = Tensor::from_vec(vec![0.13, 0.37, 0.62, 0.88], vec![4]);
308        let result = ast.apply(&vals);
309        // Should be quantized to multiples of 0.25
310        for &v in &result.data {
311            let remainder = (v * 4.0) - (v * 4.0).floor();
312            assert!(remainder.abs() < 1e-5);
313        }
314    }
315
316    #[test]
317    fn test_tint_colors() {
318        let ast = AsciiStyleTransfer::new(StylePreset::Neon);
319        let colors = Tensor::from_vec(vec![0.5, 0.5, 0.5, 1.0], vec![1, 4]);
320        let tinted = ast.tint_colors(&colors);
321        assert_eq!(tinted.shape, vec![1, 4]);
322        // Neon boosts R and B, dims G
323        assert!(tinted.data[0] > 0.5); // R boosted
324        assert!(tinted.data[1] < 0.5); // G dimmed
325        assert!(tinted.data[2] > 0.5); // B boosted
326        assert_eq!(tinted.data[3], 1.0); // alpha unchanged
327    }
328
329    #[test]
330    fn test_all_presets() {
331        for preset in &[StylePreset::Pencil, StylePreset::Neon, StylePreset::Retro, StylePreset::Gothic, StylePreset::Watercolor] {
332            let ast = AsciiStyleTransfer::new(*preset);
333            let vals = Tensor::from_vec(vec![0.3, 0.6, 0.9], vec![3]);
334            let result = ast.apply(&vals);
335            assert_eq!(result.shape, vec![3]);
336            // All values should be in [0, 1]
337            for &v in &result.data {
338                assert!(v >= 0.0 && v <= 1.0, "preset {:?} produced out-of-range value {v}", preset);
339            }
340        }
341    }
342}