use burn_core as burn;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn_nn::conv::{Conv2d, Conv2dConfig};
use burn_nn::loss::Reduction;
use super::alexnet::AlexFeatureExtractor;
use super::squeezenet::SqueezeFeatureExtractor;
use super::vgg::VggFeatureExtractor;
#[derive(Config, Debug, Copy, PartialEq, Eq)]
pub enum LpipsNet {
Vgg,
Alex,
Squeeze,
}
#[derive(Config, Debug)]
pub struct LpipsConfig {
#[config(default = "LpipsNet::Vgg")]
pub net: LpipsNet,
#[config(default = true)]
pub normalize: bool,
}
impl LpipsConfig {
pub fn init_pretrained<B: Backend>(&self, device: &B::Device) -> Lpips<B> {
let lpips = self.init(device);
super::weights::load_pretrained_weights(lpips, self.net)
}
pub fn init<B: Backend>(&self, device: &B::Device) -> Lpips<B> {
match self.net {
LpipsNet::Vgg => {
Lpips::Vgg(LpipsVgg {
extractor: VggFeatureExtractor::new(device),
lin0: Conv2dConfig::new([64, 1], [1, 1])
.with_bias(false)
.init(device),
lin1: Conv2dConfig::new([128, 1], [1, 1])
.with_bias(false)
.init(device),
lin2: Conv2dConfig::new([256, 1], [1, 1])
.with_bias(false)
.init(device),
lin3: Conv2dConfig::new([512, 1], [1, 1])
.with_bias(false)
.init(device),
lin4: Conv2dConfig::new([512, 1], [1, 1])
.with_bias(false)
.init(device),
normalize: self.normalize,
})
}
LpipsNet::Alex => {
Lpips::Alex(LpipsAlex {
extractor: AlexFeatureExtractor::new(device),
lin0: Conv2dConfig::new([64, 1], [1, 1])
.with_bias(false)
.init(device),
lin1: Conv2dConfig::new([192, 1], [1, 1])
.with_bias(false)
.init(device),
lin2: Conv2dConfig::new([384, 1], [1, 1])
.with_bias(false)
.init(device),
lin3: Conv2dConfig::new([256, 1], [1, 1])
.with_bias(false)
.init(device),
lin4: Conv2dConfig::new([256, 1], [1, 1])
.with_bias(false)
.init(device),
normalize: self.normalize,
})
}
LpipsNet::Squeeze => {
Lpips::Squeeze(LpipsSqueeze {
extractor: SqueezeFeatureExtractor::new(device),
lin0: Conv2dConfig::new([64, 1], [1, 1])
.with_bias(false)
.init(device),
lin1: Conv2dConfig::new([128, 1], [1, 1])
.with_bias(false)
.init(device),
lin2: Conv2dConfig::new([256, 1], [1, 1])
.with_bias(false)
.init(device),
lin3: Conv2dConfig::new([384, 1], [1, 1])
.with_bias(false)
.init(device),
lin4: Conv2dConfig::new([384, 1], [1, 1])
.with_bias(false)
.init(device),
lin5: Conv2dConfig::new([512, 1], [1, 1])
.with_bias(false)
.init(device),
lin6: Conv2dConfig::new([512, 1], [1, 1])
.with_bias(false)
.init(device),
normalize: self.normalize,
})
}
}
}
}
#[derive(Module, Debug)]
#[allow(clippy::large_enum_variant)]
#[module(custom_display)]
pub enum Lpips<B: Backend> {
Vgg(LpipsVgg<B>),
Alex(LpipsAlex<B>),
Squeeze(LpipsSqueeze<B>),
}
#[derive(Module, Debug)]
pub struct LpipsVgg<B: Backend> {
pub(crate) extractor: VggFeatureExtractor<B>,
pub(crate) lin0: Conv2d<B>,
pub(crate) lin1: Conv2d<B>,
pub(crate) lin2: Conv2d<B>,
pub(crate) lin3: Conv2d<B>,
pub(crate) lin4: Conv2d<B>,
pub(crate) normalize: bool,
}
#[derive(Module, Debug)]
pub struct LpipsAlex<B: Backend> {
pub(crate) extractor: AlexFeatureExtractor<B>,
pub(crate) lin0: Conv2d<B>,
pub(crate) lin1: Conv2d<B>,
pub(crate) lin2: Conv2d<B>,
pub(crate) lin3: Conv2d<B>,
pub(crate) lin4: Conv2d<B>,
pub(crate) normalize: bool,
}
#[derive(Module, Debug)]
pub struct LpipsSqueeze<B: Backend> {
pub(crate) extractor: SqueezeFeatureExtractor<B>,
pub(crate) lin0: Conv2d<B>,
pub(crate) lin1: Conv2d<B>,
pub(crate) lin2: Conv2d<B>,
pub(crate) lin3: Conv2d<B>,
pub(crate) lin4: Conv2d<B>,
pub(crate) lin5: Conv2d<B>,
pub(crate) lin6: Conv2d<B>,
pub(crate) normalize: bool,
}
impl<B: Backend> LpipsVgg<B> {
pub fn forward_no_reduction(&self, input: Tensor<B, 4>, target: Tensor<B, 4>) -> Tensor<B, 1> {
let (input, target) = preprocess_inputs(input, target, self.normalize);
let feats0 = self.extractor.forward(input);
let feats1 = self.extractor.forward(target);
let layer_distances: Vec<Tensor<B, 2>> = vec![
compute_layer_distance(&feats0[0], &feats1[0], &self.lin0).unsqueeze_dim(1),
compute_layer_distance(&feats0[1], &feats1[1], &self.lin1).unsqueeze_dim(1),
compute_layer_distance(&feats0[2], &feats1[2], &self.lin2).unsqueeze_dim(1),
compute_layer_distance(&feats0[3], &feats1[3], &self.lin3).unsqueeze_dim(1),
compute_layer_distance(&feats0[4], &feats1[4], &self.lin4).unsqueeze_dim(1),
];
Tensor::cat(layer_distances, 1)
.sum_dim(1)
.squeeze_dim::<1>(1)
}
}
impl<B: Backend> LpipsAlex<B> {
pub fn forward_no_reduction(&self, input: Tensor<B, 4>, target: Tensor<B, 4>) -> Tensor<B, 1> {
let (input, target) = preprocess_inputs(input, target, self.normalize);
let feats0 = self.extractor.forward(input);
let feats1 = self.extractor.forward(target);
let layer_distances: Vec<Tensor<B, 2>> = vec![
compute_layer_distance(&feats0[0], &feats1[0], &self.lin0).unsqueeze_dim(1),
compute_layer_distance(&feats0[1], &feats1[1], &self.lin1).unsqueeze_dim(1),
compute_layer_distance(&feats0[2], &feats1[2], &self.lin2).unsqueeze_dim(1),
compute_layer_distance(&feats0[3], &feats1[3], &self.lin3).unsqueeze_dim(1),
compute_layer_distance(&feats0[4], &feats1[4], &self.lin4).unsqueeze_dim(1),
];
Tensor::cat(layer_distances, 1)
.sum_dim(1)
.squeeze_dim::<1>(1)
}
}
impl<B: Backend> LpipsSqueeze<B> {
pub fn forward_no_reduction(&self, input: Tensor<B, 4>, target: Tensor<B, 4>) -> Tensor<B, 1> {
let (input, target) = preprocess_inputs(input, target, self.normalize);
let feats0 = self.extractor.forward(input);
let feats1 = self.extractor.forward(target);
let layer_distances: Vec<Tensor<B, 2>> = vec![
compute_layer_distance(&feats0[0], &feats1[0], &self.lin0).unsqueeze_dim(1),
compute_layer_distance(&feats0[1], &feats1[1], &self.lin1).unsqueeze_dim(1),
compute_layer_distance(&feats0[2], &feats1[2], &self.lin2).unsqueeze_dim(1),
compute_layer_distance(&feats0[3], &feats1[3], &self.lin3).unsqueeze_dim(1),
compute_layer_distance(&feats0[4], &feats1[4], &self.lin4).unsqueeze_dim(1),
compute_layer_distance(&feats0[5], &feats1[5], &self.lin5).unsqueeze_dim(1),
compute_layer_distance(&feats0[6], &feats1[6], &self.lin6).unsqueeze_dim(1),
];
Tensor::cat(layer_distances, 1)
.sum_dim(1)
.squeeze_dim::<1>(1)
}
}
impl<B: Backend> ModuleDisplay for Lpips<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let (net_name, normalize) = match self {
Lpips::Vgg(inner) => ("Vgg", inner.normalize),
Lpips::Alex(inner) => ("Alex", inner.normalize),
Lpips::Squeeze(inner) => ("Squeeze", inner.normalize),
};
content
.add("net", &net_name.to_string())
.add("normalize", &normalize.to_string())
.optional()
}
}
impl<B: Backend> Lpips<B> {
pub fn forward(
&self,
input: Tensor<B, 4>,
target: Tensor<B, 4>,
reduction: Reduction,
) -> Tensor<B, 1> {
let distance = self.forward_no_reduction(input, target);
match reduction {
Reduction::Mean | Reduction::Auto | Reduction::BatchMean => distance.mean(),
Reduction::Sum => distance.sum(),
}
}
pub fn forward_no_reduction(&self, input: Tensor<B, 4>, target: Tensor<B, 4>) -> Tensor<B, 1> {
match self {
Lpips::Vgg(inner) => inner.forward_no_reduction(input, target),
Lpips::Alex(inner) => inner.forward_no_reduction(input, target),
Lpips::Squeeze(inner) => inner.forward_no_reduction(input, target),
}
}
}
fn normalize_tensor<B: Backend>(x: Tensor<B, 4>) -> Tensor<B, 4> {
let norm = x.clone().mul(x.clone()).sum_dim(1).sqrt().clamp_min(1e-10);
x.div(norm)
}
fn scaling_layer<B: Backend>(x: Tensor<B, 4>) -> Tensor<B, 4> {
let device = x.device();
let [batch, _, h, w] = x.dims();
let shift = Tensor::<B, 2>::from_floats([[-0.030], [-0.088], [-0.188]], &device)
.reshape([1, 3, 1, 1])
.expand([batch, 3, h, w]);
let scale = Tensor::<B, 2>::from_floats([[0.458], [0.448], [0.450]], &device)
.reshape([1, 3, 1, 1])
.expand([batch, 3, h, w]);
x.sub(shift).div(scale)
}
fn compute_layer_distance<B: Backend>(
feat0: &Tensor<B, 4>,
feat1: &Tensor<B, 4>,
lin: &Conv2d<B>,
) -> Tensor<B, 1> {
let feat0_norm = normalize_tensor(feat0.clone());
let feat1_norm = normalize_tensor(feat1.clone());
let diff = feat0_norm.sub(feat1_norm);
let diff_sq = diff.clone().mul(diff);
let weighted = lin.forward(diff_sq);
let [batch, c, h, w] = weighted.dims();
weighted
.reshape([batch, c * h * w])
.mean_dim(1)
.squeeze_dim::<1>(1)
}
fn preprocess_inputs<B: Backend>(
input: Tensor<B, 4>,
target: Tensor<B, 4>,
normalize: bool,
) -> (Tensor<B, 4>, Tensor<B, 4>) {
let (input, target) = if normalize {
(
input.mul_scalar(2.0).sub_scalar(1.0),
target.mul_scalar(2.0).sub_scalar(1.0),
)
} else {
(input, target)
};
(scaling_layer(input), scaling_layer(target))
}
#[cfg(test)]
mod tests {
use super::*;
use burn_core::tensor::{TensorData, Tolerance, ops::FloatElem};
use burn_ndarray::NdArray;
type TestBackend = NdArray<f32>;
type FT = FloatElem<TestBackend>;
type TestTensor<const D: usize> = Tensor<TestBackend, D>;
#[test]
fn test_lpips_identical_images_zero_distance() {
let device = Default::default();
let image = TestTensor::<4>::ones([1, 3, 32, 32], &device);
let lpips: Lpips<TestBackend> = LpipsConfig::new().init(&device);
let distance = lpips.forward(image.clone(), image, Reduction::Mean);
let expected = TensorData::from([0.0]);
distance
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_lpips_different_images_nonzero_distance() {
let device = Default::default();
let image1 = TestTensor::<4>::zeros([1, 3, 32, 32], &device);
let image2 = TestTensor::<4>::ones([1, 3, 32, 32], &device);
let lpips: Lpips<TestBackend> = LpipsConfig::new().init(&device);
let distance = lpips.forward(image1, image2, Reduction::Mean);
let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];
assert!(
distance_value.abs() > 1e-6,
"LPIPS should be != 0 for different images"
);
}
#[test]
fn test_lpips_symmetry() {
let device = Default::default();
let image1 = TestTensor::<4>::zeros([1, 3, 32, 32], &device);
let image2 = TestTensor::<4>::ones([1, 3, 32, 32], &device);
let lpips: Lpips<TestBackend> = LpipsConfig::new().init(&device);
let distance_forward = lpips.forward(image1.clone(), image2.clone(), Reduction::Mean);
let distance_reverse = lpips.forward(image2, image1, Reduction::Mean);
distance_forward
.into_data()
.assert_approx_eq::<FT>(&distance_reverse.into_data(), Tolerance::default());
}
#[test]
fn test_lpips_forward_mean_reduction() {
let device = Default::default();
let image1 = TestTensor::<4>::zeros([2, 3, 32, 32], &device);
let image2 = TestTensor::<4>::ones([2, 3, 32, 32], &device);
let lpips: Lpips<TestBackend> = LpipsConfig::new().init(&device);
let distance = lpips.forward(image1, image2, Reduction::Mean);
assert_eq!(distance.dims(), [1]);
}
#[test]
fn test_lpips_forward_no_reduction() {
let device = Default::default();
let batch_size = 4;
let image1 = TestTensor::<4>::zeros([batch_size, 3, 32, 32], &device);
let image2 = TestTensor::<4>::ones([batch_size, 3, 32, 32], &device);
let lpips: Lpips<TestBackend> = LpipsConfig::new().init(&device);
let distance = lpips.forward_no_reduction(image1, image2);
assert_eq!(distance.dims(), [batch_size]);
}
#[test]
fn test_lpips_alex_identical_images_zero_distance() {
let device = Default::default();
let image = TestTensor::<4>::ones([1, 3, 64, 64], &device);
let lpips: Lpips<TestBackend> = LpipsConfig::new().with_net(LpipsNet::Alex).init(&device);
let distance = lpips.forward(image.clone(), image, Reduction::Mean);
let expected = TensorData::from([0.0]);
distance
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_lpips_alex_different_images_nonzero_distance() {
let device = Default::default();
let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device);
let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device);
let lpips: Lpips<TestBackend> = LpipsConfig::new().with_net(LpipsNet::Alex).init(&device);
let distance = lpips.forward(image1, image2, Reduction::Mean);
let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];
assert!(
distance_value.abs() > 1e-6,
"LPIPS (Alex) should be != 0 for different images"
);
}
#[test]
fn test_lpips_squeeze_identical_images_zero_distance() {
let device = Default::default();
let image = TestTensor::<4>::ones([1, 3, 64, 64], &device);
let lpips: Lpips<TestBackend> =
LpipsConfig::new().with_net(LpipsNet::Squeeze).init(&device);
let distance = lpips.forward(image.clone(), image, Reduction::Mean);
let expected = TensorData::from([0.0]);
distance
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_lpips_squeeze_different_images_nonzero_distance() {
let device = Default::default();
let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device);
let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device);
let lpips: Lpips<TestBackend> =
LpipsConfig::new().with_net(LpipsNet::Squeeze).init(&device);
let distance = lpips.forward(image1, image2, Reduction::Mean);
let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];
assert!(
distance_value.abs() > 1e-6,
"LPIPS (Squeeze) should be != 0 for different images"
);
}
#[test]
fn display_vgg() {
let device = Default::default();
let lpips: Lpips<TestBackend> = LpipsConfig::new().init(&device);
let display_str = format!("{lpips}");
assert!(display_str.contains("Lpips"));
assert!(display_str.contains("Vgg"));
}
#[test]
fn display_alex() {
let device = Default::default();
let lpips: Lpips<TestBackend> = LpipsConfig::new().with_net(LpipsNet::Alex).init(&device);
let display_str = format!("{lpips}");
assert!(display_str.contains("Lpips"));
assert!(display_str.contains("Alex"));
}
#[test]
fn display_squeeze() {
let device = Default::default();
let lpips: Lpips<TestBackend> =
LpipsConfig::new().with_net(LpipsNet::Squeeze).init(&device);
let display_str = format!("{lpips}");
assert!(display_str.contains("Lpips"));
assert!(display_str.contains("Squeeze"));
}
#[test]
#[ignore = "downloads pre-trained weights"]
fn test_lpips_pretrained_vgg() {
let device = Default::default();
let lpips: Lpips<TestBackend> = LpipsConfig::new()
.with_net(LpipsNet::Vgg)
.init_pretrained(&device);
let image = TestTensor::<4>::ones([1, 3, 64, 64], &device);
let distance = lpips.forward(image.clone(), image, Reduction::Mean);
let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];
assert!(
distance_value.abs() < 1e-5,
"Pretrained LPIPS (VGG) should be ~0 for identical images, got {}",
distance_value
);
let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device);
let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device);
let distance = lpips.forward(image1, image2, Reduction::Mean);
let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];
assert!(
distance_value > 0.0,
"Pretrained LPIPS (VGG) should be > 0 for different images, got {}",
distance_value
);
}
#[test]
#[ignore = "downloads pre-trained weights"]
fn test_lpips_pretrained_alex() {
let device = Default::default();
let lpips: Lpips<TestBackend> = LpipsConfig::new()
.with_net(LpipsNet::Alex)
.init_pretrained(&device);
let image = TestTensor::<4>::ones([1, 3, 64, 64], &device);
let distance = lpips.forward(image.clone(), image, Reduction::Mean);
let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];
assert!(
distance_value.abs() < 1e-5,
"Pretrained LPIPS (Alex) should be ~0 for identical images, got {}",
distance_value
);
let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device);
let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device);
let distance = lpips.forward(image1, image2, Reduction::Mean);
let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];
assert!(
distance_value > 0.0,
"Pretrained LPIPS (Alex) should be > 0 for different images"
);
}
#[test]
#[ignore = "downloads pre-trained weights"]
fn test_lpips_pretrained_squeeze() {
let device = Default::default();
let lpips: Lpips<TestBackend> = LpipsConfig::new()
.with_net(LpipsNet::Squeeze)
.init_pretrained(&device);
let image = TestTensor::<4>::ones([1, 3, 64, 64], &device);
let distance = lpips.forward(image.clone(), image, Reduction::Mean);
let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];
assert!(
distance_value.abs() < 1e-5,
"Pretrained LPIPS (Squeeze) should be ~0 for identical images, got {}",
distance_value
);
let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device);
let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device);
let distance = lpips.forward(image1, image2, Reduction::Mean);
let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];
assert!(
distance_value > 0.0,
"Pretrained LPIPS (Squeeze) should be > 0 for different images, got {}",
distance_value
);
}
}