use std::collections::HashMap;
use panproto_schema::Edge;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ContractionRecord {
pub original_parent: u32,
pub children: SmallVec<u32, 4>,
pub original_edge: Edge,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ContractionTracker {
contracted: HashMap<u32, ContractionRecord>,
absorptions: HashMap<u32, Vec<u32>>,
}
impl ContractionTracker {
#[must_use]
pub fn new() -> Self {
Self {
contracted: HashMap::new(),
absorptions: HashMap::new(),
}
}
pub fn contract(&mut self, node_id: u32, record: ContractionRecord) {
let surviving = self.nearest_surviving(record.original_parent);
self.absorptions.entry(surviving).or_default().push(node_id);
self.contracted.insert(node_id, record);
}
pub fn expand(&mut self, node_id: u32) -> Option<ContractionRecord> {
let record = self.contracted.remove(&node_id)?;
let surviving = self.nearest_surviving(record.original_parent);
if let Some(absorbed) = self.absorptions.get_mut(&surviving) {
if let Some(pos) = absorbed.iter().position(|&n| n == node_id) {
absorbed.remove(pos);
}
if absorbed.is_empty() {
self.absorptions.remove(&surviving);
}
}
Some(record)
}
#[must_use]
pub fn contracted_into(&self, surviving: u32) -> &[u32] {
self.absorptions.get(&surviving).map_or(&[], Vec::as_slice)
}
#[must_use]
pub fn is_contracted(&self, node_id: u32) -> bool {
self.contracted.contains_key(&node_id)
}
#[must_use]
pub fn original_parent(&self, node_id: u32) -> Option<u32> {
self.contracted.get(&node_id).map(|r| r.original_parent)
}
fn nearest_surviving(&self, mut node: u32) -> u32 {
while let Some(record) = self.contracted.get(&node) {
node = record.original_parent;
}
node
}
}
impl Default for ContractionTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
#[allow(clippy::expect_used)]
mod tests {
use panproto_gat::Name;
use panproto_schema::Edge;
use smallvec::SmallVec;
use super::*;
fn test_edge() -> Edge {
Edge {
src: Name::from("a"),
tgt: Name::from("b"),
kind: Name::from("prop"),
name: None,
}
}
fn make_record(parent: u32, children: &[u32]) -> ContractionRecord {
ContractionRecord {
original_parent: parent,
children: children.iter().copied().collect::<SmallVec<u32, 4>>(),
original_edge: test_edge(),
}
}
#[test]
fn contract_records_children() {
let mut tracker = ContractionTracker::new();
tracker.contract(5, make_record(1, &[10, 11]));
assert!(tracker.is_contracted(5));
let record = tracker.contracted.get(&5).unwrap();
assert_eq!(record.children.as_slice(), &[10, 11]);
assert_eq!(record.original_parent, 1);
}
#[test]
fn expand_undoes_contraction() {
let mut tracker = ContractionTracker::new();
tracker.contract(5, make_record(1, &[10, 11]));
assert!(tracker.is_contracted(5));
let record = tracker.expand(5).unwrap();
assert_eq!(record.original_parent, 1);
assert!(!tracker.is_contracted(5));
assert!(tracker.contracted_into(1).is_empty());
}
#[test]
fn contracted_into_tracks_absorptions() {
let mut tracker = ContractionTracker::new();
tracker.contract(5, make_record(1, &[10]));
tracker.contract(6, make_record(1, &[11]));
let absorbed = tracker.contracted_into(1);
assert!(absorbed.contains(&5));
assert!(absorbed.contains(&6));
assert_eq!(absorbed.len(), 2);
}
#[test]
fn is_contracted_checks_correctly() {
let mut tracker = ContractionTracker::new();
tracker.contract(5, make_record(1, &[10]));
assert!(tracker.is_contracted(5));
assert!(!tracker.is_contracted(1));
assert!(!tracker.is_contracted(10));
assert!(!tracker.is_contracted(999));
}
#[test]
fn multiple_contractions() {
let mut tracker = ContractionTracker::new();
tracker.contract(3, make_record(1, &[30, 31]));
tracker.contract(4, make_record(2, &[40]));
tracker.contract(5, make_record(2, &[50, 51, 52]));
assert!(tracker.is_contracted(3));
assert!(tracker.is_contracted(4));
assert!(tracker.is_contracted(5));
assert_eq!(tracker.contracted_into(1), &[3]);
let into_2 = tracker.contracted_into(2);
assert!(into_2.contains(&4));
assert!(into_2.contains(&5));
assert_eq!(tracker.original_parent(3), Some(1));
assert_eq!(tracker.original_parent(4), Some(2));
assert_eq!(tracker.original_parent(5), Some(2));
assert_eq!(tracker.original_parent(99), None);
let record = tracker.expand(4).unwrap();
assert_eq!(record.original_parent, 2);
assert!(!tracker.is_contracted(4));
assert_eq!(tracker.contracted_into(2), &[5]);
}
}