use crate::error::{Result, VisionError};
use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
pub fn gram_matrix(features: &Array3<f64>) -> Array2<f64> {
let (c, h, w) = features.dim();
let spatial = h * w;
let flat: Array2<f64> = features
.to_shape((c, spatial))
.map(|v| v.into_owned())
.unwrap_or_else(|_| {
let mut buf = Array2::zeros((c, spatial));
for ch in 0..c {
let mut idx = 0;
for row in 0..h {
for col in 0..w {
buf[[ch, idx]] = features[[ch, row, col]];
idx += 1;
}
}
}
buf
});
let scale = if spatial > 0 {
1.0 / spatial as f64
} else {
1.0
};
let mut gram = Array2::zeros((c, c));
for i in 0..c {
for j in 0..c {
let dot: f64 = flat
.row(i)
.iter()
.zip(flat.row(j).iter())
.map(|(a, b)| a * b)
.sum();
gram[[i, j]] = dot * scale;
}
}
gram
}
pub fn style_loss(generated_gram: &Array2<f64>, style_gram: &Array2<f64>) -> f64 {
debug_assert_eq!(
generated_gram.dim(),
style_gram.dim(),
"Gram matrices must have identical shapes"
);
let (c, _) = generated_gram.dim();
let denom = 4.0 * (c * c) as f64;
let denom = if denom > 0.0 { denom } else { 1.0 };
let sum_sq: f64 = generated_gram
.iter()
.zip(style_gram.iter())
.map(|(g, s)| {
let diff = g - s;
diff * diff
})
.sum();
sum_sq / denom
}
pub fn content_loss(generated: &Array3<f64>, content: &Array3<f64>) -> f64 {
debug_assert_eq!(
generated.dim(),
content.dim(),
"Feature maps must have identical shapes"
);
let n = generated.len();
if n == 0 {
return 0.0;
}
let sum_sq: f64 = generated
.iter()
.zip(content.iter())
.map(|(g, c)| {
let d = g - c;
d * d
})
.sum();
sum_sq / n as f64
}
pub fn total_variation_loss(image: &Array3<f64>) -> f64 {
let (c, h, w) = image.dim();
if h < 2 || w < 2 {
return 0.0;
}
let mut tv = 0.0_f64;
let n = c * (h - 1) * (w - 1);
for ch in 0..c {
for row in 0..h - 1 {
for col in 0..w - 1 {
let vert = (image[[ch, row + 1, col]] - image[[ch, row, col]]).abs();
let horiz = (image[[ch, row, col + 1]] - image[[ch, row, col]]).abs();
tv += vert + horiz;
}
}
}
if n > 0 {
tv / n as f64
} else {
0.0
}
}
#[derive(Debug, Clone)]
pub struct StyleTransferWeights {
pub content_weight: f64,
pub style_weight: f64,
pub tv_weight: f64,
}
impl Default for StyleTransferWeights {
fn default() -> Self {
Self {
content_weight: 1.0,
style_weight: 1e5,
tv_weight: 1e-4,
}
}
}
#[derive(Debug, Clone)]
pub struct StyleTransferLoss {
pub weights: StyleTransferWeights,
}
impl StyleTransferLoss {
pub fn new(weights: StyleTransferWeights) -> Self {
Self { weights }
}
pub fn total(
&self,
generated_gram: &Array2<f64>,
style_gram: &Array2<f64>,
generated: &Array3<f64>,
content: &Array3<f64>,
) -> f64 {
let lc = self.weights.content_weight * content_loss(generated, content);
let ls = self.weights.style_weight * style_loss(generated_gram, style_gram);
let ltv = self.weights.tv_weight * total_variation_loss(generated);
lc + ls + ltv
}
pub fn components(
&self,
generated_gram: &Array2<f64>,
style_gram: &Array2<f64>,
generated: &Array3<f64>,
content: &Array3<f64>,
) -> (f64, f64, f64) {
let lc = content_loss(generated, content);
let ls = style_loss(generated_gram, style_gram);
let ltv = total_variation_loss(generated);
(lc, ls, ltv)
}
}
fn style_gradient(features: &Array3<f64>, style_gram: &Array2<f64>) -> Array3<f64> {
let gen_gram = gram_matrix(features);
let (c, h, w) = features.dim();
let spatial = (h * w) as f64;
let denom = 4.0 * (c * c) as f64 * spatial.max(1.0);
let residual = &gen_gram - style_gram;
let mut grad = Array3::zeros((c, h, w));
for row in 0..h {
for col in 0..w {
for ci in 0..c {
let mut acc = 0.0_f64;
for cj in 0..c {
acc += residual[[ci, cj]] * features[[cj, row, col]];
}
grad[[ci, row, col]] = 2.0 * acc / denom;
}
}
}
grad
}
fn content_gradient(generated: &Array3<f64>, content: &Array3<f64>) -> Array3<f64> {
let n = generated.len().max(1) as f64;
(generated - content).mapv(|d| 2.0 * d / n)
}
fn tv_gradient(image: &Array3<f64>) -> Array3<f64> {
let (c, h, w) = image.dim();
if h < 2 || w < 2 {
return Array3::zeros((c, h, w));
}
let n = (c * (h - 1) * (w - 1)).max(1) as f64;
let mut grad = Array3::zeros((c, h, w));
for ch in 0..c {
for row in 0..h {
for col in 0..w {
let mut g = 0.0_f64;
if row + 1 < h {
let diff = image[[ch, row + 1, col]] - image[[ch, row, col]];
g -= diff.signum(); }
if row > 0 {
let diff = image[[ch, row, col]] - image[[ch, row - 1, col]];
g += diff.signum();
}
if col + 1 < w {
let diff = image[[ch, row, col + 1]] - image[[ch, row, col]];
g -= diff.signum();
}
if col > 0 {
let diff = image[[ch, row, col]] - image[[ch, row, col - 1]];
g += diff.signum();
}
grad[[ch, row, col]] = g / n;
}
}
}
grad
}
pub fn optimize_style_transfer(
content: &Array3<f64>,
style: &Array3<f64>,
weights: &StyleTransferWeights,
n_iters: usize,
lr: f64,
) -> Result<Array3<f64>> {
if lr <= 0.0 {
return Err(VisionError::InvalidParameter(format!(
"Learning rate must be positive, got {lr}"
)));
}
if n_iters == 0 {
return Err(VisionError::InvalidParameter(
"n_iters must be at least 1".to_string(),
));
}
if content.dim() != style.dim() {
return Err(VisionError::DimensionMismatch(format!(
"Content shape {:?} ≠ style shape {:?}",
content.dim(),
style.dim()
)));
}
let style_gram = gram_matrix(style);
let mut generated = content.to_owned();
for iter in 0..n_iters {
let gen_gram = gram_matrix(&generated);
let grad_style = style_gradient(&generated, &style_gram);
let grad_content = content_gradient(&generated, content);
let grad_tv = tv_gradient(&generated);
let grad = grad_content.mapv(|g| g * weights.content_weight)
+ grad_style.mapv(|g| g * weights.style_weight)
+ grad_tv.mapv(|g| g * weights.tv_weight);
let step = if iter == 0 { lr * 0.1 } else { lr };
generated = generated - grad.mapv(|g| g * step);
drop(gen_gram);
}
Ok(generated)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array3;
fn make_ramp(c: usize, h: usize, w: usize) -> Array3<f64> {
let mut a = Array3::zeros((c, h, w));
for ch in 0..c {
for row in 0..h {
for col in 0..w {
a[[ch, row, col]] = (ch * h * w + row * w + col) as f64;
}
}
}
a
}
#[test]
fn test_gram_matrix_shape() {
let feat: Array3<f64> = Array3::ones((4, 8, 8));
let g = gram_matrix(&feat);
assert_eq!(g.dim(), (4, 4));
}
#[test]
fn test_gram_matrix_symmetric() {
let feat = make_ramp(3, 5, 5);
let g = gram_matrix(&feat);
for i in 0..3 {
for j in 0..3 {
let diff = (g[[i, j]] - g[[j, i]]).abs();
assert!(diff < 1e-10, "Gram not symmetric at ({i},{j}): {diff}");
}
}
}
#[test]
fn test_gram_matrix_positive_semidefinite() {
let feat = make_ramp(3, 4, 4);
let g = gram_matrix(&feat);
for i in 0..3 {
assert!(
g[[i, i]] >= 0.0,
"Diagonal element ({i},{i}) is negative: {}",
g[[i, i]]
);
}
}
#[test]
fn test_style_loss_identical() {
let feat: Array3<f64> = Array3::ones((3, 4, 4));
let g = gram_matrix(&feat);
let loss = style_loss(&g, &g);
assert!(loss.abs() < 1e-12);
}
#[test]
fn test_style_loss_non_negative() {
let a = make_ramp(3, 4, 4);
let b: Array3<f64> = Array3::zeros((3, 4, 4));
let ga = gram_matrix(&a);
let gb = gram_matrix(&b);
assert!(style_loss(&ga, &gb) >= 0.0);
}
#[test]
fn test_content_loss_identical() {
let feat = make_ramp(3, 4, 4);
assert!(content_loss(&feat, &feat).abs() < 1e-12);
}
#[test]
fn test_content_loss_non_negative() {
let a = make_ramp(2, 4, 4);
let b: Array3<f64> = Array3::zeros((2, 4, 4));
assert!(content_loss(&a, &b) >= 0.0);
}
#[test]
fn test_total_variation_uniform() {
let img: Array3<f64> = Array3::from_elem((2, 6, 6), 5.0);
assert!(total_variation_loss(&img).abs() < 1e-12);
}
#[test]
fn test_total_variation_non_negative() {
let img = make_ramp(2, 6, 6);
assert!(total_variation_loss(&img) >= 0.0);
}
#[test]
fn test_total_variation_small_image() {
let img: Array3<f64> = Array3::ones((2, 1, 1));
assert_eq!(total_variation_loss(&img), 0.0);
}
#[test]
fn test_style_transfer_loss_struct() {
let content = make_ramp(2, 4, 4);
let style: Array3<f64> = Array3::from_elem((2, 4, 4), 3.0);
let gen_gram = gram_matrix(&content);
let sty_gram = gram_matrix(&style);
let loss_fn = StyleTransferLoss::new(StyleTransferWeights::default());
let total = loss_fn.total(&gen_gram, &sty_gram, &content, &content);
assert!(total >= 0.0);
}
#[test]
fn test_style_transfer_components() {
let img: Array3<f64> = Array3::ones((2, 4, 4));
let g = gram_matrix(&img);
let loss_fn = StyleTransferLoss::new(StyleTransferWeights::default());
let (lc, ls, ltv) = loss_fn.components(&g, &g, &img, &img);
assert!(lc.abs() < 1e-12);
assert!(ls.abs() < 1e-12);
assert!(ltv >= 0.0);
}
#[test]
fn test_optimize_style_transfer_shape() {
let content: Array3<f64> = Array3::ones((2, 4, 4));
let style: Array3<f64> = Array3::ones((2, 4, 4));
let weights = StyleTransferWeights {
content_weight: 1.0,
style_weight: 1.0,
tv_weight: 0.0,
};
let result = optimize_style_transfer(&content, &style, &weights, 3, 0.01);
assert!(result.is_ok());
assert_eq!(result.expect("Test: result shape").dim(), (2, 4, 4));
}
#[test]
fn test_optimize_style_transfer_bad_lr() {
let content: Array3<f64> = Array3::ones((2, 4, 4));
let style: Array3<f64> = Array3::ones((2, 4, 4));
let weights = StyleTransferWeights::default();
assert!(optimize_style_transfer(&content, &style, &weights, 3, 0.0).is_err());
assert!(optimize_style_transfer(&content, &style, &weights, 3, -1.0).is_err());
}
#[test]
fn test_optimize_style_transfer_zero_iters() {
let content: Array3<f64> = Array3::ones((2, 4, 4));
let style: Array3<f64> = Array3::ones((2, 4, 4));
let weights = StyleTransferWeights::default();
assert!(optimize_style_transfer(&content, &style, &weights, 0, 0.01).is_err());
}
#[test]
fn test_optimize_style_transfer_dimension_mismatch() {
let content: Array3<f64> = Array3::ones((2, 4, 4));
let style: Array3<f64> = Array3::ones((3, 4, 4));
let weights = StyleTransferWeights::default();
assert!(optimize_style_transfer(&content, &style, &weights, 3, 0.01).is_err());
}
}