Skip to main content

nimble_http/
router.rs

1//! router.rs
2
3use crate::response::File;
4use crate::response::IntoResponse;
5use crate::cors::Cors;
6use hyper::body::to_bytes;
7use hyper::{Body, Method, Request, Response, StatusCode};
8use serde::de::DeserializeOwned;
9use serde_json::from_slice as serde_from_json;
10use serde_urlencoded::{from_bytes as serde_from_bytes, from_str as serde_from_str};
11use std::collections::HashMap;
12use std::fs;
13use std::future::Future;
14use std::path::Path;
15use std::pin::Pin;
16use std::sync::Arc;
17
18type HandlerFuture = Pin<Box<dyn Future<Output = Response<Body>> + Send>>;
19type HandlerFn = Box<dyn Fn(Request<Body>) -> HandlerFuture + Send + Sync>;
20
21/// 路由处理器
22pub struct Router {
23    routes: HashMap<(Method, String), HandlerFn>,
24    cors: Option<Cors>,  // 新增:可选的 CORS 配置
25}
26
27impl Router {
28    /// 创建新的路由器
29    pub fn new() -> Self {
30        Self {
31            routes: HashMap::new(),
32            cors: None,  // 默认没有 CORS
33        }
34        .visit_dirs(Path::new("./static"))
35    }
36    
37    // 新增:设置 CORS
38    pub fn with_cors(mut self, cors: Cors) -> Self {
39        self.cors = Some(cors);
40        self
41    }
42    
43    // 新增:使用默认 CORS
44    pub fn cors(mut self) -> Self {
45        self.cors = Some(Cors::default());
46        self
47    }
48
49    fn visit_dirs(mut self, dir: &Path) -> Self {
50        if dir.is_dir() {
51            if let Ok(entries) = fs::read_dir(dir) {
52                for entry in entries {
53                    if let Ok(entry) = entry {
54                        let path = entry.path();
55                        if path.is_dir() {
56                            self = self.visit_dirs(&path); // 递归子目录
57                        } else {
58                            // 在这里处理文件,例如打印路径
59                            let p = path
60                                .to_str()
61                                .unwrap_or("")
62                                .to_string()
63                                .chars()
64                                .skip(2)
65                                .collect::<String>();
66                            self = self.file(format!("/{p}").as_str(), p);
67                        }
68                    }
69                }
70            }
71        }
72        self
73    }
74
75    /// 添加路由
76    pub fn route(mut self, path: &str, method_handler: MethodHandler) -> Self {
77        self.routes.insert(
78            (method_handler.method, path.to_string()),
79            method_handler.handler,
80        );
81        self
82    }
83
84    pub fn file(self, path: &str, file_path: String) -> Self {
85        let file_path = Arc::new(file_path); // 在外面创建 Arc
86
87        self.route(
88            path,
89            get(move |query| {
90                let file_path = Arc::clone(&file_path);
91                async move {
92                    match query
93                        .get("download")
94                        .unwrap_or(&"false".to_string())
95                        .as_str()
96                    {
97                        "1" | "" | "true" => File(file_path.to_string(), true),
98                        _ => File(file_path.to_string(), false),
99                    }
100                }
101            }),
102        )
103    }
104
105    /// 处理传入的请求
106pub(crate) async fn handle_request(&self, req: Request<Body>) -> Response<Body> {
107    // 提前取出需要的信息
108    let method = req.method().clone();
109    let path = req.uri().path().to_string();
110    let origin = req.headers()
111        .get("origin")
112        .and_then(|h| h.to_str().ok())
113        .map(|s| s.to_owned());  // 提前取出 origin
114    
115    // 处理 OPTIONS 预检请求
116    if method == Method::OPTIONS {
117        if let Some(cors) = &self.cors {
118            return cors.handle_preflight(origin.as_deref());
119        }
120    }
121    
122    let key = (method, path);
123
124    match self.routes.get(&key) {
125        Some(handler) => {
126            let mut response = handler(req).await;  // req 在这里被消耗
127            // 用之前取出的 origin
128            if let Some(cors) = &self.cors {
129                cors.apply_headers(&mut response, origin.as_deref());
130            }
131            response
132        }
133        None => Response::builder()
134            .status(StatusCode::NOT_FOUND)
135            .body(Body::from("404 Not Found"))
136            .unwrap(),
137    }
138}
139}
140
141impl Default for Router {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147/// 方法处理器构建器
148pub struct MethodHandler {
149    pub(crate) method: Method,
150    pub(crate) handler: HandlerFn,
151}
152
153/// 创建 GET 路由
154pub fn get<F, Fut, Res>(handler: F) -> MethodHandler
155where
156    F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
157    Fut: Future<Output = Res> + Send + 'static,
158    Res: IntoResponse + 'static,
159{
160    let handler_fn =
161        Box::new(
162            move |req: Request<Body>| match serde_from_str(req.uri().query().unwrap_or("")) {
163                Ok(query) => {
164                    let fut = handler(query);
165                    Box::pin(async move { fut.await.into_response() }) as HandlerFuture
166                }
167                Err(e) => Box::pin(async move {
168                    Response::new(Body::from(format!("{:?}<br />Invalid query parameters", e)))
169                }) as HandlerFuture,
170            },
171        );
172
173    MethodHandler {
174        method: Method::GET,
175        handler: handler_fn,
176    }
177}
178
179/// 创建 POST 路由
180pub fn post<F, Fut, Res>(handler: F) -> MethodHandler
181where
182    F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
183    Fut: Future<Output = Res> + Send + 'static,
184    Res: IntoResponse + 'static,
185{
186    let handler = Arc::new(handler);
187    let handler_fn = Box::new(move |req: Request<Body>| {
188        let handler = Arc::clone(&handler);
189        Box::pin(async move {
190            handler(
191                serde_from_bytes(&to_bytes(req.into_body()).await.unwrap_or_default())
192                    .unwrap_or_default(),
193            )
194            .await
195            .into_response()
196        }) as HandlerFuture
197    });
198
199    MethodHandler {
200        method: Method::POST,
201        handler: handler_fn,
202    }
203}
204
205/// 创建 POST JSON 路由
206pub fn post_json<F, Fut, Res, T>(handler: F) -> MethodHandler
207where
208    F: Fn(T) -> Fut + Send + Sync + 'static,
209    Fut: Future<Output = Res> + Send + 'static,
210    Res: IntoResponse + 'static,
211    T: DeserializeOwned + Send + 'static,
212{
213    let handler = Arc::new(handler);
214    let handler_fn = Box::new(move |req: Request<Body>| {
215        let handler = Arc::clone(&handler);
216        Box::pin(async move {
217            // 读取 body
218            let body_bytes = match to_bytes(req.into_body()).await {
219                Ok(bytes) => bytes,
220                Err(e) => {
221                    return Response::builder()
222                        .status(StatusCode::BAD_REQUEST)
223                        .body(Body::from(format!("读取请求体失败: {}", e)))
224                        .unwrap();
225                }
226            };
227
228            // 解析 JSON
229            match serde_from_json::<T>(&body_bytes) {
230                Ok(data) => handler(data).await.into_response(),
231                Err(e) => Response::builder()
232                    .status(StatusCode::BAD_REQUEST)
233                    .body(Body::from(format!("无效的 JSON 格式: {}", e)))
234                    .unwrap(),
235            }
236        }) as HandlerFuture
237    });
238
239    MethodHandler {
240        method: Method::POST,
241        handler: handler_fn,
242    }
243}