aleph_alpha_api/
image_processing.rs

1// code copied from official AlephAlpha rust client: https://github.com/Aleph-Alpha/aleph-alpha-client-rs/blob/main/src/image_preprocessing.rs
2use image::{
3    imageops::FilterType::CatmullRom, DynamicImage, GenericImageView, ImageError, ImageFormat,
4};
5use std::{
6    cmp::min,
7    fs::File,
8    io::{self, BufReader, Cursor},
9    path::Path,
10};
11use thiserror::Error as ThisError;
12
13/// Image is shrank on the server side, before it is send to the model. We might as well save the
14/// bandwidth and do it right away on the client side.
15const DESIRED_IMAGE_SIZE: u32 = 384;
16
17pub fn from_image_path(path: &Path) -> Result<Vec<u8>, LoadImageError> {
18    let file = BufReader::new(File::open(path).map_err(LoadImageError::Io)?);
19    let format = ImageFormat::from_path(path).map_err(LoadImageError::UnknownImageFormat)?;
20    let image = image::load(file, format).map_err(LoadImageError::InvalidImageEncoding)?;
21
22    let bytes = preprocess_image(&image);
23    Ok(bytes)
24}
25
26pub fn preprocess_image(org_image: &DynamicImage) -> Vec<u8> {
27    let center_cropped = center_cropped(org_image);
28    let resized = center_cropped.resize_exact(DESIRED_IMAGE_SIZE, DESIRED_IMAGE_SIZE, CatmullRom);
29    let buf = Vec::new();
30    let mut out = Cursor::new(buf);
31    resized.write_to(&mut out, ImageFormat::Png).unwrap();
32    out.into_inner()
33}
34
35fn center_cropped(image: &DynamicImage) -> DynamicImage {
36    let (height, width) = image.dimensions();
37    let size = min(height, width);
38    let x = (height - size) / 2;
39    let y = (width - size) / 2;
40    image.crop_imm(x, y, width, height)
41}
42
43/// Errors returned by the Aleph Alpha Client
44#[derive(ThisError, Debug)]
45pub enum LoadImageError {
46    #[error("Error decoding input image")]
47    InvalidImageEncoding(#[source] ImageError),
48    #[error("Failed to guess image format from path")]
49    UnknownImageFormat(#[source] ImageError),
50    #[error("Error opening input image file.")]
51    Io(#[source] io::Error),
52}