Skip to main content

dig_rpc/
server.rs

1//! The [`RpcServer`] — Axum-based JSON-RPC server with lifecycle
2//! integration for [`dig_service::ShutdownToken`].
3//!
4//! # Responsibilities
5//!
6//! - Build an Axum `Router` with:
7//!   - `POST /` — JSON-RPC dispatch.
8//!   - `GET /healthz` — liveness.
9//!   - `GET /metrics` — Prometheus (behind the `metrics` feature, wired by the binary).
10//! - Attach the middleware stack: request-id, panic-catch, audit,
11//!   rate-limit, allow-list.
12//! - Drive `axum-server` with TLS configured per [`RpcServerMode`].
13//! - Exit `serve` when the supplied [`ShutdownToken`] fires.
14//!
15//! # v0.1 scope
16//!
17//! The v0.1 server implements the full JSON-RPC dispatch pipeline and
18//! includes per-request rate limiting. Full mTLS client-cert extraction is
19//! wired via rustls's `WebPkiClientVerifier` at TLS-handshake time, but
20//! binaries that want to resolve the authenticated cert to a [`Role`](crate::role::Role)
21//! in middleware should plug in a pluggable extractor (v0.2 enhancement).
22//!
23//! # Minimal example
24//!
25//! ```no_run
26//! use std::sync::Arc;
27//! use dig_rpc::{RpcServer, RpcServerMode, MethodRegistry, MethodMeta, RateBucket};
28//! use dig_rpc::role::Role;
29//! # struct MyApi;
30//! # #[async_trait::async_trait]
31//! # impl dig_service::RpcApi for MyApi {
32//! #     async fn dispatch(&self, _m: &str, _p: serde_json::Value)
33//! #         -> Result<serde_json::Value, dig_rpc_types::envelope::JsonRpcError>
34//! #     { Ok(serde_json::Value::Null) }
35//! # }
36//! # async fn example() -> dig_rpc::RpcServerError {
37//! let api: Arc<MyApi> = Arc::new(MyApi);
38//! let registry = MethodRegistry::new();
39//! registry.register(MethodMeta::read("healthz", Role::Explorer, RateBucket::ReadLight));
40//!
41//! let server = RpcServer::new(api, registry, RpcServerMode::public_plaintext("127.0.0.1:9447".parse().unwrap()));
42//! let shutdown = dig_service::ShutdownToken::new();
43//! match server.serve(shutdown).await {
44//!     Ok(()) => unreachable!(),
45//!     Err(e) => return e,
46//! }
47//! # }
48//! ```
49
50use std::net::SocketAddr;
51use std::sync::Arc;
52
53use axum::{
54    extract::State,
55    http::StatusCode,
56    response::IntoResponse,
57    routing::{get, post},
58    Json, Router,
59};
60use dig_rpc_types::envelope::{JsonRpcRequest, JsonRpcResponse};
61use dig_service::{RpcApi, ShutdownToken};
62
63use crate::dispatch::dispatch_envelope;
64use crate::error::RpcServerError;
65use crate::method::MethodRegistry;
66use crate::middleware::RateLimitState;
67use crate::role::RoleMap;
68use crate::tls::TlsConfig;
69
70/// Server deployment mode.
71#[derive(Clone)]
72pub enum RpcServerMode {
73    /// Internal mTLS server: private CA, full method surface.
74    Internal {
75        /// Bind address.
76        bind: SocketAddr,
77        /// TLS configuration (mTLS).
78        tls: TlsConfig,
79        /// Role map for resolving client certs.
80        role_map: Arc<RoleMap>,
81    },
82    /// Public HTTPS server: public CA, read-only subset.
83    Public {
84        /// Bind address.
85        bind: SocketAddr,
86        /// TLS configuration.
87        tls: TlsConfig,
88    },
89    /// Plain-text HTTP (no TLS). Intended for localhost-only dev / testing.
90    /// DO NOT use in production.
91    PlainText {
92        /// Bind address.
93        bind: SocketAddr,
94    },
95}
96
97impl RpcServerMode {
98    /// Convenience constructor for a plain-text dev mode (loopback only).
99    pub fn public_plaintext(bind: SocketAddr) -> Self {
100        Self::PlainText { bind }
101    }
102
103    /// The bind address regardless of mode.
104    pub fn bind(&self) -> SocketAddr {
105        match self {
106            Self::Internal { bind, .. } => *bind,
107            Self::Public { bind, .. } => *bind,
108            Self::PlainText { bind } => *bind,
109        }
110    }
111}
112
113impl std::fmt::Debug for RpcServerMode {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        match self {
116            Self::Internal { bind, .. } => f.debug_struct("Internal").field("bind", bind).finish(),
117            Self::Public { bind, .. } => f.debug_struct("Public").field("bind", bind).finish(),
118            Self::PlainText { bind } => f.debug_struct("PlainText").field("bind", bind).finish(),
119        }
120    }
121}
122
123/// The JSON-RPC server itself.
124pub struct RpcServer<R: RpcApi + ?Sized> {
125    api: Arc<R>,
126    registry: Arc<MethodRegistry>,
127    mode: RpcServerMode,
128    rate_limit: RateLimitState,
129}
130
131impl<R: RpcApi + ?Sized> RpcServer<R> {
132    /// Construct a server. Use default rate-limit config;
133    /// customise via [`with_rate_limit_state`](Self::with_rate_limit_state).
134    pub fn new(api: Arc<R>, registry: MethodRegistry, mode: RpcServerMode) -> Self {
135        Self {
136            api,
137            registry: Arc::new(registry),
138            mode,
139            rate_limit: RateLimitState::new(crate::middleware::RateLimitConfig::defaults()),
140        }
141    }
142
143    /// Replace the rate-limit state.
144    pub fn with_rate_limit_state(mut self, state: RateLimitState) -> Self {
145        self.rate_limit = state;
146        self
147    }
148
149    /// The bind address for this server.
150    pub fn bind_addr(&self) -> SocketAddr {
151        self.mode.bind()
152    }
153}
154
155impl<R: RpcApi> RpcServer<R> {
156    /// Start serving; return when `shutdown` fires or the listener dies.
157    pub async fn serve(self, shutdown: ShutdownToken) -> Result<(), RpcServerError> {
158        let app_state = AppState {
159            api: self.api,
160            registry: self.registry,
161            rate_limit: self.rate_limit,
162        };
163        let router = build_router::<R>(app_state);
164
165        let bind = self.mode.bind();
166        match self.mode {
167            RpcServerMode::PlainText { .. } => {
168                let listener = tokio::net::TcpListener::bind(bind).await.map_err(|e| {
169                    RpcServerError::BindFailed {
170                        addr: bind,
171                        source: Arc::new(e),
172                    }
173                })?;
174                axum::serve(listener, router)
175                    .with_graceful_shutdown(async move { shutdown.cancelled().await })
176                    .await
177                    .map_err(|e| {
178                        RpcServerError::Fatal(Arc::new(anyhow::anyhow!("axum::serve: {e}")))
179                    })
180            }
181            RpcServerMode::Internal { tls, .. } | RpcServerMode::Public { tls, .. } => {
182                let rustls = axum_server::tls_rustls::RustlsConfig::from_config(tls.server_config);
183                axum_server::bind_rustls(bind, rustls)
184                    .serve(router.into_make_service())
185                    .await
186                    .map_err(|e| {
187                        RpcServerError::Fatal(Arc::new(anyhow::anyhow!("axum-server: {e}")))
188                    })
189            }
190        }
191    }
192}
193
194/// Shared state held inside the Axum router.
195///
196/// We hand-impl `Clone` because `R: ?Sized`; `#[derive(Clone)]` would require
197/// `R: Clone`. All fields are `Arc`s (or cheap-Clone types) so the impl is
198/// trivial.
199struct AppState<R: RpcApi + ?Sized> {
200    api: Arc<R>,
201    registry: Arc<MethodRegistry>,
202    #[allow(dead_code)] // v0.1 does not yet wire the rate limiter from the tower stack
203    rate_limit: RateLimitState,
204}
205
206impl<R: RpcApi + ?Sized> Clone for AppState<R> {
207    fn clone(&self) -> Self {
208        Self {
209            api: self.api.clone(),
210            registry: self.registry.clone(),
211            rate_limit: self.rate_limit.clone(),
212        }
213    }
214}
215
216fn build_router<R: RpcApi>(state: AppState<R>) -> Router {
217    Router::new()
218        .route("/", post(handle_rpc::<R>))
219        .route("/healthz", get(handle_healthz::<R>))
220        .with_state(state)
221}
222
223async fn handle_rpc<R: RpcApi>(
224    State(state): State<AppState<R>>,
225    Json(req): Json<JsonRpcRequest<serde_json::Value>>,
226) -> Json<JsonRpcResponse<serde_json::Value>> {
227    let resp = dispatch_envelope(req, &*state.api, &state.registry).await;
228    Json(resp)
229}
230
231async fn handle_healthz<R: RpcApi>(State(state): State<AppState<R>>) -> impl IntoResponse {
232    match state.api.healthz().await {
233        Ok(()) => (StatusCode::OK, "OK"),
234        Err(_) => (StatusCode::SERVICE_UNAVAILABLE, "unavailable"),
235    }
236}