use crate::metric::{
Metric, MetricAttributes, MetricMetadata, MetricName, Numeric, NumericAttributes, NumericEntry,
SerializedEntry,
state::{FormatOptions, NumericMetricState},
};
use burn_core::{
prelude::{Backend, Tensor},
tensor::ElementConversion,
};
use core::marker::PhantomData;
use std::f64::consts::LN_10;
pub struct PsnrInput<B: Backend> {
outputs: Tensor<B, 4>,
targets: Tensor<B, 4>,
}
impl<B: Backend> PsnrInput<B> {
pub fn new(outputs: Tensor<B, 4>, targets: Tensor<B, 4>) -> Self {
assert!(
outputs.dims() == targets.dims(),
"Shape mismatch: outputs {:?}, targets {:?}",
outputs.dims(),
targets.dims()
);
Self { outputs, targets }
}
}
#[derive(Debug, Clone, Copy)]
pub struct PsnrMetricConfig {
pub max_pixel_val: f64,
pub epsilon: f64,
}
impl PsnrMetricConfig {
pub fn new(max_pixel_val: f64) -> Self {
assert!(max_pixel_val > 0.0, "max_pixel_val must be positive");
Self {
max_pixel_val,
epsilon: 1e-10,
}
}
pub fn with_epsilon(mut self, epsilon: f64) -> Self {
assert!(epsilon > 0.0, "epsilon must be positive");
self.epsilon = epsilon;
self
}
}
#[derive(Clone)]
pub struct PsnrMetric<B: Backend> {
name: MetricName,
state: NumericMetricState,
_b: PhantomData<B>,
config: PsnrMetricConfig,
}
impl<B: Backend> PsnrMetric<B> {
pub fn new(config: PsnrMetricConfig) -> Self {
Self {
name: MetricName::new(format!("PSNR@{}", config.max_pixel_val)),
state: NumericMetricState::default(),
config,
_b: PhantomData,
}
}
pub fn with_name(mut self, name: &str) -> Self {
self.name = MetricName::new(name.to_string());
self
}
}
impl<B: Backend> Metric for PsnrMetric<B> {
type Input = PsnrInput<B>;
fn name(&self) -> MetricName {
self.name.clone()
}
fn update(&mut self, item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {
let dims = item.outputs.dims();
let batch_size = dims[0];
let outputs = item.outputs.clone();
let targets = item.targets.clone();
let diff = outputs.sub(targets);
let mse_per_image = diff.powi_scalar(2).mean_dims(&[1, 2, 3]);
let mse_flat = mse_per_image.flatten::<1>(0, 3);
let mse_clamped = mse_flat.clamp_min(self.config.epsilon);
let max_squared = self.config.max_pixel_val * self.config.max_pixel_val;
let psnr_per_image = mse_clamped
.recip()
.mul_scalar(max_squared)
.log()
.mul_scalar(10.0 / LN_10);
let avg_psnr = psnr_per_image.mean().into_scalar().elem::<f64>();
self.state.update(
avg_psnr,
batch_size,
FormatOptions::new(self.name()).unit("dB").precision(2),
)
}
fn clear(&mut self) {
self.state.reset();
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: Some("dB".to_string()),
higher_is_better: true,
}
.into()
}
}
impl<B: Backend> Numeric for PsnrMetric<B> {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{TestBackend, metric::Numeric};
use burn_core::tensor::TensorData;
#[test]
fn test_psnr_perfect_reconstruction() {
let device = Default::default();
let outputs = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[[1.0_f32, 0.5], [0.25, 0.75]]]]),
&device,
);
let targets = outputs.clone();
let config = PsnrMetricConfig::new(1.0);
let mut metric = PsnrMetric::<TestBackend>::new(config);
let input = PsnrInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let psnr = metric.value().current();
assert!(
psnr >= 99.0,
"PSNR for perfect reconstruction should be ~100 dB, got {} dB",
psnr
);
}
#[test]
fn test_psnr_constant_error() {
let device = Default::default();
let outputs = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[[0.1_f32, 0.1], [0.1, 0.1]]]]),
&device,
);
let targets = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[[0.0_f32, 0.0], [0.0, 0.0]]]]),
&device,
);
let config = PsnrMetricConfig::new(1.0);
let mut metric = PsnrMetric::<TestBackend>::new(config);
let input = PsnrInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let psnr = metric.value().current();
assert!(
(psnr - 20.0).abs() < 0.01,
"Expected PSNR ~20 dB, got {} dB",
psnr
);
}
#[test]
fn test_psnr_varying_error() {
let device = Default::default();
let outputs = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[[0.1_f32, 0.2], [0.3, 0.4]]]]),
&device,
);
let targets = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[[0.0_f32, 0.0], [0.0, 0.0]]]]),
&device,
);
let config = PsnrMetricConfig::new(1.0);
let mut metric = PsnrMetric::<TestBackend>::new(config);
let input = PsnrInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let psnr = metric.value().current();
let expected_psnr = 10.0 * (1.0_f64 / 0.075).log10();
assert!(
(psnr - expected_psnr).abs() < 0.01,
"Expected PSNR ~{:.3} dB, got {} dB",
expected_psnr,
psnr
);
}
#[test]
fn test_psnr_max_pixel_255() {
let device = Default::default();
let outputs = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[[10.0_f32, 10.0], [10.0, 10.0]]]]),
&device,
);
let targets = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[[0.0_f32, 0.0], [0.0, 0.0]]]]),
&device,
);
let config = PsnrMetricConfig::new(255.0);
let mut metric = PsnrMetric::<TestBackend>::new(config);
let input = PsnrInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let psnr = metric.value().current();
let expected_psnr = 10.0 * (255.0_f64 * 255.0 / 100.0).log10();
assert!(
(psnr - expected_psnr).abs() < 0.01,
"Expected PSNR ~{:.3} dB, got {} dB",
expected_psnr,
psnr
);
}
#[test]
fn test_psnr_batch_averaging() {
let device = Default::default();
let outputs = Tensor::<TestBackend, 4>::from_data(
TensorData::from([
[[[0.1_f32, 0.1], [0.1, 0.1]]],
[[[0.01_f32, 0.01], [0.01, 0.01]]],
]),
&device,
);
let targets = Tensor::<TestBackend, 4>::from_data(
TensorData::from([
[[[0.0_f32, 0.0], [0.0, 0.0]]],
[[[0.0_f32, 0.0], [0.0, 0.0]]],
]),
&device,
);
let config = PsnrMetricConfig::new(1.0);
let mut metric = PsnrMetric::<TestBackend>::new(config);
let input = PsnrInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let psnr = metric.value().current();
let expected_psnr = 30.0;
assert!(
(psnr - expected_psnr).abs() < 0.01,
"Expected average PSNR ~{} dB, got {} dB",
expected_psnr,
psnr
);
}
#[test]
fn test_psnr_multichannel() {
let device = Default::default();
let outputs = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[
[[0.1_f32, 0.1], [0.1, 0.1]],
[[0.1_f32, 0.1], [0.1, 0.1]],
[[0.1_f32, 0.1], [0.1, 0.1]],
]]),
&device,
);
let targets = Tensor::<TestBackend, 4>::zeros([1, 3, 2, 2], &device);
let config = PsnrMetricConfig::new(1.0);
let mut metric = PsnrMetric::<TestBackend>::new(config);
let input = PsnrInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let psnr = metric.value().current();
let expected_psnr = 20.0;
assert!(
(psnr - expected_psnr).abs() < 0.01,
"Expected PSNR ~{} dB, got {} dB",
expected_psnr,
psnr
);
}
#[test]
fn test_psnr_running_average() {
let device = Default::default();
let config = PsnrMetricConfig::new(1.0);
let mut metric = PsnrMetric::<TestBackend>::new(config);
let outputs1 = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[[0.1_f32, 0.1], [0.1, 0.1]]]]),
&device,
);
let targets1 = Tensor::<TestBackend, 4>::zeros([1, 1, 2, 2], &device);
let input1 = PsnrInput::new(outputs1, targets1);
let _entry = metric.update(&input1, &MetricMetadata::fake());
let psnr1 = metric.value().current();
let expected_psnr1 = 20.0;
assert!(
(psnr1 - expected_psnr1).abs() < 0.01,
"First update PSNR should be ~{} dB, got {} dB",
expected_psnr1,
psnr1
);
let outputs2 = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[[0.01_f32, 0.01], [0.01, 0.01]]]]),
&device,
);
let targets2 = Tensor::<TestBackend, 4>::zeros([1, 1, 2, 2], &device);
let input2 = PsnrInput::new(outputs2, targets2);
let _entry = metric.update(&input2, &MetricMetadata::fake());
let running_avg_psnr = metric.running_value().current();
let expected_running_avg_psnr = 30.0;
assert!(
(running_avg_psnr - expected_running_avg_psnr).abs() < 0.01,
"Running average should be ~{} dB, got {} dB",
expected_running_avg_psnr,
running_avg_psnr
);
}
#[test]
fn test_psnr_clear() {
let device = Default::default();
let config = PsnrMetricConfig::new(1.0);
let mut metric = PsnrMetric::<TestBackend>::new(config);
let outputs = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[[0.1_f32, 0.1], [0.1, 0.1]]]]),
&device,
);
let targets = Tensor::<TestBackend, 4>::zeros([1, 1, 2, 2], &device);
let input = PsnrInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let psnr = metric.value().current();
let expected_psnr = 20.0;
assert!(
(psnr - expected_psnr).abs() < 0.01,
"Expected PSNR ~{} dB, got {} dB",
expected_psnr,
psnr
);
metric.clear();
let psnr = metric.running_value().current();
assert!(psnr.is_nan(), "Expected NaN after clear, got {} dB", psnr)
}
#[test]
fn test_psnr_custom_name() {
let config = PsnrMetricConfig::new(1.0);
let metric = PsnrMetric::<TestBackend>::new(config).with_name("CustomPSNR");
assert_eq!(metric.name().to_string(), "CustomPSNR");
}
#[test]
fn test_psnr_custom_epsilon() {
let device = Default::default();
let config = PsnrMetricConfig::new(1.0).with_epsilon(0.01);
let mut metric = PsnrMetric::<TestBackend>::new(config);
let outputs = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[[0.5_f32, 0.5], [0.5, 0.5]]]]),
&device,
);
let targets = outputs.clone();
let input = PsnrInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let psnr = metric.value().current();
let expected_psnr = 20.0;
assert!(
(psnr - expected_psnr).abs() < 0.01,
"Expected PSNR ~{} dB with epsilon=0.01, got {}",
expected_psnr,
psnr
);
}
#[test]
fn test_psnr_negative_errors() {
let device = Default::default();
let outputs = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[[0.0_f32, 0.0], [0.0, 0.0]]]]),
&device,
);
let targets = Tensor::<TestBackend, 4>::from_data(
TensorData::from([[[[0.1_f32, 0.1], [0.1, 0.1]]]]),
&device,
);
let config = PsnrMetricConfig::new(1.0);
let mut metric = PsnrMetric::<TestBackend>::new(config);
let input = PsnrInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let psnr = metric.value().current();
let expected_psnr = 20.0;
assert!(
(psnr - expected_psnr).abs() < 0.01,
"Expected PSNR ~{} dB, got {}",
expected_psnr,
psnr
);
}
#[test]
fn test_psnr_large_batch() {
let device = Default::default();
let batch_size = 8;
let outputs = Tensor::<TestBackend, 4>::full([batch_size, 3, 4, 4], 0.1, &device);
let targets = Tensor::<TestBackend, 4>::zeros([batch_size, 3, 4, 4], &device);
let config = PsnrMetricConfig::new(1.0);
let mut metric = PsnrMetric::<TestBackend>::new(config);
let input = PsnrInput::new(outputs, targets);
let _entry = metric.update(&input, &MetricMetadata::fake());
let psnr = metric.value().current();
let expected_psnr = 20.0;
assert!(
(psnr - expected_psnr).abs() < 0.01,
"Expected PSNR ~{} dB, got {}",
expected_psnr,
psnr
);
}
#[test]
fn test_psnr_attributes() {
let config = PsnrMetricConfig::new(1.0);
let metric = PsnrMetric::<TestBackend>::new(config);
let attrs = metric.attributes();
match attrs {
MetricAttributes::Numeric(numeric_attrs) => {
assert_eq!(numeric_attrs.unit, Some("dB".to_string()));
assert!(numeric_attrs.higher_is_better);
}
_ => panic!("Expected numeric attributes"),
}
}
#[test]
#[should_panic(expected = "Shape mismatch")]
fn test_psnr_shape_mismatch() {
let device = Default::default();
let outputs = Tensor::<TestBackend, 4>::zeros([1, 1, 2, 2], &device);
let targets = Tensor::<TestBackend, 4>::zeros([1, 1, 3, 3], &device);
let _ = PsnrInput::new(outputs, targets);
}
#[test]
#[should_panic(expected = "max_pixel_val must be positive")]
fn test_psnr_negative_max_pixel_val() {
let _ = PsnrMetricConfig::new(-1.0);
}
#[test]
#[should_panic(expected = "max_pixel_val must be positive")]
fn test_psnr_zero_max_pixel_val() {
let _ = PsnrMetricConfig::new(0.0);
}
#[test]
#[should_panic(expected = "epsilon must be positive")]
fn test_psnr_negative_epsilon() {
let _ = PsnrMetricConfig::new(1.0).with_epsilon(-1e-10);
}
#[test]
#[should_panic(expected = "epsilon must be positive")]
fn test_psnr_zero_epsilon() {
let _ = PsnrMetricConfig::new(1.0).with_epsilon(0.0);
}
}