1use std::net::SocketAddr;
51use std::sync::Arc;
52
53use axum::{
54 extract::State,
55 http::StatusCode,
56 response::IntoResponse,
57 routing::{get, post},
58 Json, Router,
59};
60use dig_rpc_types::envelope::{JsonRpcRequest, JsonRpcResponse};
61use dig_service::{RpcApi, ShutdownToken};
62
63use crate::dispatch::dispatch_envelope;
64use crate::error::RpcServerError;
65use crate::method::MethodRegistry;
66use crate::middleware::RateLimitState;
67use crate::role::RoleMap;
68use crate::tls::TlsConfig;
69
70#[derive(Clone)]
72pub enum RpcServerMode {
73 Internal {
75 bind: SocketAddr,
77 tls: TlsConfig,
79 role_map: Arc<RoleMap>,
81 },
82 Public {
84 bind: SocketAddr,
86 tls: TlsConfig,
88 },
89 PlainText {
92 bind: SocketAddr,
94 },
95}
96
97impl RpcServerMode {
98 pub fn public_plaintext(bind: SocketAddr) -> Self {
100 Self::PlainText { bind }
101 }
102
103 pub fn bind(&self) -> SocketAddr {
105 match self {
106 Self::Internal { bind, .. } => *bind,
107 Self::Public { bind, .. } => *bind,
108 Self::PlainText { bind } => *bind,
109 }
110 }
111}
112
113impl std::fmt::Debug for RpcServerMode {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 match self {
116 Self::Internal { bind, .. } => f.debug_struct("Internal").field("bind", bind).finish(),
117 Self::Public { bind, .. } => f.debug_struct("Public").field("bind", bind).finish(),
118 Self::PlainText { bind } => f.debug_struct("PlainText").field("bind", bind).finish(),
119 }
120 }
121}
122
123pub struct RpcServer<R: RpcApi + ?Sized> {
125 api: Arc<R>,
126 registry: Arc<MethodRegistry>,
127 mode: RpcServerMode,
128 rate_limit: RateLimitState,
129}
130
131impl<R: RpcApi + ?Sized> RpcServer<R> {
132 pub fn new(api: Arc<R>, registry: MethodRegistry, mode: RpcServerMode) -> Self {
135 Self {
136 api,
137 registry: Arc::new(registry),
138 mode,
139 rate_limit: RateLimitState::new(crate::middleware::RateLimitConfig::defaults()),
140 }
141 }
142
143 pub fn with_rate_limit_state(mut self, state: RateLimitState) -> Self {
145 self.rate_limit = state;
146 self
147 }
148
149 pub fn bind_addr(&self) -> SocketAddr {
151 self.mode.bind()
152 }
153}
154
155impl<R: RpcApi> RpcServer<R> {
156 pub async fn serve(self, shutdown: ShutdownToken) -> Result<(), RpcServerError> {
158 let app_state = AppState {
159 api: self.api,
160 registry: self.registry,
161 rate_limit: self.rate_limit,
162 };
163 let router = build_router::<R>(app_state);
164
165 let bind = self.mode.bind();
166 match self.mode {
167 RpcServerMode::PlainText { .. } => {
168 let listener = tokio::net::TcpListener::bind(bind).await.map_err(|e| {
169 RpcServerError::BindFailed {
170 addr: bind,
171 source: Arc::new(e),
172 }
173 })?;
174 axum::serve(listener, router)
175 .with_graceful_shutdown(async move { shutdown.cancelled().await })
176 .await
177 .map_err(|e| {
178 RpcServerError::Fatal(Arc::new(anyhow::anyhow!("axum::serve: {e}")))
179 })
180 }
181 RpcServerMode::Internal { tls, .. } | RpcServerMode::Public { tls, .. } => {
182 let rustls = axum_server::tls_rustls::RustlsConfig::from_config(tls.server_config);
183 axum_server::bind_rustls(bind, rustls)
184 .serve(router.into_make_service())
185 .await
186 .map_err(|e| {
187 RpcServerError::Fatal(Arc::new(anyhow::anyhow!("axum-server: {e}")))
188 })
189 }
190 }
191 }
192}
193
194struct AppState<R: RpcApi + ?Sized> {
200 api: Arc<R>,
201 registry: Arc<MethodRegistry>,
202 #[allow(dead_code)] rate_limit: RateLimitState,
204}
205
206impl<R: RpcApi + ?Sized> Clone for AppState<R> {
207 fn clone(&self) -> Self {
208 Self {
209 api: self.api.clone(),
210 registry: self.registry.clone(),
211 rate_limit: self.rate_limit.clone(),
212 }
213 }
214}
215
216fn build_router<R: RpcApi>(state: AppState<R>) -> Router {
217 Router::new()
218 .route("/", post(handle_rpc::<R>))
219 .route("/healthz", get(handle_healthz::<R>))
220 .with_state(state)
221}
222
223async fn handle_rpc<R: RpcApi>(
224 State(state): State<AppState<R>>,
225 Json(req): Json<JsonRpcRequest<serde_json::Value>>,
226) -> Json<JsonRpcResponse<serde_json::Value>> {
227 let resp = dispatch_envelope(req, &*state.api, &state.registry).await;
228 Json(resp)
229}
230
231async fn handle_healthz<R: RpcApi>(State(state): State<AppState<R>>) -> impl IntoResponse {
232 match state.api.healthz().await {
233 Ok(()) => (StatusCode::OK, "OK"),
234 Err(_) => (StatusCode::SERVICE_UNAVAILABLE, "unavailable"),
235 }
236}