use std::future::Future;
use std::net::{SocketAddr, TcpListener as StdTcpListener};
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::response;
use crate::response::{internal_error, malformed};
use futures_channel::mpsc;
use futures_util::future::FutureExt;
use futures_util::stream::{StreamExt, TryStreamExt};
use hyper::header::{HeaderMap, HeaderValue};
use hyper::server::conn::AddrStream;
use hyper::server::{conn::AddrIncoming, Builder as HyperBuilder};
use hyper::service::{make_service_fn, service_fn};
use hyper::{Error as HyperError, Method};
use jsonrpsee_core::error::{Error, GenericTransportError};
use jsonrpsee_core::http_helpers::{self, read_body};
use jsonrpsee_core::middleware::{self, HttpMiddleware as Middleware};
use jsonrpsee_core::server::access_control::AccessControl;
use jsonrpsee_core::server::helpers::{prepare_error, MethodResponse};
use jsonrpsee_core::server::helpers::{BatchResponse, BatchResponseBuilder};
use jsonrpsee_core::server::resource_limiting::Resources;
use jsonrpsee_core::server::rpc_module::{MethodKind, Methods};
use jsonrpsee_core::tracing::{rx_log_from_json, rx_log_from_str, tx_log_from_str, RpcTracing};
use jsonrpsee_core::TEN_MB_SIZE_BYTES;
use jsonrpsee_types::error::{ErrorCode, ErrorObject, BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG};
use jsonrpsee_types::{Id, Notification, Params, Request};
use serde_json::value::RawValue;
use tokio::net::{TcpListener, ToSocketAddrs};
use tracing_futures::Instrument;
type Notif<'a> = Notification<'a, Option<&'a RawValue>>;
#[derive(Debug)]
pub struct Builder<M = ()> {
access_control: AccessControl,
resources: Resources,
max_request_body_size: u32,
max_response_body_size: u32,
batch_requests_supported: bool,
tokio_runtime: Option<tokio::runtime::Handle>,
middleware: M,
max_log_length: u32,
health_api: Option<HealthApi>,
}
impl Default for Builder {
fn default() -> Self {
Self {
access_control: AccessControl::default(),
max_request_body_size: TEN_MB_SIZE_BYTES,
max_response_body_size: TEN_MB_SIZE_BYTES,
batch_requests_supported: true,
resources: Resources::default(),
tokio_runtime: None,
middleware: (),
max_log_length: 4096,
health_api: None,
}
}
}
impl Builder {
pub fn new() -> Self {
Self::default()
}
}
impl<M> Builder<M> {
pub fn set_middleware<T: Middleware>(self, middleware: T) -> Builder<T> {
Builder {
access_control: self.access_control,
max_request_body_size: self.max_request_body_size,
max_response_body_size: self.max_response_body_size,
batch_requests_supported: self.batch_requests_supported,
resources: self.resources,
tokio_runtime: self.tokio_runtime,
middleware,
max_log_length: self.max_log_length,
health_api: self.health_api,
}
}
pub fn max_request_body_size(mut self, size: u32) -> Self {
self.max_request_body_size = size;
self
}
pub fn max_response_body_size(mut self, size: u32) -> Self {
self.max_response_body_size = size;
self
}
pub fn set_access_control(mut self, acl: AccessControl) -> Self {
self.access_control = acl;
self
}
pub fn batch_requests_supported(mut self, supported: bool) -> Self {
self.batch_requests_supported = supported;
self
}
pub fn register_resource(mut self, label: &'static str, capacity: u16, default: u16) -> Result<Self, Error> {
self.resources.register(label, capacity, default)?;
Ok(self)
}
pub fn custom_tokio_runtime(mut self, rt: tokio::runtime::Handle) -> Self {
self.tokio_runtime = Some(rt);
self
}
pub fn health_api(mut self, path: impl Into<String>, method: impl Into<String>) -> Result<Self, Error> {
let path = path.into();
if !path.starts_with('/') {
return Err(Error::Custom(format!("Health endpoint path must start with `/` to work, got: {}", path)));
}
self.health_api = Some(HealthApi { path, method: method.into() });
Ok(self)
}
pub fn build_from_hyper(
self,
listener: hyper::server::Builder<AddrIncoming>,
local_addr: SocketAddr,
) -> Result<Server<M>, Error> {
Ok(Server {
access_control: self.access_control,
listener,
local_addr: Some(local_addr),
max_request_body_size: self.max_request_body_size,
max_response_body_size: self.max_response_body_size,
batch_requests_supported: self.batch_requests_supported,
resources: self.resources,
tokio_runtime: self.tokio_runtime,
middleware: self.middleware,
max_log_length: self.max_log_length,
health_api: self.health_api,
})
}
pub fn build_from_tcp(self, listener: impl Into<StdTcpListener>) -> Result<Server<M>, Error> {
let listener = listener.into();
let local_addr = listener.local_addr().ok();
let listener = hyper::Server::from_tcp(listener)?;
Ok(Server {
listener,
local_addr,
access_control: self.access_control,
max_request_body_size: self.max_request_body_size,
max_response_body_size: self.max_response_body_size,
batch_requests_supported: self.batch_requests_supported,
resources: self.resources,
tokio_runtime: self.tokio_runtime,
middleware: self.middleware,
max_log_length: self.max_log_length,
health_api: self.health_api,
})
}
pub async fn build(self, addrs: impl ToSocketAddrs) -> Result<Server<M>, Error> {
let listener = TcpListener::bind(addrs).await?.into_std()?;
let local_addr = listener.local_addr().ok();
let listener = hyper::Server::from_tcp(listener)?.tcp_nodelay(true);
Ok(Server {
listener,
local_addr,
access_control: self.access_control,
max_request_body_size: self.max_request_body_size,
max_response_body_size: self.max_response_body_size,
batch_requests_supported: self.batch_requests_supported,
resources: self.resources,
tokio_runtime: self.tokio_runtime,
middleware: self.middleware,
max_log_length: self.max_log_length,
health_api: self.health_api,
})
}
}
#[derive(Debug, Clone)]
struct HealthApi {
path: String,
method: String,
}
#[derive(Debug)]
pub struct ServerHandle {
stop_sender: mpsc::Sender<()>,
pub(crate) handle: Option<tokio::task::JoinHandle<()>>,
}
impl ServerHandle {
pub fn stop(mut self) -> Result<tokio::task::JoinHandle<()>, Error> {
let stop = self.stop_sender.try_send(()).map(|_| self.handle.take());
match stop {
Ok(Some(handle)) => Ok(handle),
_ => Err(Error::AlreadyStopped),
}
}
}
impl Future for ServerHandle {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let handle = match &mut self.handle {
Some(handle) => handle,
None => return Poll::Ready(()),
};
handle.poll_unpin(cx).map(|_| ())
}
}
#[derive(Debug)]
pub struct Server<M = ()> {
listener: HyperBuilder<AddrIncoming>,
local_addr: Option<SocketAddr>,
max_request_body_size: u32,
max_response_body_size: u32,
max_log_length: u32,
batch_requests_supported: bool,
access_control: AccessControl,
resources: Resources,
tokio_runtime: Option<tokio::runtime::Handle>,
middleware: M,
health_api: Option<HealthApi>,
}
impl<M: Middleware> Server<M> {
pub fn local_addr(&self) -> Result<SocketAddr, Error> {
self.local_addr.ok_or_else(|| Error::Custom("Local address not found".into()))
}
pub fn start(mut self, methods: impl Into<Methods>) -> Result<ServerHandle, Error> {
let max_request_body_size = self.max_request_body_size;
let max_response_body_size = self.max_response_body_size;
let max_log_length = self.max_log_length;
let acl = self.access_control;
let (tx, mut rx) = mpsc::channel(1);
let listener = self.listener;
let resources = self.resources;
let middleware = self.middleware;
let batch_requests_supported = self.batch_requests_supported;
let methods = methods.into().initialize_resources(&resources)?;
let health_api = self.health_api;
let make_service = make_service_fn(move |conn: &AddrStream| {
let remote_addr = conn.remote_addr();
let methods = methods.clone();
let acl = acl.clone();
let resources = resources.clone();
let middleware = middleware.clone();
let health_api = health_api.clone();
async move {
Ok::<_, HyperError>(service_fn(move |request| {
let request_start = middleware.on_request(remote_addr, request.headers());
let methods = methods.clone();
let acl = acl.clone();
let resources = resources.clone();
let middleware = middleware.clone();
let health_api = health_api.clone();
async move {
let keys = request.headers().keys().map(|k| k.as_str());
let cors_request_headers = http_helpers::get_cors_request_headers(request.headers());
let host = match http_helpers::read_header_value(request.headers(), "host") {
Some(origin) => origin,
None => return Ok(malformed()),
};
let maybe_origin = http_helpers::read_header_value(request.headers(), "origin");
if let Err(e) = acl.verify_host(host) {
tracing::warn!("Denied request: {:?}", e);
return Ok(response::host_not_allowed());
}
if let Err(e) = acl.verify_origin(maybe_origin, host) {
tracing::warn!("Denied request: {:?}", e);
return Ok(response::invalid_allow_origin());
}
if let Err(e) = acl.verify_headers(keys, cors_request_headers) {
tracing::warn!("Denied request: {:?}", e);
return Ok(response::invalid_allow_headers());
}
match *request.method() {
Method::OPTIONS => {
let origin = match maybe_origin {
Some(origin) => origin,
None => return Ok(malformed()),
};
let allowed_headers = acl.allowed_headers().to_cors_header_value();
let allowed_header_bytes = allowed_headers.as_bytes();
let res = hyper::Response::builder()
.header("access-control-allow-origin", origin)
.header("access-control-allow-methods", "POST")
.header("access-control-allow-headers", allowed_header_bytes)
.body(hyper::Body::empty())
.unwrap_or_else(|e| {
tracing::error!("Error forming preflight response: {}", e);
internal_error()
});
Ok(res)
}
Method::POST if content_type_is_json(&request) => {
let origin = return_origin_if_different_from_host(request.headers()).cloned();
let mut res = process_validated_request(ProcessValidatedRequest {
request,
middleware,
methods,
resources,
max_request_body_size,
max_response_body_size,
max_log_length,
batch_requests_supported,
request_start,
})
.await?;
if let Some(origin) = origin {
res.headers_mut().insert("access-control-allow-origin", origin);
}
Ok(res)
}
Method::GET => match health_api.as_ref() {
Some(health) if health.path.as_str() == request.uri().path() => {
process_health_request(
health,
middleware,
methods,
max_response_body_size,
request_start,
max_log_length,
)
.await
}
_ => Ok(response::method_not_allowed()),
},
Method::POST => Ok(response::unsupported_content_type()),
_ => Ok(response::method_not_allowed()),
}
}
}))
}
});
let rt = match self.tokio_runtime.take() {
Some(rt) => rt,
None => tokio::runtime::Handle::current(),
};
let handle = rt.spawn(async move {
let server = listener.serve(make_service);
let _ = server.with_graceful_shutdown(async move { rx.next().await.map_or((), |_| ()) }).await;
});
Ok(ServerHandle { handle: Some(handle), stop_sender: tx })
}
}
fn return_origin_if_different_from_host(headers: &HeaderMap) -> Option<&HeaderValue> {
if let (Some(origin), Some(host)) = (headers.get("origin"), headers.get("host")) {
if origin != host {
Some(origin)
} else {
None
}
} else {
None
}
}
fn content_type_is_json(request: &hyper::Request<hyper::Body>) -> bool {
is_json(request.headers().get("content-type"))
}
fn is_json(content_type: Option<&hyper::header::HeaderValue>) -> bool {
match content_type.and_then(|val| val.to_str().ok()) {
Some(content)
if content.eq_ignore_ascii_case("application/json")
|| content.eq_ignore_ascii_case("application/json; charset=utf-8")
|| content.eq_ignore_ascii_case("application/json;charset=utf-8") =>
{
true
}
_ => false,
}
}
struct ProcessValidatedRequest<M: Middleware> {
request: hyper::Request<hyper::Body>,
middleware: M,
methods: Methods,
resources: Resources,
max_request_body_size: u32,
max_response_body_size: u32,
max_log_length: u32,
batch_requests_supported: bool,
request_start: M::Instant,
}
async fn process_validated_request<M: Middleware>(
input: ProcessValidatedRequest<M>,
) -> Result<hyper::Response<hyper::Body>, HyperError> {
let ProcessValidatedRequest {
request,
middleware,
methods,
resources,
max_request_body_size,
max_response_body_size,
max_log_length,
batch_requests_supported,
request_start,
} = input;
let (parts, body) = request.into_parts();
let (body, is_single) = match read_body(&parts.headers, body, max_request_body_size).await {
Ok(r) => r,
Err(GenericTransportError::TooLarge) => return Ok(response::too_large(max_request_body_size)),
Err(GenericTransportError::Malformed) => return Ok(response::malformed()),
Err(GenericTransportError::Inner(e)) => {
tracing::error!("Internal error reading request body: {}", e);
return Ok(response::internal_error());
}
};
if is_single {
let call = CallData {
conn_id: 0,
middleware: &middleware,
methods: &methods,
max_response_body_size,
max_log_length,
resources: &resources,
request_start,
};
let response = process_single_request(body, call).await;
middleware.on_response(&response.result, request_start);
Ok(response::ok_response(response.result))
}
else if !batch_requests_supported {
let err = MethodResponse::error(
Id::Null,
ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, &BATCHES_NOT_SUPPORTED_MSG, None),
);
middleware.on_response(&err.result, request_start);
Ok(response::ok_response(err.result))
}
else {
let response = process_batch_request(Batch {
data: body,
call: CallData {
conn_id: 0,
middleware: &middleware,
methods: &methods,
max_response_body_size,
max_log_length,
resources: &resources,
request_start,
},
})
.await;
middleware.on_response(&response.result, request_start);
Ok(response::ok_response(response.result))
}
}
async fn process_health_request<M: Middleware>(
health_api: &HealthApi,
middleware: M,
methods: Methods,
max_response_body_size: u32,
request_start: M::Instant,
max_log_length: u32,
) -> Result<hyper::Response<hyper::Body>, HyperError> {
let trace = RpcTracing::method_call(&health_api.method);
async {
tx_log_from_str("HTTP health API", max_log_length);
let response = match methods.method_with_name(&health_api.method) {
None => MethodResponse::error(Id::Null, ErrorObject::from(ErrorCode::MethodNotFound)),
Some((_name, method_callback)) => match method_callback.inner() {
MethodKind::Sync(callback) => {
(callback)(Id::Number(0), Params::new(None), max_response_body_size as usize)
}
MethodKind::Async(callback) => {
(callback)(Id::Number(0), Params::new(None), 0, max_response_body_size as usize, None).await
}
MethodKind::Subscription(_) | MethodKind::Unsubscription(_) => {
MethodResponse::error(Id::Null, ErrorObject::from(ErrorCode::InternalError))
}
},
};
rx_log_from_str(&response.result, max_log_length);
middleware.on_result(&health_api.method, response.success, request_start);
middleware.on_response(&response.result, request_start);
if response.success {
#[derive(serde::Deserialize)]
struct RpcPayload<'a> {
#[serde(borrow)]
result: &'a serde_json::value::RawValue,
}
let payload: RpcPayload = serde_json::from_str(&response.result)
.expect("valid JSON-RPC response must have a result field and be valid JSON; qed");
Ok(response::ok_response(payload.result.to_string()))
} else {
Ok(response::internal_error())
}
}
.instrument(trace.into_span())
.await
}
#[derive(Debug, Clone)]
struct Batch<'a, M: Middleware> {
data: Vec<u8>,
call: CallData<'a, M>,
}
#[derive(Debug, Clone)]
struct CallData<'a, M: Middleware> {
conn_id: usize,
middleware: &'a M,
methods: &'a Methods,
max_response_body_size: u32,
max_log_length: u32,
resources: &'a Resources,
request_start: M::Instant,
}
#[derive(Debug, Clone)]
struct Call<'a, M: Middleware> {
params: Params<'a>,
name: &'a str,
call: CallData<'a, M>,
id: Id<'a>,
}
async fn process_batch_request<M>(b: Batch<'_, M>) -> BatchResponse
where
M: Middleware,
{
let Batch { data, call } = b;
if let Ok(batch) = serde_json::from_slice::<Vec<Request>>(&data) {
let max_response_size = call.max_response_body_size;
let batch = batch.into_iter().map(|req| Ok((req, call.clone())));
let batch_stream = futures_util::stream::iter(batch);
let trace = RpcTracing::batch();
return async {
let batch_response = batch_stream
.try_fold(
BatchResponseBuilder::new_with_limit(max_response_size as usize),
|batch_response, (req, call)| async move {
let params = Params::new(req.params.map(|params| params.get()));
let response = execute_call(Call { name: &req.method, params, id: req.id, call }).await;
batch_response.append(&response)
},
)
.await;
match batch_response {
Ok(batch) => batch.finish(),
Err(batch_err) => batch_err,
}
}
.instrument(trace.into_span())
.await;
}
if let Ok(batch) = serde_json::from_slice::<Vec<Notif>>(&data) {
return if !batch.is_empty() {
BatchResponse { result: "".to_string(), success: true }
} else {
BatchResponse::error(Id::Null, ErrorObject::from(ErrorCode::InvalidRequest))
};
}
let (id, code) = prepare_error(&data);
BatchResponse::error(id, ErrorObject::from(code))
}
async fn process_single_request<M: Middleware>(data: Vec<u8>, call: CallData<'_, M>) -> MethodResponse {
if let Ok(req) = serde_json::from_slice::<Request>(&data) {
let trace = RpcTracing::method_call(&req.method);
async {
rx_log_from_json(&req, call.max_log_length);
let params = Params::new(req.params.map(|params| params.get()));
let name = &req.method;
let id = req.id;
execute_call(Call { name, params, id, call }).await
}
.instrument(trace.into_span())
.await
} else if let Ok(req) = serde_json::from_slice::<Notif>(&data) {
let trace = RpcTracing::notification(&req.method);
let span = trace.into_span();
let _enter = span.enter();
rx_log_from_json(&req, call.max_log_length);
MethodResponse { result: String::new(), success: true }
} else {
let (id, code) = prepare_error(&data);
MethodResponse::error(id, ErrorObject::from(code))
}
}
async fn execute_call<M: Middleware>(c: Call<'_, M>) -> MethodResponse {
let Call { name, id, params, call } = c;
let CallData { resources, methods, middleware, max_response_body_size, max_log_length, conn_id, request_start } =
call;
let response = match methods.method_with_name(name) {
None => {
middleware.on_call(name, params.clone(), middleware::MethodKind::Unknown);
MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound))
}
Some((name, method)) => match &method.inner() {
MethodKind::Sync(callback) => {
middleware.on_call(name, params.clone(), middleware::MethodKind::MethodCall);
match method.claim(name, resources) {
Ok(guard) => {
let r = (callback)(id, params, max_response_body_size as usize);
drop(guard);
r
}
Err(err) => {
tracing::error!("[Methods::execute_with_resources] failed to lock resources: {:?}", err);
MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy))
}
}
}
MethodKind::Async(callback) => {
middleware.on_call(name, params.clone(), middleware::MethodKind::MethodCall);
match method.claim(name, resources) {
Ok(guard) => {
let id = id.into_owned();
let params = params.into_owned();
(callback)(id, params, conn_id, max_response_body_size as usize, Some(guard)).await
}
Err(err) => {
tracing::error!("[Methods::execute_with_resources] failed to lock resources: {:?}", err);
MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy))
}
}
}
MethodKind::Subscription(_) | MethodKind::Unsubscription(_) => {
middleware.on_call(name, params.clone(), middleware::MethodKind::Unknown);
tracing::error!("Subscriptions not supported on HTTP");
MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError))
}
},
};
tx_log_from_str(&response.result, max_log_length);
middleware.on_result(name, response.success, request_start);
response
}