state_tree/
tree_diff.rs

1use std::collections::HashSet;
2
3use crate::patch::CopyFromPatch;
4use crate::tree::{SizedType, StateTreeSkeleton};
5
6pub fn take_diff<T: SizedType>(
7    old_skeleton: &StateTreeSkeleton<T>,
8    new_skeleton: &StateTreeSkeleton<T>,
9) -> HashSet<CopyFromPatch> {
10    build_patches_recursive(old_skeleton, new_skeleton, vec![], vec![])
11}
12
13/// Enum representing the result of LCS algorithm
14#[derive(Debug)]
15pub enum DiffResult {
16    /// Element that exists in both sequences
17    Common { old_index: usize, new_index: usize },
18    /// Element that exists only in the old sequence (deleted)
19    Delete { old_index: usize },
20    /// Element that exists only in the new sequence (inserted)
21    Insert { new_index: usize },
22}
23
24/// Compare two slices and return optimal matching based on scores.
25/// The `score_fn` closure calculates the match score (0.0 to 1.0).
26pub fn lcs_by_score<T>(
27    old: &[T],
28    new: &[T],
29    mut score_fn: impl FnMut(&T, &T) -> f64,
30) -> Vec<DiffResult> {
31    let old_len = old.len();
32    let new_len = new.len();
33
34    // DP table: dp[i][j] = cumulative score
35    let mut dp = vec![vec![0.0; new_len + 1]; old_len + 1];
36
37    for i in 1..=old_len {
38        for j in 1..=new_len {
39            let score = score_fn(&old[i - 1], &new[j - 1]);
40
41            if score > 0.0 {
42                // If matched: diagonal value + match score
43                dp[i][j] = (dp[i - 1][j - 1] + score).max(dp[i - 1][j].max(dp[i][j - 1]));
44            } else {
45                // If not matched: take the maximum value
46                dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
47            }
48        }
49    }
50    // Backtrack to restore the result
51    let mut results = Vec::new();
52    let (mut i, mut j) = (old_len, new_len);
53
54    while i > 0 || j > 0 {
55        if i > 0 && j > 0 {
56            let score = score_fn(&old[i - 1], &new[j - 1]);
57
58            if score > 0.0 {
59                // Likely matched
60                results.push(DiffResult::Common {
61                    old_index: i - 1,
62                    new_index: j - 1,
63                });
64                i -= 1;
65                j -= 1;
66            } else if j > 0 && (i == 0 || dp[i][j - 1] >= dp[i - 1][j]) {
67                results.push(DiffResult::Insert { new_index: j - 1 });
68                j -= 1;
69            } else if i > 0 {
70                results.push(DiffResult::Delete { old_index: i - 1 });
71                i -= 1;
72            }
73        } else if j > 0 {
74            results.push(DiffResult::Insert { new_index: j - 1 });
75            j -= 1;
76        } else if i > 0 {
77            results.push(DiffResult::Delete { old_index: i - 1 });
78            i -= 1;
79        }
80    }
81
82    results.reverse(); // Reverse to get the correct order
83    results
84}
85
86fn nodes_match<T: SizedType>(old: &StateTreeSkeleton<T>, new: &StateTreeSkeleton<T>) -> bool {
87    match (old, new) {
88        (StateTreeSkeleton::Delay { len: len1 }, StateTreeSkeleton::Delay { len: len2 }) => {
89            len1 == len2
90        }
91        (StateTreeSkeleton::Mem(t1), StateTreeSkeleton::Mem(t2)) => {
92            t1.word_size() == t2.word_size()
93        }
94        (StateTreeSkeleton::Feed(t1), StateTreeSkeleton::Feed(t2)) => {
95            t1.word_size() == t2.word_size()
96        }
97        (StateTreeSkeleton::FnCall(c1), StateTreeSkeleton::FnCall(c2)) => {
98            c1.len() == c2.len() && c1.iter().zip(c2.iter()).all(|(a, b)| nodes_match(a, b))
99        }
100        _ => false,
101    }
102}
103
104/// Retrieve a node from a Skeleton using a path
105fn get_node_at_path<'a, T: SizedType>(
106    skeleton: &'a StateTreeSkeleton<T>,
107    path: &[usize],
108) -> Option<&'a StateTreeSkeleton<T>> {
109    if path.is_empty() {
110        return Some(skeleton);
111    }
112    
113    match skeleton {
114        StateTreeSkeleton::FnCall(children) => {
115            let child = children.get(path[0])?;
116            get_node_at_path(child, &path[1..])
117        }
118        _ => None,
119    }
120}
121
122fn build_patches_recursive<T: SizedType>(
123    old_skeleton: &StateTreeSkeleton<T>,
124    new_skeleton: &StateTreeSkeleton<T>,
125    old_path: Vec<usize>,
126    new_path: Vec<usize>,
127) -> HashSet<CopyFromPatch> {
128    // Retrieve the current node from the path
129    let old_node = get_node_at_path(old_skeleton, &old_path).expect("Invalid old_path");
130    let new_node = get_node_at_path(new_skeleton, &new_path).expect("Invalid new_path");
131    
132    // If the nodes are completely matched, return a single patch
133    if nodes_match(old_node, new_node) {
134        // Convert path to address
135        let (src_addr, size) = old_skeleton
136            .path_to_address(&old_path)
137            .expect("Invalid old_path");
138        let (dst_addr, dst_size) = new_skeleton
139            .path_to_address(&new_path)
140            .expect("Invalid new_path");
141
142        debug_assert_eq!(
143            size, dst_size,
144            "Size mismatch between matched nodes at old_path {old_path:?} and new_path {new_path:?}"
145        );
146
147        return [CopyFromPatch {
148            src_addr,
149            dst_addr,
150            size,
151        }]
152        .into_iter()
153        .collect();
154    }
155
156    match (old_node, new_node) {
157        (StateTreeSkeleton::FnCall(old_children), StateTreeSkeleton::FnCall(new_children)) => {
158            // First, calculate patches for all child nodes (to avoid side effects in score calculation)
159            let mut child_patches_map = Vec::new();
160            for old_idx in 0..old_children.len() {
161                for new_idx in 0..new_children.len() {
162                    let child_old_path = [old_path.clone(), vec![old_idx]].concat();
163                    let child_new_path = [new_path.clone(), vec![new_idx]].concat();
164                    let patches = build_patches_recursive(
165                        old_skeleton,
166                        new_skeleton,
167                        child_old_path,
168                        child_new_path,
169                    );
170                    let score = if patches.is_empty() {
171                        0.0
172                    } else {
173                        patches.len() as f64
174                    };
175                    child_patches_map.push(((old_idx, new_idx), patches, score));
176                }
177            }
178
179            // Find matching using LCS
180            let old_c_with_id: Vec<_> = old_children.iter().enumerate().collect();
181            let new_c_with_id: Vec<_> = new_children.iter().enumerate().collect();
182
183            let lcs_results = lcs_by_score(
184                &old_c_with_id,
185                &new_c_with_id,
186                |(oid, _old), (nid, _new)| {
187                    child_patches_map
188                        .iter()
189                        .find(|((o, n), _, _)| o == oid && n == nid)
190                        .map(|(_, _, score)| *score)
191                        .unwrap_or(0.0)
192                },
193            );
194
195            // Collect patches based on LCS results
196            let mut c_patches = HashSet::new();
197            for result in &lcs_results {
198                if let DiffResult::Common {
199                    old_index,
200                    new_index,
201                } = result
202                    && let Some((_, patches, _)) = child_patches_map
203                        .iter()
204                        .find(|((o, n), _, _)| o == old_index && n == new_index)
205                {
206                    c_patches.extend(patches.iter().cloned());
207                }
208            }
209
210            c_patches
211        }
212        _ => HashSet::new(),
213    }
214}