use burn_core as burn;
use super::vgg19::Vgg19;
use super::weights::load_vgg19_weights;
use crate::loss::Reduction;
use burn::config::Config;
use burn::module::Module;
use burn::tensor::{Tensor, backend::Backend};
#[cfg_attr(docsrs, doc(cfg(feature = "pretrained")))]
#[derive(Config, Debug)]
pub struct GramMatrixLossConfig {
pub layer_weights: Vec<f32>,
#[config(default = "false")]
pub use_avg_pool: bool,
}
impl GramMatrixLossConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> GramMatrixLoss<B> {
self.assertions();
let vgg19 = Vgg19::new(self.use_avg_pool, device);
let pretrained_vgg19 = load_vgg19_weights(vgg19).no_grad();
GramMatrixLoss {
layer_weights: self.layer_weights.clone(),
feat_extractor: pretrained_vgg19,
}
}
fn assertions(&self) {
assert!(
self.layer_weights.len() == 5,
"The layer_weights vector must contain exactly 5 elements"
);
assert!(
self.layer_weights.iter().all(|&w| w >= 0.0),
"All layer weights must be non-negative"
);
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "pretrained")))]
#[derive(Module, Debug)]
pub struct GramMatrixLoss<B: Backend> {
pub layer_weights: Vec<f32>,
pub feat_extractor: Vgg19<B>,
}
impl<B: Backend> GramMatrixLoss<B> {
pub fn forward(
&self,
predictions: Tensor<B, 4>,
targets: Tensor<B, 4>,
reduction: Reduction,
) -> Tensor<B, 1> {
let unreduced_loss = self.forward_no_reduction(predictions, targets);
match reduction {
Reduction::Mean | Reduction::Auto => unreduced_loss.mean(),
Reduction::Sum => unreduced_loss.sum(),
other => panic!("{other:?} reduction is not supported"),
}
}
pub fn forward_no_reduction(
&self,
predictions: Tensor<B, 4>,
targets: Tensor<B, 4>,
) -> Tensor<B, 1> {
let pred_processed = self.preprocess_input(predictions);
let target_processed = self.preprocess_input(targets);
let pred_features = self.feat_extractor.forward(pred_processed);
let mut pred_normalization_factors = Vec::with_capacity(5);
for feature_tensor in &pred_features {
let [_, c, h_times_w] = feature_tensor.dims();
let (c_f, hw_f) = (c as f32, h_times_w as f32);
pred_normalization_factors.push(4.0 * c_f * c_f * hw_f * hw_f);
}
let target_features = self.feat_extractor.forward(target_processed);
let mut loss_tensors = Vec::with_capacity(pred_features.len());
for (pred_f, target_f) in pred_features.into_iter().zip(target_features) {
let pred_gram_matrices = pred_f.clone().matmul(pred_f.clone().transpose());
let target_gram_matrices = target_f.clone().matmul(target_f.clone().transpose());
let gram_matrices_diff = pred_gram_matrices - target_gram_matrices;
let gram_matrices_diff_squared = gram_matrices_diff.powi_scalar(2);
let loss = gram_matrices_diff_squared
.sum_dims(&[1, 2])
.squeeze_dims::<1>(&[1, 2]);
loss_tensors.push(loss);
}
let scaled_loss_tensors: Vec<Tensor<B, 1>> = loss_tensors
.into_iter()
.zip(pred_normalization_factors)
.zip(self.layer_weights.clone())
.map(|((loss_tensor, norm_factor), weight)| {
loss_tensor.div_scalar(norm_factor).mul_scalar(weight)
})
.collect();
let stacked_loss_tensors = Tensor::stack::<2>(scaled_loss_tensors, 1);
stacked_loss_tensors.sum_dim(1).squeeze_dim(1)
}
fn preprocess_input(&self, tensor: Tensor<B, 4>) -> Tensor<B, 4> {
let device = &tensor.device();
let channels = tensor.dims()[1];
assert!(
channels == 3,
"Expected input tensor to have exactly 3 channels, but got {}",
channels
);
let mean = Tensor::<B, 1>::from_floats([0.485, 0.456, 0.406], device).reshape([1, 3, 1, 1]);
let std = Tensor::<B, 1>::from_floats([0.229, 0.224, 0.225], device).reshape([1, 3, 1, 1]);
(tensor - mean) / std
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::Distribution;
#[test]
#[should_panic(expected = "The layer_weights vector must contain exactly 5 elements")]
fn test_gram_matrix_loss_config_invalid_length() {
let device = Default::default();
GramMatrixLossConfig::new(vec![1.0, 1.0]).init::<TestBackend>(&device);
}
#[test]
#[should_panic(expected = "All layer weights must be non-negative")]
fn test_gram_matrix_loss_config_negative_weights() {
let device = Default::default();
GramMatrixLossConfig::new(vec![1.0, -1.0, 1.0, 1.0, 1.0]).init::<TestBackend>(&device);
}
#[test]
#[ignore = "downloads pre-trained weights"]
fn test_gram_matrix_loss_config_valid_weights() {
let device = Default::default();
let layer_weights = vec![0.0, 0.2, 0.2, 0.25, 0.4];
let loss_fn = GramMatrixLossConfig::new(layer_weights.clone()).init::<TestBackend>(&device);
assert_eq!(
loss_fn.layer_weights, layer_weights,
"Expected layer weights vector {:?}, got {:?}",
loss_fn.layer_weights, layer_weights
);
}
#[test]
#[should_panic(expected = "Expected input tensor to have exactly 3 channels, but got 1")]
fn test_gram_matrix_loss_1_channel_panic() {
let device = Default::default();
let loss_fn = GramMatrixLoss {
layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0],
feat_extractor: Vgg19::new(false, &device),
};
let tensor1: Tensor<TestBackend, 4> =
Tensor::random([2, 1, 16, 16], Distribution::Default, &device);
let tensor2 = tensor1.clone();
let _ = loss_fn.forward(tensor1, tensor2, Reduction::Mean);
}
#[test]
#[should_panic(expected = "Expected input tensor to have exactly 3 channels, but got 4")]
fn test_gram_matrix_loss_4_channel_panic() {
let device = Default::default();
let loss_fn = GramMatrixLoss {
layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0],
feat_extractor: Vgg19::new(false, &device),
};
let tensor1: Tensor<TestBackend, 4> =
Tensor::random([2, 4, 16, 16], Distribution::Default, &device);
let tensor2 = tensor1.clone();
let _ = loss_fn.forward(tensor1, tensor2, Reduction::Mean);
}
#[test]
fn test_gram_matrix_loss_zero_for_identical_inputs() {
let device = Default::default();
let loss_fn = GramMatrixLoss {
layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0],
feat_extractor: Vgg19::new(false, &device),
};
let tensor1: Tensor<TestBackend, 4> =
Tensor::random([2, 3, 16, 16], Distribution::Default, &device);
let tensor2 = tensor1.clone();
let loss = loss_fn.forward(tensor1, tensor2, Reduction::Mean);
let loss_val: f32 = loss.into_scalar();
assert!(
loss_val.abs() < 1e-4,
"Loss should be zero for identical inputs"
);
}
#[test]
fn test_gram_matrix_loss_greater_than_zero_for_different_inputs() {
let device = Default::default();
let loss_fn = GramMatrixLoss {
layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0],
feat_extractor: Vgg19::new(false, &device),
};
let tensor1: Tensor<TestBackend, 4> = Tensor::ones([2, 3, 16, 16], &device);
let tensor2: Tensor<TestBackend, 4> = Tensor::zeros([2, 3, 16, 16], &device);
let loss = loss_fn.forward(tensor1, tensor2, Reduction::Mean);
let loss_val: f32 = loss.into_scalar();
assert!(
loss_val > 0.0,
"Loss should be positive for different inputs"
);
}
#[test]
fn test_gram_matrix_loss_forward_no_reduction_shape() {
let device = Default::default();
let loss_fn = GramMatrixLoss {
layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0],
feat_extractor: Vgg19::new(false, &device),
};
let batch_size = 4;
let tensor1: Tensor<TestBackend, 4> =
Tensor::random([batch_size, 3, 16, 16], Distribution::Default, &device);
let tensor2: Tensor<TestBackend, 4> =
Tensor::random([batch_size, 3, 16, 16], Distribution::Default, &device);
let unreduced_loss = loss_fn.forward_no_reduction(tensor1, tensor2);
assert_eq!(unreduced_loss.dims(), [batch_size]);
}
#[test]
fn test_gram_matrix_loss_reduction_sum_vs_mean() {
let device = Default::default();
let loss_fn = GramMatrixLoss {
layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0],
feat_extractor: Vgg19::new(false, &device),
};
let batch_size = 4;
let tensor1: Tensor<TestBackend, 4> =
Tensor::random([batch_size, 3, 16, 16], Distribution::Default, &device);
let tensor2: Tensor<TestBackend, 4> =
Tensor::random([batch_size, 3, 16, 16], Distribution::Default, &device);
let loss_mean: f32 = loss_fn
.forward(tensor1.clone(), tensor2.clone(), Reduction::Mean)
.into_scalar();
let loss_sum: f32 = loss_fn
.forward(tensor1, tensor2, Reduction::Sum)
.into_scalar();
let expected_sum = loss_mean * (batch_size as f32);
let diff = (loss_sum - expected_sum).abs();
assert!(
diff < 1e-4,
"Sum reduction should equal batch_size * Mean reduction"
);
}
#[test]
fn test_gram_matrix_loss_with_avg_pool() {
let device = Default::default();
let loss_fn = GramMatrixLoss {
layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0],
feat_extractor: Vgg19::new(true, &device),
};
let batch_size = 4;
let tensor1: Tensor<TestBackend, 4> = Tensor::ones([batch_size, 3, 16, 16], &device);
let tensor2: Tensor<TestBackend, 4> = Tensor::zeros([batch_size, 3, 16, 16], &device);
let loss = loss_fn.forward(tensor1, tensor2, Reduction::Mean);
let loss_val: f32 = loss.into_scalar();
assert!(
loss_val > 0.0,
"Loss should be positive for different inputs using avg pooling"
);
}
#[test]
fn test_gram_matrix_loss_autodiff() {
use crate::TestAutodiffBackend;
let device = Default::default();
let loss_fn = GramMatrixLoss {
layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0],
feat_extractor: Vgg19::<TestAutodiffBackend>::new(false, &device).no_grad(),
};
let predictions: Tensor<TestAutodiffBackend, 4> =
Tensor::ones([2, 3, 16, 16], &device).require_grad();
let targets: Tensor<TestAutodiffBackend, 4> = Tensor::zeros([2, 3, 16, 16], &device);
let loss = loss_fn.forward(predictions.clone(), targets, Reduction::Mean);
let grads = loss.backward();
let pred_grad = predictions.grad(&grads);
assert!(
pred_grad.is_some(),
"Gradients should be computed for the predictions tensor"
);
let conv1_1_weight_grad = loss_fn.feat_extractor.conv1_1.weight.val().grad(&grads);
assert!(
conv1_1_weight_grad.is_none(),
"Gradients should not be computed for VGG19 parameters"
);
}
#[test]
#[ignore = "downloads pre-trained weights"]
fn test_gram_matrix_loss_pretrained_weights_identical_inputs() {
let device = Default::default();
let loss_fn =
GramMatrixLossConfig::new(vec![1.0, 1.0, 1.0, 1.0, 1.0]).init::<TestBackend>(&device);
let tensor1: Tensor<TestBackend, 4> =
Tensor::random([2, 3, 16, 16], Distribution::Default, &device);
let tensor2 = tensor1.clone();
let loss = loss_fn.forward(tensor1, tensor2, Reduction::Mean);
let loss_val: f32 = loss.into_scalar();
assert!(
loss_val.abs() < 1e-4,
"Loss should be zero for identical inputs"
);
}
#[test]
#[ignore = "downloads pre-trained weights"]
fn test_gram_matrix_loss_pretrained_weights_different_inputs() {
let device = Default::default();
let loss_fn =
GramMatrixLossConfig::new(vec![1.0, 1.0, 1.0, 1.0, 1.0]).init::<TestBackend>(&device);
let tensor1: Tensor<TestBackend, 4> = Tensor::ones([2, 3, 16, 16], &device);
let tensor2: Tensor<TestBackend, 4> = Tensor::zeros([2, 3, 16, 16], &device);
let loss = loss_fn.forward(tensor1, tensor2, Reduction::Mean);
let loss_val: f32 = loss.into_scalar();
assert!(
loss_val > 0.0,
"Loss should be positive for different inputs"
);
}
}