use std::sync::Arc;
use serde::{Deserialize, Serialize};
use vernier_mask::Rle;
use crate::error::EvalError;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Segmentation {
Polygons(Vec<Vec<f64>>),
Rle(SegmentationRle),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SegmentationRle {
pub size: [u32; 2],
pub counts: SegmentationRleCounts,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum SegmentationRleCounts {
Compressed(String),
Uncompressed(#[serde(with = "arc_u32_serde")] Arc<[u32]>),
}
mod arc_u32_serde {
use std::sync::Arc;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub(super) fn serialize<S: Serializer>(value: &Arc<[u32]>, ser: S) -> Result<S::Ok, S::Error> {
value.as_ref().serialize(ser)
}
pub(super) fn deserialize<'de, D: Deserializer<'de>>(de: D) -> Result<Arc<[u32]>, D::Error> {
Vec::<u32>::deserialize(de).map(Arc::from)
}
}
impl Segmentation {
pub fn to_rle(&self, h: u32, w: u32) -> Result<Rle, EvalError> {
match self {
Self::Polygons(polys) => Ok(Rle::from_polygons(polys, h, w)?),
Self::Rle(rle) => {
let [rh, rw] = rle.size;
if rh != h || rw != w {
return Err(EvalError::DimensionMismatch {
detail: format!(
"segmentation declares size [{rh}, {rw}] but image is [{h}, {w}]"
),
});
}
match &rle.counts {
SegmentationRleCounts::Compressed(s) => {
Ok(Rle::from_string_bytes(s.as_bytes(), h, w)?)
}
SegmentationRleCounts::Uncompressed(counts) => Ok(Rle {
h,
w,
counts: Arc::clone(counts),
}),
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn parse(json: &str) -> Segmentation {
serde_json::from_str(json).unwrap()
}
#[test]
fn parses_polygon_shape() {
let s = parse("[[0.0, 0.0, 2.0, 0.0, 2.0, 2.0, 0.0, 2.0]]");
match s {
Segmentation::Polygons(p) => {
assert_eq!(p.len(), 1);
assert_eq!(p[0].len(), 8);
}
other => panic!("expected Polygons, got {other:?}"),
}
}
#[test]
fn parses_compressed_rle_shape() {
let s = parse(r#"{"size": [10, 10], "counts": "PPYo`0"}"#);
match s {
Segmentation::Rle(rle) => {
assert_eq!(rle.size, [10, 10]);
assert!(matches!(rle.counts, SegmentationRleCounts::Compressed(_)));
}
other => panic!("expected Rle, got {other:?}"),
}
}
#[test]
fn parses_uncompressed_rle_shape() {
let s = parse(r#"{"size": [2, 2], "counts": [0, 4]}"#);
match s {
Segmentation::Rle(rle) => {
assert_eq!(rle.size, [2, 2]);
match rle.counts {
SegmentationRleCounts::Uncompressed(c) => {
assert_eq!(&c[..], &[0u32, 4][..]);
}
other => panic!("expected Uncompressed, got {other:?}"),
}
}
other => panic!("expected Rle, got {other:?}"),
}
}
#[test]
fn polygon_to_rle_rasterizes_and_unions_k2() {
let json = r#"[
[0.0, 0.0, 2.0, 0.0, 2.0, 2.0, 0.0, 2.0],
[3.0, 0.0, 5.0, 0.0, 5.0, 2.0, 3.0, 2.0]
]"#;
let s: Segmentation = serde_json::from_str(json).unwrap();
let rle = s.to_rle(8, 8).unwrap();
assert_eq!(rle.area(), 8);
}
#[test]
fn compressed_rle_to_rle_round_trips() {
let original = Rle {
h: 4,
w: 4,
counts: vec![0u32, 4, 4, 4, 4].into(),
};
let counts = String::from_utf8(original.to_string_bytes()).unwrap();
let json = format!(r#"{{"size": [4, 4], "counts": "{counts}"}}"#);
let s: Segmentation = serde_json::from_str(&json).unwrap();
let rle = s.to_rle(4, 4).unwrap();
assert_eq!(rle, original);
}
#[test]
fn uncompressed_rle_to_rle_uses_counts_verbatim() {
let s = parse(r#"{"size": [2, 2], "counts": [0, 4]}"#);
let rle = s.to_rle(2, 2).unwrap();
assert_eq!(rle.h, 2);
assert_eq!(rle.w, 2);
assert_eq!(&rle.counts[..], &[0u32, 4][..]);
assert_eq!(rle.area(), 4);
}
#[test]
fn rle_size_mismatch_errors_h2_corrected() {
let s = parse(r#"{"size": [10, 10], "counts": [0, 100]}"#);
let err = s.to_rle(20, 20).unwrap_err();
match err {
EvalError::DimensionMismatch { detail } => {
assert!(detail.contains("[10, 10]"));
assert!(detail.contains("[20, 20]"));
}
other => panic!("expected DimensionMismatch, got {other:?}"),
}
}
#[test]
fn empty_polygon_list_yields_all_background_at_requested_shape() {
let s = parse("[]");
let rle = s.to_rle(4, 4).unwrap();
assert_eq!(rle.h, 4);
assert_eq!(rle.w, 4);
assert_eq!(rle.area(), 0);
}
#[test]
fn polygon_with_too_few_vertices_propagates_k1_error() {
let s = parse("[[0.0, 0.0, 1.0, 1.0]]");
let err = s.to_rle(8, 8).unwrap_err();
assert!(matches!(err, EvalError::Mask(_)));
}
}