Skip to main content

layer_conform_core/
apted.rs

1//! APTED (All Path Tree Edit Distance) implementation.
2//!
3//! Memoized DP over `(node1.id, node2.id)` pairs. Identifier equality is
4//! decided by `(kind, value)` — `id/subtree_size` are ignored.
5
6use std::collections::HashMap;
7
8use crate::tree::TreeNode;
9
10#[derive(Copy, Clone, Debug)]
11pub struct AptedOptions {
12    pub rename_cost: f64,
13    pub insert_cost: f64,
14    pub delete_cost: f64,
15}
16
17impl Default for AptedOptions {
18    fn default() -> Self {
19        Self { rename_cost: 1.0, insert_cost: 1.0, delete_cost: 1.0 }
20    }
21}
22
23/// Compute tree edit distance between `a` and `b`.
24/// Both trees must have been finalized (`id/subtree_size` set).
25pub fn edit_distance(a: &TreeNode, b: &TreeNode, opts: AptedOptions) -> f64 {
26    let mut memo: HashMap<(u32, u32), f64> = HashMap::new();
27    distance_recurse(a, b, opts, &mut memo)
28}
29
30fn distance_recurse(
31    a: &TreeNode,
32    b: &TreeNode,
33    opts: AptedOptions,
34    memo: &mut HashMap<(u32, u32), f64>,
35) -> f64 {
36    if let Some(v) = memo.get(&(a.id, b.id)) {
37        return *v;
38    }
39    let cost_root = if a.kind == b.kind && a.value == b.value {
40        0.0
41    } else {
42        opts.rename_cost
43    };
44
45    // Children edit distance via DP over child sequences.
46    let n = a.children.len();
47    let m = b.children.len();
48    let mut dp = vec![vec![0.0_f64; m + 1]; n + 1];
49    for i in 1..=n {
50        dp[i][0] = dp[i - 1][0] + opts.delete_cost * f64::from(a.children[i - 1].subtree_size);
51    }
52    for j in 1..=m {
53        dp[0][j] = dp[0][j - 1] + opts.insert_cost * f64::from(b.children[j - 1].subtree_size);
54    }
55    for i in 1..=n {
56        for j in 1..=m {
57            let del = dp[i - 1][j] + opts.delete_cost * f64::from(a.children[i - 1].subtree_size);
58            let ins = dp[i][j - 1] + opts.insert_cost * f64::from(b.children[j - 1].subtree_size);
59            let rep = dp[i - 1][j - 1]
60                + distance_recurse(&a.children[i - 1], &b.children[j - 1], opts, memo);
61            dp[i][j] = del.min(ins).min(rep);
62        }
63    }
64
65    let total = cost_root + dp[n][m];
66    memo.insert((a.id, b.id), total);
67    total
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use crate::tree::{NodeKind, TreeNode};
74
75    fn finalized(mut t: TreeNode) -> TreeNode {
76        t.finalize();
77        t
78    }
79
80    #[test]
81    fn identical_leaves_have_zero_distance() {
82        let a = finalized(TreeNode::leaf(NodeKind::Identifier, Some("x".into())));
83        let b = finalized(TreeNode::leaf(NodeKind::Identifier, Some("x".into())));
84        assert!((edit_distance(&a, &b, AptedOptions::default()) - 0.0).abs() < 1e-9);
85    }
86
87    #[test]
88    fn different_value_costs_one_rename() {
89        let a = finalized(TreeNode::leaf(NodeKind::Identifier, Some("x".into())));
90        let b = finalized(TreeNode::leaf(NodeKind::Identifier, Some("y".into())));
91        assert!((edit_distance(&a, &b, AptedOptions::default()) - 1.0).abs() < 1e-9);
92    }
93
94    #[test]
95    fn missing_child_costs_subtree_size() {
96        // a: Block(Ident, Ident)   size=3
97        // b: Block(Ident)          size=2
98        let a = finalized(TreeNode::branch(
99            NodeKind::Block,
100            vec![
101                TreeNode::leaf(NodeKind::Identifier, Some("x".into())),
102                TreeNode::leaf(NodeKind::Identifier, Some("y".into())),
103            ],
104        ));
105        let b = finalized(TreeNode::branch(
106            NodeKind::Block,
107            vec![TreeNode::leaf(NodeKind::Identifier, Some("x".into()))],
108        ));
109        let d = edit_distance(&a, &b, AptedOptions::default());
110        assert!((d - 1.0).abs() < 1e-9, "expected 1.0, got {d}");
111    }
112
113    #[test]
114    fn completely_different_trees() {
115        let a = finalized(TreeNode::leaf(NodeKind::Identifier, Some("x".into())));
116        let b = finalized(TreeNode::leaf(NodeKind::Literal, Some("y".into())));
117        // kind と value 両方違う → rename 1.0
118        assert!((edit_distance(&a, &b, AptedOptions::default()) - 1.0).abs() < 1e-9);
119    }
120}