use {
super::{
Config,
Snapshot,
SnapshotStateMachine,
SnapshotSync,
SyncInitCommand,
protocol::*,
},
crate::{
PeerId,
collections::sync::PendingRequest,
groups::*,
primitives::Short,
},
chrono::Utc,
core::{
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
time::Duration,
},
std::collections::{HashMap, VecDeque, hash_map::Entry},
tokio::{
sync::{
broadcast,
mpsc::{UnboundedReceiver, UnboundedSender},
},
time::{Instant, Sleep, sleep},
},
};
pub struct SnapshotSyncProvider<M: SnapshotStateMachine> {
config: Config,
sync_init_cmd: SyncInitCommand<M>,
requests_rx: UnboundedReceiver<PendingRequest<M>>,
available: HashMap<Cursor, AvailableSnapshot<M>>,
}
impl<M: SnapshotStateMachine> SnapshotSyncProvider<M> {
pub(super) fn new(
config: Config,
sync_init_cmd: SyncInitCommand<M>,
requests_rx: UnboundedReceiver<PendingRequest<M>>,
) -> Self {
Self {
config,
requests_rx,
sync_init_cmd,
available: HashMap::new(),
}
}
}
impl<M: SnapshotStateMachine> StateSyncProvider for SnapshotSyncProvider<M> {
type Owner = SnapshotSync<M>;
fn poll(
&mut self,
task_cx: &mut Context<'_>,
sync_cx: &mut dyn SyncProviderContext<Self::Owner>,
) -> Poll<()> {
if self.poll_pending_requests(task_cx, sync_cx).is_ready() {
return Poll::Ready(());
}
self.prune_expired_snapshots(task_cx);
Poll::Pending
}
fn receive(
&mut self,
message: SnapshotSyncMessage<<M::Snapshot as Snapshot>::Item>,
sender: PeerId,
cx: &mut dyn SyncProviderContext<Self::Owner>,
) -> Result<(), SnapshotSyncMessage<<M::Snapshot as Snapshot>::Item>> {
match message {
SnapshotSyncMessage::RequestSnapshot => {
self.on_snapshot_request(
SnapshotRequest {
requested_by: sender,
requested_at: Utc::now(),
},
cx,
);
Ok(())
}
SnapshotSyncMessage::FetchDataRequest(request) => {
self.on_fetch_data_request(&request, sender, cx);
Ok(())
}
other => Err(other),
}
}
fn safe_to_prune_prefix(
&self,
cx: &mut dyn SyncProviderContext<Self::Owner>,
) -> Option<Index> {
Some(cx.committed().index())
}
}
impl<M: SnapshotStateMachine> SnapshotSyncProvider<M> {
fn poll_pending_requests(
&mut self,
task_cx: &mut Context<'_>,
sync_cx: &mut dyn SyncProviderContext<SnapshotSync<M>>,
) -> Poll<()> {
let mut received = false;
while let Poll::Ready(Some(pending)) = self.requests_rx.poll_recv(task_cx) {
received = true;
if pending.request.requested_by == sync_cx.local_id()
|| self.config.is_expired(&pending.request)
{
continue;
}
let len = pending.snapshot.len();
match self.available.entry(pending.position) {
Entry::Occupied(mut existing) => {
existing.get_mut().revive(self.config.snapshot_ttl);
}
Entry::Vacant(place) => {
place.insert(AvailableSnapshot::new(
pending.snapshot,
self.config.snapshot_ttl,
));
}
}
sync_cx.send_to(
pending.request.requested_by,
SnapshotInfo {
anchor: pending.position,
items_count: len,
}
.into(),
);
tracing::trace!(
to = %Short(pending.request.requested_by),
anchor = %pending.position,
items = len,
group = %sync_cx.group_id(),
network = %sync_cx.network_id(),
"offering snapshot"
);
}
if received {
Poll::Ready(())
} else {
Poll::Pending
}
}
fn prune_expired_snapshots(&mut self, cx: &mut Context<'_>) {
let mut expired = Vec::new();
for (pos, snapshot) in &mut self.available {
if snapshot.poll_expired(cx).is_ready() {
expired.push(*pos);
}
}
for pos in expired {
self.available.remove(&pos);
}
}
fn on_snapshot_request(
&self,
request: SnapshotRequest,
cx: &mut dyn SyncProviderContext<SnapshotSync<M>>,
) {
let sender = request.requested_by;
if !cx.is_leader() {
tracing::debug!(
from = %sender,
group = %cx.group_id(),
network = %cx.network_id(),
"ignoring snapshot request on a non-leader node"
);
return;
}
let request = (self.sync_init_cmd)(request);
if cx.feed_command(request).is_err() {
tracing::debug!(
from = %sender,
group = %cx.group_id(),
network = %cx.network_id(),
"failed to schedule snapshot request"
);
} else {
tracing::trace!(
by = %Short(sender),
group = %cx.group_id(),
network = %cx.network_id(),
"snapshot sync requested"
);
}
}
fn on_fetch_data_request(
&mut self,
request: &FetchDataRequest,
sender: PeerId,
cx: &mut dyn SyncProviderContext<SnapshotSync<M>>,
) {
let Some(available) = self.available.get_mut(&request.anchor) else {
tracing::debug!(
from = %Short(sender),
anchor = %request.anchor,
group = %cx.group_id(),
network = %cx.network_id(),
"requested snapshot not available (expired or unknown)"
);
return;
};
let end = request
.range
.end
.min(request.range.start + self.config.fetch_batch_size);
let range = request.range.start..end;
if range.is_empty() {
return;
}
available.revive(self.config.snapshot_ttl);
let Some(items) = available.snapshot.iter_range(range) else {
tracing::warn!(
from = %Short(sender),
anchor = %request.anchor,
range = ?request.range,
group = %cx.group_id(),
network = %cx.network_id(),
"snapshot range out of bounds"
);
return;
};
let items: std::vec::Vec<_> = items.collect();
cx.send_to(
sender,
SnapshotSyncMessage::FetchDataResponse(FetchDataResponse {
anchor: request.anchor,
offset: request.range.start,
items,
}),
);
}
}
struct AvailableSnapshot<M: SnapshotStateMachine> {
snapshot: M::Snapshot,
expired: Pin<Box<Sleep>>,
}
impl<M: SnapshotStateMachine> AvailableSnapshot<M> {
fn new(snapshot: M::Snapshot, ttl: Duration) -> Self {
Self {
snapshot,
expired: Box::pin(sleep(ttl)),
}
}
fn is_expired(&self) -> bool {
self.expired.as_ref().is_elapsed()
}
fn revive(&mut self, ttl: Duration) {
self.expired.as_mut().reset(Instant::now() + ttl);
}
fn poll_expired(&mut self, cx: &mut Context<'_>) -> Poll<()> {
match self.expired.as_mut().poll(cx) {
Poll::Ready(()) => Poll::Ready(()),
Poll::Pending => Poll::Pending,
}
}
}