1use crate::cors::Cors;
4use crate::response::File;
5use crate::response::IntoResponse;
6use hyper::body::to_bytes;
7use hyper::{Body, Method, Request, Response, StatusCode};
8use serde::de::DeserializeOwned;
9use serde_json::from_slice as serde_from_json;
10use serde_urlencoded::{from_bytes as serde_from_bytes, from_str as serde_from_str};
11use std::collections::HashMap;
12use std::fs;
13use std::future::Future;
14use std::net::SocketAddr;
15use std::path::Path;
16use std::pin::Pin;
17use std::sync::{Arc, Mutex};
18
19type HandlerFuture = Pin<Box<dyn Future<Output = Response<Body>> + Send>>;
20type HandlerFn = Box<dyn Fn(Request<Body>) -> HandlerFuture + Send + Sync>;
21
22pub struct Router {
24 routes: HashMap<(Method, String), HandlerFn>,
25 cors: Option<Cors>,
26 pub(crate) enable_logger: bool,
27 pub(crate) outputs: Arc<Mutex<String>>,
28}
29
30impl Router {
31 pub fn new() -> Self {
33 Self {
34 routes: HashMap::new(),
35 cors: None,
36 enable_logger: false,
37 outputs: Arc::new(Mutex::new(String::new())),
38 }
39 .visit_dirs(Path::new("./static"))
40 }
41
42 pub fn with_cors(mut self, cors: Cors) -> Self {
44 self.cors = Some(cors);
45 self
46 }
47
48 pub fn cors(mut self) -> Self {
50 self.cors = Some(Cors::default());
51 self
52 }
53
54 fn visit_dirs(mut self, dir: &Path) -> Self {
55 if dir.is_dir() {
56 if let Ok(entries) = fs::read_dir(dir) {
57 for entry in entries {
58 if let Ok(entry) = entry {
59 let path = entry.path();
60 if path.is_dir() {
61 self = self.visit_dirs(&path); } else {
63 let p = path
65 .to_str()
66 .unwrap_or("")
67 .to_string()
68 .chars()
69 .skip(2)
70 .collect::<String>();
71 self = self.file(format!("/{p}").as_str(), p);
72 }
73 }
74 }
75 }
76 }
77 self
78 }
79
80 pub fn route(mut self, path: &str, method_handler: MethodHandler) -> Self {
82 self.routes.insert(
83 (method_handler.method, path.to_string()),
84 method_handler.handler,
85 );
86 self
87 }
88
89 pub fn file(self, path: &str, file_path: String) -> Self {
90 let file_path = Arc::new(file_path); self.route(
93 path,
94 get(move |query| {
95 let file_path = Arc::clone(&file_path);
96 async move {
97 match query
98 .get("download")
99 .unwrap_or(&"false".to_string())
100 .as_str()
101 {
102 "1" | "" | "true" => File(file_path.to_string(), true),
103 _ => File(file_path.to_string(), false),
104 }
105 }
106 }),
107 )
108 }
109
110 pub(crate) async fn handle_request(&mut self, req: Request<Body>) -> Response<Body> {
112 let method = req.method().clone();
114 let path = req.uri().path().to_string();
115 let client_ip = req
116 .extensions()
117 .get::<SocketAddr>()
118 .map(|addr| addr.ip().to_string())
119 .unwrap_or_else(|| "unknown".to_string());
120 let origin = req
121 .headers()
122 .get("origin")
123 .and_then(|h| h.to_str().ok())
124 .map(|s| s.to_owned()); if method == Method::OPTIONS {
128 if let Some(cors) = &self.cors {
129 return cors.handle_preflight(origin.as_deref());
130 }
131 }
132
133 match self.routes.get(&(method.clone(), path.clone())) {
134 Some(handler) => {
135 let mut response = handler(req).await;
136 let status = response.status();
137
138 if let Some(cors) = &self.cors {
139 cors.apply_headers(&mut response, origin.as_deref());
140 }
141 if self.enable_logger {
142 self.request_log(client_ip, &method, path, status);
143 }
144 response
145 }
146 None => {
147 self.request_log(client_ip, &method, path, StatusCode::NOT_FOUND);
148 Response::builder()
149 .status(StatusCode::NOT_FOUND)
150 .body(Body::from("404 Not Found"))
151 .unwrap()
152 }
153 }
154 }
155}
156
157impl Default for Router {
158 fn default() -> Self {
159 Self::new()
160 }
161}
162
163pub struct MethodHandler {
165 pub(crate) method: Method,
166 pub(crate) handler: HandlerFn,
167}
168
169pub fn get<F, Fut, Res>(handler: F) -> MethodHandler
171where
172 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
173 Fut: Future<Output = Res> + Send + 'static,
174 Res: IntoResponse + 'static,
175{
176 let handler_fn =
177 Box::new(
178 move |req: Request<Body>| match serde_from_str(req.uri().query().unwrap_or("")) {
179 Ok(query) => {
180 let fut = handler(query);
181 Box::pin(async move { fut.await.into_response() }) as HandlerFuture
182 }
183 Err(e) => Box::pin(async move {
184 Response::builder()
185 .status(StatusCode::BAD_REQUEST)
186 .body(Body::from(format!("{:?}<br />Invalid query parameters", e)))
187 .unwrap()
188 }) as HandlerFuture,
189 },
190 );
191
192 MethodHandler {
193 method: Method::GET,
194 handler: handler_fn,
195 }
196}
197
198pub fn post<F, Fut, Res>(handler: F) -> MethodHandler
200where
201 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
202 Fut: Future<Output = Res> + Send + 'static,
203 Res: IntoResponse + 'static,
204{
205 let handler = Arc::new(handler);
206 let handler_fn = Box::new(move |req: Request<Body>| {
207 let handler = Arc::clone(&handler);
208 Box::pin(async move {
209 handler(
210 serde_from_bytes(&to_bytes(req.into_body()).await.unwrap_or_default())
211 .unwrap_or_default(),
212 )
213 .await
214 .into_response()
215 }) as HandlerFuture
216 });
217
218 MethodHandler {
219 method: Method::POST,
220 handler: handler_fn,
221 }
222}
223
224pub fn post_json<F, Fut, Res, T>(handler: F) -> MethodHandler
226where
227 F: Fn(T) -> Fut + Send + Sync + 'static,
228 Fut: Future<Output = Res> + Send + 'static,
229 Res: IntoResponse + 'static,
230 T: DeserializeOwned + Send + 'static,
231{
232 let handler = Arc::new(handler);
233 let handler_fn = Box::new(move |req: Request<Body>| {
234 let handler = Arc::clone(&handler);
235 Box::pin(async move {
236 let body_bytes = match to_bytes(req.into_body()).await {
238 Ok(bytes) => bytes,
239 Err(e) => {
240 return Response::builder()
241 .status(StatusCode::BAD_REQUEST)
242 .body(Body::from(format!("读取请求体失败: {}", e)))
243 .unwrap();
244 }
245 };
246
247 match serde_from_json::<T>(&body_bytes) {
249 Ok(data) => handler(data).await.into_response(),
250 Err(e) => Response::builder()
251 .status(StatusCode::BAD_REQUEST)
252 .body(Body::from(format!("无效的 JSON 格式: {}", e)))
253 .unwrap(),
254 }
255 }) as HandlerFuture
256 });
257
258 MethodHandler {
259 method: Method::POST,
260 handler: handler_fn,
261 }
262}