rustack_apigatewayv2_http/
service.rs1use std::{convert::Infallible, future::Future, pin::Pin, sync::Arc};
4
5use bytes::Bytes;
6use http_body_util::BodyExt;
7use hyper::body::Incoming;
8use rustack_apigatewayv2_model::error::ApiGatewayV2Error;
9
10use crate::{
11 body::ApiGatewayV2ResponseBody,
12 dispatch::{ApiGatewayV2Handler, dispatch_operation},
13 response::{CONTENT_TYPE, error_to_response},
14 router::resolve_operation,
15};
16
17#[derive(Clone)]
19pub struct ApiGatewayV2HttpConfig {
20 pub skip_signature_validation: bool,
22 pub region: String,
24 pub credential_provider: Option<Arc<dyn rustack_auth::CredentialProvider>>,
26}
27
28impl std::fmt::Debug for ApiGatewayV2HttpConfig {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("ApiGatewayV2HttpConfig")
31 .field("skip_signature_validation", &self.skip_signature_validation)
32 .field("region", &self.region)
33 .field(
34 "credential_provider",
35 &self.credential_provider.as_ref().map(|_| "..."),
36 )
37 .finish()
38 }
39}
40
41impl Default for ApiGatewayV2HttpConfig {
42 fn default() -> Self {
43 Self {
44 skip_signature_validation: true,
45 region: "us-east-1".to_owned(),
46 credential_provider: None,
47 }
48 }
49}
50
51#[derive(Debug)]
57pub struct ApiGatewayV2HttpService<H: ApiGatewayV2Handler> {
58 handler: Arc<H>,
59 config: Arc<ApiGatewayV2HttpConfig>,
60}
61
62impl<H: ApiGatewayV2Handler> ApiGatewayV2HttpService<H> {
63 pub fn new(handler: Arc<H>, config: ApiGatewayV2HttpConfig) -> Self {
65 Self {
66 handler,
67 config: Arc::new(config),
68 }
69 }
70}
71
72impl<H: ApiGatewayV2Handler> Clone for ApiGatewayV2HttpService<H> {
73 fn clone(&self) -> Self {
74 Self {
75 handler: Arc::clone(&self.handler),
76 config: Arc::clone(&self.config),
77 }
78 }
79}
80
81impl<H: ApiGatewayV2Handler> hyper::service::Service<http::Request<Incoming>>
82 for ApiGatewayV2HttpService<H>
83{
84 type Response = http::Response<ApiGatewayV2ResponseBody>;
85 type Error = Infallible;
86 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
87
88 fn call(&self, req: http::Request<Incoming>) -> Self::Future {
89 let handler = Arc::clone(&self.handler);
90 let config = Arc::clone(&self.config);
91 let request_id = uuid::Uuid::new_v4().to_string();
92
93 Box::pin(async move {
94 let response = process_request(req, handler.as_ref(), &config, &request_id).await;
95 let response = add_common_headers(response, &request_id);
96 Ok(response)
97 })
98 }
99}
100
101async fn process_request<H: ApiGatewayV2Handler>(
103 req: http::Request<Incoming>,
104 handler: &H,
105 config: &ApiGatewayV2HttpConfig,
106 _request_id: &str,
107) -> http::Response<ApiGatewayV2ResponseBody> {
108 let (parts, incoming) = req.into_parts();
109
110 let path = parts.uri.path();
112 let (op, path_params, success_status) = match resolve_operation(&parts.method, path) {
113 Ok(result) => result,
114 Err(err) => return wrap_error_response(&err),
115 };
116
117 let query = parts.uri.query().unwrap_or("").to_owned();
119
120 let body = match collect_body(incoming).await {
122 Ok(body) => body,
123 Err(err) => return wrap_error_response(&err),
124 };
125
126 if !config.skip_signature_validation {
128 if let Some(ref cred_provider) = config.credential_provider {
129 let body_hash = rustack_auth::hash_payload(&body);
130 if let Err(auth_err) =
131 rustack_auth::verify_sigv4(&parts, &body_hash, cred_provider.as_ref())
132 {
133 let err = ApiGatewayV2Error::with_message(
134 rustack_apigatewayv2_model::error::ApiGatewayV2ErrorCode::AccessDeniedException,
135 auth_err.to_string(),
136 );
137 return wrap_error_response(&err);
138 }
139 }
140 }
141
142 match dispatch_operation(handler, op, path_params, query, parts.headers, body).await {
144 Ok(mut response) => {
145 if response.status() == http::StatusCode::OK && success_status != 200 {
147 *response.status_mut() =
148 http::StatusCode::from_u16(success_status).unwrap_or(http::StatusCode::OK);
149 }
150 response
151 }
152 Err(err) => wrap_error_response(&err),
153 }
154}
155
156fn wrap_error_response(error: &ApiGatewayV2Error) -> http::Response<ApiGatewayV2ResponseBody> {
161 if let Ok(bytes_response) = error_to_response(error) {
162 let (parts, body) = bytes_response.into_parts();
163 http::Response::from_parts(parts, ApiGatewayV2ResponseBody::from_bytes(body))
164 } else {
165 let (parts, body) = http::Response::builder()
167 .status(http::StatusCode::INTERNAL_SERVER_ERROR)
168 .body(Bytes::from(r#"{"message":"Internal error"}"#))
169 .unwrap_or_default()
170 .into_parts();
171 http::Response::from_parts(parts, ApiGatewayV2ResponseBody::from_bytes(body))
172 }
173}
174
175async fn collect_body(incoming: Incoming) -> Result<Bytes, ApiGatewayV2Error> {
177 incoming
178 .collect()
179 .await
180 .map(http_body_util::Collected::to_bytes)
181 .map_err(|e| ApiGatewayV2Error::internal_error(format!("Failed to read request body: {e}")))
182}
183
184fn add_common_headers(
186 mut response: http::Response<ApiGatewayV2ResponseBody>,
187 request_id: &str,
188) -> http::Response<ApiGatewayV2ResponseBody> {
189 let is_no_content = response.status() == http::StatusCode::NO_CONTENT;
190 let headers = response.headers_mut();
191
192 if let Ok(hv) = http::HeaderValue::from_str(request_id) {
193 headers.entry("x-amzn-requestid").or_insert(hv);
194 }
195
196 if !is_no_content {
198 headers
199 .entry("content-type")
200 .or_insert(http::HeaderValue::from_static(CONTENT_TYPE));
201 }
202
203 headers.insert("server", http::HeaderValue::from_static("Rustack"));
204
205 headers.insert(
207 "access-control-allow-origin",
208 http::HeaderValue::from_static("*"),
209 );
210
211 response
212}