1use crate::error::{Error, Result};
2use crate::types::{OxiditeRequest, OxiditeResponse};
3use hyper::Method;
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9use tower_service::Service;
10use regex::Regex;
11
12pub trait Handler: Send + Sync + 'static {
13 fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>>;
14}
15
16impl<F, Fut> Handler for F
17where
18 F: Fn(OxiditeRequest) -> Fut + Send + Sync + 'static,
19 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
20{
21 fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
22 Box::pin(self(req))
23 }
24}
25
26struct Route {
27 pattern: Regex,
28 param_names: Vec<String>,
29 handler: Arc<dyn Handler>,
30}
31
32#[derive(Clone)]
33pub struct Router {
34 routes: HashMap<Method, Vec<Arc<Route>>>,
35}
36
37impl Router {
38 pub fn new() -> Self {
39 Self {
40 routes: HashMap::new(),
41 }
42 }
43
44 pub fn get<H>(&mut self, path: &str, handler: H)
45 where
46 H: Handler,
47 {
48 self.add_route(Method::GET, path, handler);
49 }
50
51 pub fn post<H>(&mut self, path: &str, handler: H)
52 where
53 H: Handler,
54 {
55 self.add_route(Method::POST, path, handler);
56 }
57
58 pub fn put<H>(&mut self, path: &str, handler: H)
59 where
60 H: Handler,
61 {
62 self.add_route(Method::PUT, path, handler);
63 }
64
65 pub fn delete<H>(&mut self, path: &str, handler: H)
66 where
67 H: Handler,
68 {
69 self.add_route(Method::DELETE, path, handler);
70 }
71
72 pub fn patch<H>(&mut self, path: &str, handler: H)
73 where
74 H: Handler,
75 {
76 self.add_route(Method::PATCH, path, handler);
77 }
78
79 fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
80 where
81 H: Handler,
82 {
83 let (pattern, param_names) = compile_path(path);
84 let route = Arc::new(Route {
85 pattern,
86 param_names,
87 handler: Arc::new(handler),
88 });
89
90 self.routes
91 .entry(method)
92 .or_insert_with(Vec::new)
93 .push(route);
94 }
95
96 pub async fn handle(&self, mut req: OxiditeRequest) -> Result<OxiditeResponse> {
97 let method = req.method().clone();
98 let path = req.uri().path().to_string();
99 let path_for_error = path.clone();
100
101 let try_match = |target_method: &Method, req: &mut OxiditeRequest| -> Option<Arc<Route>> {
103 if let Some(routes) = self.routes.get(target_method) {
104 for route in routes {
105 if let Some(captures) = route.pattern.captures(&path) {
106 let mut params = serde_json::Map::new();
108 for (i, name) in route.param_names.iter().enumerate() {
109 if let Some(value) = captures.get(i + 1) {
110 params.insert(
111 name.clone(),
112 serde_json::Value::String(value.as_str().to_string()),
113 );
114 }
115 }
116
117 if !params.is_empty() {
119 req.extensions_mut().insert(crate::extract::PathParams(
120 serde_json::Value::Object(params),
121 ));
122 }
123
124 return Some(route.clone());
125 }
126 }
127 }
128 None
129 };
130
131 if let Some(route) = try_match(&method, &mut req) {
133 return route.handler.call(req).await;
134 }
135
136 if method == Method::HEAD {
138 if let Some(route) = try_match(&Method::GET, &mut req) {
139 return route.handler.call(req).await;
142 }
143 }
144
145 eprintln!("🔍 Route not found: {} {}", method, path_for_error);
147 Err(Error::NotFound("Route not found".to_string()))
148 }
149}
150
151impl Service<OxiditeRequest> for Router {
152 type Response = OxiditeResponse;
153 type Error = Error;
154 type Future = Pin<Box<dyn Future<Output = Result<Self::Response>> + Send>>;
155
156 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
157 Poll::Ready(Ok(()))
158 }
159
160 fn call(&mut self, req: OxiditeRequest) -> Self::Future {
161 let router = self.clone();
162 Box::pin(async move {
163 router.handle(req).await
164 })
165 }
166}
167
168impl Default for Router {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174fn compile_path(path: &str) -> (Regex, Vec<String>) {
177 let mut pattern = String::from("^");
178 let mut param_names = Vec::new();
179 let mut chars = path.chars().peekable();
180
181 while let Some(ch) = chars.next() {
182 match ch {
183 ':' => {
184 let mut param_name = String::new();
186 while let Some(&next_ch) = chars.peek() {
187 if next_ch.is_alphanumeric() || next_ch == '_' {
188 param_name.push(next_ch);
189 chars.next();
190 } else {
191 break;
192 }
193 }
194 param_names.push(param_name);
195 pattern.push_str("([^/]+)");
196 }
197 '*' => {
198 pattern.push_str("(.*)");
200 }
201 '.' | '+' | '?' | '^' | '$' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '\\' => {
202 pattern.push('\\');
204 pattern.push(ch);
205 }
206 _ => {
207 pattern.push(ch);
208 }
209 }
210 }
211
212 pattern.push('$');
213 let regex = Regex::new(&pattern).expect("Invalid route pattern");
214 (regex, param_names)
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn test_compile_path() {
223 let (regex, params) = compile_path("/users/:id");
224 assert_eq!(params, vec!["id"]);
225 assert!(regex.is_match("/users/123"));
226 assert!(!regex.is_match("/users/123/posts"));
227
228 let (regex, params) = compile_path("/users/:user_id/posts/:post_id");
229 assert_eq!(params, vec!["user_id", "post_id"]);
230 assert!(regex.is_match("/users/1/posts/2"));
231 }
232
233 #[test]
234 fn test_exact_match() {
235 let (regex, params) = compile_path("/users");
236 assert_eq!(params.len(), 0);
237 assert!(regex.is_match("/users"));
238 assert!(!regex.is_match("/users/123"));
239 }
240}