1use std::collections::BTreeSet;
5use std::path::Path;
6
7use image::{DynamicImage, ImageBuffer, Luma, open as open_dynamic};
8use npyz::{self, DType, NpyFile, TypeChar, WriterBuilder};
9
10use crate::constant;
11use crate::cv::{connected_components, find_labeled_contours};
12use crate::error::PineappleError;
13use crate::im::{Polygons, PineappleBuffer, PineappleViewBuffer};
14
15pub type PineappleMask = PineappleBuffer<u32, Vec<u32>>;
45
46impl PineappleMask {
49 pub fn open<P: AsRef<Path>>(path: P) -> Result<PineappleMask, PineappleError> {
60 let extension = path
61 .as_ref()
62 .extension()
63 .and_then(|s| s.to_str())
64 .map(|s| s.to_lowercase());
65
66 if let Some(ext) = extension {
67 if ext == "npy" {
68 if let Ok(bytes) = std::fs::read(&path)
69 && let Ok(npy) = NpyFile::new(&bytes[..]) {
70 return Self::new_from_numpy(npy);
71 }
72
73 return Err(PineappleError::ImageReadError);
74 }
75
76 if constant::IMAGE_DYNAMIC_FORMATS.iter().any(|e| e == &ext) {
77 if let Ok(image) = open_dynamic(&path) {
78 return Self::new_from_dynamic(image);
79 }
80
81 return Err(PineappleError::ImageReadError);
82 }
83 }
84
85 Err(PineappleError::ImageExtensionError)
86 }
87
88 pub fn new_from_dynamic(mask: DynamicImage) -> Result<PineappleMask, PineappleError> {
105 let width = mask.width();
106 let height = mask.height();
107
108 match mask {
109 DynamicImage::ImageLuma8(buffer) => Ok(PineappleMask::new(
110 width,
111 height,
112 1,
113 buffer
114 .into_raw()
115 .into_iter()
116 .map(|pixel| pixel as u32)
117 .collect(),
118 )?),
119 DynamicImage::ImageLumaA8(buffer) => Ok(PineappleMask::new(
120 width,
121 height,
122 1,
123 buffer
124 .into_raw()
125 .chunks_exact(2)
126 .map(|pixel| pixel[0] as u32)
127 .collect(),
128 )?),
129 DynamicImage::ImageLuma16(buffer) => Ok(PineappleMask::new(
130 width,
131 height,
132 1,
133 buffer
134 .into_raw()
135 .into_iter()
136 .map(|pixel| pixel as u32)
137 .collect(),
138 )?),
139 DynamicImage::ImageLumaA16(buffer) => Ok(PineappleMask::new(
140 width,
141 height,
142 1,
143 buffer
144 .into_raw()
145 .chunks_exact(2)
146 .map(|pixel| pixel[0] as u32)
147 .collect(),
148 )?),
149 _ => Err(PineappleError::MaskError(
150 "A dynamic image mask with a valid data type was not detected.",
151 )),
152 }
153 }
154
155 pub fn new_from_numpy(npy: NpyFile<&[u8]>) -> Result<PineappleMask, PineappleError> {
172 let shape = npy.shape().to_vec();
173
174 let (h, w, c) = match shape.len() {
175 2 => (shape[0] as u32, shape[1] as u32, 1u32),
176 3 => (shape[0] as u32, shape[1] as u32, shape[2] as u32),
177 _ => {
178 return Err(PineappleError::MaskError(
179 "Numpy array masks must have an (H, W) shape.",
180 ));
181 }
182 };
183
184 if c != 1 {
185 return Err(PineappleError::MaskFormatError);
186 }
187
188 match npy.dtype() {
189 DType::Plain(x) => match (x.type_char(), x.size_field()) {
190 (TypeChar::Uint, 1) => Ok(PineappleMask::new(
191 w,
192 h,
193 1,
194 npy.into_vec()
195 .unwrap()
196 .into_iter()
197 .map(|pixel: u8| pixel as u32)
198 .collect(),
199 )?),
200 (TypeChar::Uint, 2) => Ok(PineappleMask::new(
201 w,
202 h,
203 1,
204 npy.into_vec()
205 .unwrap()
206 .into_iter()
207 .map(|pixel: u16| pixel as u32)
208 .collect(),
209 )?),
210 (TypeChar::Uint, 4) => Ok(PineappleMask::new(w, h, 1, npy.into_vec().unwrap())?),
211 _ => Err(PineappleError::MaskError(
212 "A numpy mask array with a valid data type was not detected.",
213 )),
214 },
215 _ => Err(PineappleError::MaskError(
216 "Only plain numpy mask arrays are currentled supported.",
217 )),
218 }
219 }
220}
221
222impl PineappleMask {
227 pub fn label(&mut self) -> Vec<u32> {
236 let mut labels: Vec<u32> = self
237 .as_raw()
238 .iter()
239 .filter(|&&x| x != 0)
240 .cloned()
241 .collect::<BTreeSet<u32>>()
242 .into_iter()
243 .collect();
244
245 if labels.len() == 1 {
248 self.buffer = connected_components(self.width(), self.height(), &self.buffer);
249 labels = self
250 .as_raw()
251 .iter()
252 .filter(|&&x| x != 0)
253 .cloned()
254 .collect::<BTreeSet<u32>>()
255 .into_iter()
256 .collect();
257 }
258
259 labels
260 }
261
262 pub fn polygons(&mut self) -> Result<(Vec<u32>, Polygons), PineappleError> {
264 let labels = self.label();
265 let (labels, contours) =
266 find_labeled_contours(self.width(), self.height(), &self.buffer, &labels);
267
268 Ok((labels, Polygons::new(contours)?))
269 }
270
271 pub fn crop_binary(
281 &self,
282 x: u32,
283 y: u32,
284 w: u32,
285 h: u32,
286 label: u32,
287 ) -> Result<PineappleMask, PineappleError> {
288 if x + w > self.width() || y + h > self.height() {
289 return Err(PineappleError::MaskError("Cropping coordinates out of bounds"));
290 }
291
292 let c = self.channels() as usize;
293 let orig_w = self.width() as usize;
294 let orig_buffer: &[u32] = self.buffer.as_ref();
295
296 let mut new_buffer = Vec::with_capacity((w * h * self.channels()) as usize);
297
298 for row in y..y + h {
299 let start = ((row as usize) * orig_w + (x as usize)) * c;
300 let end = start + (w as usize) * c;
301
302 new_buffer.extend(
303 orig_buffer[start..end]
304 .iter()
305 .map(|&v| if v == label { 1 } else { 0 }),
306 );
307 }
308
309 PineappleMask::new(w, h, self.channels(), new_buffer)
310 }
311}
312
313pub type PineappleMaskView<'a> = PineappleViewBuffer<'a, u32, Vec<u32>>;
317
318impl<'a> PineappleMaskView<'a> {
321 pub fn save<P: AsRef<Path>>(&'a self, path: P, label: &u32) -> Result<(), PineappleError> {
332 let extension = path
333 .as_ref()
334 .extension()
335 .and_then(|s| s.to_str())
336 .map(|s| s.to_lowercase());
337
338 if let Some(ext) = extension {
339 if ext == "npy" {
340 let mut buffer = vec![];
341 let mut writer = npyz::WriteOptions::<u8>::new()
342 .default_dtype()
343 .shape(&[self.height() as u64, self.width() as u64])
344 .writer(&mut buffer)
345 .begin_nd()
346 .map_err(|_| PineappleError::ImageWriteError)?;
347
348 for d in self.iter() {
349 if d == label {
350 writer.push(&255u8).unwrap();
351 } else {
352 writer.push(&0u8).unwrap();
353 };
354 }
355
356 writer.finish().map_err(|_| PineappleError::ImageWriteError)?;
357 std::fs::write(&path, buffer).map_err(|_| PineappleError::ImageWriteError)?;
358
359 return Ok(());
360 }
361
362 if constant::IMAGE_DYNAMIC_FORMATS.iter().any(|e| e == &ext) {
363 ImageBuffer::<Luma<u8>, Vec<u8>>::from_raw(
364 self.width() as u32,
365 self.height() as u32,
366 self.iter()
367 .map(|p| if p == label { 255u8 } else { 0u8 })
368 .collect(),
369 )
370 .unwrap()
371 .save(path)
372 .map_err(|_| PineappleError::ImageWriteError)?;
373
374 return Ok(());
375 }
376 }
377
378 Err(PineappleError::ImageExtensionError)
379 }
380}
381
382pub enum MaskingStyle {
386 Foreground,
387 Background,
388}
389
390#[cfg(test)]
391mod test {
392
393 use super::*;
394
395 const TEST_MASK: &str = "../data/tests/test_mask";
396 const TEST_BLOB: &str = "../data/tests/test_mask_binary_blobs.png";
397
398 #[test]
399 fn test_mask_open() {
400 let extensions = [
401 "_binary.png",
402 "_binary_1.npy",
403 "_binary_1_u16.npy",
404 "_binary_255.npy",
405 "_binary_255_u16.npy",
406 "_integer.png",
407 "_integer.npy",
408 "_integer_u16.npy",
409 ];
410
411 for ext in extensions.into_iter() {
412 let img = PineappleMask::open(format!("{}{}", TEST_MASK, ext));
413 assert!(img.is_ok(), "{}", ext);
414
415 let img = img.unwrap();
416 assert_eq!(img.width(), 621);
417 assert_eq!(img.height(), 621);
418 assert_eq!(img.channels(), 1, "{}", ext);
419 }
420 }
421
422 #[test]
423 fn test_mask_save() {
424 const TEST_DEFAULT: &str = "TEST_SAVE_DEFAULT_MASK.png";
425 const TEST_NUMPY: &str = "TEST_SAVE_NUMPY_MASK.npy";
426
427 let mask = PineappleMask::new(2, 2, 1, vec![0, 255, 0, 0]).unwrap();
428
429 mask.crop_view(0, 0, 2, 2).save(TEST_DEFAULT, &255).unwrap();
430 mask.crop_view(0, 0, 2, 2).save(TEST_NUMPY, &255).unwrap();
431
432 let mask_default = PineappleMask::open(TEST_DEFAULT).unwrap();
433 let mask_numpy = PineappleMask::open(TEST_NUMPY).unwrap();
434
435 assert_eq!(mask.as_raw(), mask_default.as_raw());
436 assert_eq!(mask.as_raw(), mask_numpy.as_raw());
437
438 std::fs::remove_file(TEST_DEFAULT).unwrap();
439 std::fs::remove_file(TEST_NUMPY).unwrap();
440 }
441
442 #[test]
443 fn test_label_blob() {
444 let mut mask = PineappleMask::open(TEST_BLOB).unwrap();
445 let labels = mask.label();
446 assert_eq!(labels.len(), 11);
447 }
448
449 #[test]
450 fn test_mask_crop() {
451 let width = 10;
452 let height = 10;
453 let data: Vec<u32> = (0..width * height).collect();
454
455 let buffer = PineappleMask::new(width, height, 1, data).unwrap();
456 let crop = buffer.crop_view(0, 0, 10, 1);
457
458 for (i, col) in crop.iter().enumerate() {
459 assert_eq!(col, &(i as u32));
460 }
461 }
462
463 #[test]
464 fn test_mask_label() {
465 let width = 10;
466 let height = 10;
467 let mut data: Vec<u32> = vec![0u32; 100];
468
469 data[5] = 1u32;
470 data[25] = 1u32;
471 data[45] = 1u32;
472 data[65] = 1u32;
473 data[85] = 1u32;
474
475 let mut buffer = PineappleMask::new(width, height as u32, 1, data).unwrap();
476
477 let labels = buffer.label();
478 assert_eq!(labels.len(), 5);
479
480 assert_eq!(labels[0], 1);
481 assert_eq!(labels[1], 2);
482 assert_eq!(labels[2], 3);
483 assert_eq!(labels[3], 4);
484 assert_eq!(labels[4], 5);
485 }
486
487 #[test]
488 fn test_mask_crop_binary() {
489 let width = 2;
490 let height = 2;
491 let data: Vec<u32> = vec![0, 1, 2, 3];
492
493 let buffer = PineappleMask::new(width, height, 1, data).unwrap();
494
495 let binary = buffer.crop_binary(0, 0, 2, 2, 1).unwrap();
496
497 assert_eq!(binary.as_raw(), &[0, 1, 0, 0]);
498 }
499}