1pub mod upload;
10pub mod validation;
11
12use std::collections::HashMap;
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::Arc;
16
17use axum::http::{Request, StatusCode};
18use axum::routing::get as axum_get;
19use axum::{Router as AxumRouter, body::Body};
20use schemars::JsonSchema;
21use serde::de::DeserializeOwned;
22use serde_json::Value;
23#[cfg(feature = "di")]
24use spikard_core::di;
25pub use spikard_http::{
26 CompressionConfig, CorsConfig, LifecycleHook, LifecycleHooks, LifecycleHooksBuilder, Method, RateLimitConfig,
27 ServerConfig, StaticFilesConfig,
28 cors::{add_cors_headers, handle_preflight, validate_cors_request},
29 handler_response::HandlerResponse,
30 handler_trait::HandlerResult,
31 lifecycle::{HookResult, request_hook, response_hook},
32 sse::{SseEvent, SseEventProducer},
33 websocket::WebSocketHandler,
34};
35use spikard_http::{
36 Route, RouteMetadata, SchemaRegistry, Server,
37 handler_trait::{Handler, RequestData},
38 sse::{SseState, sse_handler},
39 websocket::{WebSocketState, websocket_handler},
40};
41pub use upload::UploadFile;
42
43pub mod testing {
44 use super::{App, AppError};
45 use axum::Router as AxumRouter;
46 use axum::body::Body;
47 use axum::http::Request;
48 use axum_test::{TestServer as AxumTestServer, TestServerConfig, Transport};
49 pub use spikard_http::testing::{
50 MultipartFilePart, ResponseSnapshot, SnapshotError, SseEvent, SseStream, WebSocketConnection, WebSocketMessage,
51 build_multipart_body, encode_urlencoded_body,
52 };
53
54 pub struct TestServer {
59 mock_server: AxumTestServer,
60 http_server: AxumTestServer,
61 }
62
63 impl TestServer {
64 pub fn from_app(app: App) -> Result<Self, AppError> {
66 let router = app.into_router()?;
67 Self::from_router(router)
68 }
69
70 pub fn from_router(router: AxumRouter) -> Result<Self, AppError> {
72 let mock_server = AxumTestServer::new(router.clone()).map_err(|err| AppError::Server(err.to_string()))?;
73 let config = TestServerConfig {
74 transport: Some(Transport::HttpRandomPort),
75 ..Default::default()
76 };
77 let http_server =
78 AxumTestServer::new_with_config(router, config).map_err(|err| AppError::Server(err.to_string()))?;
79 Ok(Self {
80 mock_server,
81 http_server,
82 })
83 }
84
85 pub async fn call(&self, request: Request<Body>) -> Result<ResponseSnapshot, SnapshotError> {
87 let response = spikard_http::testing::call_test_server(&self.mock_server, request).await;
88 spikard_http::testing::snapshot_response(response).await
89 }
90
91 pub async fn connect_websocket(&self, path: &str) -> WebSocketConnection {
93 spikard_http::testing::connect_websocket(&self.http_server, path).await
94 }
95 }
96}
97
98pub struct App {
100 config: ServerConfig,
101 registry: SchemaRegistry,
102 routes: Vec<(Route, Arc<dyn Handler>)>,
103 metadata: Vec<RouteMetadata>,
104 attached_routers: Vec<AxumRouter>,
105}
106
107impl App {
108 pub fn new() -> Self {
110 Self {
111 config: ServerConfig::default(),
112 registry: SchemaRegistry::new(),
113 routes: Vec::new(),
114 metadata: Vec::new(),
115 attached_routers: Vec::new(),
116 }
117 }
118
119 pub fn config(mut self, config: ServerConfig) -> Self {
121 self.config = config;
122 self
123 }
124
125 pub fn route<H>(&mut self, builder: RouteBuilder, handler: H) -> std::result::Result<&mut Self, AppError>
127 where
128 H: IntoHandler + 'static,
129 {
130 let metadata = builder.into_metadata();
131 let route = Route::from_metadata(metadata.clone(), &self.registry).map_err(AppError::Route)?;
132 let handler = handler.into_handler();
133 self.routes.push((route, handler));
134 self.metadata.push(metadata);
135 Ok(self)
136 }
137
138 pub fn websocket<H>(&mut self, path: impl Into<String>, handler: H) -> &mut Self
140 where
141 H: WebSocketHandler + Send + Sync + 'static,
142 {
143 let _ = self.websocket_with_schemas(path, handler, None, None);
144 self
145 }
146
147 pub fn websocket_with_schemas<H>(
149 &mut self,
150 path: impl Into<String>,
151 handler: H,
152 message_schema: Option<serde_json::Value>,
153 response_schema: Option<serde_json::Value>,
154 ) -> std::result::Result<&mut Self, AppError>
155 where
156 H: WebSocketHandler + Send + Sync + 'static,
157 {
158 let state = if message_schema.is_some() || response_schema.is_some() {
159 WebSocketState::with_schemas(handler, message_schema, response_schema).map_err(AppError::Route)?
160 } else {
161 WebSocketState::new(handler)
162 };
163
164 self.register_stateful_ws_route(path, state)
165 }
166
167 pub fn sse<P>(&mut self, path: impl Into<String>, producer: P) -> &mut Self
169 where
170 P: SseEventProducer + Send + Sync + 'static,
171 {
172 let _ = self.sse_with_schema(path, producer, None);
173 self
174 }
175
176 pub fn sse_with_schema<P>(
178 &mut self,
179 path: impl Into<String>,
180 producer: P,
181 event_schema: Option<serde_json::Value>,
182 ) -> std::result::Result<&mut Self, AppError>
183 where
184 P: SseEventProducer + Send + Sync + 'static,
185 {
186 let state = if let Some(schema) = event_schema {
187 SseState::with_schema(producer, Some(schema)).map_err(AppError::Route)?
188 } else {
189 SseState::new(producer)
190 };
191
192 self.register_stateful_sse_route(path, state)
193 }
194
195 fn register_stateful_ws_route<H: WebSocketHandler + Send + Sync + 'static>(
197 &mut self,
198 path: impl Into<String>,
199 state: WebSocketState<H>,
200 ) -> std::result::Result<&mut Self, AppError> {
201 let path = normalize_path(path.into());
202 let router = AxumRouter::new().route(&path, axum_get(websocket_handler::<H>).with_state(state));
203 self.attached_routers.push(router);
204 Ok(self)
205 }
206
207 fn register_stateful_sse_route<P: SseEventProducer + Send + Sync + 'static>(
209 &mut self,
210 path: impl Into<String>,
211 state: SseState<P>,
212 ) -> std::result::Result<&mut Self, AppError> {
213 let path = normalize_path(path.into());
214 let router = AxumRouter::new().route(&path, axum_get(sse_handler::<P>).with_state(state));
215 self.attached_routers.push(router);
216 Ok(self)
217 }
218
219 pub fn merge_axum_router(mut self, router: AxumRouter) -> Self {
221 self.attached_routers.push(router);
222 self
223 }
224
225 pub fn attach_axum_router(&mut self, router: AxumRouter) -> &mut Self {
227 self.attached_routers.push(router);
228 self
229 }
230
231 pub fn into_router(self) -> std::result::Result<axum::Router, AppError> {
233 let App {
234 config,
235 routes,
236 metadata,
237 attached_routers,
238 ..
239 } = self;
240 let mut router = Server::with_handlers_and_metadata(config, routes, metadata).map_err(AppError::Server)?;
241 for extra in attached_routers {
242 router = router.merge(extra);
243 }
244 Ok(router)
245 }
246
247 pub async fn run(self) -> std::result::Result<(), AppError> {
249 let App {
250 config,
251 routes,
252 metadata,
253 attached_routers,
254 ..
255 } = self;
256 let mut router =
257 Server::with_handlers_and_metadata(config.clone(), routes, metadata).map_err(AppError::Server)?;
258 for extra in attached_routers {
259 router = router.merge(extra);
260 }
261 Server::run_with_config(router, config)
262 .await
263 .map_err(|err| AppError::Server(err.to_string()))
264 }
265}
266
267impl Default for App {
268 fn default() -> Self {
269 Self::new()
270 }
271}
272
273pub struct RouteBuilder {
275 method: Method,
276 path: String,
277 handler_name: String,
278 request_schema: Option<Value>,
279 response_schema: Option<Value>,
280 parameter_schema: Option<Value>,
281 file_params: Option<Value>,
282 cors: Option<CorsConfig>,
283 is_async: bool,
284}
285
286impl RouteBuilder {
287 pub fn new(method: Method, path: impl Into<String>) -> Self {
289 let path = path.into();
290 let handler_name = default_handler_name(&method, &path);
291 Self {
292 method,
293 path,
294 handler_name,
295 request_schema: None,
296 response_schema: None,
297 parameter_schema: None,
298 file_params: None,
299 cors: None,
300 is_async: true,
301 }
302 }
303
304 pub fn handler_name(mut self, name: impl Into<String>) -> Self {
306 self.handler_name = name.into();
307 self
308 }
309
310 pub fn request_body<T: JsonSchema>(mut self) -> Self {
312 self.request_schema = Some(schema_for::<T>());
313 self
314 }
315
316 pub fn response_body<T: JsonSchema>(mut self) -> Self {
318 self.response_schema = Some(schema_for::<T>());
319 self
320 }
321
322 pub fn params<T: JsonSchema>(mut self) -> Self {
324 self.parameter_schema = Some(schema_for::<T>());
325 self
326 }
327
328 pub fn request_schema_json(mut self, schema: Value) -> Self {
330 self.request_schema = Some(schema);
331 self
332 }
333
334 pub fn response_schema_json(mut self, schema: Value) -> Self {
336 self.response_schema = Some(schema);
337 self
338 }
339
340 pub fn params_schema_json(mut self, schema: Value) -> Self {
342 self.parameter_schema = Some(schema);
343 self
344 }
345
346 pub fn file_params_json(mut self, schema: Value) -> Self {
348 self.file_params = Some(schema);
349 self
350 }
351
352 pub fn cors(mut self, cors: CorsConfig) -> Self {
354 self.cors = Some(cors);
355 self
356 }
357
358 pub fn sync(mut self) -> Self {
360 self.is_async = false;
361 self
362 }
363
364 fn into_metadata(self) -> RouteMetadata {
365 #[cfg(feature = "di")]
366 {
367 RouteMetadata {
368 method: self.method.to_string(),
369 path: self.path,
370 handler_name: self.handler_name,
371 request_schema: self.request_schema,
372 response_schema: self.response_schema,
373 parameter_schema: self.parameter_schema,
374 file_params: self.file_params,
375 is_async: self.is_async,
376 cors: self.cors,
377 body_param_name: None,
378 handler_dependencies: None,
379 jsonrpc_method: None,
380 }
381 }
382 #[cfg(not(feature = "di"))]
383 {
384 RouteMetadata {
385 method: self.method.to_string(),
386 path: self.path,
387 handler_name: self.handler_name,
388 request_schema: self.request_schema,
389 response_schema: self.response_schema,
390 parameter_schema: self.parameter_schema,
391 file_params: self.file_params,
392 is_async: self.is_async,
393 cors: self.cors,
394 body_param_name: None,
395 jsonrpc_method: None,
396 }
397 }
398 }
399}
400
401macro_rules! http_method {
402 (
403 $(#[$meta:meta])*
404 $name:ident,
405 $method:expr
406 ) => {
407 $(#[$meta])*
408 pub fn $name(path: impl Into<String>) -> RouteBuilder {
409 RouteBuilder::new($method, path)
410 }
411 };
412}
413
414http_method!(
415 get,
417 Method::Get
418);
419
420http_method!(
421 post,
423 Method::Post
424);
425
426http_method!(
427 put,
429 Method::Put
430);
431
432http_method!(
433 patch,
435 Method::Patch
436);
437
438http_method!(
439 delete,
441 Method::Delete
442);
443
444fn default_handler_name(method: &Method, path: &str) -> String {
445 let prefix = method.as_str().to_lowercase();
446 let suffix = sanitize_identifier(path);
447 format!("{}_{}", prefix, suffix)
448}
449
450fn sanitize_identifier(input: &str) -> String {
451 let mut ident = input
452 .chars()
453 .map(|c| {
454 if c.is_ascii_alphanumeric() {
455 c.to_ascii_lowercase()
456 } else {
457 '_'
458 }
459 })
460 .collect::<String>();
461 while ident.contains("__") {
462 ident = ident.replace("__", "_");
463 }
464 ident.trim_matches('_').to_string()
465}
466
467fn schema_for<T: JsonSchema>() -> Value {
468 let root = schemars::schema_for!(T);
469 match serde_json::to_value(root) {
470 Ok(value) => value.get("schema").cloned().unwrap_or(value),
471 Err(e) => {
472 eprintln!("warning: failed to serialize schema: {}, returning null", e);
473 Value::Null
474 }
475 }
476}
477
478fn normalize_path(path: String) -> String {
479 if path.starts_with('/') {
480 path
481 } else {
482 format!("/{}", path)
483 }
484}
485
486#[derive(Debug, thiserror::Error)]
488pub enum AppError {
489 #[error("Failed to register route: {0}")]
491 Route(String),
492 #[error("Failed to build server: {0}")]
494 Server(String),
495 #[error("Failed to decode payload: {0}")]
497 Decode(String),
498}
499
500impl From<AppError> for (StatusCode, String) {
501 fn from(err: AppError) -> Self {
502 match err {
503 AppError::Route(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg),
504 AppError::Server(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg),
505 AppError::Decode(msg) => (StatusCode::BAD_REQUEST, msg),
506 }
507 }
508}
509
510pub struct RequestContext {
512 request: Request<Body>,
513 data: RequestData,
514}
515
516impl RequestContext {
517 fn new(request: Request<Body>, data: RequestData) -> Self {
518 Self { request, data }
519 }
520
521 pub fn request(&self) -> &Request<Body> {
523 &self.request
524 }
525
526 pub fn json<T: DeserializeOwned>(&self) -> std::result::Result<T, AppError> {
528 if !self.data.body.is_null() {
529 serde_json::from_value(self.data.body.clone()).map_err(|err| AppError::Decode(err.to_string()))
530 } else if let Some(raw_bytes) = &self.data.raw_body {
531 serde_json::from_slice(raw_bytes).map_err(|err| AppError::Decode(err.to_string()))
532 } else {
533 serde_json::from_value(self.data.body.clone()).map_err(|err| AppError::Decode(err.to_string()))
534 }
535 }
536
537 pub fn query<T: DeserializeOwned>(&self) -> std::result::Result<T, AppError> {
539 serde_json::from_value(self.data.query_params.clone()).map_err(|err| AppError::Decode(err.to_string()))
540 }
541
542 pub fn query_value(&self) -> &Value {
544 &self.data.query_params
545 }
546
547 pub fn raw_query_params(&self) -> &HashMap<String, Vec<String>> {
549 &self.data.raw_query_params
550 }
551
552 pub fn path<T: DeserializeOwned>(&self) -> std::result::Result<T, AppError> {
554 let value = serde_json::to_value(&*self.data.path_params).map_err(|err| AppError::Decode(err.to_string()))?;
555 serde_json::from_value(value).map_err(|err| AppError::Decode(err.to_string()))
556 }
557
558 pub fn path_params(&self) -> &HashMap<String, String> {
560 &self.data.path_params
561 }
562
563 pub fn path_param(&self, name: &str) -> Option<&str> {
565 self.data.path_params.get(name).map(|s| s.as_str())
566 }
567
568 pub fn header(&self, name: &str) -> Option<&str> {
570 self.data.headers.get(&name.to_ascii_lowercase()).map(|s| s.as_str())
571 }
572
573 pub fn headers_map(&self) -> &HashMap<String, String> {
575 &self.data.headers
576 }
577
578 pub fn cookie(&self, name: &str) -> Option<&str> {
580 self.data.cookies.get(name).map(|s| s.as_str())
581 }
582
583 pub fn cookies_map(&self) -> &HashMap<String, String> {
585 &self.data.cookies
586 }
587
588 pub fn body_value(&self) -> &Value {
590 &self.data.body
591 }
592
593 #[cfg(feature = "di")]
595 pub fn dependencies(&self) -> Option<Arc<di::ResolvedDependencies>> {
596 self.data.dependencies.as_ref().map(Arc::clone)
597 }
598
599 pub fn method(&self) -> &str {
601 &self.data.method
602 }
603
604 pub fn path_str(&self) -> &str {
606 &self.data.path
607 }
608}
609
610pub trait IntoHandler {
612 fn into_handler(self) -> Arc<dyn Handler>;
613}
614
615impl<F, Fut> IntoHandler for F
616where
617 F: Send + Sync + 'static + Fn(RequestContext) -> Fut,
618 Fut: Future<Output = HandlerResult> + Send + 'static,
619{
620 fn into_handler(self) -> Arc<dyn Handler> {
621 Arc::new(FnHandler { inner: self })
622 }
623}
624
625struct FnHandler<F> {
626 inner: F,
627}
628
629impl<F, Fut> Handler for FnHandler<F>
630where
631 F: Send + Sync + 'static + Fn(RequestContext) -> Fut,
632 Fut: Future<Output = HandlerResult> + Send + 'static,
633{
634 fn call(&self, req: Request<Body>, data: RequestData) -> Pin<Box<dyn Future<Output = HandlerResult> + Send + '_>> {
635 let ctx = RequestContext::new(req, data);
636 Box::pin((self.inner)(ctx))
637 }
638}
639
640#[cfg(test)]
641mod tests {
642 use super::*;
643 use axum::http::{Request, StatusCode};
644 use serde::{Deserialize, Serialize};
645 use serde_json::json;
646 use tower::util::ServiceExt;
647
648 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq)]
649 struct Greeting {
650 message: String,
651 }
652
653 #[test]
654 fn sanitize_identifier_handles_complex_path() {
655 assert_eq!(
656 sanitize_identifier("/api/v2/{resource}-{id}/action"),
657 "api_v2_resource_id_action"
658 );
659 }
660
661 #[test]
662 fn normalize_path_adds_leading_slash() {
663 assert_eq!(normalize_path("users".to_string()), "/users");
664 assert_eq!(normalize_path("/users".to_string()), "/users");
665 }
666
667 #[test]
668 fn default_handler_name_includes_method_prefix() {
669 assert_eq!(default_handler_name(&Method::Get, "/items/{id}"), "get_items_id");
670 assert_eq!(default_handler_name(&Method::Post, "items"), "post_items");
671 }
672
673 #[test]
674 fn schema_for_returns_embedded_schema_object() {
675 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
676 struct Payload {
677 message: String,
678 }
679
680 let schema = schema_for::<Payload>();
681 assert!(schema.is_object());
682 assert_eq!(schema.get("type").and_then(|v| v.as_str()), Some("object"));
683 assert!(schema.get("properties").is_some());
684 }
685
686 #[test]
687 fn route_builder_sets_defaults_and_metadata() {
688 let builder = post("items").sync();
689 let meta = builder.into_metadata();
690 assert_eq!(meta.method, "POST");
691 assert_eq!(meta.path, "items");
692 assert_eq!(meta.handler_name, "post_items");
693 assert!(!meta.is_async);
694 assert!(meta.request_schema.is_none());
695 assert!(meta.response_schema.is_none());
696 }
697
698 #[test]
699 fn app_error_maps_to_status_code_and_message() {
700 let (status, msg): (StatusCode, String) = AppError::Decode("bad json".to_string()).into();
701 assert_eq!(status, StatusCode::BAD_REQUEST);
702 assert_eq!(msg, "bad json");
703 }
704
705 #[tokio::test]
706 async fn registers_route_with_schema() {
707 let mut app = App::new();
708 app.route(
709 post("/hello").request_body::<Greeting>().response_body::<Greeting>(),
710 |ctx: RequestContext| async move {
711 let body: Greeting = ctx.json()?;
712 let response = serde_json::to_value(body).unwrap();
713 Ok(axum::http::Response::builder()
714 .status(StatusCode::OK)
715 .header("content-type", "application/json")
716 .body(Body::from(response.to_string()))
717 .unwrap())
718 },
719 )
720 .unwrap();
721
722 assert_eq!(app.metadata.len(), 1);
723 let meta = &app.metadata[0];
724 assert!(meta.request_schema.is_some());
725 assert!(meta.response_schema.is_some());
726 assert!(meta.parameter_schema.is_none());
727 }
728
729 #[test]
730 fn request_context_extracts_and_accesses_all_fields() {
731 let mut headers = std::collections::HashMap::new();
732 headers.insert("content-type".to_string(), "application/json".to_string());
733 headers.insert("authorization".to_string(), "Bearer token123".to_string());
734
735 let mut cookies = std::collections::HashMap::new();
736 cookies.insert("session_id".to_string(), "abc123".to_string());
737
738 let mut path_params = std::collections::HashMap::new();
739 path_params.insert("id".to_string(), "123".to_string());
740
741 let request = Request::builder()
742 .uri("http://localhost/users/123")
743 .body(Body::empty())
744 .unwrap();
745
746 let data = RequestData {
747 method: "POST".to_string(),
748 path: "/users/{id}".to_string(),
749 headers: std::sync::Arc::new(headers),
750 cookies: std::sync::Arc::new(cookies),
751 query_params: Value::Object(Default::default()),
752 validated_params: None,
753 raw_query_params: std::sync::Arc::new(HashMap::new()),
754 path_params: std::sync::Arc::new(path_params),
755 body: Value::Null,
756 raw_body: None,
757 #[cfg(feature = "di")]
758 dependencies: None,
759 };
760
761 let ctx = RequestContext::new(request, data);
762
763 assert_eq!(ctx.header("content-type"), Some("application/json"));
764 assert_eq!(ctx.header("Content-Type"), Some("application/json"));
765 assert_eq!(ctx.header("authorization"), Some("Bearer token123"));
766
767 assert_eq!(ctx.cookie("session_id"), Some("abc123"));
768 assert_eq!(ctx.cookie("nonexistent"), None);
769
770 assert_eq!(ctx.path_param("id"), Some("123"));
771 assert_eq!(ctx.path_param("missing"), None);
772
773 assert_eq!(ctx.method(), "POST");
774 assert_eq!(ctx.path_str(), "/users/{id}");
775 }
776
777 struct EchoWebSocket;
778
779 impl WebSocketHandler for EchoWebSocket {
780 async fn handle_message(&self, message: serde_json::Value) -> Option<serde_json::Value> {
781 Some(message)
782 }
783 }
784
785 #[tokio::test]
786 async fn websocket_routes_are_registered() {
787 let mut app = App::new();
788 app.websocket("/ws", EchoWebSocket);
789 let router = app.into_router().unwrap();
790 let request = Request::builder()
791 .uri("http://localhost/ws")
792 .header("connection", "Upgrade")
793 .header("upgrade", "websocket")
794 .header("sec-websocket-version", "13")
795 .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
796 .body(Body::empty())
797 .unwrap();
798 let response = router.oneshot(request).await.unwrap();
799 assert!(
800 response.status() == StatusCode::SWITCHING_PROTOCOLS || response.status() == StatusCode::UPGRADE_REQUIRED
801 );
802 }
803
804 struct DummyProducer;
805
806 impl SseEventProducer for DummyProducer {
807 async fn next_event(&self) -> Option<SseEvent> {
808 Some(SseEvent::new(json!({
809 "message": "hello"
810 })))
811 }
812 }
813
814 #[tokio::test]
815 async fn sse_routes_are_registered() {
816 let mut app = App::new();
817 app.sse("/events", DummyProducer);
818 let router = app.into_router().unwrap();
819 let request = Request::builder()
820 .uri("http://localhost/events")
821 .body(Body::empty())
822 .unwrap();
823 let response = router.oneshot(request).await.unwrap();
824 assert_eq!(response.status(), StatusCode::OK);
825 }
826}