use anyhow::Result;
use candle_core::{DType, Device, Tensor};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NormalizeRange {
MinusOneToOne,
ZeroToOne,
}
pub fn decode_source_image(
bytes: &[u8],
target_w: u32,
target_h: u32,
range: NormalizeRange,
device: &Device,
dtype: DType,
) -> Result<Tensor> {
let img = image::load_from_memory(bytes)
.map_err(|e| anyhow::anyhow!("failed to decode source image: {e}"))?;
let img = img.resize_exact(target_w, target_h, image::imageops::FilterType::Lanczos3);
let img = img.to_rgb8();
let (w, h) = (img.width() as usize, img.height() as usize);
let raw = img.into_raw();
let data: Vec<f32> = raw.iter().map(|&v| v as f32 / 255.0).collect();
let tensor = Tensor::from_vec(data, (h, w, 3), &Device::Cpu)?;
let tensor = tensor.permute((2, 0, 1))?;
let tensor = match range {
NormalizeRange::MinusOneToOne => {
((tensor * 2.0)? - 1.0)?
}
NormalizeRange::ZeroToOne => tensor,
};
let tensor = tensor.unsqueeze(0)?;
let tensor = tensor.to_dtype(dtype)?.to_device(device)?;
Ok(tensor)
}
pub fn decode_mask_image(
bytes: &[u8],
latent_height: usize,
latent_width: usize,
device: &Device,
dtype: DType,
) -> Result<Tensor> {
let img = image::load_from_memory(bytes)
.map_err(|e| anyhow::anyhow!("failed to decode mask image: {e}"))?;
let img = img.resize_exact(
latent_width as u32,
latent_height as u32,
image::imageops::FilterType::Lanczos3,
);
let gray = img.to_luma8();
let data: Vec<f32> = gray.as_raw().iter().map(|&v| v as f32 / 255.0).collect();
let tensor = Tensor::from_vec(data, (1, 1, latent_height, latent_width), &Device::Cpu)?;
let tensor = tensor.to_dtype(dtype)?.to_device(device)?;
Ok(tensor)
}
#[cfg(test)]
mod normalization_tests {
use super::*;
use image::{DynamicImage, ImageBuffer, ImageFormat, Rgb};
use std::io::Cursor;
fn encode_test_png(pixel: [u8; 3]) -> Vec<u8> {
let img: ImageBuffer<Rgb<u8>, Vec<u8>> = ImageBuffer::from_pixel(1, 1, Rgb(pixel));
let mut out = Cursor::new(Vec::new());
DynamicImage::ImageRgb8(img)
.write_to(&mut out, ImageFormat::Png)
.expect("encode PNG");
out.into_inner()
}
fn decode_values(bytes: &[u8], range: NormalizeRange) -> Vec<f32> {
decode_source_image(bytes, 1, 1, range, &Device::Cpu, DType::F32)
.expect("decode source image")
.flatten_all()
.expect("flatten decoded tensor")
.to_vec1::<f32>()
.expect("decoded values")
}
#[test]
fn zero_to_one_normalization_preserves_unit_interval() {
let bytes = encode_test_png([0, 128, 255]);
let values = decode_values(&bytes, NormalizeRange::ZeroToOne);
assert_eq!(values.len(), 3);
assert!((values[0] - 0.0).abs() < 1e-6);
assert!((values[1] - (128.0 / 255.0)).abs() < 1e-6);
assert!((values[2] - 1.0).abs() < 1e-6);
}
#[test]
fn minus_one_to_one_normalization_centers_and_scales_pixels() {
let bytes = encode_test_png([0, 128, 255]);
let values = decode_values(&bytes, NormalizeRange::MinusOneToOne);
assert_eq!(values.len(), 3);
assert!((values[0] + 1.0).abs() < 1e-6);
assert!((values[1] - ((128.0 / 255.0) * 2.0 - 1.0)).abs() < 1e-6);
assert!((values[2] - 1.0).abs() < 1e-6);
}
}
pub struct InpaintContext {
pub original_latents: Tensor,
pub mask: Tensor,
pub noise: Tensor,
}
#[cfg(test)]
mod tests {
use super::*;
fn tiny_png() -> Vec<u8> {
let img = image::RgbImage::from_fn(4, 4, |_, _| image::Rgb([255, 0, 0]));
let mut buf = std::io::Cursor::new(Vec::new());
img.write_to(&mut buf, image::ImageFormat::Png).unwrap();
buf.into_inner()
}
#[test]
fn decode_source_image_shape() {
let png = tiny_png();
let tensor = decode_source_image(
&png,
8,
8,
NormalizeRange::ZeroToOne,
&Device::Cpu,
DType::F32,
)
.unwrap();
assert_eq!(tensor.dims(), &[1, 3, 8, 8]);
}
#[test]
fn decode_source_image_minus_one_to_one_range() {
let png = tiny_png();
let tensor = decode_source_image(
&png,
4,
4,
NormalizeRange::MinusOneToOne,
&Device::Cpu,
DType::F32,
)
.unwrap();
let min = tensor.min_all().unwrap().to_scalar::<f32>().unwrap();
let max = tensor.max_all().unwrap().to_scalar::<f32>().unwrap();
assert!(min >= -1.0 - 0.01);
assert!(max <= 1.0 + 0.01);
}
#[test]
fn decode_source_image_zero_to_one_range() {
let png = tiny_png();
let tensor = decode_source_image(
&png,
4,
4,
NormalizeRange::ZeroToOne,
&Device::Cpu,
DType::F32,
)
.unwrap();
let min = tensor.min_all().unwrap().to_scalar::<f32>().unwrap();
let max = tensor.max_all().unwrap().to_scalar::<f32>().unwrap();
assert!(min >= 0.0 - 0.01);
assert!(max <= 1.0 + 0.01);
}
#[test]
fn decode_source_image_resize() {
let png = tiny_png(); let tensor = decode_source_image(
&png,
16,
16,
NormalizeRange::ZeroToOne,
&Device::Cpu,
DType::F32,
)
.unwrap();
assert_eq!(tensor.dims(), &[1, 3, 16, 16]);
}
fn white_mask_png() -> Vec<u8> {
let img = image::GrayImage::from_fn(4, 4, |_, _| image::Luma([255]));
let mut buf = std::io::Cursor::new(Vec::new());
img.write_to(&mut buf, image::ImageFormat::Png).unwrap();
buf.into_inner()
}
fn black_mask_png() -> Vec<u8> {
let img = image::GrayImage::from_fn(4, 4, |_, _| image::Luma([0]));
let mut buf = std::io::Cursor::new(Vec::new());
img.write_to(&mut buf, image::ImageFormat::Png).unwrap();
buf.into_inner()
}
#[test]
fn decode_mask_shape() {
let mask = white_mask_png();
let tensor = decode_mask_image(&mask, 8, 8, &Device::Cpu, DType::F32).unwrap();
assert_eq!(tensor.dims(), &[1, 1, 8, 8]);
}
#[test]
fn decode_mask_white_is_one() {
let mask = white_mask_png();
let tensor = decode_mask_image(&mask, 4, 4, &Device::Cpu, DType::F32).unwrap();
let min = tensor.min_all().unwrap().to_scalar::<f32>().unwrap();
assert!(min > 0.99, "white mask should be ~1.0, got {min}");
}
#[test]
fn decode_mask_black_is_zero() {
let mask = black_mask_png();
let tensor = decode_mask_image(&mask, 4, 4, &Device::Cpu, DType::F32).unwrap();
let max = tensor.max_all().unwrap().to_scalar::<f32>().unwrap();
assert!(max < 0.01, "black mask should be ~0.0, got {max}");
}
#[test]
fn decode_mask_rgb_converted_to_grayscale() {
let rgb = tiny_png(); let tensor = decode_mask_image(&rgb, 4, 4, &Device::Cpu, DType::F32).unwrap();
assert_eq!(tensor.dims(), &[1, 1, 4, 4]);
let val = tensor.min_all().unwrap().to_scalar::<f32>().unwrap();
assert!(
val > 0.1 && val < 0.5,
"red -> grayscale should be ~0.3, got {val}"
);
}
}