use img_parts::{webp::WebP, Bytes, ImageEXIF, ImageICC};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ImageFormat {
Jpeg,
Png,
WebP,
Unknown,
}
fn detect_format(data: &[u8]) -> ImageFormat {
if data.starts_with(&[0xFF, 0xD8]) {
return ImageFormat::Jpeg;
}
if data.starts_with(&[0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]) {
return ImageFormat::Png;
}
if data.len() >= 12 && data.starts_with(b"RIFF") && &data[8..12] == b"WEBP" {
return ImageFormat::WebP;
}
ImageFormat::Unknown
}
pub fn strip_metadata(data: &[u8]) -> Option<Vec<u8>> {
match detect_format(data) {
ImageFormat::Jpeg => {
let Ok(mut jpeg) = img_parts::jpeg::Jpeg::from_bytes(Bytes::copy_from_slice(data))
else {
return Some(data.to_vec());
};
jpeg.set_exif(None);
jpeg.set_icc_profile(None);
jpeg.segments_mut()
.retain(|seg| !is_stripped_jpeg_marker(seg.marker()));
Some(jpeg.encoder().bytes().to_vec())
}
ImageFormat::Png => {
let Ok(mut png) = img_parts::png::Png::from_bytes(Bytes::copy_from_slice(data)) else {
return Some(data.to_vec());
};
png.chunks_mut().retain(|chunk| {
let tag = chunk.kind();
tag != *b"tEXt" && tag != *b"zTXt" && tag != *b"iTXt" && tag != *b"eXIf"
});
Some(png.encoder().bytes().to_vec())
}
ImageFormat::WebP => {
let Ok(mut webp) = WebP::from_bytes(Bytes::copy_from_slice(data)) else {
return Some(data.to_vec());
};
webp.set_exif(None);
webp.set_icc_profile(None);
webp.chunks_mut()
.retain(|chunk| chunk.id() != img_parts::webp::CHUNK_XMP);
Some(webp.encoder().bytes().to_vec())
}
ImageFormat::Unknown => None,
}
}
fn is_stripped_jpeg_marker(marker: u8) -> bool {
matches!(marker, 0xE1 | 0xE2 | 0xED | 0xEE)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn detect_format_jpeg() {
assert_eq!(detect_format(&[0xFF, 0xD8, 0x00]), ImageFormat::Jpeg);
}
#[test]
fn detect_format_png() {
let magic = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00];
assert_eq!(detect_format(&magic), ImageFormat::Png);
}
#[test]
fn detect_format_webp() {
let mut hdr = [0u8; 12];
hdr[..4].copy_from_slice(b"RIFF");
hdr[8..12].copy_from_slice(b"WEBP");
assert_eq!(detect_format(&hdr), ImageFormat::WebP);
}
#[test]
fn detect_format_unknown() {
assert_eq!(detect_format(b"not an image"), ImageFormat::Unknown);
assert_eq!(detect_format(&[]), ImageFormat::Unknown);
}
#[test]
fn strip_unknown_bytes_returns_none() {
let data = b"not an image";
assert!(strip_metadata(data).is_none());
}
#[test]
fn strip_empty_slice_returns_none() {
assert!(strip_metadata(&[]).is_none());
}
#[test]
fn marker_classification() {
assert!(is_stripped_jpeg_marker(0xE1)); assert!(is_stripped_jpeg_marker(0xE2)); assert!(is_stripped_jpeg_marker(0xED)); assert!(is_stripped_jpeg_marker(0xEE)); assert!(!is_stripped_jpeg_marker(0xE0)); assert!(!is_stripped_jpeg_marker(0xDA)); }
#[test]
fn strip_webp_does_not_corrupt_image() {
let img = image::RgbaImage::from_pixel(8, 8, image::Rgba([100u8, 150, 200, 255]));
let mut buf = Vec::new();
img.write_to(
&mut std::io::Cursor::new(&mut buf),
image::ImageFormat::WebP,
)
.expect("test WebP encode");
let stripped = strip_metadata(&buf).expect("WebP stripping should succeed");
assert!(
WebP::from_bytes(Bytes::copy_from_slice(&stripped)).is_ok(),
"stripped WebP is not parseable"
);
let decoded = image::load_from_memory(&stripped).expect("stripped WebP should decode");
assert_eq!(decoded.width(), 8);
assert_eq!(decoded.height(), 8);
}
#[test]
fn malformed_jpeg_returns_some_passthrough() {
let data = [0xFF, 0xD8, 0xDE, 0xAD, 0xBE, 0xEF, 0x00];
let result = strip_metadata(&data);
assert!(
result.is_some(),
"malformed-but-detected JPEG should return Some, not None"
);
assert_eq!(
result.unwrap(),
data,
"passthrough should be identical to input"
);
}
#[test]
fn malformed_png_returns_some_passthrough() {
let mut data = vec![0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
data.extend_from_slice(b"GARBAGE");
let result = strip_metadata(&data);
assert!(
result.is_some(),
"malformed-but-detected PNG should return Some, not None"
);
assert_eq!(
result.unwrap(),
data,
"passthrough should be identical to input"
);
}
#[test]
fn malformed_webp_returns_some_passthrough() {
let mut data = [0u8; 20];
data[..4].copy_from_slice(b"RIFF");
data[8..12].copy_from_slice(b"WEBP");
let result = strip_metadata(&data);
assert!(
result.is_some(),
"malformed-but-detected WebP should return Some, not None"
);
assert_eq!(
result.unwrap(),
data,
"passthrough should be identical to input"
);
}
}