lambda_runtime_client/
client.rs

1use crate::error::{ApiError, ApiErrorKind, ErrorResponse};
2use failure::ResultExt;
3use hyper::{
4    client::HttpConnector,
5    header::{self, HeaderMap, HeaderValue},
6    rt::{Future, Stream},
7    Body, Client, Method, Request, Uri,
8};
9use log::*;
10use serde_derive::*;
11use serde_json;
12use std::{collections::HashMap, fmt};
13use tokio::runtime::Runtime;
14
15const RUNTIME_API_VERSION: &str = "2018-06-01";
16const API_CONTENT_TYPE: &str = "application/json";
17const API_ERROR_CONTENT_TYPE: &str = "application/vnd.aws.lambda.error+json";
18const RUNTIME_ERROR_HEADER: &str = "Lambda-Runtime-Function-Error-Type";
19// TODO: Perhaps use a macro to generate this
20const DEFAULT_AGENT: &str = "AWS_Lambda_Rust";
21
22/// Enum of the headers returned by Lambda's `/next` API call.
23pub enum LambdaHeaders {
24    /// The AWS request ID
25    RequestId,
26    /// The ARN of the Lambda function being invoked
27    FunctionArn,
28    /// The X-Ray trace ID generated for this invocation
29    TraceId,
30    /// The deadline for the function execution in milliseconds
31    Deadline,
32    /// The client context header. This field is populated when the function
33    /// is invoked from a mobile client.
34    ClientContext,
35    /// The Cognito Identity context for the invocation. This field is populated
36    /// when the function is invoked with AWS credentials obtained from Cognito
37    /// Identity.
38    CognitoIdentity,
39}
40
41impl LambdaHeaders {
42    /// Returns the `str` representation of the header.
43    fn as_str(&self) -> &'static str {
44        match self {
45            LambdaHeaders::RequestId => "Lambda-Runtime-Aws-Request-Id",
46            LambdaHeaders::FunctionArn => "Lambda-Runtime-Invoked-Function-Arn",
47            LambdaHeaders::TraceId => "Lambda-Runtime-Trace-Id",
48            LambdaHeaders::Deadline => "Lambda-Runtime-Deadline-Ms",
49            LambdaHeaders::ClientContext => "Lambda-Runtime-Client-Context",
50            LambdaHeaders::CognitoIdentity => "Lambda-Runtime-Cognito-Identity",
51        }
52    }
53}
54
55impl fmt::Display for LambdaHeaders {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        f.write_str(self.as_str())
58    }
59}
60
61/// AWS Moble SDK client properties
62#[derive(Deserialize, Clone)]
63pub struct ClientApplication {
64    /// The mobile app installation id
65    #[serde(rename = "installationId")]
66    pub installation_id: String,
67    /// The app title for the mobile app as registered with AWS' mobile services.
68    #[serde(rename = "appTitle")]
69    pub app_title: String,
70    /// The version name of the application as registered with AWS' mobile services.
71    #[serde(rename = "appVersionName")]
72    pub app_version_name: String,
73    /// The app version code.
74    #[serde(rename = "appVersionCode")]
75    pub app_version_code: String,
76    /// The package name for the mobile application invoking the function
77    #[serde(rename = "appPackageName")]
78    pub app_package_name: String,
79}
80
81/// Client context sent by the AWS Mobile SDK.
82#[derive(Deserialize, Clone)]
83pub struct ClientContext {
84    /// Information about the mobile application invoking the function.
85    pub client: ClientApplication,
86    /// Custom properties attached to the mobile event context.
87    pub custom: HashMap<String, String>,
88    /// Environment settings from the mobile client.
89    pub environment: HashMap<String, String>,
90}
91
92#[derive(Deserialize, Clone)]
93/// Cognito identity information sent with the event
94pub struct CognitoIdentity {
95    /// The unique identity id for the Cognito credentials invoking the function.
96    pub identity_id: String,
97    /// The identity pool id the caller is "registered" with.
98    pub identity_pool_id: String,
99}
100
101/// The Lambda function execution context. The values in this struct
102/// are populated using the [Lambda environment variables](https://docs.aws.amazon.com/lambda/latest/dg/current-supported-versions.html)
103/// and the headers returned by the poll request to the Runtime APIs.
104/// A new instance of the `Context` object is passed to each handler invocation.
105#[derive(Clone)]
106pub struct EventContext {
107    /// The ARN of the Lambda function being invoked.
108    pub invoked_function_arn: String,
109    /// The AWS request ID generated by the Lambda service.
110    pub aws_request_id: String,
111    /// The X-Ray trace ID for the current invocation.
112    pub xray_trace_id: Option<String>,
113    /// The execution deadline for the current invocation in milliseconds.
114    pub deadline: i64,
115    /// The client context object sent by the AWS mobile SDK. This field is
116    /// empty unless the function is invoked using an AWS mobile SDK.
117    pub client_context: Option<ClientContext>,
118    /// The Cognito identity that invoked the function. This field is empty
119    /// unless the invocation request to the Lambda APIs was made using AWS
120    /// credentials issues by Amazon Cognito Identity Pools.
121    pub identity: Option<CognitoIdentity>,
122}
123
124/// Used by the Runtime to communicate with the internal endpoint.
125pub struct RuntimeClient {
126    _runtime: Runtime,
127    http_client: Client<HttpConnector, Body>,
128    next_endpoint: Uri,
129    runtime_agent: String,
130    host: String,
131}
132
133impl<'ev> RuntimeClient {
134    /// Creates a new instance of the Runtime APIclient SDK. The http client has timeouts disabled and
135    /// will always send a `Connection: keep-alive` header. Optionally, the runtime client can receive
136    /// a user agent string. This string is used to make requests to the runtime APIs and is used to
137    /// identify the runtime being used by the function. For example, the `lambda_runtime_core` crate
138    /// uses `AWS_Lambda_Rust/0.1.0 (rustc/1.31.1-stable)`. The runtime client can also receive an
139    /// instance of Tokio Runtime to use.
140    pub fn new(host: &str, agent: Option<String>, runtime: Option<Runtime>) -> Result<Self, ApiError> {
141        debug!("Starting new HttpRuntimeClient for {}", host);
142        let runtime_agent = match agent {
143            Some(a) => a,
144            None => DEFAULT_AGENT.to_owned(),
145        };
146
147        // start a tokio core main event loop for hyper
148        let runtime = match runtime {
149            Some(r) => r,
150            None => Runtime::new().context(ApiErrorKind::Unrecoverable("Could not initialize runtime".to_string()))?,
151        };
152
153        let http_client = Client::builder().executor(runtime.executor()).build_http();
154        // we cached the parsed Uri since this never changes.
155        let next_endpoint = format!("http://{}/{}/runtime/invocation/next", host, RUNTIME_API_VERSION)
156            .parse::<Uri>()
157            .context(ApiErrorKind::Unrecoverable("Could not parse API uri".to_string()))?;
158
159        Ok(RuntimeClient {
160            _runtime: runtime,
161            http_client,
162            next_endpoint,
163            runtime_agent,
164            host: host.to_owned(),
165        })
166    }
167}
168
169impl<'ev> RuntimeClient {
170    /// Polls for new events to the Runtime APIs.
171    pub fn next_event(&self) -> Result<(Vec<u8>, EventContext), ApiError> {
172        trace!("Polling for next event");
173
174        // We wait instead of processing the future asynchronously because AWS Lambda
175        // itself enforces only one event per container at a time. No point in taking on
176        // the additional complexity.
177        let resp = self
178            .http_client
179            .get(self.next_endpoint.clone())
180            .wait()
181            .context(ApiErrorKind::Unrecoverable("Could not fetch next event".to_string()))?;
182
183        if resp.status().is_client_error() {
184            error!(
185                "Runtime API returned client error when polling for new events: {}",
186                resp.status()
187            );
188            Err(ApiErrorKind::Recoverable(format!(
189                "Error {} when polling for events",
190                resp.status()
191            )))?;
192        }
193        if resp.status().is_server_error() {
194            error!(
195                "Runtime API returned server error when polling for new events: {}",
196                resp.status()
197            );
198            Err(ApiErrorKind::Unrecoverable(
199                "Server error when polling for new events".to_string(),
200            ))?;
201        }
202        let ctx = self.get_event_context(&resp.headers())?;
203        let out = resp
204            .into_body()
205            .concat2()
206            .wait()
207            .context(ApiErrorKind::Recoverable("Could not read event boxy".to_string()))?;
208        let buf = out.into_bytes().to_vec();
209
210        trace!(
211            "Received new event for request id {}. Event length {} bytes",
212            ctx.aws_request_id,
213            buf.len()
214        );
215        Ok((buf, ctx))
216    }
217
218    /// Calls the Lambda Runtime APIs to submit a response to an event. In this function we treat
219    /// all errors from the API as an unrecoverable error. This is because the API returns
220    /// 4xx errors for responses that are too long. In that case, we simply log the output and fail.
221    ///
222    /// # Arguments
223    ///
224    /// * `request_id` The request id associated with the event we are serving the response for.
225    ///                This is returned as a header from the poll (`/next`) API.
226    /// * `output` The object be sent back to the Runtime APIs as a response.
227    ///
228    /// # Returns
229    /// A `Result` object containing a bool return value for the call or an `error::ApiError` instance.
230    pub fn event_response(&self, request_id: &str, output: &[u8]) -> Result<(), ApiError> {
231        trace!(
232            "Posting response for request {} to Runtime API. Response length {} bytes",
233            request_id,
234            output.len()
235        );
236        let uri = format!(
237            "http://{}/{}/runtime/invocation/{}/response",
238            self.host, RUNTIME_API_VERSION, request_id
239        )
240        .parse::<Uri>()
241        .context(ApiErrorKind::Unrecoverable(
242            "Could not generate response uri".to_owned(),
243        ))?;
244        let req = self.get_runtime_post_request(&uri, output);
245
246        let resp = self
247            .http_client
248            .request(req)
249            .wait()
250            .context(ApiErrorKind::Recoverable("Could not post event response".to_string()))?;
251        if !resp.status().is_success() {
252            error!(
253                "Error from Runtime API when posting response for request {}: {}",
254                request_id,
255                resp.status()
256            );
257            Err(ApiErrorKind::Recoverable(format!(
258                "Error {} while sending response",
259                resp.status()
260            )))?;
261        }
262        trace!("Posted response to Runtime API for request {}", request_id);
263        Ok(())
264    }
265
266    /// Calls Lambda's Runtime APIs to send an error generated by the `Handler`. Because it's rust,
267    /// the error type for lambda is always `handled`.
268    ///
269    /// # Arguments
270    ///
271    /// * `request_id` The request id associated with the event we are serving the error for.
272    /// * `e` An instance of `errors::HandlerError` generated by the handler function.
273    ///
274    /// # Returns
275    /// A `Result` object containing a bool return value for the call or an `error::ApiError` instance.
276    pub fn event_error(&self, request_id: &str, e: &ErrorResponse) -> Result<(), ApiError> {
277        trace!(
278            "Posting error to runtime API for request {}: {}",
279            request_id,
280            e.error_message
281        );
282        let uri = format!(
283            "http://{}/{}/runtime/invocation/{}/error",
284            self.host, RUNTIME_API_VERSION, request_id
285        )
286        .parse::<Uri>()
287        .context(ApiErrorKind::Unrecoverable(
288            "Could not generate response uri".to_owned(),
289        ))?;
290        let req = self.get_runtime_error_request(&uri, &e);
291
292        let resp = self.http_client.request(req).wait().context(ApiErrorKind::Recoverable(
293            "Could not post event error response".to_string(),
294        ))?;
295        if !resp.status().is_success() {
296            error!(
297                "Error from Runtime API when posting error response for request {}: {}",
298                request_id,
299                resp.status()
300            );
301            Err(ApiErrorKind::Recoverable(format!(
302                "Error {} while sending response",
303                resp.status()
304            )))?;
305        }
306        trace!("Posted error response for request id {}", request_id);
307        Ok(())
308    }
309
310    /// Calls the Runtime APIs to report a failure during the init process.
311    /// The contents of the error (`e`) parmeter are passed to the Runtime APIs
312    /// using the private `to_response()` method.
313    ///
314    /// # Arguments
315    ///
316    /// * `e` An instance of `errors::RuntimeError`.
317    ///
318    /// # Panics
319    /// If it cannot send the init error. In this case we panic to force the runtime
320    /// to restart.
321    pub fn fail_init(&self, e: &ErrorResponse) {
322        error!("Calling fail_init Runtime API: {}", e.error_message);
323        let uri = format!("http://{}/{}/runtime/init/error", self.host, RUNTIME_API_VERSION)
324            .parse::<Uri>()
325            .map_err(|e| {
326                error!("Could not parse fail init URI: {}", e);
327                panic!("Killing runtime");
328            });
329        let req = self.get_runtime_error_request(&uri.unwrap(), &e);
330
331        self.http_client
332            .request(req)
333            .wait()
334            .map_err(|e| {
335                error!("Error while sending init failed message: {}", e);
336                panic!("Error while sending init failed message: {}", e);
337            })
338            .map(|resp| {
339                info!("Successfully sent error response to the runtime API: {:?}", resp);
340            })
341            .expect("Could not complete init_fail request");
342    }
343
344    /// Returns the endpoint configured for this HTTP Runtime client.
345    pub fn get_endpoint(&self) -> &str {
346        &self.host
347    }
348
349    /// Creates a Hyper `Request` object for the given `Uri` and `Body`. Sets the
350    /// HTTP method to `POST` and the `Content-Type` header value to `application/json`.
351    ///
352    /// # Arguments
353    ///
354    /// * `uri` A `Uri` reference target for the request
355    /// * `body` The content of the post request. This parameter must not be null
356    ///
357    /// # Returns
358    /// A Populated Hyper `Request` object.
359    fn get_runtime_post_request(&self, uri: &Uri, body: &[u8]) -> Request<Body> {
360        Request::builder()
361            .method(Method::POST)
362            .uri(uri.clone())
363            .header(header::CONTENT_TYPE, header::HeaderValue::from_static(API_CONTENT_TYPE))
364            .header(header::USER_AGENT, self.runtime_agent.clone())
365            .body(Body::from(body.to_owned()))
366            .unwrap()
367    }
368
369    fn get_runtime_error_request(&self, uri: &Uri, e: &ErrorResponse) -> Request<Body> {
370        let body = serde_json::to_vec(&e).expect("Could not turn error object into response JSON");
371        Request::builder()
372            .method(Method::POST)
373            .uri(uri.clone())
374            .header(
375                header::CONTENT_TYPE,
376                header::HeaderValue::from_static(API_ERROR_CONTENT_TYPE),
377            )
378            .header(header::USER_AGENT, self.runtime_agent.clone())
379            // this header is static for the runtime APIs and it's likely to go away in the future.
380            .header(RUNTIME_ERROR_HEADER, HeaderValue::from_static("Unhandled"))
381            .body(Body::from(body))
382            .unwrap()
383    }
384
385    /// Creates an `EventContext` object based on the response returned by the Runtime
386    /// API `/next` endpoint.
387    ///
388    /// # Arguments
389    ///
390    /// * `resp` The response returned by the Runtime APIs endpoint.
391    ///
392    /// # Returns
393    /// A `Result` containing the populated `EventContext` or an `ApiError` if the required headers
394    /// were not present or the client context and cognito identity could not be parsed from the
395    /// JSON string.
396    fn get_event_context(&self, headers: &HeaderMap<HeaderValue>) -> Result<EventContext, ApiError> {
397        // let headers = resp.headers();
398
399        let aws_request_id = header_string(
400            headers.get(LambdaHeaders::RequestId.as_str()),
401            &LambdaHeaders::RequestId,
402        )?;
403        let invoked_function_arn = header_string(
404            headers.get(LambdaHeaders::FunctionArn.as_str()),
405            &LambdaHeaders::FunctionArn,
406        )?;
407        let xray_trace_id = match headers.get(LambdaHeaders::TraceId.as_str()) {
408            Some(trace_id) => match trace_id.to_str() {
409                Ok(trace_str) => Some(trace_str.to_owned()),
410                Err(e) => {
411                    // we do not fail on this error.
412                    error!("Could not parse X-Ray trace id as string: {}", e);
413                    None
414                }
415            },
416            None => None,
417        };
418        let deadline = header_string(headers.get(LambdaHeaders::Deadline.as_str()), &LambdaHeaders::Deadline)?
419            .parse::<i64>()
420            .context(ApiErrorKind::Recoverable(
421                "Could not parse deadline header value to int".to_string(),
422            ))?;
423
424        let mut ctx = EventContext {
425            aws_request_id,
426            invoked_function_arn,
427            xray_trace_id,
428            deadline,
429            client_context: Option::default(),
430            identity: Option::default(),
431        };
432
433        if let Some(ctx_json) = headers.get(LambdaHeaders::ClientContext.as_str()) {
434            let ctx_json = ctx_json.to_str().context(ApiErrorKind::Recoverable(
435                "Could not convert context header content to string".to_string(),
436            ))?;
437            trace!("Found Client Context in response headers: {}", ctx_json);
438            let ctx_value: ClientContext = serde_json::from_str(&ctx_json).context(ApiErrorKind::Recoverable(
439                "Could not parse client context value as json object".to_string(),
440            ))?;
441            ctx.client_context = Option::from(ctx_value);
442        };
443
444        if let Some(cognito_json) = headers.get(LambdaHeaders::CognitoIdentity.as_str()) {
445            let cognito_json = cognito_json.to_str().context(ApiErrorKind::Recoverable(
446                "Could not convert congnito context header content to string".to_string(),
447            ))?;
448            trace!("Found Cognito Identity in response headers: {}", cognito_json);
449            let identity_value: CognitoIdentity = serde_json::from_str(&cognito_json).context(
450                ApiErrorKind::Recoverable("Could not parse cognito context value as json object".to_string()),
451            )?;
452            ctx.identity = Option::from(identity_value);
453        };
454
455        Ok(ctx)
456    }
457}
458
459fn header_string(value: Option<&HeaderValue>, header_type: &LambdaHeaders) -> Result<String, ApiError> {
460    match value {
461        Some(value_str) => Ok(value_str
462            .to_str()
463            .context(ApiErrorKind::Recoverable(format!(
464                "Could not parse {} header",
465                header_type
466            )))?
467            .to_owned()),
468        None => {
469            error!("Response headers do not contain {} header", header_type);
470            Err(ApiErrorKind::Recoverable(format!("Missing {} header", header_type)))?
471        }
472    }
473}
474
475#[cfg(test)]
476pub(crate) mod tests {
477    use super::*;
478    use chrono::{Duration, Utc};
479
480    fn get_headers() -> HeaderMap<HeaderValue> {
481        let mut headers: HeaderMap<HeaderValue> = HeaderMap::new();
482        headers.insert(
483            LambdaHeaders::RequestId.as_str(),
484            HeaderValue::from_str("req_id").unwrap(),
485        );
486        headers.insert(
487            LambdaHeaders::FunctionArn.as_str(),
488            HeaderValue::from_str("func_arn").unwrap(),
489        );
490        headers.insert(LambdaHeaders::TraceId.as_str(), HeaderValue::from_str("trace").unwrap());
491        let deadline = Utc::now() + Duration::seconds(10);
492        headers.insert(
493            LambdaHeaders::Deadline.as_str(),
494            HeaderValue::from_str(&deadline.timestamp_millis().to_string()).unwrap(),
495        );
496        headers
497    }
498
499    #[test]
500    fn get_event_context_with_empty_trace_id() {
501        let client = RuntimeClient::new("localhost:8081", None, None).expect("Could not initialize runtime client");
502        let mut headers = get_headers();
503        headers.remove(LambdaHeaders::TraceId.as_str());
504        let headers_result = client.get_event_context(&headers);
505        assert_eq!(false, headers_result.is_err());
506        let ok_result = headers_result.unwrap();
507        assert_eq!(None, ok_result.xray_trace_id);
508        assert_eq!("req_id", ok_result.aws_request_id);
509    }
510
511    #[test]
512    fn get_event_context_populates_trace_id_when_present() {
513        let client = RuntimeClient::new("localhost:8081", None, None).expect("Could not initialize runtime client");
514        let headers = get_headers();
515        let headers_result = client.get_event_context(&headers);
516        assert_eq!(false, headers_result.is_err());
517        assert_eq!(Some("trace".to_owned()), headers_result.unwrap().xray_trace_id);
518    }
519}