Skip to main content

panproto_inst/
contraction.rs

1//! Incremental contraction tracker for ancestor contraction.
2//!
3//! Tracks which nodes have been contracted (absorbed) into their nearest
4//! surviving ancestor, allowing individual contractions to be undone
5//! without recomputing the entire contraction map.
6
7use std::collections::HashMap;
8
9use panproto_schema::Edge;
10use serde::{Deserialize, Serialize};
11use smallvec::SmallVec;
12
13/// Record of a single node contraction.
14#[derive(Clone, Debug, Serialize, Deserialize)]
15pub struct ContractionRecord {
16    /// The original parent of the contracted node.
17    pub original_parent: u32,
18    /// Children that the contracted node had before contraction.
19    pub children: SmallVec<u32, 4>,
20    /// The original edge connecting the contracted node to its parent.
21    pub original_edge: Edge,
22}
23
24/// Incremental tracker for ancestor contractions in the edit lens pipeline.
25///
26/// When a node is contracted, it is absorbed into the nearest surviving
27/// ancestor. This tracker records those contractions and supports undoing
28/// them individually.
29#[derive(Clone, Debug, Serialize, Deserialize)]
30pub struct ContractionTracker {
31    contracted: HashMap<u32, ContractionRecord>,
32    absorptions: HashMap<u32, Vec<u32>>,
33}
34
35impl ContractionTracker {
36    /// Create a new, empty contraction tracker.
37    #[must_use]
38    pub fn new() -> Self {
39        Self {
40            contracted: HashMap::new(),
41            absorptions: HashMap::new(),
42        }
43    }
44
45    /// Record a contraction of `node_id`.
46    ///
47    /// The node is absorbed into its nearest surviving ancestor, which is
48    /// determined by walking up from `record.original_parent` through any
49    /// already-contracted nodes.
50    pub fn contract(&mut self, node_id: u32, record: ContractionRecord) {
51        let surviving = self.nearest_surviving(record.original_parent);
52        self.absorptions.entry(surviving).or_default().push(node_id);
53        self.contracted.insert(node_id, record);
54    }
55
56    /// Undo a contraction, removing the record and cleaning up absorptions.
57    ///
58    /// Returns the record if the node was contracted, or `None` otherwise.
59    pub fn expand(&mut self, node_id: u32) -> Option<ContractionRecord> {
60        let record = self.contracted.remove(&node_id)?;
61
62        // Remove from absorptions of the surviving ancestor
63        let surviving = self.nearest_surviving(record.original_parent);
64        if let Some(absorbed) = self.absorptions.get_mut(&surviving) {
65            if let Some(pos) = absorbed.iter().position(|&n| n == node_id) {
66                absorbed.remove(pos);
67            }
68            if absorbed.is_empty() {
69                self.absorptions.remove(&surviving);
70            }
71        }
72
73        Some(record)
74    }
75
76    /// Which contracted nodes were absorbed by the given surviving node.
77    #[must_use]
78    pub fn contracted_into(&self, surviving: u32) -> &[u32] {
79        self.absorptions.get(&surviving).map_or(&[], Vec::as_slice)
80    }
81
82    /// Check whether a node has been contracted.
83    #[must_use]
84    pub fn is_contracted(&self, node_id: u32) -> bool {
85        self.contracted.contains_key(&node_id)
86    }
87
88    /// Return the original parent of a contracted node.
89    #[must_use]
90    pub fn original_parent(&self, node_id: u32) -> Option<u32> {
91        self.contracted.get(&node_id).map(|r| r.original_parent)
92    }
93
94    /// Walk up from a node through any contracted ancestors to find the
95    /// nearest surviving (non-contracted) ancestor.
96    fn nearest_surviving(&self, mut node: u32) -> u32 {
97        while let Some(record) = self.contracted.get(&node) {
98            node = record.original_parent;
99        }
100        node
101    }
102}
103
104impl Default for ContractionTracker {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110#[cfg(test)]
111#[allow(clippy::unwrap_used)]
112#[allow(clippy::expect_used)]
113mod tests {
114    use panproto_gat::Name;
115    use panproto_schema::Edge;
116    use smallvec::SmallVec;
117
118    use super::*;
119
120    fn test_edge() -> Edge {
121        Edge {
122            src: Name::from("a"),
123            tgt: Name::from("b"),
124            kind: Name::from("prop"),
125            name: None,
126        }
127    }
128
129    fn make_record(parent: u32, children: &[u32]) -> ContractionRecord {
130        ContractionRecord {
131            original_parent: parent,
132            children: children.iter().copied().collect::<SmallVec<u32, 4>>(),
133            original_edge: test_edge(),
134        }
135    }
136
137    #[test]
138    fn contract_records_children() {
139        let mut tracker = ContractionTracker::new();
140        tracker.contract(5, make_record(1, &[10, 11]));
141
142        assert!(tracker.is_contracted(5));
143        let record = tracker.contracted.get(&5).unwrap();
144        assert_eq!(record.children.as_slice(), &[10, 11]);
145        assert_eq!(record.original_parent, 1);
146    }
147
148    #[test]
149    fn expand_undoes_contraction() {
150        let mut tracker = ContractionTracker::new();
151        tracker.contract(5, make_record(1, &[10, 11]));
152
153        assert!(tracker.is_contracted(5));
154        let record = tracker.expand(5).unwrap();
155        assert_eq!(record.original_parent, 1);
156
157        assert!(!tracker.is_contracted(5));
158        assert!(tracker.contracted_into(1).is_empty());
159    }
160
161    #[test]
162    fn contracted_into_tracks_absorptions() {
163        let mut tracker = ContractionTracker::new();
164        // Node 5 is contracted into surviving node 1
165        tracker.contract(5, make_record(1, &[10]));
166        // Node 6 is also contracted into surviving node 1
167        tracker.contract(6, make_record(1, &[11]));
168
169        let absorbed = tracker.contracted_into(1);
170        assert!(absorbed.contains(&5));
171        assert!(absorbed.contains(&6));
172        assert_eq!(absorbed.len(), 2);
173    }
174
175    #[test]
176    fn is_contracted_checks_correctly() {
177        let mut tracker = ContractionTracker::new();
178        tracker.contract(5, make_record(1, &[10]));
179
180        assert!(tracker.is_contracted(5));
181        assert!(!tracker.is_contracted(1));
182        assert!(!tracker.is_contracted(10));
183        assert!(!tracker.is_contracted(999));
184    }
185
186    #[test]
187    fn multiple_contractions() {
188        let mut tracker = ContractionTracker::new();
189
190        // Contract 3 into 1, 4 into 2, 5 into 2
191        tracker.contract(3, make_record(1, &[30, 31]));
192        tracker.contract(4, make_record(2, &[40]));
193        tracker.contract(5, make_record(2, &[50, 51, 52]));
194
195        assert!(tracker.is_contracted(3));
196        assert!(tracker.is_contracted(4));
197        assert!(tracker.is_contracted(5));
198
199        assert_eq!(tracker.contracted_into(1), &[3]);
200        let into_2 = tracker.contracted_into(2);
201        assert!(into_2.contains(&4));
202        assert!(into_2.contains(&5));
203
204        assert_eq!(tracker.original_parent(3), Some(1));
205        assert_eq!(tracker.original_parent(4), Some(2));
206        assert_eq!(tracker.original_parent(5), Some(2));
207        assert_eq!(tracker.original_parent(99), None);
208
209        // Expand one
210        let record = tracker.expand(4).unwrap();
211        assert_eq!(record.original_parent, 2);
212        assert!(!tracker.is_contracted(4));
213        assert_eq!(tracker.contracted_into(2), &[5]);
214    }
215}