1use super::tensor::Tensor;
4use super::model::{Model, Sequential, DenseLayer, Layer};
5
6#[derive(Debug, Clone, Copy, PartialEq)]
8pub enum StylePreset {
9 Pencil,
10 Neon,
11 Retro,
12 Gothic,
13 Watercolor,
14}
15
16impl StylePreset {
17 pub fn params(&self) -> (f32, f32, usize) {
19 match self {
20 StylePreset::Pencil => (1.0, 0.001, 50),
21 StylePreset::Neon => (1.0, 0.01, 80),
22 StylePreset::Retro => (1.0, 0.005, 60),
23 StylePreset::Gothic => (1.0, 0.008, 70),
24 StylePreset::Watercolor => (1.0, 0.003, 40),
25 }
26 }
27}
28
29pub struct StyleTransfer {
31 pub content_model: Model,
32 pub style_model: Model,
33 pub iterations: usize,
34 pub content_weight: f32,
35 pub style_weight: f32,
36 pub learning_rate: f32,
37}
38
39impl StyleTransfer {
40 pub fn new(content_model: Model, style_model: Model) -> Self {
41 Self {
42 content_model,
43 style_model,
44 iterations: 100,
45 content_weight: 1.0,
46 style_weight: 0.01,
47 learning_rate: 0.01,
48 }
49 }
50
51 pub fn from_preset(preset: StylePreset) -> Self {
53 let (cw, sw, iters) = preset.params();
54 let content_model = Sequential::new("content_extractor")
56 .dense(64, 32)
57 .relu()
58 .build();
59 let style_model = Sequential::new("style_extractor")
60 .dense(64, 32)
61 .relu()
62 .build();
63 Self {
64 content_model,
65 style_model,
66 iterations: iters,
67 content_weight: cw,
68 style_weight: sw,
69 learning_rate: 0.01,
70 }
71 }
72
73 pub fn gram_matrix(features: &Tensor) -> Tensor {
76 let f = if features.shape.len() == 1 {
77 features.reshape(vec![1, features.data.len()])
78 } else if features.shape.len() == 2 {
79 features.clone()
80 } else {
81 let c = features.shape[0];
83 let spatial: usize = features.shape[1..].iter().product();
84 features.reshape(vec![c, spatial])
85 };
86 let ft = f.transpose();
87 Tensor::matmul(&f, &ft)
88 }
89
90 pub fn content_loss(generated: &Tensor, target: &Tensor) -> f32 {
92 assert_eq!(generated.data.len(), target.data.len());
93 let n = generated.data.len() as f32;
94 generated.data.iter().zip(&target.data)
95 .map(|(g, t)| (g - t) * (g - t))
96 .sum::<f32>() / n
97 }
98
99 pub fn style_loss(generated_gram: &Tensor, target_gram: &Tensor) -> f32 {
101 Self::content_loss(generated_gram, target_gram)
102 }
103
104 pub fn total_loss(
106 &self,
107 gen_content_features: &Tensor,
108 target_content_features: &Tensor,
109 gen_style_features: &Tensor,
110 target_style_features: &Tensor,
111 ) -> f32 {
112 let cl = Self::content_loss(gen_content_features, target_content_features);
113 let gen_gram = Self::gram_matrix(gen_style_features);
114 let target_gram = Self::gram_matrix(target_style_features);
115 let sl = Self::style_loss(&gen_gram, &target_gram);
116 self.content_weight * cl + self.style_weight * sl
117 }
118
119 pub fn transfer(&self, content: &Tensor, style: &Tensor) -> Tensor {
123 assert_eq!(content.shape, style.shape);
124 let target_content_feat = self.content_model.forward(content);
126 let target_style_feat = self.style_model.forward(style);
127 let target_style_gram = Self::gram_matrix(&target_style_feat);
128
129 let mut generated = content.clone();
131 let lr = self.learning_rate;
132
133 for _iter in 0..self.iterations {
134 let gen_content_feat = self.content_model.forward(&generated);
135 let gen_style_feat = self.style_model.forward(&generated);
136 let gen_style_gram = Self::gram_matrix(&gen_style_feat);
137
138 let n = generated.data.len();
141 let eps = 1e-4f32;
142
143 let content_diff = gen_content_feat.sub(&target_content_feat);
146 let style_diff = gen_style_gram.sub(&target_style_gram);
147
148 let content_grad_scale = self.content_weight * 2.0 / n as f32;
151 let style_grad_scale = self.style_weight * 2.0 / gen_style_gram.data.len().max(1) as f32;
152
153 let content_signal = content_diff.mean();
155 let style_signal = style_diff.mean();
156 let total_signal = content_grad_scale * content_signal + style_grad_scale * style_signal;
157
158 for i in 0..n {
159 let toward_content = (content.data[i] - generated.data[i]) * 0.1;
161 let toward_style = (style.data[i] - generated.data[i]) * 0.05;
162 generated.data[i] += lr * (toward_content * self.content_weight
163 + toward_style * self.style_weight
164 - total_signal * 0.01);
165 }
166 }
167 generated
168 }
169}
170
171pub struct AsciiStyleTransfer {
173 pub preset: StylePreset,
174}
175
176impl AsciiStyleTransfer {
177 pub fn new(preset: StylePreset) -> Self {
178 Self { preset }
179 }
180
181 pub fn apply(&self, values: &Tensor) -> Tensor {
184 let data: Vec<f32> = match self.preset {
185 StylePreset::Pencil => {
186 values.data.iter().map(|&v| {
188 if v > 0.5 { (v * 1.5).min(1.0) } else { (v * 0.3).max(0.0) }
189 }).collect()
190 }
191 StylePreset::Neon => {
192 values.data.iter().map(|&v| {
194 let boosted = v * 2.0;
195 (1.0 / (1.0 + (-10.0 * (boosted - 0.5)).exp())).min(1.0)
196 }).collect()
197 }
198 StylePreset::Retro => {
199 values.data.iter().map(|&v| {
201 ((v * 4.0).floor() / 4.0).clamp(0.0, 1.0)
202 }).collect()
203 }
204 StylePreset::Gothic => {
205 values.data.iter().map(|&v| {
207 (v * v * 1.2).min(1.0)
208 }).collect()
209 }
210 StylePreset::Watercolor => {
211 let n = values.data.len();
213 let mut out = vec![0.0f32; n];
214 for i in 0..n {
215 let prev = if i > 0 { values.data[i - 1] } else { values.data[i] };
216 let next = if i + 1 < n { values.data[i + 1] } else { values.data[i] };
217 out[i] = (prev * 0.25 + values.data[i] * 0.5 + next * 0.25).clamp(0.0, 1.0);
218 }
219 out
220 }
221 };
222 Tensor { shape: values.shape.clone(), data }
223 }
224
225 pub fn tint_colors(&self, colors: &Tensor) -> Tensor {
227 assert_eq!(colors.shape.len(), 2);
228 assert_eq!(colors.shape[1], 4);
229 let n = colors.shape[0];
230 let mut data = colors.data.clone();
231 let (r_mul, g_mul, b_mul) = match self.preset {
232 StylePreset::Pencil => (0.9, 0.9, 0.9),
233 StylePreset::Neon => (1.2, 0.3, 1.5),
234 StylePreset::Retro => (1.1, 0.8, 0.5),
235 StylePreset::Gothic => (0.3, 0.1, 0.3),
236 StylePreset::Watercolor => (0.8, 0.9, 1.1),
237 };
238 for i in 0..n {
239 let base = i * 4;
240 data[base] = (data[base] * r_mul).clamp(0.0, 1.0);
241 data[base + 1] = (data[base + 1] * g_mul).clamp(0.0, 1.0);
242 data[base + 2] = (data[base + 2] * b_mul).clamp(0.0, 1.0);
243 }
245 Tensor { shape: colors.shape.clone(), data }
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_gram_matrix() {
255 let f = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
257 let g = StyleTransfer::gram_matrix(&f);
258 assert_eq!(g.shape, vec![2, 2]);
259 assert_eq!(g.get(&[0, 0]), 14.0);
261 assert_eq!(g.get(&[0, 1]), 32.0);
263 }
264
265 #[test]
266 fn test_content_loss() {
267 let a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
268 let b = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
269 assert_eq!(StyleTransfer::content_loss(&a, &b), 0.0);
270
271 let c = Tensor::from_vec(vec![2.0, 3.0, 4.0], vec![3]);
272 let loss = StyleTransfer::content_loss(&a, &c);
273 assert!((loss - 1.0).abs() < 1e-5); }
275
276 #[test]
277 fn test_style_loss() {
278 let a = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
279 let b = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
280 assert_eq!(StyleTransfer::style_loss(&a, &b), 0.0);
281 }
282
283 #[test]
284 fn test_transfer_preserves_shape() {
285 let st = StyleTransfer::from_preset(StylePreset::Pencil);
286 let content = Tensor::rand(vec![1, 64], 42);
288 let style = Tensor::rand(vec![1, 64], 99);
289 let result = st.transfer(&content, &style);
290 assert_eq!(result.shape, content.shape);
291 }
292
293 #[test]
294 fn test_ascii_style_pencil() {
295 let ast = AsciiStyleTransfer::new(StylePreset::Pencil);
296 let vals = Tensor::from_vec(vec![0.1, 0.5, 0.9], vec![3]);
297 let result = ast.apply(&vals);
298 assert_eq!(result.shape, vec![3]);
299 assert!(result.data[0] < vals.data[0]);
301 assert!(result.data[2] > vals.data[2] || (result.data[2] - 1.0).abs() < 1e-5);
302 }
303
304 #[test]
305 fn test_ascii_style_retro_quantizes() {
306 let ast = AsciiStyleTransfer::new(StylePreset::Retro);
307 let vals = Tensor::from_vec(vec![0.13, 0.37, 0.62, 0.88], vec![4]);
308 let result = ast.apply(&vals);
309 for &v in &result.data {
311 let remainder = (v * 4.0) - (v * 4.0).floor();
312 assert!(remainder.abs() < 1e-5);
313 }
314 }
315
316 #[test]
317 fn test_tint_colors() {
318 let ast = AsciiStyleTransfer::new(StylePreset::Neon);
319 let colors = Tensor::from_vec(vec![0.5, 0.5, 0.5, 1.0], vec![1, 4]);
320 let tinted = ast.tint_colors(&colors);
321 assert_eq!(tinted.shape, vec![1, 4]);
322 assert!(tinted.data[0] > 0.5); assert!(tinted.data[1] < 0.5); assert!(tinted.data[2] > 0.5); assert_eq!(tinted.data[3], 1.0); }
328
329 #[test]
330 fn test_all_presets() {
331 for preset in &[StylePreset::Pencil, StylePreset::Neon, StylePreset::Retro, StylePreset::Gothic, StylePreset::Watercolor] {
332 let ast = AsciiStyleTransfer::new(*preset);
333 let vals = Tensor::from_vec(vec![0.3, 0.6, 0.9], vec![3]);
334 let result = ast.apply(&vals);
335 assert_eq!(result.shape, vec![3]);
336 for &v in &result.data {
338 assert!(v >= 0.0 && v <= 1.0, "preset {:?} produced out-of-range value {v}", preset);
339 }
340 }
341 }
342}