use std::collections::{HashMap, VecDeque};
use std::time::{Duration, Instant};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use tracing::{error, info, warn};
use crate::discovery::PeerInfo;
use crate::protocol::{
generate_message_id, BlocksMessage, GetBlocksMessage, GetHeadersMessage, GetSnapshotMessage,
HeadersMessage, Message, MessageId, PeerId, SnapshotMessage,
};
use moloch_core::block::{Block, BlockHash, BlockHeader};
use moloch_core::crypto::Hash;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SyncMode {
Full,
#[default]
Fast,
Snap,
CatchUp,
Warp,
}
impl std::fmt::Display for SyncMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SyncMode::Full => write!(f, "full"),
SyncMode::Fast => write!(f, "fast"),
SyncMode::Snap => write!(f, "snap"),
SyncMode::CatchUp => write!(f, "catch-up"),
SyncMode::Warp => write!(f, "warp"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SyncStatus {
pub mode: SyncMode,
pub state: SyncState,
pub local_height: Option<u64>,
pub target_height: Option<u64>,
pub blocks_per_second: f64,
pub eta_seconds: Option<u64>,
pub sync_peers: usize,
pub started_at: Option<DateTime<Utc>>,
pub progress: f64,
}
impl Default for SyncStatus {
fn default() -> Self {
Self {
mode: SyncMode::default(),
state: SyncState::Idle,
local_height: None,
target_height: None,
blocks_per_second: 0.0,
eta_seconds: None,
sync_peers: 0,
started_at: None,
progress: 0.0,
}
}
}
impl SyncStatus {
pub fn is_syncing(&self) -> bool {
matches!(
self.state,
SyncState::Downloading | SyncState::Verifying | SyncState::Applying
)
}
pub fn is_synced(&self) -> bool {
self.state == SyncState::Synced
}
pub fn calculate_progress(&mut self) {
match (self.local_height, self.target_height) {
(Some(local), Some(target)) if target > 0 => {
self.progress = (local as f64 / target as f64) * 100.0;
}
_ => self.progress = 0.0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SyncState {
Idle,
FindingPeers,
DownloadingHeaders,
Downloading,
Verifying,
Applying,
Synced,
Failed,
Paused,
}
impl std::fmt::Display for SyncState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SyncState::Idle => write!(f, "idle"),
SyncState::FindingPeers => write!(f, "finding_peers"),
SyncState::DownloadingHeaders => write!(f, "downloading_headers"),
SyncState::Downloading => write!(f, "downloading"),
SyncState::Verifying => write!(f, "verifying"),
SyncState::Applying => write!(f, "applying"),
SyncState::Synced => write!(f, "synced"),
SyncState::Failed => write!(f, "failed"),
SyncState::Paused => write!(f, "paused"),
}
}
}
#[derive(Debug, Clone)]
pub struct SyncConfig {
pub mode: SyncMode,
pub batch_size: u32,
pub max_concurrent_requests: usize,
pub request_timeout: Duration,
pub max_retries: u32,
pub min_peers: usize,
pub sync_threshold: u64,
pub checkpoint: Option<Checkpoint>,
pub header_first: bool,
}
impl Default for SyncConfig {
fn default() -> Self {
Self {
mode: SyncMode::Fast,
batch_size: 100,
max_concurrent_requests: 4,
request_timeout: Duration::from_secs(30),
max_retries: 3,
min_peers: 1,
sync_threshold: 10,
checkpoint: None,
header_first: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub height: u64,
pub hash: BlockHash,
pub mmr_root: Hash,
pub created_at: DateTime<Utc>,
}
#[derive(Debug)]
#[allow(dead_code)]
struct PendingRequest {
id: MessageId,
peer: PeerId,
kind: RequestKind,
sent_at: Instant,
retries: u32,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
enum RequestKind {
Headers { start: u64, count: u32 },
Blocks { start: u64, count: u32 },
Snapshot { height: Option<u64> },
}
#[derive(Debug, Clone)]
struct SyncRange {
start: u64,
end: u64,
peer: Option<PeerId>,
request_id: Option<MessageId>,
retries: u32,
}
#[derive(Debug)]
pub struct SyncManager {
config: SyncConfig,
status: RwLock<SyncStatus>,
pending_requests: RwLock<HashMap<MessageId, PendingRequest>>,
sync_ranges: RwLock<VecDeque<SyncRange>>,
block_buffer: RwLock<HashMap<u64, Block>>,
header_buffer: RwLock<HashMap<u64, BlockHeader>>,
peer_heights: RwLock<HashMap<PeerId, u64>>,
synced_count: std::sync::atomic::AtomicU64,
sync_start: RwLock<Option<Instant>>,
}
impl SyncManager {
pub fn new(config: SyncConfig) -> Self {
Self {
config,
status: RwLock::new(SyncStatus::default()),
pending_requests: RwLock::new(HashMap::new()),
sync_ranges: RwLock::new(VecDeque::new()),
block_buffer: RwLock::new(HashMap::new()),
header_buffer: RwLock::new(HashMap::new()),
peer_heights: RwLock::new(HashMap::new()),
synced_count: std::sync::atomic::AtomicU64::new(0),
sync_start: RwLock::new(None),
}
}
pub fn config(&self) -> &SyncConfig {
&self.config
}
pub async fn status(&self) -> SyncStatus {
let mut status = self.status.read().await.clone();
status.calculate_progress();
self.update_rate(&mut status).await;
status
}
async fn update_rate(&self, status: &mut SyncStatus) {
let sync_start = self.sync_start.read().await;
if let Some(start) = *sync_start {
let elapsed = start.elapsed().as_secs_f64();
if elapsed > 0.0 {
let synced = self.synced_count.load(std::sync::atomic::Ordering::Relaxed);
status.blocks_per_second = synced as f64 / elapsed;
if let (Some(local), Some(target)) = (status.local_height, status.target_height) {
if status.blocks_per_second > 0.0 && target > local {
let remaining = target - local;
status.eta_seconds =
Some((remaining as f64 / status.blocks_per_second) as u64);
}
}
}
}
}
pub async fn needs_sync(&self, local_height: Option<u64>) -> bool {
let peer_heights = self.peer_heights.read().await;
if peer_heights.is_empty() {
return false;
}
let max_peer_height = peer_heights.values().copied().max().unwrap_or(0);
let local = local_height.unwrap_or(0);
max_peer_height > local + self.config.sync_threshold
}
pub async fn update_peer_height(&self, peer: PeerId, height: u64) {
let mut heights = self.peer_heights.write().await;
heights.insert(peer, height);
let max_height = heights.values().copied().max().unwrap_or(0);
let mut status = self.status.write().await;
status.target_height = Some(max_height);
}
pub async fn remove_peer(&self, peer: &PeerId) {
self.peer_heights.write().await.remove(peer);
let mut pending = self.pending_requests.write().await;
let to_remove: Vec<_> = pending
.iter()
.filter(|(_, req)| &req.peer == peer)
.map(|(id, _)| *id)
.collect();
for id in to_remove {
pending.remove(&id);
}
let mut ranges = self.sync_ranges.write().await;
for range in ranges.iter_mut() {
if range.peer.as_ref() == Some(peer) {
range.peer = None;
range.request_id = None;
}
}
}
pub async fn start_sync(&self, from_height: u64, to_height: u64) {
info!("Starting sync from {} to {}", from_height, to_height);
let mut status = self.status.write().await;
status.state = SyncState::Downloading;
status.local_height = Some(from_height);
status.target_height = Some(to_height);
status.started_at = Some(Utc::now());
drop(status);
let mut ranges = self.sync_ranges.write().await;
ranges.clear();
let batch_size = self.config.batch_size as u64;
let mut start = from_height;
while start < to_height {
let end = (start + batch_size).min(to_height);
ranges.push_back(SyncRange {
start,
end,
peer: None,
request_id: None,
retries: 0,
});
start = end;
}
*self.sync_start.write().await = Some(Instant::now());
self.synced_count
.store(0, std::sync::atomic::Ordering::SeqCst);
}
pub async fn pause_sync(&self) {
let mut status = self.status.write().await;
if status.is_syncing() {
status.state = SyncState::Paused;
}
}
pub async fn resume_sync(&self) {
let mut status = self.status.write().await;
if status.state == SyncState::Paused {
status.state = SyncState::Downloading;
}
}
pub async fn next_request(&self, available_peers: &[PeerInfo]) -> Option<(PeerId, Message)> {
let status = self.status.read().await;
if !status.is_syncing() {
return None;
}
drop(status);
let pending = self.pending_requests.read().await;
if pending.len() >= self.config.max_concurrent_requests {
return None;
}
drop(pending);
let mut ranges = self.sync_ranges.write().await;
for range in ranges.iter_mut() {
if range.peer.is_some() {
continue; }
let heights = self.peer_heights.read().await;
let suitable_peer = available_peers
.iter()
.find(|p| heights.get(&p.id).map(|h| *h >= range.end).unwrap_or(false));
if let Some(peer) = suitable_peer {
let message_id = generate_message_id();
let message = if self.config.header_first {
Message::GetHeaders(GetHeadersMessage {
id: message_id,
start_height: range.start,
count: (range.end - range.start) as u32,
})
} else {
Message::GetBlocks(GetBlocksMessage {
id: message_id,
start_height: range.start,
count: (range.end - range.start) as u32,
})
};
range.peer = Some(peer.id.clone());
range.request_id = Some(message_id);
let mut pending = self.pending_requests.write().await;
pending.insert(
message_id,
PendingRequest {
id: message_id,
peer: peer.id.clone(),
kind: if self.config.header_first {
RequestKind::Headers {
start: range.start,
count: (range.end - range.start) as u32,
}
} else {
RequestKind::Blocks {
start: range.start,
count: (range.end - range.start) as u32,
}
},
sent_at: Instant::now(),
retries: range.retries,
},
);
return Some((peer.id.clone(), message));
}
}
None
}
pub async fn handle_blocks(&self, response: BlocksMessage) -> Result<Vec<Block>, SyncError> {
let mut pending = self.pending_requests.write().await;
let request = pending.remove(&response.request_id);
drop(pending);
if request.is_none() {
return Err(SyncError::UnexpectedResponse(response.request_id));
}
let mut buffer = self.block_buffer.write().await;
let mut received = Vec::new();
for block in response.blocks {
let height = block.header.height;
buffer.insert(height, block.clone());
received.push(block);
}
self.synced_count
.fetch_add(received.len() as u64, std::sync::atomic::Ordering::SeqCst);
Ok(received)
}
pub async fn handle_headers(
&self,
response: HeadersMessage,
) -> Result<Vec<BlockHeader>, SyncError> {
let mut pending = self.pending_requests.write().await;
let request = pending.remove(&response.request_id);
drop(pending);
if request.is_none() {
return Err(SyncError::UnexpectedResponse(response.request_id));
}
let mut buffer = self.header_buffer.write().await;
let mut received = Vec::new();
for header in response.headers {
let height = header.height;
buffer.insert(height, header.clone());
received.push(header);
}
Ok(received)
}
pub async fn handle_snapshot(&self, response: SnapshotMessage) -> Result<(), SyncError> {
let mut pending = self.pending_requests.write().await;
let request = pending.remove(&response.request_id);
drop(pending);
if request.is_none() {
return Err(SyncError::UnexpectedResponse(response.request_id));
}
let mut status = self.status.write().await;
status.local_height = Some(response.height);
info!(
"Received snapshot at height {} with {} events",
response.height, response.event_count
);
Ok(())
}
pub async fn get_ready_blocks(&self, current_height: u64) -> Vec<Block> {
let mut buffer = self.block_buffer.write().await;
let mut ready = Vec::new();
let mut next_height = current_height + 1;
while let Some(block) = buffer.remove(&next_height) {
ready.push(block);
next_height += 1;
}
ready
}
pub async fn complete_range(&self, start: u64, end: u64) {
let mut ranges = self.sync_ranges.write().await;
ranges.retain(|r| !(r.start == start && r.end == end));
if ranges.is_empty() {
drop(ranges);
let mut status = self.status.write().await;
status.state = SyncState::Synced;
info!("Sync complete");
}
}
pub async fn handle_timeout(&self, request_id: MessageId) {
let mut pending = self.pending_requests.write().await;
if let Some(request) = pending.remove(&request_id) {
warn!("Request {} to {} timed out", request_id, request.peer);
let mut ranges = self.sync_ranges.write().await;
for range in ranges.iter_mut() {
if range.request_id == Some(request_id) {
range.peer = None;
range.request_id = None;
range.retries += 1;
if range.retries > self.config.max_retries {
warn!("Range {}-{} exceeded max retries", range.start, range.end);
}
break;
}
}
}
}
pub async fn get_timed_out_requests(&self) -> Vec<MessageId> {
let pending = self.pending_requests.read().await;
let now = Instant::now();
pending
.iter()
.filter(|(_, req)| now.duration_since(req.sent_at) > self.config.request_timeout)
.map(|(id, _)| *id)
.collect()
}
pub fn create_snapshot_request(&self, height: Option<u64>) -> (MessageId, Message) {
let id = generate_message_id();
let msg = Message::GetSnapshot(GetSnapshotMessage { id, height });
(id, msg)
}
pub fn create_blocks_request(&self, start: u64, count: u32) -> (MessageId, Message) {
let id = generate_message_id();
let msg = Message::GetBlocks(GetBlocksMessage {
id,
start_height: start,
count,
});
(id, msg)
}
pub fn create_headers_request(&self, start: u64, count: u32) -> (MessageId, Message) {
let id = generate_message_id();
let msg = Message::GetHeaders(GetHeadersMessage {
id,
start_height: start,
count,
});
(id, msg)
}
pub async fn stats(&self) -> SyncStats {
let status = self.status.read().await;
let pending = self.pending_requests.read().await;
let ranges = self.sync_ranges.read().await;
let block_buffer = self.block_buffer.read().await;
let header_buffer = self.header_buffer.read().await;
let peer_heights = self.peer_heights.read().await;
SyncStats {
state: status.state,
mode: status.mode,
local_height: status.local_height,
target_height: status.target_height,
pending_requests: pending.len(),
remaining_ranges: ranges.len(),
buffered_blocks: block_buffer.len(),
buffered_headers: header_buffer.len(),
known_peers: peer_heights.len(),
blocks_synced: self.synced_count.load(std::sync::atomic::Ordering::Relaxed),
blocks_per_second: status.blocks_per_second,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SyncStats {
pub state: SyncState,
pub mode: SyncMode,
pub local_height: Option<u64>,
pub target_height: Option<u64>,
pub pending_requests: usize,
pub remaining_ranges: usize,
pub buffered_blocks: usize,
pub buffered_headers: usize,
pub known_peers: usize,
pub blocks_synced: u64,
pub blocks_per_second: f64,
}
#[derive(Debug, thiserror::Error)]
pub enum SyncError {
#[error("unexpected response for request {0}")]
UnexpectedResponse(MessageId),
#[error("request timed out: {0}")]
Timeout(MessageId),
#[error("not enough peers: have {have}, need {need}")]
NotEnoughPeers { have: usize, need: usize },
#[error("invalid block at height {0}")]
InvalidBlock(u64),
#[error("chain mismatch at height {0}")]
ChainMismatch(u64),
#[error("sync cancelled")]
Cancelled,
#[error("peer error: {0}")]
PeerError(String),
}
#[cfg(test)]
mod tests {
use super::*;
use moloch_core::crypto::SecretKey;
fn test_peer_id() -> PeerId {
crate::protocol::PeerId::new(SecretKey::generate().public_key())
}
fn test_peer_info(height: u64) -> PeerInfo {
use crate::discovery::{DiscoverySource, PeerMetadata, PeerScore, PeerState};
PeerInfo {
id: test_peer_id(),
addresses: vec!["127.0.0.1:8000".parse().unwrap()],
state: PeerState::Connected,
score: PeerScore::default(),
first_seen: Utc::now(),
last_seen: Some(Utc::now()),
connection_successes: 1,
connection_failures: 0,
source: DiscoverySource::Static,
metadata: PeerMetadata {
height: Some(height),
..Default::default()
},
}
}
#[test]
fn test_sync_mode_display() {
assert_eq!(format!("{}", SyncMode::Fast), "fast");
assert_eq!(format!("{}", SyncMode::Snap), "snap");
}
#[test]
fn test_sync_state_display() {
assert_eq!(format!("{}", SyncState::Downloading), "downloading");
assert_eq!(format!("{}", SyncState::Synced), "synced");
}
#[test]
fn test_sync_status_default() {
let status = SyncStatus::default();
assert!(!status.is_syncing());
assert!(!status.is_synced());
assert_eq!(status.progress, 0.0);
}
#[test]
fn test_sync_status_progress() {
let mut status = SyncStatus {
local_height: Some(50),
target_height: Some(100),
..Default::default()
};
status.calculate_progress();
assert_eq!(status.progress, 50.0);
}
#[tokio::test]
async fn test_sync_manager_creation() {
let config = SyncConfig::default();
let manager = SyncManager::new(config);
let status = manager.status().await;
assert_eq!(status.state, SyncState::Idle);
}
#[tokio::test]
async fn test_sync_manager_peer_heights() {
let config = SyncConfig::default();
let manager = SyncManager::new(config);
let peer1 = test_peer_id();
let peer2 = test_peer_id();
manager.update_peer_height(peer1.clone(), 100).await;
manager.update_peer_height(peer2.clone(), 200).await;
let status = manager.status().await;
assert_eq!(status.target_height, Some(200));
manager.remove_peer(&peer2).await;
let status = manager.status().await;
assert_eq!(status.target_height, Some(200)); }
#[tokio::test]
async fn test_sync_manager_needs_sync() {
let config = SyncConfig {
sync_threshold: 10,
..Default::default()
};
let manager = SyncManager::new(config);
assert!(!manager.needs_sync(Some(50)).await);
let peer = test_peer_id();
manager.update_peer_height(peer, 100).await;
assert!(manager.needs_sync(Some(50)).await);
assert!(!manager.needs_sync(Some(95)).await);
}
#[tokio::test]
async fn test_sync_manager_start_sync() {
let config = SyncConfig::default();
let manager = SyncManager::new(config);
manager.start_sync(0, 1000).await;
let status = manager.status().await;
assert_eq!(status.state, SyncState::Downloading);
assert_eq!(status.local_height, Some(0));
assert_eq!(status.target_height, Some(1000));
let ranges = manager.sync_ranges.read().await;
assert!(!ranges.is_empty());
}
#[tokio::test]
async fn test_sync_manager_pause_resume() {
let config = SyncConfig::default();
let manager = SyncManager::new(config);
manager.start_sync(0, 100).await;
manager.pause_sync().await;
let status = manager.status().await;
assert_eq!(status.state, SyncState::Paused);
manager.resume_sync().await;
let status = manager.status().await;
assert_eq!(status.state, SyncState::Downloading);
}
#[tokio::test]
async fn test_sync_manager_next_request() {
let config = SyncConfig::default();
let manager = SyncManager::new(config);
let request = manager.next_request(&[]).await;
assert!(request.is_none());
manager.start_sync(0, 100).await;
let request = manager.next_request(&[]).await;
assert!(request.is_none());
let peer = test_peer_info(100);
manager.update_peer_height(peer.id.clone(), 100).await;
let request = manager.next_request(&[peer]).await;
assert!(request.is_some());
}
#[tokio::test]
async fn test_sync_manager_handle_blocks() {
let config = SyncConfig::default();
let manager = SyncManager::new(config);
let response = BlocksMessage {
request_id: 999,
blocks: vec![],
has_more: false,
};
let result = manager.handle_blocks(response).await;
assert!(matches!(result, Err(SyncError::UnexpectedResponse(999))));
}
#[tokio::test]
async fn test_sync_manager_create_requests() {
let config = SyncConfig::default();
let manager = SyncManager::new(config);
let (id1, msg1) = manager.create_blocks_request(0, 100);
assert!(matches!(msg1, Message::GetBlocks(_)));
let (id2, msg2) = manager.create_headers_request(100, 50);
assert!(matches!(msg2, Message::GetHeaders(_)));
let (id3, msg3) = manager.create_snapshot_request(Some(500));
assert!(matches!(msg3, Message::GetSnapshot(_)));
assert_ne!(id1, id2);
assert_ne!(id2, id3);
}
#[tokio::test]
async fn test_sync_manager_stats() {
let config = SyncConfig::default();
let manager = SyncManager::new(config);
manager.start_sync(0, 100).await;
let stats = manager.stats().await;
assert_eq!(stats.state, SyncState::Downloading);
assert!(stats.remaining_ranges > 0);
assert_eq!(stats.buffered_blocks, 0);
}
#[tokio::test]
async fn test_sync_manager_timeout() {
let config = SyncConfig::default();
let manager = SyncManager::new(config);
let request_id = generate_message_id();
{
let mut pending = manager.pending_requests.write().await;
pending.insert(
request_id,
PendingRequest {
id: request_id,
peer: test_peer_id(),
kind: RequestKind::Blocks {
start: 0,
count: 100,
},
sent_at: Instant::now() - Duration::from_secs(60),
retries: 0,
},
);
}
let timed_out = manager.get_timed_out_requests().await;
assert_eq!(timed_out.len(), 1);
assert_eq!(timed_out[0], request_id);
manager.handle_timeout(request_id).await;
let pending = manager.pending_requests.read().await;
assert!(!pending.contains_key(&request_id));
}
#[tokio::test]
async fn test_sync_manager_complete_range() {
let config = SyncConfig::default();
let manager = SyncManager::new(config);
manager.start_sync(0, 100).await;
let initial_count = manager.sync_ranges.read().await.len();
assert!(initial_count > 0);
manager.complete_range(0, 100).await;
let status = manager.status().await;
assert_eq!(status.state, SyncState::Synced);
}
#[test]
fn test_checkpoint() {
let checkpoint = Checkpoint {
height: 1000,
hash: moloch_core::block::BlockHash(moloch_core::crypto::hash(b"block")),
mmr_root: moloch_core::crypto::hash(b"mmr"),
created_at: Utc::now(),
};
assert_eq!(checkpoint.height, 1000);
}
}