use {
super::{Config, Snapshot, SnapshotStateMachine, SnapshotSync, protocol::*},
crate::{
PeerId,
groups::{Cursor, Index, StateSyncSession, SyncSessionContext, Term},
primitives::{AsyncWorkQueue, Pretty, Short, UnboundedChannel},
},
core::{
ops::Range,
pin::Pin,
task::{Context, Poll, Waker},
},
futures::StreamExt,
std::{
collections::{BTreeMap, HashMap, HashSet},
time::Instant,
},
tokio::time::{Sleep, sleep},
};
pub struct SnapshotSyncSession<M: SnapshotStateMachine> {
config: Config,
leader: PeerId,
trigger_position: Index,
anchor: Option<Cursor>,
snapshot_len: u64,
gap: Option<Range<u64>>,
anchor_candidates: HashMap<Cursor, (SnapshotInfo, HashSet<PeerId>)>,
buffered: Vec<(Index, Term, M::Command)>,
terminations: UnboundedChannel<PeerId>,
known_bonds: HashSet<PeerId>,
available_peers: HashSet<PeerId>,
inflight: HashMap<PeerId, PendingFetch>,
#[allow(clippy::type_complexity)]
fetched: BTreeMap<u64, (u64, Vec<<M::Snapshot as Snapshot>::Item>)>,
accumulated: M::Snapshot,
snapshot_request_timer: Option<Pin<Box<Sleep>>>,
tasks: AsyncWorkQueue,
wakers: Vec<Waker>,
total: usize,
downloaded: usize,
started_at: Instant,
unique_peers: HashSet<PeerId>,
}
impl<M: SnapshotStateMachine> SnapshotSyncSession<M> {
pub(super) fn new(
config: &Config,
cx: &mut dyn SyncSessionContext<SnapshotSync<M>>,
position: Cursor,
_leader_commit: Index,
entries: Vec<(M::Command, Term)>,
) -> Self {
let tasks = AsyncWorkQueue::default();
let terminations = UnboundedChannel::default();
for bond in cx.bonds().iter() {
let tx = terminations.sender().clone();
tasks.enqueue(async move {
bond.terminated().await;
let _ = tx.send(*bond.peer().id());
});
}
let watcher = cx.bonds();
tasks.enqueue(async move {
loop {
watcher.changed().await;
}
});
cx.send_to(cx.leader(), SnapshotSyncMessage::RequestSnapshot);
let pos = position.index().next();
Self {
leader: cx.leader(),
trigger_position: position.index(),
anchor: None,
snapshot_len: 0,
gap: None,
anchor_candidates: HashMap::new(),
config: config.clone(),
accumulated: M::Snapshot::default(),
terminations,
known_bonds: HashSet::new(),
available_peers: HashSet::new(),
inflight: HashMap::new(),
fetched: BTreeMap::new(),
snapshot_request_timer: Some(Box::pin(sleep(
config.snapshot_request_timeout,
))),
tasks,
wakers: Vec::new(),
total: 0,
downloaded: 0,
started_at: Instant::now(),
unique_peers: HashSet::new(),
buffered: entries
.into_iter()
.enumerate()
.map(|(i, (cmd, term))| (pos + i, term, cmd))
.collect(),
}
}
fn wake_all(&mut self) {
for waker in self.wakers.drain(..) {
waker.wake();
}
}
fn earliest_buffered(&self) -> Option<Index> {
self.buffered.first().map(|(idx, _, _)| *idx)
}
fn try_select_anchor(
&mut self,
cx: &dyn SyncSessionContext<SnapshotSync<M>>,
) -> bool {
if self.anchor.is_some() {
return self.try_upgrade_anchor(cx);
}
let earliest_buf = self.earliest_buffered();
let mut candidates: Vec<_> = self.anchor_candidates.iter().collect();
candidates.sort_by(|a, b| b.0.index().cmp(&a.0.index()));
for (anchor_cursor, (info, _peers)) in &candidates {
let valid =
earliest_buf.is_none_or(|earliest| anchor_cursor.index() >= earliest);
if valid {
tracing::info!(
anchor = %anchor_cursor,
items = info.items_count,
group = %Short(cx.group_id()),
network = %Short(cx.network_id()),
"selected snapshot anchor"
);
self.anchor = Some(**anchor_cursor);
self.snapshot_len = info.items_count;
self.total = info.items_count as usize;
if info.items_count > 0 {
self.gap = Some(0..info.items_count);
} else {
self.gap = Some(0..0);
}
if let Some((_, peers)) = self.anchor_candidates.get(anchor_cursor) {
for peer in peers {
self.available_peers.insert(*peer);
}
}
return true;
}
}
false
}
fn try_upgrade_anchor(
&mut self,
cx: &dyn SyncSessionContext<SnapshotSync<M>>,
) -> bool {
let Some(current) = self.anchor else {
return false;
};
let earliest_buf = self.earliest_buffered();
let best = self
.anchor_candidates
.iter()
.filter(|(cursor, _)| cursor.index() > current.index())
.filter(|(cursor, _)| {
earliest_buf.is_none_or(|earliest| cursor.index() >= earliest)
})
.max_by_key(|(cursor, _)| cursor.index());
if let Some((&new_anchor, (info, peers))) = best {
tracing::info!(
old_anchor = %current,
new_anchor = %new_anchor,
len = info.items_count,
group = %Short(cx.group_id()),
network = %Short(cx.network_id()),
"upgrading snapshot anchor"
);
self.anchor = Some(new_anchor);
self.snapshot_len = info.items_count;
self.total = info.items_count as usize;
self.downloaded = 0;
self.accumulated = M::Snapshot::default();
self.inflight.clear();
self.fetched.clear();
self.gap = Some(0..info.items_count);
self.available_peers.clear();
for peer in peers {
self.available_peers.insert(*peer);
}
true
} else {
false
}
}
fn on_snapshot_ready(
&mut self,
peer: PeerId,
info: SnapshotInfo,
cx: &dyn SyncSessionContext<SnapshotSync<M>>,
) {
tracing::debug!(
from = %Short(peer),
anchor = %info.anchor,
items = info.items_count,
group = %Short(cx.group_id()),
network = %Short(cx.network_id()),
"snapshot available"
);
self.snapshot_request_timer = None;
let entry = self
.anchor_candidates
.entry(info.anchor)
.or_insert_with(|| (info, HashSet::new()));
entry.1.insert(peer);
if self.anchor == Some(info.anchor) {
self.available_peers.insert(peer);
}
if let Some(bond) = cx.bonds().get(&peer) {
let tx = self.terminations.sender().clone();
self.tasks.enqueue(async move {
bond.terminated().await;
let _ = tx.send(peer);
});
}
self.try_select_anchor(cx);
self.wake_all();
}
fn on_fetch_data_response(
&mut self,
peer: PeerId,
response: FetchDataResponse<<M::Snapshot as Snapshot>::Item>,
) {
if self.anchor != Some(response.anchor) {
return;
}
self.inflight.remove(&peer);
let start = response.offset;
let end = start + response.items.len() as u64;
self.fetched.insert(start, (end, response.items));
self.wake_all();
}
fn poll_terminations(&mut self, cx: &mut Context<'_>) {
if self.terminations.is_empty() {
return;
}
let count = self.terminations.len();
let mut terminated = Vec::with_capacity(count);
if self
.terminations
.poll_recv_many(cx, &mut terminated, count)
.is_ready()
{
for peer in terminated {
self.available_peers.remove(&peer);
self.inflight.remove(&peer);
self.known_bonds.remove(&peer);
for (_, peers) in self.anchor_candidates.values_mut() {
peers.remove(&peer);
}
self.wake_all();
}
}
}
fn poll_timeouts(
&mut self,
poll_cx: &mut Context<'_>,
sync_cx: &dyn SyncSessionContext<SnapshotSync<M>>,
) {
let mut timed_out = Vec::new();
for (peer, pending) in &mut self.inflight {
if pending.timeout.as_mut().poll(poll_cx).is_ready() {
timed_out.push((*peer, pending.range.clone()));
}
}
for (peer, range) in timed_out {
tracing::warn!(
peer = %Short(peer),
range = ?range,
group = %Short(sync_cx.group_id()),
network = %Short(sync_cx.network_id()),
"snapshot fetch request timed out"
);
self.inflight.remove(&peer);
self.available_peers.remove(&peer);
self.wake_all();
}
}
fn poll_snapshot_request_timeout(
&mut self,
poll_cx: &mut Context<'_>,
sync_cx: &mut dyn SyncSessionContext<SnapshotSync<M>>,
) {
let Some(timer) = &mut self.snapshot_request_timer else {
return;
};
if timer.as_mut().poll(poll_cx).is_ready() {
tracing::debug!(
leader = %Short(sync_cx.leader()),
group = %Short(sync_cx.group_id()),
network = %Short(sync_cx.network_id()),
"retrying snapshot request"
);
sync_cx.send_to(sync_cx.leader(), SnapshotSyncMessage::RequestSnapshot);
self.snapshot_request_timer =
Some(Box::pin(sleep(self.config.snapshot_request_timeout)));
if let Some(t) = &mut self.snapshot_request_timer {
let _ = t.as_mut().poll(poll_cx);
}
}
}
fn poll_new_bonds(&mut self, cx: &dyn SyncSessionContext<SnapshotSync<M>>) {
let current: HashSet<PeerId> =
cx.bonds().iter().map(|b| *b.peer().id()).collect();
self.known_bonds = current;
}
fn idle_peers_sorted(&self) -> Vec<PeerId> {
let mut peers: Vec<PeerId> = self
.available_peers
.iter()
.filter(|p| !self.inflight.contains_key(*p))
.copied()
.collect();
peers.sort_by_key(|p| i32::from(*p == self.leader));
peers
}
fn schedule_fetches(
&mut self,
sync_cx: &mut dyn SyncSessionContext<SnapshotSync<M>>,
) {
let gap = match &self.gap {
Some(g) if !g.is_empty() => g.clone(),
_ => return,
};
let Some(anchor) = self.anchor else { return };
let chunk_size = self.config.fetch_batch_size;
let mut assigned: Vec<Range<u64>> = Vec::new();
for pending in self.inflight.values() {
assigned.push(pending.range.clone());
}
for (&start, (end, _)) in &self.fetched {
assigned.push(start..*end);
}
assigned.sort_by_key(|r| r.start);
let mut cursor = gap.start;
let gap_end = gap.end;
for range in &assigned {
if cursor >= gap_end {
break;
}
if range.start <= cursor && range.end > cursor {
cursor = range.end;
}
}
let mut idle_peers = self.idle_peers_sorted();
while cursor < gap_end && !idle_peers.is_empty() {
let chunk_end = (cursor + chunk_size).min(gap_end);
let peer = idle_peers.remove(0);
let effective_chunk = cursor..chunk_end;
self.send_fetch_request(peer, anchor, effective_chunk, sync_cx);
cursor = chunk_end;
for range in &assigned {
if range.start <= cursor && range.end > cursor {
cursor = range.end;
}
}
}
}
fn send_fetch_request(
&mut self,
peer: PeerId,
anchor: Cursor,
range: Range<u64>,
sync_cx: &mut dyn SyncSessionContext<SnapshotSync<M>>,
) {
tracing::trace!(
peer = %Short(peer),
range = ?range,
anchor = %anchor,
group = %Short(sync_cx.group_id()),
network = %Short(sync_cx.network_id()),
"syncing from"
);
sync_cx.send_to(
peer,
SnapshotSyncMessage::FetchDataRequest(FetchDataRequest {
anchor,
range: range.clone(),
}),
);
let timeout_duration = self.config.fetch_timeout;
self.inflight.insert(peer, PendingFetch {
range,
timeout: Box::pin(sleep(timeout_duration)),
});
self.unique_peers.insert(peer);
}
fn drain_fetched_items(
&mut self,
poll_cx: &Context<'_>,
sync_cx: &dyn SyncSessionContext<SnapshotSync<M>>,
) {
let gap = match &self.gap {
Some(g) if !g.is_empty() => g.clone(),
_ => return,
};
let mut cursor = gap.start;
while let Some(entry) = self.fetched.first_key_value() {
let (&start, _) = entry;
if start != cursor {
break; }
let (_, (end, items)) = self.fetched.remove_entry(&start).unwrap();
self.downloaded += items.len();
if !items.is_empty() {
self.wake_all();
}
self.accumulated.append(items);
let progress = self.downloaded as f64 / self.total.max(1) as f64 * 100.0;
tracing::debug!(
range = ?(start..end),
total = self.total,
downloaded = self.downloaded,
group = %Short(sync_cx.group_id()),
network = %Short(sync_cx.network_id()),
"syncing snapshot {progress:.1}%",
);
cursor = end;
}
if cursor != gap.start {
self.gap = Some(cursor..gap.end);
}
if self.gap.as_ref().is_some_and(|g| g.is_empty()) {
poll_cx.waker().wake_by_ref();
}
}
fn finalize_sync(
&mut self,
cx: &mut dyn SyncSessionContext<SnapshotSync<M>>,
) {
let anchor = self.anchor.expect("finalize_sync called without an anchor");
let snapshot = core::mem::take(&mut self.accumulated);
cx.state_machine_mut().install_snapshot(snapshot);
cx.log_mut().reset_to(anchor);
cx.set_committed(anchor);
self.buffered.retain(|(idx, _, _)| *idx > anchor.index());
if !self.buffered.is_empty() {
let mut pos = anchor.index().next();
tracing::trace!(
pos = %pos,
count = self.buffered.len(),
group = %Short(cx.group_id()),
network = %Short(cx.network_id()),
"applying buffered entries after snapshot"
);
for (index, term, command) in self.buffered.drain(..) {
assert_eq!(index, pos);
cx.log_mut().append(command, term);
pos = pos.next();
}
}
}
fn is_sync_complete(&self) -> bool {
self.anchor.is_some() && self.gap.as_ref().is_some_and(|g| g.is_empty())
}
}
impl<M: SnapshotStateMachine> StateSyncSession for SnapshotSyncSession<M> {
type Owner = SnapshotSync<M>;
fn poll(
&mut self,
poll_cx: &mut Context<'_>,
sync_cx: &mut dyn SyncSessionContext<Self::Owner>,
) -> Poll<Cursor> {
self.wakers.push(poll_cx.waker().clone());
if self.is_sync_complete() {
self.finalize_sync(sync_cx);
tracing::debug!(
group = %Short(sync_cx.group_id()),
network = %Short(sync_cx.network_id()),
"snapshot sync completed: {} items in {:?} from {} peers",
self.downloaded,
self.started_at.elapsed(),
self.unique_peers.len(),
);
return Poll::Ready(sync_cx.log().last());
}
let _ = self.tasks.poll_next_unpin(poll_cx);
self.poll_terminations(poll_cx);
self.poll_snapshot_request_timeout(poll_cx, sync_cx);
self.try_select_anchor(sync_cx);
self.drain_fetched_items(poll_cx, sync_cx);
self.poll_timeouts(poll_cx, sync_cx);
self.poll_new_bonds(sync_cx);
self.schedule_fetches(sync_cx);
Poll::Pending
}
fn receive(
&mut self,
message: SnapshotSyncMessage<<M::Snapshot as Snapshot>::Item>,
sender: PeerId,
cx: &mut dyn SyncSessionContext<Self::Owner>,
) {
match message {
SnapshotSyncMessage::SnapshotOffer(info) => {
self.on_snapshot_ready(sender, info, cx);
}
SnapshotSyncMessage::FetchDataResponse(response) => {
self.on_fetch_data_response(sender, response);
}
_ => unreachable!("handled at the provider level"),
}
}
fn buffer(
&mut self,
position: Cursor,
entries: Vec<(M::Command, Term)>,
_cx: &mut dyn SyncSessionContext<Self::Owner>,
) {
let pos = position.index().next();
self.buffered.extend(
entries
.into_iter()
.enumerate()
.map(|(i, (cmd, term))| (pos + i, term, cmd)),
);
}
}
struct PendingFetch {
range: Range<u64>,
timeout: Pin<Box<Sleep>>,
}