Skip to main content

edgefirst_client/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! Raster mask data with PNG encode/decode support.
5//!
6//! [`MaskData`] wraps PNG-encoded bytes for storing raster masks in Arrow
7//! `Binary` columns. It supports fast header-only reads (width, height,
8//! bit_depth) without decoding pixels, and can encode from raw pixels at
9//! various bit depths (1-bit binary, 8-bit scores, 16-bit precision).
10
11/// PNG magic bytes (first 8 bytes of every valid PNG file).
12const PNG_SIGNATURE: [u8; 8] = [137, 80, 78, 71, 13, 10, 26, 10];
13
14/// Minimum length of a valid PNG file with IHDR chunk.
15/// 8 (signature) + 4 (length) + 4 (type "IHDR") + 13 (IHDR data) + 4 (IHDR CRC) = 33
16const MIN_PNG_LEN: usize = 33;
17
18/// A raster mask stored as PNG-encoded bytes.
19///
20/// `MaskData` provides zero-copy access to PNG metadata (width, height,
21/// bit_depth) by reading the IHDR chunk directly, and full encode/decode
22/// for pixel data at 1-bit, 8-bit, and 16-bit depths.
23///
24/// # PNG layout reference
25///
26/// ```text
27/// [0..8]   PNG signature
28/// [8..12]  IHDR chunk length (always 13)
29/// [12..16] IHDR chunk type ("IHDR")
30/// [16..20] width  (big-endian u32)
31/// [20..24] height (big-endian u32)
32/// [24]     bit_depth
33/// [25]     color_type
34/// ...
35/// ```
36#[derive(Clone, Debug)]
37pub struct MaskData {
38    png: Vec<u8>,
39}
40
41impl MaskData {
42    /// Creates a `MaskData` from raw PNG bytes.
43    ///
44    /// The caller is responsible for ensuring the bytes represent a valid PNG.
45    /// For validated construction, use [`from_png_checked`](Self::from_png_checked).
46    pub fn from_png(png: Vec<u8>) -> Self {
47        Self { png }
48    }
49
50    /// Creates a `MaskData` from raw PNG bytes with validation.
51    ///
52    /// Validates that the bytes represent a valid grayscale PNG:
53    /// - Length >= 33 bytes (signature + IHDR chunk including CRC)
54    /// - PNG magic bytes at offset 0
55    /// - Color type byte (offset 25) is 0 (grayscale)
56    ///
57    /// # Errors
58    ///
59    /// Returns an error if the bytes are not a valid grayscale PNG.
60    pub fn from_png_checked(png: Vec<u8>) -> Result<Self, crate::Error> {
61        if png.len() < MIN_PNG_LEN {
62            return Err(crate::Error::InvalidParameters(format!(
63                "PNG data too short: {} bytes, minimum {} required",
64                png.len(),
65                MIN_PNG_LEN
66            )));
67        }
68        if png[..8] != PNG_SIGNATURE {
69            return Err(crate::Error::InvalidParameters(
70                "invalid PNG signature: not a PNG file".to_string(),
71            ));
72        }
73        let color_type = png[25];
74        if color_type != 0 {
75            return Err(crate::Error::InvalidParameters(format!(
76                "PNG color type must be 0 (grayscale), got {}",
77                color_type
78            )));
79        }
80
81        // Try to parse PNG header to catch truncated/malformed data
82        let decoder = png::Decoder::new(std::io::Cursor::new(&png));
83        if decoder.read_info().is_err() {
84            return Err(crate::Error::InvalidParameters(
85                "PNG data is malformed or truncated".to_string(),
86            ));
87        }
88
89        Ok(Self { png })
90    }
91
92    /// Returns `true` if the underlying bytes contain a valid PNG signature
93    /// and are long enough to have an IHDR chunk.
94    pub fn is_valid(&self) -> bool {
95        self.png.len() >= MIN_PNG_LEN && self.png[..8] == PNG_SIGNATURE
96    }
97
98    /// Returns a reference to the underlying PNG bytes.
99    pub fn as_bytes(&self) -> &[u8] {
100        &self.png
101    }
102
103    /// Consumes the `MaskData` and returns the underlying PNG bytes.
104    pub fn into_bytes(self) -> Vec<u8> {
105        self.png
106    }
107
108    /// Returns the image width by reading the PNG IHDR chunk (bytes 16..20).
109    ///
110    /// Returns 0 if the PNG data is too short or invalid.
111    pub fn width(&self) -> u32 {
112        self.png
113            .get(16..20)
114            .and_then(|b| b.try_into().ok())
115            .map(u32::from_be_bytes)
116            .unwrap_or(0)
117    }
118
119    /// Returns the image height by reading the PNG IHDR chunk (bytes 20..24).
120    ///
121    /// Returns 0 if the PNG data is too short or invalid.
122    pub fn height(&self) -> u32 {
123        self.png
124            .get(20..24)
125            .and_then(|b| b.try_into().ok())
126            .map(u32::from_be_bytes)
127            .unwrap_or(0)
128    }
129
130    /// Returns the bit depth by reading the PNG IHDR chunk (byte 24).
131    ///
132    /// Returns 0 if the PNG data is too short or invalid.
133    pub fn bit_depth(&self) -> u8 {
134        self.png.get(24).copied().unwrap_or(0)
135    }
136
137    /// Encodes raw 8-bit grayscale pixels into a PNG.
138    ///
139    /// For `bit_depth == 1`, pixel values must be `0` or `1` and will be packed
140    /// into 1-bit-per-pixel PNG rows (MSB first, 8 pixels per byte, with
141    /// zero-padding on the last byte if `width` is not a multiple of 8).
142    ///
143    /// For `bit_depth == 8`, pixels are encoded directly as 8-bit grayscale.
144    ///
145    /// # Errors
146    ///
147    /// Returns an error if `bit_depth` is not 1 or 8, or if `pixels.len()`
148    /// does not equal `width * height`.
149    pub fn encode(
150        pixels: &[u8],
151        width: u32,
152        height: u32,
153        bit_depth: u8,
154    ) -> Result<Self, crate::Error> {
155        if bit_depth != 1 && bit_depth != 8 {
156            return Err(crate::Error::InvalidParameters(format!(
157                "bit_depth must be 1 or 8, got {}",
158                bit_depth
159            )));
160        }
161        let expected = (width as usize) * (height as usize);
162        if pixels.len() != expected {
163            return Err(crate::Error::InvalidParameters(format!(
164                "pixel count mismatch: expected {}, got {}",
165                expected,
166                pixels.len()
167            )));
168        }
169
170        let mut buf = Vec::new();
171        {
172            let mut encoder = png::Encoder::new(&mut buf, width, height);
173            encoder.set_color(png::ColorType::Grayscale);
174            encoder.set_depth(match bit_depth {
175                1 => png::BitDepth::One,
176                8 => png::BitDepth::Eight,
177                _ => unreachable!(),
178            });
179
180            let mut writer = encoder.write_header().map_err(|e| {
181                crate::Error::InvalidParameters(format!("PNG header write failed: {}", e))
182            })?;
183
184            match bit_depth {
185                1 => {
186                    let bytes_per_row = (width as usize).div_ceil(8);
187                    let mut packed = vec![0u8; bytes_per_row * height as usize];
188                    for y in 0..height as usize {
189                        for x in 0..width as usize {
190                            if pixels[y * width as usize + x] != 0 {
191                                packed[y * bytes_per_row + x / 8] |= 0x80 >> (x % 8);
192                            }
193                        }
194                    }
195                    writer.write_image_data(&packed).map_err(|e| {
196                        crate::Error::InvalidParameters(format!(
197                            "PNG image data write failed: {}",
198                            e
199                        ))
200                    })?;
201                }
202                8 => {
203                    writer.write_image_data(pixels).map_err(|e| {
204                        crate::Error::InvalidParameters(format!(
205                            "PNG image data write failed: {}",
206                            e
207                        ))
208                    })?;
209                }
210                _ => unreachable!(),
211            }
212        }
213        Ok(Self { png: buf })
214    }
215
216    /// Encodes raw 16-bit grayscale pixels into a PNG.
217    ///
218    /// Each `u16` value is written as two bytes in big-endian order, matching
219    /// the PNG 16-bit grayscale format.
220    ///
221    /// # Errors
222    ///
223    /// Returns an error if `pixels.len()` does not equal `width * height`.
224    pub fn encode_16bit(pixels: &[u16], width: u32, height: u32) -> Result<Self, crate::Error> {
225        let expected = (width as usize) * (height as usize);
226        if pixels.len() != expected {
227            return Err(crate::Error::InvalidParameters(format!(
228                "pixel count mismatch: expected {}, got {}",
229                expected,
230                pixels.len()
231            )));
232        }
233
234        let mut buf = Vec::new();
235        {
236            let mut encoder = png::Encoder::new(&mut buf, width, height);
237            encoder.set_color(png::ColorType::Grayscale);
238            encoder.set_depth(png::BitDepth::Sixteen);
239
240            let mut writer = encoder.write_header().map_err(|e| {
241                crate::Error::InvalidParameters(format!("PNG header write failed: {}", e))
242            })?;
243
244            let raw: Vec<u8> = pixels.iter().flat_map(|&v| v.to_be_bytes()).collect();
245            writer.write_image_data(&raw).map_err(|e| {
246                crate::Error::InvalidParameters(format!("PNG image data write failed: {}", e))
247            })?;
248        }
249        Ok(Self { png: buf })
250    }
251
252    /// Decodes the PNG image to raw pixel bytes.
253    ///
254    /// For 1-bit PNGs, each pixel is unpacked to a single byte (`0` or `1`).
255    /// For 8-bit PNGs, pixel bytes are returned directly.
256    /// For 16-bit PNGs, each pixel yields two bytes in big-endian order.
257    ///
258    /// # Errors
259    ///
260    /// Returns an error if the PNG data is malformed or cannot be decoded.
261    pub fn decode(&self) -> Result<Vec<u8>, crate::Error> {
262        let decoder = png::Decoder::new(self.png.as_slice());
263        let mut reader = decoder
264            .read_info()
265            .map_err(|e| crate::Error::InvalidParameters(format!("PNG info read failed: {}", e)))?;
266
267        // Guard against decompression bombs
268        let info = reader.info();
269        let total_pixels = info.width as u64 * info.height as u64;
270        const MAX_PIXELS: u64 = 100_000_000; // 100 megapixels
271        if total_pixels > MAX_PIXELS {
272            return Err(crate::Error::InvalidParameters(format!(
273                "PNG dimensions {}x{} exceed maximum of {} pixels",
274                info.width, info.height, MAX_PIXELS
275            )));
276        }
277
278        let mut raw = vec![0u8; reader.output_buffer_size()];
279        let info = reader.next_frame(&mut raw).map_err(|e| {
280            crate::Error::InvalidParameters(format!("PNG frame read failed: {}", e))
281        })?;
282        raw.truncate(info.buffer_size());
283
284        if info.bit_depth == png::BitDepth::One {
285            let width = info.width as usize;
286            let height = info.height as usize;
287            let bytes_per_row = width.div_ceil(8);
288            let mut unpacked = Vec::with_capacity(width * height);
289            for y in 0..height {
290                for x in 0..width {
291                    let byte = raw[y * bytes_per_row + x / 8];
292                    let bit = (byte >> (7 - (x % 8))) & 1;
293                    unpacked.push(bit);
294                }
295            }
296            Ok(unpacked)
297        } else {
298            Ok(raw)
299        }
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_encode_decode_8bit() {
309        // 3x3 image with varied grayscale values
310        let pixels: Vec<u8> = vec![0, 64, 128, 192, 255, 1, 100, 200, 50];
311        let mask = MaskData::encode(&pixels, 3, 3, 8).unwrap();
312
313        assert_eq!(mask.width(), 3);
314        assert_eq!(mask.height(), 3);
315        assert_eq!(mask.bit_depth(), 8);
316
317        let decoded = mask.decode().unwrap();
318        assert_eq!(decoded, pixels);
319    }
320
321    #[test]
322    fn test_encode_decode_1bit() {
323        // 8x2 image, byte-aligned width
324        let pixels: Vec<u8> = vec![
325            1, 0, 1, 0, 1, 0, 1, 0, // row 0
326            0, 1, 0, 1, 0, 1, 0, 1, // row 1
327        ];
328        let mask = MaskData::encode(&pixels, 8, 2, 1).unwrap();
329
330        assert_eq!(mask.width(), 8);
331        assert_eq!(mask.height(), 2);
332        assert_eq!(mask.bit_depth(), 1);
333
334        let decoded = mask.decode().unwrap();
335        assert_eq!(decoded, pixels);
336    }
337
338    #[test]
339    fn test_encode_decode_16bit() {
340        // 2x2 image with u16 values
341        let pixels: Vec<u16> = vec![0, 256, 65535, 1024];
342        let mask = MaskData::encode_16bit(&pixels, 2, 2).unwrap();
343
344        assert_eq!(mask.width(), 2);
345        assert_eq!(mask.height(), 2);
346        assert_eq!(mask.bit_depth(), 16);
347
348        let decoded = mask.decode().unwrap();
349        // 16-bit PNG decodes to big-endian byte pairs
350        let expected: Vec<u8> = pixels.iter().flat_map(|&v| v.to_be_bytes()).collect();
351        assert_eq!(decoded, expected);
352    }
353
354    #[test]
355    fn test_header_read_without_decode() {
356        // 640x480 all-zeros: verify header reads work and PNG compresses well
357        let width = 640u32;
358        let height = 480u32;
359        let pixels = vec![0u8; (width * height) as usize];
360        let mask = MaskData::encode(&pixels, width, height, 8).unwrap();
361
362        assert_eq!(mask.width(), width);
363        assert_eq!(mask.height(), height);
364        assert_eq!(mask.bit_depth(), 8);
365
366        // PNG compression of all-zeros should be much smaller than raw pixels
367        let raw_size = (width * height) as usize;
368        assert!(
369            mask.as_bytes().len() < raw_size,
370            "PNG ({} bytes) should be smaller than raw ({} bytes)",
371            mask.as_bytes().len(),
372            raw_size,
373        );
374    }
375
376    #[test]
377    fn test_from_png_bytes() {
378        // Encode, extract bytes, reconstruct, verify roundtrip
379        let pixels: Vec<u8> = vec![10, 20, 30, 40, 50, 60];
380        let original = MaskData::encode(&pixels, 3, 2, 8).unwrap();
381
382        let bytes = original.into_bytes();
383        let reconstructed = MaskData::from_png(bytes);
384
385        assert_eq!(reconstructed.width(), 3);
386        assert_eq!(reconstructed.height(), 2);
387        assert_eq!(reconstructed.bit_depth(), 8);
388        assert_eq!(reconstructed.decode().unwrap(), pixels);
389    }
390
391    #[test]
392    fn test_1bit_non_aligned_width() {
393        // 5x3 image: width not a multiple of 8
394        let pixels: Vec<u8> = vec![
395            1, 0, 1, 1, 0, // row 0
396            0, 1, 0, 0, 1, // row 1
397            1, 1, 1, 0, 0, // row 2
398        ];
399        let mask = MaskData::encode(&pixels, 5, 3, 1).unwrap();
400
401        assert_eq!(mask.width(), 5);
402        assert_eq!(mask.height(), 3);
403        assert_eq!(mask.bit_depth(), 1);
404
405        let decoded = mask.decode().unwrap();
406        assert_eq!(decoded, pixels);
407    }
408
409    // =========================================================================
410    // from_png_checked validation tests
411    // =========================================================================
412
413    #[test]
414    fn test_from_png_empty_bytes() {
415        let result = MaskData::from_png_checked(vec![]);
416        assert!(result.is_err());
417    }
418
419    #[test]
420    fn test_from_png_truncated() {
421        // Just the magic bytes, no IHDR
422        let result = MaskData::from_png_checked(PNG_SIGNATURE.to_vec());
423        assert!(result.is_err());
424    }
425
426    #[test]
427    fn test_from_png_garbage() {
428        let result = MaskData::from_png_checked(vec![0u8; 64]);
429        assert!(result.is_err());
430    }
431
432    #[test]
433    fn test_from_png_wrong_color_type() {
434        // Build a valid-length buffer with correct signature but wrong color type
435        let mut fake_png = vec![0u8; MIN_PNG_LEN];
436        fake_png[..8].copy_from_slice(&PNG_SIGNATURE);
437        fake_png[25] = 2; // RGB instead of grayscale
438        let result = MaskData::from_png_checked(fake_png);
439        assert!(result.is_err());
440    }
441
442    #[test]
443    fn test_from_png_checked_valid() {
444        let pixels: Vec<u8> = vec![0, 128, 255, 64];
445        let mask = MaskData::encode(&pixels, 2, 2, 8).unwrap();
446        let bytes = mask.into_bytes();
447        let result = MaskData::from_png_checked(bytes);
448        assert!(result.is_ok());
449    }
450
451    #[test]
452    fn test_is_valid() {
453        let pixels: Vec<u8> = vec![0, 128, 255, 64];
454        let mask = MaskData::encode(&pixels, 2, 2, 8).unwrap();
455        assert!(mask.is_valid());
456
457        let invalid = MaskData::from_png(vec![1, 2, 3]);
458        assert!(!invalid.is_valid());
459    }
460
461    // =========================================================================
462    // Header reads on invalid data return 0 instead of panicking
463    // =========================================================================
464
465    #[test]
466    fn test_width_height_bit_depth_short_data() {
467        let mask = MaskData::from_png(vec![]);
468        assert_eq!(mask.width(), 0);
469        assert_eq!(mask.height(), 0);
470        assert_eq!(mask.bit_depth(), 0);
471
472        let mask2 = MaskData::from_png(vec![0; 10]);
473        assert_eq!(mask2.width(), 0);
474        assert_eq!(mask2.height(), 0);
475        assert_eq!(mask2.bit_depth(), 0);
476    }
477
478    #[test]
479    fn test_decode_invalid_data_returns_error() {
480        let mask = MaskData::from_png(vec![1, 2, 3]);
481        assert!(mask.decode().is_err());
482    }
483
484    #[test]
485    fn test_encode_invalid_bit_depth() {
486        let result = MaskData::encode(&[0; 4], 2, 2, 4);
487        assert!(result.is_err());
488    }
489
490    #[test]
491    fn test_encode_pixel_count_mismatch() {
492        let result = MaskData::encode(&[0; 3], 2, 2, 8);
493        assert!(result.is_err());
494    }
495}