use burn_core as burn;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Module, ModuleDisplay, Param};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn_nn::loss::Reduction;
use super::vgg16_l2pool::Vgg16L2PoolExtractor;
const CHANNELS: [usize; 6] = [3, 64, 128, 256, 512, 512];
const C1: f32 = 1e-6;
const C2: f32 = 1e-6;
const IMAGENET_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
const IMAGENET_STD: [f32; 3] = [0.229, 0.224, 0.225];
#[derive(Module, Debug)]
pub struct Normalizer<B: Backend> {
pub mean: Tensor<B, 4>,
pub std: Tensor<B, 4>,
}
impl<B: Backend> Normalizer<B> {
pub fn imagenet(device: &B::Device) -> Self {
let mean = Tensor::from_floats(
[[
[[IMAGENET_MEAN[0]]],
[[IMAGENET_MEAN[1]]],
[[IMAGENET_MEAN[2]]],
]],
device,
);
let std = Tensor::from_floats(
[[
[[IMAGENET_STD[0]]],
[[IMAGENET_STD[1]]],
[[IMAGENET_STD[2]]],
]],
device,
);
Self { mean, std }
}
pub fn normalize(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
x.sub(self.mean.clone()).div(self.std.clone())
}
}
#[derive(Config, Debug)]
pub struct DistsConfig {
#[config(default = true)]
pub normalize: bool,
}
impl DistsConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> Dists<B> {
let total_channels: usize = CHANNELS.iter().sum();
let alpha_data: Vec<f32> = (0..total_channels).map(|_| 0.1).collect();
let beta_data: Vec<f32> = (0..total_channels).map(|_| 0.1).collect();
let normalizer = if self.normalize {
Some(Normalizer::imagenet(device))
} else {
None
};
Dists {
extractor: Vgg16L2PoolExtractor::new(device),
alpha: Param::from_tensor(Tensor::from_floats(alpha_data.as_slice(), device)),
beta: Param::from_tensor(Tensor::from_floats(beta_data.as_slice(), device)),
normalizer,
}
}
pub fn init_pretrained<B: Backend>(&self, device: &B::Device) -> Dists<B> {
let dists = self.init(device);
super::weights::load_pretrained_weights(dists)
}
}
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Dists<B: Backend> {
pub(crate) extractor: Vgg16L2PoolExtractor<B>,
pub(crate) alpha: Param<Tensor<B, 1>>,
pub(crate) beta: Param<Tensor<B, 1>>,
pub(crate) normalizer: Option<Normalizer<B>>,
}
impl<B: Backend> ModuleDisplay for Dists<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("backbone", &"VGG16-L2Pool".to_string())
.add("normalize", &self.normalizer.is_some().to_string())
.optional()
}
}
impl<B: Backend> Dists<B> {
pub fn forward(
&self,
input: Tensor<B, 4>,
target: Tensor<B, 4>,
reduction: Reduction,
) -> Tensor<B, 1> {
let distance = self.forward_no_reduction(input, target);
match reduction {
Reduction::Mean | Reduction::Auto | Reduction::BatchMean => distance.mean(),
Reduction::Sum => distance.sum(),
}
}
pub fn forward_no_reduction(&self, input: Tensor<B, 4>, target: Tensor<B, 4>) -> Tensor<B, 1> {
let [batch, _, _, _] = input.dims();
let (input, target) = self.preprocess(input, target);
let feats_x = self.extractor.forward(input);
let feats_y = self.extractor.forward(target);
let alpha = self.alpha.val();
let beta = self.beta.val();
let alpha_sum = alpha.clone().sum();
let beta_sum = beta.clone().sum();
let device = feats_x[0].device();
let mut structure_dist = Tensor::<B, 1>::zeros([batch], &device);
let mut texture_dist = Tensor::<B, 1>::zeros([batch], &device);
let mut channel_offset = 0;
for (feat_x, feat_y) in feats_x.iter().zip(feats_y.iter()) {
let [_b, c, _h, _w] = feat_x.dims();
let alpha_stage = alpha.clone().narrow(0, channel_offset, c);
let beta_stage = beta.clone().narrow(0, channel_offset, c);
let (s_dist, t_dist) = self.compute_stage_similarity(
feat_x.clone(),
feat_y.clone(),
alpha_stage,
beta_stage,
);
structure_dist = structure_dist.add(s_dist);
texture_dist = texture_dist.add(t_dist);
channel_offset += c;
}
structure_dist = structure_dist.div(alpha_sum);
texture_dist = texture_dist.div(beta_sum);
structure_dist.add(texture_dist)
}
fn compute_stage_similarity(
&self,
feat_x: Tensor<B, 4>,
feat_y: Tensor<B, 4>,
alpha: Tensor<B, 1>,
beta: Tensor<B, 1>,
) -> (Tensor<B, 1>, Tensor<B, 1>) {
let [batch, channels, height, width] = feat_x.dims();
let device = feat_x.device();
let x = feat_x.reshape([batch, channels, height * width]);
let y = feat_y.reshape([batch, channels, height * width]);
let mean_x = x.clone().mean_dim(2).squeeze_dim::<2>(2);
let mean_y = y.clone().mean_dim(2).squeeze_dim::<2>(2);
let c1 = Tensor::<B, 2>::full([batch, channels], C1, &device);
let structure_sim = mean_x
.clone()
.mul(mean_y.clone())
.mul_scalar(2.0)
.add(c1.clone())
.div(
mean_x
.clone()
.mul(mean_x.clone())
.add(mean_y.clone().mul(mean_y.clone()))
.add(c1),
);
let var_x = x
.clone()
.mul(x.clone())
.mean_dim(2)
.squeeze_dim::<2>(2)
.sub(mean_x.clone().mul(mean_x.clone()))
.clamp_min(0.0);
let var_y = y
.clone()
.mul(y.clone())
.mean_dim(2)
.squeeze_dim::<2>(2)
.sub(mean_y.clone().mul(mean_y.clone()))
.clamp_min(0.0);
let cov_xy = x
.mul(y)
.mean_dim(2)
.squeeze_dim::<2>(2)
.sub(mean_x.clone().mul(mean_y.clone()));
let c2 = Tensor::<B, 2>::full([batch, channels], C2, &device);
let texture_sim = cov_xy
.mul_scalar(2.0)
.add(c2.clone())
.div(var_x.add(var_y).add(c2));
let structure_dist = Tensor::<B, 2>::ones([batch, channels], &device).sub(structure_sim);
let texture_dist = Tensor::<B, 2>::ones([batch, channels], &device).sub(texture_sim);
let weighted_structure = structure_dist
.mul(alpha.unsqueeze_dim::<2>(0))
.sum_dim(1)
.squeeze_dim::<1>(1);
let weighted_texture = texture_dist
.mul(beta.unsqueeze_dim::<2>(0))
.sum_dim(1)
.squeeze_dim::<1>(1);
(weighted_structure, weighted_texture)
}
fn preprocess(
&self,
input: Tensor<B, 4>,
target: Tensor<B, 4>,
) -> (Tensor<B, 4>, Tensor<B, 4>) {
match &self.normalizer {
Some(normalizer) => {
let input = normalizer.normalize(input);
let target = normalizer.normalize(target);
(input, target)
}
None => (input, target),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn_core::tensor::{TensorData, Tolerance, ops::FloatElem};
use burn_flex::Flex;
type TestBackend = Flex;
type FT = FloatElem<TestBackend>;
type TestTensor<const D: usize> = Tensor<TestBackend, D>;
#[test]
fn test_dists_identical_images_zero_distance() {
let device = Default::default();
let image = TestTensor::<4>::random(
[1, 3, 64, 64],
burn_core::tensor::Distribution::Uniform(0.0, 1.0),
&device,
);
let dists: Dists<TestBackend> = DistsConfig::new().init(&device);
let distance = dists.forward(image.clone(), image, Reduction::Mean);
let expected = TensorData::from([0.0]);
distance
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_dists_different_images_nonzero_distance() {
let device = Default::default();
let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device);
let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device);
let dists: Dists<TestBackend> = DistsConfig::new().init(&device);
let distance = dists.forward(image1, image2, Reduction::Mean);
let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];
assert!(
distance_value.abs() > 1e-6,
"DISTS should be != 0 for different images"
);
}
#[test]
fn test_dists_symmetry() {
let device = Default::default();
let image1 = TestTensor::<4>::zeros([1, 3, 32, 32], &device);
let image2 = TestTensor::<4>::ones([1, 3, 32, 32], &device);
let dists: Dists<TestBackend> = DistsConfig::new().init(&device);
let distance_forward = dists.forward(image1.clone(), image2.clone(), Reduction::Mean);
let distance_reverse = dists.forward(image2, image1, Reduction::Mean);
distance_forward
.into_data()
.assert_approx_eq::<FT>(&distance_reverse.into_data(), Tolerance::default());
}
#[test]
fn test_dists_batch_processing() {
let device = Default::default();
let image1 = TestTensor::<4>::zeros([2, 3, 32, 32], &device);
let image2 = TestTensor::<4>::ones([2, 3, 32, 32], &device);
let dists: Dists<TestBackend> = DistsConfig::new().init(&device);
let distance = dists.forward(image1, image2, Reduction::Mean);
assert_eq!(distance.dims(), [1]);
}
#[test]
fn test_dists_no_reduction() {
let device = Default::default();
let batch_size = 4;
let image1 = TestTensor::<4>::zeros([batch_size, 3, 32, 32], &device);
let image2 = TestTensor::<4>::ones([batch_size, 3, 32, 32], &device);
let dists: Dists<TestBackend> = DistsConfig::new().init(&device);
let distance = dists.forward_no_reduction(image1, image2);
assert_eq!(distance.dims(), [batch_size]);
}
#[test]
fn display_dists() {
let device = Default::default();
let dists: Dists<TestBackend> = DistsConfig::new().init(&device);
let display_str = format!("{dists}");
assert!(display_str.contains("Dists"));
assert!(display_str.contains("VGG16-L2Pool"));
}
#[test]
#[ignore = "downloads pre-trained weights"]
fn test_dists_pretrained() {
let device = Default::default();
let dists: Dists<TestBackend> = DistsConfig::new().init_pretrained(&device);
let image = TestTensor::<4>::random(
[1, 3, 64, 64],
burn_core::tensor::Distribution::Uniform(0.0, 1.0),
&device,
);
let distance = dists.forward(image.clone(), image, Reduction::Mean);
let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];
assert!(
distance_value.abs() < 1e-5,
"Pretrained DISTS should be ~0 for identical images, got {}",
distance_value
);
let image1 = TestTensor::<4>::random(
[1, 3, 64, 64],
burn_core::tensor::Distribution::Uniform(0.0, 0.3),
&device,
);
let image2 = TestTensor::<4>::random(
[1, 3, 64, 64],
burn_core::tensor::Distribution::Uniform(0.7, 1.0),
&device,
);
let distance = dists.forward(image1, image2, Reduction::Mean);
let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];
assert!(
distance_value > 0.0,
"Pretrained DISTS should be > 0 for different images"
);
}
}