cinereus/
matching.rs

1//! GumTree node matching algorithm.
2//!
3//! Implements two-phase matching:
4//! 1. Top-down: Match identical subtrees by hash
5//! 2. Bottom-up: Match remaining nodes by structural similarity
6
7use crate::tree::Tree;
8use core::hash::Hash;
9use indextree::NodeId;
10use std::collections::{HashMap, HashSet};
11
12/// A bidirectional mapping between nodes in two trees.
13#[derive(Debug, Default)]
14pub struct Matching {
15    /// Map from tree A node to tree B node
16    a_to_b: HashMap<NodeId, NodeId>,
17    /// Map from tree B node to tree A node
18    b_to_a: HashMap<NodeId, NodeId>,
19}
20
21impl Matching {
22    /// Create a new empty matching.
23    pub fn new() -> Self {
24        Self::default()
25    }
26
27    /// Add a match between two nodes.
28    pub fn add(&mut self, a: NodeId, b: NodeId) {
29        self.a_to_b.insert(a, b);
30        self.b_to_a.insert(b, a);
31    }
32
33    /// Check if a node from tree A is matched.
34    pub fn contains_a(&self, a: NodeId) -> bool {
35        self.a_to_b.contains_key(&a)
36    }
37
38    /// Check if a node from tree B is matched.
39    pub fn contains_b(&self, b: NodeId) -> bool {
40        self.b_to_a.contains_key(&b)
41    }
42
43    /// Get the match for a node from tree A.
44    pub fn get_b(&self, a: NodeId) -> Option<NodeId> {
45        self.a_to_b.get(&a).copied()
46    }
47
48    /// Get the match for a node from tree B.
49    pub fn get_a(&self, b: NodeId) -> Option<NodeId> {
50        self.b_to_a.get(&b).copied()
51    }
52
53    /// Get all matched pairs.
54    pub fn pairs(&self) -> impl Iterator<Item = (NodeId, NodeId)> + '_ {
55        self.a_to_b.iter().map(|(&a, &b)| (a, b))
56    }
57
58    /// Get the number of matched pairs.
59    pub fn len(&self) -> usize {
60        self.a_to_b.len()
61    }
62
63    /// Check if there are no matches.
64    pub fn is_empty(&self) -> bool {
65        self.a_to_b.is_empty()
66    }
67}
68
69/// Configuration for the matching algorithm.
70#[derive(Debug, Clone)]
71pub struct MatchingConfig {
72    /// Minimum Dice coefficient for bottom-up matching.
73    /// Nodes with similarity below this threshold won't be matched.
74    pub similarity_threshold: f64,
75
76    /// Minimum height for a node to be considered in top-down matching.
77    /// Smaller subtrees are left for bottom-up matching.
78    pub min_height: usize,
79}
80
81impl Default for MatchingConfig {
82    fn default() -> Self {
83        Self {
84            similarity_threshold: 0.5,
85            min_height: 1,
86        }
87    }
88}
89
90/// Compute the matching between two trees using the GumTree algorithm.
91pub fn compute_matching<K, L>(
92    tree_a: &Tree<K, L>,
93    tree_b: &Tree<K, L>,
94    config: &MatchingConfig,
95) -> Matching
96where
97    K: Clone + Eq + Hash,
98    L: Clone,
99{
100    let mut matching = Matching::new();
101
102    // Phase 1: Top-down matching (identical subtrees by hash)
103    top_down_phase(tree_a, tree_b, &mut matching, config);
104
105    // Phase 2: Bottom-up matching (similar nodes by Dice coefficient)
106    bottom_up_phase(tree_a, tree_b, &mut matching, config);
107
108    matching
109}
110
111/// Phase 1: Top-down matching.
112///
113/// Greedily matches nodes with identical subtree hashes, starting from the roots
114/// and working down. When two nodes have the same hash, their entire subtrees
115/// are identical and can be matched recursively.
116fn top_down_phase<K, L>(
117    tree_a: &Tree<K, L>,
118    tree_b: &Tree<K, L>,
119    matching: &mut Matching,
120    config: &MatchingConfig,
121) where
122    K: Clone + Eq + Hash,
123    L: Clone,
124{
125    // Build hash -> nodes index for tree B
126    let mut b_by_hash: HashMap<u64, Vec<NodeId>> = HashMap::new();
127    for b_id in tree_b.iter() {
128        let hash = tree_b.get(b_id).hash;
129        b_by_hash.entry(hash).or_default().push(b_id);
130    }
131
132    // Priority queue: process nodes by height (descending)
133    // Higher nodes = larger subtrees = more valuable to match first
134    let mut candidates: Vec<(NodeId, NodeId)> = vec![(tree_a.root, tree_b.root)];
135
136    // Sort by height descending
137    candidates.sort_by(|a, b| {
138        let ha = tree_a.height(a.0);
139        let hb = tree_a.height(b.0);
140        hb.cmp(&ha)
141    });
142
143    while let Some((a_id, b_id)) = candidates.pop() {
144        // Skip if already matched
145        if matching.contains_a(a_id) || matching.contains_b(b_id) {
146            continue;
147        }
148
149        let a_data = tree_a.get(a_id);
150        let b_data = tree_b.get(b_id);
151
152        // Skip small subtrees (leave for bottom-up)
153        if tree_a.height(a_id) < config.min_height {
154            continue;
155        }
156
157        // If hashes match, these subtrees are identical
158        if a_data.hash == b_data.hash && a_data.kind == b_data.kind {
159            match_subtrees(tree_a, tree_b, a_id, b_id, matching);
160        } else {
161            // Hashes differ - try to match children
162            for a_child in tree_a.children(a_id) {
163                let a_child_data = tree_a.get(a_child);
164
165                // Look for B nodes with matching hash
166                if let Some(b_candidates) = b_by_hash.get(&a_child_data.hash) {
167                    for &b_candidate in b_candidates {
168                        if !matching.contains_b(b_candidate) {
169                            candidates.push((a_child, b_candidate));
170                        }
171                    }
172                }
173
174                // Also try children of b_id with same kind
175                for b_child in tree_b.children(b_id) {
176                    if !matching.contains_b(b_child) {
177                        let b_child_data = tree_b.get(b_child);
178                        if a_child_data.kind == b_child_data.kind {
179                            candidates.push((a_child, b_child));
180                        }
181                    }
182                }
183            }
184        }
185    }
186}
187
188/// Match two subtrees recursively (when their hashes match).
189fn match_subtrees<K, L>(
190    tree_a: &Tree<K, L>,
191    tree_b: &Tree<K, L>,
192    a_id: NodeId,
193    b_id: NodeId,
194    matching: &mut Matching,
195) where
196    K: Clone + Eq + Hash,
197    L: Clone,
198{
199    matching.add(a_id, b_id);
200
201    // Match children in order (they should be identical if hashes match)
202    let a_children: Vec<_> = tree_a.children(a_id).collect();
203    let b_children: Vec<_> = tree_b.children(b_id).collect();
204
205    for (a_child, b_child) in a_children.into_iter().zip(b_children.into_iter()) {
206        match_subtrees(tree_a, tree_b, a_child, b_child, matching);
207    }
208}
209
210/// Phase 2: Bottom-up matching.
211///
212/// For unmatched nodes, find candidates with the same kind and compute
213/// similarity using the Dice coefficient on matched descendants.
214/// For leaf nodes (no children), we match by hash since Dice is not meaningful.
215fn bottom_up_phase<K, L>(
216    tree_a: &Tree<K, L>,
217    tree_b: &Tree<K, L>,
218    matching: &mut Matching,
219    config: &MatchingConfig,
220) where
221    K: Clone + Eq + Hash,
222    L: Clone,
223{
224    // Build indices for tree B: by kind and by (kind, hash) for leaves
225    let mut b_by_kind: HashMap<K, Vec<NodeId>> = HashMap::new();
226    let mut b_by_kind_hash: HashMap<(K, u64), Vec<NodeId>> = HashMap::new();
227
228    for b_id in tree_b.iter() {
229        if !matching.contains_b(b_id) {
230            let b_data = tree_b.get(b_id);
231            let kind = b_data.kind.clone();
232            b_by_kind.entry(kind.clone()).or_default().push(b_id);
233
234            // For leaves, also index by (kind, hash)
235            if tree_b.child_count(b_id) == 0 {
236                b_by_kind_hash
237                    .entry((kind, b_data.hash))
238                    .or_default()
239                    .push(b_id);
240            }
241        }
242    }
243
244    // Process tree A in post-order (children before parents)
245    for a_id in tree_a.post_order() {
246        if matching.contains_a(a_id) {
247            continue;
248        }
249
250        let a_data = tree_a.get(a_id);
251        let is_leaf = tree_a.child_count(a_id) == 0;
252
253        if is_leaf {
254            // For leaves, match by exact hash (same kind AND same hash)
255            let key = (a_data.kind.clone(), a_data.hash);
256            if let Some(candidates) = b_by_kind_hash.get(&key) {
257                for &b_id in candidates {
258                    if !matching.contains_b(b_id) {
259                        matching.add(a_id, b_id);
260                        break; // Take the first available match
261                    }
262                }
263            }
264        } else {
265            // For internal nodes, use Dice coefficient
266            let candidates = b_by_kind.get(&a_data.kind).cloned().unwrap_or_default();
267
268            let mut best: Option<(NodeId, f64)> = None;
269            for b_id in candidates {
270                if matching.contains_b(b_id) {
271                    continue;
272                }
273
274                // Skip leaves when looking for internal node matches
275                if tree_b.child_count(b_id) == 0 {
276                    continue;
277                }
278
279                let score = dice_coefficient(tree_a, tree_b, a_id, b_id, matching);
280                if score >= config.similarity_threshold
281                    && (best.is_none() || score > best.unwrap().1)
282                {
283                    best = Some((b_id, score));
284                }
285            }
286
287            if let Some((b_id, _)) = best {
288                matching.add(a_id, b_id);
289            }
290        }
291    }
292}
293
294/// Compute the Dice coefficient between two nodes based on matched descendants.
295///
296/// dice(A, B) = 2 × |matched_descendants| / (|descendants_A| + |descendants_B|)
297fn dice_coefficient<K, L>(
298    tree_a: &Tree<K, L>,
299    tree_b: &Tree<K, L>,
300    a_id: NodeId,
301    b_id: NodeId,
302    matching: &Matching,
303) -> f64
304where
305    K: Clone + Eq + Hash,
306    L: Clone,
307{
308    let desc_a: HashSet<_> = tree_a.descendants(a_id).collect();
309    let desc_b: HashSet<_> = tree_b.descendants(b_id).collect();
310
311    let common = desc_a
312        .iter()
313        .filter(|&&a| {
314            matching
315                .get_b(a)
316                .map(|b| desc_b.contains(&b))
317                .unwrap_or(false)
318        })
319        .count();
320
321    if desc_a.is_empty() && desc_b.is_empty() {
322        1.0 // Both are leaves with no descendants
323    } else {
324        2.0 * common as f64 / (desc_a.len() + desc_b.len()) as f64
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use crate::tree::NodeData;
332
333    #[test]
334    fn test_identical_trees() {
335        let mut tree_a: Tree<&str, String> = Tree::new(NodeData::new(100, "root"));
336        tree_a.add_child(tree_a.root, NodeData::leaf(1, "leaf", "a".to_string()));
337        tree_a.add_child(tree_a.root, NodeData::leaf(2, "leaf", "b".to_string()));
338
339        let mut tree_b: Tree<&str, String> = Tree::new(NodeData::new(100, "root"));
340        tree_b.add_child(tree_b.root, NodeData::leaf(1, "leaf", "a".to_string()));
341        tree_b.add_child(tree_b.root, NodeData::leaf(2, "leaf", "b".to_string()));
342
343        let matching = compute_matching(&tree_a, &tree_b, &MatchingConfig::default());
344
345        // All nodes should be matched
346        assert_eq!(matching.len(), 3);
347    }
348
349    #[test]
350    fn test_partial_match() {
351        // Trees with same structure but one leaf differs
352        let mut tree_a: Tree<&str, String> = Tree::new(NodeData::new(100, "root"));
353        let child1_a = tree_a.add_child(tree_a.root, NodeData::leaf(1, "leaf", "same".to_string()));
354        let _child2_a =
355            tree_a.add_child(tree_a.root, NodeData::leaf(2, "leaf", "diff_a".to_string()));
356
357        let mut tree_b: Tree<&str, String> = Tree::new(NodeData::new(100, "root"));
358        let child1_b = tree_b.add_child(tree_b.root, NodeData::leaf(1, "leaf", "same".to_string()));
359        let _child2_b =
360            tree_b.add_child(tree_b.root, NodeData::leaf(3, "leaf", "diff_b".to_string()));
361
362        let matching = compute_matching(&tree_a, &tree_b, &MatchingConfig::default());
363
364        // The identical leaf should be matched
365        assert!(
366            matching.contains_a(child1_a),
367            "Identical leaves should match"
368        );
369        assert_eq!(matching.get_b(child1_a), Some(child1_b));
370    }
371}