use std::{
future::Future,
net::SocketAddr,
pin::Pin,
sync::Arc,
};
use serde::{de::DeserializeOwned, Serialize};
pub use super::SerializationKey;
pub use toolkit_zero_macros::mechanism;
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
type HandlerFn = Arc<dyn Fn(IncomingRequest) -> BoxFuture<http::Response<bytes::Bytes>> + Send + Sync>;
pub(crate) struct IncomingRequest {
body: bytes::Bytes,
query: String,
#[allow(dead_code)]
headers: http::HeaderMap,
}
pub struct Rejection {
status: http::StatusCode,
}
impl Rejection {
fn new(status: http::StatusCode) -> Self {
Self { status }
}
pub fn forbidden() -> Self {
Self::new(http::StatusCode::FORBIDDEN)
}
pub fn bad_request() -> Self {
Self::new(http::StatusCode::BAD_REQUEST)
}
pub fn internal() -> Self {
Self::new(http::StatusCode::INTERNAL_SERVER_ERROR)
}
fn into_response(self) -> http::Response<bytes::Bytes> {
http::Response::builder()
.status(self.status)
.body(bytes::Bytes::new())
.unwrap()
}
}
pub trait Reply: Send {
fn into_response(self) -> http::Response<bytes::Bytes>;
}
pub struct EmptyReply;
impl Reply for EmptyReply {
fn into_response(self) -> http::Response<bytes::Bytes> {
http::Response::builder()
.status(http::StatusCode::OK)
.body(bytes::Bytes::new())
.unwrap()
}
}
pub struct HtmlReply {
body: String,
}
impl Reply for HtmlReply {
fn into_response(self) -> http::Response<bytes::Bytes> {
http::Response::builder()
.status(http::StatusCode::OK)
.header(http::header::CONTENT_TYPE, "text/html; charset=utf-8")
.body(bytes::Bytes::from(self.body.into_bytes()))
.unwrap()
}
}
pub fn html_reply(content: impl Into<String>) -> HtmlReply {
HtmlReply { body: content.into() }
}
struct JsonReply {
body: bytes::Bytes,
status: http::StatusCode,
}
impl Reply for JsonReply {
fn into_response(self) -> http::Response<bytes::Bytes> {
http::Response::builder()
.status(self.status)
.header(http::header::CONTENT_TYPE, "application/json")
.body(self.body)
.unwrap()
}
}
impl Reply for http::Response<bytes::Bytes> {
fn into_response(self) -> http::Response<bytes::Bytes> {
self
}
}
pub struct SocketType {
pub(crate) method: http::Method,
pub(crate) path: String,
pub(crate) handler: HandlerFn,
}
impl Clone for SocketType {
fn clone(&self) -> Self {
Self {
method: self.method.clone(),
path: self.path.clone(),
handler: Arc::clone(&self.handler),
}
}
}
#[derive(Clone, Copy, Debug)]
enum HttpMethod {
Get, Post, Put, Delete, Patch, Head, Options,
}
impl HttpMethod {
fn to_http(self) -> http::Method {
match self {
HttpMethod::Get => http::Method::GET,
HttpMethod::Post => http::Method::POST,
HttpMethod::Put => http::Method::PUT,
HttpMethod::Delete => http::Method::DELETE,
HttpMethod::Patch => http::Method::PATCH,
HttpMethod::Head => http::Method::HEAD,
HttpMethod::Options => http::Method::OPTIONS,
}
}
}
fn path_matches(pattern: &str, actual_path: &str) -> bool {
let pat: Vec<&str> = pattern
.trim_matches('/')
.split('/')
.filter(|s| !s.is_empty())
.collect();
let act: Vec<&str> = actual_path
.trim_matches('/')
.split('/')
.filter(|s| !s.is_empty())
.collect();
pat == act
}
pub struct ServerMechanism {
method: HttpMethod,
path: String,
}
impl ServerMechanism {
fn instance(method: HttpMethod, path: impl Into<String>) -> Self {
let path = path.into();
log::debug!("Creating {:?} route at '{}'", method, path);
Self { method, path }
}
pub fn get(path: impl Into<String>) -> Self { Self::instance(HttpMethod::Get, path) }
pub fn post(path: impl Into<String>) -> Self { Self::instance(HttpMethod::Post, path) }
pub fn put(path: impl Into<String>) -> Self { Self::instance(HttpMethod::Put, path) }
pub fn delete(path: impl Into<String>) -> Self { Self::instance(HttpMethod::Delete, path) }
pub fn patch(path: impl Into<String>) -> Self { Self::instance(HttpMethod::Patch, path) }
pub fn head(path: impl Into<String>) -> Self { Self::instance(HttpMethod::Head, path) }
pub fn options(path: impl Into<String>) -> Self { Self::instance(HttpMethod::Options, path) }
pub fn state<S: Clone + Send + Sync + 'static>(self, state: S) -> StatefulSocketBuilder<S> {
log::trace!("Attaching state to {:?} route at '{}'", self.method, self.path);
StatefulSocketBuilder { base: self, state }
}
pub fn json<T: DeserializeOwned + Send>(self) -> JsonSocketBuilder<T> {
log::trace!("Attaching JSON body expectation to {:?} route at '{}'", self.method, self.path);
JsonSocketBuilder { base: self, _phantom: std::marker::PhantomData }
}
pub fn query<T: DeserializeOwned + Send>(self) -> QuerySocketBuilder<T> {
log::trace!("Attaching query parameter expectation to {:?} route at '{}'", self.method, self.path);
QuerySocketBuilder { base: self, _phantom: std::marker::PhantomData }
}
pub fn encryption<T>(self, key: SerializationKey) -> EncryptedBodyBuilder<T> {
log::trace!("Attaching encrypted body to {:?} route at '{}'", self.method, self.path);
EncryptedBodyBuilder { base: self, key, _phantom: std::marker::PhantomData }
}
pub fn encrypted_query<T>(self, key: SerializationKey) -> EncryptedQueryBuilder<T> {
log::trace!("Attaching encrypted query to {:?} route at '{}'", self.method, self.path);
EncryptedQueryBuilder { base: self, key, _phantom: std::marker::PhantomData }
}
pub fn onconnect<F, Fut, Re>(self, handler: F) -> SocketType
where
F: Fn() -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<Re, Rejection>> + Send,
Re: Reply + Send,
{
log::debug!("Finalising async {:?} route at '{}' (no args)", self.method, self.path);
let method = self.method.to_http();
let path = self.path.clone();
SocketType {
method,
path,
handler: Arc::new(move |_req: IncomingRequest| {
let h = handler.clone();
Box::pin(async move {
match h().await {
Ok(r) => r.into_response(),
Err(e) => e.into_response(),
}
})
}),
}
}
pub unsafe fn onconnect_sync<F, Re>(self, handler: F) -> SocketType
where
F: Fn() -> Result<Re, Rejection> + Clone + Send + Sync + 'static,
Re: Reply + Send + 'static,
{
log::warn!(
"Registering sync handler on {:?} '{}' — ensure rate-limiting is applied externally",
self.method, self.path
);
let method = self.method.to_http();
let path = self.path.clone();
SocketType {
method,
path,
handler: Arc::new(move |_req: IncomingRequest| {
let h = handler.clone();
Box::pin(async move {
match tokio::task::spawn_blocking(move || h()).await {
Ok(Ok(r)) => r.into_response(),
Ok(Err(e)) => e.into_response(),
Err(_) => {
log::warn!("Sync handler panicked; returning 500");
Rejection::internal().into_response()
}
}
})
}),
}
}
}
pub struct JsonSocketBuilder<T> {
base: ServerMechanism,
_phantom: std::marker::PhantomData<T>,
}
impl<T: DeserializeOwned + Send + 'static> JsonSocketBuilder<T> {
pub fn onconnect<F, Fut, Re>(self, handler: F) -> SocketType
where
F: Fn(T) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<Re, Rejection>> + Send,
Re: Reply + Send,
{
log::debug!(
"Finalising async {:?} route at '{}' (JSON body)",
self.base.method, self.base.path
);
let method = self.base.method.to_http();
let path = self.base.path.clone();
SocketType {
method,
path,
handler: Arc::new(move |req: IncomingRequest| {
let h = handler.clone();
Box::pin(async move {
let body: T = match serde_json::from_slice(&req.body) {
Ok(v) => v,
Err(e) => {
log::debug!("JSON body parse failed: {}", e);
return Rejection::bad_request().into_response();
}
};
match h(body).await {
Ok(r) => r.into_response(),
Err(e) => e.into_response(),
}
})
}),
}
}
pub unsafe fn onconnect_sync<F, Re>(self, handler: F) -> SocketType
where
F: Fn(T) -> Result<Re, Rejection> + Clone + Send + Sync + 'static,
Re: Reply + Send + 'static,
{
log::warn!(
"Registering sync handler on {:?} '{}' (JSON body) — ensure rate-limiting is applied externally",
self.base.method, self.base.path
);
let method = self.base.method.to_http();
let path = self.base.path.clone();
SocketType {
method,
path,
handler: Arc::new(move |req: IncomingRequest| {
let h = handler.clone();
Box::pin(async move {
let body: T = match serde_json::from_slice(&req.body) {
Ok(v) => v,
Err(e) => {
log::debug!("JSON body parse failed (sync): {}", e);
return Rejection::bad_request().into_response();
}
};
match tokio::task::spawn_blocking(move || h(body)).await {
Ok(Ok(r)) => r.into_response(),
Ok(Err(e)) => e.into_response(),
Err(_) => {
log::warn!("Sync handler (JSON body) panicked; returning 500");
Rejection::internal().into_response()
}
}
})
}),
}
}
pub fn state<S: Clone + Send + Sync + 'static>(
self, state: S,
) -> StatefulJsonSocketBuilder<T, S> {
StatefulJsonSocketBuilder {
base: self.base,
state,
_phantom: std::marker::PhantomData,
}
}
}
pub struct QuerySocketBuilder<T> {
base: ServerMechanism,
_phantom: std::marker::PhantomData<T>,
}
impl<T: DeserializeOwned + Send + 'static> QuerySocketBuilder<T> {
pub fn onconnect<F, Fut, Re>(self, handler: F) -> SocketType
where
F: Fn(T) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<Re, Rejection>> + Send,
Re: Reply + Send,
{
log::debug!(
"Finalising async {:?} route at '{}' (query params)",
self.base.method, self.base.path
);
let method = self.base.method.to_http();
let path = self.base.path.clone();
SocketType {
method,
path,
handler: Arc::new(move |req: IncomingRequest| {
let h = handler.clone();
Box::pin(async move {
let params: T = match serde_urlencoded::from_str(&req.query) {
Ok(v) => v,
Err(e) => {
log::debug!("query param parse failed: {}", e);
return Rejection::bad_request().into_response();
}
};
match h(params).await {
Ok(r) => r.into_response(),
Err(e) => e.into_response(),
}
})
}),
}
}
pub unsafe fn onconnect_sync<F, Re>(self, handler: F) -> SocketType
where
F: Fn(T) -> Result<Re, Rejection> + Clone + Send + Sync + 'static,
Re: Reply + Send + 'static,
{
log::warn!(
"Registering sync handler on {:?} '{}' (query params) — ensure rate-limiting is applied externally",
self.base.method, self.base.path
);
let method = self.base.method.to_http();
let path = self.base.path.clone();
SocketType {
method,
path,
handler: Arc::new(move |req: IncomingRequest| {
let h = handler.clone();
Box::pin(async move {
let params: T = match serde_urlencoded::from_str(&req.query) {
Ok(v) => v,
Err(e) => {
log::debug!("query param parse failed (sync): {}", e);
return Rejection::bad_request().into_response();
}
};
match tokio::task::spawn_blocking(move || h(params)).await {
Ok(Ok(r)) => r.into_response(),
Ok(Err(e)) => e.into_response(),
Err(_) => {
log::warn!("Sync handler (query params) panicked; returning 500");
Rejection::internal().into_response()
}
}
})
}),
}
}
pub fn state<S: Clone + Send + Sync + 'static>(
self, state: S,
) -> StatefulQuerySocketBuilder<T, S> {
StatefulQuerySocketBuilder {
base: self.base,
state,
_phantom: std::marker::PhantomData,
}
}
}
pub struct StatefulSocketBuilder<S> {
base: ServerMechanism,
state: S,
}
impl<S: Clone + Send + Sync + 'static> StatefulSocketBuilder<S> {
pub fn json<T: DeserializeOwned + Send>(self) -> StatefulJsonSocketBuilder<T, S> {
StatefulJsonSocketBuilder {
base: self.base,
state: self.state,
_phantom: std::marker::PhantomData,
}
}
pub fn query<T: DeserializeOwned + Send>(self) -> StatefulQuerySocketBuilder<T, S> {
StatefulQuerySocketBuilder {
base: self.base,
state: self.state,
_phantom: std::marker::PhantomData,
}
}
pub fn encryption<T>(self, key: SerializationKey) -> StatefulEncryptedBodyBuilder<T, S> {
StatefulEncryptedBodyBuilder {
base: self.base,
key,
state: self.state,
_phantom: std::marker::PhantomData,
}
}
pub fn encrypted_query<T>(self, key: SerializationKey) -> StatefulEncryptedQueryBuilder<T, S> {
StatefulEncryptedQueryBuilder {
base: self.base,
key,
state: self.state,
_phantom: std::marker::PhantomData,
}
}
pub fn onconnect<F, Fut, Re>(self, handler: F) -> SocketType
where
F: Fn(S) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<Re, Rejection>> + Send,
Re: Reply + Send,
{
log::debug!(
"Finalising async {:?} route at '{}' (state)",
self.base.method, self.base.path
);
let method = self.base.method.to_http();
let path = self.base.path.clone();
let state = self.state;
SocketType {
method,
path,
handler: Arc::new(move |_req: IncomingRequest| {
let h = handler.clone();
let s = state.clone();
Box::pin(async move {
match h(s).await {
Ok(r) => r.into_response(),
Err(e) => e.into_response(),
}
})
}),
}
}
pub unsafe fn onconnect_sync<F, Re>(self, handler: F) -> SocketType
where
F: Fn(S) -> Result<Re, Rejection> + Clone + Send + Sync + 'static,
Re: Reply + Send + 'static,
{
log::warn!(
"Registering sync handler on {:?} '{}' (state) — ensure rate-limiting and lock-free state are in place",
self.base.method, self.base.path
);
let method = self.base.method.to_http();
let path = self.base.path.clone();
let state = self.state;
SocketType {
method,
path,
handler: Arc::new(move |_req: IncomingRequest| {
let h = handler.clone();
let s = state.clone();
Box::pin(async move {
match tokio::task::spawn_blocking(move || h(s)).await {
Ok(Ok(r)) => r.into_response(),
Ok(Err(e)) => e.into_response(),
Err(_) => {
log::warn!("Sync handler (state) panicked; returning 500");
Rejection::internal().into_response()
}
}
})
}),
}
}
}
pub struct StatefulJsonSocketBuilder<T, S> {
base: ServerMechanism,
state: S,
_phantom: std::marker::PhantomData<T>,
}
impl<T: DeserializeOwned + Send + 'static, S: Clone + Send + Sync + 'static>
StatefulJsonSocketBuilder<T, S>
{
pub fn onconnect<F, Fut, Re>(self, handler: F) -> SocketType
where
F: Fn(S, T) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<Re, Rejection>> + Send,
Re: Reply + Send,
{
log::debug!(
"Finalising async {:?} route at '{}' (state + JSON body)",
self.base.method, self.base.path
);
let method = self.base.method.to_http();
let path = self.base.path.clone();
let state = self.state;
SocketType {
method,
path,
handler: Arc::new(move |req: IncomingRequest| {
let h = handler.clone();
let s = state.clone();
Box::pin(async move {
let body: T = match serde_json::from_slice(&req.body) {
Ok(v) => v,
Err(e) => {
log::debug!("JSON body parse failed (state): {}", e);
return Rejection::bad_request().into_response();
}
};
match h(s, body).await {
Ok(r) => r.into_response(),
Err(e) => e.into_response(),
}
})
}),
}
}
pub unsafe fn onconnect_sync<F, Re>(self, handler: F) -> SocketType
where
F: Fn(S, T) -> Result<Re, Rejection> + Clone + Send + Sync + 'static,
Re: Reply + Send + 'static,
{
log::warn!(
"Registering sync handler on {:?} '{}' (state + JSON body) — ensure rate-limiting and lock-free state are in place",
self.base.method, self.base.path
);
let method = self.base.method.to_http();
let path = self.base.path.clone();
let state = self.state;
SocketType {
method,
path,
handler: Arc::new(move |req: IncomingRequest| {
let h = handler.clone();
let s = state.clone();
Box::pin(async move {
let body: T = match serde_json::from_slice(&req.body) {
Ok(v) => v,
Err(e) => {
log::debug!("JSON body parse failed (state+sync): {}", e);
return Rejection::bad_request().into_response();
}
};
match tokio::task::spawn_blocking(move || h(s, body)).await {
Ok(Ok(r)) => r.into_response(),
Ok(Err(e)) => e.into_response(),
Err(_) => {
log::warn!("Sync handler (state + JSON body) panicked; returning 500");
Rejection::internal().into_response()
}
}
})
}),
}
}
}
pub struct StatefulQuerySocketBuilder<T, S> {
base: ServerMechanism,
state: S,
_phantom: std::marker::PhantomData<T>,
}
impl<T: DeserializeOwned + Send + 'static, S: Clone + Send + Sync + 'static>
StatefulQuerySocketBuilder<T, S>
{
pub fn onconnect<F, Fut, Re>(self, handler: F) -> SocketType
where
F: Fn(S, T) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<Re, Rejection>> + Send,
Re: Reply + Send,
{
log::debug!(
"Finalising async {:?} route at '{}' (state + query params)",
self.base.method, self.base.path
);
let method = self.base.method.to_http();
let path = self.base.path.clone();
let state = self.state;
SocketType {
method,
path,
handler: Arc::new(move |req: IncomingRequest| {
let h = handler.clone();
let s = state.clone();
Box::pin(async move {
let params: T = match serde_urlencoded::from_str(&req.query) {
Ok(v) => v,
Err(e) => {
log::debug!("query param parse failed (state): {}", e);
return Rejection::bad_request().into_response();
}
};
match h(s, params).await {
Ok(r) => r.into_response(),
Err(e) => e.into_response(),
}
})
}),
}
}
pub unsafe fn onconnect_sync<F, Re>(self, handler: F) -> SocketType
where
F: Fn(S, T) -> Result<Re, Rejection> + Clone + Send + Sync + 'static,
Re: Reply + Send + 'static,
{
log::warn!(
"Registering sync handler on {:?} '{}' (state + query params) — ensure rate-limiting and lock-free state are in place",
self.base.method, self.base.path
);
let method = self.base.method.to_http();
let path = self.base.path.clone();
let state = self.state;
SocketType {
method,
path,
handler: Arc::new(move |req: IncomingRequest| {
let h = handler.clone();
let s = state.clone();
Box::pin(async move {
let params: T = match serde_urlencoded::from_str(&req.query) {
Ok(v) => v,
Err(e) => {
log::debug!("query param parse failed (state+sync): {}", e);
return Rejection::bad_request().into_response();
}
};
match tokio::task::spawn_blocking(move || h(s, params)).await {
Ok(Ok(r)) => r.into_response(),
Ok(Err(e)) => e.into_response(),
Err(_) => {
log::warn!("Sync handler (state + query params) panicked; returning 500");
Rejection::internal().into_response()
}
}
})
}),
}
}
}
pub struct EncryptedBodyBuilder<T> {
base: ServerMechanism,
key: SerializationKey,
_phantom: std::marker::PhantomData<T>,
}
impl<T> EncryptedBodyBuilder<T>
where
T: bincode::Decode<()> + Send + 'static,
{
pub fn onconnect<F, Fut, Re>(self, handler: F) -> SocketType
where
F: Fn(T) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<Re, Rejection>> + Send,
Re: Reply + Send,
{
log::debug!(
"Finalising async {:?} route at '{}' (encrypted body)",
self.base.method, self.base.path
);
let method = self.base.method.to_http();
let path = self.base.path.clone();
let key = self.key;
SocketType {
method,
path,
handler: Arc::new(move |req: IncomingRequest| {
let h = handler.clone();
let key = key.clone();
Box::pin(async move {
let value: T = match decode_body(&req.body, &key) {
Ok(v) => v,
Err(e) => return e.into_response(),
};
match h(value).await {
Ok(r) => r.into_response(),
Err(e) => e.into_response(),
}
})
}),
}
}
pub unsafe fn onconnect_sync<F, Re>(self, handler: F) -> SocketType
where
F: Fn(T) -> Result<Re, Rejection> + Clone + Send + Sync + 'static,
Re: Reply + Send + 'static,
{
log::warn!(
"Registering sync handler on {:?} '{}' (encrypted body) — ensure rate-limiting is applied externally",
self.base.method, self.base.path
);
let method = self.base.method.to_http();
let path = self.base.path.clone();
let key = self.key;
SocketType {
method,
path,
handler: Arc::new(move |req: IncomingRequest| {
let h = handler.clone();
let key = key.clone();
Box::pin(async move {
let value: T = match decode_body(&req.body, &key) {
Ok(v) => v,
Err(e) => return e.into_response(),
};
match tokio::task::spawn_blocking(move || h(value)).await {
Ok(Ok(r)) => r.into_response(),
Ok(Err(e)) => e.into_response(),
Err(_) => {
log::warn!("Sync encrypted handler panicked; returning 500");
Rejection::internal().into_response()
}
}
})
}),
}
}
pub fn state<S: Clone + Send + Sync + 'static>(
self, state: S,
) -> StatefulEncryptedBodyBuilder<T, S> {
StatefulEncryptedBodyBuilder {
base: self.base,
key: self.key,
state,
_phantom: std::marker::PhantomData,
}
}
}
pub struct EncryptedQueryBuilder<T> {
base: ServerMechanism,
key: SerializationKey,
_phantom: std::marker::PhantomData<T>,
}
impl<T> EncryptedQueryBuilder<T>
where
T: bincode::Decode<()> + Send + 'static,
{
pub fn onconnect<F, Fut, Re>(self, handler: F) -> SocketType
where
F: Fn(T) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<Re, Rejection>> + Send,
Re: Reply + Send,
{
log::debug!(
"Finalising async {:?} route at '{}' (encrypted query)",
self.base.method, self.base.path
);
let method = self.base.method.to_http();
let path = self.base.path.clone();
let key = self.key;
SocketType {
method,
path,
handler: Arc::new(move |req: IncomingRequest| {
let h = handler.clone();
let key = key.clone();
Box::pin(async move {
let value: T = match decode_query(&req.query, &key) {
Ok(v) => v,
Err(e) => return e.into_response(),
};
match h(value).await {
Ok(r) => r.into_response(),
Err(e) => e.into_response(),
}
})
}),
}
}
pub fn state<S: Clone + Send + Sync + 'static>(
self, state: S,
) -> StatefulEncryptedQueryBuilder<T, S> {
StatefulEncryptedQueryBuilder {
base: self.base,
key: self.key,
state,
_phantom: std::marker::PhantomData,
}
}
}
pub struct StatefulEncryptedBodyBuilder<T, S> {
base: ServerMechanism,
key: SerializationKey,
state: S,
_phantom: std::marker::PhantomData<T>,
}
impl<T, S> StatefulEncryptedBodyBuilder<T, S>
where
T: bincode::Decode<()> + Send + 'static,
S: Clone + Send + Sync + 'static,
{
pub fn onconnect<F, Fut, Re>(self, handler: F) -> SocketType
where
F: Fn(S, T) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<Re, Rejection>> + Send,
Re: Reply + Send,
{
log::debug!(
"Finalising async {:?} route at '{}' (state + encrypted body)",
self.base.method, self.base.path
);
let method = self.base.method.to_http();
let path = self.base.path.clone();
let key = self.key;
let state = self.state;
SocketType {
method,
path,
handler: Arc::new(move |req: IncomingRequest| {
let h = handler.clone();
let key = key.clone();
let s = state.clone();
Box::pin(async move {
let value: T = match decode_body(&req.body, &key) {
Ok(v) => v,
Err(e) => return e.into_response(),
};
match h(s, value).await {
Ok(r) => r.into_response(),
Err(e) => e.into_response(),
}
})
}),
}
}
}
pub struct StatefulEncryptedQueryBuilder<T, S> {
base: ServerMechanism,
key: SerializationKey,
state: S,
_phantom: std::marker::PhantomData<T>,
}
impl<T, S> StatefulEncryptedQueryBuilder<T, S>
where
T: bincode::Decode<()> + Send + 'static,
S: Clone + Send + Sync + 'static,
{
pub fn onconnect<F, Fut, Re>(self, handler: F) -> SocketType
where
F: Fn(S, T) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<Re, Rejection>> + Send,
Re: Reply + Send,
{
log::debug!(
"Finalising async {:?} route at '{}' (state + encrypted query)",
self.base.method, self.base.path
);
let method = self.base.method.to_http();
let path = self.base.path.clone();
let key = self.key;
let state = self.state;
SocketType {
method,
path,
handler: Arc::new(move |req: IncomingRequest| {
let h = handler.clone();
let key = key.clone();
let s = state.clone();
Box::pin(async move {
let value: T = match decode_query(&req.query, &key) {
Ok(v) => v,
Err(e) => return e.into_response(),
};
match h(s, value).await {
Ok(r) => r.into_response(),
Err(e) => e.into_response(),
}
})
}),
}
}
}
fn decode_body<T: bincode::Decode<()>>(
raw: &bytes::Bytes,
key: &SerializationKey,
) -> Result<T, Rejection> {
crate::serialization::open(raw, key.veil_key()).map_err(|e| {
log::debug!("body decryption failed (key mismatch or corrupt body): {}", e);
Rejection::forbidden()
})
}
fn decode_query<T: bincode::Decode<()>>(
raw_query: &str,
key: &SerializationKey,
) -> Result<T, Rejection> {
use base64::Engine;
#[derive(serde::Deserialize)]
struct DataParam { data: String }
let q: DataParam = serde_urlencoded::from_str(raw_query).map_err(|_| {
log::debug!("encrypted query missing `data` parameter");
Rejection::bad_request()
})?;
let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(&q.data)
.map_err(|e| {
log::debug!("base64url decode failed: {}", e);
Rejection::forbidden()
})?;
crate::serialization::open(&bytes, key.veil_key()).map_err(|e| {
log::debug!("query decryption failed: {}", e);
Rejection::forbidden()
})
}
pub fn forbidden() -> Rejection {
Rejection::forbidden()
}
#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
pub enum Status {
Ok,
Created,
Accepted,
NoContent,
MovedPermanently,
Found,
NotModified,
TemporaryRedirect,
PermanentRedirect,
BadRequest,
Unauthorized,
Forbidden,
NotFound,
MethodNotAllowed,
Conflict,
Gone,
UnprocessableEntity,
TooManyRequests,
InternalServerError,
NotImplemented,
BadGateway,
ServiceUnavailable,
GatewayTimeout,
}
impl From<Status> for http::StatusCode {
fn from(s: Status) -> Self {
match s {
Status::Ok => http::StatusCode::OK,
Status::Created => http::StatusCode::CREATED,
Status::Accepted => http::StatusCode::ACCEPTED,
Status::NoContent => http::StatusCode::NO_CONTENT,
Status::MovedPermanently => http::StatusCode::MOVED_PERMANENTLY,
Status::Found => http::StatusCode::FOUND,
Status::NotModified => http::StatusCode::NOT_MODIFIED,
Status::TemporaryRedirect => http::StatusCode::TEMPORARY_REDIRECT,
Status::PermanentRedirect => http::StatusCode::PERMANENT_REDIRECT,
Status::BadRequest => http::StatusCode::BAD_REQUEST,
Status::Unauthorized => http::StatusCode::UNAUTHORIZED,
Status::Forbidden => http::StatusCode::FORBIDDEN,
Status::NotFound => http::StatusCode::NOT_FOUND,
Status::MethodNotAllowed => http::StatusCode::METHOD_NOT_ALLOWED,
Status::Conflict => http::StatusCode::CONFLICT,
Status::Gone => http::StatusCode::GONE,
Status::UnprocessableEntity => http::StatusCode::UNPROCESSABLE_ENTITY,
Status::TooManyRequests => http::StatusCode::TOO_MANY_REQUESTS,
Status::InternalServerError => http::StatusCode::INTERNAL_SERVER_ERROR,
Status::NotImplemented => http::StatusCode::NOT_IMPLEMENTED,
Status::BadGateway => http::StatusCode::BAD_GATEWAY,
Status::ServiceUnavailable => http::StatusCode::SERVICE_UNAVAILABLE,
Status::GatewayTimeout => http::StatusCode::GATEWAY_TIMEOUT,
}
}
}
pub struct Server {
mechanisms: Vec<SocketType>,
}
impl Default for Server {
fn default() -> Self {
Self::new()
}
}
impl Server {
fn new() -> Self {
Self { mechanisms: Vec::new() }
}
pub fn mechanism(&mut self, mech: SocketType) -> &mut Self {
self.mechanisms.push(mech);
log::debug!("Route registered (total: {})", self.mechanisms.len());
self
}
pub fn serve(self, addr: impl Into<SocketAddr>) -> ServerFuture {
let addr = addr.into();
let routes = Arc::new(tokio::sync::RwLock::new(self.mechanisms));
ServerFuture::new(async move {
log::info!("Server binding to {}", addr);
run_hyper_server(routes, addr, std::future::pending::<()>()).await;
})
}
pub fn serve_with_graceful_shutdown(
self,
addr: impl Into<std::net::SocketAddr>,
shutdown: impl std::future::Future<Output = ()> + Send + 'static,
) -> ServerFuture {
let addr = addr.into();
let routes = Arc::new(tokio::sync::RwLock::new(self.mechanisms));
ServerFuture::new(async move {
log::info!("Server binding to {} (graceful shutdown enabled)", addr);
run_hyper_server(routes, addr, shutdown).await;
})
}
pub fn serve_from_listener(
self,
listener: tokio::net::TcpListener,
shutdown: impl std::future::Future<Output = ()> + Send + 'static,
) -> ServerFuture {
let routes = Arc::new(tokio::sync::RwLock::new(self.mechanisms));
ServerFuture::new(async move {
log::info!(
"Server running on {} (graceful shutdown enabled)",
listener.local_addr().map(|a| a.to_string()).unwrap_or_else(|_| "?".into())
);
run_hyper_server_inner(routes, listener, shutdown).await;
})
}
pub fn rebind(self, addr: impl Into<std::net::SocketAddr>) -> ServerFuture {
let addr = addr.into();
let routes = Arc::new(tokio::sync::RwLock::new(self.mechanisms));
ServerFuture::new(async move {
log::info!("Server binding to {} (graceful shutdown on Ctrl+C)", addr);
run_hyper_server(
routes,
addr,
async {
tokio::signal::ctrl_c().await.ok();
log::info!("Interrupt received — draining in-flight connections");
},
).await;
})
}
pub fn serve_managed(self, addr: impl Into<std::net::SocketAddr>) -> Result<BackgroundServer, std::io::Error> {
let addr = addr.into();
let std_listener = std::net::TcpListener::bind(addr)?;
std_listener.set_nonblocking(true)?;
let routes = Arc::new(tokio::sync::RwLock::new(self.mechanisms));
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
let routes_ref = Arc::clone(&routes);
let handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::from_std(std_listener)
.expect("from_std: listener must be non-blocking");
log::info!("server bound to {}", addr);
run_hyper_server_inner(routes_ref, listener, async { rx.await.ok(); }).await;
});
Ok(BackgroundServer {
routes,
addr,
shutdown_tx: Some(tx),
handle: Some(handle),
})
}
}
pub struct ServerFuture(Pin<Box<dyn Future<Output = ()> + Send + 'static>>);
impl ServerFuture {
fn new(fut: impl Future<Output = ()> + Send + 'static) -> Self {
Self(Box::pin(fut))
}
pub fn background(self) -> tokio::task::JoinHandle<()> {
tokio::spawn(self.0)
}
}
impl std::future::IntoFuture for ServerFuture {
type Output = ();
type IntoFuture = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
fn into_future(self) -> Self::IntoFuture {
self.0
}
}
async fn dispatch(
routes: &Arc<tokio::sync::RwLock<Vec<SocketType>>>,
req: hyper::Request<hyper::body::Incoming>,
) -> http::Response<bytes::Bytes> {
use http_body_util::BodyExt;
let (parts, body) = req.into_parts();
let path = parts.uri.path().to_owned();
let query = parts.uri.query().unwrap_or("").to_owned();
let method = parts.method.clone();
let headers = parts.headers.clone();
let body_bytes = match body.collect().await {
Ok(c) => c.to_bytes(),
Err(e) => {
log::debug!("failed to read request body: {}", e);
return http::Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.body(bytes::Bytes::new())
.unwrap();
}
};
let handler = {
let guard = routes.read().await;
guard
.iter()
.find(|s| s.method == method && path_matches(&s.path, &path))
.map(|s| Arc::clone(&s.handler))
};
match handler {
Some(h) => {
h(IncomingRequest { body: body_bytes, query, headers }).await
}
None => {
log::debug!("No route matched {} {}", method, path);
http::Response::builder()
.status(http::StatusCode::NOT_FOUND)
.body(bytes::Bytes::new())
.unwrap()
}
}
}
async fn run_hyper_server_inner(
routes: Arc<tokio::sync::RwLock<Vec<SocketType>>>,
listener: tokio::net::TcpListener,
shutdown: impl Future<Output = ()> + Send + 'static,
) {
use hyper_util::server::graceful::GracefulShutdown;
use hyper_util::rt::TokioIo;
use hyper::server::conn::http1;
let graceful = GracefulShutdown::new();
let mut shutdown = std::pin::pin!(shutdown);
loop {
tokio::select! {
result = listener.accept() => {
let (stream, remote) = match result {
Ok(pair) => pair,
Err(e) => {
log::warn!("accept error: {}", e);
continue;
}
};
log::trace!("accepted connection from {}", remote);
let routes_ref = Arc::clone(&routes);
let conn = http1::Builder::new().serve_connection(
TokioIo::new(stream),
hyper::service::service_fn(move |req| {
let r = Arc::clone(&routes_ref);
async move {
let resp = dispatch(&r, req).await;
let (parts, body) = resp.into_parts();
Ok::<_, std::convert::Infallible>(
http::Response::from_parts(
parts,
http_body_util::Full::new(body),
)
)
}
}),
);
let fut = graceful.watch(conn);
tokio::spawn(async move {
if let Err(e) = fut.await {
log::debug!("connection error: {}", e);
}
});
}
_ = &mut shutdown => {
log::info!("shutdown signal received — draining in-flight connections");
break;
}
}
}
drop(listener);
graceful.shutdown().await;
log::info!("all connections drained");
}
async fn run_hyper_server(
routes: Arc<tokio::sync::RwLock<Vec<SocketType>>>,
addr: SocketAddr,
shutdown: impl Future<Output = ()> + Send + 'static,
) {
let listener = match tokio::net::TcpListener::bind(addr).await {
Ok(l) => {
log::info!("server bound to {}", addr);
l
}
Err(e) => {
log::error!("failed to bind {}: {}", addr, e);
panic!("server bind failed: {}", e);
}
};
run_hyper_server_inner(routes, listener, shutdown).await;
}
#[must_use = "dropping BackgroundServer leaves the background task running; call .stop().await to shut it down"]
pub struct BackgroundServer {
routes: Arc<tokio::sync::RwLock<Vec<SocketType>>>,
addr: std::net::SocketAddr,
shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
handle: Option<tokio::task::JoinHandle<()>>,
}
impl BackgroundServer {
pub fn addr(&self) -> std::net::SocketAddr {
self.addr
}
pub async fn stop(mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(h) = self.handle.take() {
let _ = h.await;
}
}
pub async fn rebind(&mut self, addr: impl Into<std::net::SocketAddr>) -> Result<(), std::io::Error> {
let new_addr = addr.into();
let std_listener = std::net::TcpListener::bind(new_addr)?;
std_listener.set_nonblocking(true)?;
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(h) = self.handle.take() {
let _ = h.await;
}
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
self.shutdown_tx = Some(tx);
self.addr = new_addr;
let routes = Arc::clone(&self.routes);
self.handle = Some(tokio::spawn(async move {
let listener = tokio::net::TcpListener::from_std(std_listener)
.expect("from_std: listener must be non-blocking");
run_hyper_server_inner(routes, listener, async { rx.await.ok(); }).await;
}));
log::info!("Server rebound to {}", new_addr);
Ok(())
}
pub async fn mechanism(&mut self, mech: SocketType) -> &mut Self {
self.routes.write().await.push(mech);
log::debug!(
"mechanism: live route added (total = {})",
self.routes.read().await.len()
);
self
}
}
pub fn reply_with_status(
status: Status,
reply: impl Reply,
) -> Result<http::Response<bytes::Bytes>, Rejection> {
let mut resp = reply.into_response();
*resp.status_mut() = status.into();
Ok(resp)
}
pub fn reply() -> Result<impl Reply, Rejection> {
Ok::<_, Rejection>(EmptyReply)
}
pub fn reply_with_json<T: Serialize>(
json: &T,
) -> Result<impl Reply + use<T>, Rejection> {
let bytes = serde_json::to_vec(json).map_err(|_| Rejection::internal())?;
Ok::<_, Rejection>(JsonReply {
body: bytes::Bytes::from(bytes),
status: http::StatusCode::OK,
})
}
pub fn reply_with_status_and_json<T: Serialize>(
status: Status,
json: &T,
) -> Result<impl Reply + use<T>, Rejection> {
let bytes = serde_json::to_vec(json).map_err(|_| Rejection::internal())?;
Ok::<_, Rejection>(JsonReply {
body: bytes::Bytes::from(bytes),
status: status.into(),
})
}
pub fn reply_sealed<T: bincode::Encode>(
value: &T,
key: SerializationKey,
) -> Result<http::Response<bytes::Bytes>, Rejection> {
sealed_response(value, key, None)
}
pub fn reply_sealed_with_status<T: bincode::Encode>(
value: &T,
key: SerializationKey,
status: Status,
) -> Result<http::Response<bytes::Bytes>, Rejection> {
sealed_response(value, key, Some(status))
}
fn sealed_response<T: bincode::Encode>(
value: &T,
key: SerializationKey,
status: Option<Status>,
) -> Result<http::Response<bytes::Bytes>, Rejection> {
let code: http::StatusCode = status.map(Into::into).unwrap_or(http::StatusCode::OK);
let sealed = crate::serialization::seal(value, key.veil_key())
.map_err(|_| Rejection::internal())?;
Ok(http::Response::builder()
.status(code)
.header(http::header::CONTENT_TYPE, "application/octet-stream")
.body(bytes::Bytes::from(sealed))
.unwrap())
}
#[doc(hidden)]
#[macro_export]
macro_rules! reply {
() => {{
$crate::socket::server::reply()
}};
(message => $message: expr, status => $status: expr) => {{
$crate::socket::server::reply_with_status($status, $message)
}};
(json => $json: expr) => {{
$crate::socket::server::reply_with_json(&$json)
}};
(json => $json: expr, status => $status: expr) => {{
$crate::socket::server::reply_with_status_and_json($status, &$json)
}};
(sealed => $val: expr, key => $key: expr) => {{
$crate::socket::server::reply_sealed(&$val, $key)
}};
(sealed => $val: expr, key => $key: expr, status => $status: expr) => {{
$crate::socket::server::reply_sealed_with_status(&$val, $key, $status)
}};
}
pub use crate::reply;