1#![doc = include_str!("readme.md")]
2#![warn(missing_docs)]
3
4pub mod error;
5pub mod extract;
6pub mod middleware;
7pub mod response;
8pub mod router;
9pub mod template;
10pub mod tls;
11
12pub use wae_session as session;
13
14pub use response::{Attachment, Html, JsonResponse, Redirect, StreamResponse};
15pub use router::{MethodRouter, RouterBuilder, delete, get, head, options, patch, post, put, trace};
16
17use http::{Response, StatusCode, header};
18use http_body_util::Full;
19use hyper::body::Bytes;
20use std::{net::SocketAddr, path::Path, sync::Arc, time::Duration};
21use tokio::net::TcpListener;
22use tracing::info;
23
24pub use wae_types::{WaeError, WaeResult};
25
26pub type Body = Full<Bytes>;
28
29pub fn empty_body() -> Body {
31 Full::new(Bytes::new())
32}
33
34pub fn full_body<B: Into<Bytes>>(data: B) -> Body {
36 Full::new(data.into())
37}
38
39pub type HttpsResult<T> = WaeResult<T>;
41
42pub type HttpsError = WaeError;
44
45pub trait IntoResponse {
49 fn into_response(self) -> Response<Body>;
51}
52
53impl IntoResponse for Response<Body> {
54 fn into_response(self) -> Response<Body> {
55 self
56 }
57}
58
59impl IntoResponse for &'static str {
60 fn into_response(self) -> Response<Body> {
61 Response::builder()
62 .status(StatusCode::OK)
63 .header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
64 .body(full_body(self))
65 .unwrap()
66 }
67}
68
69impl IntoResponse for String {
70 fn into_response(self) -> Response<Body> {
71 Response::builder()
72 .status(StatusCode::OK)
73 .header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
74 .body(full_body(self))
75 .unwrap()
76 }
77}
78
79impl<T: IntoResponse> IntoResponse for (StatusCode, T) {
80 fn into_response(self) -> Response<Body> {
81 let mut res = self.1.into_response();
82 *res.status_mut() = self.0;
83 res
84 }
85}
86
87type RouteHandlerFn<S> = Arc<dyn Fn(crate::extract::RequestParts, S) -> Response<Body> + Send + Sync + 'static>;
89
90pub struct Router<S = ()> {
92 routes: std::collections::HashMap<http::Method, matchit::Router<RouteHandlerFn<S>>>,
93 raw_routes: Vec<RouteEntry<S>>,
94 state: S,
95}
96
97struct RouteEntry<S> {
99 method: http::Method,
100 path: String,
101 handler: RouteHandlerFn<S>,
102}
103
104impl<S: Clone> Clone for RouteEntry<S> {
105 fn clone(&self) -> Self {
106 Self { method: self.method.clone(), path: self.path.clone(), handler: self.handler.clone() }
107 }
108}
109
110impl<S: Clone> Clone for Router<S> {
111 fn clone(&self) -> Self {
112 let mut routes = std::collections::HashMap::new();
113 for (method, _) in &self.routes {
114 let new_router = matchit::Router::new();
115 routes.insert(method.clone(), new_router);
116 }
117 let mut new_router = Self { routes, raw_routes: Vec::new(), state: self.state.clone() };
118 for entry in &self.raw_routes {
119 let router = new_router.routes.entry(entry.method.clone()).or_insert_with(matchit::Router::new);
120 let _ = router.insert(entry.path.clone(), entry.handler.clone());
121 new_router.raw_routes.push(entry.clone());
122 }
123 new_router
124 }
125}
126
127impl Default for Router<()> {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133impl Router<()> {
134 pub fn new() -> Self {
136 Self { routes: std::collections::HashMap::new(), raw_routes: Vec::new(), state: () }
137 }
138}
139
140impl<S> Router<S> {
141 pub fn with_state(state: S) -> Self {
143 Self { routes: std::collections::HashMap::new(), raw_routes: Vec::new(), state }
144 }
145
146 pub fn state(&self) -> &S {
148 &self.state
149 }
150
151 pub fn state_mut(&mut self) -> &mut S {
153 &mut self.state
154 }
155
156 pub fn add_route_inner(
158 &mut self,
159 _method: http::Method,
160 _path: String,
161 _handler: Box<dyn std::any::Any + Send + Sync + 'static>,
162 ) {
163 }
164}
165
166impl<S> Router<S>
167where
168 S: Clone + Send + Sync + 'static,
169{
170 pub fn add_route<H, T>(&mut self, method: http::Method, path: &str, handler: H)
172 where
173 H: Fn(T) -> Response<Body> + Clone + Send + Sync + 'static,
174 T: crate::extract::FromRequestParts<S, Error = crate::extract::ExtractorError> + 'static,
175 {
176 let handler_fn: RouteHandlerFn<S> = Arc::new(move |parts, state| {
177 let handler = handler.clone();
178 match T::from_request_parts(&parts, &state) {
179 Ok(t) => handler(t),
180 Err(e) => {
181 let error_msg = e.to_string();
182 Response::builder()
183 .status(StatusCode::BAD_REQUEST)
184 .header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
185 .body(full_body(error_msg))
186 .unwrap()
187 }
188 }
189 });
190
191 let entry = RouteEntry { method: method.clone(), path: path.to_string(), handler: handler_fn.clone() };
192 self.raw_routes.push(entry);
193
194 let router = self.routes.entry(method).or_insert_with(matchit::Router::new);
195 let _ = router.insert(path, handler_fn);
196 }
197
198 pub fn merge(mut self, other: Router<S>) -> Self {
200 for entry in other.raw_routes {
201 let router = self.routes.entry(entry.method.clone()).or_insert_with(matchit::Router::new);
202 let _ = router.insert(entry.path.clone(), entry.handler.clone());
203 self.raw_routes.push(entry);
204 }
205 self
206 }
207
208 pub fn nest_service<T>(mut self, prefix: &str, service: T) -> Self
210 where
211 T: Into<Router<S>>,
212 {
213 let other = service.into();
214 for entry in other.raw_routes {
215 let new_path = format!("{}{}", prefix.trim_end_matches('/'), entry.path);
216 let router = self.routes.entry(entry.method.clone()).or_insert_with(matchit::Router::new);
217 let _ = router.insert(new_path.clone(), entry.handler.clone());
218 self.raw_routes.push(RouteEntry { method: entry.method, path: new_path, handler: entry.handler });
219 }
220 self
221 }
222}
223
224#[derive(Debug, Clone, Copy, Default)]
228pub enum HttpVersion {
229 Http1Only,
231 Http2Only,
233 #[default]
235 Both,
236 Http3,
238}
239
240#[derive(Debug, Clone)]
244pub struct Http2Config {
245 pub enabled: bool,
247 pub enable_push: bool,
249 pub max_concurrent_streams: u32,
251 pub initial_stream_window_size: u32,
253 pub max_frame_size: u32,
255 pub enable_connect_protocol: bool,
257 pub stream_idle_timeout: Duration,
259}
260
261impl Default for Http2Config {
262 fn default() -> Self {
263 Self {
264 enabled: true,
265 enable_push: false,
266 max_concurrent_streams: 256,
267 initial_stream_window_size: 65535,
268 max_frame_size: 16384,
269 enable_connect_protocol: false,
270 stream_idle_timeout: Duration::from_secs(60),
271 }
272 }
273}
274
275impl Http2Config {
276 pub fn new() -> Self {
278 Self::default()
279 }
280
281 pub fn disabled() -> Self {
283 Self { enabled: false, ..Self::default() }
284 }
285
286 pub fn with_enable_push(mut self, enable: bool) -> Self {
288 self.enable_push = enable;
289 self
290 }
291
292 pub fn with_max_concurrent_streams(mut self, max: u32) -> Self {
294 self.max_concurrent_streams = max;
295 self
296 }
297
298 pub fn with_initial_stream_window_size(mut self, size: u32) -> Self {
300 self.initial_stream_window_size = size;
301 self
302 }
303
304 pub fn with_max_frame_size(mut self, size: u32) -> Self {
306 self.max_frame_size = size;
307 self
308 }
309
310 pub fn with_enable_connect_protocol(mut self, enable: bool) -> Self {
312 self.enable_connect_protocol = enable;
313 self
314 }
315
316 pub fn with_stream_idle_timeout(mut self, timeout: Duration) -> Self {
318 self.stream_idle_timeout = timeout;
319 self
320 }
321}
322
323#[derive(Debug, Clone)]
327pub struct TlsConfig {
328 pub cert_path: String,
330 pub key_path: String,
332}
333
334impl TlsConfig {
335 pub fn new(cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
342 Self { cert_path: cert_path.into(), key_path: key_path.into() }
343 }
344}
345
346#[derive(Debug, Clone, Default)]
350pub struct Http3Config {
351 pub enabled: bool,
353}
354
355impl Http3Config {
356 pub fn new() -> Self {
358 Self::default()
359 }
360
361 pub fn enabled() -> Self {
363 Self { enabled: true }
364 }
365}
366
367#[derive(Debug, Clone)]
371pub struct HttpsServerConfig {
372 pub addr: SocketAddr,
374 pub service_name: String,
376 pub http_version: HttpVersion,
378 pub http2_config: Http2Config,
380 pub http3_config: Http3Config,
382 pub tls_config: Option<TlsConfig>,
384}
385
386impl Default for HttpsServerConfig {
387 fn default() -> Self {
388 Self {
389 addr: "0.0.0.0:3000".parse().unwrap(),
390 service_name: "wae-https-service".to_string(),
391 http_version: HttpVersion::Both,
392 http2_config: Http2Config::default(),
393 http3_config: Http3Config::default(),
394 tls_config: None,
395 }
396 }
397}
398
399pub struct HttpsServerBuilder<S = ()> {
403 config: HttpsServerConfig,
404 router: Router<S>,
405 _marker: std::marker::PhantomData<S>,
406}
407
408impl HttpsServerBuilder<()> {
409 pub fn new() -> Self {
411 Self { config: HttpsServerConfig::default(), router: Router::new(), _marker: std::marker::PhantomData }
412 }
413}
414
415impl Default for HttpsServerBuilder<()> {
416 fn default() -> Self {
417 Self::new()
418 }
419}
420
421impl<S> HttpsServerBuilder<S>
422where
423 S: Clone + Send + Sync + 'static,
424{
425 pub fn addr(mut self, addr: SocketAddr) -> Self {
427 self.config.addr = addr;
428 self
429 }
430
431 pub fn service_name(mut self, name: impl Into<String>) -> Self {
433 self.config.service_name = name.into();
434 self
435 }
436
437 pub fn router<T>(mut self, router: T) -> Self
439 where
440 T: Into<Router<S>>,
441 {
442 self.router = router.into();
443 self
444 }
445
446 pub fn merge_router(mut self, router: Router<S>) -> Self {
448 self.router = self.router.merge(router);
449 self
450 }
451
452 pub fn http_version(mut self, version: HttpVersion) -> Self {
454 self.config.http_version = version;
455 self
456 }
457
458 pub fn http2_config(mut self, config: Http2Config) -> Self {
460 self.config.http2_config = config;
461 self
462 }
463
464 pub fn http3_config(mut self, config: Http3Config) -> Self {
466 self.config.http3_config = config;
467 self
468 }
469
470 pub fn tls(mut self, cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
477 self.config.tls_config = Some(TlsConfig::new(cert_path, key_path));
478 self
479 }
480
481 pub fn tls_config(mut self, config: TlsConfig) -> Self {
483 self.config.tls_config = Some(config);
484 self
485 }
486
487 pub fn build(self) -> HttpsServer<S> {
489 HttpsServer { config: self.config, router: self.router, _marker: std::marker::PhantomData }
490 }
491}
492
493pub struct HttpsServer<S = ()> {
497 config: HttpsServerConfig,
498 router: Router<S>,
499 _marker: std::marker::PhantomData<S>,
500}
501
502impl<S> HttpsServer<S>
503where
504 S: Clone + Send + Sync + 'static,
505{
506 pub async fn serve(self) -> HttpsResult<()> {
508 let addr = self.config.addr;
509 let service_name = self.config.service_name.clone();
510 let protocol_info = self.get_protocol_info();
511 let tls_config = self.config.tls_config.clone();
512
513 let listener =
514 TcpListener::bind(addr).await.map_err(|e| WaeError::internal(format!("Failed to bind address: {}", e)))?;
515
516 info!("{} {} server starting on {}", service_name, protocol_info, addr);
517
518 match tls_config {
519 Some(tls_config) => self.serve_tls(listener, &tls_config).await,
520 None => self.serve_plain(listener).await,
521 }
522 }
523
524 async fn serve_plain(self, listener: TcpListener) -> HttpsResult<()> {
526 loop {
527 let (stream, _addr) = listener.accept().await.map_err(|e| WaeError::internal(format!("Accept error: {}", e)))?;
528
529 let router = self.router.clone();
530 tokio::spawn(async move {
531 let service = RouterService::new(router);
532 let io = hyper_util::rt::tokio::TokioIo::new(stream);
533 let _ = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
534 .serve_connection(io, service)
535 .await;
536 });
537 }
538 }
539
540 async fn serve_tls(self, listener: TcpListener, tls_config: &TlsConfig) -> HttpsResult<()> {
542 let enable_http2 = matches!(self.config.http_version, HttpVersion::Http2Only | HttpVersion::Both);
543
544 let acceptor = crate::tls::create_tls_acceptor_with_http2(&tls_config.cert_path, &tls_config.key_path, enable_http2)?;
545
546 loop {
547 let (stream, _addr) = listener.accept().await.map_err(|e| WaeError::internal(format!("Accept error: {}", e)))?;
548
549 let acceptor = acceptor.clone();
550 let router = self.router.clone();
551
552 tokio::spawn(async move {
553 let tls_stream = match acceptor.accept(stream).await {
554 Ok(s) => s,
555 Err(e) => {
556 tracing::error!("TLS handshake error: {}", e);
557 return;
558 }
559 };
560
561 let service = RouterService::new(router);
562 let io = hyper_util::rt::tokio::TokioIo::new(tls_stream);
563 let _ = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
564 .serve_connection(io, service)
565 .await;
566 });
567 }
568 }
569
570 fn get_protocol_info(&self) -> String {
571 let tls_info = if self.config.tls_config.is_some() { "S" } else { "" };
572 let version_info = match self.config.http_version {
573 HttpVersion::Http1Only => "HTTP/1.1",
574 HttpVersion::Http2Only => "HTTP/2",
575 HttpVersion::Both => "HTTP/1.1+HTTP/2",
576 HttpVersion::Http3 => "HTTP/3",
577 };
578 format!("{}{}", version_info, tls_info)
579 }
580}
581
582#[derive(Debug, serde::Serialize)]
586pub struct ApiResponse<T> {
587 pub success: bool,
589 pub data: Option<T>,
591 pub error: Option<ApiErrorBody>,
593 pub trace_id: Option<String>,
595}
596
597#[derive(Debug, serde::Serialize)]
601pub struct ApiErrorBody {
602 pub code: String,
604 pub message: String,
606}
607
608impl<T: serde::Serialize> ApiResponse<T> {
609 pub fn into_response(self) -> Response<Body> {
611 let status = if self.success { StatusCode::OK } else { StatusCode::BAD_REQUEST };
612 let body = serde_json::to_string(&self).unwrap_or_default();
613 Response::builder()
614 .status(status)
615 .header(header::CONTENT_TYPE, "application/json")
616 .body(Full::new(Bytes::from(body)))
617 .unwrap()
618 }
619}
620
621impl<T> IntoResponse for ApiResponse<T>
622where
623 T: serde::Serialize,
624{
625 fn into_response(self) -> Response<Body> {
626 self.into_response()
627 }
628}
629
630impl<T> ApiResponse<T>
631where
632 T: serde::Serialize,
633{
634 pub fn success(data: T) -> Self {
640 Self { success: true, data: Some(data), error: None, trace_id: None }
641 }
642
643 pub fn success_with_trace(data: T, trace_id: impl Into<String>) -> Self {
650 Self { success: true, data: Some(data), error: None, trace_id: Some(trace_id.into()) }
651 }
652
653 pub fn error(code: impl Into<String>, message: impl Into<String>) -> Self {
660 Self {
661 success: false,
662 data: None,
663 error: Some(ApiErrorBody { code: code.into(), message: message.into() }),
664 trace_id: None,
665 }
666 }
667
668 pub fn error_with_trace(code: impl Into<String>, message: impl Into<String>, trace_id: impl Into<String>) -> Self {
676 Self {
677 success: false,
678 data: None,
679 error: Some(ApiErrorBody { code: code.into(), message: message.into() }),
680 trace_id: Some(trace_id.into()),
681 }
682 }
683}
684
685pub fn static_files_router(base_path: impl AsRef<Path>, prefix: &str) -> Router {
694 let mut router = Router::new();
695 let base_path = base_path.as_ref().to_path_buf();
696 let prefix = prefix.to_string();
697
698 router.add_route(http::Method::GET, &format!("{}/*path", prefix), move |parts: crate::extract::RequestParts| {
699 let path = parts.uri.path();
700 let file_path = if let Some(stripped) = path.strip_prefix(&prefix) {
701 base_path.join(stripped.trim_start_matches('/'))
702 }
703 else {
704 return Response::builder().status(StatusCode::NOT_FOUND).body(empty_body()).unwrap();
705 };
706
707 if !file_path.exists() || !file_path.is_file() {
708 return Response::builder().status(StatusCode::NOT_FOUND).body(empty_body()).unwrap();
709 }
710
711 let content = match std::fs::read(&file_path) {
712 Ok(c) => c,
713 Err(_) => {
714 return Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(empty_body()).unwrap();
715 }
716 };
717
718 let mime_type = mime_guess::from_path(&file_path).first_or_octet_stream().to_string();
719
720 Response::builder().status(StatusCode::OK).header(header::CONTENT_TYPE, mime_type).body(full_body(content)).unwrap()
721 });
722
723 router
724}
725
726pub struct RouterService<S = ()> {
728 router: Router<S>,
729}
730
731impl<S: Clone> Clone for RouterService<S> {
732 fn clone(&self) -> Self {
733 Self { router: self.router.clone() }
734 }
735}
736
737impl<S> From<Router<S>> for RouterService<S> {
738 fn from(router: Router<S>) -> Self {
739 Self { router }
740 }
741}
742
743impl<S> RouterService<S>
744where
745 S: Clone + Send + Sync + 'static,
746{
747 pub fn new(router: Router<S>) -> Self {
749 Self { router }
750 }
751
752 pub async fn handle_request(&self, request: hyper::Request<hyper::body::Incoming>) -> http::Response<Body> {
754 let (parts, _body) = request.into_parts();
755 let method = parts.method.clone();
756 let uri = parts.uri.clone();
757 let version = parts.version;
758 let headers = parts.headers.clone();
759
760 let mut request_parts = crate::extract::RequestParts::new(method.clone(), uri.clone(), version, headers);
761
762 let path = uri.path();
763
764 let Some(method_router) = self.router.routes.get(&method)
765 else {
766 return Response::builder().status(StatusCode::METHOD_NOT_ALLOWED).body(empty_body()).unwrap();
767 };
768
769 let match_result = method_router.at(path);
770
771 let Ok(matched) = match_result
772 else {
773 return Response::builder().status(StatusCode::NOT_FOUND).body(empty_body()).unwrap();
774 };
775
776 request_parts.path_params = matched.params.iter().map(|(k, v)| (k.to_string(), v.to_string())).collect();
777
778 let handler = matched.value;
779 let state = self.router.state.clone();
780
781 handler(request_parts, state)
782 }
783}
784
785impl<S> hyper::service::Service<hyper::Request<hyper::body::Incoming>> for RouterService<S>
786where
787 S: Clone + Send + Sync + 'static,
788{
789 type Response = http::Response<Body>;
790 type Error = std::convert::Infallible;
791 type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
792
793 fn call(&self, req: hyper::Request<hyper::body::Incoming>) -> Self::Future {
794 let this = self.clone();
795 Box::pin(async move {
796 let response = this.handle_request(req).await;
797 Ok(response)
798 })
799 }
800}