Skip to main content

nimble_http/
router.rs

1//! router.rs
2
3use crate::response::File;
4use crate::response::IntoResponse;
5use hyper::body::to_bytes;
6use hyper::{Body, Method, Request, Response, StatusCode};
7use serde::de::DeserializeOwned;
8use serde_json::from_slice as serde_from_json;
9use serde_urlencoded::{from_bytes as serde_from_bytes, from_str as serde_from_str};
10use std::collections::HashMap;
11use std::fs;
12use std::future::Future;
13use std::path::Path;
14use std::pin::Pin;
15use std::sync::Arc;
16
17type HandlerFuture = Pin<Box<dyn Future<Output = Response<Body>> + Send>>;
18type HandlerFn = Box<dyn Fn(Request<Body>) -> HandlerFuture + Send + Sync>;
19
20/// 路由处理器
21pub struct Router {
22	routes: HashMap<(Method, String), HandlerFn>,
23}
24
25impl Router {
26	/// 创建新的路由器
27	pub fn new() -> Self {
28		Self {
29			routes: HashMap::new(),
30		}
31		.visit_dirs(Path::new("./static"))
32	}
33
34	fn visit_dirs(mut self, dir: &Path) -> Self {
35		if dir.is_dir() {
36			if let Ok(entries) = fs::read_dir(dir) {
37				for entry in entries {
38					if let Ok(entry) = entry {
39						let path = entry.path();
40						if path.is_dir() {
41							self = self.visit_dirs(&path); // 递归子目录
42						} else {
43							// 在这里处理文件,例如打印路径
44							let p = path
45								.to_str()
46								.unwrap_or("")
47								.to_string()
48								.chars()
49								.skip(2)
50								.collect::<String>();
51							self = self.file(format!("/{p}").as_str(), p);
52						}
53					}
54				}
55			}
56		}
57		self
58	}
59
60	/// 添加路由
61	pub fn route(mut self, path: &str, method_handler: MethodHandler) -> Self {
62		self.routes.insert(
63			(method_handler.method, path.to_string()),
64			method_handler.handler,
65		);
66		self
67	}
68
69	pub fn file(self, path: &str, file_path: String) -> Self {
70		let file_path = Arc::new(file_path); // 在外面创建 Arc
71
72		self.route(
73			path,
74			get(move |query| {
75				let file_path = Arc::clone(&file_path);
76				async move {
77					match query
78						.get("download")
79						.unwrap_or(&"false".to_string())
80						.as_str()
81					{
82						"1" | "" | "true" => File(file_path.to_string(), true),
83						_ => File(file_path.to_string(), false),
84					}
85				}
86			}),
87		)
88	}
89
90	/// 处理传入的请求
91	pub(crate) async fn handle_request(&self, req: Request<Body>) -> Response<Body> {
92		let key = (req.method().clone(), req.uri().path().to_string());
93
94		match self.routes.get(&key) {
95			Some(handler) => handler(req).await,
96			None => Response::builder()
97				.status(StatusCode::NOT_FOUND)
98				.body(Body::from("404 Not Found"))
99				.unwrap(),
100		}
101	}
102}
103
104impl Default for Router {
105	fn default() -> Self {
106		Self::new()
107	}
108}
109
110/// 方法处理器构建器
111pub struct MethodHandler {
112	pub(crate) method: Method,
113	pub(crate) handler: HandlerFn,
114}
115
116/// 创建 GET 路由
117pub fn get<F, Fut, Res>(handler: F) -> MethodHandler
118where
119	F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
120	Fut: Future<Output = Res> + Send + 'static,
121	Res: IntoResponse + 'static,
122{
123	let handler_fn =
124		Box::new(
125			move |req: Request<Body>| match serde_from_str(req.uri().query().unwrap_or("")) {
126				Ok(query) => {
127					let fut = handler(query);
128					Box::pin(async move { fut.await.into_response() }) as HandlerFuture
129				}
130				Err(e) => Box::pin(async move {
131					Response::new(Body::from(format!("{:?}<br />Invalid query parameters", e)))
132				}) as HandlerFuture,
133			},
134		);
135
136	MethodHandler {
137		method: Method::GET,
138		handler: handler_fn,
139	}
140}
141
142/// 创建 POST 路由
143pub fn post<F, Fut, Res>(handler: F) -> MethodHandler
144where
145	F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
146	Fut: Future<Output = Res> + Send + 'static,
147	Res: IntoResponse + 'static,
148{
149	let handler = Arc::new(handler);
150	let handler_fn = Box::new(move |req: Request<Body>| {
151		let handler = Arc::clone(&handler);
152		Box::pin(async move {
153			handler(
154				serde_from_bytes(&to_bytes(req.into_body()).await.unwrap_or_default())
155					.unwrap_or_default(),
156			)
157			.await
158			.into_response()
159		}) as HandlerFuture
160	});
161
162	MethodHandler {
163		method: Method::POST,
164		handler: handler_fn,
165	}
166}
167
168/// 创建 POST JSON 路由
169pub fn post_json<F, Fut, Res, T>(handler: F) -> MethodHandler
170where
171	F: Fn(T) -> Fut + Send + Sync + 'static,
172	Fut: Future<Output = Res> + Send + 'static,
173	Res: IntoResponse + 'static,
174	T: DeserializeOwned + Send + 'static,
175{
176	let handler = Arc::new(handler);
177	let handler_fn = Box::new(move |req: Request<Body>| {
178		let handler = Arc::clone(&handler);
179		Box::pin(async move {
180			// 读取 body
181			let body_bytes = match to_bytes(req.into_body()).await {
182				Ok(bytes) => bytes,
183				Err(e) => {
184					return Response::builder()
185						.status(StatusCode::BAD_REQUEST)
186						.body(Body::from(format!("读取请求体失败: {}", e)))
187						.unwrap();
188				}
189			};
190
191			// 解析 JSON
192			match serde_from_json::<T>(&body_bytes) {
193				Ok(data) => handler(data).await.into_response(),
194				Err(e) => Response::builder()
195					.status(StatusCode::BAD_REQUEST)
196					.body(Body::from(format!("无效的 JSON 格式: {}", e)))
197					.unwrap(),
198			}
199		}) as HandlerFuture
200	});
201
202	MethodHandler {
203		method: Method::POST,
204		handler: handler_fn,
205	}
206}