use std::sync::Arc;
use std::time::Duration;
use openraft::BasicNode;
use openraft::error::{InstallSnapshotError, NetworkError, RPCError, RaftError, RemoteError};
use openraft::network::RPCOption;
use openraft::network::{RaftNetwork, RaftNetworkFactory};
use openraft::raft::{
AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
VoteRequest, VoteResponse,
};
use super::types::*;
const RPC_TIMEOUT: Duration = Duration::from_secs(5);
const SNAPSHOT_RPC_TIMEOUT: Duration = Duration::from_mins(1);
fn effective_timeout(option: &RPCOption, default: Duration) -> Duration {
let hard = option.hard_ttl();
hard.max(default)
}
pub const RAFT_TRANSPORT_TOKEN_HEADER: &str = "x-hirnd-raft-token";
pub struct HirnRaftNetworkFactory {
client: reqwest::Client,
transport_secret: Option<Arc<str>>,
}
impl HirnRaftNetworkFactory {
pub fn new(transport_secret: Option<&str>) -> reqwest::Result<Self> {
let client = reqwest::Client::builder()
.timeout(RPC_TIMEOUT)
.pool_idle_timeout(Duration::from_secs(30))
.pool_max_idle_per_host(2)
.build()?;
Ok(Self {
client,
transport_secret: transport_secret.map(Arc::<str>::from),
})
}
}
impl RaftNetworkFactory<TypeConfig> for HirnRaftNetworkFactory {
type Network = HirnRaftNetwork;
async fn new_client(&mut self, target: NodeId, node: &BasicNode) -> Self::Network {
HirnRaftNetwork {
target,
addr: node.addr.clone(),
client: self.client.clone(),
transport_secret: self.transport_secret.clone(),
}
}
}
pub struct HirnRaftNetwork {
target: NodeId,
addr: String,
client: reqwest::Client,
transport_secret: Option<Arc<str>>,
}
impl HirnRaftNetwork {
fn endpoint(&self) -> &str {
self.addr.trim_end_matches('/')
}
fn post(&self, path: &str) -> reqwest::RequestBuilder {
let builder = self.client.post(format!("{}/{}", self.endpoint(), path));
match self.transport_secret.as_deref() {
Some(secret) => builder.header(RAFT_TRANSPORT_TOKEN_HEADER, secret),
None => builder,
}
}
}
impl RaftNetwork<TypeConfig> for HirnRaftNetwork {
async fn append_entries(
&mut self,
rpc: AppendEntriesRequest<TypeConfig>,
option: RPCOption,
) -> Result<AppendEntriesResponse<NodeId>, RPCError<NodeId, BasicNode, RaftError<NodeId>>> {
let timeout = effective_timeout(&option, RPC_TIMEOUT);
let resp = self
.post("raft/append")
.json(&rpc)
.timeout(timeout)
.send()
.await
.map_err(|e| RPCError::Network(NetworkError::new(&e)))?;
let result: Result<AppendEntriesResponse<NodeId>, RaftError<NodeId>> = resp
.json()
.await
.map_err(|e| RPCError::Network(NetworkError::new(&e)))?;
result.map_err(|e| RPCError::RemoteError(RemoteError::new(self.target, e)))
}
async fn install_snapshot(
&mut self,
rpc: InstallSnapshotRequest<TypeConfig>,
option: RPCOption,
) -> Result<
InstallSnapshotResponse<NodeId>,
RPCError<NodeId, BasicNode, RaftError<NodeId, InstallSnapshotError>>,
> {
let timeout = effective_timeout(&option, SNAPSHOT_RPC_TIMEOUT);
let resp = self
.post("raft/snapshot")
.json(&rpc)
.timeout(timeout)
.send()
.await
.map_err(|e| RPCError::Network(NetworkError::new(&e)))?;
let result: Result<
InstallSnapshotResponse<NodeId>,
RaftError<NodeId, InstallSnapshotError>,
> = resp
.json()
.await
.map_err(|e| RPCError::Network(NetworkError::new(&e)))?;
result.map_err(|e| RPCError::RemoteError(RemoteError::new(self.target, e)))
}
async fn vote(
&mut self,
rpc: VoteRequest<NodeId>,
option: RPCOption,
) -> Result<VoteResponse<NodeId>, RPCError<NodeId, BasicNode, RaftError<NodeId>>> {
let timeout = effective_timeout(&option, RPC_TIMEOUT);
let resp = self
.post("raft/vote")
.json(&rpc)
.timeout(timeout)
.send()
.await
.map_err(|e| RPCError::Network(NetworkError::new(&e)))?;
let result: Result<VoteResponse<NodeId>, RaftError<NodeId>> = resp
.json()
.await
.map_err(|e| RPCError::Network(NetworkError::new(&e)))?;
result.map_err(|e| RPCError::RemoteError(RemoteError::new(self.target, e)))
}
}