use {
super::{Config, LogReplaySync, LogReplaySyncMessage},
crate::{
PeerId,
groups::{
Cursor,
Index,
StateMachine,
StateSyncSession,
SyncSessionContext,
Term,
},
primitives::{AsyncWorkQueue, Pretty, Short, UnboundedChannel},
},
core::{
ops::RangeInclusive,
pin::Pin,
task::{Context, Poll, Waker},
},
futures::StreamExt,
std::{
collections::{BTreeMap, HashMap, HashSet},
time::Instant,
},
tokio::time::{Sleep, sleep},
};
pub struct LogReplaySession<M: StateMachine> {
config: Config,
leader: PeerId,
gap: RangeInclusive<Index>,
buffered: Vec<(Index, Term, M::Command)>,
terminations: UnboundedChannel<PeerId>,
known_bonds: HashSet<PeerId>,
availability: HashMap<PeerId, RangeInclusive<Index>>,
inflight: HashMap<PeerId, PendingFetch>,
#[allow(clippy::type_complexity)]
fetched: BTreeMap<Index, (Index, PeerId, Vec<(M::Command, Term)>)>,
tasks: AsyncWorkQueue,
wakers: Vec<Waker>,
total: usize,
downloaded: usize,
started_at: Instant,
unique_peers: HashSet<PeerId>,
}
impl<M> core::fmt::Debug for LogReplaySession<M>
where
M: StateMachine,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("LogReplaySession")
.field("gap", &self.gap)
.finish_non_exhaustive()
}
}
impl<M: StateMachine> LogReplaySession<M> {
pub(super) fn new(
config: &Config,
cx: &dyn SyncSessionContext<LogReplaySync<M>>,
position: Cursor,
entries: Vec<(M::Command, Term)>,
) -> Self {
let local_pos = cx.log().last();
let gap = local_pos.index().next()..=position.index();
let total = (position.index() - local_pos.index()).as_usize();
assert_ne!(
total, 0,
"no need to create a sync session if we're not behind, local: \
{local_pos}, target: {position}"
);
let tasks = AsyncWorkQueue::default();
let terminations = UnboundedChannel::default();
for bond in cx.bonds().iter() {
let tx = terminations.sender().clone();
tasks.enqueue(async move {
bond.terminated().await;
let _ = tx.send(*bond.peer().id());
});
}
let watcher = cx.bonds();
tasks.enqueue(async move {
loop {
watcher.changed().await;
}
});
let pos = position.index().next();
Self {
gap,
tasks,
terminations,
leader: cx.leader(),
config: config.clone(),
buffered: entries
.into_iter()
.enumerate()
.map(|(i, (cmd, term))| (pos + i, term, cmd))
.collect(),
known_bonds: HashSet::new(),
availability: HashMap::new(),
inflight: HashMap::new(),
fetched: BTreeMap::new(),
wakers: Vec::new(),
downloaded: 0,
started_at: Instant::now(),
unique_peers: HashSet::new(),
total,
}
}
fn wake_all(&mut self) {
for waker in self.wakers.drain(..) {
waker.wake();
}
}
fn poll_terminations(&mut self, cx: &mut Context<'_>) {
if self.terminations.is_empty() {
return;
}
let count = self.terminations.len();
let mut terminated = Vec::with_capacity(count);
if self
.terminations
.poll_recv_many(cx, &mut terminated, count)
.is_ready()
{
for peer in terminated {
self.availability.remove(&peer);
self.inflight.remove(&peer);
self.known_bonds.remove(&peer);
self.wake_all();
}
}
}
fn poll_timeouts(
&mut self,
poll_cx: &mut Context<'_>,
sync_cx: &mut dyn SyncSessionContext<LogReplaySync<M>>,
) {
let mut timed_out = Vec::new();
for (peer, pending) in &mut self.inflight {
if pending.timeout_fut.as_mut().poll(poll_cx).is_ready() {
timed_out.push((*peer, pending.range.clone()));
}
}
for (peer, range) in timed_out {
tracing::warn!(
peer = %Short(peer),
range = %Pretty(&range),
group = %Short(sync_cx.group_id()),
network = %Short(sync_cx.network_id()),
"fetch request timed out"
);
self.inflight.remove(&peer);
self.availability.remove(&peer);
sync_cx
.send_to(peer, LogReplaySyncMessage::AvailabilityRequest)
.expect("infallible serialization");
self.wake_all();
}
}
fn poll_new_bonds(
&mut self,
cx: &mut dyn SyncSessionContext<LogReplaySync<M>>,
) {
let current: HashSet<PeerId> =
cx.bonds().iter().map(|b| *b.peer().id()).collect();
let new_peers: Vec<PeerId> =
current.difference(&self.known_bonds).copied().collect();
for peer in &new_peers {
cx.send_to(*peer, LogReplaySyncMessage::AvailabilityRequest)
.expect("infallible serialization");
}
self.known_bonds = current;
}
fn idle_peers_sorted(&self) -> Vec<PeerId> {
let mut peers: Vec<PeerId> = self
.availability
.keys()
.filter(|p| !self.inflight.contains_key(*p))
.copied()
.collect();
peers.sort_by_key(|p| i32::from(*p == self.leader));
peers
}
fn schedule_fetches(
&mut self,
sync_cx: &mut dyn SyncSessionContext<LogReplaySync<M>>,
) {
let chunk_size = self.config.batch_size;
let mut assigned: Vec<RangeInclusive<Index>> = Vec::new();
for pending in self.inflight.values() {
assigned.push(pending.range.clone());
}
for (&start, &(end, _, _)) in &self.fetched {
assigned.push(start..=end);
}
assigned.sort_by_key(|r| *r.start());
let mut cursor = *self.gap.start();
let gap_end = *self.gap.end();
for range in &assigned {
if cursor > gap_end {
break;
}
if *range.start() <= cursor && *range.end() >= cursor {
cursor = range.end().next();
}
}
let mut idle_peers = self.idle_peers_sorted();
while cursor <= gap_end && !idle_peers.is_empty() {
let chunk_end = (cursor + (chunk_size - 1)).min(gap_end);
if let Some(idx) = idle_peers.iter().position(|p| {
self
.availability
.get(p)
.is_some_and(|a| *a.start() <= cursor && *a.end() >= cursor)
}) {
let peer = idle_peers.remove(idx);
let avail = &self.availability[&peer];
let effective_end = chunk_end.min(*avail.end());
let effective_chunk = cursor..=effective_end;
self.send_fetch_request(peer, effective_chunk, sync_cx);
cursor = effective_end.next();
} else {
break;
}
for range in &assigned {
if *range.start() <= cursor && *range.end() >= cursor {
cursor = range.end().next();
}
}
}
}
fn send_fetch_request(
&mut self,
peer: PeerId,
range: RangeInclusive<Index>,
sync_cx: &mut dyn SyncSessionContext<LogReplaySync<M>>,
) {
tracing::trace!(
peer = %Short(peer),
range = %Pretty(&range),
group = %Short(sync_cx.group_id()),
network = %Short(sync_cx.network_id()),
"requesting state from"
);
sync_cx
.send_to(
peer,
LogReplaySyncMessage::FetchEntriesRequest(range.clone()),
)
.expect("infallible serialization");
let timeout_duration = self.config.fetch_timeout;
self.inflight.insert(peer, PendingFetch {
range,
timeout_fut: Box::pin(sleep(timeout_duration)),
});
self.unique_peers.insert(peer);
}
fn on_availability_response(
&mut self,
peer: PeerId,
available: RangeInclusive<Index>,
cx: &dyn SyncSessionContext<LogReplaySync<M>>,
) {
self.availability.insert(peer, available);
if let Some(bond) = cx.bonds().get(&peer) {
let tx = self.terminations.sender().clone();
self.tasks.enqueue(async move {
bond.terminated().await;
let _ = tx.send(peer);
});
}
self.wake_all();
}
fn on_fetch_entries_response(
&mut self,
peer: PeerId,
range: RangeInclusive<Index>,
entries: Vec<(M::Command, Term)>,
) {
self.inflight.remove(&peer);
self
.fetched
.insert(*range.start(), (*range.end(), peer, entries));
self.wake_all();
}
fn drain_fetched_entries(
&mut self,
poll_cx: &Context<'_>,
sync_cx: &mut dyn SyncSessionContext<LogReplaySync<M>>,
) {
let mut cursor = *self.gap.start();
while let Some(entry) = self.fetched.first_key_value() {
let (&start, _) = entry;
if start != cursor {
break; }
let (_, (end, peer, entries)) =
self.fetched.remove_entry(&start).unwrap();
assert_eq!(start, sync_cx.log().last().index().next());
self.downloaded += entries.len();
if !entries.is_empty() {
self.wake_all();
}
for (command, term) in entries {
sync_cx.log_mut().append(command, term);
}
let synced_range = start..=sync_cx.log().last().index();
let progress = self.downloaded as f64 / self.total as f64 * 100.0;
tracing::debug!(
range = %Pretty(&synced_range),
from = %Short(peer),
group = %Short(sync_cx.group_id()),
network = %Short(sync_cx.network_id()),
"syncing state {progress:.1}% complete",
);
cursor = end.next();
}
if cursor != *self.gap.start() {
self.gap = cursor..=*self.gap.end();
}
if self.gap.is_empty() {
poll_cx.waker().wake_by_ref();
}
}
fn finalize_sync(
&mut self,
cx: &mut dyn SyncSessionContext<LogReplaySync<M>>,
) {
let mut pos = self.gap.end().next();
if !self.buffered.is_empty() {
tracing::trace!(
pos = %pos,
count = self.buffered.len(),
group = %Short(cx.group_id()),
network = %Short(cx.network_id()),
"applying buffered state"
);
for (index, term, command) in self.buffered.drain(..) {
assert_eq!(index, pos);
cx.log_mut().append(command, term);
pos = pos.next();
}
}
}
}
impl<M: StateMachine> StateSyncSession for LogReplaySession<M> {
type Owner = LogReplaySync<M>;
fn poll(
&mut self,
poll_cx: &mut Context<'_>,
sync_cx: &mut dyn SyncSessionContext<LogReplaySync<M>>,
) -> Poll<Cursor> {
self.wakers.push(poll_cx.waker().clone());
if self.gap.is_empty() {
self.finalize_sync(sync_cx);
tracing::debug!(
group = %Short(sync_cx.group_id()),
network = %Short(sync_cx.network_id()),
"synced {} entries in {:?} from {} peers",
self.downloaded,
self.started_at.elapsed(),
self.unique_peers.len(),
);
return Poll::Ready(sync_cx.log().last());
}
let _ = self.tasks.poll_next_unpin(poll_cx);
self.poll_terminations(poll_cx);
self.drain_fetched_entries(poll_cx, sync_cx);
self.poll_timeouts(poll_cx, sync_cx);
self.poll_new_bonds(sync_cx);
self.schedule_fetches(sync_cx);
Poll::Pending
}
fn receive(
&mut self,
message: LogReplaySyncMessage<M::Command>,
sender: crate::PeerId,
cx: &mut dyn SyncSessionContext<LogReplaySync<M>>,
) {
match message {
LogReplaySyncMessage::AvailabilityResponse(range) => {
self.on_availability_response(sender, range, cx);
}
LogReplaySyncMessage::FetchEntriesResponse { range, entries } => {
self.on_fetch_entries_response(
sender,
range,
entries
.into_iter()
.map(|(enc_cmd, term)| (enc_cmd.0, term))
.collect(),
);
}
_ => {} }
}
fn buffer(
&mut self,
position: Cursor,
entries: Vec<(M::Command, Term)>,
_: &mut dyn SyncSessionContext<LogReplaySync<M>>,
) {
let pos = position.index().next();
self.buffered.extend(
entries
.into_iter()
.enumerate()
.map(|(i, (cmd, term))| (pos + i, term, cmd)),
);
}
}
#[derive(Debug)]
struct PendingFetch {
range: RangeInclusive<Index>,
timeout_fut: Pin<Box<Sleep>>,
}