panproto_inst/
contraction.rs1use std::collections::HashMap;
8
9use panproto_schema::Edge;
10use serde::{Deserialize, Serialize};
11use smallvec::SmallVec;
12
13#[derive(Clone, Debug, Serialize, Deserialize)]
15pub struct ContractionRecord {
16 pub original_parent: u32,
18 pub children: SmallVec<u32, 4>,
20 pub original_edge: Edge,
22}
23
24#[derive(Clone, Debug, Serialize, Deserialize)]
30pub struct ContractionTracker {
31 contracted: HashMap<u32, ContractionRecord>,
32 absorptions: HashMap<u32, Vec<u32>>,
33}
34
35impl ContractionTracker {
36 #[must_use]
38 pub fn new() -> Self {
39 Self {
40 contracted: HashMap::new(),
41 absorptions: HashMap::new(),
42 }
43 }
44
45 pub fn contract(&mut self, node_id: u32, record: ContractionRecord) {
51 let surviving = self.nearest_surviving(record.original_parent);
52 self.absorptions.entry(surviving).or_default().push(node_id);
53 self.contracted.insert(node_id, record);
54 }
55
56 pub fn expand(&mut self, node_id: u32) -> Option<ContractionRecord> {
60 let record = self.contracted.remove(&node_id)?;
61
62 let surviving = self.nearest_surviving(record.original_parent);
64 if let Some(absorbed) = self.absorptions.get_mut(&surviving) {
65 if let Some(pos) = absorbed.iter().position(|&n| n == node_id) {
66 absorbed.remove(pos);
67 }
68 if absorbed.is_empty() {
69 self.absorptions.remove(&surviving);
70 }
71 }
72
73 Some(record)
74 }
75
76 #[must_use]
78 pub fn contracted_into(&self, surviving: u32) -> &[u32] {
79 self.absorptions.get(&surviving).map_or(&[], Vec::as_slice)
80 }
81
82 #[must_use]
84 pub fn is_contracted(&self, node_id: u32) -> bool {
85 self.contracted.contains_key(&node_id)
86 }
87
88 #[must_use]
90 pub fn original_parent(&self, node_id: u32) -> Option<u32> {
91 self.contracted.get(&node_id).map(|r| r.original_parent)
92 }
93
94 fn nearest_surviving(&self, mut node: u32) -> u32 {
97 while let Some(record) = self.contracted.get(&node) {
98 node = record.original_parent;
99 }
100 node
101 }
102}
103
104impl Default for ContractionTracker {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110#[cfg(test)]
111#[allow(clippy::unwrap_used)]
112#[allow(clippy::expect_used)]
113mod tests {
114 use panproto_gat::Name;
115 use panproto_schema::Edge;
116 use smallvec::SmallVec;
117
118 use super::*;
119
120 fn test_edge() -> Edge {
121 Edge {
122 src: Name::from("a"),
123 tgt: Name::from("b"),
124 kind: Name::from("prop"),
125 name: None,
126 }
127 }
128
129 fn make_record(parent: u32, children: &[u32]) -> ContractionRecord {
130 ContractionRecord {
131 original_parent: parent,
132 children: children.iter().copied().collect::<SmallVec<u32, 4>>(),
133 original_edge: test_edge(),
134 }
135 }
136
137 #[test]
138 fn contract_records_children() {
139 let mut tracker = ContractionTracker::new();
140 tracker.contract(5, make_record(1, &[10, 11]));
141
142 assert!(tracker.is_contracted(5));
143 let record = tracker.contracted.get(&5).unwrap();
144 assert_eq!(record.children.as_slice(), &[10, 11]);
145 assert_eq!(record.original_parent, 1);
146 }
147
148 #[test]
149 fn expand_undoes_contraction() {
150 let mut tracker = ContractionTracker::new();
151 tracker.contract(5, make_record(1, &[10, 11]));
152
153 assert!(tracker.is_contracted(5));
154 let record = tracker.expand(5).unwrap();
155 assert_eq!(record.original_parent, 1);
156
157 assert!(!tracker.is_contracted(5));
158 assert!(tracker.contracted_into(1).is_empty());
159 }
160
161 #[test]
162 fn contracted_into_tracks_absorptions() {
163 let mut tracker = ContractionTracker::new();
164 tracker.contract(5, make_record(1, &[10]));
166 tracker.contract(6, make_record(1, &[11]));
168
169 let absorbed = tracker.contracted_into(1);
170 assert!(absorbed.contains(&5));
171 assert!(absorbed.contains(&6));
172 assert_eq!(absorbed.len(), 2);
173 }
174
175 #[test]
176 fn is_contracted_checks_correctly() {
177 let mut tracker = ContractionTracker::new();
178 tracker.contract(5, make_record(1, &[10]));
179
180 assert!(tracker.is_contracted(5));
181 assert!(!tracker.is_contracted(1));
182 assert!(!tracker.is_contracted(10));
183 assert!(!tracker.is_contracted(999));
184 }
185
186 #[test]
187 fn multiple_contractions() {
188 let mut tracker = ContractionTracker::new();
189
190 tracker.contract(3, make_record(1, &[30, 31]));
192 tracker.contract(4, make_record(2, &[40]));
193 tracker.contract(5, make_record(2, &[50, 51, 52]));
194
195 assert!(tracker.is_contracted(3));
196 assert!(tracker.is_contracted(4));
197 assert!(tracker.is_contracted(5));
198
199 assert_eq!(tracker.contracted_into(1), &[3]);
200 let into_2 = tracker.contracted_into(2);
201 assert!(into_2.contains(&4));
202 assert!(into_2.contains(&5));
203
204 assert_eq!(tracker.original_parent(3), Some(1));
205 assert_eq!(tracker.original_parent(4), Some(2));
206 assert_eq!(tracker.original_parent(5), Some(2));
207 assert_eq!(tracker.original_parent(99), None);
208
209 let record = tracker.expand(4).unwrap();
211 assert_eq!(record.original_parent, 2);
212 assert!(!tracker.is_contracted(4));
213 assert_eq!(tracker.contracted_into(2), &[5]);
214 }
215}