Skip to main content

nimble_http/
router.rs

1//! router.rs
2
3use crate::cors::Cors;
4use crate::response::File;
5use crate::response::IntoResponse;
6use hyper::body::to_bytes;
7use hyper::{Body, Method, Request, Response, StatusCode};
8use serde::de::DeserializeOwned;
9use serde_json::from_slice as serde_from_json;
10use serde_urlencoded::{from_bytes as serde_from_bytes, from_str as serde_from_str};
11use std::collections::HashMap;
12use std::fs;
13use std::future::Future;
14use std::net::SocketAddr;
15use std::path::Path;
16use std::pin::Pin;
17use std::sync::{Arc, Mutex};
18
19type HandlerFuture = Pin<Box<dyn Future<Output = Response<Body>> + Send>>;
20type HandlerFn = Box<dyn Fn(Request<Body>) -> HandlerFuture + Send + Sync>;
21
22/// 路由处理器
23pub struct Router {
24	routes: HashMap<(Method, String), HandlerFn>,
25	cors: Option<Cors>,
26	pub(crate) enable_logger: bool,
27	pub(crate) outputs: Arc<Mutex<String>>,
28}
29
30impl Router {
31	/// 创建新的路由器
32	pub fn new() -> Self {
33		Self {
34			routes: HashMap::new(),
35			cors: None,
36			enable_logger: false,
37			outputs: Arc::new(Mutex::new(String::new())),
38		}
39		.visit_dirs(Path::new("./static"))
40	}
41
42	// 新增:设置 CORS
43	pub fn with_cors(mut self, cors: Cors) -> Self {
44		self.cors = Some(cors);
45		self
46	}
47
48	// 新增:使用默认 CORS
49	pub fn cors(mut self) -> Self {
50		self.cors = Some(Cors::default());
51		self
52	}
53
54	fn visit_dirs(mut self, dir: &Path) -> Self {
55		if dir.is_dir() {
56			if let Ok(entries) = fs::read_dir(dir) {
57				for entry in entries {
58					if let Ok(entry) = entry {
59						let path = entry.path();
60						if path.is_dir() {
61							self = self.visit_dirs(&path); // 递归子目录
62						} else {
63							// 在这里处理文件,例如打印路径
64							let p = path
65								.to_str()
66								.unwrap_or("")
67								.to_string()
68								.chars()
69								.skip(2)
70								.collect::<String>();
71							self = self.file(format!("/{p}").as_str(), p);
72						}
73					}
74				}
75			}
76		}
77		self
78	}
79
80	/// 添加路由
81	pub fn route(mut self, path: &str, method_handler: MethodHandler) -> Self {
82		self.routes.insert(
83			(method_handler.method, path.to_string()),
84			method_handler.handler,
85		);
86		self
87	}
88
89	pub fn file(self, path: &str, file_path: String) -> Self {
90		let file_path = Arc::new(file_path); // 在外面创建 Arc
91
92		self.route(
93			path,
94			get(move |query| {
95				let file_path = Arc::clone(&file_path);
96				async move {
97					match query
98						.get("download")
99						.unwrap_or(&"false".to_string())
100						.as_str()
101					{
102						"1" | "" | "true" => File(file_path.to_string(), true),
103						_ => File(file_path.to_string(), false),
104					}
105				}
106			}),
107		)
108	}
109
110	/// 处理传入的请求
111	pub(crate) async fn handle_request(&mut self, req: Request<Body>) -> Response<Body> {
112		// 提前取出需要的信息
113		let method = req.method().clone();
114		let path = req.uri().path().to_string();
115		let client_ip = req
116			.extensions()
117			.get::<SocketAddr>()
118			.map(|addr| addr.ip().to_string())
119			.unwrap_or_else(|| "unknown".to_string());
120		let origin = req
121			.headers()
122			.get("origin")
123			.and_then(|h| h.to_str().ok())
124			.map(|s| s.to_owned()); // 提前取出 origin
125
126		// 处理 OPTIONS 预检请求
127		if method == Method::OPTIONS {
128			if let Some(cors) = &self.cors {
129				return cors.handle_preflight(origin.as_deref());
130			}
131		}
132
133		match self.routes.get(&(method.clone(), path.clone())) {
134			Some(handler) => {
135				let mut response = handler(req).await;
136				let status = response.status();
137
138				if let Some(cors) = &self.cors {
139					cors.apply_headers(&mut response, origin.as_deref());
140				}
141				if self.enable_logger {
142					self.request_log(client_ip, &method, path, status);
143				}
144				response
145			}
146			None => {
147				self.request_log(client_ip, &method, path, StatusCode::NOT_FOUND);
148				Response::builder()
149					.status(StatusCode::NOT_FOUND)
150					.body(Body::from("404 Not Found"))
151					.unwrap()
152			}
153		}
154	}
155}
156
157impl Default for Router {
158	fn default() -> Self {
159		Self::new()
160	}
161}
162
163/// 方法处理器构建器
164pub struct MethodHandler {
165	pub(crate) method: Method,
166	pub(crate) handler: HandlerFn,
167}
168
169/// 创建 GET 路由
170pub fn get<F, Fut, Res>(handler: F) -> MethodHandler
171where
172	F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
173	Fut: Future<Output = Res> + Send + 'static,
174	Res: IntoResponse + 'static,
175{
176	let handler_fn =
177		Box::new(
178			move |req: Request<Body>| match serde_from_str(req.uri().query().unwrap_or("")) {
179				Ok(query) => {
180					let fut = handler(query);
181					Box::pin(async move { fut.await.into_response() }) as HandlerFuture
182				}
183				Err(e) => Box::pin(async move {
184					Response::builder()
185						.status(StatusCode::BAD_REQUEST)
186						.body(Body::from(format!("{:?}<br />Invalid query parameters", e)))
187						.unwrap()
188				}) as HandlerFuture,
189			},
190		);
191
192	MethodHandler {
193		method: Method::GET,
194		handler: handler_fn,
195	}
196}
197
198/// 创建 POST 路由
199pub fn post<F, Fut, Res>(handler: F) -> MethodHandler
200where
201	F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
202	Fut: Future<Output = Res> + Send + 'static,
203	Res: IntoResponse + 'static,
204{
205	let handler = Arc::new(handler);
206	let handler_fn = Box::new(move |req: Request<Body>| {
207		let handler = Arc::clone(&handler);
208		Box::pin(async move {
209			handler(
210				serde_from_bytes(&to_bytes(req.into_body()).await.unwrap_or_default())
211					.unwrap_or_default(),
212			)
213			.await
214			.into_response()
215		}) as HandlerFuture
216	});
217
218	MethodHandler {
219		method: Method::POST,
220		handler: handler_fn,
221	}
222}
223
224/// 创建 POST JSON 路由
225pub fn post_json<F, Fut, Res, T>(handler: F) -> MethodHandler
226where
227	F: Fn(T) -> Fut + Send + Sync + 'static,
228	Fut: Future<Output = Res> + Send + 'static,
229	Res: IntoResponse + 'static,
230	T: DeserializeOwned + Send + 'static,
231{
232	let handler = Arc::new(handler);
233	let handler_fn = Box::new(move |req: Request<Body>| {
234		let handler = Arc::clone(&handler);
235		Box::pin(async move {
236			// 读取 body
237			let body_bytes = match to_bytes(req.into_body()).await {
238				Ok(bytes) => bytes,
239				Err(e) => {
240					return Response::builder()
241						.status(StatusCode::BAD_REQUEST)
242						.body(Body::from(format!("读取请求体失败: {}", e)))
243						.unwrap();
244				}
245			};
246
247			// 解析 JSON
248			match serde_from_json::<T>(&body_bytes) {
249				Ok(data) => handler(data).await.into_response(),
250				Err(e) => Response::builder()
251					.status(StatusCode::BAD_REQUEST)
252					.body(Body::from(format!("无效的 JSON 格式: {}", e)))
253					.unwrap(),
254			}
255		}) as HandlerFuture
256	});
257
258	MethodHandler {
259		method: Method::POST,
260		handler: handler_fn,
261	}
262}