use crate::artifact::delta::apply_delta;
use crate::error::{FossilError, Result};
use crate::repo::Repository;
use quinn::{ClientConfig, Endpoint, RecvStream, SendStream, ServerConfig};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use std::collections::HashSet;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
const HASH_CHUNK_SIZE: usize = 1000;
fn err<E: std::fmt::Display>(e: E) -> FossilError {
FossilError::SyncError(e.to_string())
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MsgType {
Hello = 1,
HelloAck = 2,
HelloNak = 3,
HashBatch = 4,
HashEnd = 5,
NeedBatch = 6,
NeedEnd = 7,
Artifact = 8,
DeltaArtifact = 9,
Done = 10,
RawBlob = 11, RawDeltaBlob = 12, }
impl TryFrom<u8> for MsgType {
type Error = FossilError;
fn try_from(v: u8) -> Result<Self> {
match v {
1 => Ok(MsgType::Hello),
2 => Ok(MsgType::HelloAck),
3 => Ok(MsgType::HelloNak),
4 => Ok(MsgType::HashBatch),
5 => Ok(MsgType::HashEnd),
6 => Ok(MsgType::NeedBatch),
7 => Ok(MsgType::NeedEnd),
8 => Ok(MsgType::Artifact),
9 => Ok(MsgType::DeltaArtifact),
10 => Ok(MsgType::Done),
11 => Ok(MsgType::RawBlob),
12 => Ok(MsgType::RawDeltaBlob),
_ => Err(FossilError::SyncError(format!("Unknown msg type: {}", v))),
}
}
}
#[derive(Debug, Default)]
pub struct SyncStats {
pub artifacts_sent: usize,
pub artifacts_received: usize,
pub bytes_sent: usize,
pub bytes_received: usize,
pub skipped: usize,
}
pub type ProgressCallback = Box<dyn Fn(&str) + Send + Sync>;
fn default_progress(msg: &str) {
eprintln!("{}", msg);
}
fn generate_self_signed_cert() -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).map_err(err)?;
let key = PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()).into();
let cert_der = CertificateDer::from(cert.cert.der().to_vec());
Ok((vec![cert_der], key))
}
fn make_server_config() -> Result<ServerConfig> {
let (certs, key) = generate_self_signed_cert()?;
let crypto = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(err)?;
let config = ServerConfig::with_crypto(Arc::new(
quinn::crypto::rustls::QuicServerConfig::try_from(crypto).map_err(err)?,
));
Ok(config)
}
fn make_client_config() -> Result<ClientConfig> {
let crypto = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(SkipServerVerification))
.with_no_client_auth();
let config = ClientConfig::new(Arc::new(
quinn::crypto::rustls::QuicClientConfig::try_from(crypto).map_err(err)?,
));
Ok(config)
}
#[derive(Debug)]
struct SkipServerVerification;
impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
fn verify_server_cert(
&self,
_: &CertificateDer<'_>,
_: &[CertificateDer<'_>],
_: &rustls::pki_types::ServerName<'_>,
_: &[u8],
_: rustls::pki_types::UnixTime,
) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_: &[u8],
_: &CertificateDer<'_>,
_: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_: &[u8],
_: &CertificateDer<'_>,
_: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ED25519,
]
}
}
fn get_all_hashes(repo: &Repository) -> Result<HashSet<String>> {
let mut hashes = HashSet::new();
let conn = repo.database().connection();
let mut stmt = conn.prepare("SELECT uuid FROM blob")?;
let rows = stmt.query_map([], |row| row.get::<_, String>(0))?;
for hash in rows.flatten() {
hashes.insert(hash);
}
Ok(hashes)
}
fn decompress(data: &[u8]) -> Result<Vec<u8>> {
lz4_flex::decompress_size_prepended(data).map_err(err)
}
async fn send_hashes(send: &mut SendStream, hashes: &HashSet<String>) -> Result<usize> {
let hash_vec: Vec<&String> = hashes.iter().collect();
let mut bytes = 0;
for chunk in hash_vec.chunks(HASH_CHUNK_SIZE) {
send.write_u8(MsgType::HashBatch as u8).await.map_err(err)?;
send.write_u32(chunk.len() as u32).await.map_err(err)?;
bytes += 5;
for hash in chunk {
let hash_bytes = hash.as_bytes();
send.write_u8(hash_bytes.len() as u8).await.map_err(err)?;
send.write_all(hash_bytes).await.map_err(err)?;
bytes += 1 + hash_bytes.len();
}
}
send.write_u8(MsgType::HashEnd as u8).await.map_err(err)?;
bytes += 1;
Ok(bytes)
}
async fn receive_hashes(recv: &mut RecvStream) -> Result<(HashSet<String>, usize)> {
let mut hashes = HashSet::new();
let mut bytes = 0;
loop {
let b = recv.read_u8().await.map_err(err)?;
bytes += 1;
let msg_type = MsgType::try_from(b)?;
match msg_type {
MsgType::HashBatch => {
let count = recv.read_u32().await.map_err(err)?;
bytes += 4;
for _ in 0..count {
let hash_len = recv.read_u8().await.map_err(err)? as usize;
let mut buf = vec![0u8; hash_len];
recv.read_exact(&mut buf).await.map_err(err)?;
hashes.insert(String::from_utf8_lossy(&buf).to_string());
bytes += 1 + hash_len;
}
}
MsgType::HashEnd => break,
_ => {
return Err(FossilError::SyncError(format!(
"Unexpected msg in hash exchange: {:?}",
msg_type
)))
}
}
}
Ok((hashes, bytes))
}
async fn send_needs(send: &mut SendStream, needs: &[String]) -> Result<usize> {
let mut bytes = 0;
for chunk in needs.chunks(HASH_CHUNK_SIZE) {
send.write_u8(MsgType::NeedBatch as u8).await.map_err(err)?;
send.write_u32(chunk.len() as u32).await.map_err(err)?;
bytes += 5;
for hash in chunk {
let hash_bytes = hash.as_bytes();
send.write_u8(hash_bytes.len() as u8).await.map_err(err)?;
send.write_all(hash_bytes).await.map_err(err)?;
bytes += 1 + hash_bytes.len();
}
}
send.write_u8(MsgType::NeedEnd as u8).await.map_err(err)?;
bytes += 1;
Ok(bytes)
}
async fn receive_needs(recv: &mut RecvStream) -> Result<(Vec<String>, usize)> {
let mut needs = Vec::new();
let mut bytes = 0;
loop {
let b = recv.read_u8().await.map_err(err)?;
let msg_type = MsgType::try_from(b)?;
bytes += 1;
match msg_type {
MsgType::NeedBatch => {
let count = recv.read_u32().await.map_err(err)?;
bytes += 4;
for _ in 0..count {
let hash_len = recv.read_u8().await.map_err(err)? as usize;
let mut buf = vec![0u8; hash_len];
recv.read_exact(&mut buf).await.map_err(err)?;
needs.push(String::from_utf8_lossy(&buf).to_string());
bytes += 1 + hash_len;
}
}
MsgType::NeedEnd => break,
_ => {
return Err(FossilError::SyncError(format!(
"Unexpected msg in need exchange: {:?}",
msg_type
)))
}
}
}
Ok((needs, bytes))
}
async fn send_artifact(
send: &mut SendStream,
repo: &Repository,
hash: &str,
_remote_hashes: &HashSet<String>,
) -> Result<usize> {
let (raw_content, size, delta_source) = repo.database().get_blob_for_sync(hash)?;
let hash_bytes = hash.as_bytes();
if let Some(src_hash) = delta_source {
let src_hash_bytes = src_hash.as_bytes();
send.write_u8(MsgType::RawDeltaBlob as u8)
.await
.map_err(err)?;
send.write_u8(hash_bytes.len() as u8).await.map_err(err)?;
send.write_all(hash_bytes).await.map_err(err)?;
send.write_u8(src_hash_bytes.len() as u8)
.await
.map_err(err)?;
send.write_all(src_hash_bytes).await.map_err(err)?;
send.write_i64(size).await.map_err(err)?;
send.write_u32(raw_content.len() as u32)
.await
.map_err(err)?;
send.write_all(&raw_content).await.map_err(err)?;
Ok(1 + 1 + hash_bytes.len() + 1 + src_hash_bytes.len() + 8 + 4 + raw_content.len())
} else {
send.write_u8(MsgType::RawBlob as u8).await.map_err(err)?;
send.write_u8(hash_bytes.len() as u8).await.map_err(err)?;
send.write_all(hash_bytes).await.map_err(err)?;
send.write_i64(size).await.map_err(err)?;
send.write_u32(raw_content.len() as u32)
.await
.map_err(err)?;
send.write_all(&raw_content).await.map_err(err)?;
Ok(1 + 1 + hash_bytes.len() + 8 + 4 + raw_content.len())
}
}
async fn receive_artifact(
recv: &mut RecvStream,
repo: &Repository,
msg_type: MsgType,
) -> Result<(String, usize)> {
let hash_len = recv.read_u8().await.map_err(err)? as usize;
let mut hash_buf = vec![0u8; hash_len];
recv.read_exact(&mut hash_buf).await.map_err(err)?;
let hash = String::from_utf8_lossy(&hash_buf).to_string();
match msg_type {
MsgType::RawBlob => {
let size = recv.read_i64().await.map_err(err)?;
let content_len = recv.read_u32().await.map_err(err)? as usize;
let mut raw_content = vec![0u8; content_len];
recv.read_exact(&mut raw_content).await.map_err(err)?;
repo.database().insert_raw_blob(&raw_content, &hash, size)?;
Ok((hash, 1 + hash_len + 8 + 4 + content_len))
}
MsgType::RawDeltaBlob => {
let src_hash_len = recv.read_u8().await.map_err(err)? as usize;
let mut src_hash_buf = vec![0u8; src_hash_len];
recv.read_exact(&mut src_hash_buf).await.map_err(err)?;
let src_hash = String::from_utf8_lossy(&src_hash_buf).to_string();
let size = recv.read_i64().await.map_err(err)?;
let content_len = recv.read_u32().await.map_err(err)? as usize;
let mut raw_content = vec![0u8; content_len];
recv.read_exact(&mut raw_content).await.map_err(err)?;
let rid = repo.database().insert_raw_blob(&raw_content, &hash, size)?;
if let Ok(src_rid) = repo.database().get_rid_by_hash(&src_hash) {
repo.database().insert_delta(rid, src_rid)?;
}
Ok((hash, 1 + hash_len + 1 + src_hash_len + 8 + 4 + content_len))
}
MsgType::DeltaArtifact => {
let base_hash_len = recv.read_u8().await.map_err(err)? as usize;
let mut base_buf = vec![0u8; base_hash_len];
recv.read_exact(&mut base_buf).await.map_err(err)?;
let base_hash = String::from_utf8_lossy(&base_buf).to_string();
let size = recv.read_u32().await.map_err(err)?;
let mut compressed = vec![0u8; size as usize];
recv.read_exact(&mut compressed).await.map_err(err)?;
let delta = decompress(&compressed)?;
let base = crate::artifact::blob::get_artifact_by_hash(repo.database(), &base_hash)?;
let content = apply_delta(&base, &delta)?;
let computed = if hash.len() == 40 {
crate::hash::sha1_hex(&content)
} else {
crate::hash::sha3_256_hex(&content)
};
if computed != hash {
return Err(FossilError::SyncError(format!(
"Hash mismatch: expected {} got {}",
hash, computed
)));
}
let blob_compressed = crate::artifact::blob::compress(&content)?;
repo.database()
.insert_blob(&blob_compressed, &hash, content.len() as i64)?;
Ok((hash, 1 + hash_len + 1 + base_hash_len + 4 + size as usize))
}
MsgType::Artifact => {
let size = recv.read_u32().await.map_err(err)?;
let mut compressed = vec![0u8; size as usize];
recv.read_exact(&mut compressed).await.map_err(err)?;
let content = decompress(&compressed)?;
let computed = if hash.len() == 40 {
crate::hash::sha1_hex(&content)
} else {
crate::hash::sha3_256_hex(&content)
};
if computed != hash {
return Err(FossilError::SyncError(format!(
"Hash mismatch: expected {} got {}",
hash, computed
)));
}
let blob_compressed = crate::artifact::blob::compress(&content)?;
repo.database()
.insert_blob(&blob_compressed, &hash, content.len() as i64)?;
Ok((hash, 1 + hash_len + 4 + size as usize))
}
_ => Err(FossilError::SyncError(format!(
"Unexpected message type in receive_artifact: {:?}",
msg_type
))),
}
}
pub struct QuicServer {
endpoint: Endpoint,
}
impl QuicServer {
pub fn bind(addr: &str) -> Result<Self> {
let addr: SocketAddr = addr.parse().map_err(err)?;
let config = make_server_config()?;
let endpoint = Endpoint::server(config, addr)?;
Ok(Self { endpoint })
}
pub fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.endpoint.local_addr()?)
}
pub async fn handle_sync(&self, repo: &Repository, _repo_path: &PathBuf) -> Result<SyncStats> {
self.handle_sync_with_progress(repo, _repo_path, default_progress)
.await
}
pub async fn handle_sync_with_progress<F: Fn(&str)>(
&self,
repo: &Repository,
_repo_path: &PathBuf,
progress: F,
) -> Result<SyncStats> {
let mut stats = SyncStats::default();
let project_code = repo.project_code()?;
progress("Waiting for connection...");
let conn = self
.endpoint
.accept()
.await
.ok_or_else(|| FossilError::SyncError("No connection".into()))?
.await
.map_err(err)?;
progress("Connection accepted, waiting for stream...");
let (mut send, mut recv) = conn.accept_bi().await.map_err(err)?;
progress("Stream opened, receiving HELLO...");
let msg = recv.read_u8().await.map_err(err)?;
if MsgType::try_from(msg)? != MsgType::Hello {
return Err(FossilError::SyncError("Expected HELLO".into()));
}
let code_len = recv.read_u8().await.map_err(err)? as usize;
let mut code_buf = vec![0u8; code_len];
recv.read_exact(&mut code_buf).await.map_err(err)?;
let client_project_code = String::from_utf8_lossy(&code_buf).to_string();
stats.bytes_received += 2 + code_len;
if client_project_code != project_code {
progress(&format!(
"Project mismatch: client={} server={}",
client_project_code, project_code
));
send.write_u8(MsgType::HelloNak as u8).await.map_err(err)?;
send.finish().map_err(err)?;
return Err(FossilError::SyncError("Project mismatch".into()));
}
send.write_u8(MsgType::HelloAck as u8).await.map_err(err)?;
stats.bytes_sent += 1;
progress("HELLO accepted, collecting hashes...");
let our_hashes = get_all_hashes(repo)?;
progress(&format!("Sending {} hashes...", our_hashes.len()));
stats.bytes_sent += send_hashes(&mut send, &our_hashes).await?;
progress("Receiving client hashes...");
let (client_hashes, recv_bytes) = receive_hashes(&mut recv).await?;
stats.bytes_received += recv_bytes;
progress(&format!(
"Received {} hashes from client",
client_hashes.len()
));
let need_from_client: Vec<String> =
client_hashes.difference(&our_hashes).cloned().collect();
let client_needs: Vec<String> = our_hashes.difference(&client_hashes).cloned().collect();
progress(&format!(
"Need {} from client, client needs {} from us",
need_from_client.len(),
client_needs.len()
));
stats.bytes_sent += send_needs(&mut send, &need_from_client).await?;
let (confirmed_client_needs, recv_bytes) = receive_needs(&mut recv).await?;
stats.bytes_received += recv_bytes;
let total_to_send = confirmed_client_needs.len();
progress(&format!("Sending {} artifacts to client...", total_to_send));
for (i, hash) in confirmed_client_needs.iter().enumerate() {
match send_artifact(&mut send, repo, hash, &client_hashes).await {
Ok(bytes) => {
stats.bytes_sent += bytes;
stats.artifacts_sent += 1;
}
Err(_) => {
stats.skipped += 1;
}
}
if (i + 1) % 100 == 0 || i + 1 == total_to_send {
progress(&format!(
" Sent {}/{} artifacts ({} skipped, {} KB)",
stats.artifacts_sent + stats.skipped,
total_to_send,
stats.skipped,
stats.bytes_sent / 1024
));
}
}
send.write_u8(MsgType::Done as u8).await.map_err(err)?;
stats.bytes_sent += 1;
progress("Receiving artifacts from client...");
loop {
let b = recv.read_u8().await.map_err(err)?;
let msg_type = MsgType::try_from(b)?;
stats.bytes_received += 1;
if msg_type == MsgType::Done {
break;
}
if !matches!(
msg_type,
MsgType::Artifact
| MsgType::DeltaArtifact
| MsgType::RawBlob
| MsgType::RawDeltaBlob
) {
return Err(FossilError::SyncError(format!(
"Expected artifact, got {:?}",
msg_type
)));
}
let (_, art_bytes) = receive_artifact(&mut recv, repo, msg_type).await?;
stats.bytes_received += art_bytes;
stats.artifacts_received += 1;
if stats.artifacts_received % 100 == 0 {
progress(&format!(
" Received {} artifacts ({} KB)",
stats.artifacts_received,
stats.bytes_received / 1024
));
}
}
progress("Sync complete");
send.finish().map_err(err)?;
Ok(stats)
}
}
pub struct QuicClient;
impl QuicClient {
pub async fn sync(repo: &Repository, _repo_path: &PathBuf, addr: &str) -> Result<SyncStats> {
Self::sync_with_progress(repo, _repo_path, addr, default_progress).await
}
pub async fn sync_with_progress<F: Fn(&str)>(
repo: &Repository,
_repo_path: &PathBuf,
addr: &str,
progress: F,
) -> Result<SyncStats> {
let mut stats = SyncStats::default();
let project_code = repo.project_code()?;
progress(&format!("Connecting to {}...", addr));
let addr: SocketAddr = addr.parse().map_err(err)?;
let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap())?;
endpoint.set_default_client_config(make_client_config()?);
let conn = endpoint
.connect(addr, "localhost")
.map_err(err)?
.await
.map_err(err)?;
progress("Connected, opening stream...");
let (mut send, mut recv) = conn.open_bi().await.map_err(err)?;
progress("Sending HELLO...");
send.write_u8(MsgType::Hello as u8).await.map_err(err)?;
let code_bytes = project_code.as_bytes();
send.write_u8(code_bytes.len() as u8).await.map_err(err)?;
send.write_all(code_bytes).await.map_err(err)?;
stats.bytes_sent += 2 + code_bytes.len();
let msg = recv.read_u8().await.map_err(err)?;
stats.bytes_received += 1;
let msg_type = MsgType::try_from(msg)?;
if msg_type == MsgType::HelloNak {
progress("Server rejected: project mismatch");
return Err(FossilError::SyncError("Project mismatch".into()));
}
if msg_type != MsgType::HelloAck {
return Err(FossilError::SyncError("Expected HELLO-ACK".into()));
}
progress("HELLO accepted, receiving server hashes...");
let (server_hashes, recv_bytes) = receive_hashes(&mut recv).await?;
stats.bytes_received += recv_bytes;
progress(&format!(
"Received {} hashes from server",
server_hashes.len()
));
let our_hashes = get_all_hashes(repo)?;
progress(&format!("Sending {} hashes...", our_hashes.len()));
stats.bytes_sent += send_hashes(&mut send, &our_hashes).await?;
let need_from_server: Vec<String> =
server_hashes.difference(&our_hashes).cloned().collect();
let (confirmed_server_needs, recv_bytes) = receive_needs(&mut recv).await?;
stats.bytes_received += recv_bytes;
progress(&format!(
"Need {} from server, server needs {} from us",
need_from_server.len(),
confirmed_server_needs.len()
));
stats.bytes_sent += send_needs(&mut send, &need_from_server).await?;
let total_to_receive = need_from_server.len();
progress(&format!(
"Receiving {} artifacts from server...",
total_to_receive
));
loop {
let b = recv.read_u8().await.map_err(err)?;
let msg_type = MsgType::try_from(b)?;
stats.bytes_received += 1;
if msg_type == MsgType::Done {
break;
}
if !matches!(
msg_type,
MsgType::Artifact
| MsgType::DeltaArtifact
| MsgType::RawBlob
| MsgType::RawDeltaBlob
) {
return Err(FossilError::SyncError(format!(
"Expected artifact, got {:?}",
msg_type
)));
}
let (_, art_bytes) = receive_artifact(&mut recv, repo, msg_type).await?;
stats.bytes_received += art_bytes;
stats.artifacts_received += 1;
if stats.artifacts_received % 100 == 0 {
progress(&format!(
" Received {}/{} artifacts ({} KB)",
stats.artifacts_received,
total_to_receive,
stats.bytes_received / 1024
));
}
}
progress(&format!(
"Received {} artifacts, sending {} to server...",
stats.artifacts_received,
confirmed_server_needs.len()
));
let total_to_send = confirmed_server_needs.len();
for (i, hash) in confirmed_server_needs.iter().enumerate() {
match send_artifact(&mut send, repo, hash, &server_hashes).await {
Ok(bytes) => {
stats.bytes_sent += bytes;
stats.artifacts_sent += 1;
}
Err(_) => {
stats.skipped += 1;
}
}
if (i + 1) % 100 == 0 || i + 1 == total_to_send {
progress(&format!(
" Sent {}/{} artifacts ({} skipped, {} KB)",
stats.artifacts_sent + stats.skipped,
total_to_send,
stats.skipped,
stats.bytes_sent / 1024
));
}
}
send.write_u8(MsgType::Done as u8).await.map_err(err)?;
stats.bytes_sent += 1;
progress("Sync complete");
send.finish().map_err(err)?;
Ok(stats)
}
}
pub fn sync_blocking(repo: &Repository, repo_path: &PathBuf, addr: &str) -> Result<SyncStats> {
let rt = tokio::runtime::Runtime::new().map_err(err)?;
rt.block_on(QuicClient::sync(repo, repo_path, addr))
}