#![doc = include_str!("readme.md")]
#![warn(missing_docs)]
pub mod error;
pub mod extract;
pub mod middleware;
pub mod response;
pub mod router;
pub mod template;
pub mod tls;
pub use wae_session as session;
use http::{Response, StatusCode, header};
use http_body_util::Full;
use hyper::body::Bytes;
use std::{net::SocketAddr, path::Path, time::Duration};
use tokio::net::TcpListener;
use tracing::info;
pub use wae_types::{WaeError, WaeResult};
pub type Body = Full<Bytes>;
pub fn empty_body() -> Body {
Full::new(Bytes::new())
}
pub fn full_body<B: Into<Bytes>>(data: B) -> Body {
Full::new(data.into())
}
pub type HttpsResult<T> = WaeResult<T>;
pub type HttpsError = WaeError;
#[derive(Clone, Default)]
pub struct Router {
}
impl Router {
pub fn new() -> Self {
Self::default()
}
pub fn merge(self, _other: Self) -> Self {
self
}
pub fn nest_service<S>(self, _prefix: &str, _service: S) -> Self {
self
}
}
#[derive(Debug, Clone, Copy, Default)]
pub enum HttpVersion {
Http1Only,
Http2Only,
#[default]
Both,
Http3,
}
#[derive(Debug, Clone)]
pub struct Http2Config {
pub enabled: bool,
pub enable_push: bool,
pub max_concurrent_streams: u32,
pub initial_stream_window_size: u32,
pub max_frame_size: u32,
pub enable_connect_protocol: bool,
pub stream_idle_timeout: Duration,
}
impl Default for Http2Config {
fn default() -> Self {
Self {
enabled: true,
enable_push: false,
max_concurrent_streams: 256,
initial_stream_window_size: 65535,
max_frame_size: 16384,
enable_connect_protocol: false,
stream_idle_timeout: Duration::from_secs(60),
}
}
}
impl Http2Config {
pub fn new() -> Self {
Self::default()
}
pub fn disabled() -> Self {
Self { enabled: false, ..Self::default() }
}
pub fn with_enable_push(mut self, enable: bool) -> Self {
self.enable_push = enable;
self
}
pub fn with_max_concurrent_streams(mut self, max: u32) -> Self {
self.max_concurrent_streams = max;
self
}
pub fn with_initial_stream_window_size(mut self, size: u32) -> Self {
self.initial_stream_window_size = size;
self
}
pub fn with_max_frame_size(mut self, size: u32) -> Self {
self.max_frame_size = size;
self
}
pub fn with_enable_connect_protocol(mut self, enable: bool) -> Self {
self.enable_connect_protocol = enable;
self
}
pub fn with_stream_idle_timeout(mut self, timeout: Duration) -> Self {
self.stream_idle_timeout = timeout;
self
}
}
#[derive(Debug, Clone)]
pub struct TlsConfig {
pub cert_path: String,
pub key_path: String,
}
impl TlsConfig {
pub fn new(cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
Self { cert_path: cert_path.into(), key_path: key_path.into() }
}
}
#[derive(Debug, Clone, Default)]
pub struct Http3Config {
pub enabled: bool,
}
impl Http3Config {
pub fn new() -> Self {
Self::default()
}
pub fn enabled() -> Self {
Self { enabled: true }
}
}
#[derive(Debug, Clone)]
pub struct HttpsServerConfig {
pub addr: SocketAddr,
pub service_name: String,
pub http_version: HttpVersion,
pub http2_config: Http2Config,
pub http3_config: Http3Config,
pub tls_config: Option<TlsConfig>,
}
impl Default for HttpsServerConfig {
fn default() -> Self {
Self {
addr: "0.0.0.0:3000".parse().unwrap(),
service_name: "wae-https-service".to_string(),
http_version: HttpVersion::Both,
http2_config: Http2Config::default(),
http3_config: Http3Config::default(),
tls_config: None,
}
}
}
pub struct HttpsServerBuilder {
config: HttpsServerConfig,
router: Router,
}
impl HttpsServerBuilder {
pub fn new() -> Self {
Self { config: HttpsServerConfig::default(), router: Router::new() }
}
pub fn addr(mut self, addr: SocketAddr) -> Self {
self.config.addr = addr;
self
}
pub fn service_name(mut self, name: impl Into<String>) -> Self {
self.config.service_name = name.into();
self
}
pub fn router(mut self, router: Router) -> Self {
self.router = router;
self
}
pub fn merge_router(mut self, router: Router) -> Self {
self.router = self.router.merge(router);
self
}
pub fn http_version(mut self, version: HttpVersion) -> Self {
self.config.http_version = version;
self
}
pub fn http2_config(mut self, config: Http2Config) -> Self {
self.config.http2_config = config;
self
}
pub fn http3_config(mut self, config: Http3Config) -> Self {
self.config.http3_config = config;
self
}
pub fn tls(mut self, cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
self.config.tls_config = Some(TlsConfig::new(cert_path, key_path));
self
}
pub fn tls_config(mut self, config: TlsConfig) -> Self {
self.config.tls_config = Some(config);
self
}
pub fn build(self) -> HttpsServer {
HttpsServer { config: self.config, router: self.router }
}
}
impl Default for HttpsServerBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct HttpsServer {
config: HttpsServerConfig,
router: Router,
}
impl HttpsServer {
pub async fn serve(self) -> HttpsResult<()> {
let addr = self.config.addr;
let service_name = self.config.service_name.clone();
let protocol_info = self.get_protocol_info();
let tls_config = self.config.tls_config.clone();
let listener =
TcpListener::bind(addr).await.map_err(|e| WaeError::internal(format!("Failed to bind address: {}", e)))?;
info!("{} {} server starting on {}", service_name, protocol_info, addr);
match tls_config {
Some(tls_config) => self.serve_tls(listener, &tls_config).await,
None => self.serve_plain(listener).await,
}
}
async fn serve_plain(self, _listener: TcpListener) -> HttpsResult<()> {
loop {
tokio::time::sleep(Duration::from_secs(3600)).await;
}
}
async fn serve_tls(self, _listener: TcpListener, _tls_config: &TlsConfig) -> HttpsResult<()> {
loop {
tokio::time::sleep(Duration::from_secs(3600)).await;
}
}
fn get_protocol_info(&self) -> String {
let tls_info = if self.config.tls_config.is_some() { "S" } else { "" };
let version_info = match self.config.http_version {
HttpVersion::Http1Only => "HTTP/1.1",
HttpVersion::Http2Only => "HTTP/2",
HttpVersion::Both => "HTTP/1.1+HTTP/2",
HttpVersion::Http3 => "HTTP/3",
};
format!("{}{}", version_info, tls_info)
}
}
#[derive(Debug, serde::Serialize)]
pub struct ApiResponse<T> {
pub success: bool,
pub data: Option<T>,
pub error: Option<ApiErrorBody>,
pub trace_id: Option<String>,
}
#[derive(Debug, serde::Serialize)]
pub struct ApiErrorBody {
pub code: String,
pub message: String,
}
impl<T: serde::Serialize> ApiResponse<T> {
pub fn into_response(self) -> Response<Body> {
let status = if self.success { StatusCode::OK } else { StatusCode::BAD_REQUEST };
let body = serde_json::to_string(&self).unwrap_or_default();
Response::builder()
.status(status)
.header(header::CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(body)))
.unwrap()
}
}
pub fn static_files_router(_path: impl AsRef<Path>, _prefix: &str) -> Router {
Router::new()
}
impl<T> ApiResponse<T>
where
T: serde::Serialize,
{
pub fn success(data: T) -> Self {
Self { success: true, data: Some(data), error: None, trace_id: None }
}
pub fn success_with_trace(data: T, trace_id: impl Into<String>) -> Self {
Self { success: true, data: Some(data), error: None, trace_id: Some(trace_id.into()) }
}
pub fn error(code: impl Into<String>, message: impl Into<String>) -> Self {
Self {
success: false,
data: None,
error: Some(ApiErrorBody { code: code.into(), message: message.into() }),
trace_id: None,
}
}
pub fn error_with_trace(code: impl Into<String>, message: impl Into<String>, trace_id: impl Into<String>) -> Self {
Self {
success: false,
data: None,
error: Some(ApiErrorBody { code: code.into(), message: message.into() }),
trace_id: Some(trace_id.into()),
}
}
}