use crate::{Result, VisionError};
use image::DynamicImage;
use torsh_tensor::Tensor;
use super::image_to_tensor;
pub fn calculate_stats(images: &[DynamicImage]) -> Result<(Vec<f32>, Vec<f32>)> {
if images.is_empty() {
return Err(VisionError::InvalidArgument(
"No images provided".to_string(),
));
}
let mut all_pixels: Vec<Vec<f32>> = vec![Vec::new(); 3];
for image in images {
let tensor = image_to_tensor(image)?;
let shape = tensor.shape();
if shape.dims()[0] == 3 {
for c in 0..3 {
for y in 0..shape.dims()[1] {
for x in 0..shape.dims()[2] {
let pixel_val = tensor.get(&[c, y, x])?;
all_pixels[c].push(pixel_val);
}
}
}
} else if shape.dims()[0] == 1 {
for y in 0..shape.dims()[1] {
for x in 0..shape.dims()[2] {
let pixel_val = tensor.get(&[0, y, x])?;
for c in 0..3 {
all_pixels[c].push(pixel_val);
}
}
}
}
}
let mut means = Vec::new();
let mut stds = Vec::new();
for channel_pixels in &all_pixels {
if channel_pixels.is_empty() {
means.push(0.0);
stds.push(1.0);
continue;
}
let sum: f32 = channel_pixels.iter().sum();
let mean = sum / channel_pixels.len() as f32;
means.push(mean);
let variance: f32 = channel_pixels
.iter()
.map(|x| (x - mean).powi(2))
.sum::<f32>()
/ channel_pixels.len() as f32;
let std = variance.sqrt();
stds.push(std.max(1e-8)); }
Ok((means, stds))
}
pub fn psnr(image1: &Tensor<f32>, image2: &Tensor<f32>, max_val: Option<f32>) -> Result<f32> {
let shape1 = image1.shape();
let shape2 = image2.shape();
if shape1.dims().len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (C, H, W) for image1, got {}D",
shape1.dims().len()
)));
}
if shape2.dims().len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (C, H, W) for image2, got {}D",
shape2.dims().len()
)));
}
if shape1.dims() != shape2.dims() {
return Err(VisionError::InvalidArgument(
"Input tensors must have the same shape".to_string(),
));
}
let max_val = max_val.unwrap_or(1.0);
let diff = image1
.sub(image2)
.map_err(|e| VisionError::TensorError(e))?;
let squared_diff = diff.mul(&diff).map_err(|e| VisionError::TensorError(e))?;
let mse = squared_diff.mean(None, false)?.item()?;
if mse < 1e-10 {
return Ok(f32::INFINITY); }
let psnr_value = 20.0 * (max_val / mse.sqrt()).log10();
Ok(psnr_value)
}
pub fn ssim(
image1: &Tensor<f32>,
image2: &Tensor<f32>,
window_size: Option<usize>,
k1: Option<f32>,
k2: Option<f32>,
) -> Result<f32> {
let shape1 = image1.shape();
let shape2 = image2.shape();
if shape1.dims().len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (C, H, W) for image1, got {}D",
shape1.dims().len()
)));
}
if shape2.dims().len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (C, H, W) for image2, got {}D",
shape2.dims().len()
)));
}
if shape1.dims() != shape2.dims() {
return Err(VisionError::InvalidArgument(
"Input tensors must have the same shape".to_string(),
));
}
let window_size = window_size.unwrap_or(11);
let k1 = k1.unwrap_or(0.01);
let k2 = k2.unwrap_or(0.03);
let data_range = 1.0;
let c1 = (k1 * data_range).powi(2);
let c2 = (k2 * data_range).powi(2);
let (channels, height, width) = (shape1.dims()[0], shape1.dims()[1], shape1.dims()[2]);
if window_size > height || window_size > width {
return Err(VisionError::InvalidArgument(format!(
"Window size ({}) too large for image dimensions ({}x{})",
window_size, height, width
)));
}
let mut ssim_total = 0.0;
let mut valid_windows = 0;
for c in 0..channels {
let mut channel_ssim = 0.0;
let mut channel_windows = 0;
for y in 0..=(height - window_size) {
for x in 0..=(width - window_size) {
let (mu1, mu2, sigma1_sq, sigma2_sq, sigma12) =
calculate_window_statistics(image1, image2, c, y, x, window_size)?;
let numerator = (2.0 * mu1 * mu2 + c1) * (2.0 * sigma12 + c2);
let denominator = (mu1 * mu1 + mu2 * mu2 + c1) * (sigma1_sq + sigma2_sq + c2);
if denominator > 0.0 {
channel_ssim += numerator / denominator;
channel_windows += 1;
}
}
}
if channel_windows > 0 {
ssim_total += channel_ssim / channel_windows as f32;
valid_windows += 1;
}
}
if valid_windows > 0 {
Ok(ssim_total / valid_windows as f32)
} else {
Ok(0.0)
}
}
fn calculate_window_statistics(
image1: &Tensor<f32>,
image2: &Tensor<f32>,
channel: usize,
start_y: usize,
start_x: usize,
window_size: usize,
) -> Result<(f32, f32, f32, f32, f32)> {
let mut sum1 = 0.0;
let mut sum2 = 0.0;
let mut sum1_sq = 0.0;
let mut sum2_sq = 0.0;
let mut sum12 = 0.0;
let n = (window_size * window_size) as f32;
for y in start_y..(start_y + window_size) {
for x in start_x..(start_x + window_size) {
let val1 = image1.get(&[channel, y, x])?;
let val2 = image2.get(&[channel, y, x])?;
sum1 += val1;
sum2 += val2;
sum1_sq += val1 * val1;
sum2_sq += val2 * val2;
sum12 += val1 * val2;
}
}
let mu1 = sum1 / n;
let mu2 = sum2 / n;
let sigma1_sq = (sum1_sq / n) - (mu1 * mu1);
let sigma2_sq = (sum2_sq / n) - (mu2 * mu2);
let sigma12 = (sum12 / n) - (mu1 * mu2);
Ok((mu1, mu2, sigma1_sq, sigma2_sq, sigma12))
}
pub fn mse(image1: &Tensor<f32>, image2: &Tensor<f32>) -> Result<f32> {
let shape1 = image1.shape();
let shape2 = image2.shape();
if shape1.dims().len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (C, H, W) for image1, got {}D",
shape1.dims().len()
)));
}
if shape2.dims().len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (C, H, W) for image2, got {}D",
shape2.dims().len()
)));
}
if shape1.dims() != shape2.dims() {
return Err(VisionError::InvalidArgument(
"Input tensors must have the same shape".to_string(),
));
}
let diff = image1
.sub(image2)
.map_err(|e| VisionError::TensorError(e))?;
let squared_diff = diff.mul(&diff).map_err(|e| VisionError::TensorError(e))?;
let mse_value = squared_diff.mean(None, false)?.item()?;
Ok(mse_value)
}
pub fn mae(image1: &Tensor<f32>, image2: &Tensor<f32>) -> Result<f32> {
let shape1 = image1.shape();
let shape2 = image2.shape();
if shape1.dims().len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (C, H, W) for image1, got {}D",
shape1.dims().len()
)));
}
if shape2.dims().len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (C, H, W) for image2, got {}D",
shape2.dims().len()
)));
}
if shape1.dims() != shape2.dims() {
return Err(VisionError::InvalidArgument(
"Input tensors must have the same shape".to_string(),
));
}
let diff = image1
.sub(image2)
.map_err(|e| VisionError::TensorError(e))?;
let abs_diff = diff.abs().map_err(|e| VisionError::TensorError(e))?;
let mae_value = abs_diff.mean(None, false)?.item()?;
Ok(mae_value)
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation;
#[test]
fn test_mse_identical_images() {
let tensor = creation::ones(&[3, 32, 32]).expect("creation should succeed");
let result = mse(&tensor, &tensor).expect("mse should succeed");
assert!((result - 0.0).abs() < 1e-7);
}
#[test]
fn test_mae_identical_images() {
let tensor = creation::ones(&[3, 32, 32]).expect("creation should succeed");
let result = mae(&tensor, &tensor).expect("mae should succeed");
assert!((result - 0.0).abs() < 1e-7);
}
#[test]
fn test_psnr_identical_images() {
let tensor = creation::ones(&[3, 32, 32]).expect("creation should succeed");
let result = psnr(&tensor, &tensor, Some(1.0)).expect("operation should succeed");
assert!(result.is_infinite());
}
#[test]
fn test_ssim_identical_images() {
let tensor = creation::ones(&[3, 32, 32]).expect("creation should succeed");
let result = ssim(&tensor, &tensor, None, None, None).expect("ssim should succeed");
assert!((result - 1.0).abs() < 1e-7);
}
#[test]
fn test_invalid_tensor_shapes() {
let tensor_2d = creation::ones(&[32, 32]).expect("creation should succeed");
let tensor_3d = creation::ones(&[3, 32, 32]).expect("creation should succeed");
assert!(mse(&tensor_2d, &tensor_3d).is_err());
assert!(mae(&tensor_2d, &tensor_3d).is_err());
assert!(psnr(&tensor_2d, &tensor_3d, None).is_err());
assert!(ssim(&tensor_2d, &tensor_3d, None, None, None).is_err());
}
}