pineapple_core/im/
mask.rs

1// Copyright (c) 2025, Tom Ouellette
2// Licensed under the BSD 3-Clause License
3
4use std::collections::BTreeSet;
5use std::path::Path;
6
7use image::{DynamicImage, ImageBuffer, Luma, open as open_dynamic};
8use npyz::{self, DType, NpyFile, TypeChar, WriterBuilder};
9
10use crate::constant;
11use crate::cv::{connected_components, find_labeled_contours};
12use crate::error::PineappleError;
13use crate::im::{Polygons, PineappleBuffer, PineappleViewBuffer};
14
15/// A row-major container storing mask pixels
16///
17/// Masks must have pixels in either u8 or u32 format. By default, we cast
18/// u8 masks to u32 type for consistency. The length of the container must
19/// be equal to the product of `w` * `h`.
20///
21/// # Examples
22///
23/// ```
24/// use pineapple_core::im::PineappleMask;
25///
26/// let width = 10;
27/// let height = 10;
28/// let buffer = vec![0u32; (width * height) as usize];
29/// let buffer = PineappleMask::new(width, height, 1, buffer);
30///
31/// assert_eq!(buffer.unwrap().len(), (width * height) as usize);
32/// ```
33///
34/// ```
35/// use pineapple_core::im::PineappleMask;
36///
37/// let width = 10;
38/// let height = 10;
39/// let buffer = vec![0u32; (width * height * 10) as usize];
40/// let buffer = PineappleMask::new(width, height, 1, buffer);
41///
42/// assert!(buffer.is_err()); // Buffer size does not match dimensions
43/// ```
44pub type PineappleMask = PineappleBuffer<u32, Vec<u32>>;
45
46// >>> I/O METHODS
47
48impl PineappleMask {
49    /// Open a new mask from a provided path
50    ///
51    /// # Arguments
52    ///
53    /// * `path` - A path to an image with a valid extension
54    ///
55    /// ```no_run
56    /// use pineapple_core::im::PineappleMask;
57    /// let image = PineappleMask::open("mask.png");
58    /// ```
59    pub fn open<P: AsRef<Path>>(path: P) -> Result<PineappleMask, PineappleError> {
60        let extension = path
61            .as_ref()
62            .extension()
63            .and_then(|s| s.to_str())
64            .map(|s| s.to_lowercase());
65
66        if let Some(ext) = extension {
67            if ext == "npy" {
68                if let Ok(bytes) = std::fs::read(&path)
69                    && let Ok(npy) = NpyFile::new(&bytes[..]) {
70                    return Self::new_from_numpy(npy);
71                }
72
73                return Err(PineappleError::ImageReadError);
74            }
75
76            if constant::IMAGE_DYNAMIC_FORMATS.iter().any(|e| e == &ext) {
77                if let Ok(image) = open_dynamic(&path) {
78                    return Self::new_from_dynamic(image);
79                }
80
81                return Err(PineappleError::ImageReadError);
82            }
83        }
84
85        Err(PineappleError::ImageExtensionError)
86    }
87
88    /// Initialize a new mask from a DynamicImage
89    ///
90    /// # Arguments
91    ///
92    /// * `image` - An 8 or 16-bit grayscale DynamicImage
93    ///
94    /// # Examples
95    ///
96    /// ```
97    /// use image::{GrayImage, DynamicImage};
98    /// use pineapple_core::im::PineappleMask;
99    ///
100    /// let gray = GrayImage::new(10, 10);
101    /// let dynamic = DynamicImage::ImageLuma8(gray);
102    /// let image = PineappleMask::new_from_dynamic(dynamic);
103    /// ```
104    pub fn new_from_dynamic(mask: DynamicImage) -> Result<PineappleMask, PineappleError> {
105        let width = mask.width();
106        let height = mask.height();
107
108        match mask {
109            DynamicImage::ImageLuma8(buffer) => Ok(PineappleMask::new(
110                width,
111                height,
112                1,
113                buffer
114                    .into_raw()
115                    .into_iter()
116                    .map(|pixel| pixel as u32)
117                    .collect(),
118            )?),
119            DynamicImage::ImageLumaA8(buffer) => Ok(PineappleMask::new(
120                width,
121                height,
122                1,
123                buffer
124                    .into_raw()
125                    .chunks_exact(2)
126                    .map(|pixel| pixel[0] as u32)
127                    .collect(),
128            )?),
129            DynamicImage::ImageLuma16(buffer) => Ok(PineappleMask::new(
130                width,
131                height,
132                1,
133                buffer
134                    .into_raw()
135                    .into_iter()
136                    .map(|pixel| pixel as u32)
137                    .collect(),
138            )?),
139            DynamicImage::ImageLumaA16(buffer) => Ok(PineappleMask::new(
140                width,
141                height,
142                1,
143                buffer
144                    .into_raw()
145                    .chunks_exact(2)
146                    .map(|pixel| pixel[0] as u32)
147                    .collect(),
148            )?),
149            _ => Err(PineappleError::MaskError(
150                "A dynamic image mask with a valid data type was not detected.",
151            )),
152        }
153    }
154
155    /// Initialize a new image from a numpy array buffer
156    ///
157    /// # Arguments
158    ///
159    /// * `npy` - A (height, width, channel) shaped numpy array buffer
160    ///
161    /// # Examples
162    ///
163    /// ```no_run
164    /// use npyz::NpyFile;
165    /// use pineapple_core::im::PineappleMask;
166    ///
167    /// let bytes = std::fs::read("mask.npy").unwrap();
168    /// let npy = NpyFile::new(&bytes[..]).unwrap();
169    /// let image = PineappleMask::new_from_numpy(npy);
170    /// ```
171    pub fn new_from_numpy(npy: NpyFile<&[u8]>) -> Result<PineappleMask, PineappleError> {
172        let shape = npy.shape().to_vec();
173
174        let (h, w, c) = match shape.len() {
175            2 => (shape[0] as u32, shape[1] as u32, 1u32),
176            3 => (shape[0] as u32, shape[1] as u32, shape[2] as u32),
177            _ => {
178                return Err(PineappleError::MaskError(
179                    "Numpy array masks must have an (H, W) shape.",
180                ));
181            }
182        };
183
184        if c != 1 {
185            return Err(PineappleError::MaskFormatError);
186        }
187
188        match npy.dtype() {
189            DType::Plain(x) => match (x.type_char(), x.size_field()) {
190                (TypeChar::Uint, 1) => Ok(PineappleMask::new(
191                    w,
192                    h,
193                    1,
194                    npy.into_vec()
195                        .unwrap()
196                        .into_iter()
197                        .map(|pixel: u8| pixel as u32)
198                        .collect(),
199                )?),
200                (TypeChar::Uint, 2) => Ok(PineappleMask::new(
201                    w,
202                    h,
203                    1,
204                    npy.into_vec()
205                        .unwrap()
206                        .into_iter()
207                        .map(|pixel: u16| pixel as u32)
208                        .collect(),
209                )?),
210                (TypeChar::Uint, 4) => Ok(PineappleMask::new(w, h, 1, npy.into_vec().unwrap())?),
211                _ => Err(PineappleError::MaskError(
212                    "A numpy mask array with a valid data type was not detected.",
213                )),
214            },
215            _ => Err(PineappleError::MaskError(
216                "Only plain numpy mask arrays are currentled supported.",
217            )),
218        }
219    }
220}
221
222// <<< I/O METHODS
223
224// >>> TRANSFORM METHODS
225
226impl PineappleMask {
227    /// Re-label the mask using connected components and return unique labels
228    ///
229    /// # Notes
230    ///
231    /// Re-labelling is guaranteed to assign the correct number of labels when
232    /// assuming 8-connectivity. However, the labels are not guaranteed to be
233    /// incremental (e.g. 1, 2, 3, ..). This should be taken into account when
234    /// iterating over objects.
235    pub fn label(&mut self) -> Vec<u32> {
236        let mut labels: Vec<u32> = self
237            .as_raw()
238            .iter()
239            .filter(|&&x| x != 0)
240            .cloned()
241            .collect::<BTreeSet<u32>>()
242            .into_iter()
243            .collect();
244
245        // Currently, we only re-label binary masks and assume any mask
246        // with more than one unique label is an integer-labeled mask.
247        if labels.len() == 1 {
248            self.buffer = connected_components(self.width(), self.height(), &self.buffer);
249            labels = self
250                .as_raw()
251                .iter()
252                .filter(|&&x| x != 0)
253                .cloned()
254                .collect::<BTreeSet<u32>>()
255                .into_iter()
256                .collect();
257        }
258
259        labels
260    }
261
262    /// Extract polygons from a segmentation mask
263    pub fn polygons(&mut self) -> Result<(Vec<u32>, Polygons), PineappleError> {
264        let labels = self.label();
265        let (labels, contours) =
266            find_labeled_contours(self.width(), self.height(), &self.buffer, &labels);
267
268        Ok((labels, Polygons::new(contours)?))
269    }
270
271    /// Crops image while only including pixels with a specified label
272    ///
273    /// # Arguments
274    ///
275    /// * `x` - Minimum x-coordinate (left)
276    /// * `y` - Minimum y-coordinate (bottom)
277    /// * `w` - Width of crop
278    /// * `h` - Height of crop
279    /// * `label` - Only include mask pixels equal to this label
280    pub fn crop_binary(
281        &self,
282        x: u32,
283        y: u32,
284        w: u32,
285        h: u32,
286        label: u32,
287    ) -> Result<PineappleMask, PineappleError> {
288        if x + w > self.width() || y + h > self.height() {
289            return Err(PineappleError::MaskError("Cropping coordinates out of bounds"));
290        }
291
292        let c = self.channels() as usize;
293        let orig_w = self.width() as usize;
294        let orig_buffer: &[u32] = self.buffer.as_ref();
295
296        let mut new_buffer = Vec::with_capacity((w * h * self.channels()) as usize);
297
298        for row in y..y + h {
299            let start = ((row as usize) * orig_w + (x as usize)) * c;
300            let end = start + (w as usize) * c;
301
302            new_buffer.extend(
303                orig_buffer[start..end]
304                    .iter()
305                    .map(|&v| if v == label { 1 } else { 0 }),
306            );
307        }
308
309        PineappleMask::new(w, h, self.channels(), new_buffer)
310    }
311}
312
313// <<< TRANSFORM METHODS
314
315/// A type for mask object buffer
316pub type PineappleMaskView<'a> = PineappleViewBuffer<'a, u32, Vec<u32>>;
317
318// I/O METHODS
319
320impl<'a> PineappleMaskView<'a> {
321    /// Save an object
322    ///
323    /// # Arguments
324    ///
325    /// * `path` - A path to an image with a valid extension
326    ///
327    /// ```no_run
328    /// use pineapple_core::im::PineappleImage;
329    /// let image = PineappleImage::open("image.png");
330    /// ```
331    pub fn save<P: AsRef<Path>>(&'a self, path: P, label: &u32) -> Result<(), PineappleError> {
332        let extension = path
333            .as_ref()
334            .extension()
335            .and_then(|s| s.to_str())
336            .map(|s| s.to_lowercase());
337
338        if let Some(ext) = extension {
339            if ext == "npy" {
340                let mut buffer = vec![];
341                let mut writer = npyz::WriteOptions::<u8>::new()
342                    .default_dtype()
343                    .shape(&[self.height() as u64, self.width() as u64])
344                    .writer(&mut buffer)
345                    .begin_nd()
346                    .map_err(|_| PineappleError::ImageWriteError)?;
347
348                for d in self.iter() {
349                    if d == label {
350                        writer.push(&255u8).unwrap();
351                    } else {
352                        writer.push(&0u8).unwrap();
353                    };
354                }
355
356                writer.finish().map_err(|_| PineappleError::ImageWriteError)?;
357                std::fs::write(&path, buffer).map_err(|_| PineappleError::ImageWriteError)?;
358
359                return Ok(());
360            }
361
362            if constant::IMAGE_DYNAMIC_FORMATS.iter().any(|e| e == &ext) {
363                ImageBuffer::<Luma<u8>, Vec<u8>>::from_raw(
364                    self.width() as u32,
365                    self.height() as u32,
366                    self.iter()
367                        .map(|p| if p == label { 255u8 } else { 0u8 })
368                        .collect(),
369                )
370                .unwrap()
371                .save(path)
372                .map_err(|_| PineappleError::ImageWriteError)?;
373
374                return Ok(());
375            }
376        }
377
378        Err(PineappleError::ImageExtensionError)
379    }
380}
381
382// <<< I/O METHODS
383
384/// Type of masking style to use
385pub enum MaskingStyle {
386    Foreground,
387    Background,
388}
389
390#[cfg(test)]
391mod test {
392
393    use super::*;
394
395    const TEST_MASK: &str = "../data/tests/test_mask";
396    const TEST_BLOB: &str = "../data/tests/test_mask_binary_blobs.png";
397
398    #[test]
399    fn test_mask_open() {
400        let extensions = [
401            "_binary.png",
402            "_binary_1.npy",
403            "_binary_1_u16.npy",
404            "_binary_255.npy",
405            "_binary_255_u16.npy",
406            "_integer.png",
407            "_integer.npy",
408            "_integer_u16.npy",
409        ];
410
411        for ext in extensions.into_iter() {
412            let img = PineappleMask::open(format!("{}{}", TEST_MASK, ext));
413            assert!(img.is_ok(), "{}", ext);
414
415            let img = img.unwrap();
416            assert_eq!(img.width(), 621);
417            assert_eq!(img.height(), 621);
418            assert_eq!(img.channels(), 1, "{}", ext);
419        }
420    }
421
422    #[test]
423    fn test_mask_save() {
424        const TEST_DEFAULT: &str = "TEST_SAVE_DEFAULT_MASK.png";
425        const TEST_NUMPY: &str = "TEST_SAVE_NUMPY_MASK.npy";
426
427        let mask = PineappleMask::new(2, 2, 1, vec![0, 255, 0, 0]).unwrap();
428
429        mask.crop_view(0, 0, 2, 2).save(TEST_DEFAULT, &255).unwrap();
430        mask.crop_view(0, 0, 2, 2).save(TEST_NUMPY, &255).unwrap();
431
432        let mask_default = PineappleMask::open(TEST_DEFAULT).unwrap();
433        let mask_numpy = PineappleMask::open(TEST_NUMPY).unwrap();
434
435        assert_eq!(mask.as_raw(), mask_default.as_raw());
436        assert_eq!(mask.as_raw(), mask_numpy.as_raw());
437
438        std::fs::remove_file(TEST_DEFAULT).unwrap();
439        std::fs::remove_file(TEST_NUMPY).unwrap();
440    }
441
442    #[test]
443    fn test_label_blob() {
444        let mut mask = PineappleMask::open(TEST_BLOB).unwrap();
445        let labels = mask.label();
446        assert_eq!(labels.len(), 11);
447    }
448
449    #[test]
450    fn test_mask_crop() {
451        let width = 10;
452        let height = 10;
453        let data: Vec<u32> = (0..width * height).collect();
454
455        let buffer = PineappleMask::new(width, height, 1, data).unwrap();
456        let crop = buffer.crop_view(0, 0, 10, 1);
457
458        for (i, col) in crop.iter().enumerate() {
459            assert_eq!(col, &(i as u32));
460        }
461    }
462
463    #[test]
464    fn test_mask_label() {
465        let width = 10;
466        let height = 10;
467        let mut data: Vec<u32> = vec![0u32; 100];
468
469        data[5] = 1u32;
470        data[25] = 1u32;
471        data[45] = 1u32;
472        data[65] = 1u32;
473        data[85] = 1u32;
474
475        let mut buffer = PineappleMask::new(width, height as u32, 1, data).unwrap();
476
477        let labels = buffer.label();
478        assert_eq!(labels.len(), 5);
479
480        assert_eq!(labels[0], 1);
481        assert_eq!(labels[1], 2);
482        assert_eq!(labels[2], 3);
483        assert_eq!(labels[3], 4);
484        assert_eq!(labels[4], 5);
485    }
486
487    #[test]
488    fn test_mask_crop_binary() {
489        let width = 2;
490        let height = 2;
491        let data: Vec<u32> = vec![0, 1, 2, 3];
492
493        let buffer = PineappleMask::new(width, height, 1, data).unwrap();
494
495        let binary = buffer.crop_binary(0, 0, 2, 2, 1).unwrap();
496
497        assert_eq!(binary.as_raw(), &[0, 1, 0, 0]);
498    }
499}