openraft 0.10.0-alpha.18

Advanced Raft consensus
Documentation
use std::time::Duration;

use display_more::DisplayOptionExt;
use futures_util::FutureExt;

use crate::RaftNetworkFactory;
use crate::RaftTypeConfig;
use crate::StorageError;
use crate::async_runtime::MpscSender;
use crate::async_runtime::watch::WatchReceiver;
use crate::core::notification::Notification;
use crate::core::sm::handle::SnapshotReader;
use crate::errors::HigherVote;
use crate::errors::RPCError;
use crate::errors::ReplicationClosed;
use crate::errors::ReplicationError;
use crate::network::Backoff;
use crate::network::NetBackoff;
use crate::network::NetSnapshot;
use crate::network::RPCOption;
use crate::progress::inflight_id::InflightId;
use crate::replication::Progress;
use crate::replication::replication_context::ReplicationContext;
use crate::replication::response::ReplicationResult;
use crate::replication::snapshot_transmitter_handle::SnapshotTransmitterHandle;
use crate::type_config::TypeConfigExt;
use crate::type_config::alias::InstantOf;
use crate::type_config::alias::SnapshotOf;
use crate::type_config::alias::VoteOf;
use crate::type_config::alias::WatchSenderOf;
use crate::vote::raft_vote::RaftVoteExt;

/// Task that transmits a snapshot to a follower.
///
/// Spawned by `RaftCore` when log replication falls too far behind and a snapshot
/// is needed. Runs independently, retrying on transient failures with backoff,
/// and notifies `RaftCore` of progress or errors via the notification channel.
pub(crate) struct SnapshotTransmitter<C, N, SM = ()>
where
    C: RaftTypeConfig,
    N: RaftNetworkFactory<C>,
{
    pub(crate) replication_context: ReplicationContext<C>,

    inflight_id: InflightId,

    /// Network connection for snapshot replication.
    ///
    /// Snapshot transmitting is a long-running task and is processed in a separate task.
    network: N::Network,

    /// The backoff policy if an [`Unreachable`](`crate::error::Unreachable`) error is returned.
    /// It will be reset to `None` when a successful response is received.
    backoff: Option<Backoff>,

    /// The handle to get a snapshot directly from the state machine.
    snapshot_reader: SnapshotReader<C, SM>,
}

