const PNG_SIGNATURE: [u8; 8] = [137, 80, 78, 71, 13, 10, 26, 10];
const MIN_PNG_LEN: usize = 33;
#[derive(Clone, Debug)]
pub struct MaskData {
png: Vec<u8>,
}
impl MaskData {
pub fn from_png(png: Vec<u8>) -> Self {
Self { png }
}
pub fn from_png_checked(png: Vec<u8>) -> Result<Self, crate::Error> {
if png.len() < MIN_PNG_LEN {
return Err(crate::Error::InvalidParameters(format!(
"PNG data too short: {} bytes, minimum {} required",
png.len(),
MIN_PNG_LEN
)));
}
if png[..8] != PNG_SIGNATURE {
return Err(crate::Error::InvalidParameters(
"invalid PNG signature: not a PNG file".to_string(),
));
}
let color_type = png[25];
if color_type != 0 {
return Err(crate::Error::InvalidParameters(format!(
"PNG color type must be 0 (grayscale), got {}",
color_type
)));
}
let decoder = png::Decoder::new(std::io::Cursor::new(&png));
if decoder.read_info().is_err() {
return Err(crate::Error::InvalidParameters(
"PNG data is malformed or truncated".to_string(),
));
}
Ok(Self { png })
}
pub fn is_valid(&self) -> bool {
self.png.len() >= MIN_PNG_LEN && self.png[..8] == PNG_SIGNATURE
}
pub fn as_bytes(&self) -> &[u8] {
&self.png
}
pub fn into_bytes(self) -> Vec<u8> {
self.png
}
pub fn width(&self) -> u32 {
self.png
.get(16..20)
.and_then(|b| b.try_into().ok())
.map(u32::from_be_bytes)
.unwrap_or(0)
}
pub fn height(&self) -> u32 {
self.png
.get(20..24)
.and_then(|b| b.try_into().ok())
.map(u32::from_be_bytes)
.unwrap_or(0)
}
pub fn bit_depth(&self) -> u8 {
self.png.get(24).copied().unwrap_or(0)
}
pub fn encode(
pixels: &[u8],
width: u32,
height: u32,
bit_depth: u8,
) -> Result<Self, crate::Error> {
if bit_depth != 1 && bit_depth != 8 {
return Err(crate::Error::InvalidParameters(format!(
"bit_depth must be 1 or 8, got {}",
bit_depth
)));
}
let expected = (width as usize) * (height as usize);
if pixels.len() != expected {
return Err(crate::Error::InvalidParameters(format!(
"pixel count mismatch: expected {}, got {}",
expected,
pixels.len()
)));
}
let mut buf = Vec::new();
{
let mut encoder = png::Encoder::new(&mut buf, width, height);
encoder.set_color(png::ColorType::Grayscale);
encoder.set_depth(match bit_depth {
1 => png::BitDepth::One,
8 => png::BitDepth::Eight,
_ => unreachable!(),
});
let mut writer = encoder.write_header().map_err(|e| {
crate::Error::InvalidParameters(format!("PNG header write failed: {}", e))
})?;
match bit_depth {
1 => {
let bytes_per_row = (width as usize).div_ceil(8);
let mut packed = vec![0u8; bytes_per_row * height as usize];
for y in 0..height as usize {
for x in 0..width as usize {
if pixels[y * width as usize + x] != 0 {
packed[y * bytes_per_row + x / 8] |= 0x80 >> (x % 8);
}
}
}
writer.write_image_data(&packed).map_err(|e| {
crate::Error::InvalidParameters(format!(
"PNG image data write failed: {}",
e
))
})?;
}
8 => {
writer.write_image_data(pixels).map_err(|e| {
crate::Error::InvalidParameters(format!(
"PNG image data write failed: {}",
e
))
})?;
}
_ => unreachable!(),
}
}
Ok(Self { png: buf })
}
pub fn encode_16bit(pixels: &[u16], width: u32, height: u32) -> Result<Self, crate::Error> {
let expected = (width as usize) * (height as usize);
if pixels.len() != expected {
return Err(crate::Error::InvalidParameters(format!(
"pixel count mismatch: expected {}, got {}",
expected,
pixels.len()
)));
}
let mut buf = Vec::new();
{
let mut encoder = png::Encoder::new(&mut buf, width, height);
encoder.set_color(png::ColorType::Grayscale);
encoder.set_depth(png::BitDepth::Sixteen);
let mut writer = encoder.write_header().map_err(|e| {
crate::Error::InvalidParameters(format!("PNG header write failed: {}", e))
})?;
let raw: Vec<u8> = pixels.iter().flat_map(|&v| v.to_be_bytes()).collect();
writer.write_image_data(&raw).map_err(|e| {
crate::Error::InvalidParameters(format!("PNG image data write failed: {}", e))
})?;
}
Ok(Self { png: buf })
}
pub fn decode(&self) -> Result<Vec<u8>, crate::Error> {
let decoder = png::Decoder::new(std::io::Cursor::new(self.png.as_slice()));
let mut reader = decoder
.read_info()
.map_err(|e| crate::Error::InvalidParameters(format!("PNG info read failed: {}", e)))?;
let info = reader.info();
let total_pixels = info.width as u64 * info.height as u64;
const MAX_PIXELS: u64 = 100_000_000; if total_pixels > MAX_PIXELS {
return Err(crate::Error::InvalidParameters(format!(
"PNG dimensions {}x{} exceed maximum of {} pixels",
info.width, info.height, MAX_PIXELS
)));
}
let buffer_size = reader.output_buffer_size().ok_or_else(|| {
crate::Error::InvalidParameters("PNG output buffer size unavailable".to_string())
})?;
let mut raw = vec![0u8; buffer_size];
let info = reader.next_frame(&mut raw).map_err(|e| {
crate::Error::InvalidParameters(format!("PNG frame read failed: {}", e))
})?;
raw.truncate(info.buffer_size());
if info.bit_depth == png::BitDepth::One {
let width = info.width as usize;
let height = info.height as usize;
let bytes_per_row = width.div_ceil(8);
let mut unpacked = Vec::with_capacity(width * height);
for y in 0..height {
for x in 0..width {
let byte = raw[y * bytes_per_row + x / 8];
let bit = (byte >> (7 - (x % 8))) & 1;
unpacked.push(bit);
}
}
Ok(unpacked)
} else {
Ok(raw)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_decode_8bit() {
let pixels: Vec<u8> = vec![0, 64, 128, 192, 255, 1, 100, 200, 50];
let mask = MaskData::encode(&pixels, 3, 3, 8).unwrap();
assert_eq!(mask.width(), 3);
assert_eq!(mask.height(), 3);
assert_eq!(mask.bit_depth(), 8);
let decoded = mask.decode().unwrap();
assert_eq!(decoded, pixels);
}
#[test]
fn test_encode_decode_1bit() {
let pixels: Vec<u8> = vec![
1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, ];
let mask = MaskData::encode(&pixels, 8, 2, 1).unwrap();
assert_eq!(mask.width(), 8);
assert_eq!(mask.height(), 2);
assert_eq!(mask.bit_depth(), 1);
let decoded = mask.decode().unwrap();
assert_eq!(decoded, pixels);
}
#[test]
fn test_encode_decode_16bit() {
let pixels: Vec<u16> = vec![0, 256, 65535, 1024];
let mask = MaskData::encode_16bit(&pixels, 2, 2).unwrap();
assert_eq!(mask.width(), 2);
assert_eq!(mask.height(), 2);
assert_eq!(mask.bit_depth(), 16);
let decoded = mask.decode().unwrap();
let expected: Vec<u8> = pixels.iter().flat_map(|&v| v.to_be_bytes()).collect();
assert_eq!(decoded, expected);
}
#[test]
fn test_header_read_without_decode() {
let width = 640u32;
let height = 480u32;
let pixels = vec![0u8; (width * height) as usize];
let mask = MaskData::encode(&pixels, width, height, 8).unwrap();
assert_eq!(mask.width(), width);
assert_eq!(mask.height(), height);
assert_eq!(mask.bit_depth(), 8);
let raw_size = (width * height) as usize;
assert!(
mask.as_bytes().len() < raw_size,
"PNG ({} bytes) should be smaller than raw ({} bytes)",
mask.as_bytes().len(),
raw_size,
);
}
#[test]
fn test_from_png_bytes() {
let pixels: Vec<u8> = vec![10, 20, 30, 40, 50, 60];
let original = MaskData::encode(&pixels, 3, 2, 8).unwrap();
let bytes = original.into_bytes();
let reconstructed = MaskData::from_png(bytes);
assert_eq!(reconstructed.width(), 3);
assert_eq!(reconstructed.height(), 2);
assert_eq!(reconstructed.bit_depth(), 8);
assert_eq!(reconstructed.decode().unwrap(), pixels);
}
#[test]
fn test_1bit_non_aligned_width() {
let pixels: Vec<u8> = vec![
1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, ];
let mask = MaskData::encode(&pixels, 5, 3, 1).unwrap();
assert_eq!(mask.width(), 5);
assert_eq!(mask.height(), 3);
assert_eq!(mask.bit_depth(), 1);
let decoded = mask.decode().unwrap();
assert_eq!(decoded, pixels);
}
#[test]
fn test_from_png_empty_bytes() {
let result = MaskData::from_png_checked(vec![]);
assert!(result.is_err());
}
#[test]
fn test_from_png_truncated() {
let result = MaskData::from_png_checked(PNG_SIGNATURE.to_vec());
assert!(result.is_err());
}
#[test]
fn test_from_png_garbage() {
let result = MaskData::from_png_checked(vec![0u8; 64]);
assert!(result.is_err());
}
#[test]
fn test_from_png_wrong_color_type() {
let mut fake_png = vec![0u8; MIN_PNG_LEN];
fake_png[..8].copy_from_slice(&PNG_SIGNATURE);
fake_png[25] = 2; let result = MaskData::from_png_checked(fake_png);
assert!(result.is_err());
}
#[test]
fn test_from_png_checked_valid() {
let pixels: Vec<u8> = vec![0, 128, 255, 64];
let mask = MaskData::encode(&pixels, 2, 2, 8).unwrap();
let bytes = mask.into_bytes();
let result = MaskData::from_png_checked(bytes);
assert!(result.is_ok());
}
#[test]
fn test_is_valid() {
let pixels: Vec<u8> = vec![0, 128, 255, 64];
let mask = MaskData::encode(&pixels, 2, 2, 8).unwrap();
assert!(mask.is_valid());
let invalid = MaskData::from_png(vec![1, 2, 3]);
assert!(!invalid.is_valid());
}
#[test]
fn test_width_height_bit_depth_short_data() {
let mask = MaskData::from_png(vec![]);
assert_eq!(mask.width(), 0);
assert_eq!(mask.height(), 0);
assert_eq!(mask.bit_depth(), 0);
let mask2 = MaskData::from_png(vec![0; 10]);
assert_eq!(mask2.width(), 0);
assert_eq!(mask2.height(), 0);
assert_eq!(mask2.bit_depth(), 0);
}
#[test]
fn test_decode_invalid_data_returns_error() {
let mask = MaskData::from_png(vec![1, 2, 3]);
assert!(mask.decode().is_err());
}
#[test]
fn test_encode_invalid_bit_depth() {
let result = MaskData::encode(&[0; 4], 2, 2, 4);
assert!(result.is_err());
}
#[test]
fn test_encode_pixel_count_mismatch() {
let result = MaskData::encode(&[0; 3], 2, 2, 8);
assert!(result.is_err());
}
}