1use crate::response::File;
4use crate::response::IntoResponse;
5use crate::cors::Cors;
6use hyper::body::to_bytes;
7use hyper::{Body, Method, Request, Response, StatusCode};
8use serde::de::DeserializeOwned;
9use serde_json::from_slice as serde_from_json;
10use serde_urlencoded::{from_bytes as serde_from_bytes, from_str as serde_from_str};
11use std::collections::HashMap;
12use std::fs;
13use std::future::Future;
14use std::path::Path;
15use std::pin::Pin;
16use std::sync::Arc;
17
18type HandlerFuture = Pin<Box<dyn Future<Output = Response<Body>> + Send>>;
19type HandlerFn = Box<dyn Fn(Request<Body>) -> HandlerFuture + Send + Sync>;
20
21pub struct Router {
23 routes: HashMap<(Method, String), HandlerFn>,
24 cors: Option<Cors>, }
26
27impl Router {
28 pub fn new() -> Self {
30 Self {
31 routes: HashMap::new(),
32 cors: None, }
34 .visit_dirs(Path::new("./static"))
35 }
36
37 pub fn with_cors(mut self, cors: Cors) -> Self {
39 self.cors = Some(cors);
40 self
41 }
42
43 pub fn cors(mut self) -> Self {
45 self.cors = Some(Cors::default());
46 self
47 }
48
49 fn visit_dirs(mut self, dir: &Path) -> Self {
50 if dir.is_dir() {
51 if let Ok(entries) = fs::read_dir(dir) {
52 for entry in entries {
53 if let Ok(entry) = entry {
54 let path = entry.path();
55 if path.is_dir() {
56 self = self.visit_dirs(&path); } else {
58 let p = path
60 .to_str()
61 .unwrap_or("")
62 .to_string()
63 .chars()
64 .skip(2)
65 .collect::<String>();
66 self = self.file(format!("/{p}").as_str(), p);
67 }
68 }
69 }
70 }
71 }
72 self
73 }
74
75 pub fn route(mut self, path: &str, method_handler: MethodHandler) -> Self {
77 self.routes.insert(
78 (method_handler.method, path.to_string()),
79 method_handler.handler,
80 );
81 self
82 }
83
84 pub fn file(self, path: &str, file_path: String) -> Self {
85 let file_path = Arc::new(file_path); self.route(
88 path,
89 get(move |query| {
90 let file_path = Arc::clone(&file_path);
91 async move {
92 match query
93 .get("download")
94 .unwrap_or(&"false".to_string())
95 .as_str()
96 {
97 "1" | "" | "true" => File(file_path.to_string(), true),
98 _ => File(file_path.to_string(), false),
99 }
100 }
101 }),
102 )
103 }
104
105 pub(crate) async fn handle_request(&self, req: Request<Body>) -> Response<Body> {
107 let method = req.method().clone();
109 let path = req.uri().path().to_string();
110 let origin = req.headers()
111 .get("origin")
112 .and_then(|h| h.to_str().ok())
113 .map(|s| s.to_owned()); if method == Method::OPTIONS {
117 if let Some(cors) = &self.cors {
118 return cors.handle_preflight(origin.as_deref());
119 }
120 }
121
122 let key = (method, path);
123
124 match self.routes.get(&key) {
125 Some(handler) => {
126 let mut response = handler(req).await; if let Some(cors) = &self.cors {
129 cors.apply_headers(&mut response, origin.as_deref());
130 }
131 response
132 }
133 None => Response::builder()
134 .status(StatusCode::NOT_FOUND)
135 .body(Body::from("404 Not Found"))
136 .unwrap(),
137 }
138}
139}
140
141impl Default for Router {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147pub struct MethodHandler {
149 pub(crate) method: Method,
150 pub(crate) handler: HandlerFn,
151}
152
153pub fn get<F, Fut, Res>(handler: F) -> MethodHandler
155where
156 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
157 Fut: Future<Output = Res> + Send + 'static,
158 Res: IntoResponse + 'static,
159{
160 let handler_fn =
161 Box::new(
162 move |req: Request<Body>| match serde_from_str(req.uri().query().unwrap_or("")) {
163 Ok(query) => {
164 let fut = handler(query);
165 Box::pin(async move { fut.await.into_response() }) as HandlerFuture
166 }
167 Err(e) => Box::pin(async move {
168 Response::new(Body::from(format!("{:?}<br />Invalid query parameters", e)))
169 }) as HandlerFuture,
170 },
171 );
172
173 MethodHandler {
174 method: Method::GET,
175 handler: handler_fn,
176 }
177}
178
179pub fn post<F, Fut, Res>(handler: F) -> MethodHandler
181where
182 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
183 Fut: Future<Output = Res> + Send + 'static,
184 Res: IntoResponse + 'static,
185{
186 let handler = Arc::new(handler);
187 let handler_fn = Box::new(move |req: Request<Body>| {
188 let handler = Arc::clone(&handler);
189 Box::pin(async move {
190 handler(
191 serde_from_bytes(&to_bytes(req.into_body()).await.unwrap_or_default())
192 .unwrap_or_default(),
193 )
194 .await
195 .into_response()
196 }) as HandlerFuture
197 });
198
199 MethodHandler {
200 method: Method::POST,
201 handler: handler_fn,
202 }
203}
204
205pub fn post_json<F, Fut, Res, T>(handler: F) -> MethodHandler
207where
208 F: Fn(T) -> Fut + Send + Sync + 'static,
209 Fut: Future<Output = Res> + Send + 'static,
210 Res: IntoResponse + 'static,
211 T: DeserializeOwned + Send + 'static,
212{
213 let handler = Arc::new(handler);
214 let handler_fn = Box::new(move |req: Request<Body>| {
215 let handler = Arc::clone(&handler);
216 Box::pin(async move {
217 let body_bytes = match to_bytes(req.into_body()).await {
219 Ok(bytes) => bytes,
220 Err(e) => {
221 return Response::builder()
222 .status(StatusCode::BAD_REQUEST)
223 .body(Body::from(format!("读取请求体失败: {}", e)))
224 .unwrap();
225 }
226 };
227
228 match serde_from_json::<T>(&body_bytes) {
230 Ok(data) => handler(data).await.into_response(),
231 Err(e) => Response::builder()
232 .status(StatusCode::BAD_REQUEST)
233 .body(Body::from(format!("无效的 JSON 格式: {}", e)))
234 .unwrap(),
235 }
236 }) as HandlerFuture
237 });
238
239 MethodHandler {
240 method: Method::POST,
241 handler: handler_fn,
242 }
243}