1use std::collections::BTreeMap;
2use serde::{Serialize, Deserialize};
3
4use crate::merkle_tree::{ModelMerkleTree, MerkleTree};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ComparisonResult {
9 pub changed_params: Vec<String>,
10 pub unchanged_params: Vec<String>,
11 pub changed_layers: Vec<String>,
12 pub total_changed_bytes: usize,
13 pub total_bytes: usize,
14 pub change_percentage: f64,
15 pub chunk_changes: BTreeMap<String, Vec<usize>>,
16 pub hash_comparisons: usize,
17}
18
19impl ComparisonResult {
20 pub fn summary(&self) -> String {
21 format!(
22 "Changed: {} params, {} layers, {:.2}% of model ({} hash comparisons)",
23 self.changed_params.len(),
24 self.changed_layers.len(),
25 self.change_percentage,
26 self.hash_comparisons,
27 )
28 }
29}
30
31pub fn compare_model_trees(
33 tree_a: &ModelMerkleTree,
34 tree_b: &ModelMerkleTree,
35 detailed: bool,
36) -> ComparisonResult {
37 let mut result = ComparisonResult {
38 changed_params: Vec::new(),
39 unchanged_params: Vec::new(),
40 changed_layers: Vec::new(),
41 total_changed_bytes: 0,
42 total_bytes: 0,
43 change_percentage: 0.0,
44 chunk_changes: BTreeMap::new(),
45 hash_comparisons: 0,
46 };
47
48 let mut total_chunks_a = 0usize;
49 let mut total_chunks_b = 0usize;
50 let mut changed_chunks = 0usize;
51
52 let all_layers: Vec<&String> = tree_a.layer_trees.keys()
53 .chain(tree_b.layer_trees.keys())
54 .collect::<std::collections::BTreeSet<_>>()
55 .into_iter().collect();
56
57 for layer_name in &all_layers {
58 let la = tree_a.layer_trees.get(*layer_name);
59 let lb = tree_b.layer_trees.get(*layer_name);
60
61 match (la, lb) {
62 (None, Some(existing)) | (Some(existing), None) => {
63 result.changed_layers.push((*layer_name).clone());
64 for (pname, ptdata) in &existing.param_trees {
65 result.changed_params.push(pname.clone());
66 changed_chunks += ptdata.num_chunks;
67 if la.is_some() { total_chunks_a += ptdata.num_chunks; }
68 else { total_chunks_b += ptdata.num_chunks; }
69 }
70 }
71 (Some(la_tree), Some(lb_tree)) => {
72 result.hash_comparisons += 1;
73 if la_tree.layer_root == lb_tree.layer_root {
74 for (pname, ptd) in &la_tree.param_trees {
75 result.unchanged_params.push(pname.clone());
76 total_chunks_a += ptd.num_chunks;
77 }
78 for (_, ptd) in &lb_tree.param_trees {
79 total_chunks_b += ptd.num_chunks;
80 }
81 continue;
82 }
83
84 result.changed_layers.push((*layer_name).clone());
85
86 let all_params: std::collections::BTreeSet<&String> =
87 la_tree.param_trees.keys().chain(lb_tree.param_trees.keys()).collect();
88
89 for pname in all_params {
90 let pa = la_tree.param_trees.get(pname);
91 let pb = lb_tree.param_trees.get(pname);
92
93 if let Some(p) = pa { total_chunks_a += p.num_chunks; }
94 if let Some(p) = pb { total_chunks_b += p.num_chunks; }
95
96 match (pa, pb) {
97 (None, _) | (_, None) => {
98 result.changed_params.push(pname.clone());
99 let existing = pa.or(pb).unwrap();
100 changed_chunks += existing.num_chunks;
101 }
102 (Some(pa_data), Some(pb_data)) => {
103 result.hash_comparisons += 1;
104 if pa_data.root_hash == pb_data.root_hash {
105 result.unchanged_params.push(pname.clone());
106 continue;
107 }
108 result.changed_params.push(pname.clone());
109
110 if detailed {
111 let ta = MerkleTree::from_data(pa_data.clone());
112 let tb = MerkleTree::from_data(pb_data.clone());
113 let (indices, comps) = ta.diff_tree(&tb);
114 result.hash_comparisons += comps;
115 changed_chunks += indices.len();
116 if !indices.is_empty() {
117 result.chunk_changes.insert(pname.clone(), indices);
118 }
119 } else {
120 changed_chunks += pa_data.num_chunks.max(pb_data.num_chunks);
121 }
122 }
123 }
124 }
125 }
126 (None, None) => {}
127 }
128 }
129
130 let avg_chunk_size = tree_a.layer_trees.values().next()
131 .and_then(|l| l.param_trees.values().next())
132 .map(|p| p.chunk_size)
133 .unwrap_or(16384);
134
135 result.total_bytes = total_chunks_a.max(total_chunks_b) * avg_chunk_size;
136 result.total_changed_bytes = changed_chunks * avg_chunk_size;
137 result.change_percentage = if result.total_bytes > 0 {
138 (result.total_changed_bytes as f64 / result.total_bytes as f64) * 100.0
139 } else { 0.0 };
140
141 result
142}
143
144pub fn estimate_sync_savings(
146 tree_a: &ModelMerkleTree,
147 tree_b: &ModelMerkleTree,
148) -> SyncSavings {
149 let diff = compare_model_trees(tree_a, tree_b, true);
150 let full = diff.total_bytes;
151 let incremental = diff.total_changed_bytes;
152 let savings = full.saturating_sub(incremental);
153 let pct = if full > 0 { (savings as f64 / full as f64) * 100.0 } else { 0.0 };
154 SyncSavings { full_sync_bytes: full, incremental_sync_bytes: incremental, savings_bytes: savings, savings_percentage: pct }
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct SyncSavings {
159 pub full_sync_bytes: usize,
160 pub incremental_sync_bytes: usize,
161 pub savings_bytes: usize,
162 pub savings_percentage: f64,
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::hasher::HashAlgorithm;
169 use crate::merkle_tree::build_model_merkle_tree;
170 use std::collections::BTreeMap;
171
172 #[test]
173 fn test_identical_trees() {
174 let mut blobs = BTreeMap::new();
175 blobs.insert("layer.0.weight".into(), vec![1u8; 4096]);
176 blobs.insert("layer.1.weight".into(), vec![2u8; 4096]);
177
178 let t1 = build_model_merkle_tree(&blobs, 1024, "a", HashAlgorithm::Sha256, false);
179 let t2 = build_model_merkle_tree(&blobs, 1024, "b", HashAlgorithm::Sha256, false);
180
181 let result = compare_model_trees(&t1, &t2, true);
182 assert!(result.changed_params.is_empty());
183 assert_eq!(result.change_percentage, 0.0);
184 }
185
186 #[test]
187 fn test_one_param_changed() {
188 let mut blobs1 = BTreeMap::new();
189 blobs1.insert("layer.0.weight".into(), vec![1u8; 4096]);
190 blobs1.insert("layer.1.weight".into(), vec![2u8; 4096]);
191
192 let mut blobs2 = blobs1.clone();
193 blobs2.insert("layer.1.weight".into(), vec![99u8; 4096]);
194
195 let t1 = build_model_merkle_tree(&blobs1, 1024, "a", HashAlgorithm::Sha256, false);
196 let t2 = build_model_merkle_tree(&blobs2, 1024, "b", HashAlgorithm::Sha256, false);
197
198 let result = compare_model_trees(&t1, &t2, true);
199 assert_eq!(result.changed_params, vec!["layer.1.weight"]);
200 assert!(result.change_percentage > 0.0);
201 }
202
203 #[test]
204 fn test_sync_savings() {
205 let mut blobs1 = BTreeMap::new();
206 for i in 0..10 {
207 blobs1.insert(format!("layer.{i}.weight"), vec![i as u8; 4096]);
208 }
209 let mut blobs2 = blobs1.clone();
210 blobs2.insert("layer.0.weight".into(), vec![255u8; 4096]);
211
212 let t1 = build_model_merkle_tree(&blobs1, 1024, "a", HashAlgorithm::Sha256, false);
213 let t2 = build_model_merkle_tree(&blobs2, 1024, "b", HashAlgorithm::Sha256, false);
214
215 let savings = estimate_sync_savings(&t1, &t2);
216 assert!(savings.savings_percentage > 50.0);
217 }
218}