layer_conform_core/
apted.rs1use std::collections::HashMap;
7
8use crate::tree::TreeNode;
9
10#[derive(Copy, Clone, Debug)]
11pub struct AptedOptions {
12 pub rename_cost: f64,
13 pub insert_cost: f64,
14 pub delete_cost: f64,
15}
16
17impl Default for AptedOptions {
18 fn default() -> Self {
19 Self { rename_cost: 1.0, insert_cost: 1.0, delete_cost: 1.0 }
20 }
21}
22
23pub fn edit_distance(a: &TreeNode, b: &TreeNode, opts: AptedOptions) -> f64 {
26 let mut memo: HashMap<(u32, u32), f64> = HashMap::new();
27 distance_recurse(a, b, opts, &mut memo)
28}
29
30fn distance_recurse(
31 a: &TreeNode,
32 b: &TreeNode,
33 opts: AptedOptions,
34 memo: &mut HashMap<(u32, u32), f64>,
35) -> f64 {
36 if let Some(v) = memo.get(&(a.id, b.id)) {
37 return *v;
38 }
39 let cost_root = if a.kind == b.kind && a.value == b.value {
40 0.0
41 } else {
42 opts.rename_cost
43 };
44
45 let n = a.children.len();
47 let m = b.children.len();
48 let mut dp = vec![vec![0.0_f64; m + 1]; n + 1];
49 for i in 1..=n {
50 dp[i][0] = dp[i - 1][0] + opts.delete_cost * f64::from(a.children[i - 1].subtree_size);
51 }
52 for j in 1..=m {
53 dp[0][j] = dp[0][j - 1] + opts.insert_cost * f64::from(b.children[j - 1].subtree_size);
54 }
55 for i in 1..=n {
56 for j in 1..=m {
57 let del = dp[i - 1][j] + opts.delete_cost * f64::from(a.children[i - 1].subtree_size);
58 let ins = dp[i][j - 1] + opts.insert_cost * f64::from(b.children[j - 1].subtree_size);
59 let rep = dp[i - 1][j - 1]
60 + distance_recurse(&a.children[i - 1], &b.children[j - 1], opts, memo);
61 dp[i][j] = del.min(ins).min(rep);
62 }
63 }
64
65 let total = cost_root + dp[n][m];
66 memo.insert((a.id, b.id), total);
67 total
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73 use crate::tree::{NodeKind, TreeNode};
74
75 fn finalized(mut t: TreeNode) -> TreeNode {
76 t.finalize();
77 t
78 }
79
80 #[test]
81 fn identical_leaves_have_zero_distance() {
82 let a = finalized(TreeNode::leaf(NodeKind::Identifier, Some("x".into())));
83 let b = finalized(TreeNode::leaf(NodeKind::Identifier, Some("x".into())));
84 assert!((edit_distance(&a, &b, AptedOptions::default()) - 0.0).abs() < 1e-9);
85 }
86
87 #[test]
88 fn different_value_costs_one_rename() {
89 let a = finalized(TreeNode::leaf(NodeKind::Identifier, Some("x".into())));
90 let b = finalized(TreeNode::leaf(NodeKind::Identifier, Some("y".into())));
91 assert!((edit_distance(&a, &b, AptedOptions::default()) - 1.0).abs() < 1e-9);
92 }
93
94 #[test]
95 fn missing_child_costs_subtree_size() {
96 let a = finalized(TreeNode::branch(
99 NodeKind::Block,
100 vec![
101 TreeNode::leaf(NodeKind::Identifier, Some("x".into())),
102 TreeNode::leaf(NodeKind::Identifier, Some("y".into())),
103 ],
104 ));
105 let b = finalized(TreeNode::branch(
106 NodeKind::Block,
107 vec![TreeNode::leaf(NodeKind::Identifier, Some("x".into()))],
108 ));
109 let d = edit_distance(&a, &b, AptedOptions::default());
110 assert!((d - 1.0).abs() < 1e-9, "expected 1.0, got {d}");
111 }
112
113 #[test]
114 fn completely_different_trees() {
115 let a = finalized(TreeNode::leaf(NodeKind::Identifier, Some("x".into())));
116 let b = finalized(TreeNode::leaf(NodeKind::Literal, Some("y".into())));
117 assert!((edit_distance(&a, &b, AptedOptions::default()) - 1.0).abs() < 1e-9);
119 }
120}