use axum::{
body::Bytes,
extract::{Json, State},
http::{header, HeaderMap, StatusCode},
response::IntoResponse,
routing::{get, post},
Router,
};
use openraft::BasicNode;
use rivven_cluster::{
hash_node_id, ClusterCoordinator, MetadataCommand, MetadataResponse, RaftNode, RaftTypeConfig,
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, error, info};
const CONTENT_TYPE_BINARY: &str = "application/octet-stream";
const CONTENT_TYPE_JSON: &str = "application/json";
fn deserialize_request<T: DeserializeOwned>(
headers: &HeaderMap,
body: &Bytes,
) -> Result<T, String> {
let content_type = headers
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or(CONTENT_TYPE_JSON);
if content_type.contains("octet-stream") {
postcard::from_bytes(body).map_err(|e| format!("postcard deserialize error: {}", e))
} else {
serde_json::from_slice(body).map_err(|e| format!("json deserialize error: {}", e))
}
}
fn serialize_response<T: Serialize>(
headers: &HeaderMap,
data: &T,
) -> Result<(Bytes, &'static str), String> {
let accept = headers
.get(header::ACCEPT)
.or_else(|| headers.get(header::CONTENT_TYPE))
.and_then(|v| v.to_str().ok())
.unwrap_or(CONTENT_TYPE_JSON);
if accept.contains("octet-stream") {
let bytes =
postcard::to_allocvec(data).map_err(|e| format!("postcard serialize error: {}", e))?;
Ok((Bytes::from(bytes), CONTENT_TYPE_BINARY))
} else {
let bytes = serde_json::to_vec(data).map_err(|e| format!("json serialize error: {}", e))?;
Ok((Bytes::from(bytes), CONTENT_TYPE_JSON))
}
}
#[derive(Clone, Default)]
pub struct TlsConfig {
pub enabled: bool,
pub cert_path: Option<PathBuf>,
pub key_path: Option<PathBuf>,
pub ca_path: Option<PathBuf>,
pub verify_client: bool,
}
#[derive(Clone)]
pub struct RaftApiState {
pub raft_node: Arc<RwLock<RaftNode>>,
pub coordinator: Option<Arc<RwLock<ClusterCoordinator>>>,
pub http_client: reqwest::Client,
pub node_addresses: Arc<RwLock<std::collections::HashMap<u64, String>>>,
pub cluster_auth_token: Option<Arc<String>>,
}
impl RaftApiState {
pub fn new(
raft_node: Arc<RwLock<RaftNode>>,
coordinator: Option<Arc<RwLock<ClusterCoordinator>>>,
) -> Self {
Self::with_tls(raft_node, coordinator, &TlsConfig::default())
}
pub fn with_tls(
raft_node: Arc<RwLock<RaftNode>>,
coordinator: Option<Arc<RwLock<ClusterCoordinator>>>,
tls_config: &TlsConfig,
) -> Self {
let http_client = if tls_config.enabled {
if let Some(ca_path) = tls_config.ca_path.as_ref() {
let ca_cert = std::fs::read(ca_path).expect("Failed to read CA certificate");
let ca_cert = reqwest::Certificate::from_pem(&ca_cert)
.expect("Failed to parse CA certificate");
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.add_root_certificate(ca_cert)
.build()
.expect("Failed to create TLS HTTP client")
} else {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.expect("Failed to create TLS HTTP client")
}
} else {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.expect("Failed to create HTTP client")
};
Self {
raft_node,
coordinator,
http_client,
node_addresses: Arc::new(RwLock::new(std::collections::HashMap::new())),
cluster_auth_token: None,
}
}
pub fn with_cluster_auth_token(mut self, token: Option<String>) -> Self {
self.cluster_auth_token = token.map(Arc::new);
self
}
pub async fn register_node_address(&self, node_id: u64, http_addr: String) {
self.node_addresses.write().await.insert(node_id, http_addr);
}
pub async fn get_node_address(&self, node_id: u64) -> Option<String> {
self.node_addresses.read().await.get(&node_id).cloned()
}
}
#[derive(Debug, Serialize)]
pub struct HealthResponse {
pub status: String,
pub node_id: u64,
pub node_id_str: String,
pub is_leader: bool,
pub leader_id: Option<u64>,
pub cluster_mode: bool,
}
#[derive(Debug, Serialize)]
pub struct MetricsResponse {
pub node_id: u64,
pub is_leader: bool,
pub current_term: Option<u64>,
pub last_log_index: Option<u64>,
pub commit_index: Option<u64>,
pub applied_index: Option<u64>,
pub membership_size: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct MembershipResponse {
pub nodes: Vec<NodeInfo>,
}
#[derive(Debug, Serialize)]
pub struct NodeInfo {
pub id: u64,
pub addr: String,
pub is_voter: bool,
}
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub error: String,
pub code: u16,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ProposalRequest {
pub command: MetadataCommand,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ProposalResponse {
pub success: bool,
pub response: Option<MetadataResponse>,
pub error: Option<String>,
pub redirect_to: Option<u64>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct BootstrapRequest {
pub members: Vec<BootstrapMember>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BootstrapMember {
pub node_id: String,
pub addr: String,
}
#[derive(Debug, Serialize)]
pub struct BootstrapResponse {
pub success: bool,
pub message: String,
pub leader_id: Option<u64>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AddLearnerRequest {
pub node_id: String,
pub addr: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChangeMembershipRequest {
pub voters: Vec<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct BatchProposalRequest {
pub commands: Vec<MetadataCommand>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct BatchProposalResponse {
pub success: bool,
pub responses: Option<Vec<MetadataResponse>>,
pub error: Option<String>,
pub count: usize,
}
async fn cluster_auth_middleware(
expected_token: Arc<String>,
req: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
use subtle::ConstantTimeEq;
let auth_header = req
.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok());
match auth_header {
Some(value) if value.starts_with("Bearer ") => {
let provided = &value[7..];
let expected = expected_token.as_bytes();
let provided_bytes = provided.as_bytes();
if expected.len() == provided_bytes.len() && expected.ct_eq(provided_bytes).into() {
next.run(req).await
} else {
(StatusCode::UNAUTHORIZED, "Invalid cluster token").into_response()
}
}
_ => (
StatusCode::UNAUTHORIZED,
"Missing Authorization: Bearer <token>",
)
.into_response(),
}
}
pub fn create_raft_router(state: RaftApiState) -> Router {
create_raft_router_base(state.clone()).with_state(state)
}
fn create_raft_router_base<S: Clone + Send + Sync + 'static>(state: RaftApiState) -> Router<S> {
let public = Router::new()
.route("/health", get(health_handler))
.route("/metrics", get(prometheus_metrics_handler))
.route("/metrics/json", get(metrics_handler))
.route("/membership", get(membership_handler));
let protected = Router::new()
.route("/raft/append", post(append_entries_handler))
.route("/raft/snapshot", post(install_snapshot_handler))
.route("/raft/vote", post(vote_handler))
.route("/api/v1/bootstrap", post(bootstrap_handler))
.route("/api/v1/add-learner", post(add_learner_handler))
.route("/api/v1/change-membership", post(change_membership_handler))
.route(
"/api/v1/transfer-leadership",
post(transfer_leadership_handler),
)
.route("/api/v1/propose", post(propose_handler))
.route("/api/v1/propose/batch", post(batch_propose_handler))
.route("/api/v1/metadata", get(metadata_handler))
.route(
"/api/v1/metadata/linearizable",
get(linearizable_metadata_handler),
);
let protected = if let Some(ref token) = state.cluster_auth_token {
let token = token.clone();
protected.layer(axum::middleware::from_fn(move |req, next| {
cluster_auth_middleware(token.clone(), req, next)
}))
} else {
protected
};
public.merge(protected).with_state(state)
}
async fn health_handler(State(state): State<RaftApiState>) -> impl IntoResponse {
let raft = state.raft_node.read().await;
let metrics = raft.metrics();
let has_leader = raft.leader().is_some();
let status = if has_leader {
"healthy"
} else if metrics.is_some() {
"degraded" } else {
"unhealthy" };
let status_code = if status == "healthy" {
StatusCode::OK
} else {
StatusCode::SERVICE_UNAVAILABLE
};
let response = HealthResponse {
status: status.to_string(),
node_id: raft.node_id(),
node_id_str: raft.node_id_str().to_string(),
is_leader: raft.is_leader(),
leader_id: raft.leader(),
cluster_mode: raft.leader().is_some_and(|leader| leader != raft.node_id()),
};
(status_code, Json(response))
}
async fn metrics_handler(State(state): State<RaftApiState>) -> impl IntoResponse {
let raft = state.raft_node.read().await;
let metrics = raft.metrics();
let response = MetricsResponse {
node_id: raft.node_id(),
is_leader: raft.is_leader(),
current_term: metrics.as_ref().map(|m| m.current_term),
last_log_index: metrics.as_ref().and_then(|m| m.last_log_index),
commit_index: metrics
.as_ref()
.and_then(|m| m.last_applied.map(|l| l.index)), applied_index: metrics
.as_ref()
.and_then(|m| m.last_applied.map(|l| l.index)),
membership_size: metrics
.as_ref()
.map(|m| m.membership_config.membership().voter_ids().count()),
};
(StatusCode::OK, Json(response))
}
async fn prometheus_metrics_handler(State(state): State<RaftApiState>) -> impl IntoResponse {
let raft = state.raft_node.read().await;
let metrics = raft.metrics();
let mut output = String::new();
output.push_str("# HELP rivven_raft_node_id Raft node identifier\n");
output.push_str("# TYPE rivven_raft_node_id gauge\n");
output.push_str(&format!("rivven_raft_node_id {}\n", raft.node_id()));
output.push_str("# HELP rivven_raft_is_leader Whether this node is the Raft leader\n");
output.push_str("# TYPE rivven_raft_is_leader gauge\n");
output.push_str(&format!(
"rivven_raft_is_leader {}\n",
if raft.is_leader() { 1 } else { 0 }
));
if let Some(ref m) = metrics {
output.push_str("# HELP rivven_raft_current_term Current Raft term\n");
output.push_str("# TYPE rivven_raft_current_term gauge\n");
output.push_str(&format!("rivven_raft_current_term {}\n", m.current_term));
if let Some(index) = m.last_log_index {
output.push_str("# HELP rivven_raft_last_log_index Index of last log entry\n");
output.push_str("# TYPE rivven_raft_last_log_index gauge\n");
output.push_str(&format!("rivven_raft_last_log_index {}\n", index));
}
if let Some(log_id) = m.last_applied {
output.push_str("# HELP rivven_raft_applied_index Index of last applied entry\n");
output.push_str("# TYPE rivven_raft_applied_index gauge\n");
output.push_str(&format!("rivven_raft_applied_index {}\n", log_id.index));
}
let voter_count = m.membership_config.membership().voter_ids().count();
let learner_count = m.membership_config.membership().learner_ids().count();
output.push_str("# HELP rivven_raft_cluster_voters Number of voter nodes in cluster\n");
output.push_str("# TYPE rivven_raft_cluster_voters gauge\n");
output.push_str(&format!("rivven_raft_cluster_voters {}\n", voter_count));
output.push_str("# HELP rivven_raft_cluster_learners Number of learner nodes in cluster\n");
output.push_str("# TYPE rivven_raft_cluster_learners gauge\n");
output.push_str(&format!("rivven_raft_cluster_learners {}\n", learner_count));
}
output.push_str("\n# Rivven core metrics\n");
output.push_str("# HELP rivven_info Build information\n");
output.push_str("# TYPE rivven_info gauge\n");
output.push_str(&format!(
"rivven_info{{version=\"{}\",node_id_str=\"{}\"}} 1\n",
env!("CARGO_PKG_VERSION"),
raft.node_id_str()
));
(
StatusCode::OK,
[(
header::CONTENT_TYPE,
"text/plain; version=0.0.4; charset=utf-8",
)],
output,
)
}
async fn membership_handler(State(state): State<RaftApiState>) -> impl IntoResponse {
let raft = state.raft_node.read().await;
let nodes = if let Some(metrics) = raft.metrics() {
metrics
.membership_config
.membership()
.nodes()
.map(|(id, node)| NodeInfo {
id: *id,
addr: node.addr.clone(),
is_voter: metrics
.membership_config
.membership()
.voter_ids()
.any(|vid| vid == *id),
})
.collect()
} else {
vec![NodeInfo {
id: raft.node_id(),
addr: "localhost".to_string(),
is_voter: true,
}]
};
let response = MembershipResponse { nodes };
(StatusCode::OK, Json(response))
}
async fn append_entries_handler(
State(state): State<RaftApiState>,
headers: HeaderMap,
body: Bytes,
) -> impl IntoResponse {
let req: openraft::raft::AppendEntriesRequest<RaftTypeConfig> =
match deserialize_request(&headers, &body) {
Ok(r) => r,
Err(e) => {
error!("Failed to deserialize AppendEntries: {}", e);
return (
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: e,
code: 400,
}),
)
.into_response();
}
};
debug!(
leader = req.vote.leader_id().node_id,
term = req.vote.leader_id().term,
entries = req.entries.len(),
"AppendEntries RPC"
);
let raft = state.raft_node.read().await;
match raft.handle_append_entries(req).await {
Ok(response) => {
match serialize_response(&headers, &response) {
Ok((bytes, content_type)) => (
StatusCode::OK,
[(header::CONTENT_TYPE, content_type)],
bytes,
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: e,
code: 500,
}),
)
.into_response(),
}
}
Err(e) => {
error!("AppendEntries failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: e.to_string(),
code: 500,
}),
)
.into_response()
}
}
}
async fn install_snapshot_handler(
State(state): State<RaftApiState>,
headers: HeaderMap,
body: Bytes,
) -> impl IntoResponse {
let req: openraft::raft::InstallSnapshotRequest<RaftTypeConfig> =
match deserialize_request(&headers, &body) {
Ok(r) => r,
Err(e) => {
error!("Failed to deserialize InstallSnapshot: {}", e);
return (
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: e,
code: 400,
}),
)
.into_response();
}
};
debug!(
leader = req.vote.leader_id().node_id,
snapshot_size = req.data.len(),
"InstallSnapshot RPC"
);
let raft = state.raft_node.read().await;
match raft.handle_install_snapshot(req).await {
Ok(response) => match serialize_response(&headers, &response) {
Ok((bytes, content_type)) => (
StatusCode::OK,
[(header::CONTENT_TYPE, content_type)],
bytes,
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: e,
code: 500,
}),
)
.into_response(),
},
Err(e) => {
error!("InstallSnapshot failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: e.to_string(),
code: 500,
}),
)
.into_response()
}
}
}
async fn vote_handler(
State(state): State<RaftApiState>,
headers: HeaderMap,
body: Bytes,
) -> impl IntoResponse {
let req: openraft::raft::VoteRequest<u64> = match deserialize_request(&headers, &body) {
Ok(r) => r,
Err(e) => {
error!("Failed to deserialize Vote: {}", e);
return (
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: e,
code: 400,
}),
)
.into_response();
}
};
debug!(
candidate = req.vote.leader_id().node_id,
term = req.vote.leader_id().term,
"Vote RPC"
);
let raft = state.raft_node.read().await;
match raft.handle_vote(req).await {
Ok(response) => match serialize_response(&headers, &response) {
Ok((bytes, content_type)) => (
StatusCode::OK,
[(header::CONTENT_TYPE, content_type)],
bytes,
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: e,
code: 500,
}),
)
.into_response(),
},
Err(e) => {
error!("Vote failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: e.to_string(),
code: 500,
}),
)
.into_response()
}
}
}
async fn bootstrap_handler(
State(state): State<RaftApiState>,
Json(req): Json<BootstrapRequest>,
) -> impl IntoResponse {
let raft = state.raft_node.read().await;
let members: std::collections::BTreeMap<u64, BasicNode> = req
.members
.iter()
.map(|m| {
let node_id = hash_node_id(&m.node_id);
(
node_id,
BasicNode {
addr: m.addr.clone(),
},
)
})
.collect();
drop(raft);
for member in &req.members {
let node_id = hash_node_id(&member.node_id);
state
.register_node_address(node_id, member.addr.clone())
.await;
}
let raft = state.raft_node.read().await;
info!(member_count = members.len(), "Bootstrapping cluster");
match raft.bootstrap(members).await {
Ok(_) => (
StatusCode::OK,
Json(BootstrapResponse {
success: true,
message: "Cluster bootstrapped successfully".to_string(),
leader_id: raft.leader(),
}),
)
.into_response(),
Err(e) => {
error!("Bootstrap failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(BootstrapResponse {
success: false,
message: format!("Bootstrap failed: {}", e),
leader_id: None,
}),
)
.into_response()
}
}
}
async fn add_learner_handler(
State(state): State<RaftApiState>,
Json(req): Json<AddLearnerRequest>,
) -> impl IntoResponse {
let raft = state.raft_node.read().await;
if !raft.is_leader() {
return (
StatusCode::TEMPORARY_REDIRECT,
Json(ErrorResponse {
error: "Not leader".to_string(),
code: 307,
}),
)
.into_response();
}
let node_id = hash_node_id(&req.node_id);
let node = BasicNode {
addr: req.addr.clone(),
};
drop(raft);
state.register_node_address(node_id, req.addr.clone()).await;
let raft = state.raft_node.read().await;
info!(node_id, addr = %req.addr, "Adding learner node");
if let Some(raft_instance) = raft.get_raft() {
match raft_instance.add_learner(node_id, node, true).await {
Ok(_) => (
StatusCode::OK,
Json(serde_json::json!({
"success": true,
"message": "Learner added successfully",
"node_id": node_id
})),
)
.into_response(),
Err(e) => {
error!("Failed to add learner: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Failed to add learner: {}", e),
code: 500,
}),
)
.into_response()
}
}
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorResponse {
error: "Raft not initialized".to_string(),
code: 503,
}),
)
.into_response()
}
}
async fn change_membership_handler(
State(state): State<RaftApiState>,
Json(req): Json<ChangeMembershipRequest>,
) -> impl IntoResponse {
let raft = state.raft_node.read().await;
if !raft.is_leader() {
return (
StatusCode::TEMPORARY_REDIRECT,
Json(ErrorResponse {
error: "Not leader".to_string(),
code: 307,
}),
)
.into_response();
}
let voters: std::collections::BTreeSet<u64> =
req.voters.iter().map(|id| hash_node_id(id)).collect();
info!(voter_count = voters.len(), "Changing cluster membership");
if let Some(raft_instance) = raft.get_raft() {
match raft_instance.change_membership(voters, false).await {
Ok(_) => (
StatusCode::OK,
Json(serde_json::json!({
"success": true,
"message": "Membership changed successfully"
})),
)
.into_response(),
Err(e) => {
error!("Failed to change membership: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Failed to change membership: {}", e),
code: 500,
}),
)
.into_response()
}
}
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorResponse {
error: "Raft not initialized".to_string(),
code: 503,
}),
)
.into_response()
}
}
async fn propose_handler(
State(state): State<RaftApiState>,
Json(req): Json<ProposalRequest>,
) -> impl IntoResponse {
let raft = state.raft_node.read().await;
if !raft.is_leader() {
let leader_id = raft.leader();
drop(raft);
if let Some(leader) = leader_id {
if let Some(leader_addr) = state.get_node_address(leader).await {
info!(leader_id = leader, leader_addr = %leader_addr, "Forwarding proposal to leader");
match forward_to_leader(&state.http_client, &leader_addr, &req).await {
Ok(response) => return (StatusCode::OK, Json(response)).into_response(),
Err(e) => {
error!("Failed to forward to leader: {}", e);
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(ProposalResponse {
success: false,
response: None,
error: Some(format!("Leader forwarding failed: {}", e)),
redirect_to: Some(leader),
}),
)
.into_response();
}
}
}
}
return (
StatusCode::TEMPORARY_REDIRECT,
Json(ProposalResponse {
success: false,
response: None,
error: Some("Not leader".to_string()),
redirect_to: leader_id,
}),
)
.into_response();
}
match raft.propose(req.command).await {
Ok(response) => (
StatusCode::OK,
Json(ProposalResponse {
success: true,
response: Some(response),
error: None,
redirect_to: None,
}),
)
.into_response(),
Err(e) => {
error!("Proposal failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ProposalResponse {
success: false,
response: None,
error: Some(e.to_string()),
redirect_to: None,
}),
)
.into_response()
}
}
}
async fn forward_to_leader(
client: &reqwest::Client,
leader_addr: &str,
req: &ProposalRequest,
) -> Result<ProposalResponse, String> {
let url = format!("{}/api/v1/propose", leader_addr);
let response = client
.post(&url)
.json(req)
.send()
.await
.map_err(|e| format!("HTTP request failed: {}", e))?;
if !response.status().is_success() {
return Err(format!("Leader returned error: {}", response.status()));
}
response
.json::<ProposalResponse>()
.await
.map_err(|e| format!("Failed to parse response: {}", e))
}
async fn batch_propose_handler(
State(state): State<RaftApiState>,
Json(req): Json<BatchProposalRequest>,
) -> impl IntoResponse {
let raft = state.raft_node.read().await;
if !raft.is_leader() {
let leader_id = raft.leader();
return (
StatusCode::TEMPORARY_REDIRECT,
Json(BatchProposalResponse {
success: false,
responses: None,
error: Some(format!("Not leader, redirect to {:?}", leader_id)),
count: 0,
}),
)
.into_response();
}
let count = req.commands.len();
info!(count, "Processing batch proposal");
match raft.propose_batch(req.commands).await {
Ok(responses) => (
StatusCode::OK,
Json(BatchProposalResponse {
success: true,
responses: Some(responses),
error: None,
count,
}),
)
.into_response(),
Err(e) => {
error!("Batch proposal failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(BatchProposalResponse {
success: false,
responses: None,
error: Some(e.to_string()),
count,
}),
)
.into_response()
}
}
}
async fn metadata_handler(State(state): State<RaftApiState>) -> impl IntoResponse {
let raft = state.raft_node.read().await;
let metadata = raft.metadata().await;
let response = serde_json::json!({
"topics": metadata.topics.keys().collect::<Vec<_>>(),
"topic_count": metadata.topics.len(),
"node_count": metadata.nodes.len(),
"last_applied_index": metadata.last_applied_index,
"consistency": "eventual",
});
(StatusCode::OK, Json(response))
}
async fn linearizable_metadata_handler(State(state): State<RaftApiState>) -> impl IntoResponse {
let raft = state.raft_node.read().await;
match raft.ensure_linearizable_read().await {
Ok(_) => {
let metadata = raft.metadata().await;
let response = serde_json::json!({
"topics": metadata.topics.keys().collect::<Vec<_>>(),
"topic_count": metadata.topics.len(),
"node_count": metadata.nodes.len(),
"last_applied_index": metadata.last_applied_index,
"consistency": "linearizable",
});
(StatusCode::OK, Json(response)).into_response()
}
Err(e) => {
error!("Linearizable read failed: {}", e);
(
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorResponse {
error: format!("Linearizable read failed: {}", e),
code: 503,
}),
)
.into_response()
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TransferLeadershipRequest {
#[serde(default)]
pub target_node_id: Option<String>,
}
async fn transfer_leadership_handler(
State(state): State<RaftApiState>,
Json(req): Json<TransferLeadershipRequest>,
) -> impl IntoResponse {
let raft = state.raft_node.read().await;
if !raft.is_leader() {
return (
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: "Not the leader - cannot transfer leadership".to_string(),
code: 400,
}),
)
.into_response();
}
if let Some(raft_instance) = raft.get_raft() {
if let Some(ref target) = req.target_node_id {
info!(
target = %target,
"Leadership transfer requested with preferred target (best-effort)"
);
}
info!("Initiating leadership step-down for graceful transfer");
match raft_instance.trigger().elect().await {
Ok(_) => (
StatusCode::OK,
Json(serde_json::json!({
"success": true,
"message": "Election triggered - leadership may transfer",
"note": "Check /health endpoint to see new leader"
})),
)
.into_response(),
Err(e) => {
error!("Election trigger failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Election trigger failed: {}", e),
code: 500,
}),
)
.into_response()
}
}
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorResponse {
error: "Raft not initialized".to_string(),
code: 503,
}),
)
.into_response()
}
}
pub async fn start_raft_api_server(
bind_addr: std::net::SocketAddr,
state: RaftApiState,
) -> anyhow::Result<()> {
start_raft_api_server_with_tls(bind_addr, state, &TlsConfig::default()).await
}
#[cfg(feature = "dashboard")]
pub struct DashboardConfig {
pub enabled: bool,
pub stats: std::sync::Arc<crate::cluster_server::ServerStats>,
pub topic_manager: rivven_core::TopicManager,
pub offset_manager: rivven_core::OffsetManager,
}
pub async fn start_raft_api_server_with_tls(
bind_addr: std::net::SocketAddr,
state: RaftApiState,
tls_config: &TlsConfig,
) -> anyhow::Result<()> {
let app = create_raft_router(state);
if tls_config.enabled {
let cert_path = tls_config
.cert_path
.as_ref()
.ok_or_else(|| anyhow::anyhow!("TLS enabled but no certificate path provided"))?;
let key_path = tls_config
.key_path
.as_ref()
.ok_or_else(|| anyhow::anyhow!("TLS enabled but no key path provided"))?;
info!("Starting Raft API server with TLS on {}", bind_addr);
let config =
axum_server::tls_rustls::RustlsConfig::from_pem_file(cert_path, key_path).await?;
axum_server::bind_rustls(bind_addr, config)
.serve(app.into_make_service())
.await?;
} else {
info!("Starting Raft API server on {}", bind_addr);
let listener = tokio::net::TcpListener::bind(bind_addr).await?;
axum::serve(listener, app).await?;
}
Ok(())
}
#[cfg(feature = "dashboard")]
pub async fn start_raft_api_server_with_dashboard(
bind_addr: std::net::SocketAddr,
state: RaftApiState,
tls_config: &TlsConfig,
dashboard_config: DashboardConfig,
) -> anyhow::Result<()> {
use tower_http::cors::{Any, CorsLayer};
let dashboard_state = crate::dashboard::DashboardState {
raft_state: state.clone(),
stats: dashboard_config.stats,
topic_manager: dashboard_config.topic_manager,
offset_manager: dashboard_config.offset_manager,
};
let app = if dashboard_config.enabled {
info!("Dashboard enabled at http://{}/", bind_addr);
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
create_raft_router(state)
.merge(crate::dashboard::create_dashboard_router(dashboard_state))
.layer(cors)
} else {
create_raft_router(state)
};
if tls_config.enabled {
let cert_path = tls_config
.cert_path
.as_ref()
.ok_or_else(|| anyhow::anyhow!("TLS enabled but no certificate path provided"))?;
let key_path = tls_config
.key_path
.as_ref()
.ok_or_else(|| anyhow::anyhow!("TLS enabled but no key path provided"))?;
info!("Starting API server with TLS on {}", bind_addr);
let config =
axum_server::tls_rustls::RustlsConfig::from_pem_file(cert_path, key_path).await?;
axum_server::bind_rustls(bind_addr, config)
.serve(app.into_make_service())
.await?;
} else {
info!("Starting API server on {}", bind_addr);
let listener = tokio::net::TcpListener::bind(bind_addr).await?;
axum::serve(listener, app).await?;
}
Ok(())
}
#[cfg(test)]
mod tests {
}