Skip to main content

nimble_http/
router.rs

1//! router.rs
2
3#[cfg(feature = "cors")]
4use crate::cors::Cors;
5
6use crate::response::File;
7use crate::response::IntoResponse;
8use hyper::body::to_bytes;
9use hyper::{Body, Method, Request, Response, StatusCode};
10use serde::de::DeserializeOwned;
11use serde_json::from_slice as serde_from_json;
12use serde_urlencoded::{from_bytes as serde_from_bytes, from_str as serde_from_str};
13use std::collections::HashMap;
14use std::fs;
15use std::future::Future;
16use std::net::SocketAddr;
17use std::path::Path;
18use std::pin::Pin;
19use std::sync::{Arc, Mutex};
20
21type HandlerFuture = Pin<Box<dyn Future<Output = Response<Body>> + Send>>;
22type HandlerFn = Box<dyn Fn(Request<Body>) -> HandlerFuture + Send + Sync>;
23
24/// 路由处理器
25pub struct Router {
26	routes: HashMap<(Method, String), HandlerFn>,
27	
28	#[cfg(feature = "cors")]
29	cors: Option<Cors>,
30	
31	pub(crate) enable_logger: bool,
32	pub(crate) outputs: Arc<Mutex<String>>,
33}
34
35impl Router {
36	/// 创建新的路由器
37	pub fn new() -> Self {
38		Self {
39			routes: HashMap::new(),
40			
41			#[cfg(feature = "cors")]
42			cors: None,
43			
44			enable_logger: false,
45			outputs: Arc::new(Mutex::new(String::new())),
46		}
47		.visit_dirs(Path::new("./static"))
48	}
49
50	#[cfg(feature = "cors")]
51	pub fn with_cors(mut self, cors: Cors) -> Self {
52		self.cors = Some(cors);
53		self
54	}
55
56	#[cfg(feature = "cors")]
57	pub fn cors(mut self) -> Self {
58		self.cors = Some(Cors::default());
59		self
60	}
61
62	fn visit_dirs(mut self, dir: &Path) -> Self {
63		if dir.is_dir() {
64			if let Ok(entries) = fs::read_dir(dir) {
65				for entry in entries {
66					if let Ok(entry) = entry {
67						let path = entry.path();
68						if path.is_dir() {
69							self = self.visit_dirs(&path); // 递归子目录
70						} else {
71							// 在这里处理文件,例如打印路径
72							let p = path
73								.to_str()
74								.unwrap_or("")
75								.to_string()
76								.chars()
77								.skip(2)
78								.collect::<String>();
79							self = self.file(format!("/{p}").as_str(), p);
80						}
81					}
82				}
83			}
84		}
85		self
86	}
87
88	/// 添加路由
89	pub fn route(mut self, path: &str, method_handler: MethodHandler) -> Self {
90		self.routes.insert(
91			(method_handler.method, path.to_string()),
92			method_handler.handler,
93		);
94		self
95	}
96
97	pub fn file(self, path: &str, file_path: String) -> Self {
98		let file_path = Arc::new(file_path); // 在外面创建 Arc
99
100		self.route(
101			path,
102			get(move |query| {
103				let file_path = Arc::clone(&file_path);
104				async move {
105					match query
106						.get("download")
107						.unwrap_or(&"false".to_string())
108						.as_str()
109					{
110						"1" | "" | "true" => File(file_path.to_string(), true),
111						_ => File(file_path.to_string(), false),
112					}
113				}
114			}),
115		)
116	}
117
118	/// 处理传入的请求
119	pub(crate) async fn handle_request(&mut self, req: Request<Body>) -> Response<Body> {
120		// 提前取出需要的信息
121		let method = req.method().clone();
122		let path = req.uri().path().to_string();
123		let client_ip = req
124			.extensions()
125			.get::<SocketAddr>()
126			.map(|addr| addr.ip().to_string())
127			.unwrap_or_else(|| "unknown".to_string());
128		#[allow(unused_variables)]
129		let origin = req
130			.headers()
131			.get("origin")
132			.and_then(|h| h.to_str().ok())
133			.map(|s| s.to_owned()); // 提前取出 origin
134
135		// 处理 OPTIONS 预检请求
136		if method == Method::OPTIONS {
137			#[cfg(feature = "cors")]
138			{
139				if let Some(cors) = &self.cors {
140					return cors.handle_preflight(origin.as_deref());
141				}
142			}
143		}
144
145		match self.routes.get(&(method.clone(), path.clone())) {
146			Some(handler) => {
147				#[allow(unused_mut)]
148				let mut response = handler(req).await;
149				let status = response.status();
150				#[cfg(feature = "cors")]
151				{
152					if let Some(cors) = &self.cors {
153						cors.apply_headers(&mut response, origin.as_deref());
154					}
155				}
156				if self.enable_logger {
157					self.request_log(client_ip, &method, path, status);
158				}
159				response
160			}
161			None => {
162				self.request_log(client_ip, &method, path, StatusCode::NOT_FOUND);
163				Response::builder()
164					.status(StatusCode::NOT_FOUND)
165					.body(Body::from("404 Not Found"))
166					.unwrap()
167			}
168		}
169	}
170}
171
172impl Default for Router {
173	fn default() -> Self {
174		Self::new()
175	}
176}
177
178/// 方法处理器构建器
179pub struct MethodHandler {
180	pub(crate) method: Method,
181	pub(crate) handler: HandlerFn,
182}
183
184/// 创建 GET 路由
185pub fn get<F, Fut, Res>(handler: F) -> MethodHandler
186where
187	F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
188	Fut: Future<Output = Res> + Send + 'static,
189	Res: IntoResponse + 'static,
190{
191	let handler_fn =
192		Box::new(
193			move |req: Request<Body>| match serde_from_str(req.uri().query().unwrap_or("")) {
194				Ok(query) => {
195					let fut = handler(query);
196					Box::pin(async move { fut.await.into_response() }) as HandlerFuture
197				}
198				Err(e) => Box::pin(async move {
199					Response::builder()
200						.status(StatusCode::BAD_REQUEST)
201						.body(Body::from(format!("{:?}<br />Invalid query parameters", e)))
202						.unwrap()
203				}) as HandlerFuture,
204			},
205		);
206
207	MethodHandler {
208		method: Method::GET,
209		handler: handler_fn,
210	}
211}
212
213/// 创建 POST 路由
214pub fn post<F, Fut, Res>(handler: F) -> MethodHandler
215where
216	F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
217	Fut: Future<Output = Res> + Send + 'static,
218	Res: IntoResponse + 'static,
219{
220	let handler = Arc::new(handler);
221	let handler_fn = Box::new(move |req: Request<Body>| {
222		let handler = Arc::clone(&handler);
223		Box::pin(async move {
224			handler(
225				serde_from_bytes(&to_bytes(req.into_body()).await.unwrap_or_default())
226					.unwrap_or_default(),
227			)
228			.await
229			.into_response()
230		}) as HandlerFuture
231	});
232
233	MethodHandler {
234		method: Method::POST,
235		handler: handler_fn,
236	}
237}
238
239/// 创建 POST JSON 路由
240pub fn post_json<F, Fut, Res, T>(handler: F) -> MethodHandler
241where
242	F: Fn(T) -> Fut + Send + Sync + 'static,
243	Fut: Future<Output = Res> + Send + 'static,
244	Res: IntoResponse + 'static,
245	T: DeserializeOwned + Send + 'static,
246{
247	let handler = Arc::new(handler);
248	let handler_fn = Box::new(move |req: Request<Body>| {
249		let handler = Arc::clone(&handler);
250		Box::pin(async move {
251			// 读取 body
252			let body_bytes = match to_bytes(req.into_body()).await {
253				Ok(bytes) => bytes,
254				Err(e) => {
255					return Response::builder()
256						.status(StatusCode::BAD_REQUEST)
257						.body(Body::from(format!("读取请求体失败: {}", e)))
258						.unwrap();
259				}
260			};
261
262			// 解析 JSON
263			match serde_from_json::<T>(&body_bytes) {
264				Ok(data) => handler(data).await.into_response(),
265				Err(e) => Response::builder()
266					.status(StatusCode::BAD_REQUEST)
267					.body(Body::from(format!("无效的 JSON 格式: {}", e)))
268					.unwrap(),
269			}
270		}) as HandlerFuture
271	});
272
273	MethodHandler {
274		method: Method::POST,
275		handler: handler_fn,
276	}
277}