use std::net::SocketAddr;
use std::sync::Arc;
use axum::{
extract::State,
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use dig_rpc_types::envelope::{JsonRpcRequest, JsonRpcResponse};
use dig_service::{RpcApi, ShutdownToken};
use crate::dispatch::dispatch_envelope;
use crate::error::RpcServerError;
use crate::method::MethodRegistry;
use crate::middleware::RateLimitState;
use crate::role::RoleMap;
use crate::tls::TlsConfig;
#[derive(Clone)]
pub enum RpcServerMode {
Internal {
bind: SocketAddr,
tls: TlsConfig,
role_map: Arc<RoleMap>,
},
Public {
bind: SocketAddr,
tls: TlsConfig,
},
PlainText {
bind: SocketAddr,
},
}
impl RpcServerMode {
pub fn public_plaintext(bind: SocketAddr) -> Self {
Self::PlainText { bind }
}
pub fn bind(&self) -> SocketAddr {
match self {
Self::Internal { bind, .. } => *bind,
Self::Public { bind, .. } => *bind,
Self::PlainText { bind } => *bind,
}
}
}
impl std::fmt::Debug for RpcServerMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Internal { bind, .. } => f.debug_struct("Internal").field("bind", bind).finish(),
Self::Public { bind, .. } => f.debug_struct("Public").field("bind", bind).finish(),
Self::PlainText { bind } => f.debug_struct("PlainText").field("bind", bind).finish(),
}
}
}
pub struct RpcServer<R: RpcApi + ?Sized> {
api: Arc<R>,
registry: Arc<MethodRegistry>,
mode: RpcServerMode,
rate_limit: RateLimitState,
}
impl<R: RpcApi + ?Sized> RpcServer<R> {
pub fn new(api: Arc<R>, registry: MethodRegistry, mode: RpcServerMode) -> Self {
Self {
api,
registry: Arc::new(registry),
mode,
rate_limit: RateLimitState::new(crate::middleware::RateLimitConfig::defaults()),
}
}
pub fn with_rate_limit_state(mut self, state: RateLimitState) -> Self {
self.rate_limit = state;
self
}
pub fn bind_addr(&self) -> SocketAddr {
self.mode.bind()
}
}
impl<R: RpcApi> RpcServer<R> {
pub async fn serve(self, shutdown: ShutdownToken) -> Result<(), RpcServerError> {
let app_state = AppState {
api: self.api,
registry: self.registry,
rate_limit: self.rate_limit,
};
let router = build_router::<R>(app_state);
let bind = self.mode.bind();
match self.mode {
RpcServerMode::PlainText { .. } => {
let listener = tokio::net::TcpListener::bind(bind).await.map_err(|e| {
RpcServerError::BindFailed {
addr: bind,
source: Arc::new(e),
}
})?;
axum::serve(listener, router)
.with_graceful_shutdown(async move { shutdown.cancelled().await })
.await
.map_err(|e| {
RpcServerError::Fatal(Arc::new(anyhow::anyhow!("axum::serve: {e}")))
})
}
RpcServerMode::Internal { tls, .. } | RpcServerMode::Public { tls, .. } => {
let rustls = axum_server::tls_rustls::RustlsConfig::from_config(tls.server_config);
axum_server::bind_rustls(bind, rustls)
.serve(router.into_make_service())
.await
.map_err(|e| {
RpcServerError::Fatal(Arc::new(anyhow::anyhow!("axum-server: {e}")))
})
}
}
}
}
struct AppState<R: RpcApi + ?Sized> {
api: Arc<R>,
registry: Arc<MethodRegistry>,
#[allow(dead_code)] rate_limit: RateLimitState,
}
impl<R: RpcApi + ?Sized> Clone for AppState<R> {
fn clone(&self) -> Self {
Self {
api: self.api.clone(),
registry: self.registry.clone(),
rate_limit: self.rate_limit.clone(),
}
}
}
fn build_router<R: RpcApi>(state: AppState<R>) -> Router {
Router::new()
.route("/", post(handle_rpc::<R>))
.route("/healthz", get(handle_healthz::<R>))
.with_state(state)
}
async fn handle_rpc<R: RpcApi>(
State(state): State<AppState<R>>,
Json(req): Json<JsonRpcRequest<serde_json::Value>>,
) -> Json<JsonRpcResponse<serde_json::Value>> {
let resp = dispatch_envelope(req, &*state.api, &state.registry).await;
Json(resp)
}
async fn handle_healthz<R: RpcApi>(State(state): State<AppState<R>>) -> impl IntoResponse {
match state.api.healthz().await {
Ok(()) => (StatusCode::OK, "OK"),
Err(_) => (StatusCode::SERVICE_UNAVAILABLE, "unavailable"),
}
}