1use crate::error::{ApiError, ApiErrorKind, ErrorResponse};
2use failure::ResultExt;
3use hyper::{
4 client::HttpConnector,
5 header::{self, HeaderMap, HeaderValue},
6 rt::{Future, Stream},
7 Body, Client, Method, Request, Uri,
8};
9use log::*;
10use serde_derive::*;
11use serde_json;
12use std::{collections::HashMap, fmt};
13use tokio::runtime::Runtime;
14
15const RUNTIME_API_VERSION: &str = "2018-06-01";
16const API_CONTENT_TYPE: &str = "application/json";
17const API_ERROR_CONTENT_TYPE: &str = "application/vnd.aws.lambda.error+json";
18const RUNTIME_ERROR_HEADER: &str = "Lambda-Runtime-Function-Error-Type";
19const DEFAULT_AGENT: &str = "AWS_Lambda_Rust";
21
22pub enum LambdaHeaders {
24 RequestId,
26 FunctionArn,
28 TraceId,
30 Deadline,
32 ClientContext,
35 CognitoIdentity,
39}
40
41impl LambdaHeaders {
42 fn as_str(&self) -> &'static str {
44 match self {
45 LambdaHeaders::RequestId => "Lambda-Runtime-Aws-Request-Id",
46 LambdaHeaders::FunctionArn => "Lambda-Runtime-Invoked-Function-Arn",
47 LambdaHeaders::TraceId => "Lambda-Runtime-Trace-Id",
48 LambdaHeaders::Deadline => "Lambda-Runtime-Deadline-Ms",
49 LambdaHeaders::ClientContext => "Lambda-Runtime-Client-Context",
50 LambdaHeaders::CognitoIdentity => "Lambda-Runtime-Cognito-Identity",
51 }
52 }
53}
54
55impl fmt::Display for LambdaHeaders {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 f.write_str(self.as_str())
58 }
59}
60
61#[derive(Deserialize, Clone)]
63pub struct ClientApplication {
64 #[serde(rename = "installationId")]
66 pub installation_id: String,
67 #[serde(rename = "appTitle")]
69 pub app_title: String,
70 #[serde(rename = "appVersionName")]
72 pub app_version_name: String,
73 #[serde(rename = "appVersionCode")]
75 pub app_version_code: String,
76 #[serde(rename = "appPackageName")]
78 pub app_package_name: String,
79}
80
81#[derive(Deserialize, Clone)]
83pub struct ClientContext {
84 pub client: ClientApplication,
86 pub custom: HashMap<String, String>,
88 pub environment: HashMap<String, String>,
90}
91
92#[derive(Deserialize, Clone)]
93pub struct CognitoIdentity {
95 pub identity_id: String,
97 pub identity_pool_id: String,
99}
100
101#[derive(Clone)]
106pub struct EventContext {
107 pub invoked_function_arn: String,
109 pub aws_request_id: String,
111 pub xray_trace_id: Option<String>,
113 pub deadline: i64,
115 pub client_context: Option<ClientContext>,
118 pub identity: Option<CognitoIdentity>,
122}
123
124pub struct RuntimeClient {
126 _runtime: Runtime,
127 http_client: Client<HttpConnector, Body>,
128 next_endpoint: Uri,
129 runtime_agent: String,
130 host: String,
131}
132
133impl<'ev> RuntimeClient {
134 pub fn new(host: &str, agent: Option<String>, runtime: Option<Runtime>) -> Result<Self, ApiError> {
141 debug!("Starting new HttpRuntimeClient for {}", host);
142 let runtime_agent = match agent {
143 Some(a) => a,
144 None => DEFAULT_AGENT.to_owned(),
145 };
146
147 let runtime = match runtime {
149 Some(r) => r,
150 None => Runtime::new().context(ApiErrorKind::Unrecoverable("Could not initialize runtime".to_string()))?,
151 };
152
153 let http_client = Client::builder().executor(runtime.executor()).build_http();
154 let next_endpoint = format!("http://{}/{}/runtime/invocation/next", host, RUNTIME_API_VERSION)
156 .parse::<Uri>()
157 .context(ApiErrorKind::Unrecoverable("Could not parse API uri".to_string()))?;
158
159 Ok(RuntimeClient {
160 _runtime: runtime,
161 http_client,
162 next_endpoint,
163 runtime_agent,
164 host: host.to_owned(),
165 })
166 }
167}
168
169impl<'ev> RuntimeClient {
170 pub fn next_event(&self) -> Result<(Vec<u8>, EventContext), ApiError> {
172 trace!("Polling for next event");
173
174 let resp = self
178 .http_client
179 .get(self.next_endpoint.clone())
180 .wait()
181 .context(ApiErrorKind::Unrecoverable("Could not fetch next event".to_string()))?;
182
183 if resp.status().is_client_error() {
184 error!(
185 "Runtime API returned client error when polling for new events: {}",
186 resp.status()
187 );
188 Err(ApiErrorKind::Recoverable(format!(
189 "Error {} when polling for events",
190 resp.status()
191 )))?;
192 }
193 if resp.status().is_server_error() {
194 error!(
195 "Runtime API returned server error when polling for new events: {}",
196 resp.status()
197 );
198 Err(ApiErrorKind::Unrecoverable(
199 "Server error when polling for new events".to_string(),
200 ))?;
201 }
202 let ctx = self.get_event_context(&resp.headers())?;
203 let out = resp
204 .into_body()
205 .concat2()
206 .wait()
207 .context(ApiErrorKind::Recoverable("Could not read event boxy".to_string()))?;
208 let buf = out.into_bytes().to_vec();
209
210 trace!(
211 "Received new event for request id {}. Event length {} bytes",
212 ctx.aws_request_id,
213 buf.len()
214 );
215 Ok((buf, ctx))
216 }
217
218 pub fn event_response(&self, request_id: &str, output: &[u8]) -> Result<(), ApiError> {
231 trace!(
232 "Posting response for request {} to Runtime API. Response length {} bytes",
233 request_id,
234 output.len()
235 );
236 let uri = format!(
237 "http://{}/{}/runtime/invocation/{}/response",
238 self.host, RUNTIME_API_VERSION, request_id
239 )
240 .parse::<Uri>()
241 .context(ApiErrorKind::Unrecoverable(
242 "Could not generate response uri".to_owned(),
243 ))?;
244 let req = self.get_runtime_post_request(&uri, output);
245
246 let resp = self
247 .http_client
248 .request(req)
249 .wait()
250 .context(ApiErrorKind::Recoverable("Could not post event response".to_string()))?;
251 if !resp.status().is_success() {
252 error!(
253 "Error from Runtime API when posting response for request {}: {}",
254 request_id,
255 resp.status()
256 );
257 Err(ApiErrorKind::Recoverable(format!(
258 "Error {} while sending response",
259 resp.status()
260 )))?;
261 }
262 trace!("Posted response to Runtime API for request {}", request_id);
263 Ok(())
264 }
265
266 pub fn event_error(&self, request_id: &str, e: &ErrorResponse) -> Result<(), ApiError> {
277 trace!(
278 "Posting error to runtime API for request {}: {}",
279 request_id,
280 e.error_message
281 );
282 let uri = format!(
283 "http://{}/{}/runtime/invocation/{}/error",
284 self.host, RUNTIME_API_VERSION, request_id
285 )
286 .parse::<Uri>()
287 .context(ApiErrorKind::Unrecoverable(
288 "Could not generate response uri".to_owned(),
289 ))?;
290 let req = self.get_runtime_error_request(&uri, &e);
291
292 let resp = self.http_client.request(req).wait().context(ApiErrorKind::Recoverable(
293 "Could not post event error response".to_string(),
294 ))?;
295 if !resp.status().is_success() {
296 error!(
297 "Error from Runtime API when posting error response for request {}: {}",
298 request_id,
299 resp.status()
300 );
301 Err(ApiErrorKind::Recoverable(format!(
302 "Error {} while sending response",
303 resp.status()
304 )))?;
305 }
306 trace!("Posted error response for request id {}", request_id);
307 Ok(())
308 }
309
310 pub fn fail_init(&self, e: &ErrorResponse) {
322 error!("Calling fail_init Runtime API: {}", e.error_message);
323 let uri = format!("http://{}/{}/runtime/init/error", self.host, RUNTIME_API_VERSION)
324 .parse::<Uri>()
325 .map_err(|e| {
326 error!("Could not parse fail init URI: {}", e);
327 panic!("Killing runtime");
328 });
329 let req = self.get_runtime_error_request(&uri.unwrap(), &e);
330
331 self.http_client
332 .request(req)
333 .wait()
334 .map_err(|e| {
335 error!("Error while sending init failed message: {}", e);
336 panic!("Error while sending init failed message: {}", e);
337 })
338 .map(|resp| {
339 info!("Successfully sent error response to the runtime API: {:?}", resp);
340 })
341 .expect("Could not complete init_fail request");
342 }
343
344 pub fn get_endpoint(&self) -> &str {
346 &self.host
347 }
348
349 fn get_runtime_post_request(&self, uri: &Uri, body: &[u8]) -> Request<Body> {
360 Request::builder()
361 .method(Method::POST)
362 .uri(uri.clone())
363 .header(header::CONTENT_TYPE, header::HeaderValue::from_static(API_CONTENT_TYPE))
364 .header(header::USER_AGENT, self.runtime_agent.clone())
365 .body(Body::from(body.to_owned()))
366 .unwrap()
367 }
368
369 fn get_runtime_error_request(&self, uri: &Uri, e: &ErrorResponse) -> Request<Body> {
370 let body = serde_json::to_vec(&e).expect("Could not turn error object into response JSON");
371 Request::builder()
372 .method(Method::POST)
373 .uri(uri.clone())
374 .header(
375 header::CONTENT_TYPE,
376 header::HeaderValue::from_static(API_ERROR_CONTENT_TYPE),
377 )
378 .header(header::USER_AGENT, self.runtime_agent.clone())
379 .header(RUNTIME_ERROR_HEADER, HeaderValue::from_static("Unhandled"))
381 .body(Body::from(body))
382 .unwrap()
383 }
384
385 fn get_event_context(&self, headers: &HeaderMap<HeaderValue>) -> Result<EventContext, ApiError> {
397 let aws_request_id = header_string(
400 headers.get(LambdaHeaders::RequestId.as_str()),
401 &LambdaHeaders::RequestId,
402 )?;
403 let invoked_function_arn = header_string(
404 headers.get(LambdaHeaders::FunctionArn.as_str()),
405 &LambdaHeaders::FunctionArn,
406 )?;
407 let xray_trace_id = match headers.get(LambdaHeaders::TraceId.as_str()) {
408 Some(trace_id) => match trace_id.to_str() {
409 Ok(trace_str) => Some(trace_str.to_owned()),
410 Err(e) => {
411 error!("Could not parse X-Ray trace id as string: {}", e);
413 None
414 }
415 },
416 None => None,
417 };
418 let deadline = header_string(headers.get(LambdaHeaders::Deadline.as_str()), &LambdaHeaders::Deadline)?
419 .parse::<i64>()
420 .context(ApiErrorKind::Recoverable(
421 "Could not parse deadline header value to int".to_string(),
422 ))?;
423
424 let mut ctx = EventContext {
425 aws_request_id,
426 invoked_function_arn,
427 xray_trace_id,
428 deadline,
429 client_context: Option::default(),
430 identity: Option::default(),
431 };
432
433 if let Some(ctx_json) = headers.get(LambdaHeaders::ClientContext.as_str()) {
434 let ctx_json = ctx_json.to_str().context(ApiErrorKind::Recoverable(
435 "Could not convert context header content to string".to_string(),
436 ))?;
437 trace!("Found Client Context in response headers: {}", ctx_json);
438 let ctx_value: ClientContext = serde_json::from_str(&ctx_json).context(ApiErrorKind::Recoverable(
439 "Could not parse client context value as json object".to_string(),
440 ))?;
441 ctx.client_context = Option::from(ctx_value);
442 };
443
444 if let Some(cognito_json) = headers.get(LambdaHeaders::CognitoIdentity.as_str()) {
445 let cognito_json = cognito_json.to_str().context(ApiErrorKind::Recoverable(
446 "Could not convert congnito context header content to string".to_string(),
447 ))?;
448 trace!("Found Cognito Identity in response headers: {}", cognito_json);
449 let identity_value: CognitoIdentity = serde_json::from_str(&cognito_json).context(
450 ApiErrorKind::Recoverable("Could not parse cognito context value as json object".to_string()),
451 )?;
452 ctx.identity = Option::from(identity_value);
453 };
454
455 Ok(ctx)
456 }
457}
458
459fn header_string(value: Option<&HeaderValue>, header_type: &LambdaHeaders) -> Result<String, ApiError> {
460 match value {
461 Some(value_str) => Ok(value_str
462 .to_str()
463 .context(ApiErrorKind::Recoverable(format!(
464 "Could not parse {} header",
465 header_type
466 )))?
467 .to_owned()),
468 None => {
469 error!("Response headers do not contain {} header", header_type);
470 Err(ApiErrorKind::Recoverable(format!("Missing {} header", header_type)))?
471 }
472 }
473}
474
475#[cfg(test)]
476pub(crate) mod tests {
477 use super::*;
478 use chrono::{Duration, Utc};
479
480 fn get_headers() -> HeaderMap<HeaderValue> {
481 let mut headers: HeaderMap<HeaderValue> = HeaderMap::new();
482 headers.insert(
483 LambdaHeaders::RequestId.as_str(),
484 HeaderValue::from_str("req_id").unwrap(),
485 );
486 headers.insert(
487 LambdaHeaders::FunctionArn.as_str(),
488 HeaderValue::from_str("func_arn").unwrap(),
489 );
490 headers.insert(LambdaHeaders::TraceId.as_str(), HeaderValue::from_str("trace").unwrap());
491 let deadline = Utc::now() + Duration::seconds(10);
492 headers.insert(
493 LambdaHeaders::Deadline.as_str(),
494 HeaderValue::from_str(&deadline.timestamp_millis().to_string()).unwrap(),
495 );
496 headers
497 }
498
499 #[test]
500 fn get_event_context_with_empty_trace_id() {
501 let client = RuntimeClient::new("localhost:8081", None, None).expect("Could not initialize runtime client");
502 let mut headers = get_headers();
503 headers.remove(LambdaHeaders::TraceId.as_str());
504 let headers_result = client.get_event_context(&headers);
505 assert_eq!(false, headers_result.is_err());
506 let ok_result = headers_result.unwrap();
507 assert_eq!(None, ok_result.xray_trace_id);
508 assert_eq!("req_id", ok_result.aws_request_id);
509 }
510
511 #[test]
512 fn get_event_context_populates_trace_id_when_present() {
513 let client = RuntimeClient::new("localhost:8081", None, None).expect("Could not initialize runtime client");
514 let headers = get_headers();
515 let headers_result = client.get_event_context(&headers);
516 assert_eq!(false, headers_result.is_err());
517 assert_eq!(Some("trace".to_owned()), headers_result.unwrap().xray_trace_id);
518 }
519}