use std::future::Future;
use aws_lambda_events::encodings::Base64Data;
use bytes::Bytes;
use http::{HeaderMap, Method, Request, Uri};
use http_body_util::{BodyExt, Full};
use lambda_runtime::LambdaEvent;
use lambda_runtime::service_fn;
use lambda_runtime::tower::ServiceExt;
use serde::{Deserialize, Serialize};
use crate::endpoint::{Endpoint, HandleOptions, ProtocolMode};
#[derive(Clone, Debug, Default, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
struct LambdaRequest {
#[serde(with = "http_serde::method")]
pub http_method: Method,
#[serde(default)]
#[serde(with = "http_serde::uri")]
pub path: Uri,
#[serde(with = "http_serde::header_map", default)]
pub headers: HeaderMap,
pub is_base64_encoded: bool,
pub body: Option<Base64Data>,
}
#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize)]
#[serde(rename_all = "camelCase")]
struct LambdaResponse {
pub status_code: u16,
#[serde(default)]
pub status_description: Option<String>,
#[serde(with = "http_serde::header_map", default)]
pub headers: HeaderMap,
#[serde(skip_serializing_if = "Option::is_none")]
pub body: Option<Base64Data>,
#[serde(default)]
pub is_base64_encoded: bool,
}
impl LambdaResponse {
fn builder() -> LambdaResponseBuilder {
LambdaResponseBuilder {
status_code: 200,
status_description: None,
headers: HeaderMap::default(),
body: None,
}
}
}
struct LambdaResponseBuilder {
status_code: u16,
status_description: Option<String>,
headers: HeaderMap,
body: Option<Base64Data>,
}
impl LambdaResponseBuilder {
pub fn status_code(mut self, status_code: u16) -> Self {
self.status_code = status_code;
self.status_description = http::StatusCode::from_u16(status_code)
.map(|s| s.to_string())
.ok();
self
}
pub fn body(mut self, body: Bytes) -> Self {
let mut data = Base64Data::default();
data.0 = body.into();
self.body = Some(data);
self
}
pub fn build(self) -> LambdaResponse {
LambdaResponse {
status_code: self.status_code,
status_description: self.status_description,
headers: self.headers,
body: self.body,
is_base64_encoded: true,
}
}
}
#[derive(Clone)]
pub struct LambdaEndpoint;
impl LambdaEndpoint {
pub fn run(endpoint: Endpoint) -> impl Future<Output = Result<(), lambda_runtime::Error>> {
let svc = service_fn(handle);
let svc = svc.map_request(move |req| {
let endpoint = endpoint.clone();
LambdaEventWithEndpoint {
inner: req,
endpoint,
}
});
lambda_runtime::run(svc)
}
}
struct LambdaEventWithEndpoint {
inner: LambdaEvent<LambdaRequest>,
endpoint: Endpoint,
}
async fn handle(req: LambdaEventWithEndpoint) -> Result<LambdaResponse, lambda_runtime::Error> {
let (request, _) = req.inner.into_parts();
let mut http_request = Request::builder()
.method(request.http_method)
.uri(request.path)
.body(request.body.map(|b| Full::from(b.0)).unwrap_or_default())
.expect("to build");
http_request.headers_mut().extend(request.headers);
let response = req.endpoint.handle_with_options(
http_request,
HandleOptions {
protocol_mode: ProtocolMode::RequestResponse,
},
);
let (parts, body) = response.into_parts();
let body = body.collect().await?.to_bytes();
let mut builder = LambdaResponse::builder().status_code(parts.status.as_u16());
builder.headers.extend(parts.headers);
Ok(builder.body(body).build())
}