use std::time::Duration;
use display_more::DisplayOptionExt;
use futures_util::FutureExt;
use crate::RaftNetworkFactory;
use crate::RaftTypeConfig;
use crate::StorageError;
use crate::async_runtime::MpscSender;
use crate::async_runtime::watch::WatchReceiver;
use crate::core::notification::Notification;
use crate::core::sm::handle::SnapshotReader;
use crate::errors::HigherVote;
use crate::errors::RPCError;
use crate::errors::ReplicationClosed;
use crate::errors::ReplicationError;
use crate::network::Backoff;
use crate::network::NetBackoff;
use crate::network::NetSnapshot;
use crate::network::RPCOption;
use crate::progress::inflight_id::InflightId;
use crate::replication::Progress;
use crate::replication::replication_context::ReplicationContext;
use crate::replication::response::ReplicationResult;
use crate::replication::snapshot_transmitter_handle::SnapshotTransmitterHandle;
use crate::type_config::TypeConfigExt;
use crate::type_config::alias::InstantOf;
use crate::type_config::alias::SnapshotOf;
use crate::type_config::alias::VoteOf;
use crate::type_config::alias::WatchSenderOf;
use crate::vote::raft_vote::RaftVoteExt;
pub(crate) struct SnapshotTransmitter<C, N, SM = ()>
where
C: RaftTypeConfig,
N: RaftNetworkFactory<C>,
{
pub(crate) replication_context: ReplicationContext<C>,
inflight_id: InflightId,
network: N::Network,
backoff: Option<Backoff>,
snapshot_reader: SnapshotReader<C, SM>,
}
impl<C, N, SM: 'static> SnapshotTransmitter<C, N, SM>
where
C: RaftTypeConfig,
N: RaftNetworkFactory<C>,
{
pub(crate) fn spawn(
replication_context: ReplicationContext<C>,
network: N::Network,
snapshot_reader: SnapshotReader<C, SM>,
inflight_id: InflightId,
cancel_tx: WatchSenderOf<C, ()>,
) -> SnapshotTransmitterHandle<C> {
let snapshot_transmit = Self {
replication_context,
inflight_id,
network,
backoff: None,
snapshot_reader,
};
let join_handle = C::spawn(snapshot_transmit.stream_snapshot());
SnapshotTransmitterHandle {
_join_handle: join_handle,
_tx_cancel: cancel_tx,
}
}
#[tracing::instrument(level = "info", skip_all)]
async fn stream_snapshot(mut self) {
tracing::info!("{}", func_name!());
let mut ith: i32 = -1;
loop {
ith += 1;
let res = self.read_and_send_snapshot(ith).await;
let error = match res {
Err(error) => error,
Ok(_) => {
return;
}
};
tracing::error!("ReplicationError while sending snapshot: {}", error);
match error {
ReplicationError::Closed(closed) => {
tracing::info!("snapshot transmission canceled: {}", closed);
return;
}
ReplicationError::HigherVote(h) => {
tracing::info!("snapshot transmission aborted, higher vote seen: {}", h);
self.replication_context
.tx_notify
.send(Notification::HigherVote {
target: self.replication_context.target,
higher: h.higher,
leader_vote: self.replication_context.leader_vote.clone(),
})
.await
.ok();
return;
}
ReplicationError::StorageError(error) => {
tracing::error!(
"error replication to target: {}, error: {}",
self.replication_context.target,
error
);
self.replication_context.tx_notify.send(Notification::StorageError { error }).await.ok();
return;
}
ReplicationError::RPCError(err) => {
match &err {
RPCError::Unreachable(_unreachable) => {
if self.backoff.is_none() {
self.backoff = Some(self.network.backoff());
}
}
RPCError::Timeout(_) | RPCError::Network(_) | RPCError::RemoteError(_) => {
self.backoff = None;
}
};
if let Some(b) = &mut self.backoff {
let duration = b.next().unwrap_or_else(|| {
tracing::warn!("backoff exhausted, using default");
Duration::from_millis(500)
});
let sleep = C::sleep(duration);
let recv = self.replication_context.cancel_rx.changed();
futures_util::select! {
_ = sleep.fuse() => {
tracing::debug!("backoff timeout");
}
_ = recv.fuse() => {
tracing::info!("snapshot transmission canceled by RaftCore");
return;
}
}
}
}
};
}
}
async fn read_and_send_snapshot(&mut self, ith: i32) -> Result<(), ReplicationError<C>> {
let snapshot = self.snapshot_reader.get_snapshot().await.map_err(|reason| {
tracing::warn!("failed to get snapshot from state machine: {}", reason);
ReplicationClosed::new(reason)
})?;
tracing::info!(
"{}-th snapshot sending: has read snapshot: meta:{}",
ith,
snapshot.as_ref().map(|x| &x.meta).display()
);
let snapshot = match snapshot {
None => {
let sto_err = StorageError::read_snapshot(None, C::err_from_string("snapshot not found"));
return Err(sto_err.into());
}
Some(x) => x,
};
let mut option = RPCOption::new(self.replication_context.config.install_snapshot_timeout());
option.snapshot_chunk_size = Some(self.replication_context.config.snapshot_max_chunk_size as usize);
self.send_snapshot(snapshot, option).await
}
async fn send_snapshot(&mut self, snapshot: SnapshotOf<C>, option: RPCOption) -> Result<(), ReplicationError<C>> {
let meta = snapshot.meta.clone();
let mut c = self.replication_context.cancel_rx.clone();
let cancel = async move {
c.changed().await.ok();
ReplicationClosed::new("RaftCore is dropped")
};
let sender_vote: VoteOf<C> = self.replication_context.leader_vote.clone().into_vote();
let start_time = C::now();
let resp = self.network.full_snapshot(sender_vote.clone(), snapshot, cancel, option).await?;
tracing::info!("finished sending full_snapshot, resp: {}", resp);
if resp.vote.as_ref_vote() > sender_vote.as_ref_vote() {
return Err(ReplicationError::HigherVote(HigherVote {
higher: resp.vote,
sender_vote,
}));
}
self.notify_heartbeat_progress(start_time).await;
self.notify_progress(ReplicationResult(Ok(meta.last_log_id))).await;
Ok(())
}
async fn notify_heartbeat_progress(&mut self, sending_time: InstantOf<C>) {
self.replication_context
.tx_notify
.send({
Notification::HeartbeatProgress {
stream_id: self.replication_context.stream_id,
target: self.replication_context.target.clone(),
sending_time,
}
})
.await
.ok();
}
async fn notify_progress(&mut self, replication_result: ReplicationResult<C>) {
tracing::debug!(
"{}: target: {}, result: {}",
func_name!(),
self.replication_context.target.clone(),
replication_result
);
self.replication_context
.tx_notify
.send({
Notification::ReplicationProgress {
progress: Progress {
target: self.replication_context.target.clone(),
result: Ok(replication_result.clone()),
},
inflight_id: Some(self.inflight_id),
}
})
.await
.ok();
}
}