Skip to main content

nimble_http/
router.rs

1//! router.rs
2
3use crate::cors::Cors;
4use crate::response::File;
5use crate::response::IntoResponse;
6use hyper::body::to_bytes;
7use hyper::{Body, Method, Request, Response, StatusCode};
8use serde::de::DeserializeOwned;
9use serde_json::from_slice as serde_from_json;
10use serde_urlencoded::{from_bytes as serde_from_bytes, from_str as serde_from_str};
11use std::collections::HashMap;
12use std::fs;
13use std::future::Future;
14use std::net::SocketAddr;
15use std::path::Path;
16use std::pin::Pin;
17use std::sync::Arc;
18
19type HandlerFuture = Pin<Box<dyn Future<Output = Response<Body>> + Send>>;
20type HandlerFn = Box<dyn Fn(Request<Body>) -> HandlerFuture + Send + Sync>;
21
22/// 路由处理器
23pub struct Router {
24	routes: HashMap<(Method, String), HandlerFn>,
25	cors: Option<Cors>,
26	pub enable_logger: bool,
27}
28
29impl Router {
30	/// 创建新的路由器
31	pub fn new() -> Self {
32		Self {
33			routes: HashMap::new(),
34			cors: None,
35			enable_logger: false,
36		}
37		.visit_dirs(Path::new("./static"))
38	}
39
40	// 新增:设置 CORS
41	pub fn with_cors(mut self, cors: Cors) -> Self {
42		self.cors = Some(cors);
43		self
44	}
45
46	// 新增:使用默认 CORS
47	pub fn cors(mut self) -> Self {
48		self.cors = Some(Cors::default());
49		self
50	}
51
52	fn visit_dirs(mut self, dir: &Path) -> Self {
53		if dir.is_dir() {
54			if let Ok(entries) = fs::read_dir(dir) {
55				for entry in entries {
56					if let Ok(entry) = entry {
57						let path = entry.path();
58						if path.is_dir() {
59							self = self.visit_dirs(&path); // 递归子目录
60						} else {
61							// 在这里处理文件,例如打印路径
62							let p = path
63								.to_str()
64								.unwrap_or("")
65								.to_string()
66								.chars()
67								.skip(2)
68								.collect::<String>();
69							self = self.file(format!("/{p}").as_str(), p);
70						}
71					}
72				}
73			}
74		}
75		self
76	}
77
78	/// 添加路由
79	pub fn route(mut self, path: &str, method_handler: MethodHandler) -> Self {
80		self.routes.insert(
81			(method_handler.method, path.to_string()),
82			method_handler.handler,
83		);
84		self
85	}
86
87	pub fn file(self, path: &str, file_path: String) -> Self {
88		let file_path = Arc::new(file_path); // 在外面创建 Arc
89
90		self.route(
91			path,
92			get(move |query| {
93				let file_path = Arc::clone(&file_path);
94				async move {
95					match query
96						.get("download")
97						.unwrap_or(&"false".to_string())
98						.as_str()
99					{
100						"1" | "" | "true" => File(file_path.to_string(), true),
101						_ => File(file_path.to_string(), false),
102					}
103				}
104			}),
105		)
106	}
107
108	/// 处理传入的请求
109	pub(crate) async fn handle_request(&self, req: Request<Body>) -> Response<Body> {
110		// 提前取出需要的信息
111		let method = req.method().clone();
112		let path = req.uri().path().to_string();
113		let client_ip = req
114			.extensions()
115			.get::<SocketAddr>()
116			.map(|addr| addr.ip().to_string())
117			.unwrap_or_else(|| "unknown".to_string());
118		let origin = req
119			.headers()
120			.get("origin")
121			.and_then(|h| h.to_str().ok())
122			.map(|s| s.to_owned()); // 提前取出 origin
123
124		// 处理 OPTIONS 预检请求
125		if method == Method::OPTIONS {
126			if let Some(cors) = &self.cors {
127				return cors.handle_preflight(origin.as_deref());
128			}
129		}
130
131		match self.routes.get(&(method.clone(), path.clone())) {
132			Some(handler) => {
133				let mut response = handler(req).await;
134				let status = response.status();
135
136				if let Some(cors) = &self.cors {
137					cors.apply_headers(&mut response, origin.as_deref());
138				}
139				if self.enable_logger {
140					self.request_log(client_ip, &method, path, status);
141				}
142				response
143			}
144			None => {
145				self.request_log(client_ip, &method, path, StatusCode::NOT_FOUND);
146				Response::builder()
147					.status(StatusCode::NOT_FOUND)
148					.body(Body::from("404 Not Found"))
149					.unwrap()
150			}
151		}
152	}
153}
154
155impl Default for Router {
156	fn default() -> Self {
157		Self::new()
158	}
159}
160
161/// 方法处理器构建器
162pub struct MethodHandler {
163	pub(crate) method: Method,
164	pub(crate) handler: HandlerFn,
165}
166
167/// 创建 GET 路由
168pub fn get<F, Fut, Res>(handler: F) -> MethodHandler
169where
170	F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
171	Fut: Future<Output = Res> + Send + 'static,
172	Res: IntoResponse + 'static,
173{
174	let handler_fn =
175		Box::new(
176			move |req: Request<Body>| match serde_from_str(req.uri().query().unwrap_or("")) {
177				Ok(query) => {
178					let fut = handler(query);
179					Box::pin(async move { fut.await.into_response() }) as HandlerFuture
180				}
181				Err(e) => Box::pin(async move {
182					Response::new(Body::from(format!("{:?}<br />Invalid query parameters", e)))
183				}) as HandlerFuture,
184			},
185		);
186
187	MethodHandler {
188		method: Method::GET,
189		handler: handler_fn,
190	}
191}
192
193/// 创建 POST 路由
194pub fn post<F, Fut, Res>(handler: F) -> MethodHandler
195where
196	F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
197	Fut: Future<Output = Res> + Send + 'static,
198	Res: IntoResponse + 'static,
199{
200	let handler = Arc::new(handler);
201	let handler_fn = Box::new(move |req: Request<Body>| {
202		let handler = Arc::clone(&handler);
203		Box::pin(async move {
204			handler(
205				serde_from_bytes(&to_bytes(req.into_body()).await.unwrap_or_default())
206					.unwrap_or_default(),
207			)
208			.await
209			.into_response()
210		}) as HandlerFuture
211	});
212
213	MethodHandler {
214		method: Method::POST,
215		handler: handler_fn,
216	}
217}
218
219/// 创建 POST JSON 路由
220pub fn post_json<F, Fut, Res, T>(handler: F) -> MethodHandler
221where
222	F: Fn(T) -> Fut + Send + Sync + 'static,
223	Fut: Future<Output = Res> + Send + 'static,
224	Res: IntoResponse + 'static,
225	T: DeserializeOwned + Send + 'static,
226{
227	let handler = Arc::new(handler);
228	let handler_fn = Box::new(move |req: Request<Body>| {
229		let handler = Arc::clone(&handler);
230		Box::pin(async move {
231			// 读取 body
232			let body_bytes = match to_bytes(req.into_body()).await {
233				Ok(bytes) => bytes,
234				Err(e) => {
235					return Response::builder()
236						.status(StatusCode::BAD_REQUEST)
237						.body(Body::from(format!("读取请求体失败: {}", e)))
238						.unwrap();
239				}
240			};
241
242			// 解析 JSON
243			match serde_from_json::<T>(&body_bytes) {
244				Ok(data) => handler(data).await.into_response(),
245				Err(e) => Response::builder()
246					.status(StatusCode::BAD_REQUEST)
247					.body(Body::from(format!("无效的 JSON 格式: {}", e)))
248					.unwrap(),
249			}
250		}) as HandlerFuture
251	});
252
253	MethodHandler {
254		method: Method::POST,
255		handler: handler_fn,
256	}
257}