state_tree/
tree_diff.rs

1use std::collections::HashSet;
2
3use crate::patch::CopyFromPatch;
4use crate::tree::{SizedType, StateTreeSkeleton};
5
6pub fn take_diff<T: SizedType>(
7    old_skeleton: &StateTreeSkeleton<T>,
8    new_skeleton: &StateTreeSkeleton<T>,
9) -> HashSet<CopyFromPatch> {
10    let patchset = build_patches_recursive(old_skeleton, new_skeleton, vec![], vec![]);
11    patchset.into_iter().collect()
12}
13
14/// LCSアルゴリズムの結果を表すEnum
15#[derive(Debug)]
16pub enum DiffResult {
17    /// 両方のシーケンスに共通して存在する要素
18    Common { old_index: usize, new_index: usize },
19    /// 古いシーケンスにのみ存在する要素(削除された)
20    Delete { old_index: usize },
21    /// 新しいシーケンスにのみ存在する要素(挿入された)
22    Insert { new_index: usize },
23}
24
25/// 2つのスライスを比較し、スコアに基づいて最適なマッチングを返す
26/// `score_fn`クロージャでマッチスコア(0.0~1.0)を計算する
27pub fn lcs_by_score<T>(
28    old: &[T],
29    new: &[T],
30    mut score_fn: impl FnMut(&T, &T) -> f64,
31) -> Vec<DiffResult> {
32    let old_len = old.len();
33    let new_len = new.len();
34
35    // DP テーブル: dp[i][j] = 累積スコア
36    let mut dp = vec![vec![0.0; new_len + 1]; old_len + 1];
37
38    for i in 1..=old_len {
39        for j in 1..=new_len {
40            let score = score_fn(&old[i - 1], &new[j - 1]);
41
42            if score > 0.0 {
43                // マッチした場合: 対角線の値 + マッチスコア
44                dp[i][j] = (dp[i - 1][j - 1] + score).max(dp[i - 1][j].max(dp[i][j - 1]));
45            } else {
46                // マッチしなかった場合: 最大値を取る
47                dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
48            }
49        }
50    }
51    // バックトラックして結果を復元
52    let mut results = Vec::new();
53    let (mut i, mut j) = (old_len, new_len);
54
55    while i > 0 || j > 0 {
56        if i > 0 && j > 0 {
57            let score = score_fn(&old[i - 1], &new[j - 1]);
58
59            if score > 0.0 {
60                // マッチしている可能性が高い
61                results.push(DiffResult::Common {
62                    old_index: i - 1,
63                    new_index: j - 1,
64                });
65                i -= 1;
66                j -= 1;
67            } else if j > 0 && (i == 0 || dp[i][j - 1] >= dp[i - 1][j]) {
68                results.push(DiffResult::Insert { new_index: j - 1 });
69                j -= 1;
70            } else if i > 0 {
71                results.push(DiffResult::Delete { old_index: i - 1 });
72                i -= 1;
73            }
74        } else if j > 0 {
75            results.push(DiffResult::Insert { new_index: j - 1 });
76            j -= 1;
77        } else if i > 0 {
78            results.push(DiffResult::Delete { old_index: i - 1 });
79            i -= 1;
80        }
81    }
82
83    results.reverse(); // 結果を正しい順序にする
84    results
85}
86
87fn nodes_match<T: SizedType>(old: &StateTreeSkeleton<T>, new: &StateTreeSkeleton<T>) -> bool {
88    match (old, new) {
89        (StateTreeSkeleton::Delay { len: len1 }, StateTreeSkeleton::Delay { len: len2 }) => {
90            len1 == len2
91        }
92        (StateTreeSkeleton::Mem(t1), StateTreeSkeleton::Mem(t2)) => {
93            t1.word_size() == t2.word_size()
94        }
95        (StateTreeSkeleton::Feed(t1), StateTreeSkeleton::Feed(t2)) => {
96            t1.word_size() == t2.word_size()
97        }
98        (StateTreeSkeleton::FnCall(c1), StateTreeSkeleton::FnCall(c2)) => {
99            c1.len() == c2.len() && c1.iter().zip(c2.iter()).all(|(a, b)| nodes_match(a, b))
100        }
101        _ => false,
102    }
103}
104
105fn build_patches_recursive<T: SizedType>(
106    old_node: &StateTreeSkeleton<T>,
107    new_node: &StateTreeSkeleton<T>,
108    old_path: Vec<usize>,
109    new_path: Vec<usize>,
110) -> HashSet<CopyFromPatch> {
111    // ノードが完全に一致する場合は、単一のパッチを返す
112    if nodes_match(old_node, new_node) {
113        return [CopyFromPatch { old_path, new_path }].into_iter().collect();
114    }
115
116    match (old_node, new_node) {
117        (StateTreeSkeleton::FnCall(old_children), StateTreeSkeleton::FnCall(new_children)) => {
118            // 最初に全ての子ノードのパッチを計算(スコア計算の副作用を避けるため)
119            let mut child_patches_map = Vec::new();
120            for (old_idx, old_child) in old_children.iter().enumerate() {
121                for (new_idx, new_child) in new_children.iter().enumerate() {
122                    let child_old_path = [old_path.clone(), vec![old_idx]].concat();
123                    let child_new_path = [new_path.clone(), vec![new_idx]].concat();
124                    let patches = build_patches_recursive(
125                        old_child,
126                        new_child,
127                        child_old_path,
128                        child_new_path,
129                    );
130                    let score = if patches.is_empty() {
131                        0.0
132                    } else {
133                        patches.len() as f64
134                    };
135                    child_patches_map.push(((old_idx, new_idx), patches, score));
136                }
137            }
138
139            // LCSでマッチングを見つける
140            let old_c_with_id: Vec<_> = old_children.iter().enumerate().collect();
141            let new_c_with_id: Vec<_> = new_children.iter().enumerate().collect();
142
143            let lcs_results = lcs_by_score(
144                &old_c_with_id,
145                &new_c_with_id,
146                |(oid, _old), (nid, _new)| {
147                    child_patches_map
148                        .iter()
149                        .find(|((o, n), _, _)| o == oid && n == nid)
150                        .map(|(_, _, score)| *score)
151                        .unwrap_or(0.0)
152                },
153            );
154
155            // LCS結果に基づいてパッチを収集
156            let mut c_patches = HashSet::new();
157            for result in &lcs_results {
158                if let DiffResult::Common {
159                    old_index,
160                    new_index,
161                } = result
162                    && let Some((_, patches, _)) = child_patches_map
163                        .iter()
164                        .find(|((o, n), _, _)| o == old_index && n == new_index)
165                {
166                    c_patches.extend(patches.iter().cloned());
167                }
168            }
169
170            c_patches
171        }
172        _ => HashSet::new(),
173    }
174}