use std::sync::Arc;
use std::time::Duration;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::routing::post;
use axum::{Json, Router};
use openraft::error::{
Fatal, InstallSnapshotError, NetworkError, RPCError, RaftError, RemoteError, ReplicationClosed,
StreamingError, Unreachable,
};
use openraft::network::{Backoff, RPCOption, RaftNetwork, RaftNetworkFactory};
use openraft::raft::{
AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
SnapshotResponse, VoteRequest, VoteResponse,
};
use openraft::{AnyError, Raft, Snapshot, SnapshotMeta, Vote};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use super::types::{YantrikNode, YantrikNodeId, YantrikRaftTypeConfig};
#[derive(Clone)]
pub struct HttpRaftNetwork {
client: Client,
target: YantrikNodeId,
target_addr: String,
request_timeout: Duration,
}
impl HttpRaftNetwork {
fn url(&self, path: &str) -> String {
let base = self.target_addr.trim_end_matches('/');
format!("{base}{path}")
}
async fn json_post<Req, Resp>(&self, path: &str, body: &Req) -> Result<Resp, NetworkError>
where
Req: Serialize + ?Sized,
Resp: for<'de> Deserialize<'de>,
{
let url = self.url(path);
let res = self
.client
.post(&url)
.timeout(self.request_timeout)
.json(body)
.send()
.await
.map_err(|e| {
NetworkError::new(&AnyError::error(format!(
"raft RPC POST {url}: transport: {e}"
)))
})?;
let status = res.status();
if !status.is_success() {
let body = res.text().await.unwrap_or_default();
return Err(NetworkError::new(&AnyError::error(format!(
"raft RPC POST {url}: HTTP {status}: {body}"
))));
}
let parsed = res.json::<Resp>().await.map_err(|e| {
NetworkError::new(&AnyError::error(format!(
"raft RPC POST {url}: response decode: {e}"
)))
})?;
Ok(parsed)
}
}
#[derive(Debug, Serialize, Deserialize)]
struct InstallFullSnapshotWire {
vote: Vote<YantrikNodeId>,
meta: SnapshotMeta<YantrikNodeId, YantrikNode>,
data: Vec<u8>,
}
impl RaftNetwork<YantrikRaftTypeConfig> for HttpRaftNetwork {
async fn append_entries(
&mut self,
rpc: AppendEntriesRequest<YantrikRaftTypeConfig>,
_option: RPCOption,
) -> Result<
AppendEntriesResponse<YantrikNodeId>,
RPCError<YantrikNodeId, YantrikNode, RaftError<YantrikNodeId>>,
> {
match self
.json_post::<_, Result<AppendEntriesResponse<YantrikNodeId>, RaftError<YantrikNodeId>>>(
"/v1/raft/append_entries",
&rpc,
)
.await
{
Ok(Ok(resp)) => Ok(resp),
Ok(Err(remote_err)) => Err(RPCError::RemoteError(RemoteError::new(
self.target,
remote_err,
))),
Err(net_err) => Err(RPCError::Unreachable(Unreachable::new(&net_err))),
}
}
async fn vote(
&mut self,
rpc: VoteRequest<YantrikNodeId>,
_option: RPCOption,
) -> Result<
VoteResponse<YantrikNodeId>,
RPCError<YantrikNodeId, YantrikNode, RaftError<YantrikNodeId>>,
> {
match self
.json_post::<_, Result<VoteResponse<YantrikNodeId>, RaftError<YantrikNodeId>>>(
"/v1/raft/vote",
&rpc,
)
.await
{
Ok(Ok(resp)) => Ok(resp),
Ok(Err(remote_err)) => Err(RPCError::RemoteError(RemoteError::new(
self.target,
remote_err,
))),
Err(net_err) => Err(RPCError::Unreachable(Unreachable::new(&net_err))),
}
}
async fn install_snapshot(
&mut self,
_rpc: InstallSnapshotRequest<YantrikRaftTypeConfig>,
_option: RPCOption,
) -> Result<
InstallSnapshotResponse<YantrikNodeId>,
RPCError<YantrikNodeId, YantrikNode, RaftError<YantrikNodeId, InstallSnapshotError>>,
> {
Err(RPCError::Unreachable(Unreachable::new(&NetworkError::new(
&AnyError::error(
"install_snapshot is deprecated under generic-snapshot-data; use full_snapshot",
),
))))
}
async fn full_snapshot(
&mut self,
vote: Vote<YantrikNodeId>,
snapshot: Snapshot<YantrikRaftTypeConfig>,
_cancel: impl std::future::Future<Output = ReplicationClosed> + Send + 'static,
_option: RPCOption,
) -> Result<
SnapshotResponse<YantrikNodeId>,
StreamingError<YantrikRaftTypeConfig, Fatal<YantrikNodeId>>,
> {
let data = snapshot.snapshot.into_inner();
let wire = InstallFullSnapshotWire {
vote: vote.clone(),
meta: snapshot.meta,
data,
};
match self
.json_post::<_, Result<SnapshotResponse<YantrikNodeId>, Fatal<YantrikNodeId>>>(
"/v1/raft/install_full_snapshot",
&wire,
)
.await
{
Ok(Ok(resp)) => Ok(resp),
Ok(Err(fatal)) => Err(StreamingError::RemoteError(RemoteError::new(
self.target,
fatal,
))),
Err(net_err) => Err(StreamingError::Network(net_err)),
}
}
fn backoff(&self) -> Backoff {
Backoff::new(std::iter::repeat(Duration::from_millis(500)))
}
}
#[derive(Clone)]
pub struct HttpRaftNetworkFactory {
client: Client,
request_timeout: Duration,
}
impl HttpRaftNetworkFactory {
pub fn new(client: Client, request_timeout: Duration) -> Self {
Self {
client,
request_timeout,
}
}
pub fn new_plaintext(request_timeout: Duration) -> Self {
tracing::warn!(
"HttpRaftNetworkFactory: plaintext HTTP — DEV ONLY. \
Production must use new() with mTLS-configured reqwest::Client (RFC 014-A)."
);
let client = Client::builder()
.timeout(request_timeout)
.build()
.expect("plaintext reqwest client must build");
Self {
client,
request_timeout,
}
}
}
impl RaftNetworkFactory<YantrikRaftTypeConfig> for HttpRaftNetworkFactory {
type Network = HttpRaftNetwork;
async fn new_client(&mut self, target: YantrikNodeId, node: &YantrikNode) -> Self::Network {
HttpRaftNetwork {
client: self.client.clone(),
target,
target_addr: node.addr.clone(),
request_timeout: self.request_timeout,
}
}
}
pub fn raft_receive_router(raft: Arc<Raft<YantrikRaftTypeConfig>>) -> Router {
Router::new()
.route("/v1/raft/append_entries", post(handle_append_entries))
.route("/v1/raft/vote", post(handle_vote))
.route(
"/v1/raft/install_full_snapshot",
post(handle_install_full_snapshot),
)
.with_state(raft)
}
async fn handle_append_entries(
State(raft): State<Arc<Raft<YantrikRaftTypeConfig>>>,
Json(rpc): Json<AppendEntriesRequest<YantrikRaftTypeConfig>>,
) -> Response {
match raft.append_entries(rpc).await {
Ok(resp) => Json::<Result<_, RaftError<YantrikNodeId>>>(Ok(resp)).into_response(),
Err(e) => {
Json::<Result<AppendEntriesResponse<YantrikNodeId>, RaftError<YantrikNodeId>>>(Err(e))
.into_response()
}
}
}
async fn handle_vote(
State(raft): State<Arc<Raft<YantrikRaftTypeConfig>>>,
Json(rpc): Json<VoteRequest<YantrikNodeId>>,
) -> Response {
match raft.vote(rpc).await {
Ok(resp) => Json::<Result<_, RaftError<YantrikNodeId>>>(Ok(resp)).into_response(),
Err(e) => Json::<Result<VoteResponse<YantrikNodeId>, RaftError<YantrikNodeId>>>(Err(e))
.into_response(),
}
}
async fn handle_install_full_snapshot(
State(raft): State<Arc<Raft<YantrikRaftTypeConfig>>>,
Json(wire): Json<InstallFullSnapshotWire>,
) -> Response {
let snapshot = Snapshot {
meta: wire.meta,
snapshot: Box::new(std::io::Cursor::new(wire.data)),
};
match raft.install_full_snapshot(wire.vote, snapshot).await {
Ok(resp) => Json::<Result<_, Fatal<YantrikNodeId>>>(Ok(resp)).into_response(),
Err(e) => {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json::<Result<SnapshotResponse<YantrikNodeId>, Fatal<YantrikNodeId>>>(Err(e)),
)
.into_response()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::commit::LocalSqliteCommitter;
use crate::raft::log_storage::SqliteRaftLogStorage;
use crate::raft::state_machine::YantrikStateMachine;
use openraft::Config;
use std::collections::BTreeMap;
use std::net::SocketAddr;
async fn spawn_single_node_raft() -> (Arc<Raft<YantrikRaftTypeConfig>>, String) {
let local = Arc::new(LocalSqliteCommitter::open_in_memory().unwrap());
let log_store = SqliteRaftLogStorage::open_in_memory();
let state_machine = YantrikStateMachine::new(
local,
std::sync::Arc::new(crate::commit::LocalApplier::new()),
);
let network = super::super::network::StubRaftNetworkFactory;
let config = Arc::new(
Config {
cluster_name: "yantrikdb-recv-test".into(),
heartbeat_interval: 100,
election_timeout_min: 200,
election_timeout_max: 400,
..Default::default()
}
.validate()
.unwrap(),
);
let me = YantrikNodeId::new(1);
let raft = Arc::new(
Raft::<YantrikRaftTypeConfig>::new(me, config, network, log_store, state_machine)
.await
.unwrap(),
);
let mut nodes = BTreeMap::new();
nodes.insert(me, YantrikNode::new("http://127.0.0.1:0"));
raft.initialize(nodes).await.unwrap();
for _ in 0..30 {
if raft.current_leader().await == Some(me) {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
let router = raft_receive_router(raft.clone());
let listener = tokio::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))
.await
.unwrap();
let bound = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, router).await.ok();
});
tokio::time::sleep(Duration::from_millis(50)).await;
(raft, format!("http://{bound}"))
}
#[tokio::test]
async fn http_factory_creates_per_peer_clients() {
let mut f = HttpRaftNetworkFactory::new_plaintext(Duration::from_millis(500));
let n1 = f
.new_client(YantrikNodeId::new(1), &YantrikNode::new("http://n1"))
.await;
assert_eq!(n1.target, YantrikNodeId::new(1));
assert_eq!(n1.target_addr, "http://n1");
}
#[tokio::test]
async fn append_entries_roundtrip_through_http_layer() {
let (_raft, base) = spawn_single_node_raft().await;
let mut f = HttpRaftNetworkFactory::new_plaintext(Duration::from_secs(2));
let mut net = f
.new_client(YantrikNodeId::new(1), &YantrikNode::new(&base))
.await;
let rpc = AppendEntriesRequest::<YantrikRaftTypeConfig> {
vote: Vote::new_committed(1, YantrikNodeId::new(1)),
prev_log_id: None,
entries: Vec::new(),
leader_commit: None,
};
let resp = net
.append_entries(rpc, RPCOption::new(Duration::from_secs(1)))
.await;
match resp {
Ok(_) => {}
Err(RPCError::RemoteError(_)) => {}
Err(other) => panic!("expected Ok or RemoteError, got {other:?}"),
}
}
#[tokio::test]
async fn vote_roundtrip_through_http_layer() {
let (_raft, base) = spawn_single_node_raft().await;
let mut f = HttpRaftNetworkFactory::new_plaintext(Duration::from_secs(2));
let mut net = f
.new_client(YantrikNodeId::new(1), &YantrikNode::new(&base))
.await;
let rpc = VoteRequest {
vote: Vote::new(2, YantrikNodeId::new(2)),
last_log_id: None,
};
let resp = net
.vote(rpc, RPCOption::new(Duration::from_secs(1)))
.await
.unwrap();
let _ = resp.vote_granted;
}
#[tokio::test]
async fn unreachable_address_surfaces_as_transport_error() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let dead_addr = listener.local_addr().unwrap();
drop(listener);
let mut f = HttpRaftNetworkFactory::new_plaintext(Duration::from_millis(500));
let mut net = f
.new_client(
YantrikNodeId::new(99),
&YantrikNode::new(&format!("http://{dead_addr}")),
)
.await;
let rpc = VoteRequest {
vote: Vote::new(1, YantrikNodeId::new(1)),
last_log_id: None,
};
let err = net
.vote(rpc, RPCOption::new(Duration::from_millis(500)))
.await
.unwrap_err();
assert!(matches!(err, RPCError::Unreachable(_)));
}
#[tokio::test]
async fn install_snapshot_returns_unreachable_under_generic_data() {
let (_raft, base) = spawn_single_node_raft().await;
let mut f = HttpRaftNetworkFactory::new_plaintext(Duration::from_secs(1));
let mut net = f
.new_client(YantrikNodeId::new(1), &YantrikNode::new(&base))
.await;
let dummy_meta = SnapshotMeta::<YantrikNodeId, YantrikNode>::default();
let rpc = InstallSnapshotRequest {
vote: Vote::new(1, YantrikNodeId::new(1)),
meta: dummy_meta,
offset: 0,
data: Vec::new(),
done: true,
};
let err = net
.install_snapshot(rpc, RPCOption::new(Duration::from_secs(1)))
.await
.unwrap_err();
assert!(matches!(err, RPCError::Unreachable(_)));
}
}