use super::protocol::{Card, Message};
use crate::artifact::blob;
use crate::error::{FossilError, Result};
use crate::hash;
use crate::repo::Repository;
use std::collections::HashSet;
pub struct SyncClient<'a> {
repo: &'a Repository,
url: String,
projectcode: String,
servercode: String,
cookie: Option<String>,
phantoms: HashSet<String>,
received_artifacts: usize,
sent_artifacts: usize,
}
impl<'a> SyncClient<'a> {
pub fn new(repo: &'a Repository, url: &str) -> Result<Self> {
let projectcode = repo.project_code()?;
let servercode = hash::sha3_256_hex(uuid::Uuid::new_v4().to_string().as_bytes());
Ok(Self {
repo,
url: url.trim_end_matches('/').to_string(),
projectcode,
servercode,
cookie: None,
phantoms: HashSet::new(),
received_artifacts: 0,
sent_artifacts: 0,
})
}
pub fn pull(&mut self, username: &str, password: &str) -> Result<SyncStats> {
let mut total_received = 0;
let mut rounds = 0;
loop {
rounds += 1;
let response = self.pull_round(username, password)?;
let received_this_round = response
.cards
.iter()
.filter(|c| matches!(c, Card::File { .. } | Card::CFile { .. }))
.count();
total_received += received_this_round;
for card in &response.cards {
if let Card::Error { message } = card {
return Err(FossilError::SyncError(message.clone()));
}
}
let mut new_phantoms = 0;
for card in &response.cards {
if let Card::Igot {
artifact_id,
is_private,
} = card
{
if !*is_private && !self.has_artifact(artifact_id)? {
self.phantoms.insert(artifact_id.clone());
new_phantoms += 1;
}
}
}
if new_phantoms == 0 && received_this_round == 0 {
break;
}
if rounds > 100 {
return Err(FossilError::SyncError("Too many sync rounds".to_string()));
}
}
Ok(SyncStats {
received: total_received,
sent: 0,
rounds,
})
}
fn pull_round(&mut self, username: &str, password: &str) -> Result<Message> {
let mut request = Message::new();
let mut payload = Message::new();
payload.add(Card::Pragma {
name: "client-version".to_string(),
values: vec!["25000".to_string()],
});
payload.add(Card::Pull {
servercode: self.servercode.clone(),
projectcode: self.projectcode.clone(),
});
if let Some(ref cookie) = self.cookie {
payload.add(Card::Cookie {
payload: cookie.clone(),
});
}
for phantom in self.phantoms.iter().take(200) {
payload.add(Card::Gimme {
artifact_id: phantom.clone(),
});
}
let payload_text = payload.to_text()?;
request.add(Message::create_login(username, password, &payload_text));
for card in payload.cards {
request.add(card);
}
let response = self.send_request(&request)?;
for card in &response.cards {
match card {
Card::File {
artifact_id,
delta_source,
content,
} => {
self.store_artifact(artifact_id, delta_source.as_deref(), content, false)?;
self.phantoms.remove(artifact_id);
self.received_artifacts += 1;
}
Card::CFile {
artifact_id,
delta_source,
content,
..
} => {
self.store_artifact(artifact_id, delta_source.as_deref(), content, true)?;
self.phantoms.remove(artifact_id);
self.received_artifacts += 1;
}
Card::Cookie { payload } => {
self.cookie = Some(payload.clone());
}
_ => {}
}
}
Ok(response)
}
pub fn push(&mut self, username: &str, password: &str) -> Result<SyncStats> {
let mut total_sent = 0;
let mut rounds = 0;
let mut gimme_queue: HashSet<String> = HashSet::new();
let unclustered = self.get_unclustered_artifacts()?;
loop {
rounds += 1;
let (response, sent_this_round) =
self.push_round(username, password, &unclustered, &gimme_queue)?;
total_sent += sent_this_round;
for card in &response.cards {
if let Card::Error { message } = card {
return Err(FossilError::SyncError(message.clone()));
}
}
gimme_queue.clear();
for card in &response.cards {
if let Card::Gimme { artifact_id } = card {
gimme_queue.insert(artifact_id.clone());
}
}
if gimme_queue.is_empty() && sent_this_round == 0 {
break;
}
if rounds > 100 {
return Err(FossilError::SyncError("Too many sync rounds".to_string()));
}
}
Ok(SyncStats {
received: 0,
sent: total_sent,
rounds,
})
}
fn push_round(
&mut self,
username: &str,
password: &str,
unclustered: &[String],
gimme_queue: &HashSet<String>,
) -> Result<(Message, usize)> {
let mut request = Message::new();
let mut sent_count = 0;
let mut payload = Message::new();
payload.add(Card::Pragma {
name: "client-version".to_string(),
values: vec!["25000".to_string()],
});
payload.add(Card::Push {
servercode: self.servercode.clone(),
projectcode: self.projectcode.clone(),
});
if let Some(ref cookie) = self.cookie {
payload.add(Card::Cookie {
payload: cookie.clone(),
});
}
let mut total_size = 0;
let max_size = 1024 * 1024;
for artifact_id in gimme_queue {
if total_size > max_size {
break;
}
if let Ok(content) = self.get_artifact_content(artifact_id) {
payload.add(Card::File {
artifact_id: artifact_id.clone(),
delta_source: None,
content: content.clone(),
});
total_size += content.len();
sent_count += 1;
}
}
for artifact_id in unclustered.iter().take(500) {
payload.add(Card::Igot {
artifact_id: artifact_id.clone(),
is_private: false,
});
}
let payload_text = payload.to_text()?;
request.add(Message::create_login(username, password, &payload_text));
for card in payload.cards {
request.add(card);
}
let response = self.send_request(&request)?;
for card in &response.cards {
if let Card::Cookie { payload } = card {
self.cookie = Some(payload.clone());
}
}
self.sent_artifacts += sent_count;
Ok((response, sent_count))
}
pub fn sync(&mut self, username: &str, password: &str) -> Result<SyncStats> {
let pull_stats = self.pull(username, password)?;
let push_stats = self.push(username, password)?;
Ok(SyncStats {
received: pull_stats.received,
sent: push_stats.sent,
rounds: pull_stats.rounds + push_stats.rounds,
})
}
fn send_request(&self, request: &Message) -> Result<Message> {
use std::io::Write;
use std::process::{Command, Stdio};
let body = request.encode()?;
let path = self.url.strip_prefix("file://").ok_or_else(|| {
FossilError::SyncError(
"SyncClient only supports file:// URLs. Use QUIC sync for network sync."
.to_string(),
)
})?;
let http_req = format!(
"POST /xfer HTTP/1.0\r\nContent-Type: application/x-heroforge\r\nContent-Length: {}\r\n\r\n",
body.len()
);
let mut full_req = http_req.into_bytes();
full_req.extend_from_slice(&body);
let mut child = Command::new("heroforge")
.args(["http", path])
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.map_err(|e| FossilError::SyncError(format!("Failed to run heroforge http: {}", e)))?;
if let Some(mut stdin) = child.stdin.take() {
stdin.write_all(&full_req)?;
}
let output = child.wait_with_output()?;
if let Some(pos) = output.stdout.windows(4).position(|w| w == b"\r\n\r\n") {
let body_start = pos + 4;
let body_bytes = &output.stdout[body_start..];
Message::decode(body_bytes)
} else {
Err(FossilError::SyncError("Invalid HTTP response".to_string()))
}
}
fn has_artifact(&self, hash: &str) -> Result<bool> {
match self.repo.database().get_rid_by_hash(hash) {
Ok(_) => Ok(true),
Err(_) => Ok(false),
}
}
fn store_artifact(
&self,
artifact_id: &str,
delta_source: Option<&str>,
content: &[u8],
is_compressed: bool,
) -> Result<()> {
if delta_source.is_some() {
return Ok(());
}
let data = if is_compressed {
blob::decompress(content)?
} else {
content.to_vec()
};
let computed_hash = hash::sha3_256_hex(&data);
if !artifact_id.starts_with(&computed_hash[..artifact_id.len().min(computed_hash.len())]) {
return Ok(());
}
let compressed = blob::compress(&data)?;
self.repo
.database()
.insert_blob(&compressed, &computed_hash, data.len() as i64)?;
Ok(())
}
fn get_artifact_content(&self, hash: &str) -> Result<Vec<u8>> {
blob::get_artifact_by_hash(self.repo.database(), hash)
}
fn get_unclustered_artifacts(&self) -> Result<Vec<String>> {
let mut stmt = self
.repo
.database()
.connection()
.prepare("SELECT uuid FROM blob WHERE rid IN (SELECT rid FROM unclustered)")?;
let hashes: Vec<String> = stmt
.query_map([], |row| row.get(0))?
.filter_map(|r| r.ok())
.collect();
Ok(hashes)
}
}
#[derive(Debug, Clone)]
pub struct SyncStats {
pub received: usize,
pub sent: usize,
pub rounds: usize,
}
impl std::fmt::Display for SyncStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Received: {}, Sent: {}, Rounds: {}",
self.received, self.sent, self.rounds
)
}
}