use super::tensor::Tensor;
use super::model::{Model, Sequential, DenseLayer, Layer};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum StylePreset {
Pencil,
Neon,
Retro,
Gothic,
Watercolor,
}
impl StylePreset {
pub fn params(&self) -> (f32, f32, usize) {
match self {
StylePreset::Pencil => (1.0, 0.001, 50),
StylePreset::Neon => (1.0, 0.01, 80),
StylePreset::Retro => (1.0, 0.005, 60),
StylePreset::Gothic => (1.0, 0.008, 70),
StylePreset::Watercolor => (1.0, 0.003, 40),
}
}
}
pub struct StyleTransfer {
pub content_model: Model,
pub style_model: Model,
pub iterations: usize,
pub content_weight: f32,
pub style_weight: f32,
pub learning_rate: f32,
}
impl StyleTransfer {
pub fn new(content_model: Model, style_model: Model) -> Self {
Self {
content_model,
style_model,
iterations: 100,
content_weight: 1.0,
style_weight: 0.01,
learning_rate: 0.01,
}
}
pub fn from_preset(preset: StylePreset) -> Self {
let (cw, sw, iters) = preset.params();
let content_model = Sequential::new("content_extractor")
.dense(64, 32)
.relu()
.build();
let style_model = Sequential::new("style_extractor")
.dense(64, 32)
.relu()
.build();
Self {
content_model,
style_model,
iterations: iters,
content_weight: cw,
style_weight: sw,
learning_rate: 0.01,
}
}
pub fn gram_matrix(features: &Tensor) -> Tensor {
let f = if features.shape.len() == 1 {
features.reshape(vec![1, features.data.len()])
} else if features.shape.len() == 2 {
features.clone()
} else {
let c = features.shape[0];
let spatial: usize = features.shape[1..].iter().product();
features.reshape(vec![c, spatial])
};
let ft = f.transpose();
Tensor::matmul(&f, &ft)
}
pub fn content_loss(generated: &Tensor, target: &Tensor) -> f32 {
assert_eq!(generated.data.len(), target.data.len());
let n = generated.data.len() as f32;
generated.data.iter().zip(&target.data)
.map(|(g, t)| (g - t) * (g - t))
.sum::<f32>() / n
}
pub fn style_loss(generated_gram: &Tensor, target_gram: &Tensor) -> f32 {
Self::content_loss(generated_gram, target_gram)
}
pub fn total_loss(
&self,
gen_content_features: &Tensor,
target_content_features: &Tensor,
gen_style_features: &Tensor,
target_style_features: &Tensor,
) -> f32 {
let cl = Self::content_loss(gen_content_features, target_content_features);
let gen_gram = Self::gram_matrix(gen_style_features);
let target_gram = Self::gram_matrix(target_style_features);
let sl = Self::style_loss(&gen_gram, &target_gram);
self.content_weight * cl + self.style_weight * sl
}
pub fn transfer(&self, content: &Tensor, style: &Tensor) -> Tensor {
assert_eq!(content.shape, style.shape);
let target_content_feat = self.content_model.forward(content);
let target_style_feat = self.style_model.forward(style);
let target_style_gram = Self::gram_matrix(&target_style_feat);
let mut generated = content.clone();
let lr = self.learning_rate;
for _iter in 0..self.iterations {
let gen_content_feat = self.content_model.forward(&generated);
let gen_style_feat = self.style_model.forward(&generated);
let gen_style_gram = Self::gram_matrix(&gen_style_feat);
let n = generated.data.len();
let eps = 1e-4f32;
let content_diff = gen_content_feat.sub(&target_content_feat);
let style_diff = gen_style_gram.sub(&target_style_gram);
let content_grad_scale = self.content_weight * 2.0 / n as f32;
let style_grad_scale = self.style_weight * 2.0 / gen_style_gram.data.len().max(1) as f32;
let content_signal = content_diff.mean();
let style_signal = style_diff.mean();
let total_signal = content_grad_scale * content_signal + style_grad_scale * style_signal;
for i in 0..n {
let toward_content = (content.data[i] - generated.data[i]) * 0.1;
let toward_style = (style.data[i] - generated.data[i]) * 0.05;
generated.data[i] += lr * (toward_content * self.content_weight
+ toward_style * self.style_weight
- total_signal * 0.01);
}
}
generated
}
}
pub struct AsciiStyleTransfer {
pub preset: StylePreset,
}
impl AsciiStyleTransfer {
pub fn new(preset: StylePreset) -> Self {
Self { preset }
}
pub fn apply(&self, values: &Tensor) -> Tensor {
let data: Vec<f32> = match self.preset {
StylePreset::Pencil => {
values.data.iter().map(|&v| {
if v > 0.5 { (v * 1.5).min(1.0) } else { (v * 0.3).max(0.0) }
}).collect()
}
StylePreset::Neon => {
values.data.iter().map(|&v| {
let boosted = v * 2.0;
(1.0 / (1.0 + (-10.0 * (boosted - 0.5)).exp())).min(1.0)
}).collect()
}
StylePreset::Retro => {
values.data.iter().map(|&v| {
((v * 4.0).floor() / 4.0).clamp(0.0, 1.0)
}).collect()
}
StylePreset::Gothic => {
values.data.iter().map(|&v| {
(v * v * 1.2).min(1.0)
}).collect()
}
StylePreset::Watercolor => {
let n = values.data.len();
let mut out = vec![0.0f32; n];
for i in 0..n {
let prev = if i > 0 { values.data[i - 1] } else { values.data[i] };
let next = if i + 1 < n { values.data[i + 1] } else { values.data[i] };
out[i] = (prev * 0.25 + values.data[i] * 0.5 + next * 0.25).clamp(0.0, 1.0);
}
out
}
};
Tensor { shape: values.shape.clone(), data }
}
pub fn tint_colors(&self, colors: &Tensor) -> Tensor {
assert_eq!(colors.shape.len(), 2);
assert_eq!(colors.shape[1], 4);
let n = colors.shape[0];
let mut data = colors.data.clone();
let (r_mul, g_mul, b_mul) = match self.preset {
StylePreset::Pencil => (0.9, 0.9, 0.9),
StylePreset::Neon => (1.2, 0.3, 1.5),
StylePreset::Retro => (1.1, 0.8, 0.5),
StylePreset::Gothic => (0.3, 0.1, 0.3),
StylePreset::Watercolor => (0.8, 0.9, 1.1),
};
for i in 0..n {
let base = i * 4;
data[base] = (data[base] * r_mul).clamp(0.0, 1.0);
data[base + 1] = (data[base + 1] * g_mul).clamp(0.0, 1.0);
data[base + 2] = (data[base + 2] * b_mul).clamp(0.0, 1.0);
}
Tensor { shape: colors.shape.clone(), data }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gram_matrix() {
let f = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let g = StyleTransfer::gram_matrix(&f);
assert_eq!(g.shape, vec![2, 2]);
assert_eq!(g.get(&[0, 0]), 14.0);
assert_eq!(g.get(&[0, 1]), 32.0);
}
#[test]
fn test_content_loss() {
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
let b = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
assert_eq!(StyleTransfer::content_loss(&a, &b), 0.0);
let c = Tensor::from_vec(vec![2.0, 3.0, 4.0], vec![3]);
let loss = StyleTransfer::content_loss(&a, &c);
assert!((loss - 1.0).abs() < 1e-5); }
#[test]
fn test_style_loss() {
let a = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
let b = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
assert_eq!(StyleTransfer::style_loss(&a, &b), 0.0);
}
#[test]
fn test_transfer_preserves_shape() {
let st = StyleTransfer::from_preset(StylePreset::Pencil);
let content = Tensor::rand(vec![1, 64], 42);
let style = Tensor::rand(vec![1, 64], 99);
let result = st.transfer(&content, &style);
assert_eq!(result.shape, content.shape);
}
#[test]
fn test_ascii_style_pencil() {
let ast = AsciiStyleTransfer::new(StylePreset::Pencil);
let vals = Tensor::from_vec(vec![0.1, 0.5, 0.9], vec![3]);
let result = ast.apply(&vals);
assert_eq!(result.shape, vec![3]);
assert!(result.data[0] < vals.data[0]);
assert!(result.data[2] > vals.data[2] || (result.data[2] - 1.0).abs() < 1e-5);
}
#[test]
fn test_ascii_style_retro_quantizes() {
let ast = AsciiStyleTransfer::new(StylePreset::Retro);
let vals = Tensor::from_vec(vec![0.13, 0.37, 0.62, 0.88], vec![4]);
let result = ast.apply(&vals);
for &v in &result.data {
let remainder = (v * 4.0) - (v * 4.0).floor();
assert!(remainder.abs() < 1e-5);
}
}
#[test]
fn test_tint_colors() {
let ast = AsciiStyleTransfer::new(StylePreset::Neon);
let colors = Tensor::from_vec(vec![0.5, 0.5, 0.5, 1.0], vec![1, 4]);
let tinted = ast.tint_colors(&colors);
assert_eq!(tinted.shape, vec![1, 4]);
assert!(tinted.data[0] > 0.5); assert!(tinted.data[1] < 0.5); assert!(tinted.data[2] > 0.5); assert_eq!(tinted.data[3], 1.0); }
#[test]
fn test_all_presets() {
for preset in &[StylePreset::Pencil, StylePreset::Neon, StylePreset::Retro, StylePreset::Gothic, StylePreset::Watercolor] {
let ast = AsciiStyleTransfer::new(*preset);
let vals = Tensor::from_vec(vec![0.3, 0.6, 0.9], vec![3]);
let result = ast.apply(&vals);
assert_eq!(result.shape, vec![3]);
for &v in &result.data {
assert!(v >= 0.0 && v <= 1.0, "preset {:?} produced out-of-range value {v}", preset);
}
}
}
}