Skip to main content

lambda_runtime/
types.rs

1use crate::{Error, RefConfig};
2use base64::prelude::*;
3use bytes::Bytes;
4use http::{header::ToStrError, HeaderMap, HeaderValue, StatusCode};
5use lambda_runtime_api_client::body::Body;
6use serde::{Deserialize, Serialize};
7use std::{
8    collections::HashMap,
9    fmt::Debug,
10    time::{Duration, SystemTime},
11};
12use tokio_stream::Stream;
13
14/// Client context sent by the AWS Mobile SDK.
15#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
16pub struct ClientContext {
17    /// Information about the mobile application invoking the function.
18    #[serde(default)]
19    pub client: ClientApplication,
20    /// Custom properties attached to the mobile event context.
21    #[serde(default)]
22    pub custom: HashMap<String, String>,
23    /// Environment settings from the mobile client.
24    #[serde(default)]
25    pub environment: HashMap<String, String>,
26}
27
28/// AWS Mobile SDK client fields.
29#[derive(Serialize, Deserialize, Default, Clone, Debug, Eq, PartialEq)]
30#[serde(rename_all = "camelCase")]
31pub struct ClientApplication {
32    /// The mobile app installation id
33    #[serde(alias = "installation_id")]
34    pub installation_id: String,
35    /// The app title for the mobile app as registered with AWS' mobile services.
36    #[serde(alias = "app_title")]
37    pub app_title: String,
38    /// The version name of the application as registered with AWS' mobile services.
39    #[serde(alias = "app_version_name")]
40    pub app_version_name: String,
41    /// The app version code.
42    #[serde(alias = "app_version_code")]
43    pub app_version_code: String,
44    /// The package name for the mobile application invoking the function
45    #[serde(alias = "app_package_name")]
46    pub app_package_name: String,
47}
48
49/// Cognito identity information sent with the event
50#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
51#[serde(rename_all = "camelCase")]
52pub struct CognitoIdentity {
53    /// The unique identity id for the Cognito credentials invoking the function.
54    #[serde(alias = "cognitoIdentityId", alias = "identity_id")]
55    pub identity_id: String,
56    /// The identity pool id the caller is "registered" with.
57    #[serde(alias = "cognitoIdentityPoolId", alias = "identity_pool_id")]
58    pub identity_pool_id: String,
59}
60
61/// The Lambda function execution context. The values in this struct
62/// are populated using the [Lambda environment variables](https://docs.aws.amazon.com/lambda/latest/dg/current-supported-versions.html)
63/// and [the headers returned by the poll request to the Runtime APIs](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-api.html#runtimes-api-next).
64#[non_exhaustive]
65#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
66pub struct Context {
67    /// The AWS request ID generated by the Lambda service.
68    pub request_id: String,
69    /// The execution deadline for the current invocation in milliseconds.
70    pub deadline: u64,
71    /// The ARN of the Lambda function being invoked.
72    pub invoked_function_arn: String,
73    /// The X-Ray trace ID for the current invocation.
74    pub xray_trace_id: Option<String>,
75    /// The client context object sent by the AWS mobile SDK. This field is
76    /// empty unless the function is invoked using an AWS mobile SDK.
77    pub client_context: Option<ClientContext>,
78    /// The Cognito identity that invoked the function. This field is empty
79    /// unless the invocation request to the Lambda APIs was made using AWS
80    /// credentials issues by Amazon Cognito Identity Pools.
81    pub identity: Option<CognitoIdentity>,
82    /// The tenant ID for the current invocation.
83    pub tenant_id: Option<String>,
84    /// Lambda function configuration from the local environment variables.
85    /// Includes information such as the function name, memory allocation,
86    /// version, and log streams.
87    pub env_config: RefConfig,
88}
89
90impl Default for Context {
91    fn default() -> Context {
92        Context {
93            request_id: "".to_owned(),
94            deadline: 0,
95            invoked_function_arn: "".to_owned(),
96            xray_trace_id: None,
97            client_context: None,
98            identity: None,
99            tenant_id: None,
100            env_config: std::sync::Arc::new(crate::Config::default()),
101        }
102    }
103}
104
105impl Context {
106    /// Create a new [Context] struct based on the function configuration
107    /// and the incoming request data.
108    pub fn new(request_id: &str, env_config: RefConfig, headers: &HeaderMap) -> Result<Self, Error> {
109        let client_context: Option<ClientContext> = if let Some(value) = headers.get("lambda-runtime-client-context") {
110            let raw = value.to_str()?;
111            if raw.is_empty() {
112                None
113            } else {
114                Some(serde_json::from_str(raw)?)
115            }
116        } else {
117            None
118        };
119
120        let identity: Option<CognitoIdentity> = if let Some(value) = headers.get("lambda-runtime-cognito-identity") {
121            let raw = value.to_str()?;
122            if raw.is_empty() {
123                None
124            } else {
125                Some(serde_json::from_str(raw)?)
126            }
127        } else {
128            None
129        };
130
131        let ctx = Context {
132            request_id: request_id.to_owned(),
133            deadline: headers
134                .get("lambda-runtime-deadline-ms")
135                .expect("missing lambda-runtime-deadline-ms header")
136                .to_str()?
137                .parse::<u64>()?,
138            invoked_function_arn: headers
139                .get("lambda-runtime-invoked-function-arn")
140                .unwrap_or(&HeaderValue::from_static(
141                    "No header lambda-runtime-invoked-function-arn found.",
142                ))
143                .to_str()?
144                .to_owned(),
145            xray_trace_id: headers
146                .get("lambda-runtime-trace-id")
147                .map(|v| String::from_utf8_lossy(v.as_bytes()).to_string()),
148            client_context,
149            identity,
150            tenant_id: headers
151                .get("lambda-runtime-aws-tenant-id")
152                .map(|v| String::from_utf8_lossy(v.as_bytes()).to_string()),
153            env_config,
154        };
155
156        Ok(ctx)
157    }
158
159    /// The execution deadline for the current invocation.
160    pub fn deadline(&self) -> SystemTime {
161        SystemTime::UNIX_EPOCH + Duration::from_millis(self.deadline)
162    }
163}
164
165/// Extract the invocation request id from the incoming request.
166pub(crate) fn invoke_request_id(headers: &HeaderMap) -> Result<&str, ToStrError> {
167    headers
168        .get("lambda-runtime-aws-request-id")
169        .expect("missing lambda-runtime-aws-request-id header")
170        .to_str()
171}
172
173/// Incoming Lambda request containing the event payload and context.
174#[derive(Clone, Debug)]
175pub struct LambdaEvent<T> {
176    /// Event payload.
177    pub payload: T,
178    /// Invocation context.
179    pub context: Context,
180}
181
182impl<T> LambdaEvent<T> {
183    /// Creates a new Lambda request
184    pub fn new(payload: T, context: Context) -> Self {
185        Self { payload, context }
186    }
187
188    /// Split the Lambda event into its payload and context.
189    pub fn into_parts(self) -> (T, Context) {
190        (self.payload, self.context)
191    }
192}
193
194/// Metadata prelude for a stream response.
195#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq)]
196#[serde(rename_all = "camelCase")]
197pub struct MetadataPrelude {
198    #[serde(with = "http_serde::status_code")]
199    /// The HTTP status code.
200    pub status_code: StatusCode,
201    #[serde(with = "http_serde::header_map")]
202    /// The HTTP headers.
203    pub headers: HeaderMap,
204    /// The HTTP cookies.
205    pub cookies: Vec<String>,
206}
207
208pub trait ToStreamErrorTrailer {
209    /// Convert the hyper error into a stream error trailer.
210    fn to_tailer(&self) -> String;
211}
212
213impl ToStreamErrorTrailer for Error {
214    fn to_tailer(&self) -> String {
215        format!(
216            "Lambda-Runtime-Function-Error-Type: Runtime.StreamError\r\nLambda-Runtime-Function-Error-Body: {}\r\n",
217            BASE64_STANDARD.encode(self.to_string())
218        )
219    }
220}
221
222/// A streaming response that contains the metadata prelude and the stream of bytes that will be
223/// sent to the client.
224#[derive(Debug)]
225pub struct StreamResponse<S> {
226    ///  The metadata prelude.
227    pub metadata_prelude: MetadataPrelude,
228    /// The stream of bytes that will be sent to the client.
229    pub stream: S,
230}
231
232/// An enum representing the response of a function that can return either a buffered
233/// response of type `B` or a streaming response of type `S`.
234pub enum FunctionResponse<B, S> {
235    /// A buffered response containing the entire payload of the response. This is useful
236    /// for responses that can be processed quickly and have a relatively small payload size(<= 6MB).
237    BufferedResponse(B),
238    /// A streaming response that delivers the payload incrementally. This is useful for
239    /// large payloads(> 6MB) or responses that take a long time to generate. The client can start
240    /// processing the response as soon as the first chunk is available, without waiting
241    /// for the entire payload to be generated.
242    StreamingResponse(StreamResponse<S>),
243}
244
245/// a trait that can be implemented for any type that can be converted into a FunctionResponse.
246/// This allows us to use the `into` method to convert a type into a FunctionResponse.
247pub trait IntoFunctionResponse<B, S> {
248    /// Convert the type into a FunctionResponse.
249    fn into_response(self) -> FunctionResponse<B, S>;
250}
251
252impl<B, S> IntoFunctionResponse<B, S> for FunctionResponse<B, S> {
253    fn into_response(self) -> FunctionResponse<B, S> {
254        self
255    }
256}
257
258impl<B> IntoFunctionResponse<B, Body> for B
259where
260    B: Serialize,
261{
262    fn into_response(self) -> FunctionResponse<B, Body> {
263        FunctionResponse::BufferedResponse(self)
264    }
265}
266
267impl<S, D, E> IntoFunctionResponse<(), S> for StreamResponse<S>
268where
269    S: Stream<Item = Result<D, E>> + Unpin + Send + 'static,
270    D: Into<Bytes> + Send,
271    E: Into<Error> + Send + Debug,
272{
273    fn into_response(self) -> FunctionResponse<(), S> {
274        FunctionResponse::StreamingResponse(self)
275    }
276}
277
278impl<S, D, E> From<S> for StreamResponse<S>
279where
280    S: Stream<Item = Result<D, E>> + Unpin + Send + 'static,
281    D: Into<Bytes> + Send,
282    E: Into<Error> + Send + Debug,
283{
284    fn from(value: S) -> Self {
285        StreamResponse {
286            metadata_prelude: Default::default(),
287            stream: value,
288        }
289    }
290}
291
292#[cfg(test)]
293mod test {
294    use super::*;
295    use crate::Config;
296    use std::sync::Arc;
297
298    #[test]
299    fn context_with_expected_values_and_types_resolves() {
300        let config = Arc::new(Config::default());
301
302        let mut headers = HeaderMap::new();
303        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
304        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
305        headers.insert(
306            "lambda-runtime-invoked-function-arn",
307            HeaderValue::from_static("arn::myarn"),
308        );
309        headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
310        let tried = Context::new("id", config, &headers);
311        assert!(tried.is_ok());
312    }
313
314    #[test]
315    fn context_with_certain_missing_headers_still_resolves() {
316        let config = Arc::new(Config::default());
317
318        let mut headers = HeaderMap::new();
319        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
320        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
321        let tried = Context::new("id", config, &headers);
322        assert!(tried.is_ok());
323    }
324
325    #[test]
326    fn context_with_client_context_resolves() {
327        let mut custom = HashMap::new();
328        custom.insert("key".to_string(), "value".to_string());
329        let mut environment = HashMap::new();
330        environment.insert("key".to_string(), "value".to_string());
331        let client_context = ClientContext {
332            client: ClientApplication {
333                installation_id: String::new(),
334                app_title: String::new(),
335                app_version_name: String::new(),
336                app_version_code: String::new(),
337                app_package_name: String::new(),
338            },
339            custom,
340            environment,
341        };
342        let client_context_str = serde_json::to_string(&client_context).unwrap();
343        let mut headers = HeaderMap::new();
344        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
345        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
346        headers.insert(
347            "lambda-runtime-client-context",
348            HeaderValue::from_str(&client_context_str).unwrap(),
349        );
350
351        let config = Arc::new(Config::default());
352        let tried = Context::new("id", config, &headers);
353        assert!(tried.is_ok());
354        let tried = tried.unwrap();
355        assert!(tried.client_context.is_some());
356        assert_eq!(tried.client_context.unwrap(), client_context);
357    }
358
359    #[test]
360    fn context_with_empty_client_context_resolves() {
361        let config = Arc::new(Config::default());
362        let mut headers = HeaderMap::new();
363        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
364        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
365        headers.insert("lambda-runtime-client-context", HeaderValue::from_static("{}"));
366        let tried = Context::new("id", config, &headers);
367        assert!(tried.is_ok());
368        assert!(tried.unwrap().client_context.is_some());
369    }
370
371    #[test]
372    fn context_with_identity_resolves() {
373        let config = Arc::new(Config::default());
374
375        let cognito_identity = CognitoIdentity {
376            identity_id: String::new(),
377            identity_pool_id: String::new(),
378        };
379        let cognito_identity_str = serde_json::to_string(&cognito_identity).unwrap();
380        let mut headers = HeaderMap::new();
381        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
382        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
383        headers.insert(
384            "lambda-runtime-cognito-identity",
385            HeaderValue::from_str(&cognito_identity_str).unwrap(),
386        );
387        let tried = Context::new("id", config, &headers);
388        assert!(tried.is_ok());
389        let tried = tried.unwrap();
390        assert!(tried.identity.is_some());
391        assert_eq!(tried.identity.unwrap(), cognito_identity);
392    }
393
394    #[test]
395    fn context_with_bad_deadline_type_is_err() {
396        let config = Arc::new(Config::default());
397
398        let mut headers = HeaderMap::new();
399        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
400        headers.insert(
401            "lambda-runtime-deadline-ms",
402            HeaderValue::from_static("BAD-Type,not <u64>"),
403        );
404        headers.insert(
405            "lambda-runtime-invoked-function-arn",
406            HeaderValue::from_static("arn::myarn"),
407        );
408        headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
409        let tried = Context::new("id", config, &headers);
410        assert!(tried.is_err());
411    }
412
413    #[test]
414    fn context_with_bad_client_context_is_err() {
415        let config = Arc::new(Config::default());
416
417        let mut headers = HeaderMap::new();
418        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
419        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
420        headers.insert(
421            "lambda-runtime-client-context",
422            HeaderValue::from_static("BAD-Type,not JSON"),
423        );
424        let tried = Context::new("id", config, &headers);
425        assert!(tried.is_err());
426    }
427
428    #[test]
429    fn context_with_empty_identity_is_err() {
430        let config = Arc::new(Config::default());
431
432        let mut headers = HeaderMap::new();
433        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
434        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
435        headers.insert("lambda-runtime-cognito-identity", HeaderValue::from_static("{}"));
436        let tried = Context::new("id", config, &headers);
437        assert!(tried.is_err());
438    }
439
440    #[test]
441    fn context_with_bad_identity_is_err() {
442        let config = Arc::new(Config::default());
443
444        let mut headers = HeaderMap::new();
445        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
446        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
447        headers.insert(
448            "lambda-runtime-cognito-identity",
449            HeaderValue::from_static("BAD-Type,not JSON"),
450        );
451        let tried = Context::new("id", config, &headers);
452        assert!(tried.is_err());
453    }
454
455    #[test]
456    #[should_panic]
457    fn context_with_missing_deadline_should_panic() {
458        let config = Arc::new(Config::default());
459
460        let mut headers = HeaderMap::new();
461        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
462        headers.insert(
463            "lambda-runtime-invoked-function-arn",
464            HeaderValue::from_static("arn::myarn"),
465        );
466        headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
467        let _ = Context::new("id", config, &headers);
468    }
469
470    #[test]
471    fn invoke_request_id_should_not_panic() {
472        let mut headers = HeaderMap::new();
473        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
474        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
475        headers.insert(
476            "lambda-runtime-invoked-function-arn",
477            HeaderValue::from_static("arn::myarn"),
478        );
479        headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
480
481        let _ = invoke_request_id(&headers);
482    }
483
484    #[test]
485    #[should_panic]
486    fn invoke_request_id_should_panic() {
487        let mut headers = HeaderMap::new();
488        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
489        headers.insert(
490            "lambda-runtime-invoked-function-arn",
491            HeaderValue::from_static("arn::myarn"),
492        );
493        headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
494
495        let _ = invoke_request_id(&headers);
496    }
497
498    #[test]
499    fn serde_metadata_prelude() {
500        let metadata_prelude = MetadataPrelude {
501            status_code: StatusCode::OK,
502            headers: {
503                let mut headers = HeaderMap::new();
504                headers.insert("key", "val".parse().unwrap());
505                headers
506            },
507            cookies: vec!["cookie".to_string()],
508        };
509
510        let serialized = serde_json::to_string(&metadata_prelude).unwrap();
511        let deserialized: MetadataPrelude = serde_json::from_str(&serialized).unwrap();
512
513        assert_eq!(metadata_prelude, deserialized);
514    }
515
516    #[test]
517    fn context_with_tenant_id_resolves() {
518        let config = Arc::new(Config::default());
519        let mut headers = HeaderMap::new();
520        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
521        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
522        headers.insert("lambda-runtime-aws-tenant-id", HeaderValue::from_static("tenant-123"));
523
524        let context = Context::new("id", config, &headers).unwrap();
525        assert_eq!(context.tenant_id, Some("tenant-123".to_string()));
526    }
527
528    #[test]
529    fn context_without_tenant_id_resolves() {
530        let config = Arc::new(Config::default());
531        let mut headers = HeaderMap::new();
532        headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
533        headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
534
535        let context = Context::new("id", config, &headers).unwrap();
536        assert_eq!(context.tenant_id, None);
537    }
538}