1const PNG_SIGNATURE: [u8; 8] = [137, 80, 78, 71, 13, 10, 26, 10];
13
14const MIN_PNG_LEN: usize = 33;
17
18#[derive(Clone, Debug)]
37pub struct MaskData {
38 png: Vec<u8>,
39}
40
41impl MaskData {
42 pub fn from_png(png: Vec<u8>) -> Self {
47 Self { png }
48 }
49
50 pub fn from_png_checked(png: Vec<u8>) -> Result<Self, crate::Error> {
61 if png.len() < MIN_PNG_LEN {
62 return Err(crate::Error::InvalidParameters(format!(
63 "PNG data too short: {} bytes, minimum {} required",
64 png.len(),
65 MIN_PNG_LEN
66 )));
67 }
68 if png[..8] != PNG_SIGNATURE {
69 return Err(crate::Error::InvalidParameters(
70 "invalid PNG signature: not a PNG file".to_string(),
71 ));
72 }
73 let color_type = png[25];
74 if color_type != 0 {
75 return Err(crate::Error::InvalidParameters(format!(
76 "PNG color type must be 0 (grayscale), got {}",
77 color_type
78 )));
79 }
80
81 let decoder = png::Decoder::new(std::io::Cursor::new(&png));
83 if decoder.read_info().is_err() {
84 return Err(crate::Error::InvalidParameters(
85 "PNG data is malformed or truncated".to_string(),
86 ));
87 }
88
89 Ok(Self { png })
90 }
91
92 pub fn is_valid(&self) -> bool {
95 self.png.len() >= MIN_PNG_LEN && self.png[..8] == PNG_SIGNATURE
96 }
97
98 pub fn as_bytes(&self) -> &[u8] {
100 &self.png
101 }
102
103 pub fn into_bytes(self) -> Vec<u8> {
105 self.png
106 }
107
108 pub fn width(&self) -> u32 {
112 self.png
113 .get(16..20)
114 .and_then(|b| b.try_into().ok())
115 .map(u32::from_be_bytes)
116 .unwrap_or(0)
117 }
118
119 pub fn height(&self) -> u32 {
123 self.png
124 .get(20..24)
125 .and_then(|b| b.try_into().ok())
126 .map(u32::from_be_bytes)
127 .unwrap_or(0)
128 }
129
130 pub fn bit_depth(&self) -> u8 {
134 self.png.get(24).copied().unwrap_or(0)
135 }
136
137 pub fn encode(
150 pixels: &[u8],
151 width: u32,
152 height: u32,
153 bit_depth: u8,
154 ) -> Result<Self, crate::Error> {
155 if bit_depth != 1 && bit_depth != 8 {
156 return Err(crate::Error::InvalidParameters(format!(
157 "bit_depth must be 1 or 8, got {}",
158 bit_depth
159 )));
160 }
161 let expected = (width as usize) * (height as usize);
162 if pixels.len() != expected {
163 return Err(crate::Error::InvalidParameters(format!(
164 "pixel count mismatch: expected {}, got {}",
165 expected,
166 pixels.len()
167 )));
168 }
169
170 let mut buf = Vec::new();
171 {
172 let mut encoder = png::Encoder::new(&mut buf, width, height);
173 encoder.set_color(png::ColorType::Grayscale);
174 encoder.set_depth(match bit_depth {
175 1 => png::BitDepth::One,
176 8 => png::BitDepth::Eight,
177 _ => unreachable!(),
178 });
179
180 let mut writer = encoder.write_header().map_err(|e| {
181 crate::Error::InvalidParameters(format!("PNG header write failed: {}", e))
182 })?;
183
184 match bit_depth {
185 1 => {
186 let bytes_per_row = (width as usize).div_ceil(8);
187 let mut packed = vec![0u8; bytes_per_row * height as usize];
188 for y in 0..height as usize {
189 for x in 0..width as usize {
190 if pixels[y * width as usize + x] != 0 {
191 packed[y * bytes_per_row + x / 8] |= 0x80 >> (x % 8);
192 }
193 }
194 }
195 writer.write_image_data(&packed).map_err(|e| {
196 crate::Error::InvalidParameters(format!(
197 "PNG image data write failed: {}",
198 e
199 ))
200 })?;
201 }
202 8 => {
203 writer.write_image_data(pixels).map_err(|e| {
204 crate::Error::InvalidParameters(format!(
205 "PNG image data write failed: {}",
206 e
207 ))
208 })?;
209 }
210 _ => unreachable!(),
211 }
212 }
213 Ok(Self { png: buf })
214 }
215
216 pub fn encode_16bit(pixels: &[u16], width: u32, height: u32) -> Result<Self, crate::Error> {
225 let expected = (width as usize) * (height as usize);
226 if pixels.len() != expected {
227 return Err(crate::Error::InvalidParameters(format!(
228 "pixel count mismatch: expected {}, got {}",
229 expected,
230 pixels.len()
231 )));
232 }
233
234 let mut buf = Vec::new();
235 {
236 let mut encoder = png::Encoder::new(&mut buf, width, height);
237 encoder.set_color(png::ColorType::Grayscale);
238 encoder.set_depth(png::BitDepth::Sixteen);
239
240 let mut writer = encoder.write_header().map_err(|e| {
241 crate::Error::InvalidParameters(format!("PNG header write failed: {}", e))
242 })?;
243
244 let raw: Vec<u8> = pixels.iter().flat_map(|&v| v.to_be_bytes()).collect();
245 writer.write_image_data(&raw).map_err(|e| {
246 crate::Error::InvalidParameters(format!("PNG image data write failed: {}", e))
247 })?;
248 }
249 Ok(Self { png: buf })
250 }
251
252 pub fn decode(&self) -> Result<Vec<u8>, crate::Error> {
262 let decoder = png::Decoder::new(self.png.as_slice());
263 let mut reader = decoder
264 .read_info()
265 .map_err(|e| crate::Error::InvalidParameters(format!("PNG info read failed: {}", e)))?;
266
267 let info = reader.info();
269 let total_pixels = info.width as u64 * info.height as u64;
270 const MAX_PIXELS: u64 = 100_000_000; if total_pixels > MAX_PIXELS {
272 return Err(crate::Error::InvalidParameters(format!(
273 "PNG dimensions {}x{} exceed maximum of {} pixels",
274 info.width, info.height, MAX_PIXELS
275 )));
276 }
277
278 let mut raw = vec![0u8; reader.output_buffer_size()];
279 let info = reader.next_frame(&mut raw).map_err(|e| {
280 crate::Error::InvalidParameters(format!("PNG frame read failed: {}", e))
281 })?;
282 raw.truncate(info.buffer_size());
283
284 if info.bit_depth == png::BitDepth::One {
285 let width = info.width as usize;
286 let height = info.height as usize;
287 let bytes_per_row = width.div_ceil(8);
288 let mut unpacked = Vec::with_capacity(width * height);
289 for y in 0..height {
290 for x in 0..width {
291 let byte = raw[y * bytes_per_row + x / 8];
292 let bit = (byte >> (7 - (x % 8))) & 1;
293 unpacked.push(bit);
294 }
295 }
296 Ok(unpacked)
297 } else {
298 Ok(raw)
299 }
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_encode_decode_8bit() {
309 let pixels: Vec<u8> = vec![0, 64, 128, 192, 255, 1, 100, 200, 50];
311 let mask = MaskData::encode(&pixels, 3, 3, 8).unwrap();
312
313 assert_eq!(mask.width(), 3);
314 assert_eq!(mask.height(), 3);
315 assert_eq!(mask.bit_depth(), 8);
316
317 let decoded = mask.decode().unwrap();
318 assert_eq!(decoded, pixels);
319 }
320
321 #[test]
322 fn test_encode_decode_1bit() {
323 let pixels: Vec<u8> = vec![
325 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, ];
328 let mask = MaskData::encode(&pixels, 8, 2, 1).unwrap();
329
330 assert_eq!(mask.width(), 8);
331 assert_eq!(mask.height(), 2);
332 assert_eq!(mask.bit_depth(), 1);
333
334 let decoded = mask.decode().unwrap();
335 assert_eq!(decoded, pixels);
336 }
337
338 #[test]
339 fn test_encode_decode_16bit() {
340 let pixels: Vec<u16> = vec![0, 256, 65535, 1024];
342 let mask = MaskData::encode_16bit(&pixels, 2, 2).unwrap();
343
344 assert_eq!(mask.width(), 2);
345 assert_eq!(mask.height(), 2);
346 assert_eq!(mask.bit_depth(), 16);
347
348 let decoded = mask.decode().unwrap();
349 let expected: Vec<u8> = pixels.iter().flat_map(|&v| v.to_be_bytes()).collect();
351 assert_eq!(decoded, expected);
352 }
353
354 #[test]
355 fn test_header_read_without_decode() {
356 let width = 640u32;
358 let height = 480u32;
359 let pixels = vec![0u8; (width * height) as usize];
360 let mask = MaskData::encode(&pixels, width, height, 8).unwrap();
361
362 assert_eq!(mask.width(), width);
363 assert_eq!(mask.height(), height);
364 assert_eq!(mask.bit_depth(), 8);
365
366 let raw_size = (width * height) as usize;
368 assert!(
369 mask.as_bytes().len() < raw_size,
370 "PNG ({} bytes) should be smaller than raw ({} bytes)",
371 mask.as_bytes().len(),
372 raw_size,
373 );
374 }
375
376 #[test]
377 fn test_from_png_bytes() {
378 let pixels: Vec<u8> = vec![10, 20, 30, 40, 50, 60];
380 let original = MaskData::encode(&pixels, 3, 2, 8).unwrap();
381
382 let bytes = original.into_bytes();
383 let reconstructed = MaskData::from_png(bytes);
384
385 assert_eq!(reconstructed.width(), 3);
386 assert_eq!(reconstructed.height(), 2);
387 assert_eq!(reconstructed.bit_depth(), 8);
388 assert_eq!(reconstructed.decode().unwrap(), pixels);
389 }
390
391 #[test]
392 fn test_1bit_non_aligned_width() {
393 let pixels: Vec<u8> = vec![
395 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, ];
399 let mask = MaskData::encode(&pixels, 5, 3, 1).unwrap();
400
401 assert_eq!(mask.width(), 5);
402 assert_eq!(mask.height(), 3);
403 assert_eq!(mask.bit_depth(), 1);
404
405 let decoded = mask.decode().unwrap();
406 assert_eq!(decoded, pixels);
407 }
408
409 #[test]
414 fn test_from_png_empty_bytes() {
415 let result = MaskData::from_png_checked(vec![]);
416 assert!(result.is_err());
417 }
418
419 #[test]
420 fn test_from_png_truncated() {
421 let result = MaskData::from_png_checked(PNG_SIGNATURE.to_vec());
423 assert!(result.is_err());
424 }
425
426 #[test]
427 fn test_from_png_garbage() {
428 let result = MaskData::from_png_checked(vec![0u8; 64]);
429 assert!(result.is_err());
430 }
431
432 #[test]
433 fn test_from_png_wrong_color_type() {
434 let mut fake_png = vec![0u8; MIN_PNG_LEN];
436 fake_png[..8].copy_from_slice(&PNG_SIGNATURE);
437 fake_png[25] = 2; let result = MaskData::from_png_checked(fake_png);
439 assert!(result.is_err());
440 }
441
442 #[test]
443 fn test_from_png_checked_valid() {
444 let pixels: Vec<u8> = vec![0, 128, 255, 64];
445 let mask = MaskData::encode(&pixels, 2, 2, 8).unwrap();
446 let bytes = mask.into_bytes();
447 let result = MaskData::from_png_checked(bytes);
448 assert!(result.is_ok());
449 }
450
451 #[test]
452 fn test_is_valid() {
453 let pixels: Vec<u8> = vec![0, 128, 255, 64];
454 let mask = MaskData::encode(&pixels, 2, 2, 8).unwrap();
455 assert!(mask.is_valid());
456
457 let invalid = MaskData::from_png(vec![1, 2, 3]);
458 assert!(!invalid.is_valid());
459 }
460
461 #[test]
466 fn test_width_height_bit_depth_short_data() {
467 let mask = MaskData::from_png(vec![]);
468 assert_eq!(mask.width(), 0);
469 assert_eq!(mask.height(), 0);
470 assert_eq!(mask.bit_depth(), 0);
471
472 let mask2 = MaskData::from_png(vec![0; 10]);
473 assert_eq!(mask2.width(), 0);
474 assert_eq!(mask2.height(), 0);
475 assert_eq!(mask2.bit_depth(), 0);
476 }
477
478 #[test]
479 fn test_decode_invalid_data_returns_error() {
480 let mask = MaskData::from_png(vec![1, 2, 3]);
481 assert!(mask.decode().is_err());
482 }
483
484 #[test]
485 fn test_encode_invalid_bit_depth() {
486 let result = MaskData::encode(&[0; 4], 2, 2, 4);
487 assert!(result.is_err());
488 }
489
490 #[test]
491 fn test_encode_pixel_count_mismatch() {
492 let result = MaskData::encode(&[0; 3], 2, 2, 8);
493 assert!(result.is_err());
494 }
495}