1const PNG_SIGNATURE: [u8; 8] = [137, 80, 78, 71, 13, 10, 26, 10];
13
14const MIN_PNG_LEN: usize = 33;
17
18#[derive(Clone, Debug)]
37pub struct MaskData {
38 png: Vec<u8>,
39}
40
41impl MaskData {
42 pub fn from_png(png: Vec<u8>) -> Self {
47 Self { png }
48 }
49
50 pub fn from_png_checked(png: Vec<u8>) -> Result<Self, crate::Error> {
61 if png.len() < MIN_PNG_LEN {
62 return Err(crate::Error::InvalidParameters(format!(
63 "PNG data too short: {} bytes, minimum {} required",
64 png.len(),
65 MIN_PNG_LEN
66 )));
67 }
68 if png[..8] != PNG_SIGNATURE {
69 return Err(crate::Error::InvalidParameters(
70 "invalid PNG signature: not a PNG file".to_string(),
71 ));
72 }
73 let color_type = png[25];
74 if color_type != 0 {
75 return Err(crate::Error::InvalidParameters(format!(
76 "PNG color type must be 0 (grayscale), got {}",
77 color_type
78 )));
79 }
80
81 let decoder = png::Decoder::new(std::io::Cursor::new(&png));
83 if decoder.read_info().is_err() {
84 return Err(crate::Error::InvalidParameters(
85 "PNG data is malformed or truncated".to_string(),
86 ));
87 }
88
89 Ok(Self { png })
90 }
91
92 pub fn is_valid(&self) -> bool {
95 self.png.len() >= MIN_PNG_LEN && self.png[..8] == PNG_SIGNATURE
96 }
97
98 pub fn as_bytes(&self) -> &[u8] {
100 &self.png
101 }
102
103 pub fn into_bytes(self) -> Vec<u8> {
105 self.png
106 }
107
108 pub fn width(&self) -> u32 {
112 self.png
113 .get(16..20)
114 .and_then(|b| b.try_into().ok())
115 .map(u32::from_be_bytes)
116 .unwrap_or(0)
117 }
118
119 pub fn height(&self) -> u32 {
123 self.png
124 .get(20..24)
125 .and_then(|b| b.try_into().ok())
126 .map(u32::from_be_bytes)
127 .unwrap_or(0)
128 }
129
130 pub fn bit_depth(&self) -> u8 {
134 self.png.get(24).copied().unwrap_or(0)
135 }
136
137 pub fn encode(
150 pixels: &[u8],
151 width: u32,
152 height: u32,
153 bit_depth: u8,
154 ) -> Result<Self, crate::Error> {
155 if bit_depth != 1 && bit_depth != 8 {
156 return Err(crate::Error::InvalidParameters(format!(
157 "bit_depth must be 1 or 8, got {}",
158 bit_depth
159 )));
160 }
161 let expected = (width as usize) * (height as usize);
162 if pixels.len() != expected {
163 return Err(crate::Error::InvalidParameters(format!(
164 "pixel count mismatch: expected {}, got {}",
165 expected,
166 pixels.len()
167 )));
168 }
169
170 let mut buf = Vec::new();
171 {
172 let mut encoder = png::Encoder::new(&mut buf, width, height);
173 encoder.set_color(png::ColorType::Grayscale);
174 encoder.set_depth(match bit_depth {
175 1 => png::BitDepth::One,
176 8 => png::BitDepth::Eight,
177 _ => unreachable!(),
178 });
179
180 let mut writer = encoder.write_header().map_err(|e| {
181 crate::Error::InvalidParameters(format!("PNG header write failed: {}", e))
182 })?;
183
184 match bit_depth {
185 1 => {
186 let bytes_per_row = (width as usize).div_ceil(8);
187 let mut packed = vec![0u8; bytes_per_row * height as usize];
188 for y in 0..height as usize {
189 for x in 0..width as usize {
190 if pixels[y * width as usize + x] != 0 {
191 packed[y * bytes_per_row + x / 8] |= 0x80 >> (x % 8);
192 }
193 }
194 }
195 writer.write_image_data(&packed).map_err(|e| {
196 crate::Error::InvalidParameters(format!(
197 "PNG image data write failed: {}",
198 e
199 ))
200 })?;
201 }
202 8 => {
203 writer.write_image_data(pixels).map_err(|e| {
204 crate::Error::InvalidParameters(format!(
205 "PNG image data write failed: {}",
206 e
207 ))
208 })?;
209 }
210 _ => unreachable!(),
211 }
212 }
213 Ok(Self { png: buf })
214 }
215
216 pub fn encode_16bit(pixels: &[u16], width: u32, height: u32) -> Result<Self, crate::Error> {
225 let expected = (width as usize) * (height as usize);
226 if pixels.len() != expected {
227 return Err(crate::Error::InvalidParameters(format!(
228 "pixel count mismatch: expected {}, got {}",
229 expected,
230 pixels.len()
231 )));
232 }
233
234 let mut buf = Vec::new();
235 {
236 let mut encoder = png::Encoder::new(&mut buf, width, height);
237 encoder.set_color(png::ColorType::Grayscale);
238 encoder.set_depth(png::BitDepth::Sixteen);
239
240 let mut writer = encoder.write_header().map_err(|e| {
241 crate::Error::InvalidParameters(format!("PNG header write failed: {}", e))
242 })?;
243
244 let raw: Vec<u8> = pixels.iter().flat_map(|&v| v.to_be_bytes()).collect();
245 writer.write_image_data(&raw).map_err(|e| {
246 crate::Error::InvalidParameters(format!("PNG image data write failed: {}", e))
247 })?;
248 }
249 Ok(Self { png: buf })
250 }
251
252 pub fn decode(&self) -> Result<Vec<u8>, crate::Error> {
262 let decoder = png::Decoder::new(std::io::Cursor::new(self.png.as_slice()));
263 let mut reader = decoder
264 .read_info()
265 .map_err(|e| crate::Error::InvalidParameters(format!("PNG info read failed: {}", e)))?;
266
267 let info = reader.info();
269 let total_pixels = info.width as u64 * info.height as u64;
270 const MAX_PIXELS: u64 = 100_000_000; if total_pixels > MAX_PIXELS {
272 return Err(crate::Error::InvalidParameters(format!(
273 "PNG dimensions {}x{} exceed maximum of {} pixels",
274 info.width, info.height, MAX_PIXELS
275 )));
276 }
277
278 let buffer_size = reader.output_buffer_size().ok_or_else(|| {
279 crate::Error::InvalidParameters("PNG output buffer size unavailable".to_string())
280 })?;
281 let mut raw = vec![0u8; buffer_size];
282 let info = reader.next_frame(&mut raw).map_err(|e| {
283 crate::Error::InvalidParameters(format!("PNG frame read failed: {}", e))
284 })?;
285 raw.truncate(info.buffer_size());
286
287 if info.bit_depth == png::BitDepth::One {
288 let width = info.width as usize;
289 let height = info.height as usize;
290 let bytes_per_row = width.div_ceil(8);
291 let mut unpacked = Vec::with_capacity(width * height);
292 for y in 0..height {
293 for x in 0..width {
294 let byte = raw[y * bytes_per_row + x / 8];
295 let bit = (byte >> (7 - (x % 8))) & 1;
296 unpacked.push(bit);
297 }
298 }
299 Ok(unpacked)
300 } else {
301 Ok(raw)
302 }
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_encode_decode_8bit() {
312 let pixels: Vec<u8> = vec![0, 64, 128, 192, 255, 1, 100, 200, 50];
314 let mask = MaskData::encode(&pixels, 3, 3, 8).unwrap();
315
316 assert_eq!(mask.width(), 3);
317 assert_eq!(mask.height(), 3);
318 assert_eq!(mask.bit_depth(), 8);
319
320 let decoded = mask.decode().unwrap();
321 assert_eq!(decoded, pixels);
322 }
323
324 #[test]
325 fn test_encode_decode_1bit() {
326 let pixels: Vec<u8> = vec![
328 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, ];
331 let mask = MaskData::encode(&pixels, 8, 2, 1).unwrap();
332
333 assert_eq!(mask.width(), 8);
334 assert_eq!(mask.height(), 2);
335 assert_eq!(mask.bit_depth(), 1);
336
337 let decoded = mask.decode().unwrap();
338 assert_eq!(decoded, pixels);
339 }
340
341 #[test]
342 fn test_encode_decode_16bit() {
343 let pixels: Vec<u16> = vec![0, 256, 65535, 1024];
345 let mask = MaskData::encode_16bit(&pixels, 2, 2).unwrap();
346
347 assert_eq!(mask.width(), 2);
348 assert_eq!(mask.height(), 2);
349 assert_eq!(mask.bit_depth(), 16);
350
351 let decoded = mask.decode().unwrap();
352 let expected: Vec<u8> = pixels.iter().flat_map(|&v| v.to_be_bytes()).collect();
354 assert_eq!(decoded, expected);
355 }
356
357 #[test]
358 fn test_header_read_without_decode() {
359 let width = 640u32;
361 let height = 480u32;
362 let pixels = vec![0u8; (width * height) as usize];
363 let mask = MaskData::encode(&pixels, width, height, 8).unwrap();
364
365 assert_eq!(mask.width(), width);
366 assert_eq!(mask.height(), height);
367 assert_eq!(mask.bit_depth(), 8);
368
369 let raw_size = (width * height) as usize;
371 assert!(
372 mask.as_bytes().len() < raw_size,
373 "PNG ({} bytes) should be smaller than raw ({} bytes)",
374 mask.as_bytes().len(),
375 raw_size,
376 );
377 }
378
379 #[test]
380 fn test_from_png_bytes() {
381 let pixels: Vec<u8> = vec![10, 20, 30, 40, 50, 60];
383 let original = MaskData::encode(&pixels, 3, 2, 8).unwrap();
384
385 let bytes = original.into_bytes();
386 let reconstructed = MaskData::from_png(bytes);
387
388 assert_eq!(reconstructed.width(), 3);
389 assert_eq!(reconstructed.height(), 2);
390 assert_eq!(reconstructed.bit_depth(), 8);
391 assert_eq!(reconstructed.decode().unwrap(), pixels);
392 }
393
394 #[test]
395 fn test_1bit_non_aligned_width() {
396 let pixels: Vec<u8> = vec![
398 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, ];
402 let mask = MaskData::encode(&pixels, 5, 3, 1).unwrap();
403
404 assert_eq!(mask.width(), 5);
405 assert_eq!(mask.height(), 3);
406 assert_eq!(mask.bit_depth(), 1);
407
408 let decoded = mask.decode().unwrap();
409 assert_eq!(decoded, pixels);
410 }
411
412 #[test]
417 fn test_from_png_empty_bytes() {
418 let result = MaskData::from_png_checked(vec![]);
419 assert!(result.is_err());
420 }
421
422 #[test]
423 fn test_from_png_truncated() {
424 let result = MaskData::from_png_checked(PNG_SIGNATURE.to_vec());
426 assert!(result.is_err());
427 }
428
429 #[test]
430 fn test_from_png_garbage() {
431 let result = MaskData::from_png_checked(vec![0u8; 64]);
432 assert!(result.is_err());
433 }
434
435 #[test]
436 fn test_from_png_wrong_color_type() {
437 let mut fake_png = vec![0u8; MIN_PNG_LEN];
439 fake_png[..8].copy_from_slice(&PNG_SIGNATURE);
440 fake_png[25] = 2; let result = MaskData::from_png_checked(fake_png);
442 assert!(result.is_err());
443 }
444
445 #[test]
446 fn test_from_png_checked_valid() {
447 let pixels: Vec<u8> = vec![0, 128, 255, 64];
448 let mask = MaskData::encode(&pixels, 2, 2, 8).unwrap();
449 let bytes = mask.into_bytes();
450 let result = MaskData::from_png_checked(bytes);
451 assert!(result.is_ok());
452 }
453
454 #[test]
455 fn test_is_valid() {
456 let pixels: Vec<u8> = vec![0, 128, 255, 64];
457 let mask = MaskData::encode(&pixels, 2, 2, 8).unwrap();
458 assert!(mask.is_valid());
459
460 let invalid = MaskData::from_png(vec![1, 2, 3]);
461 assert!(!invalid.is_valid());
462 }
463
464 #[test]
469 fn test_width_height_bit_depth_short_data() {
470 let mask = MaskData::from_png(vec![]);
471 assert_eq!(mask.width(), 0);
472 assert_eq!(mask.height(), 0);
473 assert_eq!(mask.bit_depth(), 0);
474
475 let mask2 = MaskData::from_png(vec![0; 10]);
476 assert_eq!(mask2.width(), 0);
477 assert_eq!(mask2.height(), 0);
478 assert_eq!(mask2.bit_depth(), 0);
479 }
480
481 #[test]
482 fn test_decode_invalid_data_returns_error() {
483 let mask = MaskData::from_png(vec![1, 2, 3]);
484 assert!(mask.decode().is_err());
485 }
486
487 #[test]
488 fn test_encode_invalid_bit_depth() {
489 let result = MaskData::encode(&[0; 4], 2, 2, 4);
490 assert!(result.is_err());
491 }
492
493 #[test]
494 fn test_encode_pixel_count_mismatch() {
495 let result = MaskData::encode(&[0; 3], 2, 2, 8);
496 assert!(result.is_err());
497 }
498}