lambda_runtime/
types.rs

1use crate::{Error, RefConfig};
2use base64::prelude::*;
3use bytes::Bytes;
4use http::{header::ToStrError, HeaderMap, HeaderValue, StatusCode};
5use lambda_runtime_api_client::body::Body;
6use serde::{Deserialize, Serialize};
7use std::{
8    collections::HashMap,
9    fmt::Debug,
10    time::{Duration, SystemTime},
11};
12use tokio_stream::Stream;
13
14/// Client context sent by the AWS Mobile SDK.
15#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
16pub struct ClientContext {
17    /// Information about the mobile application invoking the function.
18    #[serde(default)]
19    pub client: ClientApplication,
20    /// Custom properties attached to the mobile event context.
21    #[serde(default)]
22    pub custom: HashMap<String, String>,
23    /// Environment settings from the mobile client.
24    #[serde(default)]
25    pub environment: HashMap<String, String>,
26}
27
28/// AWS Mobile SDK client fields.
29#[derive(Serialize, Deserialize, Default, Clone, Debug, Eq, PartialEq)]
30#[serde(rename_all = "camelCase")]
31pub struct ClientApplication {
32    /// The mobile app installation id
33    #[serde(alias = "installation_id")]
34    pub installation_id: String,
35    /// The app title for the mobile app as registered with AWS' mobile services.
36    #[serde(alias = "app_title")]
37    pub app_title: String,
38    /// The version name of the application as registered with AWS' mobile services.
39    #[serde(alias = "app_version_name")]
40    pub app_version_name: String,
41    /// The app version code.
42    #[serde(alias = "app_version_code")]
43    pub app_version_code: String,
44    /// The package name for the mobile application invoking the function
45    #[serde(alias = "app_package_name")]
46    pub app_package_name: String,
47}
48
49/// Cognito identity information sent with the event
50#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
51#[serde(rename_all = "camelCase")]
52pub struct CognitoIdentity {
53    /// The unique identity id for the Cognito credentials invoking the function.
54    #[serde(alias = "cognitoIdentityId", alias = "identity_id")]
55    pub identity_id: String,
56    /// The identity pool id the caller is "registered" with.
57    #[serde(alias = "cognitoIdentityPoolId", alias = "identity_pool_id")]
58    pub identity_pool_id: String,
59}
60
61/// The Lambda function execution context. The values in this struct
62/// are populated using the [Lambda environment variables](https://docs.aws.amazon.com/lambda/latest/dg/current-supported-versions.html)
63/// and [the headers returned by the poll request to the Runtime APIs](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-api.html#runtimes-api-next).
64#[non_exhaustive]
65#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
66pub struct Context {
67    /// The AWS request ID generated by the Lambda service.
68    pub request_id: String,
69    /// The execution deadline for the current invocation in milliseconds.
70    pub deadline: u64,
71    /// The ARN of the Lambda function being invoked.
72    pub invoked_function_arn: String,
73    /// The X-Ray trace ID for the current invocation.
74    pub xray_trace_id: Option<String>,
75    /// The client context object sent by the AWS mobile SDK. This field is
76    /// empty unless the function is invoked using an AWS mobile SDK.
77    pub client_context: Option<ClientContext>,
78    /// The Cognito identity that invoked the function. This field is empty
79    /// unless the invocation request to the Lambda APIs was made using AWS
80    /// credentials issues by Amazon Cognito Identity Pools.
81    pub identity: Option<CognitoIdentity>,
82    /// Lambda function configuration from the local environment variables.
83    /// Includes information such as the function name, memory allocation,
84    /// version, and log streams.
85    pub env_config: RefConfig,
86}
87
88impl Default for Context {
89    fn default() -> Context {
90        Context {
91            request_id: "".to_owned(),
92            deadline: 0,
93            invoked_function_arn: "".to_owned(),
94            xray_trace_id: None,
95            client_context: None,
96            identity: None,
97            env_config: std::sync::Arc::new(crate::Config::default()),
98        }
99    }
100}
101
102impl Context {
103    /// Create a new [Context] struct based on the function configuration
104    /// and the incoming request data.
105    pub fn new(request_id: &str, env_config: RefConfig, headers: &HeaderMap) -> Result<Self, Error> {
106        let client_context: Option<ClientContext> = if let Some(value) = headers.get("lambda-runtime-client-context") {
107            serde_json::from_str(value.to_str()?)?
108        } else {
109            None
110        };
111
112        let identity: Option<CognitoIdentity> = if let Some(value) = headers.get("lambda-runtime-cognito-identity") {
113            serde_json::from_str(value.to_str()?)?
114        } else {
115            None
116        };
117
118        let ctx = Context {
119            request_id: request_id.to_owned(),
120            deadline: headers
121                .get("lambda-runtime-deadline-ms")
122                .expect("missing lambda-runtime-deadline-ms header")
123                .to_str()?
124                .parse::<u64>()?,
125            invoked_function_arn: headers
126                .get("lambda-runtime-invoked-function-arn")
127                .unwrap_or(&HeaderValue::from_static(
128                    "No header lambda-runtime-invoked-function-arn found.",
129                ))
130                .to_str()?
131                .to_owned(),
132            xray_trace_id: headers
133                .get("lambda-runtime-trace-id")
134                .map(|v| String::from_utf8_lossy(v.as_bytes()).to_string()),
135            client_context,
136            identity,
137            env_config,
138        };
139
140        Ok(ctx)
141    }
142
143    /// The execution deadline for the current invocation.
144    pub fn deadline(&self) -> SystemTime {
145        SystemTime::UNIX_EPOCH + Duration::from_millis(self.deadline)
146    }
147}
148
149/// Extract the invocation request id from the incoming request.
150pub(crate) fn invoke_request_id(headers: &HeaderMap) -> Result<&str, ToStrError> {
151    headers
152        .get("lambda-runtime-aws-request-id")
153        .expect("missing lambda-runtime-aws-request-id header")
154        .to_str()
155}
156
157/// Incoming Lambda request containing the event payload and context.
158#[derive(Clone, Debug)]
159pub struct LambdaEvent<T> {
160    /// Event payload.
161    pub payload: T,
162    /// Invocation context.
163    pub context: Context,
164}
165
166impl<T> LambdaEvent<T> {
167    /// Creates a new Lambda request
168    pub fn new(payload: T, context: Context) -> Self {
169        Self { payload, context }
170    }
171
172    /// Split the Lambda event into its payload and context.
173    pub fn into_parts(self) -> (T, Context) {
174        (self.payload, self.context)
175    }
176}
177
178/// Metadata prelude for a stream response.
179#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq)]
180#[serde(rename_all = "camelCase")]
181pub struct MetadataPrelude {
182    #[serde(with = "http_serde::status_code")]
183    /// The HTTP status code.
184    pub status_code: StatusCode,
185    #[serde(with = "http_serde::header_map")]
186    /// The HTTP headers.
187    pub headers: HeaderMap,
188    /// The HTTP cookies.
189    pub cookies: Vec<String>,
190}
191
192pub trait ToStreamErrorTrailer {
193    /// Convert the hyper error into a stream error trailer.
194    fn to_tailer(&self) -> String;
195}
196
197impl ToStreamErrorTrailer for Error {
198    fn to_tailer(&self) -> String {
199        format!(
200            "Lambda-Runtime-Function-Error-Type: Runtime.StreamError\r\nLambda-Runtime-Function-Error-Body: {}\r\n",
201            BASE64_STANDARD.encode(self.to_string())
202        )
203    }
204}
205
206/// A streaming response that contains the metadata prelude and the stream of bytes that will be
207/// sent to the client.
208#[derive(Debug)]
209pub struct StreamResponse<S> {
210    ///  The metadata prelude.
211    pub metadata_prelude: MetadataPrelude,
212    /// The stream of bytes that will be sent to the client.
213    pub stream: S,
214}
215
216/// An enum representing the response of a function that can return either a buffered
217/// response of type `B` or a streaming response of type `S`.
218pub enum FunctionResponse<B, S> {
219    /// A buffered response containing the entire payload of the response. This is useful
220    /// for responses that can be processed quickly and have a relatively small payload size(<= 6MB).
221    BufferedResponse(B),
222    /// A streaming response that delivers the payload incrementally. This is useful for
223    /// large payloads(> 6MB) or responses that take a long time to generate. The client can start
224    /// processing the response as soon as the first chunk is available, without waiting
225    /// for the entire payload to be generated.
226    StreamingResponse(StreamResponse<S>),
227}
228
229/// a trait that can be implemented for any type that can be converted into a FunctionResponse.
230/// This allows us to use the `into` method to convert a type into a FunctionResponse.
231pub trait IntoFunctionResponse<B, S> {
232    /// Convert the type into a FunctionResponse.
233    fn into_response(self) -> FunctionResponse<B, S>;
234}
235
236impl<B, S> IntoFunctionResponse<B, S> for FunctionResponse<B, S> {
237    fn into_response(self) -> FunctionResponse<B, S> {
238        self
239    }
240}
241
242impl<B> IntoFunctionResponse<B, Body> for B
243where
244    B: Serialize,
245{
246    fn into_response(self) -> FunctionResponse<B, Body> {
247        FunctionResponse::BufferedResponse(self)
248    }
249}
250
251impl<S, D, E> IntoFunctionResponse<(), S> for StreamResponse<S>
252where
253    S: Stream<Item = Result<D, E>> + Unpin + Send + 'static,
254    D: Into<Bytes> + Send,
255    E: Into<Error> + Send + Debug,
256{
257    fn into_response(self) -> FunctionResponse<(), S> {
258        FunctionResponse::StreamingResponse(self)
259    }
260}
261
262impl<S, D, E> From<S> for StreamResponse<S>
263where
264    S: Stream<Item = Result<D, E>> + Unpin + Send + 'static,
265    D: Into<Bytes> + Send,
266    E: Into<Error> + Send + Debug,
267{
268    fn from(value: S) -> Self {
269        StreamResponse {
270            metadata_prelude: Default::default(),
271            stream: value,
272        }
273    }
274}
275
276#[cfg(test)]
277mod test {
278    use super::*;
279    use crate::Config;
280    use std::sync::Arc;
281
282    #[test]
283    fn context_with_expected_values_and_types_resolves() {
284        let config = Arc::new(Config::default());
285
286        let mut headers = HeaderMap::new();
287        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
288        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
289        headers.insert(
290            "lambda-runtime-invoked-function-arn",
291            HeaderValue::from_static("arn::myarn"),
292        );
293        headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
294        let tried = Context::new("id", config, &headers);
295        assert!(tried.is_ok());
296    }
297
298    #[test]
299    fn context_with_certain_missing_headers_still_resolves() {
300        let config = Arc::new(Config::default());
301
302        let mut headers = HeaderMap::new();
303        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
304        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
305        let tried = Context::new("id", config, &headers);
306        assert!(tried.is_ok());
307    }
308
309    #[test]
310    fn context_with_client_context_resolves() {
311        let mut custom = HashMap::new();
312        custom.insert("key".to_string(), "value".to_string());
313        let mut environment = HashMap::new();
314        environment.insert("key".to_string(), "value".to_string());
315        let client_context = ClientContext {
316            client: ClientApplication {
317                installation_id: String::new(),
318                app_title: String::new(),
319                app_version_name: String::new(),
320                app_version_code: String::new(),
321                app_package_name: String::new(),
322            },
323            custom,
324            environment,
325        };
326        let client_context_str = serde_json::to_string(&client_context).unwrap();
327        let mut headers = HeaderMap::new();
328        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
329        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
330        headers.insert(
331            "lambda-runtime-client-context",
332            HeaderValue::from_str(&client_context_str).unwrap(),
333        );
334
335        let config = Arc::new(Config::default());
336        let tried = Context::new("id", config, &headers);
337        assert!(tried.is_ok());
338        let tried = tried.unwrap();
339        assert!(tried.client_context.is_some());
340        assert_eq!(tried.client_context.unwrap(), client_context);
341    }
342
343    #[test]
344    fn context_with_empty_client_context_resolves() {
345        let config = Arc::new(Config::default());
346        let mut headers = HeaderMap::new();
347        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
348        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
349        headers.insert("lambda-runtime-client-context", HeaderValue::from_static("{}"));
350        let tried = Context::new("id", config, &headers);
351        assert!(tried.is_ok());
352        assert!(tried.unwrap().client_context.is_some());
353    }
354
355    #[test]
356    fn context_with_identity_resolves() {
357        let config = Arc::new(Config::default());
358
359        let cognito_identity = CognitoIdentity {
360            identity_id: String::new(),
361            identity_pool_id: String::new(),
362        };
363        let cognito_identity_str = serde_json::to_string(&cognito_identity).unwrap();
364        let mut headers = HeaderMap::new();
365        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
366        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
367        headers.insert(
368            "lambda-runtime-cognito-identity",
369            HeaderValue::from_str(&cognito_identity_str).unwrap(),
370        );
371        let tried = Context::new("id", config, &headers);
372        assert!(tried.is_ok());
373        let tried = tried.unwrap();
374        assert!(tried.identity.is_some());
375        assert_eq!(tried.identity.unwrap(), cognito_identity);
376    }
377
378    #[test]
379    fn context_with_bad_deadline_type_is_err() {
380        let config = Arc::new(Config::default());
381
382        let mut headers = HeaderMap::new();
383        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
384        headers.insert(
385            "lambda-runtime-deadline-ms",
386            HeaderValue::from_static("BAD-Type,not <u64>"),
387        );
388        headers.insert(
389            "lambda-runtime-invoked-function-arn",
390            HeaderValue::from_static("arn::myarn"),
391        );
392        headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
393        let tried = Context::new("id", config, &headers);
394        assert!(tried.is_err());
395    }
396
397    #[test]
398    fn context_with_bad_client_context_is_err() {
399        let config = Arc::new(Config::default());
400
401        let mut headers = HeaderMap::new();
402        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
403        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
404        headers.insert(
405            "lambda-runtime-client-context",
406            HeaderValue::from_static("BAD-Type,not JSON"),
407        );
408        let tried = Context::new("id", config, &headers);
409        assert!(tried.is_err());
410    }
411
412    #[test]
413    fn context_with_empty_identity_is_err() {
414        let config = Arc::new(Config::default());
415
416        let mut headers = HeaderMap::new();
417        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
418        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
419        headers.insert("lambda-runtime-cognito-identity", HeaderValue::from_static("{}"));
420        let tried = Context::new("id", config, &headers);
421        assert!(tried.is_err());
422    }
423
424    #[test]
425    fn context_with_bad_identity_is_err() {
426        let config = Arc::new(Config::default());
427
428        let mut headers = HeaderMap::new();
429        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
430        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
431        headers.insert(
432            "lambda-runtime-cognito-identity",
433            HeaderValue::from_static("BAD-Type,not JSON"),
434        );
435        let tried = Context::new("id", config, &headers);
436        assert!(tried.is_err());
437    }
438
439    #[test]
440    #[should_panic]
441    fn context_with_missing_deadline_should_panic() {
442        let config = Arc::new(Config::default());
443
444        let mut headers = HeaderMap::new();
445        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
446        headers.insert(
447            "lambda-runtime-invoked-function-arn",
448            HeaderValue::from_static("arn::myarn"),
449        );
450        headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
451        let _ = Context::new("id", config, &headers);
452    }
453
454    #[test]
455    fn invoke_request_id_should_not_panic() {
456        let mut headers = HeaderMap::new();
457        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
458        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
459        headers.insert(
460            "lambda-runtime-invoked-function-arn",
461            HeaderValue::from_static("arn::myarn"),
462        );
463        headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
464
465        let _ = invoke_request_id(&headers);
466    }
467
468    #[test]
469    #[should_panic]
470    fn invoke_request_id_should_panic() {
471        let mut headers = HeaderMap::new();
472        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
473        headers.insert(
474            "lambda-runtime-invoked-function-arn",
475            HeaderValue::from_static("arn::myarn"),
476        );
477        headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
478
479        let _ = invoke_request_id(&headers);
480    }
481
482    #[test]
483    fn serde_metadata_prelude() {
484        let metadata_prelude = MetadataPrelude {
485            status_code: StatusCode::OK,
486            headers: {
487                let mut headers = HeaderMap::new();
488                headers.insert("key", "val".parse().unwrap());
489                headers
490            },
491            cookies: vec!["cookie".to_string()],
492        };
493
494        let serialized = serde_json::to_string(&metadata_prelude).unwrap();
495        let deserialized: MetadataPrelude = serde_json::from_str(&serialized).unwrap();
496
497        assert_eq!(metadata_prelude, deserialized);
498    }
499}