1use crate::cors::Cors;
4use crate::response::File;
5use crate::response::IntoResponse;
6use hyper::body::to_bytes;
7use hyper::{Body, Method, Request, Response, StatusCode};
8use serde::de::DeserializeOwned;
9use serde_json::from_slice as serde_from_json;
10use serde_urlencoded::{from_bytes as serde_from_bytes, from_str as serde_from_str};
11use std::collections::HashMap;
12use std::fs;
13use std::future::Future;
14use std::net::SocketAddr;
15use std::path::Path;
16use std::pin::Pin;
17use std::sync::Arc;
18
19type HandlerFuture = Pin<Box<dyn Future<Output = Response<Body>> + Send>>;
20type HandlerFn = Box<dyn Fn(Request<Body>) -> HandlerFuture + Send + Sync>;
21
22pub struct Router {
24 routes: HashMap<(Method, String), HandlerFn>,
25 cors: Option<Cors>,
26 pub enable_logger: bool,
27}
28
29impl Router {
30 pub fn new() -> Self {
32 Self {
33 routes: HashMap::new(),
34 cors: None,
35 enable_logger: false,
36 }
37 .visit_dirs(Path::new("./static"))
38 }
39
40 pub fn with_cors(mut self, cors: Cors) -> Self {
42 self.cors = Some(cors);
43 self
44 }
45
46 pub fn cors(mut self) -> Self {
48 self.cors = Some(Cors::default());
49 self
50 }
51
52 fn visit_dirs(mut self, dir: &Path) -> Self {
53 if dir.is_dir() {
54 if let Ok(entries) = fs::read_dir(dir) {
55 for entry in entries {
56 if let Ok(entry) = entry {
57 let path = entry.path();
58 if path.is_dir() {
59 self = self.visit_dirs(&path); } else {
61 let p = path
63 .to_str()
64 .unwrap_or("")
65 .to_string()
66 .chars()
67 .skip(2)
68 .collect::<String>();
69 self = self.file(format!("/{p}").as_str(), p);
70 }
71 }
72 }
73 }
74 }
75 self
76 }
77
78 pub fn route(mut self, path: &str, method_handler: MethodHandler) -> Self {
80 self.routes.insert(
81 (method_handler.method, path.to_string()),
82 method_handler.handler,
83 );
84 self
85 }
86
87 pub fn file(self, path: &str, file_path: String) -> Self {
88 let file_path = Arc::new(file_path); self.route(
91 path,
92 get(move |query| {
93 let file_path = Arc::clone(&file_path);
94 async move {
95 match query
96 .get("download")
97 .unwrap_or(&"false".to_string())
98 .as_str()
99 {
100 "1" | "" | "true" => File(file_path.to_string(), true),
101 _ => File(file_path.to_string(), false),
102 }
103 }
104 }),
105 )
106 }
107
108 pub(crate) async fn handle_request(&self, req: Request<Body>) -> Response<Body> {
110 let method = req.method().clone();
112 let path = req.uri().path().to_string();
113 let client_ip = req
114 .extensions()
115 .get::<SocketAddr>()
116 .map(|addr| addr.ip().to_string())
117 .unwrap_or_else(|| "unknown".to_string());
118 let origin = req
119 .headers()
120 .get("origin")
121 .and_then(|h| h.to_str().ok())
122 .map(|s| s.to_owned()); if method == Method::OPTIONS {
126 if let Some(cors) = &self.cors {
127 return cors.handle_preflight(origin.as_deref());
128 }
129 }
130
131 match self.routes.get(&(method.clone(), path.clone())) {
132 Some(handler) => {
133 let mut response = handler(req).await;
134 let status = response.status();
135
136 if let Some(cors) = &self.cors {
137 cors.apply_headers(&mut response, origin.as_deref());
138 }
139 if self.enable_logger {
140 self.request_log(client_ip, &method, path, status);
141 }
142 response
143 }
144 None => {
145 self.request_log(client_ip, &method, path, StatusCode::NOT_FOUND);
146 Response::builder()
147 .status(StatusCode::NOT_FOUND)
148 .body(Body::from("404 Not Found"))
149 .unwrap()
150 }
151 }
152 }
153}
154
155impl Default for Router {
156 fn default() -> Self {
157 Self::new()
158 }
159}
160
161pub struct MethodHandler {
163 pub(crate) method: Method,
164 pub(crate) handler: HandlerFn,
165}
166
167pub fn get<F, Fut, Res>(handler: F) -> MethodHandler
169where
170 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
171 Fut: Future<Output = Res> + Send + 'static,
172 Res: IntoResponse + 'static,
173{
174 let handler_fn =
175 Box::new(
176 move |req: Request<Body>| match serde_from_str(req.uri().query().unwrap_or("")) {
177 Ok(query) => {
178 let fut = handler(query);
179 Box::pin(async move { fut.await.into_response() }) as HandlerFuture
180 }
181 Err(e) => Box::pin(async move {
182 Response::new(Body::from(format!("{:?}<br />Invalid query parameters", e)))
183 }) as HandlerFuture,
184 },
185 );
186
187 MethodHandler {
188 method: Method::GET,
189 handler: handler_fn,
190 }
191}
192
193pub fn post<F, Fut, Res>(handler: F) -> MethodHandler
195where
196 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
197 Fut: Future<Output = Res> + Send + 'static,
198 Res: IntoResponse + 'static,
199{
200 let handler = Arc::new(handler);
201 let handler_fn = Box::new(move |req: Request<Body>| {
202 let handler = Arc::clone(&handler);
203 Box::pin(async move {
204 handler(
205 serde_from_bytes(&to_bytes(req.into_body()).await.unwrap_or_default())
206 .unwrap_or_default(),
207 )
208 .await
209 .into_response()
210 }) as HandlerFuture
211 });
212
213 MethodHandler {
214 method: Method::POST,
215 handler: handler_fn,
216 }
217}
218
219pub fn post_json<F, Fut, Res, T>(handler: F) -> MethodHandler
221where
222 F: Fn(T) -> Fut + Send + Sync + 'static,
223 Fut: Future<Output = Res> + Send + 'static,
224 Res: IntoResponse + 'static,
225 T: DeserializeOwned + Send + 'static,
226{
227 let handler = Arc::new(handler);
228 let handler_fn = Box::new(move |req: Request<Body>| {
229 let handler = Arc::clone(&handler);
230 Box::pin(async move {
231 let body_bytes = match to_bytes(req.into_body()).await {
233 Ok(bytes) => bytes,
234 Err(e) => {
235 return Response::builder()
236 .status(StatusCode::BAD_REQUEST)
237 .body(Body::from(format!("读取请求体失败: {}", e)))
238 .unwrap();
239 }
240 };
241
242 match serde_from_json::<T>(&body_bytes) {
244 Ok(data) => handler(data).await.into_response(),
245 Err(e) => Response::builder()
246 .status(StatusCode::BAD_REQUEST)
247 .body(Body::from(format!("无效的 JSON 格式: {}", e)))
248 .unwrap(),
249 }
250 }) as HandlerFuture
251 });
252
253 MethodHandler {
254 method: Method::POST,
255 handler: handler_fn,
256 }
257}