use super::BlockHeader;
use crate::Result;
use std::collections::VecDeque;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SyncState {
Idle,
DownloadingHeaders,
Verifying,
Synced,
}
pub struct HeaderChain {
headers: VecDeque<BlockHeader>,
max_headers: usize,
finalized: u64,
}
impl HeaderChain {
pub fn new(max_headers: usize) -> Self {
Self {
headers: VecDeque::new(),
max_headers,
finalized: 0,
}
}
pub fn latest(&self) -> Option<&BlockHeader> {
self.headers.back()
}
pub fn finalized(&self) -> u64 {
self.finalized
}
pub fn push(&mut self, header: BlockHeader) -> Result<()> {
if let Some(latest) = self.headers.back() {
if header.parent_hash != latest.hash {
return Err(crate::Error::NodeSync(
"Parent hash mismatch".to_string(),
));
}
}
self.headers.push_back(header);
while self.headers.len() > self.max_headers {
self.headers.pop_front();
}
Ok(())
}
pub fn set_finalized(&mut self, block_number: u64) {
self.finalized = block_number;
while let Some(header) = self.headers.front() {
if header.number < block_number.saturating_sub(10) {
self.headers.pop_front();
} else {
break;
}
}
}
pub fn get(&self, number: u64) -> Option<&BlockHeader> {
self.headers.iter().find(|h| h.number == number)
}
pub fn get_by_hash(&self, hash: &str) -> Option<&BlockHeader> {
self.headers.iter().find(|h| h.hash == hash)
}
pub fn contains(&self, number: u64) -> bool {
self.headers.iter().any(|h| h.number == number)
}
pub fn len(&self) -> usize {
self.headers.len()
}
pub fn is_empty(&self) -> bool {
self.headers.is_empty()
}
}
pub struct SyncManager {
chain: HeaderChain,
state: SyncState,
target: u64,
}
impl SyncManager {
pub fn new() -> Self {
Self {
chain: HeaderChain::new(1000),
state: SyncState::Idle,
target: 0,
}
}
pub fn state(&self) -> SyncState {
self.state
}
pub fn progress(&self) -> f64 {
if self.target == 0 {
return 1.0;
}
let current = self.chain.latest().map(|h| h.number).unwrap_or(0);
(current as f64) / (self.target as f64)
}
pub fn start_sync(&mut self, target: u64) {
self.target = target;
self.state = SyncState::DownloadingHeaders;
}
pub fn process_headers(&mut self, headers: Vec<BlockHeader>) -> Result<()> {
self.state = SyncState::Verifying;
for header in headers {
self.chain.push(header)?;
}
if let Some(latest) = self.chain.latest() {
if latest.number >= self.target {
self.state = SyncState::Synced;
} else {
self.state = SyncState::DownloadingHeaders;
}
}
Ok(())
}
pub fn handle_finality(&mut self, block_number: u64) {
self.chain.set_finalized(block_number);
}
pub fn chain(&self) -> &HeaderChain {
&self.chain
}
}
impl Default for SyncManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_header(number: u64, parent_hash: &str) -> BlockHeader {
BlockHeader {
hash: format!("0x{:064x}", number),
number,
parent_hash: parent_hash.to_string(),
state_root: String::new(),
extrinsics_root: String::new(),
timestamp: 0,
}
}
#[test]
fn test_header_chain() {
let mut chain = HeaderChain::new(100);
let h1 = make_header(1, "0x0");
let h2 = make_header(2, &h1.hash);
chain.push(h1).unwrap();
chain.push(h2).unwrap();
assert_eq!(chain.len(), 2);
assert_eq!(chain.latest().unwrap().number, 2);
}
#[test]
fn test_sync_manager() {
let mut manager = SyncManager::new();
assert_eq!(manager.state(), SyncState::Idle);
manager.start_sync(100);
assert_eq!(manager.state(), SyncState::DownloadingHeaders);
}
}