1#[cfg(feature = "cors")]
4use crate::cors::Cors;
5
6use crate::response::File;
7use crate::response::IntoResponse;
8use hyper::body::to_bytes;
9use hyper::{Body, Method, Request, Response, StatusCode};
10use serde::de::DeserializeOwned;
11use serde_json::from_slice as serde_from_json;
12use serde_urlencoded::{from_bytes as serde_from_bytes, from_str as serde_from_str};
13use std::collections::HashMap;
14use std::fs;
15use std::future::Future;
16use std::net::SocketAddr;
17use std::path::Path;
18use std::pin::Pin;
19use std::sync::{Arc, Mutex};
20
21type HandlerFuture = Pin<Box<dyn Future<Output = Response<Body>> + Send>>;
22type HandlerFn = Box<dyn Fn(Request<Body>) -> HandlerFuture + Send + Sync>;
23
24pub struct Router {
26 routes: HashMap<(Method, String), HandlerFn>,
27
28 #[cfg(feature = "cors")]
29 cors: Option<Cors>,
30
31 pub(crate) enable_logger: bool,
32 pub(crate) outputs: Arc<Mutex<String>>,
33}
34
35impl Router {
36 pub fn new() -> Self {
38 Self {
39 routes: HashMap::new(),
40
41 #[cfg(feature = "cors")]
42 cors: None,
43
44 enable_logger: false,
45 outputs: Arc::new(Mutex::new(String::new())),
46 }
47 .visit_dirs(Path::new("./static"))
48 }
49
50 #[cfg(feature = "cors")]
51 pub fn with_cors(mut self, cors: Cors) -> Self {
52 self.cors = Some(cors);
53 self
54 }
55
56 #[cfg(feature = "cors")]
57 pub fn cors(mut self) -> Self {
58 self.cors = Some(Cors::default());
59 self
60 }
61
62 fn visit_dirs(mut self, dir: &Path) -> Self {
63 if dir.is_dir() {
64 if let Ok(entries) = fs::read_dir(dir) {
65 for entry in entries {
66 if let Ok(entry) = entry {
67 let path = entry.path();
68 if path.is_dir() {
69 self = self.visit_dirs(&path); } else {
71 let p = path
73 .to_str()
74 .unwrap_or("")
75 .to_string()
76 .chars()
77 .skip(2)
78 .collect::<String>();
79 self = self.file(format!("/{p}").as_str(), p);
80 }
81 }
82 }
83 }
84 }
85 self
86 }
87
88 pub fn route(mut self, path: &str, method_handler: MethodHandler) -> Self {
90 self.routes.insert(
91 (method_handler.method, path.to_string()),
92 method_handler.handler,
93 );
94 self
95 }
96
97 pub fn file(self, path: &str, file_path: String) -> Self {
98 let file_path = Arc::new(file_path); self.route(
101 path,
102 get(move |query| {
103 let file_path = Arc::clone(&file_path);
104 async move {
105 match query
106 .get("download")
107 .unwrap_or(&"false".to_string())
108 .as_str()
109 {
110 "1" | "" | "true" => File(file_path.to_string(), true),
111 _ => File(file_path.to_string(), false),
112 }
113 }
114 }),
115 )
116 }
117
118 pub(crate) async fn handle_request(&mut self, req: Request<Body>) -> Response<Body> {
120 let method = req.method().clone();
122 let path = req.uri().path().to_string();
123 let client_ip = req
124 .extensions()
125 .get::<SocketAddr>()
126 .map(|addr| addr.ip().to_string())
127 .unwrap_or_else(|| "unknown".to_string());
128 #[allow(unused_variables)]
129 let origin = req
130 .headers()
131 .get("origin")
132 .and_then(|h| h.to_str().ok())
133 .map(|s| s.to_owned()); if method == Method::OPTIONS {
137 #[cfg(feature = "cors")]
138 {
139 if let Some(cors) = &self.cors {
140 return cors.handle_preflight(origin.as_deref());
141 }
142 }
143 }
144
145 match self.routes.get(&(method.clone(), path.clone())) {
146 Some(handler) => {
147 #[allow(unused_mut)]
148 let mut response = handler(req).await;
149 let status = response.status();
150 #[cfg(feature = "cors")]
151 {
152 if let Some(cors) = &self.cors {
153 cors.apply_headers(&mut response, origin.as_deref());
154 }
155 }
156 if self.enable_logger {
157 self.request_log(client_ip, &method, path, status);
158 }
159 response
160 }
161 None => {
162 self.request_log(client_ip, &method, path, StatusCode::NOT_FOUND);
163 Response::builder()
164 .status(StatusCode::NOT_FOUND)
165 .body(Body::from("404 Not Found"))
166 .unwrap()
167 }
168 }
169 }
170}
171
172impl Default for Router {
173 fn default() -> Self {
174 Self::new()
175 }
176}
177
178pub struct MethodHandler {
180 pub(crate) method: Method,
181 pub(crate) handler: HandlerFn,
182}
183
184pub fn get<F, Fut, Res>(handler: F) -> MethodHandler
186where
187 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
188 Fut: Future<Output = Res> + Send + 'static,
189 Res: IntoResponse + 'static,
190{
191 let handler_fn =
192 Box::new(
193 move |req: Request<Body>| match serde_from_str(req.uri().query().unwrap_or("")) {
194 Ok(query) => {
195 let fut = handler(query);
196 Box::pin(async move { fut.await.into_response() }) as HandlerFuture
197 }
198 Err(e) => Box::pin(async move {
199 Response::builder()
200 .status(StatusCode::BAD_REQUEST)
201 .body(Body::from(format!("{:?}<br />Invalid query parameters", e)))
202 .unwrap()
203 }) as HandlerFuture,
204 },
205 );
206
207 MethodHandler {
208 method: Method::GET,
209 handler: handler_fn,
210 }
211}
212
213pub fn post<F, Fut, Res>(handler: F) -> MethodHandler
215where
216 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
217 Fut: Future<Output = Res> + Send + 'static,
218 Res: IntoResponse + 'static,
219{
220 let handler = Arc::new(handler);
221 let handler_fn = Box::new(move |req: Request<Body>| {
222 let handler = Arc::clone(&handler);
223 Box::pin(async move {
224 handler(
225 serde_from_bytes(&to_bytes(req.into_body()).await.unwrap_or_default())
226 .unwrap_or_default(),
227 )
228 .await
229 .into_response()
230 }) as HandlerFuture
231 });
232
233 MethodHandler {
234 method: Method::POST,
235 handler: handler_fn,
236 }
237}
238
239pub fn post_json<F, Fut, Res, T>(handler: F) -> MethodHandler
241where
242 F: Fn(T) -> Fut + Send + Sync + 'static,
243 Fut: Future<Output = Res> + Send + 'static,
244 Res: IntoResponse + 'static,
245 T: DeserializeOwned + Send + 'static,
246{
247 let handler = Arc::new(handler);
248 let handler_fn = Box::new(move |req: Request<Body>| {
249 let handler = Arc::clone(&handler);
250 Box::pin(async move {
251 let body_bytes = match to_bytes(req.into_body()).await {
253 Ok(bytes) => bytes,
254 Err(e) => {
255 return Response::builder()
256 .status(StatusCode::BAD_REQUEST)
257 .body(Body::from(format!("读取请求体失败: {}", e)))
258 .unwrap();
259 }
260 };
261
262 match serde_from_json::<T>(&body_bytes) {
264 Ok(data) => handler(data).await.into_response(),
265 Err(e) => Response::builder()
266 .status(StatusCode::BAD_REQUEST)
267 .body(Body::from(format!("无效的 JSON 格式: {}", e)))
268 .unwrap(),
269 }
270 }) as HandlerFuture
271 });
272
273 MethodHandler {
274 method: Method::POST,
275 handler: handler_fn,
276 }
277}