Skip to main content

spikard_http/
lifecycle.rs

1use axum::{
2    body::Body,
3    http::{Request, Response},
4};
5use std::sync::Arc;
6
7pub mod adapter;
8
9pub use spikard_core::lifecycle::{HookResult, LifecycleHook};
10
11pub type LifecycleHooks = spikard_core::lifecycle::LifecycleHooks<Request<Body>, Response<Body>>;
12pub type LifecycleHooksBuilder = spikard_core::lifecycle::LifecycleHooksBuilder<Request<Body>, Response<Body>>;
13
14/// Create a request hook for the current target.
15#[cfg(not(target_arch = "wasm32"))]
16pub fn request_hook<F, Fut>(name: impl Into<String>, func: F) -> Arc<dyn LifecycleHook<Request<Body>, Response<Body>>>
17where
18    F: Fn(Request<Body>) -> Fut + Send + Sync + 'static,
19    Fut: std::future::Future<Output = Result<HookResult<Request<Body>, Response<Body>>, String>> + Send + 'static,
20{
21    spikard_core::lifecycle::request_hook::<Request<Body>, Response<Body>, _, _>(name, func)
22}
23
24/// Create a request hook for wasm targets (no Send on futures).
25#[cfg(target_arch = "wasm32")]
26pub fn request_hook<F, Fut>(name: impl Into<String>, func: F) -> Arc<dyn LifecycleHook<Request<Body>, Response<Body>>>
27where
28    F: Fn(Request<Body>) -> Fut + Send + Sync + 'static,
29    Fut: std::future::Future<Output = Result<HookResult<Request<Body>, Response<Body>>, String>> + 'static,
30{
31    spikard_core::lifecycle::request_hook::<Request<Body>, Response<Body>, _, _>(name, func)
32}
33
34/// Create a response hook for the current target.
35#[cfg(not(target_arch = "wasm32"))]
36pub fn response_hook<F, Fut>(name: impl Into<String>, func: F) -> Arc<dyn LifecycleHook<Request<Body>, Response<Body>>>
37where
38    F: Fn(Response<Body>) -> Fut + Send + Sync + 'static,
39    Fut: std::future::Future<Output = Result<HookResult<Response<Body>, Response<Body>>, String>> + Send + 'static,
40{
41    spikard_core::lifecycle::response_hook::<Request<Body>, Response<Body>, _, _>(name, func)
42}
43
44/// Create a response hook for wasm targets (no Send on futures).
45#[cfg(target_arch = "wasm32")]
46pub fn response_hook<F, Fut>(name: impl Into<String>, func: F) -> Arc<dyn LifecycleHook<Request<Body>, Response<Body>>>
47where
48    F: Fn(Response<Body>) -> Fut + Send + Sync + 'static,
49    Fut: std::future::Future<Output = Result<HookResult<Response<Body>, Response<Body>>, String>> + 'static,
50{
51    spikard_core::lifecycle::response_hook::<Request<Body>, Response<Body>, _, _>(name, func)
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57    use axum::body::Body;
58    use axum::http::{Request, Response, StatusCode};
59    use std::future::Future;
60    use std::pin::Pin;
61
62    /// Test hook that always continues
63    struct ContinueHook {
64        name: String,
65    }
66
67    impl LifecycleHook<Request<Body>, Response<Body>> for ContinueHook {
68        fn name(&self) -> &str {
69            &self.name
70        }
71
72        fn execute_request<'a>(
73            &self,
74            req: Request<Body>,
75        ) -> Pin<Box<dyn Future<Output = Result<HookResult<Request<Body>, Response<Body>>, String>> + Send + 'a>>
76        {
77            Box::pin(async move { Ok(HookResult::Continue(req)) })
78        }
79
80        fn execute_response<'a>(
81            &self,
82            resp: Response<Body>,
83        ) -> Pin<Box<dyn Future<Output = Result<HookResult<Response<Body>, Response<Body>>, String>> + Send + 'a>>
84        {
85            Box::pin(async move { Ok(HookResult::Continue(resp)) })
86        }
87    }
88
89    /// Test hook that short-circuits with a 401 response
90    struct ShortCircuitHook {
91        name: String,
92    }
93
94    impl LifecycleHook<Request<Body>, Response<Body>> for ShortCircuitHook {
95        fn name(&self) -> &str {
96            &self.name
97        }
98
99        fn execute_request<'a>(
100            &self,
101            _req: Request<Body>,
102        ) -> Pin<Box<dyn Future<Output = Result<HookResult<Request<Body>, Response<Body>>, String>> + Send + 'a>>
103        {
104            Box::pin(async move {
105                let response = Response::builder()
106                    .status(StatusCode::UNAUTHORIZED)
107                    .body(Body::from("Unauthorized"))
108                    .unwrap();
109                Ok(HookResult::ShortCircuit(response))
110            })
111        }
112
113        fn execute_response<'a>(
114            &self,
115            _resp: Response<Body>,
116        ) -> Pin<Box<dyn Future<Output = Result<HookResult<Response<Body>, Response<Body>>, String>> + Send + 'a>>
117        {
118            Box::pin(async move {
119                let response = Response::builder()
120                    .status(StatusCode::UNAUTHORIZED)
121                    .body(Body::from("Unauthorized"))
122                    .unwrap();
123                Ok(HookResult::ShortCircuit(response))
124            })
125        }
126    }
127
128    #[tokio::test]
129    async fn test_empty_hooks_fast_path() {
130        let hooks = LifecycleHooks::new();
131        assert!(hooks.is_empty());
132
133        let req = Request::builder().body(Body::empty()).unwrap();
134        let result = hooks.execute_on_request(req).await.unwrap();
135        assert!(matches!(result, HookResult::Continue(_)));
136    }
137
138    #[tokio::test]
139    async fn test_on_request_continue() {
140        let mut hooks = LifecycleHooks::new();
141        hooks.add_on_request(Arc::new(ContinueHook {
142            name: "test".to_string(),
143        }));
144
145        let req = Request::builder().body(Body::empty()).unwrap();
146        let result = hooks.execute_on_request(req).await.unwrap();
147        assert!(matches!(result, HookResult::Continue(_)));
148    }
149
150    #[tokio::test]
151    async fn test_on_request_short_circuit() {
152        let mut hooks = LifecycleHooks::new();
153        hooks.add_on_request(Arc::new(ShortCircuitHook {
154            name: "auth_check".to_string(),
155        }));
156
157        let req = Request::builder().body(Body::empty()).unwrap();
158        let result = hooks.execute_on_request(req).await.unwrap();
159
160        match result {
161            HookResult::ShortCircuit(resp) => {
162                assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
163            }
164            HookResult::Continue(_) => panic!("Expected ShortCircuit, got Continue"),
165        }
166    }
167
168    #[tokio::test]
169    async fn test_multiple_hooks_in_order() {
170        let mut hooks = LifecycleHooks::new();
171
172        hooks.add_on_request(Arc::new(ContinueHook {
173            name: "first".to_string(),
174        }));
175        hooks.add_on_request(Arc::new(ContinueHook {
176            name: "second".to_string(),
177        }));
178
179        let req = Request::builder().body(Body::empty()).unwrap();
180        let result = hooks.execute_on_request(req).await.unwrap();
181        assert!(matches!(result, HookResult::Continue(_)));
182    }
183
184    #[tokio::test]
185    async fn test_short_circuit_stops_execution() {
186        let mut hooks = LifecycleHooks::new();
187
188        hooks.add_on_request(Arc::new(ShortCircuitHook {
189            name: "short_circuit".to_string(),
190        }));
191        hooks.add_on_request(Arc::new(ContinueHook {
192            name: "never_executed".to_string(),
193        }));
194
195        let req = Request::builder().body(Body::empty()).unwrap();
196        let result = hooks.execute_on_request(req).await.unwrap();
197
198        match result {
199            HookResult::ShortCircuit(_) => {}
200            HookResult::Continue(_) => panic!("Expected ShortCircuit, got Continue"),
201        }
202    }
203
204    #[tokio::test]
205    async fn test_on_response_hooks() {
206        let mut hooks = LifecycleHooks::new();
207        hooks.add_on_response(Arc::new(ContinueHook {
208            name: "response_hook".to_string(),
209        }));
210
211        let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
212
213        let result = hooks.execute_on_response(resp).await.unwrap();
214        assert_eq!(result.status(), StatusCode::OK);
215    }
216
217    #[tokio::test]
218    async fn test_request_hook_builder() {
219        let hook = request_hook("test", |req| async move { Ok(HookResult::Continue(req)) });
220
221        let req = Request::builder().body(Body::empty()).unwrap();
222        let result = hook.execute_request(req).await.unwrap();
223
224        assert!(matches!(result, HookResult::Continue(_)));
225    }
226
227    #[tokio::test]
228    async fn test_request_hook_with_modification() {
229        let hook = request_hook("add_header", |mut req| async move {
230            req.headers_mut()
231                .insert("X-Custom-Header", axum::http::HeaderValue::from_static("test-value"));
232            Ok(HookResult::Continue(req))
233        });
234
235        let req = Request::builder().body(Body::empty()).unwrap();
236        let result = hook.execute_request(req).await.unwrap();
237
238        match result {
239            HookResult::Continue(req) => {
240                assert_eq!(req.headers().get("X-Custom-Header").unwrap(), "test-value");
241            }
242            HookResult::ShortCircuit(_) => panic!("Expected Continue"),
243        }
244    }
245
246    #[tokio::test]
247    async fn test_request_hook_short_circuit() {
248        let hook = request_hook("auth", |_req| async move {
249            let response = Response::builder()
250                .status(StatusCode::UNAUTHORIZED)
251                .body(Body::from("Unauthorized"))
252                .unwrap();
253            Ok(HookResult::ShortCircuit(response))
254        });
255
256        let req = Request::builder().body(Body::empty()).unwrap();
257        let result = hook.execute_request(req).await.unwrap();
258
259        match result {
260            HookResult::ShortCircuit(resp) => {
261                assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
262            }
263            HookResult::Continue(_) => panic!("Expected ShortCircuit"),
264        }
265    }
266
267    #[tokio::test]
268    async fn test_response_hook_builder() {
269        let hook = response_hook("security", |mut resp| async move {
270            resp.headers_mut()
271                .insert("X-Frame-Options", axum::http::HeaderValue::from_static("DENY"));
272            Ok(HookResult::Continue(resp))
273        });
274
275        let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
276
277        let result = hook.execute_response(resp).await.unwrap();
278
279        match result {
280            HookResult::Continue(resp) => {
281                assert_eq!(resp.headers().get("X-Frame-Options").unwrap(), "DENY");
282                assert_eq!(resp.status(), StatusCode::OK);
283            }
284            HookResult::ShortCircuit(_) => panic!("Expected Continue"),
285        }
286    }
287
288    #[tokio::test]
289    async fn test_builder_pattern() {
290        let hooks = LifecycleHooks::builder()
291            .on_request(request_hook(
292                "logger",
293                |req| async move { Ok(HookResult::Continue(req)) },
294            ))
295            .pre_handler(request_hook("auth", |req| async move { Ok(HookResult::Continue(req)) }))
296            .on_response(response_hook("security", |resp| async move {
297                Ok(HookResult::Continue(resp))
298            }))
299            .build();
300
301        assert!(!hooks.is_empty());
302
303        let req = Request::builder().body(Body::empty()).unwrap();
304        let result = hooks.execute_on_request(req).await.unwrap();
305        assert!(matches!(result, HookResult::Continue(_)));
306    }
307
308    #[tokio::test]
309    async fn test_builder_with_multiple_hooks() {
310        let hooks = LifecycleHooks::builder()
311            .on_request(request_hook("first", |mut req| async move {
312                req.headers_mut()
313                    .insert("X-First", axum::http::HeaderValue::from_static("1"));
314                Ok(HookResult::Continue(req))
315            }))
316            .on_request(request_hook("second", |mut req| async move {
317                req.headers_mut()
318                    .insert("X-Second", axum::http::HeaderValue::from_static("2"));
319                Ok(HookResult::Continue(req))
320            }))
321            .build();
322
323        let req = Request::builder().body(Body::empty()).unwrap();
324        let result = hooks.execute_on_request(req).await.unwrap();
325
326        match result {
327            HookResult::Continue(req) => {
328                assert_eq!(req.headers().get("X-First").unwrap(), "1");
329                assert_eq!(req.headers().get("X-Second").unwrap(), "2");
330            }
331            HookResult::ShortCircuit(_) => panic!("Expected Continue"),
332        }
333    }
334
335    #[tokio::test]
336    async fn test_builder_short_circuit_stops_chain() {
337        let hooks = LifecycleHooks::builder()
338            .on_request(request_hook(
339                "first",
340                |req| async move { Ok(HookResult::Continue(req)) },
341            ))
342            .on_request(request_hook("short_circuit", |_req| async move {
343                let response = Response::builder()
344                    .status(StatusCode::FORBIDDEN)
345                    .body(Body::from("Blocked"))
346                    .unwrap();
347                Ok(HookResult::ShortCircuit(response))
348            }))
349            .on_request(request_hook("never_called", |mut req| async move {
350                req.headers_mut()
351                    .insert("X-Should-Not-Exist", axum::http::HeaderValue::from_static("value"));
352                Ok(HookResult::Continue(req))
353            }))
354            .build();
355
356        let req = Request::builder().body(Body::empty()).unwrap();
357        let result = hooks.execute_on_request(req).await.unwrap();
358
359        match result {
360            HookResult::ShortCircuit(resp) => {
361                assert_eq!(resp.status(), StatusCode::FORBIDDEN);
362            }
363            HookResult::Continue(_) => panic!("Expected ShortCircuit"),
364        }
365    }
366
367    #[tokio::test]
368    async fn test_all_hook_types() {
369        let hooks = LifecycleHooks::builder()
370            .on_request(request_hook("on_request", |req| async move {
371                Ok(HookResult::Continue(req))
372            }))
373            .pre_validation(request_hook("pre_validation", |req| async move {
374                Ok(HookResult::Continue(req))
375            }))
376            .pre_handler(request_hook("pre_handler", |req| async move {
377                Ok(HookResult::Continue(req))
378            }))
379            .on_response(response_hook("on_response", |resp| async move {
380                Ok(HookResult::Continue(resp))
381            }))
382            .on_error(response_hook("on_error", |resp| async move {
383                Ok(HookResult::Continue(resp))
384            }))
385            .build();
386
387        assert!(!hooks.is_empty());
388
389        let req = Request::builder().body(Body::empty()).unwrap();
390        assert!(matches!(
391            hooks.execute_on_request(req).await.unwrap(),
392            HookResult::Continue(_)
393        ));
394
395        let req = Request::builder().body(Body::empty()).unwrap();
396        assert!(matches!(
397            hooks.execute_pre_validation(req).await.unwrap(),
398            HookResult::Continue(_)
399        ));
400
401        let req = Request::builder().body(Body::empty()).unwrap();
402        assert!(matches!(
403            hooks.execute_pre_handler(req).await.unwrap(),
404            HookResult::Continue(_)
405        ));
406
407        let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
408        let result = hooks.execute_on_response(resp).await.unwrap();
409        assert_eq!(result.status(), StatusCode::OK);
410
411        let resp = Response::builder()
412            .status(StatusCode::INTERNAL_SERVER_ERROR)
413            .body(Body::empty())
414            .unwrap();
415        let result = hooks.execute_on_error(resp).await.unwrap();
416        assert_eq!(result.status(), StatusCode::INTERNAL_SERVER_ERROR);
417    }
418
419    #[tokio::test]
420    async fn test_empty_builder() {
421        let hooks = LifecycleHooks::builder().build();
422        assert!(hooks.is_empty());
423
424        let req = Request::builder().body(Body::empty()).unwrap();
425        let result = hooks.execute_on_request(req).await.unwrap();
426        assert!(matches!(result, HookResult::Continue(_)));
427    }
428
429    #[tokio::test]
430    async fn test_hook_chaining_modifies_request_sequentially() {
431        let hooks = LifecycleHooks::builder()
432            .on_request(request_hook("add_header_1", |mut req| async move {
433                req.headers_mut()
434                    .insert("X-Chain-1", axum::http::HeaderValue::from_static("first"));
435                Ok(HookResult::Continue(req))
436            }))
437            .on_request(request_hook("add_header_2", |mut req| async move {
438                req.headers_mut()
439                    .insert("X-Chain-2", axum::http::HeaderValue::from_static("second"));
440                Ok(HookResult::Continue(req))
441            }))
442            .on_request(request_hook("add_header_3", |mut req| async move {
443                req.headers_mut()
444                    .insert("X-Chain-3", axum::http::HeaderValue::from_static("third"));
445                Ok(HookResult::Continue(req))
446            }))
447            .build();
448
449        let req = Request::builder().body(Body::empty()).unwrap();
450        let result = hooks.execute_on_request(req).await.unwrap();
451
452        match result {
453            HookResult::Continue(req) => {
454                assert_eq!(req.headers().get("X-Chain-1").unwrap(), "first");
455                assert_eq!(req.headers().get("X-Chain-2").unwrap(), "second");
456                assert_eq!(req.headers().get("X-Chain-3").unwrap(), "third");
457            }
458            HookResult::ShortCircuit(_) => panic!("Expected Continue"),
459        }
460    }
461
462    #[tokio::test]
463    async fn test_response_hook_chaining_modifies_status_and_headers() {
464        let hooks = LifecycleHooks::builder()
465            .on_response(response_hook("add_security_header", |mut resp| async move {
466                resp.headers_mut().insert(
467                    "X-Content-Type-Options",
468                    axum::http::HeaderValue::from_static("nosniff"),
469                );
470                Ok(HookResult::Continue(resp))
471            }))
472            .on_response(response_hook("add_cache_header", |mut resp| async move {
473                resp.headers_mut()
474                    .insert("Cache-Control", axum::http::HeaderValue::from_static("no-cache"));
475                Ok(HookResult::Continue(resp))
476            }))
477            .on_response(response_hook("add_custom_header", |mut resp| async move {
478                resp.headers_mut()
479                    .insert("X-Custom", axum::http::HeaderValue::from_static("value"));
480                Ok(HookResult::Continue(resp))
481            }))
482            .build();
483
484        let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
485
486        let result = hooks.execute_on_response(resp).await.unwrap();
487
488        assert_eq!(result.status(), StatusCode::OK);
489        assert_eq!(result.headers().get("X-Content-Type-Options").unwrap(), "nosniff");
490        assert_eq!(result.headers().get("Cache-Control").unwrap(), "no-cache");
491        assert_eq!(result.headers().get("X-Custom").unwrap(), "value");
492    }
493
494    #[tokio::test]
495    async fn test_pre_validation_and_pre_handler_chaining() {
496        let hooks = LifecycleHooks::builder()
497            .pre_validation(request_hook("validate_auth", |mut req| async move {
498                req.headers_mut()
499                    .insert("X-Validated", axum::http::HeaderValue::from_static("true"));
500                Ok(HookResult::Continue(req))
501            }))
502            .pre_handler(request_hook("prepare_handler", |mut req| async move {
503                req.headers_mut()
504                    .insert("X-Prepared", axum::http::HeaderValue::from_static("true"));
505                Ok(HookResult::Continue(req))
506            }))
507            .build();
508
509        let req = Request::builder().body(Body::empty()).unwrap();
510        let result = hooks.execute_pre_validation(req).await.unwrap();
511
512        match result {
513            HookResult::Continue(req) => {
514                assert_eq!(req.headers().get("X-Validated").unwrap(), "true");
515                assert!(!req.headers().contains_key("X-Prepared"));
516            }
517            HookResult::ShortCircuit(_) => panic!("Expected Continue"),
518        }
519
520        let req = Request::builder()
521            .header("X-Validated", "true")
522            .body(Body::empty())
523            .unwrap();
524        let result = hooks.execute_pre_handler(req).await.unwrap();
525
526        match result {
527            HookResult::Continue(req) => {
528                assert_eq!(req.headers().get("X-Prepared").unwrap(), "true");
529            }
530            HookResult::ShortCircuit(_) => panic!("Expected Continue"),
531        }
532    }
533
534    #[tokio::test]
535    async fn test_hook_chain_with_state_passing() {
536        let hooks = LifecycleHooks::builder()
537            .on_request(request_hook("add_user_id", |mut req| async move {
538                req.headers_mut()
539                    .insert("X-User-ID", axum::http::HeaderValue::from_static("123"));
540                Ok(HookResult::Continue(req))
541            }))
542            .on_request(request_hook("add_session_id", |mut req| async move {
543                if let Some(user_id) = req.headers().get("X-User-ID") {
544                    if user_id == "123" {
545                        req.headers_mut()
546                            .insert("X-Session-ID", axum::http::HeaderValue::from_static("session_abc"));
547                    }
548                }
549                Ok(HookResult::Continue(req))
550            }))
551            .build();
552
553        let req = Request::builder().body(Body::empty()).unwrap();
554        let result = hooks.execute_on_request(req).await.unwrap();
555
556        match result {
557            HookResult::Continue(req) => {
558                assert_eq!(req.headers().get("X-User-ID").unwrap(), "123");
559                assert_eq!(req.headers().get("X-Session-ID").unwrap(), "session_abc");
560            }
561            HookResult::ShortCircuit(_) => panic!("Expected Continue"),
562        }
563    }
564
565    #[tokio::test]
566    async fn test_pre_validation_short_circuit_stops_subsequent_hooks() {
567        let hooks = LifecycleHooks::builder()
568            .on_request(request_hook("on_request", |req| async move {
569                println!("on_request executed");
570                Ok(HookResult::Continue(req))
571            }))
572            .pre_validation(request_hook("pre_validation_abort", |_req| async move {
573                println!("pre_validation executed - short circuiting");
574                let response = Response::builder()
575                    .status(StatusCode::BAD_REQUEST)
576                    .body(Body::from("Validation failed"))
577                    .unwrap();
578                Ok(HookResult::ShortCircuit(response))
579            }))
580            .pre_handler(request_hook("pre_handler", |req| async move {
581                println!("pre_handler executed - should NOT happen");
582                Ok(HookResult::Continue(req))
583            }))
584            .build();
585
586        let req = Request::builder().body(Body::empty()).unwrap();
587        let result = hooks.execute_pre_validation(req).await.unwrap();
588
589        match result {
590            HookResult::ShortCircuit(resp) => {
591                assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
592            }
593            HookResult::Continue(_) => panic!("Expected ShortCircuit"),
594        }
595    }
596
597    #[tokio::test]
598    async fn test_pre_handler_short_circuit_returns_early_response() {
599        let hooks = LifecycleHooks::builder()
600            .pre_validation(request_hook("pre_validation", |req| async move {
601                Ok(HookResult::Continue(req))
602            }))
603            .pre_handler(request_hook("rate_limit_check", |_req| async move {
604                let response = Response::builder()
605                    .status(StatusCode::TOO_MANY_REQUESTS)
606                    .body(Body::from("Rate limit exceeded"))
607                    .unwrap();
608                Ok(HookResult::ShortCircuit(response))
609            }))
610            .build();
611
612        let req = Request::builder().body(Body::empty()).unwrap();
613        let result = hooks.execute_pre_handler(req).await.unwrap();
614
615        match result {
616            HookResult::ShortCircuit(resp) => {
617                assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
618            }
619            HookResult::Continue(_) => panic!("Expected ShortCircuit"),
620        }
621    }
622
623    #[tokio::test]
624    async fn test_short_circuit_in_middle_of_chain() {
625        let hooks = LifecycleHooks::builder()
626            .on_request(request_hook("hook_1", |mut req| async move {
627                req.headers_mut()
628                    .insert("X-Executed-1", axum::http::HeaderValue::from_static("yes"));
629                Ok(HookResult::Continue(req))
630            }))
631            .on_request(request_hook("hook_2_abort", |_req| async move {
632                let response = Response::builder()
633                    .status(StatusCode::FORBIDDEN)
634                    .body(Body::from("Access denied"))
635                    .unwrap();
636                Ok(HookResult::ShortCircuit(response))
637            }))
638            .on_request(request_hook("hook_3", |mut req| async move {
639                req.headers_mut()
640                    .insert("X-Executed-3", axum::http::HeaderValue::from_static("yes"));
641                Ok(HookResult::Continue(req))
642            }))
643            .build();
644
645        let req = Request::builder().body(Body::empty()).unwrap();
646        let result = hooks.execute_on_request(req).await.unwrap();
647
648        match result {
649            HookResult::ShortCircuit(resp) => {
650                assert_eq!(resp.status(), StatusCode::FORBIDDEN);
651            }
652            HookResult::Continue(_) => panic!("Expected ShortCircuit"),
653        }
654    }
655
656    #[tokio::test]
657    async fn test_short_circuit_with_custom_response_headers() {
658        let hooks = LifecycleHooks::builder()
659            .pre_validation(request_hook("auth_check", |_req| async move {
660                let response = Response::builder()
661                    .status(StatusCode::UNAUTHORIZED)
662                    .header("WWW-Authenticate", "Bearer realm=\"api\"")
663                    .body(Body::from("Authorization required"))
664                    .unwrap();
665                Ok(HookResult::ShortCircuit(response))
666            }))
667            .build();
668
669        let req = Request::builder().body(Body::empty()).unwrap();
670        let result = hooks.execute_pre_validation(req).await.unwrap();
671
672        match result {
673            HookResult::ShortCircuit(resp) => {
674                assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
675                assert_eq!(resp.headers().get("WWW-Authenticate").unwrap(), "Bearer realm=\"api\"");
676            }
677            HookResult::Continue(_) => panic!("Expected ShortCircuit"),
678        }
679    }
680
681    #[tokio::test]
682    async fn test_hook_error_propagates_through_chain() {
683        let hooks = LifecycleHooks::builder()
684            .on_request(request_hook("good_hook", |mut req| async move {
685                req.headers_mut()
686                    .insert("X-Good", axum::http::HeaderValue::from_static("yes"));
687                Ok(HookResult::Continue(req))
688            }))
689            .on_request(request_hook("bad_hook", |_req| async move {
690                Err("Something went wrong in hook".to_string())
691            }))
692            .build();
693
694        let req = Request::builder().body(Body::empty()).unwrap();
695        let result = hooks.execute_on_request(req).await;
696
697        assert!(result.is_err());
698        assert_eq!(result.unwrap_err(), "Something went wrong in hook");
699    }
700
701    #[tokio::test]
702    async fn test_error_in_pre_validation_stops_chain() {
703        let hooks = LifecycleHooks::builder()
704            .pre_validation(request_hook("validation_hook", |_req| async move {
705                Err("Validation error: invalid input".to_string())
706            }))
707            .pre_handler(request_hook("handler_prep", |req| async move {
708                Ok(HookResult::Continue(req))
709            }))
710            .build();
711
712        let req = Request::builder().body(Body::empty()).unwrap();
713        let result = hooks.execute_pre_validation(req).await;
714
715        assert!(result.is_err());
716        assert!(result.unwrap_err().contains("Validation error"));
717    }
718
719    #[tokio::test]
720    async fn test_on_error_hook_transforms_response() {
721        let hooks = LifecycleHooks::builder()
722            .on_error(response_hook("transform_error", |mut resp| async move {
723                resp.headers_mut()
724                    .insert("X-Error-Handled", axum::http::HeaderValue::from_static("true"));
725
726                let _status = resp.status();
727                Ok(HookResult::Continue(resp))
728            }))
729            .build();
730
731        let resp = Response::builder()
732            .status(StatusCode::INTERNAL_SERVER_ERROR)
733            .body(Body::empty())
734            .unwrap();
735
736        let result = hooks.execute_on_error(resp).await.unwrap();
737
738        assert_eq!(result.headers().get("X-Error-Handled").unwrap(), "true");
739    }
740
741    #[tokio::test]
742    async fn test_response_hook_error_propagates() {
743        let hooks = LifecycleHooks::builder()
744            .on_response(response_hook("good_response_hook", |mut resp| async move {
745                resp.headers_mut()
746                    .insert("X-Processed", axum::http::HeaderValue::from_static("yes"));
747                Ok(HookResult::Continue(resp))
748            }))
749            .on_response(response_hook("bad_response_hook", |_resp| async move {
750                Err("Error processing response".to_string())
751            }))
752            .build();
753
754        let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
755
756        let result = hooks.execute_on_response(resp).await;
757
758        assert!(result.is_err());
759        assert_eq!(result.unwrap_err(), "Error processing response");
760    }
761
762    #[tokio::test]
763    async fn test_error_hook_error_propagates() {
764        let hooks = LifecycleHooks::builder()
765            .on_error(response_hook("error_hook_1", |mut resp| async move {
766                resp.headers_mut()
767                    .insert("X-Error-Processed", axum::http::HeaderValue::from_static("1"));
768                Ok(HookResult::Continue(resp))
769            }))
770            .on_error(response_hook("error_hook_2_fails", |_resp| async move {
771                Err("Error in error hook".to_string())
772            }))
773            .build();
774
775        let resp = Response::builder()
776            .status(StatusCode::INTERNAL_SERVER_ERROR)
777            .body(Body::empty())
778            .unwrap();
779
780        let result = hooks.execute_on_error(resp).await;
781
782        assert!(result.is_err());
783        assert_eq!(result.unwrap_err(), "Error in error hook");
784    }
785
786    #[tokio::test]
787    async fn test_on_request_adds_multiple_headers() {
788        let hooks = LifecycleHooks::builder()
789            .on_request(request_hook("add_request_headers", |mut req| async move {
790                req.headers_mut()
791                    .insert("X-Request-ID", axum::http::HeaderValue::from_static("req_123"));
792                req.headers_mut()
793                    .insert("X-Timestamp", axum::http::HeaderValue::from_static("2025-01-01"));
794                req.headers_mut()
795                    .insert("X-Processed", axum::http::HeaderValue::from_static("true"));
796                Ok(HookResult::Continue(req))
797            }))
798            .build();
799
800        let req = Request::builder().body(Body::empty()).unwrap();
801        let result = hooks.execute_on_request(req).await.unwrap();
802
803        match result {
804            HookResult::Continue(req) => {
805                assert_eq!(req.headers().get("X-Request-ID").unwrap(), "req_123");
806                assert_eq!(req.headers().get("X-Timestamp").unwrap(), "2025-01-01");
807                assert_eq!(req.headers().get("X-Processed").unwrap(), "true");
808            }
809            HookResult::ShortCircuit(_) => panic!("Expected Continue"),
810        }
811    }
812
813    #[tokio::test]
814    async fn test_on_response_adds_security_headers() {
815        let hooks = LifecycleHooks::builder()
816            .on_response(response_hook("add_security_headers", |mut resp| async move {
817                resp.headers_mut()
818                    .insert("X-Frame-Options", axum::http::HeaderValue::from_static("DENY"));
819                resp.headers_mut().insert(
820                    "X-Content-Type-Options",
821                    axum::http::HeaderValue::from_static("nosniff"),
822                );
823                resp.headers_mut().insert(
824                    "Strict-Transport-Security",
825                    axum::http::HeaderValue::from_static("max-age=31536000"),
826                );
827                Ok(HookResult::Continue(resp))
828            }))
829            .build();
830
831        let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
832
833        let result = hooks.execute_on_response(resp).await.unwrap();
834
835        assert_eq!(result.headers().get("X-Frame-Options").unwrap(), "DENY");
836        assert_eq!(result.headers().get("X-Content-Type-Options").unwrap(), "nosniff");
837        assert_eq!(
838            result.headers().get("Strict-Transport-Security").unwrap(),
839            "max-age=31536000"
840        );
841    }
842
843    #[tokio::test]
844    async fn test_pre_handler_modifies_request_before_execution() {
845        let hooks = LifecycleHooks::builder()
846            .pre_handler(request_hook("inject_context", |mut req| async move {
847                req.headers_mut().insert(
848                    "X-Handler-Context",
849                    axum::http::HeaderValue::from_static("context_data"),
850                );
851                req.headers_mut()
852                    .insert("X-Injected", axum::http::HeaderValue::from_static("true"));
853                Ok(HookResult::Continue(req))
854            }))
855            .build();
856
857        let req = Request::builder().body(Body::empty()).unwrap();
858        let result = hooks.execute_pre_handler(req).await.unwrap();
859
860        match result {
861            HookResult::Continue(req) => {
862                assert_eq!(req.headers().get("X-Handler-Context").unwrap(), "context_data");
863                assert_eq!(req.headers().get("X-Injected").unwrap(), "true");
864            }
865            HookResult::ShortCircuit(_) => panic!("Expected Continue"),
866        }
867    }
868
869    #[tokio::test]
870    async fn test_register_multiple_hooks_different_types() {
871        let mut hooks = LifecycleHooks::new();
872
873        hooks.add_on_request(request_hook("on_request_1", |req| async move {
874            Ok(HookResult::Continue(req))
875        }));
876
877        hooks.add_pre_validation(request_hook("pre_validation_1", |req| async move {
878            Ok(HookResult::Continue(req))
879        }));
880
881        hooks.add_pre_handler(request_hook("pre_handler_1", |req| async move {
882            Ok(HookResult::Continue(req))
883        }));
884
885        hooks.add_on_response(response_hook("on_response_1", |resp| async move {
886            Ok(HookResult::Continue(resp))
887        }));
888
889        hooks.add_on_error(response_hook("on_error_1", |resp| async move {
890            Ok(HookResult::Continue(resp))
891        }));
892
893        assert!(!hooks.is_empty());
894    }
895
896    #[tokio::test]
897    async fn test_builder_composition_with_request_and_response_hooks() {
898        let hooks = LifecycleHooks::builder()
899            .on_request(request_hook("req_1", |mut req| async move {
900                req.headers_mut()
901                    .insert("X-R1", axum::http::HeaderValue::from_static("1"));
902                Ok(HookResult::Continue(req))
903            }))
904            .on_request(request_hook("req_2", |mut req| async move {
905                req.headers_mut()
906                    .insert("X-R2", axum::http::HeaderValue::from_static("2"));
907                Ok(HookResult::Continue(req))
908            }))
909            .on_response(response_hook("resp_1", |mut resp| async move {
910                resp.headers_mut()
911                    .insert("X-Resp1", axum::http::HeaderValue::from_static("resp1"));
912                Ok(HookResult::Continue(resp))
913            }))
914            .on_response(response_hook("resp_2", |mut resp| async move {
915                resp.headers_mut()
916                    .insert("X-Resp2", axum::http::HeaderValue::from_static("resp2"));
917                Ok(HookResult::Continue(resp))
918            }))
919            .build();
920
921        let req = Request::builder().body(Body::empty()).unwrap();
922        let req_result = hooks.execute_on_request(req).await.unwrap();
923
924        match req_result {
925            HookResult::Continue(req) => {
926                assert_eq!(req.headers().get("X-R1").unwrap(), "1");
927                assert_eq!(req.headers().get("X-R2").unwrap(), "2");
928            }
929            HookResult::ShortCircuit(_) => panic!("Expected Continue"),
930        }
931
932        let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
933        let resp_result = hooks.execute_on_response(resp).await.unwrap();
934
935        assert_eq!(resp_result.headers().get("X-Resp1").unwrap(), "resp1");
936        assert_eq!(resp_result.headers().get("X-Resp2").unwrap(), "resp2");
937    }
938
939    #[tokio::test]
940    async fn test_multiple_hooks_accumulate_state() {
941        let hooks = LifecycleHooks::builder()
942            .on_request(request_hook("init_counter", |mut req| async move {
943                req.headers_mut()
944                    .insert("X-Count", axum::http::HeaderValue::from_static("0"));
945                Ok(HookResult::Continue(req))
946            }))
947            .on_request(request_hook("increment_1", |mut req| async move {
948                if let Some(count_header) = req.headers().get("X-Count") {
949                    if count_header == "0" {
950                        req.headers_mut()
951                            .insert("X-Count", axum::http::HeaderValue::from_static("1"));
952                    }
953                }
954                Ok(HookResult::Continue(req))
955            }))
956            .on_request(request_hook("increment_2", |mut req| async move {
957                if let Some(count_header) = req.headers().get("X-Count") {
958                    if count_header == "1" {
959                        req.headers_mut()
960                            .insert("X-Count", axum::http::HeaderValue::from_static("2"));
961                    }
962                }
963                Ok(HookResult::Continue(req))
964            }))
965            .build();
966
967        let req = Request::builder().body(Body::empty()).unwrap();
968        let result = hooks.execute_on_request(req).await.unwrap();
969
970        match result {
971            HookResult::Continue(req) => {
972                assert_eq!(req.headers().get("X-Count").unwrap(), "2");
973            }
974            HookResult::ShortCircuit(_) => panic!("Expected Continue"),
975        }
976    }
977
978    #[tokio::test]
979    async fn test_first_hook_short_circuits_second_continues() {
980        let hooks = LifecycleHooks::builder()
981            .on_request(request_hook("early_exit", |_req| async move {
982                let response = Response::builder()
983                    .status(StatusCode::FORBIDDEN)
984                    .body(Body::from("Early exit"))
985                    .unwrap();
986                Ok(HookResult::ShortCircuit(response))
987            }))
988            .on_request(request_hook("never_runs", |req| async move {
989                Ok(HookResult::Continue(req))
990            }))
991            .build();
992
993        let req = Request::builder().body(Body::empty()).unwrap();
994        let result = hooks.execute_on_request(req).await.unwrap();
995
996        match result {
997            HookResult::ShortCircuit(resp) => {
998                assert_eq!(resp.status(), StatusCode::FORBIDDEN);
999            }
1000            HookResult::Continue(_) => panic!("Expected ShortCircuit"),
1001        }
1002    }
1003
1004    #[tokio::test]
1005    async fn test_all_hook_phases_in_sequence() {
1006        let hooks = LifecycleHooks::builder()
1007            .on_request(request_hook("on_request", |req| async move {
1008                Ok(HookResult::Continue(req))
1009            }))
1010            .pre_validation(request_hook("pre_validation", |req| async move {
1011                Ok(HookResult::Continue(req))
1012            }))
1013            .pre_handler(request_hook("pre_handler", |req| async move {
1014                Ok(HookResult::Continue(req))
1015            }))
1016            .on_response(response_hook("on_response", |resp| async move {
1017                Ok(HookResult::Continue(resp))
1018            }))
1019            .on_error(response_hook("on_error", |resp| async move {
1020                Ok(HookResult::Continue(resp))
1021            }))
1022            .build();
1023
1024        let req = Request::builder().body(Body::empty()).unwrap();
1025        let _ = hooks.execute_on_request(req).await;
1026
1027        let req = Request::builder().body(Body::empty()).unwrap();
1028        let _ = hooks.execute_pre_validation(req).await;
1029
1030        let req = Request::builder().body(Body::empty()).unwrap();
1031        let _ = hooks.execute_pre_handler(req).await;
1032
1033        let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
1034        let _ = hooks.execute_on_response(resp).await;
1035
1036        let resp = Response::builder()
1037            .status(StatusCode::INTERNAL_SERVER_ERROR)
1038            .body(Body::empty())
1039            .unwrap();
1040        let _ = hooks.execute_on_error(resp).await;
1041    }
1042
1043    #[tokio::test]
1044    async fn test_hook_with_complex_header_manipulation() {
1045        let hooks = LifecycleHooks::builder()
1046            .on_request(request_hook("parse_auth", |mut req| async move {
1047                let has_auth = req.headers().contains_key("Authorization");
1048                let auth_status = if has_auth { "authenticated" } else { "anonymous" };
1049                req.headers_mut()
1050                    .insert("X-Auth-Status", axum::http::HeaderValue::from_static(auth_status));
1051                Ok(HookResult::Continue(req))
1052            }))
1053            .pre_validation(request_hook("validate_auth", |req| async move {
1054                if let Some(auth_header) = req.headers().get("X-Auth-Status") {
1055                    if auth_header == "anonymous" {
1056                        let response = Response::builder()
1057                            .status(StatusCode::UNAUTHORIZED)
1058                            .body(Body::from("Authentication required"))
1059                            .unwrap();
1060                        return Ok(HookResult::ShortCircuit(response));
1061                    }
1062                }
1063                Ok(HookResult::Continue(req))
1064            }))
1065            .build();
1066
1067        let auth_req = Request::builder()
1068            .header("Authorization", "Bearer token123")
1069            .body(Body::empty())
1070            .unwrap();
1071
1072        let result = hooks.execute_on_request(auth_req).await.unwrap();
1073        assert!(matches!(result, HookResult::Continue(_)));
1074
1075        let anon_req = Request::builder().body(Body::empty()).unwrap();
1076        let on_req_result = hooks.execute_on_request(anon_req).await.unwrap();
1077
1078        match on_req_result {
1079            HookResult::Continue(req) => {
1080                assert_eq!(req.headers().get("X-Auth-Status").unwrap(), "anonymous");
1081
1082                let val_result = hooks.execute_pre_validation(req).await.unwrap();
1083                assert!(matches!(val_result, HookResult::ShortCircuit(_)));
1084            }
1085            HookResult::ShortCircuit(_) => panic!("Expected Continue from on_request"),
1086        }
1087    }
1088
1089    #[tokio::test]
1090    async fn test_empty_hooks_no_overhead() {
1091        let hooks = LifecycleHooks::new();
1092        assert!(hooks.is_empty());
1093
1094        let req = Request::builder().body(Body::empty()).unwrap();
1095        let result = hooks.execute_on_request(req).await.unwrap();
1096        assert!(matches!(result, HookResult::Continue(_)));
1097
1098        let req = Request::builder().body(Body::empty()).unwrap();
1099        let result = hooks.execute_pre_validation(req).await.unwrap();
1100        assert!(matches!(result, HookResult::Continue(_)));
1101
1102        let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
1103        let result = hooks.execute_on_response(resp).await.unwrap();
1104        assert_eq!(result.status(), StatusCode::OK);
1105    }
1106
1107    #[tokio::test]
1108    async fn test_response_hook_short_circuit_treated_as_continue() {
1109        let hooks = LifecycleHooks::builder()
1110            .on_response(response_hook("hook_with_short_circuit", |mut resp| async move {
1111                resp.headers_mut()
1112                    .insert("X-Processed", axum::http::HeaderValue::from_static("yes"));
1113                Ok(HookResult::ShortCircuit(resp))
1114            }))
1115            .on_response(response_hook("second_hook", |mut resp| async move {
1116                resp.headers_mut()
1117                    .insert("X-Second", axum::http::HeaderValue::from_static("yes"));
1118                Ok(HookResult::Continue(resp))
1119            }))
1120            .build();
1121
1122        let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
1123
1124        let result = hooks.execute_on_response(resp).await.unwrap();
1125
1126        assert_eq!(result.headers().get("X-Processed").unwrap(), "yes");
1127        assert_eq!(result.headers().get("X-Second").unwrap(), "yes");
1128    }
1129
1130    #[tokio::test]
1131    async fn test_complex_pre_validation_flow_with_auth_and_content_check() {
1132        let hooks = LifecycleHooks::builder()
1133            .pre_validation(request_hook("check_auth", |req| async move {
1134                if !req.headers().contains_key("Authorization") {
1135                    return Ok(HookResult::ShortCircuit(
1136                        Response::builder()
1137                            .status(StatusCode::UNAUTHORIZED)
1138                            .body(Body::from("Missing auth"))
1139                            .unwrap(),
1140                    ));
1141                }
1142                Ok(HookResult::Continue(req))
1143            }))
1144            .pre_validation(request_hook("check_content_type", |req| async move {
1145                if req.method() == axum::http::Method::POST {
1146                    if !req.headers().contains_key("Content-Type") {
1147                        return Ok(HookResult::ShortCircuit(
1148                            Response::builder()
1149                                .status(StatusCode::BAD_REQUEST)
1150                                .body(Body::from("Missing Content-Type"))
1151                                .unwrap(),
1152                        ));
1153                    }
1154                }
1155                Ok(HookResult::Continue(req))
1156            }))
1157            .build();
1158
1159        let req = Request::builder().body(Body::empty()).unwrap();
1160        let result = hooks.execute_pre_validation(req).await.unwrap();
1161
1162        match result {
1163            HookResult::ShortCircuit(resp) => {
1164                assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1165            }
1166            HookResult::Continue(_) => panic!("Expected ShortCircuit for missing auth"),
1167        }
1168
1169        let req = Request::builder()
1170            .method(axum::http::Method::POST)
1171            .header("Authorization", "Bearer token")
1172            .body(Body::empty())
1173            .unwrap();
1174        let result = hooks.execute_pre_validation(req).await.unwrap();
1175
1176        match result {
1177            HookResult::ShortCircuit(resp) => {
1178                assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
1179            }
1180            HookResult::Continue(_) => panic!("Expected ShortCircuit for missing content type"),
1181        }
1182
1183        let req = Request::builder()
1184            .method(axum::http::Method::POST)
1185            .header("Authorization", "Bearer token")
1186            .header("Content-Type", "application/json")
1187            .body(Body::empty())
1188            .unwrap();
1189        let result = hooks.execute_pre_validation(req).await.unwrap();
1190
1191        assert!(matches!(result, HookResult::Continue(_)));
1192    }
1193}