impl<C, N, SM: 'static> SnapshotTransmitter<C, N, SM>
where
    C: RaftTypeConfig,
    N: RaftNetworkFactory<C>,
{
    pub(crate) fn spawn(
        replication_context: ReplicationContext<C>,
        network: N::Network,
        snapshot_reader: SnapshotReader<C, SM>,
        inflight_id: InflightId,
        cancel_tx: WatchSenderOf<C, ()>,
    ) -> SnapshotTransmitterHandle<C> {
        let snapshot_transmit = Self {
            replication_context,
            inflight_id,
            network,
            backoff: None,
            snapshot_reader,
        };

        // TODO: this function should just return join_handle and let the caller build
        //       SnapshotTransmitterHandle
        let join_handle = C::spawn(snapshot_transmit.stream_snapshot());

        SnapshotTransmitterHandle {
            _join_handle: join_handle,
            _tx_cancel: cancel_tx,
        }
    }

    #[tracing::instrument(level = "info", skip_all)]
    async fn stream_snapshot(mut self) {
        tracing::info!("{}", func_name!());

        let mut ith: i32 = -1;
        loop {
            ith += 1;

            let res = self.read_and_send_snapshot(ith).await;

            let error = match res {
                Err(error) => error,
                Ok(_) => {
                    return;
                }
            };

            tracing::error!("ReplicationError while sending snapshot: {}", error);

            match error {
                ReplicationError::Closed(closed) => {
                    tracing::info!("snapshot transmission canceled: {}", closed);
                    return;
                }
                ReplicationError::HigherVote(h) => {
                    tracing::info!("snapshot transmission aborted, higher vote seen: {}", h);
                    self.replication_context
                        .tx_notify
                        .send(Notification::HigherVote {
                            target: self.replication_context.target,
                            higher: h.higher,
                            leader_vote: self.replication_context.leader_vote.clone(),
                        })
                        .await
                        .ok();

                    return;
                }
                ReplicationError::StorageError(error) => {
                    tracing::error!(
                        "error replication to target: {}, error: {}",
                        self.replication_context.target,
                        error
                    );
                    self.replication_context.tx_notify.send(Notification::StorageError { error }).await.ok();
                    return;
                }
                ReplicationError::RPCError(err) => {
                    match &err {
                        RPCError::Unreachable(_unreachable) => {
                            // If there is an [`Unreachable`] error, we will backoff for a
                            // period of time. Backoff will be reset if there is a
                            // successful RPC is sent.
                            if self.backoff.is_none() {
                                self.backoff = Some(self.network.backoff());
                            }
                        }
                        RPCError::Timeout(_) | RPCError::Network(_) | RPCError::RemoteError(_) => {
                            self.backoff = None;
                        }
                    };

                    if let Some(b) = &mut self.backoff {
                        let duration = b.next().unwrap_or_else(|| {
                            tracing::warn!("backoff exhausted, using default");
                            Duration::from_millis(500)
                        });

                        let sleep = C::sleep(duration);
                        let recv = self.replication_context.cancel_rx.changed();

                        futures_util::select! {
                            _ = sleep.fuse() => {
                                tracing::debug!("backoff timeout");
                            }
                            _ = recv.fuse() => {
                                tracing::info!("snapshot transmission canceled by RaftCore");
                                return;
                            }
                        }
                    }
                }
            };
        }
    }

    async fn read_and_send_snapshot(&mut self, ith: i32) -> Result<(), ReplicationError<C>> {
        let snapshot = self.snapshot_reader.get_snapshot().await.map_err(|reason| {
            tracing::warn!("failed to get snapshot from state machine: {}", reason);
            ReplicationClosed::new(reason)
        })?;

        tracing::info!(
            "{}-th snapshot sending: has read snapshot: meta:{}",
            ith,
            snapshot.as_ref().map(|x| &x.meta).display()
        );

        let snapshot = match snapshot {
            None => {
                let sto_err = StorageError::read_snapshot(None, C::err_from_string("snapshot not found"));
                return Err(sto_err.into());
            }
            Some(x) => x,
        };

        let mut option = RPCOption::new(self.replication_context.config.install_snapshot_timeout());
        option.snapshot_chunk_size = Some(self.replication_context.config.snapshot_max_chunk_size as usize);

        self.send_snapshot(snapshot, option).await
    }

    async fn send_snapshot(&mut self, snapshot: SnapshotOf<C>, option: RPCOption) -> Result<(), ReplicationError<C>> {
        let meta = snapshot.meta.clone();

        let mut c = self.replication_context.cancel_rx.clone();
        let cancel = async move {
            c.changed().await.ok();
            ReplicationClosed::new("RaftCore is dropped")
        };

        let sender_vote: VoteOf<C> = self.replication_context.leader_vote.clone().into_vote();

        let start_time = C::now();

        let resp = self.network.full_snapshot(sender_vote.clone(), snapshot, cancel, option).await?;

        tracing::info!("finished sending full_snapshot, resp: {}", resp);

        // Handle response conditions.
        if resp.vote.as_ref_vote() > sender_vote.as_ref_vote() {
            return Err(ReplicationError::HigherVote(HigherVote {
                higher: resp.vote,
                sender_vote,
            }));
        }

        self.notify_heartbeat_progress(start_time).await;
        self.notify_progress(ReplicationResult(Ok(meta.last_log_id))).await;
        Ok(())
    }

    async fn notify_heartbeat_progress(&mut self, sending_time: InstantOf<C>) {
        self.replication_context
            .tx_notify
            .send({
                Notification::HeartbeatProgress {
                    stream_id: self.replication_context.stream_id,
                    target: self.replication_context.target.clone(),
                    sending_time,
                }
            })
            .await
            .ok();
    }

    async fn notify_progress(&mut self, replication_result: ReplicationResult<C>) {
        tracing::debug!(
            "{}: target: {}, result: {}",
            func_name!(),
            self.replication_context.target.clone(),
            replication_result
        );

        self.replication_context
            .tx_notify
            .send({
                Notification::ReplicationProgress {
                    progress: Progress {
                        target: self.replication_context.target.clone(),
                        result: Ok(replication_result.clone()),
                    },
                    inflight_id: Some(self.inflight_id),
                }
            })
            .await
            .ok();
    }
}