#![doc = include_str!("readme.md")]
#![warn(missing_docs)]
pub mod error;
pub mod extract;
pub mod middleware;
pub mod response;
pub mod router;
pub mod template;
pub mod tls;
pub use wae_session as session;
pub use response::{Attachment, Html, JsonResponse, Redirect, StreamResponse};
pub use router::{MethodRouter, RouterBuilder, delete, get, head, options, patch, post, put, trace};
use http::{Response, StatusCode, header};
use http_body_util::Full;
use hyper::body::Bytes;
use std::{net::SocketAddr, path::Path, sync::Arc, time::Duration};
use tokio::net::TcpListener;
use tracing::info;
pub use wae_types::{WaeError, WaeResult};
pub type Body = Full<Bytes>;
pub fn empty_body() -> Body {
Full::new(Bytes::new())
}
pub fn full_body<B: Into<Bytes>>(data: B) -> Body {
Full::new(data.into())
}
pub type HttpsResult<T> = WaeResult<T>;
pub type HttpsError = WaeError;
pub trait IntoResponse {
fn into_response(self) -> Response<Body>;
}
impl IntoResponse for Response<Body> {
fn into_response(self) -> Response<Body> {
self
}
}
impl IntoResponse for &'static str {
fn into_response(self) -> Response<Body> {
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
.body(full_body(self))
.unwrap()
}
}
impl IntoResponse for String {
fn into_response(self) -> Response<Body> {
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
.body(full_body(self))
.unwrap()
}
}
impl<T: IntoResponse> IntoResponse for (StatusCode, T) {
fn into_response(self) -> Response<Body> {
let mut res = self.1.into_response();
*res.status_mut() = self.0;
res
}
}
type RouteHandlerFn<S> = Arc<dyn Fn(crate::extract::RequestParts, S) -> Response<Body> + Send + Sync + 'static>;
pub struct Router<S = ()> {
routes: std::collections::HashMap<http::Method, matchit::Router<RouteHandlerFn<S>>>,
raw_routes: Vec<RouteEntry<S>>,
state: S,
}
struct RouteEntry<S> {
method: http::Method,
path: String,
handler: RouteHandlerFn<S>,
}
impl<S: Clone> Clone for RouteEntry<S> {
fn clone(&self) -> Self {
Self { method: self.method.clone(), path: self.path.clone(), handler: self.handler.clone() }
}
}
impl<S: Clone> Clone for Router<S> {
fn clone(&self) -> Self {
let mut routes = std::collections::HashMap::new();
for (method, _) in &self.routes {
let new_router = matchit::Router::new();
routes.insert(method.clone(), new_router);
}
let mut new_router = Self { routes, raw_routes: Vec::new(), state: self.state.clone() };
for entry in &self.raw_routes {
let router = new_router.routes.entry(entry.method.clone()).or_insert_with(matchit::Router::new);
let _ = router.insert(entry.path.clone(), entry.handler.clone());
new_router.raw_routes.push(entry.clone());
}
new_router
}
}
impl Default for Router<()> {
fn default() -> Self {
Self::new()
}
}
impl Router<()> {
pub fn new() -> Self {
Self { routes: std::collections::HashMap::new(), raw_routes: Vec::new(), state: () }
}
}
impl<S> Router<S> {
pub fn with_state(state: S) -> Self {
Self { routes: std::collections::HashMap::new(), raw_routes: Vec::new(), state }
}
pub fn state(&self) -> &S {
&self.state
}
pub fn state_mut(&mut self) -> &mut S {
&mut self.state
}
pub fn add_route_inner(
&mut self,
_method: http::Method,
_path: String,
_handler: Box<dyn std::any::Any + Send + Sync + 'static>,
) {
}
}
impl<S> Router<S>
where
S: Clone + Send + Sync + 'static,
{
pub fn add_route<H, T>(&mut self, method: http::Method, path: &str, handler: H)
where
H: Fn(T) -> Response<Body> + Clone + Send + Sync + 'static,
T: crate::extract::FromRequestParts<S, Error = crate::extract::ExtractorError> + 'static,
{
let handler_fn: RouteHandlerFn<S> = Arc::new(move |parts, state| {
let handler = handler.clone();
match T::from_request_parts(&parts, &state) {
Ok(t) => handler(t),
Err(e) => {
let error_msg = e.to_string();
Response::builder()
.status(StatusCode::BAD_REQUEST)
.header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
.body(full_body(error_msg))
.unwrap()
}
}
});
let entry = RouteEntry { method: method.clone(), path: path.to_string(), handler: handler_fn.clone() };
self.raw_routes.push(entry);
let router = self.routes.entry(method).or_insert_with(matchit::Router::new);
let _ = router.insert(path, handler_fn);
}
pub fn merge(mut self, other: Router<S>) -> Self {
for entry in other.raw_routes {
let router = self.routes.entry(entry.method.clone()).or_insert_with(matchit::Router::new);
let _ = router.insert(entry.path.clone(), entry.handler.clone());
self.raw_routes.push(entry);
}
self
}
pub fn nest_service<T>(mut self, prefix: &str, service: T) -> Self
where
T: Into<Router<S>>,
{
let other = service.into();
for entry in other.raw_routes {
let new_path = format!("{}{}", prefix.trim_end_matches('/'), entry.path);
let router = self.routes.entry(entry.method.clone()).or_insert_with(matchit::Router::new);
let _ = router.insert(new_path.clone(), entry.handler.clone());
self.raw_routes.push(RouteEntry { method: entry.method, path: new_path, handler: entry.handler });
}
self
}
}
#[derive(Debug, Clone, Copy, Default)]
pub enum HttpVersion {
Http1Only,
Http2Only,
#[default]
Both,
Http3,
}
#[derive(Debug, Clone)]
pub struct Http2Config {
pub enabled: bool,
pub enable_push: bool,
pub max_concurrent_streams: u32,
pub initial_stream_window_size: u32,
pub max_frame_size: u32,
pub enable_connect_protocol: bool,
pub stream_idle_timeout: Duration,
}
impl Default for Http2Config {
fn default() -> Self {
Self {
enabled: true,
enable_push: false,
max_concurrent_streams: 256,
initial_stream_window_size: 65535,
max_frame_size: 16384,
enable_connect_protocol: false,
stream_idle_timeout: Duration::from_secs(60),
}
}
}
impl Http2Config {
pub fn new() -> Self {
Self::default()
}
pub fn disabled() -> Self {
Self { enabled: false, ..Self::default() }
}
pub fn with_enable_push(mut self, enable: bool) -> Self {
self.enable_push = enable;
self
}
pub fn with_max_concurrent_streams(mut self, max: u32) -> Self {
self.max_concurrent_streams = max;
self
}
pub fn with_initial_stream_window_size(mut self, size: u32) -> Self {
self.initial_stream_window_size = size;
self
}
pub fn with_max_frame_size(mut self, size: u32) -> Self {
self.max_frame_size = size;
self
}
pub fn with_enable_connect_protocol(mut self, enable: bool) -> Self {
self.enable_connect_protocol = enable;
self
}
pub fn with_stream_idle_timeout(mut self, timeout: Duration) -> Self {
self.stream_idle_timeout = timeout;
self
}
}
#[derive(Debug, Clone)]
pub struct TlsConfig {
pub cert_path: String,
pub key_path: String,
}
impl TlsConfig {
pub fn new(cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
Self { cert_path: cert_path.into(), key_path: key_path.into() }
}
}
#[derive(Debug, Clone, Default)]
pub struct Http3Config {
pub enabled: bool,
}
impl Http3Config {
pub fn new() -> Self {
Self::default()
}
pub fn enabled() -> Self {
Self { enabled: true }
}
}
#[derive(Debug, Clone)]
pub struct HttpsServerConfig {
pub addr: SocketAddr,
pub service_name: String,
pub http_version: HttpVersion,
pub http2_config: Http2Config,
pub http3_config: Http3Config,
pub tls_config: Option<TlsConfig>,
}
impl Default for HttpsServerConfig {
fn default() -> Self {
Self {
addr: "0.0.0.0:3000".parse().unwrap(),
service_name: "wae-https-service".to_string(),
http_version: HttpVersion::Both,
http2_config: Http2Config::default(),
http3_config: Http3Config::default(),
tls_config: None,
}
}
}
pub struct HttpsServerBuilder<S = ()> {
config: HttpsServerConfig,
router: Router<S>,
_marker: std::marker::PhantomData<S>,
}
impl HttpsServerBuilder<()> {
pub fn new() -> Self {
Self { config: HttpsServerConfig::default(), router: Router::new(), _marker: std::marker::PhantomData }
}
}
impl Default for HttpsServerBuilder<()> {
fn default() -> Self {
Self::new()
}
}
impl<S> HttpsServerBuilder<S>
where
S: Clone + Send + Sync + 'static,
{
pub fn addr(mut self, addr: SocketAddr) -> Self {
self.config.addr = addr;
self
}
pub fn service_name(mut self, name: impl Into<String>) -> Self {
self.config.service_name = name.into();
self
}
pub fn router<T>(mut self, router: T) -> Self
where
T: Into<Router<S>>,
{
self.router = router.into();
self
}
pub fn merge_router(mut self, router: Router<S>) -> Self {
self.router = self.router.merge(router);
self
}
pub fn http_version(mut self, version: HttpVersion) -> Self {
self.config.http_version = version;
self
}
pub fn http2_config(mut self, config: Http2Config) -> Self {
self.config.http2_config = config;
self
}
pub fn http3_config(mut self, config: Http3Config) -> Self {
self.config.http3_config = config;
self
}
pub fn tls(mut self, cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
self.config.tls_config = Some(TlsConfig::new(cert_path, key_path));
self
}
pub fn tls_config(mut self, config: TlsConfig) -> Self {
self.config.tls_config = Some(config);
self
}
pub fn build(self) -> HttpsServer<S> {
HttpsServer { config: self.config, router: self.router, _marker: std::marker::PhantomData }
}
}
pub struct HttpsServer<S = ()> {
config: HttpsServerConfig,
router: Router<S>,
_marker: std::marker::PhantomData<S>,
}
impl<S> HttpsServer<S>
where
S: Clone + Send + Sync + 'static,
{
pub async fn serve(self) -> HttpsResult<()> {
let addr = self.config.addr;
let service_name = self.config.service_name.clone();
let protocol_info = self.get_protocol_info();
let tls_config = self.config.tls_config.clone();
let listener =
TcpListener::bind(addr).await.map_err(|e| WaeError::internal(format!("Failed to bind address: {}", e)))?;
info!("{} {} server starting on {}", service_name, protocol_info, addr);
match tls_config {
Some(tls_config) => self.serve_tls(listener, &tls_config).await,
None => self.serve_plain(listener).await,
}
}
async fn serve_plain(self, listener: TcpListener) -> HttpsResult<()> {
loop {
let (stream, _addr) = listener.accept().await.map_err(|e| WaeError::internal(format!("Accept error: {}", e)))?;
let router = self.router.clone();
tokio::spawn(async move {
let service = RouterService::new(router);
let io = hyper_util::rt::tokio::TokioIo::new(stream);
let _ = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
.serve_connection(io, service)
.await;
});
}
}
async fn serve_tls(self, listener: TcpListener, tls_config: &TlsConfig) -> HttpsResult<()> {
let enable_http2 = matches!(self.config.http_version, HttpVersion::Http2Only | HttpVersion::Both);
let acceptor = crate::tls::create_tls_acceptor_with_http2(&tls_config.cert_path, &tls_config.key_path, enable_http2)?;
loop {
let (stream, _addr) = listener.accept().await.map_err(|e| WaeError::internal(format!("Accept error: {}", e)))?;
let acceptor = acceptor.clone();
let router = self.router.clone();
tokio::spawn(async move {
let tls_stream = match acceptor.accept(stream).await {
Ok(s) => s,
Err(e) => {
tracing::error!("TLS handshake error: {}", e);
return;
}
};
let service = RouterService::new(router);
let io = hyper_util::rt::tokio::TokioIo::new(tls_stream);
let _ = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
.serve_connection(io, service)
.await;
});
}
}
fn get_protocol_info(&self) -> String {
let tls_info = if self.config.tls_config.is_some() { "S" } else { "" };
let version_info = match self.config.http_version {
HttpVersion::Http1Only => "HTTP/1.1",
HttpVersion::Http2Only => "HTTP/2",
HttpVersion::Both => "HTTP/1.1+HTTP/2",
HttpVersion::Http3 => "HTTP/3",
};
format!("{}{}", version_info, tls_info)
}
}
#[derive(Debug, serde::Serialize)]
pub struct ApiResponse<T> {
pub success: bool,
pub data: Option<T>,
pub error: Option<ApiErrorBody>,
pub trace_id: Option<String>,
}
#[derive(Debug, serde::Serialize)]
pub struct ApiErrorBody {
pub code: String,
pub message: String,
}
impl<T: serde::Serialize> ApiResponse<T> {
pub fn into_response(self) -> Response<Body> {
let status = if self.success { StatusCode::OK } else { StatusCode::BAD_REQUEST };
let body = serde_json::to_string(&self).unwrap_or_default();
Response::builder()
.status(status)
.header(header::CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(body)))
.unwrap()
}
}
impl<T> IntoResponse for ApiResponse<T>
where
T: serde::Serialize,
{
fn into_response(self) -> Response<Body> {
self.into_response()
}
}
impl<T> ApiResponse<T>
where
T: serde::Serialize,
{
pub fn success(data: T) -> Self {
Self { success: true, data: Some(data), error: None, trace_id: None }
}
pub fn success_with_trace(data: T, trace_id: impl Into<String>) -> Self {
Self { success: true, data: Some(data), error: None, trace_id: Some(trace_id.into()) }
}
pub fn error(code: impl Into<String>, message: impl Into<String>) -> Self {
Self {
success: false,
data: None,
error: Some(ApiErrorBody { code: code.into(), message: message.into() }),
trace_id: None,
}
}
pub fn error_with_trace(code: impl Into<String>, message: impl Into<String>, trace_id: impl Into<String>) -> Self {
Self {
success: false,
data: None,
error: Some(ApiErrorBody { code: code.into(), message: message.into() }),
trace_id: Some(trace_id.into()),
}
}
}
pub fn static_files_router(base_path: impl AsRef<Path>, prefix: &str) -> Router {
let mut router = Router::new();
let base_path = base_path.as_ref().to_path_buf();
let prefix = prefix.to_string();
router.add_route(http::Method::GET, &format!("{}/*path", prefix), move |parts: crate::extract::RequestParts| {
let path = parts.uri.path();
let file_path = if let Some(stripped) = path.strip_prefix(&prefix) {
base_path.join(stripped.trim_start_matches('/'))
}
else {
return Response::builder().status(StatusCode::NOT_FOUND).body(empty_body()).unwrap();
};
if !file_path.exists() || !file_path.is_file() {
return Response::builder().status(StatusCode::NOT_FOUND).body(empty_body()).unwrap();
}
let content = match std::fs::read(&file_path) {
Ok(c) => c,
Err(_) => {
return Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(empty_body()).unwrap();
}
};
let mime_type = mime_guess::from_path(&file_path).first_or_octet_stream().to_string();
Response::builder().status(StatusCode::OK).header(header::CONTENT_TYPE, mime_type).body(full_body(content)).unwrap()
});
router
}
pub struct RouterService<S = ()> {
router: Router<S>,
}
impl<S: Clone> Clone for RouterService<S> {
fn clone(&self) -> Self {
Self { router: self.router.clone() }
}
}
impl<S> From<Router<S>> for RouterService<S> {
fn from(router: Router<S>) -> Self {
Self { router }
}
}
impl<S> RouterService<S>
where
S: Clone + Send + Sync + 'static,
{
pub fn new(router: Router<S>) -> Self {
Self { router }
}
pub async fn handle_request(&self, request: hyper::Request<hyper::body::Incoming>) -> http::Response<Body> {
let (parts, _body) = request.into_parts();
let method = parts.method.clone();
let uri = parts.uri.clone();
let version = parts.version;
let headers = parts.headers.clone();
let mut request_parts = crate::extract::RequestParts::new(method.clone(), uri.clone(), version, headers);
let path = uri.path();
let Some(method_router) = self.router.routes.get(&method)
else {
return Response::builder().status(StatusCode::METHOD_NOT_ALLOWED).body(empty_body()).unwrap();
};
let match_result = method_router.at(path);
let Ok(matched) = match_result
else {
return Response::builder().status(StatusCode::NOT_FOUND).body(empty_body()).unwrap();
};
request_parts.path_params = matched.params.iter().map(|(k, v)| (k.to_string(), v.to_string())).collect();
let handler = matched.value;
let state = self.router.state.clone();
handler(request_parts, state)
}
}
impl<S> hyper::service::Service<hyper::Request<hyper::body::Incoming>> for RouterService<S>
where
S: Clone + Send + Sync + 'static,
{
type Response = http::Response<Body>;
type Error = std::convert::Infallible;
type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn call(&self, req: hyper::Request<hyper::body::Incoming>) -> Self::Future {
let this = self.clone();
Box::pin(async move {
let response = this.handle_request(req).await;
Ok(response)
})
}
}