lambda_web/
request.rs

1// SPDX-License-Identifier: MIT
2//!
3//! Lambda event deserialize
4//!
5use serde::Deserialize;
6use std::borrow::Cow;
7use std::collections::HashMap;
8
9#[derive(Deserialize, Debug)]
10#[serde(untagged)]
11pub(crate) enum LambdaHttpEvent<'a> {
12    ApiGatewayHttpV2(ApiGatewayHttpV2Event<'a>),
13    ApiGatewayRestOrAlb(ApiGatewayRestEvent<'a>),
14}
15
16impl LambdaHttpEvent<'_> {
17    /// HTTP request method
18    pub fn method<'a>(&'a self) -> &'a str {
19        match self {
20            Self::ApiGatewayHttpV2(event) => &event.request_context.http.method,
21            Self::ApiGatewayRestOrAlb(event) => &event.http_method,
22        }
23    }
24
25    /// Host name
26    #[allow(dead_code)]
27    pub fn hostname<'a>(&'a self) -> Option<&'a str> {
28        match self {
29            Self::ApiGatewayHttpV2(event) => Some(&event.request_context.domain_name),
30            Self::ApiGatewayRestOrAlb(event) => {
31                if let RestOrAlbRequestContext::Rest(context) = &event.request_context {
32                    Some(&context.domain_name)
33                } else if let Some(host_headers) = event.multi_value_headers.get("host") {
34                    host_headers.first().map(|h| h as &str)
35                } else {
36                    None
37                }
38            }
39        }
40    }
41
42    /// URL encoded path?query
43    pub fn path_query(&self) -> String {
44        match self {
45            Self::ApiGatewayHttpV2(event) => {
46                let path = encode_path_query(&event.raw_path);
47                let query = &event.raw_query_string as &str;
48                if query.is_empty() {
49                    // No query string
50                    path.into_owned()
51                } else {
52                    // With query string
53                    format!("{}?{}", path, query)
54                }
55            }
56            Self::ApiGatewayRestOrAlb(event) => {
57                let path = if let RestOrAlbRequestContext::Rest(context) = &event.request_context {
58                    // API Gateway REST, request_contest.path contains stage prefix
59                    &context.path
60                } else {
61                    // ALB
62                    &event.path
63                };
64                if let Some(query_string_parameters) = &event.multi_value_query_string_parameters {
65                    // With query string
66                    let querystr = query_string_parameters
67                        .iter()
68                        .flat_map(|(k, vec)| {
69                            let k_enc = encode_path_query(&k);
70                            vec.iter()
71                                .map(move |v| format!("{}={}", k_enc, encode_path_query(&v)))
72                        })
73                        .collect::<Vec<_>>()
74                        .join("&");
75                    format!("{}?{}", path, querystr)
76                } else {
77                    // No query string
78                    path.clone()
79                }
80            }
81        }
82    }
83
84    /// HTTP headers
85    pub fn headers<'a>(&'a self) -> Vec<(&'a str, Cow<'a, str>)> {
86        match self {
87            Self::ApiGatewayHttpV2(event) => {
88                let mut headers: Vec<(&'a str, Cow<'a, str>)> = event
89                    .headers
90                    .iter()
91                    .map(|(k, v)| (k as &str, Cow::from(v as &str)))
92                    .collect();
93
94                // Add cookie header
95                if let Some(cookies) = &event.cookies {
96                    let cookie_value = cookies.join("; ");
97                    headers.push(("cookie", Cow::from(cookie_value)));
98                }
99
100                headers
101            }
102            Self::ApiGatewayRestOrAlb(event) => event
103                .multi_value_headers
104                .iter()
105                .flat_map(|(k, vec)| vec.iter().map(move |v| (k as &str, Cow::from(v as &str))))
106                .collect(),
107        }
108    }
109
110    /// Cookies
111    /// percent encoded "key=val"
112    #[allow(dead_code)]
113    pub fn cookies<'a>(&'a self) -> Vec<&'a str> {
114        match self {
115            Self::ApiGatewayHttpV2(event) => {
116                if let Some(cookies) = &event.cookies {
117                    cookies.iter().map(|c| c.as_str()).collect()
118                } else {
119                    Vec::new()
120                }
121            }
122            Self::ApiGatewayRestOrAlb(event) => {
123                if let Some(cookie_headers) = event.multi_value_headers.get("cookie") {
124                    cookie_headers
125                        .iter()
126                        .flat_map(|v| v.split(";"))
127                        .map(|c| c.trim())
128                        .collect()
129                } else {
130                    Vec::new()
131                }
132            }
133        }
134    }
135
136    /// Check if HTTP client supports Brotli compression.
137    /// ( Accept-Encoding contains "br" )
138    #[cfg(feature = "br")]
139    pub fn client_supports_brotli(&self) -> bool {
140        match self {
141            Self::ApiGatewayHttpV2(event) => {
142                if let Some(header_val) = event.headers.get("accept-encoding") {
143                    for elm in header_val.to_ascii_lowercase().split(',') {
144                        if let Some(algo_name) = elm.split(';').next() {
145                            // first part of elm, contains 'br', 'gzip', etc.
146                            if algo_name.trim() == "br" {
147                                // HTTP client support Brotli compression
148                                return true;
149                            }
150                        }
151                    }
152                    // No "br" in accept-encoding header
153                    false
154                } else {
155                    // No accept-encoding header
156                    false
157                }
158            }
159            Self::ApiGatewayRestOrAlb(event) => {
160                if let Some(header_vals) = event.multi_value_headers.get("accept-encoding") {
161                    for header_val in header_vals {
162                        for elm in header_val.to_ascii_lowercase().split(',') {
163                            if let Some(algo_name) = elm.split(';').next() {
164                                // first part of elm, contains 'br', 'gzip', etc.
165                                if algo_name.trim() == "br" {
166                                    // HTTP client support Brotli compression
167                                    return true;
168                                }
169                            }
170                        }
171                    }
172                    // No "br" in accept-encoding header
173                    false
174                } else {
175                    // No accept-encoding header
176                    false
177                }
178            }
179        }
180    }
181
182    // Without Brotli support, always returns false
183    #[cfg(not(feature = "br"))]
184    pub fn client_supports_brotli(&self) -> bool {
185        false
186    }
187
188    /// Is request & response use multi-value-header
189    pub fn multi_value(&self) -> bool {
190        match self {
191            Self::ApiGatewayHttpV2(_) => false,
192            Self::ApiGatewayRestOrAlb(_) => true,
193        }
194    }
195
196    /// Request body
197    pub fn body(self) -> Result<Vec<u8>, base64::DecodeError> {
198        let (body, b64_encoded) = match self {
199            Self::ApiGatewayHttpV2(event) => (event.body, event.is_base64_encoded),
200            Self::ApiGatewayRestOrAlb(event) => (event.body, event.is_base64_encoded),
201        };
202
203        if let Some(body) = body {
204            if b64_encoded {
205                // base64 decode
206                base64::decode(&body as &str)
207            } else {
208                // string
209                Ok(body.into_owned().into_bytes())
210            }
211        } else {
212            // empty body (GET, OPTION, etc. methods)
213            Ok(Vec::new())
214        }
215    }
216
217    /// Source IP address
218    #[allow(dead_code)]
219    pub fn source_ip(&self) -> Option<std::net::IpAddr> {
220        use std::net::IpAddr;
221        use std::str::FromStr;
222        match self {
223            Self::ApiGatewayHttpV2(event) => {
224                IpAddr::from_str(&event.request_context.http.source_ip).ok()
225            }
226            Self::ApiGatewayRestOrAlb(event) => {
227                if let RestOrAlbRequestContext::Rest(context) = &event.request_context {
228                    IpAddr::from_str(&context.identity.source_ip).ok()
229                } else {
230                    None
231                }
232            }
233        }
234    }
235}
236
237/// API Gateway HTTP API payload format version 2.0
238/// https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html
239#[derive(Deserialize, Debug)]
240#[serde(rename_all = "camelCase")]
241pub(crate) struct ApiGatewayHttpV2Event<'a> {
242    #[allow(dead_code)]
243    version: String,
244    raw_path: String,
245    raw_query_string: String,
246    cookies: Option<Vec<String>>,
247    headers: HashMap<String, String>,
248    //#[serde(borrow)]
249    body: Option<Cow<'a, str>>,
250    #[serde(default)]
251    is_base64_encoded: bool,
252    request_context: ApiGatewayV2RequestContext,
253    // route_key: Cow<'a, str>,
254    // #[serde(default)]
255    // query_string_parameters: StrMap,
256    // #[serde(default)]
257    // path_parameters: StrMap,
258    // #[serde(default)]
259    // stage_variables: StrMap,
260}
261
262#[derive(Deserialize, Debug, Clone)]
263#[serde(rename_all = "camelCase")]
264struct ApiGatewayV2RequestContext {
265    /// The full domain name used to invoke the API. This should be the same as the incoming Host header.
266    domain_name: String,
267    /// The HTTP method used.
268    http: Http,
269    // The API owner's AWS account ID.
270    // pub account_id: String,
271    // The identifier API Gateway assigns to your API.
272    // pub api_id: String,
273    // The stringified value of the specified key-value pair of the context map returned from an API Gateway Lambda authorizer function.
274    // #[serde(default)]
275    // pub authorizer: HashMap<String, serde_json::Value>,
276    // The first label of the $context.domainName. This is often used as a caller/customer identifier.
277    // pub domain_prefix: String,
278    // The ID that API Gateway assigns to the API request.
279    // pub request_id: String,
280    // Undocumented, could be resourcePath
281    // pub route_key: String,
282    // The deployment stage of the API request (for example, Beta or Prod).
283    // pub stage: String,
284    // Undocumented, could be requestTime
285    // pub time: String,
286    // Undocumented, could be requestTimeEpoch
287    // pub time_epoch: usize,
288}
289
290#[derive(Deserialize, Debug, Default, Clone)]
291#[serde(rename_all = "camelCase")]
292struct Http {
293    /// The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.
294    method: String,
295    /// The source IP address of the TCP connection making the request to API Gateway.
296    source_ip: String,
297    // The request path. For example, for a non-proxy request URL of
298    // `https://{rest-api-id.execute-api.{region}.amazonaws.com/{stage}/root/child`,
299    // the $context.path value is `/{stage}/root/child`.
300    // pub path: String,
301    // The request protocol, for example, HTTP/1.1.
302    // pub protocol: String,
303    // The User-Agent header of the API caller.
304    // pub user_agent: String,
305}
306
307/// API Gateway REST API, ALB payload format
308/// https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html#api-gateway-simple-proxy-for-lambda-input-format
309///
310/// In case of ALB, you must explicitly enable multi-value headers setting.
311///
312#[derive(Deserialize, Debug)]
313#[serde(rename_all = "camelCase")]
314pub(crate) struct ApiGatewayRestEvent<'a> {
315    // path without stage
316    path: String,
317    http_method: String,
318    //#[serde(borrow)]
319    body: Option<Cow<'a, str>>,
320    #[serde(default)]
321    is_base64_encoded: bool,
322    multi_value_headers: HashMap<String, Vec<String>>,
323    #[serde(default)]
324    multi_value_query_string_parameters: Option<HashMap<String, Vec<String>>>,
325    // request_context = None when called from ALB
326    request_context: RestOrAlbRequestContext,
327    // headers: HashMap<String, String>,
328    // path_parameters: HashMap<String, String>,
329    // query_string_parameters: HashMap<String, String>,
330    // stage_variables: HashMap<String, String>,
331}
332
333#[derive(Deserialize, Debug)]
334#[serde(untagged)]
335enum RestOrAlbRequestContext {
336    Rest(ApiGatewayRestRequestContext),
337    Alb(AlbRequestContext),
338}
339
340/// API Gateway REST API request context
341#[derive(Deserialize, Debug)]
342#[serde(rename_all = "camelCase")]
343struct ApiGatewayRestRequestContext {
344    domain_name: String,
345    identity: ApiGatewayRestIdentity,
346    // Path with stage
347    path: String,
348    // account_id: String,
349    // api_id: String,
350    // authorizer: HashMap<String, Value>,
351    // domain_prefix: String,
352    // http_method: String,
353    // protocol: String,
354    // request_id: String,
355    // request_time: String,
356    // request_time_epoch: i64,
357    // resource_id: String,
358    // resource_path: String,
359    // stage: String,
360}
361
362/// API Gateway REST API identity
363#[derive(Deserialize, Debug)]
364#[serde(rename_all = "camelCase")]
365struct ApiGatewayRestIdentity {
366    #[allow(dead_code)]
367    access_key: Option<String>,
368    source_ip: String,
369}
370
371/// ALB Request context
372#[derive(Deserialize, Debug)]
373#[serde(rename_all = "camelCase")]
374struct AlbRequestContext {}
375
376// raw_path in API Gateway HTTP API V2 payload is percent decoded.
377// Path containing space or UTF-8 char is
378// required to percent encoded again before passed to web frameworks
379// See RFC3986 3.3 Path for valid chars.
380const RFC3986_PATH_ESCAPE_SET: &percent_encoding::AsciiSet = &percent_encoding::CONTROLS
381    .add(b' ')
382    .add(b'"')
383    .add(b'#')
384    .add(b'%')
385    .add(b'+')
386    .add(b':')
387    .add(b'<')
388    .add(b'>')
389    .add(b'?')
390    .add(b'@')
391    .add(b'[')
392    .add(b'\\')
393    .add(b']')
394    .add(b'^')
395    .add(b'`')
396    .add(b'{')
397    .add(b'|')
398    .add(b'}');
399
400fn encode_path_query<'a>(pathstr: &'a str) -> Cow<'a, str> {
401    Cow::from(percent_encoding::utf8_percent_encode(
402        pathstr,
403        &RFC3986_PATH_ESCAPE_SET,
404    ))
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use crate::test_consts::*;
411
412    #[test]
413    fn test_decode() {
414        let _: ApiGatewayHttpV2Event =
415            serde_json::from_str(API_GATEWAY_V2_GET_ROOT_NOQUERY).unwrap();
416        let _: LambdaHttpEvent = serde_json::from_str(API_GATEWAY_V2_GET_ROOT_NOQUERY).unwrap();
417        let _: ApiGatewayRestEvent =
418            serde_json::from_str(API_GATEWAY_REST_GET_ROOT_NOQUERY).unwrap();
419        let _: LambdaHttpEvent = serde_json::from_str(API_GATEWAY_REST_GET_ROOT_NOQUERY).unwrap();
420    }
421
422    #[test]
423    fn test_cookie() {
424        let event: LambdaHttpEvent = serde_json::from_str(API_GATEWAY_V2_GET_TWO_COOKIES).unwrap();
425        assert_eq!(
426            event.cookies(),
427            vec!["cookie1=value1".to_string(), "cookie2=value2".to_string()]
428        );
429        let event: LambdaHttpEvent =
430            serde_json::from_str(API_GATEWAY_REST_GET_TWO_COOKIES).unwrap();
431        assert_eq!(
432            event.cookies(),
433            vec!["cookie1=value1".to_string(), "cookie2=value2".to_string()]
434        );
435    }
436}