use crate::{Result, VisionError};
use image::{DynamicImage, GenericImageView};
use torsh_tensor::{creation, creation::zeros_mut, Tensor};
pub fn tensor_to_image(tensor: &Tensor<f32>, normalize: bool) -> Result<DynamicImage> {
let shape = tensor.shape();
if shape.dims().len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (C, H, W), got {}D",
shape.dims().len()
)));
}
let (channels, height, width) = (shape.dims()[0], shape.dims()[1], shape.dims()[2]);
if channels != 3 && channels != 1 {
return Err(VisionError::InvalidShape(format!(
"Expected 1 or 3 channels, got {}",
channels
)));
}
let mut processed_tensor = tensor.clone();
if normalize {
let min_val = processed_tensor.min()?;
let max_val = processed_tensor.max(None, false)?;
let min_val_f32 = min_val.to_vec()?[0];
let max_val_f32 = max_val.to_vec()?[0];
if max_val_f32 > min_val_f32 {
processed_tensor.sub_scalar_(min_val_f32)?;
processed_tensor = processed_tensor.div_scalar(max_val_f32 - min_val_f32)?;
}
}
if channels == 3 {
let mut img_buffer = image::RgbImage::new(width as u32, height as u32);
for y in 0..height {
for x in 0..width {
let r = (processed_tensor.get(&[0, y, x])? * 255.0).clamp(0.0, 255.0) as u8;
let g = (processed_tensor.get(&[1, y, x])? * 255.0).clamp(0.0, 255.0) as u8;
let b = (processed_tensor.get(&[2, y, x])? * 255.0).clamp(0.0, 255.0) as u8;
img_buffer.put_pixel(x as u32, y as u32, image::Rgb([r, g, b]));
}
}
Ok(DynamicImage::ImageRgb8(img_buffer))
} else {
let mut img_buffer = image::GrayImage::new(width as u32, height as u32);
for y in 0..height {
for x in 0..width {
let gray = (processed_tensor.get(&[0, y, x])? * 255.0).clamp(0.0, 255.0) as u8;
img_buffer.put_pixel(x as u32, y as u32, image::Luma([gray]));
}
}
Ok(DynamicImage::ImageLuma8(img_buffer))
}
}
pub fn image_to_tensor(image: &DynamicImage) -> Result<Tensor<f32>> {
let (width, height) = image.dimensions();
match image {
DynamicImage::ImageRgb8(rgb_img) => {
let tensor = zeros_mut(&[3, height as usize, width as usize]);
for y in 0..height {
for x in 0..width {
let pixel = rgb_img.get_pixel(x, y);
let r = pixel[0] as f32 / 255.0;
let g = pixel[1] as f32 / 255.0;
let b = pixel[2] as f32 / 255.0;
tensor.set(&[0, y as usize, x as usize], r)?;
tensor.set(&[1, y as usize, x as usize], g)?;
tensor.set(&[2, y as usize, x as usize], b)?;
}
}
Ok(tensor)
}
DynamicImage::ImageLuma8(gray_img) => {
let tensor = zeros_mut(&[1, height as usize, width as usize]);
for y in 0..height {
for x in 0..width {
let pixel = gray_img.get_pixel(x, y);
let gray = pixel[0] as f32 / 255.0;
tensor.set(&[0, y as usize, x as usize], gray)?;
}
}
Ok(tensor)
}
_ => {
let rgb_image = image.to_rgb8();
let tensor = zeros_mut(&[3, height as usize, width as usize]);
for y in 0..height {
for x in 0..width {
let pixel = rgb_image.get_pixel(x, y);
let r = pixel[0] as f32 / 255.0;
let g = pixel[1] as f32 / 255.0;
let b = pixel[2] as f32 / 255.0;
tensor.set(&[0, y as usize, x as usize], r)?;
tensor.set(&[1, y as usize, x as usize], g)?;
tensor.set(&[2, y as usize, x as usize], b)?;
}
}
Ok(tensor)
}
}
}
pub fn denormalize(tensor: &Tensor<f32>, mean: &[f32], std: &[f32]) -> Result<Tensor<f32>> {
let shape = tensor.shape();
if shape.dims().len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (C, H, W), got {}D",
shape.dims().len()
)));
}
let channels = shape.dims()[0];
if mean.len() != channels || std.len() != channels {
return Err(VisionError::InvalidArgument(format!(
"Mean and std must have same length as number of channels. Got {} channels, {} mean values, {} std values",
channels, mean.len(), std.len()
)));
}
let output = tensor.clone();
for c in 0..channels {
let channel_mean = mean[c];
let channel_std = std[c];
if channel_std == 0.0 {
return Err(VisionError::InvalidArgument(
"Standard deviation cannot be zero".to_string(),
));
}
for y in 0..shape.dims()[1] {
for x in 0..shape.dims()[2] {
let val = output.get(&[c, y, x])?;
let denormalized_val = val * channel_std + channel_mean;
output.set(&[c, y, x], denormalized_val)?;
}
}
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation;
#[test]
fn test_tensor_to_image_rgb() {
let tensor = creation::ones(&[3, 10, 10]).expect("creation should succeed");
let image = tensor_to_image(&tensor, false).expect("tensor to image should succeed");
match image {
DynamicImage::ImageRgb8(rgb_img) => {
assert_eq!(rgb_img.width(), 10);
assert_eq!(rgb_img.height(), 10);
}
_ => panic!("Expected RGB image"),
}
}
#[test]
fn test_tensor_to_image_grayscale() {
let tensor = creation::ones(&[1, 10, 10]).expect("creation should succeed");
let image = tensor_to_image(&tensor, false).expect("tensor to image should succeed");
match image {
DynamicImage::ImageLuma8(gray_img) => {
assert_eq!(gray_img.width(), 10);
assert_eq!(gray_img.height(), 10);
}
_ => panic!("Expected grayscale image"),
}
}
#[test]
fn test_image_to_tensor_roundtrip() {
let original_tensor = creation::rand(&[3, 5, 5]).expect("creation should succeed");
let image =
tensor_to_image(&original_tensor, false).expect("tensor to image should succeed");
let converted_tensor = image_to_tensor(&image).expect("image to tensor should succeed");
assert_eq!(converted_tensor.shape().dims(), &[3, 5, 5]);
}
#[test]
fn test_denormalize() {
let tensor = creation::zeros(&[3, 2, 2]).expect("creation should succeed");
let mean = [0.5, 0.5, 0.5];
let std = [0.2, 0.2, 0.2];
let result = denormalize(&tensor, &mean, &std).expect("denormalize should succeed");
for c in 0..3 {
for y in 0..2 {
for x in 0..2 {
let val = result
.get(&[c, y, x])
.expect("element retrieval should succeed for valid index");
assert!((val - 0.5).abs() < 1e-6);
}
}
}
}
#[test]
fn test_denormalize_invalid_channels() {
let tensor = creation::zeros(&[3, 2, 2]).expect("creation should succeed");
let mean = [0.5, 0.5]; let std = [0.2, 0.2, 0.2];
assert!(denormalize(&tensor, &mean, &std).is_err());
}
#[test]
fn test_denormalize_zero_std() {
let tensor = creation::zeros(&[3, 2, 2]).expect("creation should succeed");
let mean = [0.5, 0.5, 0.5];
let std = [0.2, 0.0, 0.2];
assert!(denormalize(&tensor, &mean, &std).is_err());
}
}