use crate::metric::{
Metric, MetricAttributes, MetricMetadata, MetricName, Numeric, NumericAttributes, NumericEntry,
SerializedEntry,
state::{FormatOptions, NumericMetricState},
};
use burn_core::{
prelude::{Backend, Tensor},
tensor::{ElementConversion, module::conv2d, ops::ConvOptions},
};
use core::marker::PhantomData;
pub struct SsimInput<B: Backend> {
outputs: Tensor<B, 4>,
targets: Tensor<B, 4>,
}
impl<B: Backend> SsimInput<B> {
pub fn new(outputs: Tensor<B, 4>, targets: Tensor<B, 4>) -> Self {
assert!(
outputs.dims() == targets.dims(),
"Shape mismatch: outputs {:?}, targets {:?}",
outputs.dims(),
targets.dims()
);
Self { outputs, targets }
}
}
#[derive(Debug, Clone, Copy)]
pub struct SsimMetricConfig {
pub pixel_range: f32,
pub k1: f32,
pub k2: f32,
pub kernel_size: usize,
pub sigma: f32,
}
impl SsimMetricConfig {
pub fn new(pixel_range: f32) -> Self {
assert!(pixel_range > 0.0, "pixel_range must be positive");
Self {
pixel_range: pixel_range,
k1: 0.01,
k2: 0.03,
kernel_size: 11,
sigma: 1.5,
}
}
pub fn with_k1_k2(mut self, k1: f32, k2: f32) -> Self {
assert!(k1 > 0.0, "k1 must be positive");
assert!(k2 > 0.0, "k2 must be positive");
self.k1 = k1;
self.k2 = k2;
self
}
pub fn with_kernel_size(mut self, kernel_size: usize) -> Self {
assert!(
kernel_size > 0 && kernel_size % 2 == 1,
"kernel_size must be positive and an odd number"
);
self.kernel_size = kernel_size;
self
}
pub fn with_sigma(mut self, sigma: f32) -> Self {
assert!(sigma > 0.0, "sigma must be positive");
self.sigma = sigma;
self
}
}
#[derive(Clone)]
pub struct SsimMetric<B: Backend> {
name: MetricName,
state: NumericMetricState,
_b: PhantomData<B>,
config: SsimMetricConfig,
}
impl<B: Backend> SsimMetric<B> {
pub fn new(config: SsimMetricConfig) -> Self {
Self {
name: MetricName::new(format!(
"SSIM (dr={}, w={}, σ={})",
config.pixel_range, config.kernel_size, config.sigma,
)),
state: NumericMetricState::default(),
config,
_b: PhantomData,
}
}
pub fn with_name(mut self, name: &str) -> Self {
self.name = MetricName::new(name.to_string());
self
}
fn create_1d_gaussian_kernel(&self) -> Vec<f32> {
let size = self.config.kernel_size;
let sigma = self.config.sigma;
let center = (size / 2) as f32;
let mut kernel = vec![0.0f32; size];
let mut sum = 0.0f32;
for (i, v) in kernel.iter_mut().enumerate() {
let x = i as f32 - center;
let value = (-(x * x) / (2.0 * sigma * sigma)).exp();
*v = value;
sum += value;
}
for v in kernel.iter_mut() {
*v /= sum;
}
kernel
}
fn gaussian_conv_separable(
&self,
input: Tensor<B, 4>,
kernel_1d: &[f32],
channels: usize,
device: &B::Device,
) -> Tensor<B, 4> {
let size = self.config.kernel_size;
let padding = size / 2;
let horizontal_kernel = Tensor::<B, 1>::from_floats(kernel_1d, device)
.reshape([1, 1, 1, size]) .repeat_dim(0, channels);
let vertical_kernel = Tensor::<B, 1>::from_floats(kernel_1d, device)
.reshape([1, 1, size, 1]) .repeat_dim(0, channels);
let horizontal_conv_options = ConvOptions::new([1, 1], [0, padding], [1, 1], channels);
let input_after_horizontal_conv =
conv2d(input, horizontal_kernel, None, horizontal_conv_options);
let vertical_conv_options = ConvOptions::new([1, 1], [padding, 0], [1, 1], channels);
conv2d(
input_after_horizontal_conv,
vertical_kernel,
None,
vertical_conv_options,
)
}
}
impl<B: Backend> Metric for SsimMetric<B> {
type Input = SsimInput<B>;
fn name(&self) -> MetricName {
self.name.clone()
}
fn update(&mut self, item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
let dims = item.outputs.dims();
let batch_size = dims[0];
let channels = dims[1];
let device = item.outputs.device();
let img_height = dims[2];
let img_width = dims[3];
assert!(
img_height >= self.config.kernel_size && img_width >= self.config.kernel_size,
"Image dimensions (H={}, W={}) must be >= kernel_size ({})",
img_height,
img_width,
self.config.kernel_size
);
let c1 = (self.config.k1 * self.config.pixel_range).powi(2);
let c2 = (self.config.k2 * self.config.pixel_range).powi(2);
let kernel_1d = self.create_1d_gaussian_kernel();
let x = item.outputs.clone();
let y = item.targets.clone();
let mu_x = self.gaussian_conv_separable(x.clone(), &kernel_1d, channels, &device);
let mu_y = self.gaussian_conv_separable(y.clone(), &kernel_1d, channels, &device);
let mu_x_mu_y = mu_x.clone() * mu_y.clone();
let square_of_mu_x = mu_x.clone() * mu_x.clone();
let square_of_mu_y = mu_y.clone() * mu_y.clone();
let mu_of_x_squared =
self.gaussian_conv_separable(x.clone() * x.clone(), &kernel_1d, channels, &device);
let mu_of_y_squared =
self.gaussian_conv_separable(y.clone() * y.clone(), &kernel_1d, channels, &device);
let var_x = (mu_of_x_squared - square_of_mu_x.clone()).clamp_min(0.0);
let var_y = (mu_of_y_squared - square_of_mu_y.clone()).clamp_min(0.0);
let mu_xy = self.gaussian_conv_separable(x * y, &kernel_1d, channels, &device);
let sigma_xy = mu_xy - mu_x_mu_y.clone();
let numerator = (mu_x_mu_y.mul_scalar(2.0_f32) + c1) * (sigma_xy.mul_scalar(2.0_f32) + c2);
let denominator = (square_of_mu_x + square_of_mu_y + c1) * (var_x + var_y + c2);
let ssim_tensor = numerator / denominator;
let ssim_per_image = ssim_tensor.mean_dims(&[1, 2, 3]);
let avg_ssim = ssim_per_image.mean().into_scalar().elem::<f64>();
self.state.update(
avg_ssim,
batch_size,
FormatOptions::new(self.name()).precision(4),
)
}
fn clear(&mut self) {
self.state.reset();
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: None,
higher_is_better: true,
}
.into()
}
}
impl<B: Backend> Numeric for SsimMetric<B> {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}
#[cfg(test)]
#[allow(clippy::manual_range_contains)]
mod tests {
use super::*;
use crate::{TestBackend, metric::Numeric};
use burn_core::tensor::{Distribution, Shape, TensorData};
fn test_config() -> SsimMetricConfig {
SsimMetricConfig::new(1.0)
.with_kernel_size(3)
.with_sigma(1.0)
}
#[test]
fn test_ssim_perfect_similarity() {
let device = Default::default();
let outputs = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[
[0.1_f32, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8],
[0.2, 0.3, 0.4, 0.5],
[0.6, 0.7, 0.8, 0.9],
]]]),
&device,
);
let targets = outputs.clone();
let mut metric = SsimMetric::<TestBackend>::new(test_config());
let input = SsimInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let ssim = metric.value().current();
assert!(
(ssim - 1.0).abs() < 0.001,
"SSIM for identical images should be 1.0, got {}",
ssim
);
}
#[test]
fn test_ssim_completely_different() {
let device = Default::default();
let outputs = Tensor::<TestBackend, 4>::zeros([1, 1, 4, 4], &device);
let targets = Tensor::<TestBackend, 4>::ones([1, 1, 4, 4], &device);
let mut metric = SsimMetric::<TestBackend>::new(test_config());
let input = SsimInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let ssim = metric.value().current();
assert!(
ssim < 0.0001,
"SSIM for black vs white images should be very low, got {}",
ssim
);
}
#[test]
fn test_ssim_similar_images() {
let device = Default::default();
let outputs = Tensor::<TestBackend, 4>::full([1, 1, 4, 4], 0.5, &device);
let targets = Tensor::<TestBackend, 4>::full([1, 1, 4, 4], 0.51, &device);
let mut metric = SsimMetric::<TestBackend>::new(test_config());
let input = SsimInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let ssim = metric.value().current();
assert!(
ssim > 0.99,
"SSIM for very similar images should be close to 1.0, got {}",
ssim
);
}
#[test]
fn test_ssim_batch_averaging() {
let device = Default::default();
let outputs = Tensor::<TestBackend, 4>::from_data(
TensorData::from([
[[
[0.5_f32, 0.5, 0.5, 0.5],
[0.5, 0.5, 0.5, 0.5],
[0.5, 0.5, 0.5, 0.5],
[0.5, 0.5, 0.5, 0.5],
]],
[[
[0.0_f32, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
]],
]),
&device,
);
let targets = Tensor::<TestBackend, 4>::from_data(
TensorData::from([
[[
[0.5_f32, 0.5, 0.5, 0.5],
[0.5, 0.5, 0.5, 0.5],
[0.5, 0.5, 0.5, 0.5],
[0.5, 0.5, 0.5, 0.5],
]],
[[
[1.0_f32, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
]],
]),
&device,
);
let mut metric = SsimMetric::<TestBackend>::new(test_config());
let input = SsimInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let ssim = metric.value().current();
assert!(
ssim > 0.49 && ssim < 0.51,
"Average SSIM should be around 0.5, got {}",
ssim
);
}
#[test]
fn test_ssim_multichannel() {
let device = Default::default();
let outputs = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[
[
[0.5_f32, 0.6, 0.7, 0.8],
[0.4, 0.5, 0.6, 0.7],
[0.3, 0.4, 0.5, 0.6],
[0.2, 0.3, 0.4, 0.5],
],
[
[0.3_f32, 0.4, 0.5, 0.6],
[0.2, 0.3, 0.4, 0.5],
[0.1, 0.2, 0.3, 0.4],
[0.0, 0.1, 0.2, 0.3],
],
[
[0.7_f32, 0.8, 0.9, 1.0],
[0.6, 0.7, 0.8, 0.9],
[0.5, 0.6, 0.7, 0.8],
[0.4, 0.5, 0.6, 0.7],
],
]]),
&device,
);
let targets = outputs.clone();
let mut metric = SsimMetric::<TestBackend>::new(test_config());
let input = SsimInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let ssim = metric.value().current();
assert!(
(ssim - 1.0).abs() < 0.001,
"SSIM for identical RGB images should be 1.0, got {}",
ssim
);
}
#[test]
fn test_ssim_symmetry() {
let device = Default::default();
let img1 = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[
[0.1_f32, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8],
[0.2, 0.3, 0.4, 0.5],
[0.6, 0.7, 0.8, 0.9],
]]]),
&device,
);
let img2 = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[
[0.2_f32, 0.3, 0.4, 0.5],
[0.6, 0.7, 0.8, 0.9],
[0.3, 0.4, 0.5, 0.6],
[0.7, 0.8, 0.9, 1.0],
]]]),
&device,
);
let config = test_config();
let mut metric1 = SsimMetric::<TestBackend>::new(config);
let input1 = SsimInput::new(img1.clone(), img2.clone());
let _entry = metric1.update(&input1, &MetricMetadata::fake());
let ssim1 = metric1.value().current();
let mut metric2 = SsimMetric::<TestBackend>::new(config);
let input2 = SsimInput::new(img2, img1);
let _entry = metric2.update(&input2, &MetricMetadata::fake());
let ssim2 = metric2.value().current();
assert!(
(ssim1 - ssim2).abs() < 0.001,
"SSIM should be symmetric: SSIM(x,y)={} vs SSIM(y,x)={}",
ssim1,
ssim2
);
}
#[test]
fn test_ssim_range() {
let device = Default::default();
let shape = Shape::new([1, 1, 11, 11]);
let distribution = Distribution::Uniform(0.0, 1.0);
let outputs = Tensor::<TestBackend, 4>::random(shape.clone(), distribution, &device);
let targets = Tensor::<TestBackend, 4>::random(shape, distribution, &device);
let mut metric = SsimMetric::<TestBackend>::new(test_config());
let input = SsimInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let ssim = metric.value().current();
assert!(
ssim >= -1.0 && ssim <= 1.0,
"SSIM should be in range [-1, 1], got {}",
ssim
);
}
#[test]
fn test_ssim_running_average() {
let device = Default::default();
let mut metric = SsimMetric::<TestBackend>::new(test_config());
let outputs1 = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[
[0.5_f32, 0.6, 0.7, 0.8],
[0.4, 0.5, 0.6, 0.7],
[0.3, 0.4, 0.5, 0.6],
[0.2, 0.3, 0.4, 0.5],
]]]),
&device,
);
let targets1 = outputs1.clone();
let input1 = SsimInput::new(outputs1, targets1);
let _entry = metric.update(&input1, &MetricMetadata::fake());
let ssim1 = metric.value().current();
assert!(
(ssim1 - 1.0).abs() < 0.001,
"First update SSIM should be ~1.0, got {}",
ssim1
);
let outputs2 = Tensor::<TestBackend, 4>::zeros([1, 1, 4, 4], &device);
let targets2 = Tensor::<TestBackend, 4>::ones([1, 1, 4, 4], &device);
let input2 = SsimInput::new(outputs2, targets2);
let _entry = metric.update(&input2, &MetricMetadata::fake());
let running_avg = metric.running_value().current();
assert!(
running_avg > 0.49 && running_avg < 0.51,
"Running average should be around 0.5, got {}",
running_avg
);
}
#[test]
fn test_ssim_clear() {
let device = Default::default();
let mut metric = SsimMetric::<TestBackend>::new(test_config());
let outputs = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[
[0.5_f32, 0.6, 0.7, 0.8],
[0.4, 0.5, 0.6, 0.7],
[0.3, 0.4, 0.5, 0.6],
[0.2, 0.3, 0.4, 0.5],
]]]),
&device,
);
let targets = outputs.clone();
let input = SsimInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let ssim = metric.value().current();
assert!(
(ssim - 1.0).abs() < 0.001,
"Expected SSIM ~1.0, got {}",
ssim
);
metric.clear();
let ssim = metric.running_value().current();
assert!(ssim.is_nan(), "Expected NaN after clear, got {}", ssim);
}
#[test]
fn test_ssim_custom_name() {
let config = SsimMetricConfig::new(1.0);
let metric = SsimMetric::<TestBackend>::new(config).with_name("CustomSSIM");
assert_eq!(metric.name().to_string(), "CustomSSIM");
let metric = SsimMetric::<TestBackend>::new(test_config());
assert_eq!(metric.name().to_string(), "SSIM (dr=1, w=3, σ=1)");
let config = SsimMetricConfig::new(255.0);
let metric = SsimMetric::<TestBackend>::new(config);
assert_eq!(metric.name().to_string(), "SSIM (dr=255, w=11, σ=1.5)");
}
#[test]
fn test_ssim_pixel_range_255() {
let device = Default::default();
let shape = Shape::new([1, 1, 10, 10]);
let distribution = Distribution::Uniform(0.0, 255.0);
let outputs = Tensor::<TestBackend, 4>::random(shape.clone(), distribution, &device);
let targets = outputs.clone();
let config = SsimMetricConfig::new(255.0).with_kernel_size(3);
let mut metric = SsimMetric::<TestBackend>::new(config);
let input = SsimInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let ssim = metric.value().current();
assert!(
(ssim - 1.0).abs() < 0.001,
"SSIM for identical 8-bit images should be 1.0, got {}",
ssim
);
}
#[test]
fn test_ssim_large_batch() {
let device = Default::default();
let shape = Shape::new([20, 3, 30, 30]);
let distribution = Distribution::Uniform(0.0, 1.0);
let outputs = Tensor::<TestBackend, 4>::random(shape, distribution, &device);
let targets = outputs.clone();
let mut metric = SsimMetric::<TestBackend>::new(test_config());
let input = SsimInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let ssim = metric.value().current();
assert!(
(ssim - 1.0).abs() < 0.001,
"SSIM for identical batch should be 1.0, got {}",
ssim
);
}
#[test]
fn test_ssim_default_kernel_size() {
let device = Default::default();
let shape = Shape::new([1, 1, 1080, 1920]);
let distribution = Distribution::Uniform(0.0, 1.0);
let outputs = Tensor::<TestBackend, 4>::random(shape, distribution, &device);
let targets = outputs.clone();
let config = SsimMetricConfig::new(1.0); let mut metric = SsimMetric::<TestBackend>::new(config);
let input = SsimInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let ssim = metric.value().current();
assert!(
(ssim - 1.0).abs() < 0.001,
"SSIM with default window size should work and SSIM should be ~0.0, got {}",
ssim
);
}
#[test]
fn test_ssim_attributes() {
let config = SsimMetricConfig::new(1.0);
let metric = SsimMetric::<TestBackend>::new(config);
let attrs = metric.attributes();
match attrs {
MetricAttributes::Numeric(numeric_attrs) => {
assert_eq!(numeric_attrs.unit, None);
assert!(numeric_attrs.higher_is_better);
}
_ => panic!("Expected numeric attributes"),
}
}
#[test]
#[should_panic(expected = "Shape mismatch")]
fn test_ssim_shape_mismatch() {
let device = Default::default();
let outputs = Tensor::<TestBackend, 4>::zeros([1, 1, 4, 4], &device);
let targets = Tensor::<TestBackend, 4>::zeros([1, 1, 5, 5], &device);
let _ = SsimInput::new(outputs, targets);
}
#[test]
#[should_panic(expected = "Image dimensions (H=4, W=4) must be >= kernel_size (11)")]
fn test_ssim_image_too_small() {
let device = Default::default();
let outputs = Tensor::<TestBackend, 4>::zeros([1, 1, 4, 4], &device);
let targets = outputs.clone();
let config = SsimMetricConfig::new(1.0);
let mut metric = SsimMetric::<TestBackend>::new(config);
let input = SsimInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
}
#[test]
fn test_ssim_valid_k1_k2() {
let config = SsimMetricConfig::new(1.0).with_k1_k2(0.015, 0.035);
assert!(
config.k1 == 0.015 && config.k2 == 0.035,
"Expected k1=0.015 and k2=0.035, got k1={} and k2={}",
config.k1,
config.k2
);
}
#[test]
#[should_panic(expected = "pixel_range must be positive")]
fn test_ssim_negative_pixel_range() {
let _ = SsimMetricConfig::new(-1.0);
}
#[test]
#[should_panic(expected = "pixel_range must be positive")]
fn test_ssim_zero_pixel_range() {
let _ = SsimMetricConfig::new(0.0);
}
#[test]
#[should_panic(expected = "k1 must be positive")]
fn test_ssim_negative_k1() {
let _ = SsimMetricConfig::new(1.0).with_k1_k2(-0.01, 0.03);
}
#[test]
#[should_panic(expected = "k2 must be positive")]
fn test_ssim_negative_k2() {
let _ = SsimMetricConfig::new(1.0).with_k1_k2(0.01, -0.03);
}
#[test]
#[should_panic(expected = "kernel_size must be positive and an odd number")]
fn test_ssim_even_kernel_size() {
let _ = SsimMetricConfig::new(1.0).with_kernel_size(10);
}
#[test]
#[should_panic(expected = "kernel_size must be positive and an odd number")]
fn test_ssim_zero_kernel_size() {
let _ = SsimMetricConfig::new(1.0).with_kernel_size(0);
}
#[test]
#[should_panic(expected = "sigma must be positive")]
fn test_ssim_negative_sigma() {
let _ = SsimMetricConfig::new(1.0).with_sigma(-1.5);
}
#[test]
#[should_panic(expected = "sigma must be positive")]
fn test_ssim_zero_sigma() {
let _ = SsimMetricConfig::new(1.0).with_sigma(0.0);
}
}