1use crate::tree::Tree;
8use core::hash::Hash;
9use indextree::NodeId;
10use std::collections::{HashMap, HashSet};
11
12#[derive(Debug, Default)]
14pub struct Matching {
15 a_to_b: HashMap<NodeId, NodeId>,
17 b_to_a: HashMap<NodeId, NodeId>,
19}
20
21impl Matching {
22 pub fn new() -> Self {
24 Self::default()
25 }
26
27 pub fn add(&mut self, a: NodeId, b: NodeId) {
29 self.a_to_b.insert(a, b);
30 self.b_to_a.insert(b, a);
31 }
32
33 pub fn contains_a(&self, a: NodeId) -> bool {
35 self.a_to_b.contains_key(&a)
36 }
37
38 pub fn contains_b(&self, b: NodeId) -> bool {
40 self.b_to_a.contains_key(&b)
41 }
42
43 pub fn get_b(&self, a: NodeId) -> Option<NodeId> {
45 self.a_to_b.get(&a).copied()
46 }
47
48 pub fn get_a(&self, b: NodeId) -> Option<NodeId> {
50 self.b_to_a.get(&b).copied()
51 }
52
53 pub fn pairs(&self) -> impl Iterator<Item = (NodeId, NodeId)> + '_ {
55 self.a_to_b.iter().map(|(&a, &b)| (a, b))
56 }
57
58 pub fn len(&self) -> usize {
60 self.a_to_b.len()
61 }
62
63 pub fn is_empty(&self) -> bool {
65 self.a_to_b.is_empty()
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct MatchingConfig {
72 pub similarity_threshold: f64,
75
76 pub min_height: usize,
79}
80
81impl Default for MatchingConfig {
82 fn default() -> Self {
83 Self {
84 similarity_threshold: 0.5,
85 min_height: 1,
86 }
87 }
88}
89
90pub fn compute_matching<K, L>(
92 tree_a: &Tree<K, L>,
93 tree_b: &Tree<K, L>,
94 config: &MatchingConfig,
95) -> Matching
96where
97 K: Clone + Eq + Hash,
98 L: Clone,
99{
100 let mut matching = Matching::new();
101
102 top_down_phase(tree_a, tree_b, &mut matching, config);
104
105 bottom_up_phase(tree_a, tree_b, &mut matching, config);
107
108 matching
109}
110
111fn top_down_phase<K, L>(
117 tree_a: &Tree<K, L>,
118 tree_b: &Tree<K, L>,
119 matching: &mut Matching,
120 config: &MatchingConfig,
121) where
122 K: Clone + Eq + Hash,
123 L: Clone,
124{
125 let mut b_by_hash: HashMap<u64, Vec<NodeId>> = HashMap::new();
127 for b_id in tree_b.iter() {
128 let hash = tree_b.get(b_id).hash;
129 b_by_hash.entry(hash).or_default().push(b_id);
130 }
131
132 let mut candidates: Vec<(NodeId, NodeId)> = vec![(tree_a.root, tree_b.root)];
135
136 candidates.sort_by(|a, b| {
138 let ha = tree_a.height(a.0);
139 let hb = tree_a.height(b.0);
140 hb.cmp(&ha)
141 });
142
143 while let Some((a_id, b_id)) = candidates.pop() {
144 if matching.contains_a(a_id) || matching.contains_b(b_id) {
146 continue;
147 }
148
149 let a_data = tree_a.get(a_id);
150 let b_data = tree_b.get(b_id);
151
152 if tree_a.height(a_id) < config.min_height {
154 continue;
155 }
156
157 if a_data.hash == b_data.hash && a_data.kind == b_data.kind {
159 match_subtrees(tree_a, tree_b, a_id, b_id, matching);
160 } else {
161 for a_child in tree_a.children(a_id) {
163 let a_child_data = tree_a.get(a_child);
164
165 if let Some(b_candidates) = b_by_hash.get(&a_child_data.hash) {
167 for &b_candidate in b_candidates {
168 if !matching.contains_b(b_candidate) {
169 candidates.push((a_child, b_candidate));
170 }
171 }
172 }
173
174 for b_child in tree_b.children(b_id) {
176 if !matching.contains_b(b_child) {
177 let b_child_data = tree_b.get(b_child);
178 if a_child_data.kind == b_child_data.kind {
179 candidates.push((a_child, b_child));
180 }
181 }
182 }
183 }
184 }
185 }
186}
187
188fn match_subtrees<K, L>(
190 tree_a: &Tree<K, L>,
191 tree_b: &Tree<K, L>,
192 a_id: NodeId,
193 b_id: NodeId,
194 matching: &mut Matching,
195) where
196 K: Clone + Eq + Hash,
197 L: Clone,
198{
199 matching.add(a_id, b_id);
200
201 let a_children: Vec<_> = tree_a.children(a_id).collect();
203 let b_children: Vec<_> = tree_b.children(b_id).collect();
204
205 for (a_child, b_child) in a_children.into_iter().zip(b_children.into_iter()) {
206 match_subtrees(tree_a, tree_b, a_child, b_child, matching);
207 }
208}
209
210fn bottom_up_phase<K, L>(
216 tree_a: &Tree<K, L>,
217 tree_b: &Tree<K, L>,
218 matching: &mut Matching,
219 config: &MatchingConfig,
220) where
221 K: Clone + Eq + Hash,
222 L: Clone,
223{
224 let mut b_by_kind: HashMap<K, Vec<NodeId>> = HashMap::new();
226 let mut b_by_kind_hash: HashMap<(K, u64), Vec<NodeId>> = HashMap::new();
227
228 for b_id in tree_b.iter() {
229 if !matching.contains_b(b_id) {
230 let b_data = tree_b.get(b_id);
231 let kind = b_data.kind.clone();
232 b_by_kind.entry(kind.clone()).or_default().push(b_id);
233
234 if tree_b.child_count(b_id) == 0 {
236 b_by_kind_hash
237 .entry((kind, b_data.hash))
238 .or_default()
239 .push(b_id);
240 }
241 }
242 }
243
244 for a_id in tree_a.post_order() {
246 if matching.contains_a(a_id) {
247 continue;
248 }
249
250 let a_data = tree_a.get(a_id);
251 let is_leaf = tree_a.child_count(a_id) == 0;
252
253 if is_leaf {
254 let key = (a_data.kind.clone(), a_data.hash);
256 if let Some(candidates) = b_by_kind_hash.get(&key) {
257 for &b_id in candidates {
258 if !matching.contains_b(b_id) {
259 matching.add(a_id, b_id);
260 break; }
262 }
263 }
264 } else {
265 let candidates = b_by_kind.get(&a_data.kind).cloned().unwrap_or_default();
267
268 let mut best: Option<(NodeId, f64)> = None;
269 for b_id in candidates {
270 if matching.contains_b(b_id) {
271 continue;
272 }
273
274 if tree_b.child_count(b_id) == 0 {
276 continue;
277 }
278
279 let score = dice_coefficient(tree_a, tree_b, a_id, b_id, matching);
280 if score >= config.similarity_threshold
281 && (best.is_none() || score > best.unwrap().1)
282 {
283 best = Some((b_id, score));
284 }
285 }
286
287 if let Some((b_id, _)) = best {
288 matching.add(a_id, b_id);
289 }
290 }
291 }
292}
293
294fn dice_coefficient<K, L>(
298 tree_a: &Tree<K, L>,
299 tree_b: &Tree<K, L>,
300 a_id: NodeId,
301 b_id: NodeId,
302 matching: &Matching,
303) -> f64
304where
305 K: Clone + Eq + Hash,
306 L: Clone,
307{
308 let desc_a: HashSet<_> = tree_a.descendants(a_id).collect();
309 let desc_b: HashSet<_> = tree_b.descendants(b_id).collect();
310
311 let common = desc_a
312 .iter()
313 .filter(|&&a| {
314 matching
315 .get_b(a)
316 .map(|b| desc_b.contains(&b))
317 .unwrap_or(false)
318 })
319 .count();
320
321 if desc_a.is_empty() && desc_b.is_empty() {
322 1.0 } else {
324 2.0 * common as f64 / (desc_a.len() + desc_b.len()) as f64
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use crate::tree::NodeData;
332
333 #[test]
334 fn test_identical_trees() {
335 let mut tree_a: Tree<&str, String> = Tree::new(NodeData::new(100, "root"));
336 tree_a.add_child(tree_a.root, NodeData::leaf(1, "leaf", "a".to_string()));
337 tree_a.add_child(tree_a.root, NodeData::leaf(2, "leaf", "b".to_string()));
338
339 let mut tree_b: Tree<&str, String> = Tree::new(NodeData::new(100, "root"));
340 tree_b.add_child(tree_b.root, NodeData::leaf(1, "leaf", "a".to_string()));
341 tree_b.add_child(tree_b.root, NodeData::leaf(2, "leaf", "b".to_string()));
342
343 let matching = compute_matching(&tree_a, &tree_b, &MatchingConfig::default());
344
345 assert_eq!(matching.len(), 3);
347 }
348
349 #[test]
350 fn test_partial_match() {
351 let mut tree_a: Tree<&str, String> = Tree::new(NodeData::new(100, "root"));
353 let child1_a = tree_a.add_child(tree_a.root, NodeData::leaf(1, "leaf", "same".to_string()));
354 let _child2_a =
355 tree_a.add_child(tree_a.root, NodeData::leaf(2, "leaf", "diff_a".to_string()));
356
357 let mut tree_b: Tree<&str, String> = Tree::new(NodeData::new(100, "root"));
358 let child1_b = tree_b.add_child(tree_b.root, NodeData::leaf(1, "leaf", "same".to_string()));
359 let _child2_b =
360 tree_b.add_child(tree_b.root, NodeData::leaf(3, "leaf", "diff_b".to_string()));
361
362 let matching = compute_matching(&tree_a, &tree_b, &MatchingConfig::default());
363
364 assert!(
366 matching.contains_a(child1_a),
367 "Identical leaves should match"
368 );
369 assert_eq!(matching.get_b(child1_a), Some(child1_b));
370 }
371}