1use std::collections::HashSet;
2
3use crate::patch::CopyFromPatch;
4use crate::tree::{SizedType, StateTreeSkeleton};
5
6pub fn take_diff<T: SizedType>(
7 old_skeleton: &StateTreeSkeleton<T>,
8 new_skeleton: &StateTreeSkeleton<T>,
9) -> HashSet<CopyFromPatch> {
10 let patchset = build_patches_recursive(old_skeleton, new_skeleton, vec![], vec![]);
11 patchset.into_iter().collect()
12}
13
14#[derive(Debug)]
16pub enum DiffResult {
17 Common { old_index: usize, new_index: usize },
19 Delete { old_index: usize },
21 Insert { new_index: usize },
23}
24
25pub fn lcs_by_score<T>(
28 old: &[T],
29 new: &[T],
30 mut score_fn: impl FnMut(&T, &T) -> f64,
31) -> Vec<DiffResult> {
32 let old_len = old.len();
33 let new_len = new.len();
34
35 let mut dp = vec![vec![0.0; new_len + 1]; old_len + 1];
37
38 for i in 1..=old_len {
39 for j in 1..=new_len {
40 let score = score_fn(&old[i - 1], &new[j - 1]);
41
42 if score > 0.0 {
43 dp[i][j] = (dp[i - 1][j - 1] + score).max(dp[i - 1][j].max(dp[i][j - 1]));
45 } else {
46 dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
48 }
49 }
50 }
51 let mut results = Vec::new();
53 let (mut i, mut j) = (old_len, new_len);
54
55 while i > 0 || j > 0 {
56 if i > 0 && j > 0 {
57 let score = score_fn(&old[i - 1], &new[j - 1]);
58
59 if score > 0.0 {
60 results.push(DiffResult::Common {
62 old_index: i - 1,
63 new_index: j - 1,
64 });
65 i -= 1;
66 j -= 1;
67 } else if j > 0 && (i == 0 || dp[i][j - 1] >= dp[i - 1][j]) {
68 results.push(DiffResult::Insert { new_index: j - 1 });
69 j -= 1;
70 } else if i > 0 {
71 results.push(DiffResult::Delete { old_index: i - 1 });
72 i -= 1;
73 }
74 } else if j > 0 {
75 results.push(DiffResult::Insert { new_index: j - 1 });
76 j -= 1;
77 } else if i > 0 {
78 results.push(DiffResult::Delete { old_index: i - 1 });
79 i -= 1;
80 }
81 }
82
83 results.reverse(); results
85}
86
87fn nodes_match<T: SizedType>(old: &StateTreeSkeleton<T>, new: &StateTreeSkeleton<T>) -> bool {
88 match (old, new) {
89 (StateTreeSkeleton::Delay { len: len1 }, StateTreeSkeleton::Delay { len: len2 }) => {
90 len1 == len2
91 }
92 (StateTreeSkeleton::Mem(t1), StateTreeSkeleton::Mem(t2)) => {
93 t1.word_size() == t2.word_size()
94 }
95 (StateTreeSkeleton::Feed(t1), StateTreeSkeleton::Feed(t2)) => {
96 t1.word_size() == t2.word_size()
97 }
98 (StateTreeSkeleton::FnCall(c1), StateTreeSkeleton::FnCall(c2)) => {
99 c1.len() == c2.len() && c1.iter().zip(c2.iter()).all(|(a, b)| nodes_match(a, b))
100 }
101 _ => false,
102 }
103}
104
105fn build_patches_recursive<T: SizedType>(
106 old_node: &StateTreeSkeleton<T>,
107 new_node: &StateTreeSkeleton<T>,
108 old_path: Vec<usize>,
109 new_path: Vec<usize>,
110) -> HashSet<CopyFromPatch> {
111 if nodes_match(old_node, new_node) {
113 return [CopyFromPatch { old_path, new_path }].into_iter().collect();
114 }
115
116 match (old_node, new_node) {
117 (StateTreeSkeleton::FnCall(old_children), StateTreeSkeleton::FnCall(new_children)) => {
118 let mut child_patches_map = Vec::new();
120 for (old_idx, old_child) in old_children.iter().enumerate() {
121 for (new_idx, new_child) in new_children.iter().enumerate() {
122 let child_old_path = [old_path.clone(), vec![old_idx]].concat();
123 let child_new_path = [new_path.clone(), vec![new_idx]].concat();
124 let patches = build_patches_recursive(
125 old_child,
126 new_child,
127 child_old_path,
128 child_new_path,
129 );
130 let score = if patches.is_empty() {
131 0.0
132 } else {
133 patches.len() as f64
134 };
135 child_patches_map.push(((old_idx, new_idx), patches, score));
136 }
137 }
138
139 let old_c_with_id: Vec<_> = old_children.iter().enumerate().collect();
141 let new_c_with_id: Vec<_> = new_children.iter().enumerate().collect();
142
143 let lcs_results = lcs_by_score(
144 &old_c_with_id,
145 &new_c_with_id,
146 |(oid, _old), (nid, _new)| {
147 child_patches_map
148 .iter()
149 .find(|((o, n), _, _)| o == oid && n == nid)
150 .map(|(_, _, score)| *score)
151 .unwrap_or(0.0)
152 },
153 );
154
155 let mut c_patches = HashSet::new();
157 for result in &lcs_results {
158 if let DiffResult::Common {
159 old_index,
160 new_index,
161 } = result
162 && let Some((_, patches, _)) = child_patches_map
163 .iter()
164 .find(|((o, n), _, _)| o == old_index && n == new_index)
165 {
166 c_patches.extend(patches.iter().cloned());
167 }
168 }
169
170 c_patches
171 }
172 _ => HashSet::new(),
173 }
174}