use crate::error::{Result, VisionError};
use scirs2_core::ndarray::{Array2, Array3};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub struct PanopticSegment {
pub id: u64,
pub category: usize,
pub is_thing: bool,
pub area: usize,
pub color: Option<(u8, u8, u8)>,
}
impl PanopticSegment {
pub fn new(id: u64, category: usize, is_thing: bool, area: usize) -> Self {
Self { id, category, is_thing, area, color: None }
}
pub fn with_color(mut self, color: (u8, u8, u8)) -> Self {
self.color = Some(color);
self
}
}
#[derive(Debug, Clone)]
struct MatchedPair {
iou: f64,
}
pub fn panoptic_quality(
pred_map: &Array2<i64>,
gt_map: &Array2<i64>,
pred_segs: &[PanopticSegment],
gt_segs: &[PanopticSegment],
) -> Result<(f64, f64, f64)> {
let (ph, pw) = pred_map.dim();
let (gh, gw) = gt_map.dim();
if ph != gh || pw != gw {
return Err(VisionError::InvalidParameter(format!(
"panoptic_quality: pred_map shape {}×{} != gt_map shape {}×{}",
ph, pw, gh, gw
)));
}
let h = ph;
let w = pw;
let pred_ids: std::collections::HashSet<i64> = pred_segs.iter().map(|s| s.id as i64).collect();
let gt_ids: std::collections::HashSet<i64> = gt_segs.iter().map(|s| s.id as i64).collect();
let mut intersect: HashMap<(i64, i64), u64> = HashMap::new();
let mut pred_counts: HashMap<i64, u64> = HashMap::new();
let mut gt_counts: HashMap<i64, u64> = HashMap::new();
for y in 0..h {
for x in 0..w {
let p = pred_map[[y, x]];
let g = gt_map[[y, x]];
if p == 0 || g == 0 {
continue; }
*pred_counts.entry(p).or_insert(0) += 1;
*gt_counts.entry(g).or_insert(0) += 1;
*intersect.entry((p, g)).or_insert(0) += 1;
}
}
const MATCH_THRESHOLD: f64 = 0.5;
let mut matches: Vec<MatchedPair> = Vec::new();
let mut matched_pred: std::collections::HashSet<i64> = std::collections::HashSet::new();
let mut matched_gt: std::collections::HashSet<i64> = std::collections::HashSet::new();
let pred_cat: HashMap<i64, usize> = pred_segs.iter().map(|s| (s.id as i64, s.category)).collect();
let gt_cat: HashMap<i64, usize> = gt_segs.iter().map(|s| (s.id as i64, s.category)).collect();
let mut candidate_pairs: Vec<((i64, i64), f64)> = intersect
.iter()
.filter_map(|(&(p, g), &inter)| {
if !pred_ids.contains(&p) || !gt_ids.contains(&g) {
return None;
}
if p != g {
return None;
}
let pc = *pred_counts.get(&p).unwrap_or(&0) as f64;
let gc = *gt_counts.get(&g).unwrap_or(&0) as f64;
let union = pc + gc - inter as f64;
if union <= 0.0 {
return None;
}
Some(((p, g), inter as f64 / union))
})
.collect();
candidate_pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
for ((p, g), iou) in candidate_pairs {
if iou < MATCH_THRESHOLD {
break;
}
if matched_pred.contains(&p) || matched_gt.contains(&g) {
continue;
}
matched_pred.insert(p);
matched_gt.insert(g);
matches.push(MatchedPair { iou });
}
let _ = (&pred_cat, >_cat);
let tp = matches.len() as f64;
let fp = pred_ids.iter().filter(|id| **id != 0 && !matched_pred.contains(id)).count() as f64;
let fn_ = gt_ids.iter().filter(|id| **id != 0 && !matched_gt.contains(id)).count() as f64;
let iou_sum: f64 = matches.iter().map(|m| m.iou).sum();
let denom = tp + 0.5 * fp + 0.5 * fn_;
let pq = if denom > 0.0 { iou_sum / denom } else { 1.0 };
let sq = if tp > 0.0 { iou_sum / tp } else { 1.0 };
let rq = if denom > 0.0 { tp / denom } else { 1.0 };
Ok((pq, sq, rq))
}
pub fn merge_semantic_instance(
semantic_map: &Array2<i32>,
instance_map: &Array2<i32>,
thing_classes: &[usize],
) -> Result<(Array2<i64>, Vec<PanopticSegment>)> {
let (sh, sw) = semantic_map.dim();
let (ih, iw) = instance_map.dim();
if sh != ih || sw != iw {
return Err(VisionError::InvalidParameter(format!(
"merge_semantic_instance: semantic_map shape {}×{} != instance_map shape {}×{}",
sh, sw, ih, iw
)));
}
const INSTANCE_MULTIPLIER: i64 = 1000;
let thing_set: std::collections::HashSet<usize> = thing_classes.iter().cloned().collect();
let mut panoptic_map = Array2::<i64>::zeros((sh, sw));
let mut segment_areas: HashMap<i64, usize> = HashMap::new();
let mut segment_meta: HashMap<i64, (usize, bool)> = HashMap::new();
for y in 0..sh {
for x in 0..sw {
let cat = semantic_map[[y, x]] as usize;
let inst = instance_map[[y, x]] as i64;
let is_thing = thing_set.contains(&cat);
let pan_id = if is_thing && inst > 0 {
cat as i64 * INSTANCE_MULTIPLIER + inst
} else {
cat as i64 * INSTANCE_MULTIPLIER
};
panoptic_map[[y, x]] = pan_id;
*segment_areas.entry(pan_id).or_insert(0) += 1;
segment_meta.entry(pan_id).or_insert((cat, is_thing));
}
}
let mut segments: Vec<PanopticSegment> = segment_areas
.iter()
.map(|(&pan_id, &area)| {
let (cat, is_thing) = segment_meta.get(&pan_id).cloned().unwrap_or((0, false));
PanopticSegment::new(pan_id as u64, cat, is_thing, area)
})
.collect();
segments.sort_by_key(|s| s.id);
Ok((panoptic_map, segments))
}
#[derive(Debug, Clone)]
pub struct InstancePrediction {
pub mask: Array2<u8>,
pub category: usize,
pub score: f32,
pub is_thing: bool,
}
impl InstancePrediction {
pub fn new(mask: Array2<u8>, category: usize, score: f32, is_thing: bool) -> Self {
Self { mask, category, score, is_thing }
}
}
pub fn instance_to_panoptic(
instances: &[InstancePrediction],
height: usize,
width: usize,
) -> Result<(Array2<i64>, Vec<PanopticSegment>)> {
for (i, inst) in instances.iter().enumerate() {
let (mh, mw) = inst.mask.dim();
if mh != height || mw != width {
return Err(VisionError::InvalidParameter(format!(
"instance_to_panoptic: instance {} mask shape {}×{} != expected {}×{}",
i, mh, mw, height, width
)));
}
}
const INSTANCE_MULTIPLIER: i64 = 1000;
let mut order: Vec<usize> = (0..instances.len()).collect();
order.sort_by(|&a, &b| {
instances[b]
.score
.partial_cmp(&instances[a].score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut panoptic_map = Array2::<i64>::zeros((height, width));
let mut segments: Vec<PanopticSegment> = Vec::new();
let mut cat_counters: HashMap<usize, i64> = HashMap::new();
for idx in order {
let inst = &instances[idx];
let rank = {
let counter = cat_counters.entry(inst.category).or_insert(0);
*counter += 1;
*counter
};
let pan_id = inst.category as i64 * INSTANCE_MULTIPLIER + rank;
let mut area = 0usize;
for y in 0..height {
for x in 0..width {
if inst.mask[[y, x]] != 0 && panoptic_map[[y, x]] == 0 {
panoptic_map[[y, x]] = pan_id;
area += 1;
}
}
}
if area > 0 {
segments.push(PanopticSegment::new(
pan_id as u64,
inst.category,
inst.is_thing,
area,
));
}
}
Ok((panoptic_map, segments))
}
pub fn colorize_panoptic(panoptic_map: &Array2<i64>) -> Array3<u8> {
let (h, w) = panoptic_map.dim();
let mut out = Array3::<u8>::zeros((h, w, 3));
for y in 0..h {
for x in 0..w {
let id = panoptic_map[[y, x]];
if id == 0 {
continue; }
let hashed = id.wrapping_mul(6364136223846793005_i64).wrapping_add(1442695040888963407);
out[[y, x, 0]] = ((hashed >> 16) & 0xFF) as u8;
out[[y, x, 1]] = ((hashed >> 8) & 0xFF) as u8;
out[[y, x, 2]] = (hashed & 0xFF) as u8;
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
fn make_map(h: usize, w: usize, val: i64) -> Array2<i64> {
Array2::from_elem((h, w), val)
}
#[test]
fn test_panoptic_segment_construction() {
let seg = PanopticSegment::new(42, 3, true, 100).with_color((255, 0, 0));
assert_eq!(seg.id, 42);
assert_eq!(seg.category, 3);
assert!(seg.is_thing);
assert_eq!(seg.area, 100);
assert_eq!(seg.color, Some((255, 0, 0)));
}
#[test]
fn test_pq_perfect_match() {
let map = Array2::from_shape_fn((4, 4), |(y, x)| {
if x < 2 { 1i64 } else { 2i64 }
});
let segs = vec![
PanopticSegment::new(1, 0, true, 8),
PanopticSegment::new(2, 1, true, 8),
];
let (pq, sq, rq) = panoptic_quality(&map, &map, &segs, &segs).expect("pq failed");
assert!((pq - 1.0).abs() < 1e-9, "PQ={}", pq);
assert!((sq - 1.0).abs() < 1e-9, "SQ={}", sq);
assert!((rq - 1.0).abs() < 1e-9, "RQ={}", rq);
}
#[test]
fn test_pq_no_overlap() {
let pred_map = make_map(4, 4, 1);
let gt_map = make_map(4, 4, 2);
let pred_segs = vec![PanopticSegment::new(1, 0, true, 16)];
let gt_segs = vec![PanopticSegment::new(2, 0, true, 16)];
let (pq, _sq, rq) = panoptic_quality(&pred_map, >_map, &pred_segs, >_segs)
.expect("pq failed");
assert!((pq).abs() < 1e-9, "PQ should be 0, got {}", pq);
assert!((rq).abs() < 1e-9, "RQ should be 0, got {}", rq);
}
#[test]
fn test_pq_shape_mismatch() {
let pred = make_map(4, 4, 1);
let gt = make_map(5, 4, 1);
let res = panoptic_quality(&pred, >, &[], &[]);
assert!(res.is_err());
}
#[test]
fn test_pq_empty_segs() {
let pred = Array2::<i64>::zeros((4, 4));
let gt = Array2::<i64>::zeros((4, 4));
let (pq, sq, rq) = panoptic_quality(&pred, >, &[], &[]).expect("pq failed");
assert!((pq - 1.0).abs() < 1e-9, "PQ={}", pq);
let _ = (sq, rq);
}
#[test]
fn test_merge_basic() {
let semantic = Array2::<i32>::from_elem((4, 4), 1); let instance = Array2::<i32>::from_shape_fn((4, 4), |(_, x)| if x < 2 { 1 } else { 2 });
let (pan_map, segs) =
merge_semantic_instance(&semantic, &instance, &[1]).expect("merge failed");
assert_eq!(pan_map.dim(), (4, 4));
assert_eq!(pan_map[[0, 0]], 1001);
assert_eq!(pan_map[[0, 2]], 1002);
assert!(!segs.is_empty());
}
#[test]
fn test_merge_stuff_class() {
let semantic = Array2::<i32>::from_elem((3, 3), 5); let instance = Array2::<i32>::from_shape_fn((3, 3), |(y, _)| y as i32 + 1);
let (pan_map, segs) =
merge_semantic_instance(&semantic, &instance, &[]).expect("merge stuff failed");
for v in pan_map.iter() {
assert_eq!(*v, 5000, "unexpected stuff pan_id {}", v);
}
assert_eq!(segs.len(), 1);
}
#[test]
fn test_merge_shape_mismatch() {
let semantic = Array2::<i32>::zeros((3, 3));
let instance = Array2::<i32>::zeros((4, 3));
let res = merge_semantic_instance(&semantic, &instance, &[]);
assert!(res.is_err());
}
#[test]
fn test_instance_to_panoptic_basic() {
let mut m1 = Array2::<u8>::zeros((4, 4));
m1[[0, 0]] = 1; m1[[0, 1]] = 1;
let mut m2 = Array2::<u8>::zeros((4, 4));
m2[[3, 3]] = 1; m2[[3, 2]] = 1;
let instances = vec![
InstancePrediction::new(m1, 2, 0.9, true),
InstancePrediction::new(m2, 2, 0.8, true),
];
let (pan_map, segs) =
instance_to_panoptic(&instances, 4, 4).expect("i2p failed");
assert_eq!(pan_map.dim(), (4, 4));
assert!(pan_map[[0, 0]] != 0);
assert!(pan_map[[3, 3]] != 0);
assert!(pan_map[[0, 0]] != pan_map[[3, 3]]);
assert_eq!(segs.len(), 2);
}
#[test]
fn test_instance_to_panoptic_overlap_resolved() {
let full_mask = Array2::<u8>::from_elem((4, 4), 1u8);
let instances = vec![
InstancePrediction::new(full_mask.clone(), 0, 0.9, true), InstancePrediction::new(full_mask, 0, 0.5, true), ];
let (pan_map, segs) = instance_to_panoptic(&instances, 4, 4).expect("i2p failed");
assert_eq!(segs.len(), 1, "second instance should have area 0");
for v in pan_map.iter() {
assert_ne!(*v, 0);
}
}
#[test]
fn test_instance_to_panoptic_mask_size_mismatch() {
let bad_mask = Array2::<u8>::zeros((3, 3));
let instances = vec![InstancePrediction::new(bad_mask, 0, 0.9, true)];
let res = instance_to_panoptic(&instances, 4, 4);
assert!(res.is_err());
}
#[test]
fn test_colorize_shape() {
let pan = make_map(6, 8, 1001);
let color = colorize_panoptic(&pan);
assert_eq!(color.dim(), (6, 8, 3));
}
#[test]
fn test_colorize_background_black() {
let pan = Array2::<i64>::zeros((4, 4));
let color = colorize_panoptic(&pan);
for y in 0..4 {
for x in 0..4 {
assert_eq!(color[[y, x, 0]], 0);
assert_eq!(color[[y, x, 1]], 0);
assert_eq!(color[[y, x, 2]], 0);
}
}
}
}