use anyhow::{Context, Result, bail};
use image::ImageFormat;
const MAX_FILE_SIZE: usize = 20 * 1024 * 1024;
const MAX_DIMENSION: u32 = 2048;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DetectedFormat {
Jpeg,
Png,
Gif,
WebP,
}
impl DetectedFormat {
pub fn content_type(self) -> &'static str {
match self {
Self::Jpeg => "image/jpeg",
Self::Png => "image/png",
Self::Gif => "image/gif",
Self::WebP => "image/webp",
}
}
}
pub fn detect_format(data: &[u8]) -> Option<DetectedFormat> {
if data.len() < 4 {
return None;
}
if data.starts_with(b"\x89PNG") {
Some(DetectedFormat::Png)
} else if data.starts_with(b"\xff\xd8\xff") {
Some(DetectedFormat::Jpeg)
} else if data.starts_with(b"GIF8") {
Some(DetectedFormat::Gif)
} else if data.len() >= 12 && &data[0..4] == b"RIFF" && &data[8..12] == b"WEBP" {
Some(DetectedFormat::WebP)
} else {
None
}
}
pub fn validate(data: &[u8]) -> Result<DetectedFormat> {
if data.len() > MAX_FILE_SIZE {
bail!("file too large (max 20MB)");
}
detect_format(data)
.ok_or_else(|| anyhow::anyhow!("unsupported image format (accepts JPEG, PNG, WebP, GIF)"))
}
pub fn process(data: &[u8], format: DetectedFormat) -> Result<Vec<u8>> {
let img_format = match format {
DetectedFormat::Jpeg => ImageFormat::Jpeg,
DetectedFormat::Png => ImageFormat::Png,
DetectedFormat::Gif => ImageFormat::Gif,
DetectedFormat::WebP => ImageFormat::WebP,
};
let img =
image::load_from_memory_with_format(data, img_format).context("failed to process image")?;
let img = if img.width() > MAX_DIMENSION || img.height() > MAX_DIMENSION {
img.resize(
MAX_DIMENSION,
MAX_DIMENSION,
image::imageops::FilterType::Lanczos3,
)
} else {
img
};
let mut buf = Vec::new();
let mut cursor = std::io::Cursor::new(&mut buf);
img.write_to(&mut cursor, ImageFormat::Png)
.context("failed to encode as PNG")?;
Ok(buf)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn detect_png() {
let data = b"\x89PNG\r\n\x1a\n extra";
assert_eq!(detect_format(data), Some(DetectedFormat::Png));
}
#[test]
fn detect_jpeg() {
let data = b"\xff\xd8\xff\xe0 extra";
assert_eq!(detect_format(data), Some(DetectedFormat::Jpeg));
}
#[test]
fn detect_gif() {
let data = b"GIF89a extra";
assert_eq!(detect_format(data), Some(DetectedFormat::Gif));
}
#[test]
fn detect_webp() {
let data = b"RIFF\x00\x00\x00\x00WEBP extra";
assert_eq!(detect_format(data), Some(DetectedFormat::WebP));
}
#[test]
fn detect_unknown() {
assert_eq!(detect_format(b"hello world"), None);
}
#[test]
fn validate_rejects_oversize() {
let data = vec![0u8; MAX_FILE_SIZE + 1];
assert!(validate(&data).is_err());
}
#[test]
fn validate_rejects_non_image() {
assert!(validate(b"not an image at all").is_err());
}
#[test]
fn process_valid_png() {
let png = create_test_png(1, 1);
let format = detect_format(&png).unwrap();
let result = process(&png, format).unwrap();
assert!(result.starts_with(b"\x89PNG"));
}
#[test]
fn process_resizes_large_image() {
let png = create_test_png(4096, 2048);
let format = detect_format(&png).unwrap();
let result = process(&png, format).unwrap();
let img = image::load_from_memory(&result).unwrap();
assert!(img.width() <= MAX_DIMENSION);
assert!(img.height() <= MAX_DIMENSION);
}
fn create_test_png(w: u32, h: u32) -> Vec<u8> {
let img = image::RgbaImage::from_pixel(w, h, image::Rgba([255, 0, 0, 255]));
let mut buf = Vec::new();
let mut cursor = std::io::Cursor::new(&mut buf);
img.write_to(&mut cursor, ImageFormat::Png).unwrap();
buf
}
}