mod background_snapshot_transfer;
#[cfg(test)]
mod background_snapshot_transfer_test;
pub(crate) use background_snapshot_transfer::*;
use d_engine_proto::server::cluster::ClusterConfChangeRequest;
use d_engine_proto::server::cluster::ClusterConfUpdateResponse;
use d_engine_proto::server::election::VoteRequest;
use d_engine_proto::server::election::VoteResponse;
use d_engine_proto::server::replication::AppendEntriesRequest;
use d_engine_proto::server::replication::AppendEntriesResponse;
use d_engine_proto::server::storage::SnapshotChunk;
#[cfg(any(test, feature = "__test_support"))]
use mockall::automock;
use tonic::async_trait;
use crate::BackoffPolicy;
use crate::NetworkError;
use crate::Result;
use crate::RetryPolicies;
use crate::TypeConfig;
#[derive(Debug, Clone)]
pub struct AppendResults {
pub commit_quorum_achieved: bool,
pub peer_updates: HashMap<u32, PeerUpdate>,
pub learner_progress: HashMap<u32, Option<u64>>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct PeerUpdate {
pub match_index: Option<u64>,
pub next_index: u64,
pub success: bool,
}
impl PeerUpdate {
#[allow(unused)]
pub fn success(
match_index: u64,
next_index: u64,
) -> Self {
PeerUpdate {
match_index: Some(match_index),
next_index,
success: true,
}
}
#[allow(unused)]
pub fn failed() -> Self {
Self {
match_index: None,
next_index: 1,
success: false,
}
}
}
#[derive(Debug)]
pub struct AppendResult {
pub peer_ids: HashSet<u32>,
pub responses: Vec<Result<AppendEntriesResponse>>,
}
#[derive(Debug)]
pub struct VoteResult {
pub peer_ids: HashSet<u32>,
pub responses: Vec<Result<VoteResponse>>,
}
#[allow(dead_code)]
#[derive(Debug)]
pub struct ClusterUpdateResult {
pub peer_ids: HashSet<u32>,
pub responses: Vec<Result<ClusterConfUpdateResponse>>,
}
#[cfg_attr(any(test, feature = "__test_support"), automock)]
#[async_trait]
pub trait Transport<T>: Send + Sync + 'static
where
T: TypeConfig,
{
#[allow(dead_code)]
async fn send_cluster_update(
&self,
req: ClusterConfChangeRequest,
retry: &RetryPolicies,
membership: std::sync::Arc<crate::alias::MOF<T>>,
) -> Result<ClusterUpdateResult>;
async fn send_append_requests(
&self,
requests: Vec<(u32, AppendEntriesRequest)>,
retry: &RetryPolicies,
membership: std::sync::Arc<crate::alias::MOF<T>>,
response_compress_enabled: bool,
) -> Result<AppendResult>;
async fn send_vote_requests(
&self,
req: VoteRequest,
retry: &RetryPolicies,
membership: std::sync::Arc<crate::alias::MOF<T>>,
) -> Result<VoteResult>;
async fn join_cluster(
&self,
leader_id: u32,
request: d_engine_proto::server::cluster::JoinRequest,
retry: BackoffPolicy,
membership: std::sync::Arc<crate::alias::MOF<T>>,
) -> Result<d_engine_proto::server::cluster::JoinResponse>;
async fn discover_leader(
&self,
request: d_engine_proto::server::cluster::LeaderDiscoveryRequest,
rpc_enable_compression: bool,
membership: std::sync::Arc<crate::alias::MOF<T>>,
) -> Result<Vec<d_engine_proto::server::cluster::LeaderDiscoveryResponse>>;
async fn request_snapshot_from_leader(
&self,
leader_id: u32,
ack_tx: tokio::sync::mpsc::Receiver<d_engine_proto::server::storage::SnapshotAck>,
retry: &crate::InstallSnapshotBackoffPolicy,
membership: std::sync::Arc<crate::alias::MOF<T>>,
) -> Result<Box<tonic::Streaming<SnapshotChunk>>>;
}
use std::collections::HashMap;
use std::collections::HashSet;
use std::time::Duration;
use tokio::time::sleep;
use tokio::time::timeout;
use tonic::Code;
use tracing::debug;
use tracing::warn;
use crate::Error;
pub async fn grpc_task_with_timeout_and_exponential_backoff<F, T, U>(
task_name: &'static str,
mut task: F,
policy: BackoffPolicy,
) -> std::result::Result<tonic::Response<U>, Error>
where
F: FnMut() -> T,
T: std::future::Future<Output = std::result::Result<tonic::Response<U>, tonic::Status>>
+ Send
+ 'static,
{
let mut retries = 0;
let mut current_delay = Duration::from_millis(policy.base_delay_ms);
let timeout_duration = Duration::from_millis(policy.timeout_ms);
let max_delay = Duration::from_millis(policy.max_delay_ms);
let max_retries = policy.max_retries;
let mut last_error =
NetworkError::TaskBackoffFailed("Task failed after max retries".to_string());
while retries < max_retries {
debug!("[{task_name}] Attempt {} of {}", retries + 1, max_retries);
match timeout(timeout_duration, task()).await {
Ok(Ok(r)) => {
return Ok(r); }
Ok(Err(status)) => {
last_error = match status.code() {
Code::Unavailable => {
warn!("[{task_name}] Service unavailable: {}", status.message());
NetworkError::ServiceUnavailable(format!(
"Service unavailable: {}",
status.message()
))
}
_ => {
warn!("[{task_name}] RPC error: {}", status);
NetworkError::TonicStatusError(Box::new(status))
}
};
}
Err(_e) => {
warn!("[{task_name}] Task timed out after {:?}", timeout_duration);
last_error = NetworkError::RetryTimeoutError(timeout_duration);
}
};
if retries < max_retries - 1 {
debug!("[{task_name}] Retrying in {:?}...", current_delay);
sleep(current_delay).await;
current_delay = (current_delay * 2).min(max_delay);
} else {
warn!("[{task_name}] Task failed after {} retries", retries);
}
retries += 1;
}
warn!("[{task_name}] Task failed after {} retries", max_retries);
Err(last_error.into()) }