#![cfg(feature = "server")]
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use axum::Router;
use axum::body::Bytes;
use axum::extract::{Path, State};
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response as AxumResponse};
use axum::routing::post;
use crate::codec::{
HEADER_AUTHORIZATION, HEADER_WIRE_VERSION, PATH_PREFIX, Response, WIRE_VERSION,
};
pub type MethodHandler =
Arc<dyn Fn(Bytes) -> Pin<Box<dyn Future<Output = Response> + Send>> + Send + Sync + 'static>;
pub struct RpcRouter {
interface: String,
methods: HashMap<String, MethodHandler>,
}
impl RpcRouter {
pub fn new(interface: impl Into<String>) -> Self {
Self {
interface: interface.into(),
methods: HashMap::new(),
}
}
pub fn add_method(mut self, method: impl Into<String>, handler: MethodHandler) -> Self {
self.methods.insert(method.into(), handler);
self
}
pub fn into_axum_router(self, secret: Option<String>) -> Router {
let state = Arc::new(RouterState {
interface: self.interface,
methods: self.methods,
secret,
});
let route_path = format!("{PATH_PREFIX}{{interface}}/{{method}}");
Router::new()
.route(&route_path, post(rpc_handler))
.with_state(state)
}
}
struct RouterState {
interface: String,
methods: HashMap<String, MethodHandler>,
secret: Option<String>,
}
async fn rpc_handler(
State(state): State<Arc<RouterState>>,
Path((interface, method)): Path<(String, String)>,
headers: HeaderMap,
body: Bytes,
) -> AxumResponse {
let version = headers
.get(HEADER_WIRE_VERSION)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if version != WIRE_VERSION {
return reply(
StatusCode::BAD_REQUEST,
Response::err(
"BadVersion",
format!("expected {HEADER_WIRE_VERSION}: {WIRE_VERSION}; got {version:?}"),
),
);
}
if let Some(expected) = state.secret.as_ref() {
let auth = headers
.get(HEADER_AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !auth.starts_with("Bearer ") || &auth[7..] != expected.as_str() {
return reply(
StatusCode::UNAUTHORIZED,
Response::err("Unauthorized", "missing or invalid Bearer token"),
);
}
}
if interface != state.interface {
return reply(
StatusCode::NOT_FOUND,
Response::err(
"InterfaceMismatch",
format!("server hosts {:?}, got {interface:?}", state.interface),
),
);
}
let Some(handler) = state.methods.get(&method).cloned() else {
return reply(
StatusCode::NOT_FOUND,
Response::err(
"UnknownMethod",
format!("{} has no method {method:?}", state.interface),
),
);
};
let result = handler(body).await;
let status = match &result {
Response::Ok(_) => StatusCode::OK,
Response::Err { .. } => StatusCode::INTERNAL_SERVER_ERROR,
};
reply(status, result)
}
fn reply(status: StatusCode, envelope: Response) -> AxumResponse {
let body = envelope.to_json_bytes();
let mut resp = (status, body).into_response();
resp.headers_mut().insert(
HEADER_WIRE_VERSION,
axum::http::HeaderValue::from_static(WIRE_VERSION),
);
resp.headers_mut().insert(
axum::http::header::CONTENT_TYPE,
axum::http::HeaderValue::from_static("application/json"),
);
resp
}
pub async fn serve_forever(addr: std::net::SocketAddr, router: Router) -> std::io::Result<()> {
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, router).await?;
Ok(())
}