use super::method::Method;
use super::middleware::MiddlewareFn;
use super::response::IntoResponse;
use super::sse::SseWriter;
use super::stream::StreamResponse;
use super::trie::{RouteHandler, TrieNode};
#[cfg(feature = "ws")]
use super::websocket::WsConn;
use super::{Request, Response};
use crate::RuntimeError;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use super::BufferConfig;
#[cfg(feature = "grpc")]
pub use super::dispatch::GrpcRouter;
#[cfg(feature = "ws")]
pub(super) use super::dispatch::WsHandler;
pub(super) use super::dispatch::{
DispatchResult, FrozenRouter, GateCheck, Handler, ServerDispatch, SseHandler, gate_result,
};
impl std::fmt::Debug for Router {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Router")
.field("middleware_count", &self.middleware.len())
.field("buffers", &self.buffers)
.field(
"skip_middleware_for_internal",
&self.skip_middleware_for_internal,
)
.finish()
}
}
pub struct Router {
root: TrieNode,
middleware: Vec<MiddlewareFn>,
buffers: BufferConfig,
skip_middleware_for_internal: bool,
#[cfg(feature = "grpc")]
grpc_router: Option<super::dispatch::GrpcRouter>,
}
impl Default for Router {
fn default() -> Self {
Self {
root: TrieNode::new(),
middleware: Vec::new(),
buffers: BufferConfig::default(),
skip_middleware_for_internal: false,
#[cfg(feature = "grpc")]
grpc_router: None,
}
}
}
impl Router {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn max_request_body(mut self, bytes: usize) -> Self {
self.buffers = self.buffers.with_max_request_body(bytes);
self
}
#[must_use]
pub fn sse_buffer_size(mut self, size: usize) -> Self {
self.buffers = self.buffers.with_sse_buffer_size(size);
self
}
#[cfg(feature = "ws")]
#[must_use]
pub fn ws_buffer_size(mut self, size: usize) -> Self {
self.buffers = self.buffers.with_ws_buffer_size(size);
self
}
pub(super) fn buffer_config(&self) -> BufferConfig {
self.buffers
}
#[must_use]
pub fn skip_middleware_for_internal(mut self, skip: bool) -> Self {
self.skip_middleware_for_internal = skip;
self
}
pub fn use_middleware<F, Fut>(&mut self, mw: F)
where
F: Fn(&Request, super::middleware::Next) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.middleware
.push(Box::new(move |req, next| Box::pin(mw(req, next))));
}
pub fn get<F, Fut, R>(&mut self, path: &str, handler: F)
where
F: Fn(&Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
R: IntoResponse + 'static,
{
self.add(Method::Get, path, handler);
}
pub fn post<F, Fut, R>(&mut self, path: &str, handler: F)
where
F: Fn(&Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
R: IntoResponse + 'static,
{
self.add(Method::Post, path, handler);
}
pub fn put<F, Fut, R>(&mut self, path: &str, handler: F)
where
F: Fn(&Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
R: IntoResponse + 'static,
{
self.add(Method::Put, path, handler);
}
pub fn delete<F, Fut, R>(&mut self, path: &str, handler: F)
where
F: Fn(&Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
R: IntoResponse + 'static,
{
self.add(Method::Delete, path, handler);
}
pub fn patch<F, Fut, R>(&mut self, path: &str, handler: F)
where
F: Fn(&Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
R: IntoResponse + 'static,
{
self.add(Method::Patch, path, handler);
}
pub fn head<F, Fut, R>(&mut self, path: &str, handler: F)
where
F: Fn(&Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
R: IntoResponse + 'static,
{
self.add(Method::Head, path, handler);
}
pub fn options<F, Fut, R>(&mut self, path: &str, handler: F)
where
F: Fn(&Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
R: IntoResponse + 'static,
{
self.add(Method::Options, path, handler);
}
pub fn get_stream(
&mut self,
path: &str,
handler: impl Fn(&Request) -> Pin<Box<dyn Future<Output = StreamResponse> + Send>>
+ Send
+ Sync
+ 'static,
) {
self.add_stream(Method::Get, path, handler);
}
pub fn post_stream(
&mut self,
path: &str,
handler: impl Fn(&Request) -> Pin<Box<dyn Future<Output = StreamResponse> + Send>>
+ Send
+ Sync
+ 'static,
) {
self.add_stream(Method::Post, path, handler);
}
pub fn get_sse(
&mut self,
path: &str,
handler: impl Fn(&Request, &mut SseWriter) -> Result<(), RuntimeError> + Send + Sync + 'static,
) {
self.root
.insert_route(Method::Get, path, RouteHandler::Sse(Arc::new(handler)));
}
#[cfg(feature = "ws")]
pub fn ws(
&mut self,
path: &str,
handler: impl Fn(&Request, WsConn) -> Result<(), RuntimeError> + Send + Sync + 'static,
) {
self.root.insert_route(
Method::Get,
path,
RouteHandler::WebSocket(Arc::new(handler)),
);
}
#[cfg(feature = "grpc")]
pub fn grpc(&mut self, grpc_router: super::dispatch::GrpcRouter) {
self.grpc_router = Some(grpc_router);
}
pub fn proxy(&mut self, prefix: &str, backend: &str) {
self.insert_proxy_routes(prefix, backend, None, false);
}
pub fn proxy_checked(&mut self, prefix: &str, backend: &str, healthy: Arc<AtomicBool>) {
self.insert_proxy_routes(prefix, backend, Some(healthy), false);
}
pub fn proxy_stream(&mut self, prefix: &str, backend: &str) {
self.insert_proxy_routes(prefix, backend, None, true);
}
pub fn proxy_checked_stream(&mut self, prefix: &str, backend: &str, healthy: Arc<AtomicBool>) {
self.insert_proxy_routes(prefix, backend, Some(healthy), true);
}
fn insert_proxy_routes(
&mut self,
prefix: &str,
backend: &str,
healthy: Option<Arc<AtomicBool>>,
streaming: bool,
) {
let backend: Arc<str> = backend.into();
let prefix_owned: Arc<str> = prefix.into();
let wildcard_pattern = format!("{prefix}/*proxy_path");
let exact_pattern = match prefix.is_empty() {
true => "/".to_owned(),
false => prefix.to_owned(),
};
let methods = [
Method::Get,
Method::Post,
Method::Put,
Method::Delete,
Method::Patch,
Method::Head,
Method::Options,
];
for method in methods {
for pattern in [wildcard_pattern.as_str(), exact_pattern.as_str()] {
let handler = proxy_route_handler(
streaming,
Arc::clone(&backend),
Arc::clone(&prefix_owned),
healthy.as_ref().map(Arc::clone),
);
self.root.insert_route(method, pattern, handler);
}
}
}
pub fn static_files(&mut self, prefix: &str, dir: &str) {
let exact_base_dir: Box<std::path::Path> = std::path::PathBuf::from(dir).into_boxed_path();
let wildcard_base_dir: Box<std::path::Path> =
std::path::PathBuf::from(dir).into_boxed_path();
let wildcard_pattern = format!("{prefix}/*filepath");
let exact_pattern = match prefix.is_empty() {
true => "/".to_owned(),
false => prefix.to_owned(),
};
self.root.insert_route(
Method::Get,
&exact_pattern,
RouteHandler::Async(Box::new(move |_req: &Request| {
let resp = super::static_files::serve_file(&exact_base_dir, "index.html");
Box::pin(async move { resp }) as Pin<Box<dyn Future<Output = Response> + Send>>
})),
);
self.root.insert_route(
Method::Get,
&wildcard_pattern,
RouteHandler::Async(Box::new(move |req: &Request| {
let file_path = req.param("filepath").unwrap_or("");
let resp = super::static_files::serve_file(&wildcard_base_dir, file_path);
Box::pin(async move { resp }) as Pin<Box<dyn Future<Output = Response> + Send>>
})),
);
}
fn add<F, Fut, R>(&mut self, method: Method, path: &str, handler: F)
where
F: Fn(&Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
R: IntoResponse + 'static,
{
self.root.insert_route(
method,
path,
RouteHandler::Async(Box::new(move |req: &Request| {
let fut = handler(req);
Box::pin(async move { fut.await.into_response() })
as Pin<Box<dyn Future<Output = Response> + Send>>
})),
);
}
fn add_stream(
&mut self,
method: Method,
path: &str,
handler: impl Fn(&Request) -> Pin<Box<dyn Future<Output = StreamResponse> + Send>>
+ Send
+ Sync
+ 'static,
) {
self.root
.insert_route(method, path, RouteHandler::Stream(Box::new(handler)));
}
pub(super) fn freeze(self) -> FrozenRouter {
FrozenRouter {
root: self.root.freeze(),
middleware: self.middleware.into_boxed_slice(),
skip_middleware_for_internal: self.skip_middleware_for_internal,
#[cfg(feature = "grpc")]
grpc_router: self.grpc_router,
}
}
}
fn proxy_route_handler(
streaming: bool,
backend: Arc<str>,
prefix: Arc<str>,
healthy: Option<Arc<std::sync::atomic::AtomicBool>>,
) -> RouteHandler {
match streaming {
true => RouteHandler::ProxyStream {
backend,
prefix,
healthy,
},
false => RouteHandler::Proxy {
backend,
prefix,
healthy,
},
}
}