pub mod config;
pub mod connection;
pub mod executor;
pub mod performance;
pub mod pool;
pub mod protocol;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
pub mod tls;
pub use config::{Http2Config, ServerConfig};
pub use connection::*;
pub use executor::TokioExecutor;
pub use performance::{OptimizedTcpListener, PerformanceConfig, PerformanceMetrics};
pub use pool::{ObjectPools, PoolConfig};
pub use protocol::{HttpProtocol, ProtocolDetector};
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
pub use tls::{TlsConfig, TlsError, TlsVersion};
use crate::{Request, Response, Router};
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::server::conn::{http1, http2};
use hyper::service::service_fn;
use hyper_util::rt::TokioIo;
use parking_lot::RwLock;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::time::{interval, timeout};
use tracing::{debug, info, warn};
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
use tokio_rustls::TlsAcceptor;
pub struct Server {
router: Arc<Router>,
addr: SocketAddr,
config: ServerConfig,
perf_config: PerformanceConfig,
metrics: Arc<PerformanceMetrics>,
state: Arc<ServerState>,
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
tls_acceptor: Option<TlsAcceptor>,
}
#[derive(Debug)]
struct ServerState {
running: AtomicBool,
active_connections: AtomicUsize,
total_requests: AtomicU64,
start_time: RwLock<Option<Instant>>,
}
impl ServerState {
fn new() -> Self {
Self {
running: AtomicBool::new(false),
active_connections: AtomicUsize::new(0),
total_requests: AtomicU64::new(0),
start_time: RwLock::new(None),
}
}
}
impl Server {
pub fn new(router: Router, addr: SocketAddr) -> Self {
let perf_config = PerformanceConfig::max_rps();
let metrics = PerformanceMetrics::new();
Self {
router: Arc::new(router),
addr,
config: ServerConfig::default(),
perf_config,
metrics,
state: Arc::new(ServerState::new()),
#[cfg(feature = "tls")]
tls_acceptor: None,
}
}
pub fn with_config(router: Router, addr: SocketAddr, config: ServerConfig) -> Self {
let mut server = Self::new(router, addr);
server.config = config;
server
}
pub fn max_rps(router: Router, addr: SocketAddr) -> Self {
let perf_config = PerformanceConfig::max_rps();
let metrics = PerformanceMetrics::new();
Self {
router: Arc::new(router),
addr,
config: ServerConfig::default(),
perf_config,
metrics,
state: Arc::new(ServerState::new()),
#[cfg(feature = "tls")]
tls_acceptor: None,
}
}
pub fn with_performance_config(mut self, config: PerformanceConfig) -> Self {
self.perf_config = config;
self
}
pub fn with_server_config(mut self, config: ServerConfig) -> Self {
self.config = config;
self
}
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
pub fn with_tls(mut self, tls_config: tls::TlsConfig) -> Result<Self, tls::TlsError> {
let acceptor = tls_config.build()?;
self.tls_acceptor = Some(acceptor);
self.config.tls = Some(tls_config);
Ok(self)
}
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
pub fn enable_https(
self,
cert_file: impl Into<String>,
key_file: impl Into<String>,
) -> Result<Self, tls::TlsError> {
let tls_config = tls::TlsConfig::new(cert_file, key_file);
self.with_tls(tls_config)
}
#[cfg(all(feature = "tls", feature = "self-signed"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "tls", feature = "self-signed"))))]
pub fn with_self_signed_cert(self, domain: &str) -> Result<Self, tls::TlsError> {
let (cert_pem, key_pem) = tls::TlsConfig::generate_self_signed(domain)?;
let tls_config = tls::TlsConfig::new("self_signed_cert.pem", "self_signed_key.pem");
self.with_tls(tls_config)
}
pub fn redirect_to_https(mut self, https_port: u16) -> Self {
self.config = self.config.redirect_to_https(https_port);
self
}
pub fn metrics(&self) -> Arc<PerformanceMetrics> {
Arc::clone(&self.metrics)
}
pub fn uptime(&self) -> Option<Duration> {
self.state.start_time.read().map(|start| start.elapsed())
}
pub fn active_connections(&self) -> usize {
self.state.active_connections.load(Ordering::Relaxed)
}
pub fn total_requests(&self) -> u64 {
self.state.total_requests.load(Ordering::Relaxed)
}
pub async fn ignitia(self) -> Result<(), Box<dyn std::error::Error>> {
self.state.running.store(true, Ordering::Relaxed);
*self.state.start_time.write() = Some(Instant::now());
let listener = OptimizedTcpListener::bind(self.addr, self.perf_config.clone()).await?;
let listener_metrics = listener.metrics();
#[cfg(feature = "tls")]
let protocol_info = if self.tls_acceptor.is_some() {
if self.config.http2.enabled && self.config.http1_enabled {
"HTTPS (HTTP/1.1 + HTTP/2)"
} else if self.config.http2.enabled {
"HTTPS (HTTP/2)"
} else {
"HTTPS (HTTP/1.1)"
}
} else if self.config.http2.enabled && self.config.http1_enabled {
"HTTP (HTTP/1.1 + HTTP/2)"
} else if self.config.http2.enabled {
"HTTP (HTTP/2)"
} else {
"HTTP (HTTP/1.1)"
};
#[cfg(not(feature = "tls"))]
let protocol_info = if self.config.http2.enabled && self.config.http1_enabled {
"HTTP (HTTP/1.1 + HTTP/2)"
} else if self.config.http2.enabled {
"HTTP (HTTP/2)"
} else {
"HTTP (HTTP/1.1)"
};
#[cfg(feature = "tls")]
let scheme = if self.tls_acceptor.is_some() {
"https"
} else {
"http"
};
#[cfg(not(feature = "tls"))]
let scheme = "http";
info!(
"🔥 Ignitia server blazing on {}://{} ({})",
scheme, self.addr, protocol_info
);
self.start_metrics_collection(Arc::clone(&listener_metrics));
loop {
if !self.state.running.load(Ordering::Relaxed) {
info!("🛑 Server shutdown requested");
break;
}
let (stream, addr) = match listener.accept().await {
Ok(conn) => conn,
Err(e) => {
warn!("❌ Failed to accept connection: {}", e);
continue;
}
};
let router = Arc::clone(&self.router);
let config = self.config.clone();
let metrics = Arc::clone(&self.metrics);
let state = Arc::clone(&self.state);
#[cfg(feature = "tls")]
let tls_acceptor = self.tls_acceptor.clone();
tokio::spawn(async move {
state.active_connections.fetch_add(1, Ordering::Relaxed);
let connection_start = Instant::now();
let result = {
#[cfg(feature = "tls")]
if let Some(acceptor) = tls_acceptor {
handle_tls_connection(
stream,
router,
config,
acceptor,
addr,
metrics.clone(),
state.clone(),
)
.await
} else {
let io = TokioIo::new(stream);
handle_connection(io, router, config, addr, metrics.clone(), state.clone())
.await
}
#[cfg(not(feature = "tls"))]
{
let io = TokioIo::new(stream);
handle_connection(io, router, config, addr, metrics.clone(), state.clone())
.await
}
};
if let Err(err) = result {
debug!("🔌 Connection error from {}: {}", addr, err);
}
let connection_duration = connection_start.elapsed();
metrics.record_request(connection_duration);
state.active_connections.fetch_sub(1, Ordering::Relaxed);
});
}
Ok(())
}
fn start_metrics_collection(&self, _listener_metrics: Arc<PerformanceMetrics>) {
let server_state = Arc::clone(&self.state);
let server_metrics = Arc::clone(&self.metrics);
tokio::spawn(async move {
let mut interval = interval(Duration::from_secs(1));
let mut last_requests = 0u64;
loop {
interval.tick().await;
if !server_state.running.load(Ordering::Relaxed) {
break;
}
let current_requests = server_state.total_requests.load(Ordering::Relaxed);
let rps = current_requests.saturating_sub(last_requests);
last_requests = current_requests;
server_metrics
.requests_per_second
.store(rps, Ordering::Relaxed);
if current_requests % 10000 == 0 && current_requests > 0 {
info!(
"📈 Performance: {} RPS, {} active connections, {} total requests",
rps,
server_state.active_connections.load(Ordering::Relaxed),
current_requests
);
}
}
});
}
pub async fn shutdown(&self) {
info!("🛑 Initiating graceful shutdown...");
self.state.running.store(false, Ordering::Relaxed);
let shutdown_timeout = Duration::from_secs(30);
let start = Instant::now();
while self.state.active_connections.load(Ordering::Relaxed) > 0
&& start.elapsed() < shutdown_timeout
{
tokio::time::sleep(Duration::from_millis(100)).await;
}
let remaining = self.state.active_connections.load(Ordering::Relaxed);
if remaining > 0 {
warn!("⚠️ Forcing shutdown with {} active connections", remaining);
} else {
info!("✅ All connections closed gracefully");
}
info!("🏁 Server shutdown complete");
}
}
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
async fn handle_tls_connection(
stream: tokio::net::TcpStream,
router: Arc<Router>,
config: ServerConfig,
acceptor: TlsAcceptor,
addr: SocketAddr,
metrics: Arc<PerformanceMetrics>,
state: Arc<ServerState>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
use protocol::ProtocolDetector;
let tls_stream = acceptor.accept(stream).await?;
let alpn_protocol = {
let (_, connection_info) = tls_stream.get_ref();
connection_info.alpn_protocol().map(|p| p.to_vec())
};
debug!(
"🔐 TLS connection from {} - ALPN: {:?}",
addr,
alpn_protocol.as_ref().map(|p| String::from_utf8_lossy(p))
);
let io = TokioIo::new(tls_stream);
let protocol = ProtocolDetector::detect_from_alpn(alpn_protocol.as_deref());
let service = service_fn(move |req| {
let router = Arc::clone(&router);
let metrics = Arc::clone(&metrics);
let state = Arc::clone(&state);
async move { handle_request(router, req, config.max_request_body_size, metrics, state).await }
});
match protocol {
protocol::HttpProtocol::Http2 => {
serve_http2_connection(io, service, config.http2).await?;
}
protocol::HttpProtocol::Http1 | protocol::HttpProtocol::Auto => {
let mut builder = http1::Builder::new();
builder.half_close(true);
builder.timer(hyper_util::rt::TokioTimer::new());
if config.http2.enabled {
builder
.serve_connection(io, service)
.with_upgrades()
.await?;
} else {
builder.serve_connection(io, service).await?;
}
}
}
Ok(())
}
async fn handle_connection<I>(
io: TokioIo<I>,
router: Arc<Router>,
config: ServerConfig,
_addr: SocketAddr,
metrics: Arc<PerformanceMetrics>,
state: Arc<ServerState>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
if config.redirect_http_to_https {
return handle_http_redirect(io, config.https_port.unwrap_or(443)).await;
}
let service = service_fn(move |req| {
let router = Arc::clone(&router);
let metrics = Arc::clone(&metrics);
let state = Arc::clone(&state);
async move { handle_request(router, req, config.max_request_body_size, metrics, state).await }
});
if config.auto_protocol_detection && config.http1_enabled && config.http2.enabled {
let mut builder = http1::Builder::new();
builder.half_close(true);
builder.timer(hyper_util::rt::TokioTimer::new());
builder
.serve_connection(io, service)
.with_upgrades()
.await?;
} else if config.http2.enabled && config.http2.enable_prior_knowledge {
serve_http2_connection(io, service, config.http2).await?;
} else if config.http2.enabled {
serve_http2_connection(io, service, config.http2).await?;
} else {
let mut builder = http1::Builder::new();
builder.timer(hyper_util::rt::TokioTimer::new());
builder.serve_connection(io, service).await?;
}
Ok(())
}
async fn handle_http_redirect<I>(
io: TokioIo<I>,
https_port: u16,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
let service = service_fn(move |req| async move {
let host = req
.headers()
.get("host")
.and_then(|h| h.to_str().ok())
.unwrap_or("localhost");
let redirect_url = if https_port == 443 {
format!(
"https://{}{}",
host,
req.uri()
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("")
)
} else {
format!(
"https://{}:{}{}",
host,
https_port,
req.uri()
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("")
)
};
Ok::<_, hyper::Error>(
hyper::Response::builder()
.status(301)
.header("Location", redirect_url)
.body(Full::new(Bytes::from("Redirecting to HTTPS")))
.unwrap(),
)
});
let mut builder = http1::Builder::new();
builder.timer(hyper_util::rt::TokioTimer::new());
builder.serve_connection(io, service).await?;
Ok(())
}
async fn serve_http2_connection<S, I>(
io: TokioIo<I>,
service: S,
config: Http2Config,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
S: hyper::service::Service<
hyper::Request<hyper::body::Incoming>,
Response = hyper::Response<Full<Bytes>>,
> + Clone
+ Send
+ 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
S::Future: Send,
{
let executor = TokioExecutor;
let mut builder = http2::Builder::new(executor);
builder.timer(hyper_util::rt::TokioTimer::new());
if let Some(max_streams) = config.max_concurrent_streams {
builder.max_concurrent_streams(max_streams);
}
if let Some(window_size) = config.initial_connection_window_size {
builder.initial_connection_window_size(window_size);
}
if let Some(window_size) = config.initial_stream_window_size {
builder.initial_stream_window_size(window_size);
}
if let Some(frame_size) = config.max_frame_size {
builder.max_frame_size(frame_size);
}
if let Some(interval) = config.keep_alive_interval {
builder.keep_alive_interval(interval);
}
if let Some(timeout) = config.keep_alive_timeout {
builder.keep_alive_timeout(timeout);
}
if config.adaptive_window {
builder.adaptive_window(true);
}
if let Some(max_header_size) = config.max_header_list_size {
builder.max_header_list_size(max_header_size);
}
debug!("🌐 Serving HTTP/2 connection with config: {:?}", config);
builder.serve_connection(io, service).await?;
Ok(())
}
async fn handle_request(
router: Arc<Router>,
req: hyper::Request<hyper::body::Incoming>,
max_body_size: usize,
metrics: Arc<PerformanceMetrics>,
state: Arc<ServerState>,
) -> Result<hyper::Response<Full<Bytes>>, hyper::Error> {
let request_start = Instant::now();
state.total_requests.fetch_add(1, Ordering::Relaxed);
#[cfg(feature = "websocket")]
if is_websocket_upgrade(&req) {
return handle_websocket_upgrade(router, req).await;
}
handle_regular_http_request(router, req, max_body_size, metrics, request_start).await
}
#[cfg(feature = "websocket")]
fn is_websocket_upgrade(req: &hyper::Request<hyper::body::Incoming>) -> bool {
use hyper::header::{CONNECTION, UPGRADE};
let version = req.version();
if version == http::Version::HTTP_11 {
let connection_header = req.headers().get(CONNECTION).and_then(|h| h.to_str().ok());
let upgrade_header = req.headers().get(UPGRADE).and_then(|h| h.to_str().ok());
if let (Some(conn), Some(upgrade)) = (connection_header, upgrade_header) {
return conn.to_lowercase().contains("upgrade")
&& upgrade.to_lowercase().contains("websocket")
&& req.headers().get("sec-websocket-key").is_some();
}
}
false
}
#[cfg(feature = "websocket")]
async fn handle_websocket_upgrade(
router: Arc<Router>,
hyper_req: hyper::Request<hyper::body::Incoming>,
) -> Result<hyper::Response<Full<Bytes>>, hyper::Error> {
let path = hyper_req.uri().path().to_string();
let websocket_handlers = router.get_websocket_handlers();
let handler = match websocket_handlers.get(&path) {
Some(handler) => Arc::clone(handler.value()),
None => {
debug!("🔍 No WebSocket handler found for path: {}", path);
return Ok(hyper::Response::builder()
.status(404)
.body(Full::new(Bytes::from("WebSocket endpoint not found")))
.unwrap());
}
};
let method = hyper_req.method().clone();
let uri = hyper_req.uri().clone();
let version = hyper_req.version();
let headers = hyper_req.headers().clone();
let router_extensions = {
let inner = router.inner.read();
inner.extensions.clone()
};
let mut framework_req = Request::new(method, uri, version, headers, Bytes::new());
framework_req.extensions = router_extensions;
if !crate::websocket::is_websocket_request(&framework_req) {
debug!("❌ Invalid WebSocket upgrade request");
return Ok(hyper::Response::builder()
.status(400)
.body(Full::new(Bytes::from("Invalid WebSocket upgrade request")))
.unwrap());
}
let upgrade_response = match crate::websocket::upgrade_connection(&framework_req) {
Ok(resp) => resp,
Err(e) => {
debug!("❌ WebSocket upgrade failed: {}", e);
return Ok(hyper::Response::builder()
.status(e.status_code())
.body(Full::new(Bytes::from(e.to_string())))
.unwrap());
}
};
let mut response_builder = hyper::Response::builder().status(upgrade_response.status);
for (key, value) in upgrade_response.headers.iter() {
response_builder = response_builder.header(key, value);
}
let response = response_builder.body(Full::new(Bytes::new())).unwrap();
tokio::spawn(async move {
match hyper::upgrade::on(hyper_req).await {
Ok(upgraded) => {
let response =
crate::websocket::handle_websocket_upgrade(framework_req, upgraded, handler)
.await;
#[cfg(debug_assertions)]
if !response.status.is_success() {
debug!(
"🔌 WebSocket handler returned error status: {}",
response.status
);
}
}
Err(e) => {
debug!("🔌 WebSocket upgrade failed: {}", e);
}
}
});
Ok(response)
}
async fn handle_regular_http_request(
router: Arc<Router>,
req: hyper::Request<hyper::body::Incoming>,
max_body_size: usize,
metrics: Arc<PerformanceMetrics>,
request_start: Instant,
) -> Result<hyper::Response<Full<Bytes>>, hyper::Error> {
let (parts, body) = req.into_parts();
let body_bytes = match timeout(
Duration::from_secs(10), body.collect(),
)
.await
{
Ok(Ok(collected)) => {
let bytes = collected.to_bytes();
if bytes.len() > max_body_size {
let mut response =
hyper::Response::new(Full::new(Bytes::from("Request too large")));
*response.status_mut() = http::StatusCode::PAYLOAD_TOO_LARGE;
return Ok(response);
}
bytes
}
Ok(Err(_)) | Err(_) => Bytes::new(),
};
let request = Request::new(
parts.method,
parts.uri,
parts.version,
parts.headers,
body_bytes,
);
let response = match router.handle(request).await {
Ok(res) => res,
Err(err) => {
let status = err.status_code();
let mut res = Response::new(status);
res.body = Bytes::from(err.to_string());
res
}
};
let request_duration = request_start.elapsed();
metrics.record_request(request_duration);
let mut builder = hyper::Response::builder().status(response.status);
for (key, value) in response.headers.iter() {
builder = builder.header(key, value);
}
Ok(builder.body(Full::new(response.body)).unwrap())
}