Skip to main content

hanfei_fa/
comparison.rs

1use std::collections::BTreeMap;
2use serde::{Serialize, Deserialize};
3
4use crate::merkle_tree::{ModelMerkleTree, MerkleTree};
5
6/// Result of comparing two model Merkle trees.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ComparisonResult {
9    pub changed_params: Vec<String>,
10    pub unchanged_params: Vec<String>,
11    pub changed_layers: Vec<String>,
12    pub total_changed_bytes: usize,
13    pub total_bytes: usize,
14    pub change_percentage: f64,
15    pub chunk_changes: BTreeMap<String, Vec<usize>>,
16    pub hash_comparisons: usize,
17}
18
19impl ComparisonResult {
20    pub fn summary(&self) -> String {
21        format!(
22            "Changed: {} params, {} layers, {:.2}% of model ({} hash comparisons)",
23            self.changed_params.len(),
24            self.changed_layers.len(),
25            self.change_percentage,
26            self.hash_comparisons,
27        )
28    }
29}
30
31/// Compare two model Merkle trees with 3-level pruning.
32pub fn compare_model_trees(
33    tree_a: &ModelMerkleTree,
34    tree_b: &ModelMerkleTree,
35    detailed: bool,
36) -> ComparisonResult {
37    let mut result = ComparisonResult {
38        changed_params: Vec::new(),
39        unchanged_params: Vec::new(),
40        changed_layers: Vec::new(),
41        total_changed_bytes: 0,
42        total_bytes: 0,
43        change_percentage: 0.0,
44        chunk_changes: BTreeMap::new(),
45        hash_comparisons: 0,
46    };
47
48    let mut total_chunks_a = 0usize;
49    let mut total_chunks_b = 0usize;
50    let mut changed_chunks = 0usize;
51
52    let all_layers: Vec<&String> = tree_a.layer_trees.keys()
53        .chain(tree_b.layer_trees.keys())
54        .collect::<std::collections::BTreeSet<_>>()
55        .into_iter().collect();
56
57    for layer_name in &all_layers {
58        let la = tree_a.layer_trees.get(*layer_name);
59        let lb = tree_b.layer_trees.get(*layer_name);
60
61        match (la, lb) {
62            (None, Some(existing)) | (Some(existing), None) => {
63                result.changed_layers.push((*layer_name).clone());
64                for (pname, ptdata) in &existing.param_trees {
65                    result.changed_params.push(pname.clone());
66                    changed_chunks += ptdata.num_chunks;
67                    if la.is_some() { total_chunks_a += ptdata.num_chunks; }
68                    else { total_chunks_b += ptdata.num_chunks; }
69                }
70            }
71            (Some(la_tree), Some(lb_tree)) => {
72                result.hash_comparisons += 1;
73                if la_tree.layer_root == lb_tree.layer_root {
74                    for (pname, ptd) in &la_tree.param_trees {
75                        result.unchanged_params.push(pname.clone());
76                        total_chunks_a += ptd.num_chunks;
77                    }
78                    for (_, ptd) in &lb_tree.param_trees {
79                        total_chunks_b += ptd.num_chunks;
80                    }
81                    continue;
82                }
83
84                result.changed_layers.push((*layer_name).clone());
85
86                let all_params: std::collections::BTreeSet<&String> =
87                    la_tree.param_trees.keys().chain(lb_tree.param_trees.keys()).collect();
88
89                for pname in all_params {
90                    let pa = la_tree.param_trees.get(pname);
91                    let pb = lb_tree.param_trees.get(pname);
92
93                    if let Some(p) = pa { total_chunks_a += p.num_chunks; }
94                    if let Some(p) = pb { total_chunks_b += p.num_chunks; }
95
96                    match (pa, pb) {
97                        (None, _) | (_, None) => {
98                            result.changed_params.push(pname.clone());
99                            let existing = pa.or(pb).unwrap();
100                            changed_chunks += existing.num_chunks;
101                        }
102                        (Some(pa_data), Some(pb_data)) => {
103                            result.hash_comparisons += 1;
104                            if pa_data.root_hash == pb_data.root_hash {
105                                result.unchanged_params.push(pname.clone());
106                                continue;
107                            }
108                            result.changed_params.push(pname.clone());
109
110                            if detailed {
111                                let ta = MerkleTree::from_data(pa_data.clone());
112                                let tb = MerkleTree::from_data(pb_data.clone());
113                                let (indices, comps) = ta.diff_tree(&tb);
114                                result.hash_comparisons += comps;
115                                changed_chunks += indices.len();
116                                if !indices.is_empty() {
117                                    result.chunk_changes.insert(pname.clone(), indices);
118                                }
119                            } else {
120                                changed_chunks += pa_data.num_chunks.max(pb_data.num_chunks);
121                            }
122                        }
123                    }
124                }
125            }
126            (None, None) => {}
127        }
128    }
129
130    let avg_chunk_size = tree_a.layer_trees.values().next()
131        .and_then(|l| l.param_trees.values().next())
132        .map(|p| p.chunk_size)
133        .unwrap_or(16384);
134
135    result.total_bytes = total_chunks_a.max(total_chunks_b) * avg_chunk_size;
136    result.total_changed_bytes = changed_chunks * avg_chunk_size;
137    result.change_percentage = if result.total_bytes > 0 {
138        (result.total_changed_bytes as f64 / result.total_bytes as f64) * 100.0
139    } else { 0.0 };
140
141    result
142}
143
144/// Estimate bandwidth savings from incremental sync.
145pub fn estimate_sync_savings(
146    tree_a: &ModelMerkleTree,
147    tree_b: &ModelMerkleTree,
148) -> SyncSavings {
149    let diff = compare_model_trees(tree_a, tree_b, true);
150    let full = diff.total_bytes;
151    let incremental = diff.total_changed_bytes;
152    let savings = full.saturating_sub(incremental);
153    let pct = if full > 0 { (savings as f64 / full as f64) * 100.0 } else { 0.0 };
154    SyncSavings { full_sync_bytes: full, incremental_sync_bytes: incremental, savings_bytes: savings, savings_percentage: pct }
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct SyncSavings {
159    pub full_sync_bytes: usize,
160    pub incremental_sync_bytes: usize,
161    pub savings_bytes: usize,
162    pub savings_percentage: f64,
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::hasher::HashAlgorithm;
169    use crate::merkle_tree::build_model_merkle_tree;
170    use std::collections::BTreeMap;
171
172    #[test]
173    fn test_identical_trees() {
174        let mut blobs = BTreeMap::new();
175        blobs.insert("layer.0.weight".into(), vec![1u8; 4096]);
176        blobs.insert("layer.1.weight".into(), vec![2u8; 4096]);
177
178        let t1 = build_model_merkle_tree(&blobs, 1024, "a", HashAlgorithm::Sha256, false);
179        let t2 = build_model_merkle_tree(&blobs, 1024, "b", HashAlgorithm::Sha256, false);
180
181        let result = compare_model_trees(&t1, &t2, true);
182        assert!(result.changed_params.is_empty());
183        assert_eq!(result.change_percentage, 0.0);
184    }
185
186    #[test]
187    fn test_one_param_changed() {
188        let mut blobs1 = BTreeMap::new();
189        blobs1.insert("layer.0.weight".into(), vec![1u8; 4096]);
190        blobs1.insert("layer.1.weight".into(), vec![2u8; 4096]);
191
192        let mut blobs2 = blobs1.clone();
193        blobs2.insert("layer.1.weight".into(), vec![99u8; 4096]);
194
195        let t1 = build_model_merkle_tree(&blobs1, 1024, "a", HashAlgorithm::Sha256, false);
196        let t2 = build_model_merkle_tree(&blobs2, 1024, "b", HashAlgorithm::Sha256, false);
197
198        let result = compare_model_trees(&t1, &t2, true);
199        assert_eq!(result.changed_params, vec!["layer.1.weight"]);
200        assert!(result.change_percentage > 0.0);
201    }
202
203    #[test]
204    fn test_sync_savings() {
205        let mut blobs1 = BTreeMap::new();
206        for i in 0..10 {
207            blobs1.insert(format!("layer.{i}.weight"), vec![i as u8; 4096]);
208        }
209        let mut blobs2 = blobs1.clone();
210        blobs2.insert("layer.0.weight".into(), vec![255u8; 4096]);
211
212        let t1 = build_model_merkle_tree(&blobs1, 1024, "a", HashAlgorithm::Sha256, false);
213        let t2 = build_model_merkle_tree(&blobs2, 1024, "b", HashAlgorithm::Sha256, false);
214
215        let savings = estimate_sync_savings(&t1, &t2);
216        assert!(savings.savings_percentage > 50.0);
217    }
218}