lambda_lw_http_router_core/
router.rs

1use crate::{RoutableHttpEvent, RouteContext};
2use lambda_runtime::{Error, LambdaEvent};
3use lazy_static::lazy_static;
4use opentelemetry::{global, trace::Status, Context as OtelContext};
5use opentelemetry_http::HeaderExtractor;
6use regex::Regex;
7use serde_json::{json, Value as JsonValue};
8use std::any::{Any, TypeId};
9use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::sync::Mutex;
14use tracing_opentelemetry::OpenTelemetrySpanExt;
15
16lazy_static! {
17    static ref ROUTE_REGISTRY: Mutex<HashMap<(TypeId, TypeId), Box<dyn Any + Send + Sync>>> =
18        Mutex::new(HashMap::new());
19}
20
21/// The main router type that handles HTTP requests and dispatches them to handlers.
22///
23/// The router matches incoming requests against registered routes and executes
24/// the corresponding handlers. It supports path parameters and different HTTP methods.
25///
26/// # Type Parameters
27///
28/// * `State` - The type of the application state shared across handlers
29/// * `E` - The Lambda event type (must implement `RoutableHttpEvent`)
30///
31/// # Examples
32///
33/// ```rust
34/// use lambda_lw_http_router_core::{Router, RouteContext};
35/// use serde_json::{json, Value};
36/// use lambda_runtime::Error;
37/// use aws_lambda_events::apigw::ApiGatewayV2httpRequest;
38///
39/// #[derive(Clone)]
40/// struct State {}
41///
42/// let mut router: Router<State, ApiGatewayV2httpRequest> = Router::new();
43/// router.register_route("GET", "/hello/{name}", |ctx| async move {
44///     let name = ctx.params.get("name").map_or("World", String::as_str);
45///     Ok(json!({ "message": format!("Hello, {}!", name) }))
46/// });
47/// ```
48pub struct Router<State, E>
49where
50    State: Send + Sync + Clone + 'static,
51    E: RoutableHttpEvent,
52{
53    routes: HashMap<
54        String,
55        (
56            Arc<
57                dyn Fn(
58                        RouteContext<State, E>,
59                    )
60                        -> Pin<Box<dyn Future<Output = Result<JsonValue, Error>> + Send>>
61                    + Send
62                    + Sync,
63            >,
64            Regex,
65        ),
66    >,
67}
68
69impl<State, E: RoutableHttpEvent> Router<State, E>
70where
71    State: Send + Sync + Clone + 'static,
72{
73    pub fn new() -> Self {
74        Self {
75            routes: HashMap::new(),
76        }
77    }
78
79    pub fn register_route<F, Fut>(&mut self, method: &str, path: &str, handler: F)
80    where
81        F: Fn(RouteContext<State, E>) -> Fut + Send + Sync + 'static,
82        Fut: Future<Output = Result<JsonValue, Error>> + Send + 'static,
83    {
84        let regex_pattern = path
85            .split('/')
86            .map(|segment| {
87                if segment.starts_with('{') && segment.ends_with('}') {
88                    let param_name = segment[1..segment.len() - 1].trim_end_matches('+');
89                    if segment.ends_with("+}") {
90                        // Greedy match for proxy+ style parameters
91                        format!("(?P<{param_name}>.*)")
92                    } else {
93                        // Normal parameter match (non-greedy, no slashes)
94                        format!("(?P<{param_name}>[^/]+)")
95                    }
96                } else {
97                    regex::escape(segment) // Escape regular segments
98                }
99            })
100            .collect::<Vec<_>>()
101            .join("/");
102
103        let regex = Regex::new(&format!("^{regex_pattern}$")).expect("Invalid route pattern");
104
105        let handler = Arc::new(move |ctx| {
106            Box::pin(handler(ctx)) as Pin<Box<dyn Future<Output = Result<JsonValue, Error>> + Send>>
107        });
108
109        let key = format!("{} {}", method.to_uppercase(), path);
110        self.routes.insert(key, (handler, regex));
111    }
112
113    // Helper method to handle the response and set span attributes
114    fn handle_response(span: &tracing::Span, response: JsonValue) -> Result<JsonValue, Error> {
115        // Set response attributes
116        if let Some(status) = response.get("statusCode").and_then(|s| s.as_i64()) {
117            span.set_attribute("http.response.status_code", status);
118
119            // For server spans:
120            // - Leave status unset for 1xx-4xx
121            // - Set error only for 5xx
122            let otel_status = if (500..600).contains(&(status as u16)) {
123                Status::error(format!("Server error {status}"))
124            } else {
125                Status::Unset
126            };
127            span.set_status(otel_status);
128        }
129
130        Ok(response)
131    }
132
133    // Helper method to create context and execute handler
134    async fn execute_handler(
135        &self,
136        handler_fn: &Arc<
137            dyn Fn(
138                    RouteContext<State, E>,
139                )
140                    -> Pin<Box<dyn Future<Output = Result<JsonValue, Error>> + Send>>
141                + Send
142                + Sync,
143        >,
144        params: HashMap<String, String>,
145        route_pattern: String,
146        base_ctx: RouteContext<State, E>,
147        parent_cx: OtelContext,
148        payload: &E,
149    ) -> Result<JsonValue, Error> {
150        let span = tracing::Span::current();
151        span.set_parent(parent_cx);
152        payload.set_otel_http_attributes(&span, &route_pattern, &base_ctx.lambda_context);
153
154        let ctx = RouteContext {
155            params,
156            route_pattern,
157            ..base_ctx
158        };
159
160        let response = handler_fn(ctx).await?;
161        Self::handle_response(&span, response)
162    }
163
164    pub async fn handle_request(
165        &self,
166        event: LambdaEvent<E>,
167        state: Arc<State>,
168    ) -> Result<JsonValue, Error> {
169        let (payload, lambda_context) = event.into_parts();
170
171        // Extract parent context from headers
172        let parent_cx = if let Some(headers) = payload.http_headers() {
173            global::get_text_map_propagator(|propagator| {
174                propagator.extract(&HeaderExtractor(headers))
175            })
176        } else {
177            OtelContext::current()
178        };
179
180        let raw_path = payload.path();
181        let path = raw_path.as_deref().unwrap_or("/").to_string();
182        let method = payload.http_method().to_uppercase();
183
184        let base_ctx = RouteContext {
185            path: path.clone(),
186            method: method.clone(),
187            params: HashMap::new(),
188            state,
189            event: payload.clone(),
190            lambda_context,
191            route_pattern: String::new(),
192        };
193
194        // Check if we have a matching route
195        for (route_key, (handler_fn, regex)) in &self.routes {
196            // Extract method and path from route_key
197            let parts: Vec<&str> = route_key.split_whitespace().collect();
198            let (route_method, route_path) = match parts.as_slice() {
199                [method, path] => (method.to_uppercase(), path),
200                _ => continue, // Invalid route key format
201            };
202
203            // Check if methods match
204            if method != route_method {
205                continue;
206            }
207
208            // If we have both resource and path_parameters, validate against API Gateway configuration
209            if let (Some(resource), Some(path_params)) =
210                (payload.route(), payload.path_parameters())
211            {
212                if resource == *route_path {
213                    return self
214                        .execute_handler(
215                            handler_fn,
216                            path_params.clone(),
217                            route_path.to_string(),
218                            base_ctx,
219                            parent_cx.clone(),
220                            &payload,
221                        )
222                        .await;
223                }
224            }
225
226            // Fall back to our own path parameter extraction
227            if let Some(captures) = regex.captures(&path) {
228                let mut params = HashMap::new();
229                for name in regex.capture_names().flatten() {
230                    if let Some(value) = captures.name(name) {
231                        params.insert(name.to_string(), value.as_str().to_string());
232                    }
233                }
234
235                return self
236                    .execute_handler(
237                        handler_fn,
238                        params,
239                        route_path.to_string(),
240                        base_ctx,
241                        parent_cx.clone(),
242                        &payload,
243                    )
244                    .await;
245            }
246        }
247
248        Ok(json!({
249            "statusCode": 404,
250            "headers": {"Content-Type": "text/plain"},
251            "body": "Not Found"
252        }))
253    }
254}
255
256impl<State, E: RoutableHttpEvent> Default for Router<State, E>
257where
258    State: Send + Sync + Clone + 'static,
259{
260    fn default() -> Self {
261        Self::new()
262    }
263}
264
265/// A builder for constructing routers with a fluent API.
266///
267/// This type provides a more ergonomic way to create and configure routers
268/// compared to using `Router` directly. It supports method chaining and
269/// handles all the type complexity internally.
270///
271/// # Examples
272///
273/// ```rust
274/// use lambda_lw_http_router_core::{RouterBuilder, RouteContext};
275/// use serde_json::{json, Value};
276/// use lambda_runtime::Error;
277/// use aws_lambda_events::apigw::ApiGatewayV2httpRequest;
278///
279/// #[derive(Clone)]
280/// struct State {}
281///
282/// async fn get_users(ctx: RouteContext<State, ApiGatewayV2httpRequest>) -> Result<Value, Error> {
283///     Ok(json!({ "users": [] }))
284/// }
285///
286/// async fn create_user(ctx: RouteContext<State, ApiGatewayV2httpRequest>) -> Result<Value, Error> {
287///     Ok(json!({ "status": "created" }))
288/// }
289///
290/// let router = RouterBuilder::<State, ApiGatewayV2httpRequest>::new()
291///     .route("GET", "/users", get_users)
292///     .route("POST", "/users", create_user)
293///     .build();
294/// ```
295pub struct RouterBuilder<State, E: RoutableHttpEvent>
296where
297    State: Send + Sync + Clone + 'static,
298{
299    routes: Vec<(
300        String,
301        String,
302        Box<
303            dyn Fn(
304                    RouteContext<State, E>,
305                )
306                    -> Pin<Box<dyn Future<Output = Result<JsonValue, Error>> + Send>>
307                + Send
308                + Sync,
309        >,
310    )>,
311}
312
313impl<State, E: RoutableHttpEvent> RouterBuilder<State, E>
314where
315    State: Send + Sync + Clone + 'static,
316{
317    pub fn new() -> Self {
318        Self { routes: Vec::new() }
319    }
320
321    pub fn route<F, Fut>(mut self, method: &str, path: &str, handler: F) -> Self
322    where
323        F: Fn(RouteContext<State, E>) -> Fut + Send + Sync + 'static,
324        Fut: Future<Output = Result<JsonValue, Error>> + Send + 'static,
325    {
326        let handler = Box::new(move |ctx| {
327            Box::pin(handler(ctx)) as Pin<Box<dyn Future<Output = Result<JsonValue, Error>> + Send>>
328        });
329        self.routes
330            .push((method.to_string(), path.to_string(), handler));
331        self
332    }
333
334    pub fn build(self) -> Router<State, E> {
335        let mut router = Router::new();
336        for (method, path, handler) in self.routes {
337            let handler = move |ctx: RouteContext<State, E>| (handler)(ctx);
338            router.register_route(&method, &path, handler);
339        }
340        router
341    }
342
343    pub fn from_registry() -> Self {
344        let mut builder = Self::new();
345
346        let routes = {
347            let registry = ROUTE_REGISTRY.lock().unwrap();
348            registry
349                .get(&(TypeId::of::<State>(), TypeId::of::<E>()))
350                .and_then(|routes| routes.downcast_ref::<Vec<RouteEntry<State, E>>>())
351                .map(|routes| {
352                    routes
353                        .iter()
354                        .map(|entry| (entry.method, entry.path, Arc::clone(&entry.handler)))
355                        .collect::<Vec<_>>()
356                })
357                .unwrap_or_default()
358        };
359
360        for (method, path, handler) in routes {
361            let handler = Box::new(move |ctx: RouteContext<State, E>| (handler)(ctx));
362            builder
363                .routes
364                .push((method.to_string(), path.to_string(), handler));
365        }
366
367        builder
368    }
369}
370
371impl<State, E: RoutableHttpEvent> Default for RouterBuilder<State, E>
372where
373    State: Send + Sync + Clone + 'static,
374{
375    fn default() -> Self {
376        Self::new()
377    }
378}
379
380type BoxedHandler<State, E> = dyn Fn(RouteContext<State, E>) -> Pin<Box<dyn Future<Output = Result<JsonValue, Error>> + Send>>
381    + Send
382    + Sync;
383
384#[derive(Clone)]
385struct RouteEntry<State: Clone, E: RoutableHttpEvent> {
386    method: &'static str,
387    path: &'static str,
388    handler: Arc<BoxedHandler<State, E>>,
389}
390
391pub fn register_route<State, E: RoutableHttpEvent>(
392    method: &'static str,
393    path: &'static str,
394    handler: impl Fn(RouteContext<State, E>) -> Pin<Box<dyn Future<Output = Result<JsonValue, Error>> + Send>>
395        + Send
396        + Sync
397        + 'static,
398) where
399    State: Send + Sync + Clone + 'static,
400{
401    let state_type_id = TypeId::of::<State>();
402    let event_type_id = TypeId::of::<E>();
403    let handler = Arc::new(handler) as Arc<BoxedHandler<State, E>>;
404    let entry = RouteEntry {
405        method,
406        path,
407        handler,
408    };
409
410    let mut registry = ROUTE_REGISTRY.lock().unwrap();
411    let routes = registry
412        .entry((state_type_id, event_type_id))
413        .or_insert_with(|| Box::new(Vec::<RouteEntry<State, E>>::new()));
414
415    let routes = routes
416        .downcast_mut::<Vec<RouteEntry<State, E>>>()
417        .expect("Registry type mismatch");
418    routes.push(entry);
419}