1use super::protocol::{Card, Message};
8use crate::artifact::blob;
9use crate::error::{FossilError, Result};
10use crate::hash;
11use crate::repo::Repository;
12use std::collections::HashSet;
13
14pub struct SyncClient<'a> {
16 repo: &'a Repository,
17 url: String,
18 projectcode: String,
19 servercode: String,
20 cookie: Option<String>,
21 phantoms: HashSet<String>,
22 received_artifacts: usize,
23 sent_artifacts: usize,
24}
25
26impl<'a> SyncClient<'a> {
27 pub fn new(repo: &'a Repository, url: &str) -> Result<Self> {
34 let projectcode = repo.project_code()?;
35 let servercode = hash::sha3_256_hex(uuid::Uuid::new_v4().to_string().as_bytes());
36
37 Ok(Self {
38 repo,
39 url: url.trim_end_matches('/').to_string(),
40 projectcode,
41 servercode,
42 cookie: None,
43 phantoms: HashSet::new(),
44 received_artifacts: 0,
45 sent_artifacts: 0,
46 })
47 }
48
49 pub fn pull(&mut self, username: &str, password: &str) -> Result<SyncStats> {
51 let mut total_received = 0;
52 let mut rounds = 0;
53
54 loop {
55 rounds += 1;
56 let response = self.pull_round(username, password)?;
57
58 let received_this_round = response
59 .cards
60 .iter()
61 .filter(|c| matches!(c, Card::File { .. } | Card::CFile { .. }))
62 .count();
63
64 total_received += received_this_round;
65
66 for card in &response.cards {
67 if let Card::Error { message } = card {
68 return Err(FossilError::SyncError(message.clone()));
69 }
70 }
71
72 let mut new_phantoms = 0;
73 for card in &response.cards {
74 if let Card::Igot {
75 artifact_id,
76 is_private,
77 } = card
78 {
79 if !*is_private && !self.has_artifact(artifact_id)? {
80 self.phantoms.insert(artifact_id.clone());
81 new_phantoms += 1;
82 }
83 }
84 }
85
86 if new_phantoms == 0 && received_this_round == 0 {
87 break;
88 }
89
90 if rounds > 100 {
91 return Err(FossilError::SyncError("Too many sync rounds".to_string()));
92 }
93 }
94
95 Ok(SyncStats {
96 received: total_received,
97 sent: 0,
98 rounds,
99 })
100 }
101
102 fn pull_round(&mut self, username: &str, password: &str) -> Result<Message> {
103 let mut request = Message::new();
104 let mut payload = Message::new();
105
106 payload.add(Card::Pragma {
107 name: "client-version".to_string(),
108 values: vec!["25000".to_string()],
109 });
110
111 payload.add(Card::Pull {
112 servercode: self.servercode.clone(),
113 projectcode: self.projectcode.clone(),
114 });
115
116 if let Some(ref cookie) = self.cookie {
117 payload.add(Card::Cookie {
118 payload: cookie.clone(),
119 });
120 }
121
122 for phantom in self.phantoms.iter().take(200) {
123 payload.add(Card::Gimme {
124 artifact_id: phantom.clone(),
125 });
126 }
127
128 let payload_text = payload.to_text()?;
129 request.add(Message::create_login(username, password, &payload_text));
130
131 for card in payload.cards {
132 request.add(card);
133 }
134
135 let response = self.send_request(&request)?;
136
137 for card in &response.cards {
138 match card {
139 Card::File {
140 artifact_id,
141 delta_source,
142 content,
143 } => {
144 self.store_artifact(artifact_id, delta_source.as_deref(), content, false)?;
145 self.phantoms.remove(artifact_id);
146 self.received_artifacts += 1;
147 }
148 Card::CFile {
149 artifact_id,
150 delta_source,
151 content,
152 ..
153 } => {
154 self.store_artifact(artifact_id, delta_source.as_deref(), content, true)?;
155 self.phantoms.remove(artifact_id);
156 self.received_artifacts += 1;
157 }
158 Card::Cookie { payload } => {
159 self.cookie = Some(payload.clone());
160 }
161 _ => {}
162 }
163 }
164
165 Ok(response)
166 }
167
168 pub fn push(&mut self, username: &str, password: &str) -> Result<SyncStats> {
170 let mut total_sent = 0;
171 let mut rounds = 0;
172 let mut gimme_queue: HashSet<String> = HashSet::new();
173
174 let unclustered = self.get_unclustered_artifacts()?;
175
176 loop {
177 rounds += 1;
178 let (response, sent_this_round) =
179 self.push_round(username, password, &unclustered, &gimme_queue)?;
180
181 total_sent += sent_this_round;
182
183 for card in &response.cards {
184 if let Card::Error { message } = card {
185 return Err(FossilError::SyncError(message.clone()));
186 }
187 }
188
189 gimme_queue.clear();
190 for card in &response.cards {
191 if let Card::Gimme { artifact_id } = card {
192 gimme_queue.insert(artifact_id.clone());
193 }
194 }
195
196 if gimme_queue.is_empty() && sent_this_round == 0 {
197 break;
198 }
199
200 if rounds > 100 {
201 return Err(FossilError::SyncError("Too many sync rounds".to_string()));
202 }
203 }
204
205 Ok(SyncStats {
206 received: 0,
207 sent: total_sent,
208 rounds,
209 })
210 }
211
212 fn push_round(
213 &mut self,
214 username: &str,
215 password: &str,
216 unclustered: &[String],
217 gimme_queue: &HashSet<String>,
218 ) -> Result<(Message, usize)> {
219 let mut request = Message::new();
220 let mut sent_count = 0;
221
222 let mut payload = Message::new();
223
224 payload.add(Card::Pragma {
225 name: "client-version".to_string(),
226 values: vec!["25000".to_string()],
227 });
228
229 payload.add(Card::Push {
230 servercode: self.servercode.clone(),
231 projectcode: self.projectcode.clone(),
232 });
233
234 if let Some(ref cookie) = self.cookie {
235 payload.add(Card::Cookie {
236 payload: cookie.clone(),
237 });
238 }
239
240 let mut total_size = 0;
241 let max_size = 1024 * 1024;
242
243 for artifact_id in gimme_queue {
244 if total_size > max_size {
245 break;
246 }
247
248 if let Ok(content) = self.get_artifact_content(artifact_id) {
249 payload.add(Card::File {
250 artifact_id: artifact_id.clone(),
251 delta_source: None,
252 content: content.clone(),
253 });
254 total_size += content.len();
255 sent_count += 1;
256 }
257 }
258
259 for artifact_id in unclustered.iter().take(500) {
260 payload.add(Card::Igot {
261 artifact_id: artifact_id.clone(),
262 is_private: false,
263 });
264 }
265
266 let payload_text = payload.to_text()?;
267
268 request.add(Message::create_login(username, password, &payload_text));
269
270 for card in payload.cards {
271 request.add(card);
272 }
273
274 let response = self.send_request(&request)?;
275
276 for card in &response.cards {
277 if let Card::Cookie { payload } = card {
278 self.cookie = Some(payload.clone());
279 }
280 }
281
282 self.sent_artifacts += sent_count;
283
284 Ok((response, sent_count))
285 }
286
287 pub fn sync(&mut self, username: &str, password: &str) -> Result<SyncStats> {
289 let pull_stats = self.pull(username, password)?;
290 let push_stats = self.push(username, password)?;
291
292 Ok(SyncStats {
293 received: pull_stats.received,
294 sent: push_stats.sent,
295 rounds: pull_stats.rounds + push_stats.rounds,
296 })
297 }
298
299 fn send_request(&self, request: &Message) -> Result<Message> {
301 use std::io::Write;
302 use std::process::{Command, Stdio};
303
304 let body = request.encode()?;
305
306 let path = self.url.strip_prefix("file://").ok_or_else(|| {
308 FossilError::SyncError(
309 "SyncClient only supports file:// URLs. Use QUIC sync for network sync."
310 .to_string(),
311 )
312 })?;
313
314 let http_req = format!(
315 "POST /xfer HTTP/1.0\r\nContent-Type: application/x-heroforge\r\nContent-Length: {}\r\n\r\n",
316 body.len()
317 );
318
319 let mut full_req = http_req.into_bytes();
320 full_req.extend_from_slice(&body);
321
322 let mut child = Command::new("heroforge")
323 .args(["http", path])
324 .stdin(Stdio::piped())
325 .stdout(Stdio::piped())
326 .spawn()
327 .map_err(|e| FossilError::SyncError(format!("Failed to run heroforge http: {}", e)))?;
328
329 if let Some(mut stdin) = child.stdin.take() {
330 stdin.write_all(&full_req)?;
331 }
332
333 let output = child.wait_with_output()?;
334
335 if let Some(pos) = output.stdout.windows(4).position(|w| w == b"\r\n\r\n") {
336 let body_start = pos + 4;
337 let body_bytes = &output.stdout[body_start..];
338 Message::decode(body_bytes)
339 } else {
340 Err(FossilError::SyncError("Invalid HTTP response".to_string()))
341 }
342 }
343
344 fn has_artifact(&self, hash: &str) -> Result<bool> {
345 match self.repo.database().get_rid_by_hash(hash) {
346 Ok(_) => Ok(true),
347 Err(_) => Ok(false),
348 }
349 }
350
351 fn store_artifact(
352 &self,
353 artifact_id: &str,
354 delta_source: Option<&str>,
355 content: &[u8],
356 is_compressed: bool,
357 ) -> Result<()> {
358 if delta_source.is_some() {
359 return Ok(());
360 }
361
362 let data = if is_compressed {
363 blob::decompress(content)?
364 } else {
365 content.to_vec()
366 };
367
368 let computed_hash = hash::sha3_256_hex(&data);
369 if !artifact_id.starts_with(&computed_hash[..artifact_id.len().min(computed_hash.len())]) {
370 return Ok(());
371 }
372
373 let compressed = blob::compress(&data)?;
374 self.repo
375 .database()
376 .insert_blob(&compressed, &computed_hash, data.len() as i64)?;
377
378 Ok(())
379 }
380
381 fn get_artifact_content(&self, hash: &str) -> Result<Vec<u8>> {
382 blob::get_artifact_by_hash(self.repo.database(), hash)
383 }
384
385 fn get_unclustered_artifacts(&self) -> Result<Vec<String>> {
386 let mut stmt = self
387 .repo
388 .database()
389 .connection()
390 .prepare("SELECT uuid FROM blob WHERE rid IN (SELECT rid FROM unclustered)")?;
391
392 let hashes: Vec<String> = stmt
393 .query_map([], |row| row.get(0))?
394 .filter_map(|r| r.ok())
395 .collect();
396
397 Ok(hashes)
398 }
399}
400
401#[derive(Debug, Clone)]
403pub struct SyncStats {
404 pub received: usize,
406 pub sent: usize,
408 pub rounds: usize,
410}
411
412impl std::fmt::Display for SyncStats {
413 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
414 write!(
415 f,
416 "Received: {}, Sent: {}, Rounds: {}",
417 self.received, self.sent, self.rounds
418 )
419 }
420}