use std::cmp;
use fast_image_resize::{ResizeAlg, ResizeOptions, Resizer};
use image::{
imageops::{self, FilterType},
DynamicImage, GenericImageView as _, GrayAlphaImage, GrayImage, ImageBuffer, Pixel as _,
Rgb32FImage, RgbImage, Rgba32FImage, RgbaImage,
};
use ndarray::{s, Array, ArrayD, Axis, ShapeError};
use ort::TensorValueType;
use crate::Variant;
macro_rules! convert_from_0_1_to_neg1_1 {
($f:expr) => {
$f * 2. - 1.
};
}
macro_rules! convert_from_neg1_1_to_0_1 {
($f:expr) => {
($f + 1.) / 2.
};
}
pub(super) struct ModelImage(pub(super) u32, pub(super) Variant, pub(super) DynamicImage);
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("onnx error: {0}")]
Ort(#[from] ort::Error),
#[error("shape error: {0}")]
Shape(#[from] ShapeError),
#[error("invalid shape")]
InvalidShape,
#[error("invalid image")]
Image,
#[error("resize error: {0}")]
Resize(#[from] fast_image_resize::ResizeError),
}
impl TryFrom<ModelImage> for ort::Value<TensorValueType<f32>> {
type Error = Error;
fn try_from(ModelImage(size, variant, img): ModelImage) -> Result<Self, Self::Error> {
let (w, h, xpos, ypos) = center_crop_size_and_offset(variant, &img);
let options = ResizeOptions::new()
.crop(xpos as f64, ypos as f64, w as f64, h as f64)
.resize_alg(ResizeAlg::Interpolation(
fast_image_resize::FilterType::Bilinear,
));
let modified_img = resize_img(&img, size, size, options)?;
let img = modified_img.into_rgb32f().into_vec();
let array = Array::from(img);
let array = convert_from_0_1_to_neg1_1!(array);
let mut array = array
.to_shape([size as usize, size as usize, 3])?
.insert_axis(Axis(3))
.reversed_axes();
array.swap_axes(2, 3);
assert_eq!(array.shape(), &[1, 3, size as usize, size as usize]);
Ok(ort::Value::from_array(&array)?)
}
}
impl TryFrom<(u32, Variant, ArrayD<f32>)> for ModelImage {
type Error = Error;
fn try_from(
(size, variant, mut array): (u32, Variant, ArrayD<f32>),
) -> Result<Self, Self::Error> {
let &[1, 3, height, width] = &array.shape().to_owned()[..] else {
return Err(Error::InvalidShape);
};
array.swap_axes(2, 3);
let array = array.reversed_axes().remove_axis(Axis(3));
let array = array.to_shape([width * height * 3])?;
let array = convert_from_neg1_1_to_0_1!(array);
let image = Rgb32FImage::from_vec(width as u32, height as u32, array.to_vec())
.ok_or(Error::Image)?;
Ok(Self(size, variant, image.into()))
}
}
pub(super) fn apply_residual(input: DynamicImage, residual: DynamicImage) -> DynamicImage {
let has_alpha = input.color().has_alpha();
let (w, h) = input.dimensions();
let applied = {
let input = input.clone().into_rgba32f();
let mut target = input.clone();
let residual = residual.resize_exact(w, h, FilterType::Triangle);
let residual = residual.into_rgba32f();
for ((target, residual), original) in target
.pixels_mut()
.zip(residual.pixels())
.zip(input.pixels())
{
target.apply2(residual, |x, y| {
let x = convert_from_0_1_to_neg1_1!(x);
let y = convert_from_0_1_to_neg1_1!(y);
convert_from_neg1_1_to_0_1!(f32::min(x + y, 1.0))
});
target[3] = original[3];
}
target
};
if has_alpha {
let mut input = input.into_rgba32f();
imageops::replace(&mut input, &applied, 0, 0);
input.into()
} else {
let mut input = input.into_rgb32f();
let applied = DynamicImage::ImageRgba32F(applied).into_rgb32f();
imageops::replace(&mut input, &applied, 0, 0);
input.into()
}
}
fn center_crop_size_and_offset(variant: Variant, img: &DynamicImage) -> (u32, u32, u32, u32) {
let (width, height) = img.dimensions();
if height > width * 2 || width > height * 2 || variant == Variant::P {
let m = cmp::min(height, width);
let offset = (cmp::max(height, width) - m) / 2;
let xpos;
let ypos;
if height > width {
xpos = 0;
ypos = offset;
} else {
ypos = 0;
xpos = offset;
}
(m, m, xpos, ypos)
} else {
(width, height, 0, 0)
}
}
fn resize_img(
img: &DynamicImage,
width: u32,
height: u32,
options: ResizeOptions,
) -> Result<DynamicImage, Error> {
let mut modified_img = match img {
DynamicImage::ImageLuma8(_) => DynamicImage::ImageLuma8(GrayImage::new(width, height)),
DynamicImage::ImageLumaA8(_) => {
DynamicImage::ImageLumaA8(GrayAlphaImage::new(width, height))
}
DynamicImage::ImageRgb8(_) => DynamicImage::ImageRgb8(RgbImage::new(width, height)),
DynamicImage::ImageRgba8(_) => DynamicImage::ImageRgba8(RgbaImage::new(width, height)),
DynamicImage::ImageLuma16(_) => DynamicImage::ImageLuma16(ImageBuffer::new(width, height)),
DynamicImage::ImageLumaA16(_) => {
DynamicImage::ImageLumaA16(ImageBuffer::new(width, height))
}
DynamicImage::ImageRgb16(_) => DynamicImage::ImageRgb16(ImageBuffer::new(width, height)),
DynamicImage::ImageRgba16(_) => DynamicImage::ImageRgba16(ImageBuffer::new(width, height)),
DynamicImage::ImageRgb32F(_) => DynamicImage::ImageRgb32F(Rgb32FImage::new(width, height)),
DynamicImage::ImageRgba32F(_) => {
DynamicImage::ImageRgba32F(Rgba32FImage::new(width, height))
}
_ => return Err(Error::Image),
};
Resizer::new().resize(img, &mut modified_img, &options)?;
Ok(modified_img)
}
pub(super) fn remove_boundary_artifact(
mut residual: ArrayD<f32>,
(width, height): (usize, usize),
_variant: Variant,
) -> ArrayD<f32> {
let channel_means: Vec<f32> = (0_usize..3)
.map(|i| residual.slice(s![.., i, .., ..]).mean().unwrap())
.collect();
let other_dim;
let mut mean_padded: ndarray::Array4<f32> = if width > height {
other_dim = ((width as f32 / height as f32) * 256.0) as usize;
ndarray::Array4::zeros([1, 3, 256_usize, other_dim])
} else {
other_dim = ((height as f32 / width as f32) * 256.0) as usize;
ndarray::Array4::zeros([1, 3, other_dim, 256])
};
let border = 2;
for (i, mean) in channel_means.iter().enumerate() {
residual.slice_mut(s![0, i, ..border, ..]).fill(*mean);
residual.slice_mut(s![0, i, -border.., ..]).fill(*mean);
residual.slice_mut(s![0, i, .., -border..]).fill(*mean);
residual.slice_mut(s![0, i, .., ..border]).fill(*mean);
mean_padded.slice_mut(s![0, i, .., ..]).fill(*mean);
}
if width > height {
let leftover = (other_dim - 256) / 2;
mean_padded
.slice_mut(s![.., .., .., leftover..(leftover + 256)])
.assign(&residual);
} else {
let leftover = (other_dim - 256) / 2;
mean_padded
.slice_mut(s![.., .., leftover..(leftover + 256), ..])
.assign(&residual);
}
mean_padded.into_dyn()
}
#[cfg(test)]
mod tests {
use ndarray::Array4;
use super::*;
#[test]
fn renormalize_from_0_1() {
assert_eq!(convert_from_0_1_to_neg1_1!(0.), -1.);
assert_eq!(convert_from_0_1_to_neg1_1!(0.5), 0.);
assert_eq!(convert_from_0_1_to_neg1_1!(0.99), 0.98);
}
#[test]
fn renormalize_from_neg1_1() {
assert_eq!(convert_from_neg1_1_to_0_1!(-1.), 0.);
assert_eq!(convert_from_neg1_1_to_0_1!(0.5), 0.75);
assert_eq!(convert_from_neg1_1_to_0_1!(-0.1), 0.45);
}
#[test]
fn normal_image() {
let image = DynamicImage::new(100, 110, image::ColorType::L8);
assert_eq!(
center_crop_size_and_offset(Variant::Q, &image),
(100, 110, 0, 0)
);
}
#[test]
fn skinny_image() {
let image = DynamicImage::new(10, 100, image::ColorType::L8);
assert_eq!(
center_crop_size_and_offset(Variant::Q, &image),
(10, 10, 0, 45)
);
}
#[test]
fn wide_image() {
let image = DynamicImage::new(101, 10, image::ColorType::L8);
assert_eq!(
center_crop_size_and_offset(Variant::Q, &image),
(10, 10, 45, 0)
);
}
#[test]
fn always_crop_p() {
let image = DynamicImage::new(100, 110, image::ColorType::L8);
assert_eq!(
center_crop_size_and_offset(Variant::P, &image),
(100, 100, 0, 5)
);
}
#[test]
fn remove_boundary_artifact_tall() {
let residual: Array4<f32> = Array4::ones([1, 3, 256, 256]);
let width = 256;
let height = 298;
let output = remove_boundary_artifact(residual.into_dyn(), (width, height), Variant::P);
assert_eq!(output.shape(), &[1, 3, 298, 256]);
}
}