Skip to main content

edgefirst_client/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! Raster mask data with PNG encode/decode support.
5//!
6//! [`MaskData`] wraps PNG-encoded bytes for storing raster masks in Arrow
7//! `Binary` columns. It supports fast header-only reads (width, height,
8//! bit_depth) without decoding pixels, and can encode from raw pixels at
9//! various bit depths (1-bit binary, 8-bit scores, 16-bit precision).
10
11/// PNG magic bytes (first 8 bytes of every valid PNG file).
12const PNG_SIGNATURE: [u8; 8] = [137, 80, 78, 71, 13, 10, 26, 10];
13
14/// Minimum length of a valid PNG file with IHDR chunk.
15/// 8 (signature) + 4 (length) + 4 (type "IHDR") + 13 (IHDR data) + 4 (IHDR CRC) = 33
16const MIN_PNG_LEN: usize = 33;
17
18/// A raster mask stored as PNG-encoded bytes.
19///
20/// `MaskData` provides zero-copy access to PNG metadata (width, height,
21/// bit_depth) by reading the IHDR chunk directly, and full encode/decode
22/// for pixel data at 1-bit, 8-bit, and 16-bit depths.
23///
24/// # PNG layout reference
25///
26/// ```text
27/// [0..8]   PNG signature
28/// [8..12]  IHDR chunk length (always 13)
29/// [12..16] IHDR chunk type ("IHDR")
30/// [16..20] width  (big-endian u32)
31/// [20..24] height (big-endian u32)
32/// [24]     bit_depth
33/// [25]     color_type
34/// ...
35/// ```
36#[derive(Clone, Debug)]
37pub struct MaskData {
38    png: Vec<u8>,
39}
40
41impl MaskData {
42    /// Creates a `MaskData` from raw PNG bytes.
43    ///
44    /// The caller is responsible for ensuring the bytes represent a valid PNG.
45    /// For validated construction, use [`from_png_checked`](Self::from_png_checked).
46    pub fn from_png(png: Vec<u8>) -> Self {
47        Self { png }
48    }
49
50    /// Creates a `MaskData` from raw PNG bytes with validation.
51    ///
52    /// Validates that the bytes represent a valid grayscale PNG:
53    /// - Length >= 33 bytes (signature + IHDR chunk including CRC)
54    /// - PNG magic bytes at offset 0
55    /// - Color type byte (offset 25) is 0 (grayscale)
56    ///
57    /// # Errors
58    ///
59    /// Returns an error if the bytes are not a valid grayscale PNG.
60    pub fn from_png_checked(png: Vec<u8>) -> Result<Self, crate::Error> {
61        if png.len() < MIN_PNG_LEN {
62            return Err(crate::Error::InvalidParameters(format!(
63                "PNG data too short: {} bytes, minimum {} required",
64                png.len(),
65                MIN_PNG_LEN
66            )));
67        }
68        if png[..8] != PNG_SIGNATURE {
69            return Err(crate::Error::InvalidParameters(
70                "invalid PNG signature: not a PNG file".to_string(),
71            ));
72        }
73        let color_type = png[25];
74        if color_type != 0 {
75            return Err(crate::Error::InvalidParameters(format!(
76                "PNG color type must be 0 (grayscale), got {}",
77                color_type
78            )));
79        }
80
81        // Try to parse PNG header to catch truncated/malformed data
82        let decoder = png::Decoder::new(std::io::Cursor::new(&png));
83        if decoder.read_info().is_err() {
84            return Err(crate::Error::InvalidParameters(
85                "PNG data is malformed or truncated".to_string(),
86            ));
87        }
88
89        Ok(Self { png })
90    }
91
92    /// Returns `true` if the underlying bytes contain a valid PNG signature
93    /// and are long enough to have an IHDR chunk.
94    pub fn is_valid(&self) -> bool {
95        self.png.len() >= MIN_PNG_LEN && self.png[..8] == PNG_SIGNATURE
96    }
97
98    /// Returns a reference to the underlying PNG bytes.
99    pub fn as_bytes(&self) -> &[u8] {
100        &self.png
101    }
102
103    /// Consumes the `MaskData` and returns the underlying PNG bytes.
104    pub fn into_bytes(self) -> Vec<u8> {
105        self.png
106    }
107
108    /// Returns the image width by reading the PNG IHDR chunk (bytes 16..20).
109    ///
110    /// Returns 0 if the PNG data is too short or invalid.
111    pub fn width(&self) -> u32 {
112        self.png
113            .get(16..20)
114            .and_then(|b| b.try_into().ok())
115            .map(u32::from_be_bytes)
116            .unwrap_or(0)
117    }
118
119    /// Returns the image height by reading the PNG IHDR chunk (bytes 20..24).
120    ///
121    /// Returns 0 if the PNG data is too short or invalid.
122    pub fn height(&self) -> u32 {
123        self.png
124            .get(20..24)
125            .and_then(|b| b.try_into().ok())
126            .map(u32::from_be_bytes)
127            .unwrap_or(0)
128    }
129
130    /// Returns the bit depth by reading the PNG IHDR chunk (byte 24).
131    ///
132    /// Returns 0 if the PNG data is too short or invalid.
133    pub fn bit_depth(&self) -> u8 {
134        self.png.get(24).copied().unwrap_or(0)
135    }
136
137    /// Encodes raw 8-bit grayscale pixels into a PNG.
138    ///
139    /// For `bit_depth == 1`, pixel values must be `0` or `1` and will be packed
140    /// into 1-bit-per-pixel PNG rows (MSB first, 8 pixels per byte, with
141    /// zero-padding on the last byte if `width` is not a multiple of 8).
142    ///
143    /// For `bit_depth == 8`, pixels are encoded directly as 8-bit grayscale.
144    ///
145    /// # Errors
146    ///
147    /// Returns an error if `bit_depth` is not 1 or 8, or if `pixels.len()`
148    /// does not equal `width * height`.
149    pub fn encode(
150        pixels: &[u8],
151        width: u32,
152        height: u32,
153        bit_depth: u8,
154    ) -> Result<Self, crate::Error> {
155        if bit_depth != 1 && bit_depth != 8 {
156            return Err(crate::Error::InvalidParameters(format!(
157                "bit_depth must be 1 or 8, got {}",
158                bit_depth
159            )));
160        }
161        let expected = (width as usize) * (height as usize);
162        if pixels.len() != expected {
163            return Err(crate::Error::InvalidParameters(format!(
164                "pixel count mismatch: expected {}, got {}",
165                expected,
166                pixels.len()
167            )));
168        }
169
170        let mut buf = Vec::new();
171        {
172            let mut encoder = png::Encoder::new(&mut buf, width, height);
173            encoder.set_color(png::ColorType::Grayscale);
174            encoder.set_depth(match bit_depth {
175                1 => png::BitDepth::One,
176                8 => png::BitDepth::Eight,
177                _ => unreachable!(),
178            });
179
180            let mut writer = encoder.write_header().map_err(|e| {
181                crate::Error::InvalidParameters(format!("PNG header write failed: {}", e))
182            })?;
183
184            match bit_depth {
185                1 => {
186                    let bytes_per_row = (width as usize).div_ceil(8);
187                    let mut packed = vec![0u8; bytes_per_row * height as usize];
188                    for y in 0..height as usize {
189                        for x in 0..width as usize {
190                            if pixels[y * width as usize + x] != 0 {
191                                packed[y * bytes_per_row + x / 8] |= 0x80 >> (x % 8);
192                            }
193                        }
194                    }
195                    writer.write_image_data(&packed).map_err(|e| {
196                        crate::Error::InvalidParameters(format!(
197                            "PNG image data write failed: {}",
198                            e
199                        ))
200                    })?;
201                }
202                8 => {
203                    writer.write_image_data(pixels).map_err(|e| {
204                        crate::Error::InvalidParameters(format!(
205                            "PNG image data write failed: {}",
206                            e
207                        ))
208                    })?;
209                }
210                _ => unreachable!(),
211            }
212        }
213        Ok(Self { png: buf })
214    }
215
216    /// Encodes raw 16-bit grayscale pixels into a PNG.
217    ///
218    /// Each `u16` value is written as two bytes in big-endian order, matching
219    /// the PNG 16-bit grayscale format.
220    ///
221    /// # Errors
222    ///
223    /// Returns an error if `pixels.len()` does not equal `width * height`.
224    pub fn encode_16bit(pixels: &[u16], width: u32, height: u32) -> Result<Self, crate::Error> {
225        let expected = (width as usize) * (height as usize);
226        if pixels.len() != expected {
227            return Err(crate::Error::InvalidParameters(format!(
228                "pixel count mismatch: expected {}, got {}",
229                expected,
230                pixels.len()
231            )));
232        }
233
234        let mut buf = Vec::new();
235        {
236            let mut encoder = png::Encoder::new(&mut buf, width, height);
237            encoder.set_color(png::ColorType::Grayscale);
238            encoder.set_depth(png::BitDepth::Sixteen);
239
240            let mut writer = encoder.write_header().map_err(|e| {
241                crate::Error::InvalidParameters(format!("PNG header write failed: {}", e))
242            })?;
243
244            let raw: Vec<u8> = pixels.iter().flat_map(|&v| v.to_be_bytes()).collect();
245            writer.write_image_data(&raw).map_err(|e| {
246                crate::Error::InvalidParameters(format!("PNG image data write failed: {}", e))
247            })?;
248        }
249        Ok(Self { png: buf })
250    }
251
252    /// Decodes the PNG image to raw pixel bytes.
253    ///
254    /// For 1-bit PNGs, each pixel is unpacked to a single byte (`0` or `1`).
255    /// For 8-bit PNGs, pixel bytes are returned directly.
256    /// For 16-bit PNGs, each pixel yields two bytes in big-endian order.
257    ///
258    /// # Errors
259    ///
260    /// Returns an error if the PNG data is malformed or cannot be decoded.
261    pub fn decode(&self) -> Result<Vec<u8>, crate::Error> {
262        let decoder = png::Decoder::new(std::io::Cursor::new(self.png.as_slice()));
263        let mut reader = decoder
264            .read_info()
265            .map_err(|e| crate::Error::InvalidParameters(format!("PNG info read failed: {}", e)))?;
266
267        // Guard against decompression bombs
268        let info = reader.info();
269        let total_pixels = info.width as u64 * info.height as u64;
270        const MAX_PIXELS: u64 = 100_000_000; // 100 megapixels
271        if total_pixels > MAX_PIXELS {
272            return Err(crate::Error::InvalidParameters(format!(
273                "PNG dimensions {}x{} exceed maximum of {} pixels",
274                info.width, info.height, MAX_PIXELS
275            )));
276        }
277
278        let buffer_size = reader.output_buffer_size().ok_or_else(|| {
279            crate::Error::InvalidParameters("PNG output buffer size unavailable".to_string())
280        })?;
281        let mut raw = vec![0u8; buffer_size];
282        let info = reader.next_frame(&mut raw).map_err(|e| {
283            crate::Error::InvalidParameters(format!("PNG frame read failed: {}", e))
284        })?;
285        raw.truncate(info.buffer_size());
286
287        if info.bit_depth == png::BitDepth::One {
288            let width = info.width as usize;
289            let height = info.height as usize;
290            let bytes_per_row = width.div_ceil(8);
291            let mut unpacked = Vec::with_capacity(width * height);
292            for y in 0..height {
293                for x in 0..width {
294                    let byte = raw[y * bytes_per_row + x / 8];
295                    let bit = (byte >> (7 - (x % 8))) & 1;
296                    unpacked.push(bit);
297                }
298            }
299            Ok(unpacked)
300        } else {
301            Ok(raw)
302        }
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_encode_decode_8bit() {
312        // 3x3 image with varied grayscale values
313        let pixels: Vec<u8> = vec![0, 64, 128, 192, 255, 1, 100, 200, 50];
314        let mask = MaskData::encode(&pixels, 3, 3, 8).unwrap();
315
316        assert_eq!(mask.width(), 3);
317        assert_eq!(mask.height(), 3);
318        assert_eq!(mask.bit_depth(), 8);
319
320        let decoded = mask.decode().unwrap();
321        assert_eq!(decoded, pixels);
322    }
323
324    #[test]
325    fn test_encode_decode_1bit() {
326        // 8x2 image, byte-aligned width
327        let pixels: Vec<u8> = vec![
328            1, 0, 1, 0, 1, 0, 1, 0, // row 0
329            0, 1, 0, 1, 0, 1, 0, 1, // row 1
330        ];
331        let mask = MaskData::encode(&pixels, 8, 2, 1).unwrap();
332
333        assert_eq!(mask.width(), 8);
334        assert_eq!(mask.height(), 2);
335        assert_eq!(mask.bit_depth(), 1);
336
337        let decoded = mask.decode().unwrap();
338        assert_eq!(decoded, pixels);
339    }
340
341    #[test]
342    fn test_encode_decode_16bit() {
343        // 2x2 image with u16 values
344        let pixels: Vec<u16> = vec![0, 256, 65535, 1024];
345        let mask = MaskData::encode_16bit(&pixels, 2, 2).unwrap();
346
347        assert_eq!(mask.width(), 2);
348        assert_eq!(mask.height(), 2);
349        assert_eq!(mask.bit_depth(), 16);
350
351        let decoded = mask.decode().unwrap();
352        // 16-bit PNG decodes to big-endian byte pairs
353        let expected: Vec<u8> = pixels.iter().flat_map(|&v| v.to_be_bytes()).collect();
354        assert_eq!(decoded, expected);
355    }
356
357    #[test]
358    fn test_header_read_without_decode() {
359        // 640x480 all-zeros: verify header reads work and PNG compresses well
360        let width = 640u32;
361        let height = 480u32;
362        let pixels = vec![0u8; (width * height) as usize];
363        let mask = MaskData::encode(&pixels, width, height, 8).unwrap();
364
365        assert_eq!(mask.width(), width);
366        assert_eq!(mask.height(), height);
367        assert_eq!(mask.bit_depth(), 8);
368
369        // PNG compression of all-zeros should be much smaller than raw pixels
370        let raw_size = (width * height) as usize;
371        assert!(
372            mask.as_bytes().len() < raw_size,
373            "PNG ({} bytes) should be smaller than raw ({} bytes)",
374            mask.as_bytes().len(),
375            raw_size,
376        );
377    }
378
379    #[test]
380    fn test_from_png_bytes() {
381        // Encode, extract bytes, reconstruct, verify roundtrip
382        let pixels: Vec<u8> = vec![10, 20, 30, 40, 50, 60];
383        let original = MaskData::encode(&pixels, 3, 2, 8).unwrap();
384
385        let bytes = original.into_bytes();
386        let reconstructed = MaskData::from_png(bytes);
387
388        assert_eq!(reconstructed.width(), 3);
389        assert_eq!(reconstructed.height(), 2);
390        assert_eq!(reconstructed.bit_depth(), 8);
391        assert_eq!(reconstructed.decode().unwrap(), pixels);
392    }
393
394    #[test]
395    fn test_1bit_non_aligned_width() {
396        // 5x3 image: width not a multiple of 8
397        let pixels: Vec<u8> = vec![
398            1, 0, 1, 1, 0, // row 0
399            0, 1, 0, 0, 1, // row 1
400            1, 1, 1, 0, 0, // row 2
401        ];
402        let mask = MaskData::encode(&pixels, 5, 3, 1).unwrap();
403
404        assert_eq!(mask.width(), 5);
405        assert_eq!(mask.height(), 3);
406        assert_eq!(mask.bit_depth(), 1);
407
408        let decoded = mask.decode().unwrap();
409        assert_eq!(decoded, pixels);
410    }
411
412    // =========================================================================
413    // from_png_checked validation tests
414    // =========================================================================
415
416    #[test]
417    fn test_from_png_empty_bytes() {
418        let result = MaskData::from_png_checked(vec![]);
419        assert!(result.is_err());
420    }
421
422    #[test]
423    fn test_from_png_truncated() {
424        // Just the magic bytes, no IHDR
425        let result = MaskData::from_png_checked(PNG_SIGNATURE.to_vec());
426        assert!(result.is_err());
427    }
428
429    #[test]
430    fn test_from_png_garbage() {
431        let result = MaskData::from_png_checked(vec![0u8; 64]);
432        assert!(result.is_err());
433    }
434
435    #[test]
436    fn test_from_png_wrong_color_type() {
437        // Build a valid-length buffer with correct signature but wrong color type
438        let mut fake_png = vec![0u8; MIN_PNG_LEN];
439        fake_png[..8].copy_from_slice(&PNG_SIGNATURE);
440        fake_png[25] = 2; // RGB instead of grayscale
441        let result = MaskData::from_png_checked(fake_png);
442        assert!(result.is_err());
443    }
444
445    #[test]
446    fn test_from_png_checked_valid() {
447        let pixels: Vec<u8> = vec![0, 128, 255, 64];
448        let mask = MaskData::encode(&pixels, 2, 2, 8).unwrap();
449        let bytes = mask.into_bytes();
450        let result = MaskData::from_png_checked(bytes);
451        assert!(result.is_ok());
452    }
453
454    #[test]
455    fn test_is_valid() {
456        let pixels: Vec<u8> = vec![0, 128, 255, 64];
457        let mask = MaskData::encode(&pixels, 2, 2, 8).unwrap();
458        assert!(mask.is_valid());
459
460        let invalid = MaskData::from_png(vec![1, 2, 3]);
461        assert!(!invalid.is_valid());
462    }
463
464    // =========================================================================
465    // Header reads on invalid data return 0 instead of panicking
466    // =========================================================================
467
468    #[test]
469    fn test_width_height_bit_depth_short_data() {
470        let mask = MaskData::from_png(vec![]);
471        assert_eq!(mask.width(), 0);
472        assert_eq!(mask.height(), 0);
473        assert_eq!(mask.bit_depth(), 0);
474
475        let mask2 = MaskData::from_png(vec![0; 10]);
476        assert_eq!(mask2.width(), 0);
477        assert_eq!(mask2.height(), 0);
478        assert_eq!(mask2.bit_depth(), 0);
479    }
480
481    #[test]
482    fn test_decode_invalid_data_returns_error() {
483        let mask = MaskData::from_png(vec![1, 2, 3]);
484        assert!(mask.decode().is_err());
485    }
486
487    #[test]
488    fn test_encode_invalid_bit_depth() {
489        let result = MaskData::encode(&[0; 4], 2, 2, 4);
490        assert!(result.is_err());
491    }
492
493    #[test]
494    fn test_encode_pixel_count_mismatch() {
495        let result = MaskData::encode(&[0; 3], 2, 2, 8);
496        assert!(result.is_err());
497    }
498}