1mod readiness;
66
67const ENV_PORT: &str = "AWS_LWA_PORT";
69const ENV_HOST: &str = "AWS_LWA_HOST";
70const ENV_READINESS_CHECK_PORT: &str = "AWS_LWA_READINESS_CHECK_PORT";
71const ENV_READINESS_CHECK_PATH: &str = "AWS_LWA_READINESS_CHECK_PATH";
72const ENV_READINESS_CHECK_PROTOCOL: &str = "AWS_LWA_READINESS_CHECK_PROTOCOL";
73const ENV_READINESS_CHECK_HEALTHY_STATUS: &str = "AWS_LWA_READINESS_CHECK_HEALTHY_STATUS";
74const ENV_READINESS_CHECK_MIN_UNHEALTHY_STATUS: &str = "AWS_LWA_READINESS_CHECK_MIN_UNHEALTHY_STATUS";
75const ENV_REMOVE_BASE_PATH: &str = "AWS_LWA_REMOVE_BASE_PATH";
76const ENV_PASS_THROUGH_PATH: &str = "AWS_LWA_PASS_THROUGH_PATH";
77const ENV_ASYNC_INIT: &str = "AWS_LWA_ASYNC_INIT";
78const ENV_ENABLE_COMPRESSION: &str = "AWS_LWA_ENABLE_COMPRESSION";
79const ENV_INVOKE_MODE: &str = "AWS_LWA_INVOKE_MODE";
80const ENV_AUTHORIZATION_SOURCE: &str = "AWS_LWA_AUTHORIZATION_SOURCE";
81const ENV_ERROR_STATUS_CODES: &str = "AWS_LWA_ERROR_STATUS_CODES";
82const ENV_LAMBDA_RUNTIME_API_PROXY: &str = "AWS_LWA_LAMBDA_RUNTIME_API_PROXY";
83
84const ENV_PORT_DEPRECATED: &str = "PORT";
86const ENV_HOST_DEPRECATED: &str = "HOST";
87const ENV_READINESS_CHECK_PORT_DEPRECATED: &str = "READINESS_CHECK_PORT";
88const ENV_READINESS_CHECK_PATH_DEPRECATED: &str = "READINESS_CHECK_PATH";
89const ENV_READINESS_CHECK_PROTOCOL_DEPRECATED: &str = "READINESS_CHECK_PROTOCOL";
90const ENV_REMOVE_BASE_PATH_DEPRECATED: &str = "REMOVE_BASE_PATH";
91const ENV_ASYNC_INIT_DEPRECATED: &str = "ASYNC_INIT";
92
93const ENV_LAMBDA_RUNTIME_API: &str = "AWS_LAMBDA_RUNTIME_API";
95
96use http::{
97 header::{HeaderName, HeaderValue},
98 Method, StatusCode,
99};
100use http_body::Body as HttpBody;
101use hyper::body::Incoming;
102use hyper_util::client::legacy::connect::HttpConnector;
103use hyper_util::client::legacy::Client;
104use lambda_http::request::RequestContext;
105pub use lambda_http::tracing;
106use lambda_http::Body;
107pub use lambda_http::Error;
108use lambda_http::{Request, RequestExt, Response};
109use readiness::Checkpoint;
110use std::fmt::Debug;
111use std::{
112 env,
113 future::Future,
114 pin::Pin,
115 sync::{
116 atomic::{AtomicBool, Ordering},
117 Arc,
118 },
119 time::Duration,
120};
121use tokio::{net::TcpStream, time::timeout};
122use tokio_retry::{strategy::FixedInterval, Retry};
123use tower::{Service, ServiceBuilder};
124use tower_http::compression::CompressionLayer;
125use url::Url;
126
127#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
147pub enum Protocol {
148 #[default]
151 Http,
152 Tcp,
155}
156
157impl From<&str> for Protocol {
158 fn from(value: &str) -> Self {
159 match value.to_lowercase().as_str() {
160 "http" => Protocol::Http,
161 "tcp" => Protocol::Tcp,
162 _ => Protocol::Http,
163 }
164 }
165}
166
167#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
197pub enum LambdaInvokeMode {
198 #[default]
201 Buffered,
202 ResponseStream,
205}
206
207impl From<&str> for LambdaInvokeMode {
208 fn from(value: &str) -> Self {
209 match value.to_lowercase().as_str() {
210 "buffered" => LambdaInvokeMode::Buffered,
211 "response_stream" => LambdaInvokeMode::ResponseStream,
212 _ => LambdaInvokeMode::Buffered,
213 }
214 }
215}
216
217pub struct AdapterOptions {
264 pub host: String,
267
268 pub port: String,
271
272 pub readiness_check_port: String,
275
276 pub readiness_check_path: String,
279
280 pub readiness_check_protocol: Protocol,
283
284 #[deprecated(since = "1.0.0", note = "Use readiness_check_healthy_status instead")]
288 pub readiness_check_min_unhealthy_status: u16,
289
290 pub readiness_check_healthy_status: Vec<u16>,
299
300 pub base_path: Option<String>,
307
308 pub pass_through_path: String,
311
312 pub async_init: bool,
320
321 pub compression: bool,
332
333 pub invoke_mode: LambdaInvokeMode,
336
337 pub authorization_source: Option<String>,
343
344 pub error_status_codes: Option<Vec<u16>>,
350}
351
352fn get_env_with_deprecation(new_name: &str, old_name: &str, default: &str) -> String {
354 if let Ok(val) = env::var(new_name) {
355 return val;
356 }
357 if let Ok(val) = env::var(old_name) {
358 tracing::warn!(
359 "Environment variable '{}' is deprecated and will be removed in version 2.0. Please use '{}' instead.",
360 old_name,
361 new_name
362 );
363 return val;
364 }
365 default.to_string()
366}
367
368fn get_optional_env_with_deprecation(new_name: &str, old_name: &str) -> Option<String> {
370 if let Ok(val) = env::var(new_name) {
371 return Some(val);
372 }
373 if let Ok(val) = env::var(old_name) {
374 tracing::warn!(
375 "Environment variable '{}' is deprecated and will be removed in version 2.0. Please use '{}' instead.",
376 old_name,
377 new_name
378 );
379 return Some(val);
380 }
381 None
382}
383
384impl Default for AdapterOptions {
385 #[allow(deprecated)]
386 fn default() -> Self {
387 let port = env::var(ENV_PORT)
388 .or_else(|_| env::var(ENV_PORT_DEPRECATED))
389 .unwrap_or_else(|_| "8080".to_string());
390
391 let readiness_check_healthy_status = if let Ok(val) = env::var(ENV_READINESS_CHECK_HEALTHY_STATUS) {
394 parse_status_codes(&val)
395 } else if let Ok(val) = env::var(ENV_READINESS_CHECK_MIN_UNHEALTHY_STATUS) {
396 tracing::warn!(
397 "Environment variable '{}' is deprecated. \
398 Please use '{}' instead (e.g., '100-499').",
399 ENV_READINESS_CHECK_MIN_UNHEALTHY_STATUS,
400 ENV_READINESS_CHECK_HEALTHY_STATUS
401 );
402 let min_unhealthy: u16 = val.parse().unwrap_or(500);
403 (100..min_unhealthy).collect()
404 } else {
405 (100..500).collect()
407 };
408
409 let readiness_check_min_unhealthy_status = env::var(ENV_READINESS_CHECK_MIN_UNHEALTHY_STATUS)
411 .unwrap_or_else(|_| "500".to_string())
412 .parse()
413 .unwrap_or(500);
414
415 AdapterOptions {
416 host: get_env_with_deprecation(ENV_HOST, ENV_HOST_DEPRECATED, "127.0.0.1"),
417 port: port.clone(),
418 readiness_check_port: get_env_with_deprecation(
419 ENV_READINESS_CHECK_PORT,
420 ENV_READINESS_CHECK_PORT_DEPRECATED,
421 &port,
422 ),
423 readiness_check_min_unhealthy_status,
424 readiness_check_healthy_status,
425 readiness_check_path: get_env_with_deprecation(
426 ENV_READINESS_CHECK_PATH,
427 ENV_READINESS_CHECK_PATH_DEPRECATED,
428 "/",
429 ),
430 readiness_check_protocol: get_env_with_deprecation(
431 ENV_READINESS_CHECK_PROTOCOL,
432 ENV_READINESS_CHECK_PROTOCOL_DEPRECATED,
433 "HTTP",
434 )
435 .as_str()
436 .into(),
437 base_path: get_optional_env_with_deprecation(ENV_REMOVE_BASE_PATH, ENV_REMOVE_BASE_PATH_DEPRECATED),
438 pass_through_path: env::var(ENV_PASS_THROUGH_PATH).unwrap_or_else(|_| "/events".to_string()),
439 async_init: get_env_with_deprecation(ENV_ASYNC_INIT, ENV_ASYNC_INIT_DEPRECATED, "false")
440 .parse()
441 .unwrap_or(false),
442 compression: env::var(ENV_ENABLE_COMPRESSION)
443 .unwrap_or_else(|_| "false".to_string())
444 .parse()
445 .unwrap_or(false),
446 invoke_mode: env::var(ENV_INVOKE_MODE)
447 .unwrap_or_else(|_| "buffered".to_string())
448 .as_str()
449 .into(),
450 authorization_source: env::var(ENV_AUTHORIZATION_SOURCE).ok(),
451 error_status_codes: env::var(ENV_ERROR_STATUS_CODES)
452 .ok()
453 .map(|codes| parse_status_codes(&codes)),
454 }
455 }
456}
457
458fn parse_status_codes(input: &str) -> Vec<u16> {
467 input
468 .split(',')
469 .flat_map(|part| {
470 let part = part.trim();
471 if part.contains('-') {
472 let range: Vec<&str> = part.split('-').collect();
473 if range.len() == 2 {
474 if let (Ok(start), Ok(end)) = (range[0].parse::<u16>(), range[1].parse::<u16>()) {
475 return (start..=end).collect::<Vec<_>>();
476 }
477 }
478 tracing::warn!("Failed to parse status code range: {}", part);
479 vec![]
480 } else {
481 part.parse::<u16>().map_or_else(
482 |_| {
483 if !part.is_empty() {
484 tracing::warn!("Failed to parse status code: {}", part);
485 }
486 vec![]
487 },
488 |code| vec![code],
489 )
490 }
491 })
492 .collect()
493}
494
495#[derive(Clone)]
527pub struct Adapter<C, B> {
528 client: Arc<Client<C, B>>,
529 healthcheck_url: Url,
530 healthcheck_protocol: Protocol,
531 healthcheck_healthy_status: Vec<u16>,
532 async_init: bool,
533 ready_at_init: Arc<AtomicBool>,
534 domain: Url,
535 base_path: Option<String>,
536 pass_through_path: String,
537 compression: bool,
538 invoke_mode: LambdaInvokeMode,
539 authorization_source: Option<String>,
540 error_status_codes: Option<Vec<u16>>,
541}
542
543impl Adapter<HttpConnector, Body> {
544 pub fn new(options: &AdapterOptions) -> Result<Adapter<HttpConnector, Body>, Error> {
573 let client = Client::builder(hyper_util::rt::TokioExecutor::new())
574 .pool_idle_timeout(Duration::from_secs(4))
575 .build(HttpConnector::new());
576
577 let schema = "http";
578
579 let healthcheck_url: Url = format!(
580 "{}://{}:{}{}",
581 schema, options.host, options.readiness_check_port, options.readiness_check_path
582 )
583 .parse()
584 .map_err(|e| {
585 Error::from(format!(
586 "Invalid healthcheck URL configuration (host={}, port={}, path={}): {}",
587 options.host, options.readiness_check_port, options.readiness_check_path, e
588 ))
589 })?;
590
591 let domain: Url = format!("{}://{}:{}", schema, options.host, options.port)
592 .parse()
593 .map_err(|e| {
594 Error::from(format!(
595 "Invalid domain URL configuration (host={}, port={}): {}",
596 options.host, options.port, e
597 ))
598 })?;
599
600 if options.readiness_check_protocol == Protocol::Tcp {
602 if healthcheck_url.host().is_none() {
603 return Err(Error::from("TCP readiness check requires a valid host in the URL"));
604 }
605 if healthcheck_url.port().is_none() {
606 return Err(Error::from("TCP readiness check requires a port in the URL"));
607 }
608 }
609
610 let compression = if options.compression && options.invoke_mode == LambdaInvokeMode::ResponseStream {
611 tracing::warn!("Compression is not supported with response streaming. Disabling compression.");
612 false
613 } else {
614 options.compression
615 };
616
617 Ok(Adapter {
618 client: Arc::new(client),
619 healthcheck_url,
620 healthcheck_protocol: options.readiness_check_protocol,
621 healthcheck_healthy_status: options.readiness_check_healthy_status.clone(),
622 domain,
623 base_path: options.base_path.clone(),
624 pass_through_path: options.pass_through_path.clone(),
625 async_init: options.async_init,
626 ready_at_init: Arc::new(AtomicBool::new(false)),
627 compression,
628 invoke_mode: options.invoke_mode,
629 authorization_source: options.authorization_source.clone(),
630 error_status_codes: options.error_status_codes.clone(),
631 })
632 }
633}
634
635impl Adapter<HttpConnector, Body> {
636 pub fn register_default_extension(&self) {
651 tokio::task::spawn(async move {
653 if let Err(e) = Self::register_extension_internal().await {
654 tracing::error!(error = %e, "Extension registration failed - terminating process");
655 std::process::exit(1);
656 }
657 });
658 }
659
660 async fn register_extension_internal() -> Result<(), Error> {
665 let aws_lambda_runtime_api: String =
666 env::var(ENV_LAMBDA_RUNTIME_API).unwrap_or_else(|_| "127.0.0.1:9001".to_string());
667 let client = Client::builder(hyper_util::rt::TokioExecutor::new()).build(HttpConnector::new());
668
669 let register_req = hyper::Request::builder()
670 .method(Method::POST)
671 .uri(format!("http://{aws_lambda_runtime_api}/2020-01-01/extension/register"))
672 .header("Lambda-Extension-Name", "lambda-adapter")
673 .body(Body::from("{ \"events\": [] }"))?;
674
675 let register_res = client.request(register_req).await?;
676
677 if register_res.status() != StatusCode::OK {
678 return Err(Error::from(format!(
679 "Extension registration failed with status: {}",
680 register_res.status()
681 )));
682 }
683
684 let extension_id = register_res
685 .headers()
686 .get("Lambda-Extension-Identifier")
687 .ok_or_else(|| Error::from("Missing Lambda-Extension-Identifier header"))?;
688
689 let next_req = hyper::Request::builder()
690 .method(Method::GET)
691 .uri(format!(
692 "http://{aws_lambda_runtime_api}/2020-01-01/extension/event/next"
693 ))
694 .header("Lambda-Extension-Identifier", extension_id)
695 .body(Body::Empty)?;
696
697 client.request(next_req).await?;
698
699 Ok(())
700 }
701
702 pub async fn check_init_health(&mut self) {
730 let ready_at_init = if self.async_init {
731 timeout(Duration::from_secs_f32(9.8), self.check_readiness())
732 .await
733 .unwrap_or_default()
734 } else {
735 self.check_readiness().await
736 };
737 self.ready_at_init.store(ready_at_init, Ordering::SeqCst);
738 }
739
740 async fn check_readiness(&self) -> bool {
742 let url = self.healthcheck_url.clone();
743 let protocol = self.healthcheck_protocol;
744 self.is_web_ready(&url, &protocol).await
745 }
746
747 async fn is_web_ready(&self, url: &Url, protocol: &Protocol) -> bool {
752 let mut checkpoint = Checkpoint::new();
753 Retry::spawn(FixedInterval::from_millis(10), || {
754 if checkpoint.lapsed() {
755 tracing::info!(url = %url.to_string(), "app is not ready after {}ms", checkpoint.next_ms());
756 checkpoint.increment();
757 }
758 self.check_web_readiness(url, protocol)
759 })
760 .await
761 .is_ok()
762 }
763
764 async fn check_web_readiness(&self, url: &Url, protocol: &Protocol) -> Result<(), i8> {
769 match protocol {
770 Protocol::Http => {
771 let uri: http::Uri = url
774 .as_str()
775 .parse()
776 .expect("BUG: healthcheck_url should be valid - validated in Adapter::new()");
777
778 match self.client.get(uri).await {
779 Ok(response) if self.healthcheck_healthy_status.contains(&response.status().as_u16()) => {
780 tracing::debug!("app is ready");
781 Ok(())
782 }
783 _ => {
784 tracing::trace!("app is not ready");
785 Err(-1)
786 }
787 }
788 }
789 Protocol::Tcp => {
790 let host = url
793 .host_str()
794 .expect("BUG: healthcheck_url should have host - validated in Adapter::new()");
795 let port = url
796 .port()
797 .expect("BUG: healthcheck_url should have port - validated in Adapter::new()");
798
799 match TcpStream::connect(format!("{}:{}", host, port)).await {
800 Ok(_) => Ok(()),
801 Err(_) => Err(-1),
802 }
803 }
804 }
805 }
806
807 pub async fn run(self) -> Result<(), Error> {
834 match (self.compression, self.invoke_mode) {
835 (true, LambdaInvokeMode::Buffered) => {
836 let svc = ServiceBuilder::new().layer(CompressionLayer::new()).service(self);
837 lambda_http::run_concurrent(svc).await
838 }
839 (_, LambdaInvokeMode::Buffered) => lambda_http::run_concurrent(self).await,
840 (_, LambdaInvokeMode::ResponseStream) => lambda_http::run_with_streaming_response_concurrent(self).await,
841 }
842 }
843
844 pub fn apply_runtime_proxy_config() {
880 if let Ok(runtime_proxy) = env::var(ENV_LAMBDA_RUNTIME_API_PROXY) {
881 env::set_var(ENV_LAMBDA_RUNTIME_API, runtime_proxy);
887 }
888 }
889
890 async fn fetch_response(&self, event: Request) -> Result<Response<Incoming>, Error> {
900 if self.async_init && !self.ready_at_init.load(Ordering::SeqCst) {
901 self.is_web_ready(&self.healthcheck_url, &self.healthcheck_protocol)
902 .await;
903 self.ready_at_init.store(true, Ordering::SeqCst);
904 }
905
906 let request_context = event.request_context();
907 let lambda_context = event.lambda_context();
908 let path = event.raw_http_path().to_string();
909 let mut path = path.as_str();
910 let (parts, body) = event.into_parts();
911
912 if let Some(base_path) = self.base_path.as_deref() {
914 path = path.trim_start_matches(base_path);
915 }
916
917 if matches!(request_context, RequestContext::PassThrough) && parts.method == Method::POST {
918 path = self.pass_through_path.as_str();
919 }
920
921 let mut req_headers = parts.headers;
922
923 req_headers.insert(
925 HeaderName::from_static("x-amzn-request-context"),
926 HeaderValue::from_bytes(serde_json::to_string(&request_context)?.as_bytes())?,
927 );
928
929 req_headers.insert(
931 HeaderName::from_static("x-amzn-lambda-context"),
932 HeaderValue::from_bytes(serde_json::to_string(&lambda_context)?.as_bytes())?,
933 );
934
935 if let Some(ref tenant_id) = lambda_context.tenant_id {
937 if let Ok(value) = HeaderValue::from_str(tenant_id) {
938 req_headers.insert(HeaderName::from_static("x-amz-tenant-id"), value);
939 tracing::debug!(tenant_id = %tenant_id, "propagating tenant_id header");
940 } else {
941 tracing::warn!(tenant_id = %tenant_id, "tenant_id contains invalid header characters, skipping");
942 }
943 }
944
945 if let Some(authorization_source) = self.authorization_source.as_deref() {
946 if let Some(original) = req_headers.remove(authorization_source) {
947 req_headers.insert("authorization", original);
948 } else {
949 tracing::warn!("\"{}\" header not found in request headers", authorization_source);
950 }
951 }
952
953 let mut app_url = self.domain.clone();
954 app_url.set_path(path);
955 app_url.set_query(parts.uri.query());
956
957 tracing::debug!(app_url = %app_url, req_headers = ?req_headers, "sending request to app server");
958
959 let mut builder = hyper::Request::builder().method(parts.method).uri(app_url.to_string());
960 if let Some(headers) = builder.headers_mut() {
961 headers.extend(req_headers);
962 }
963
964 let body_bytes = match body {
966 Body::Empty => Vec::new(),
967 Body::Text(s) => s.into_bytes(),
968 Body::Binary(b) => b,
969 _ => body.to_vec(),
971 };
972 let request = builder.body(Body::Binary(body_bytes))?;
973
974 let mut app_response = self.client.request(request).await?;
975
976 if let Some(error_codes) = &self.error_status_codes {
978 let status = app_response.status().as_u16();
979 if error_codes.contains(&status) {
980 return Err(Error::from(format!(
981 "Request failed with configured error status code: {}",
982 status
983 )));
984 }
985 }
986
987 app_response.headers_mut().remove("transfer-encoding");
989
990 tracing::debug!(status = %app_response.status(), body_size = ?app_response.body().size_hint().lower(),
991 app_headers = ?app_response.headers().clone(), "responding to lambda event");
992
993 Ok(app_response)
994 }
995}
996
997impl Service<Request> for Adapter<HttpConnector, Body> {
1002 type Response = Response<Incoming>;
1003 type Error = Error;
1004 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1005
1006 fn poll_ready(&mut self, _cx: &mut core::task::Context<'_>) -> core::task::Poll<Result<(), Self::Error>> {
1007 core::task::Poll::Ready(Ok(()))
1008 }
1009
1010 fn call(&mut self, event: Request) -> Self::Future {
1011 let adapter = self.clone();
1012 Box::pin(async move { adapter.fetch_response(event).await })
1013 }
1014}
1015
1016#[cfg(test)]
1017mod tests {
1018 use super::*;
1019 use httpmock::{Method::GET, MockServer};
1020
1021 #[test]
1022 fn test_parse_status_codes() {
1023 assert_eq!(parse_status_codes("500,502-504,422"), vec![500, 502, 503, 504, 422]);
1024 assert_eq!(
1025 parse_status_codes("500, 502-504, 422"), vec![500, 502, 503, 504, 422]
1027 );
1028 assert_eq!(parse_status_codes("500"), vec![500]);
1029 assert_eq!(parse_status_codes("500-502"), vec![500, 501, 502]);
1030 assert_eq!(parse_status_codes("invalid"), Vec::<u16>::new());
1031 assert_eq!(parse_status_codes("500-invalid"), Vec::<u16>::new());
1032 assert_eq!(parse_status_codes(""), Vec::<u16>::new());
1033 }
1034
1035 #[tokio::test]
1036 async fn test_status_200_is_ok() {
1037 let app_server = MockServer::start();
1039 let healthcheck = app_server.mock(|when, then| {
1040 when.method(GET).path("/healthcheck");
1041 then.status(200).body("OK");
1042 });
1043
1044 let options = AdapterOptions {
1046 host: app_server.host(),
1047 port: app_server.port().to_string(),
1048 readiness_check_port: app_server.port().to_string(),
1049 readiness_check_path: "/healthcheck".to_string(),
1050 ..Default::default()
1051 };
1052
1053 let adapter = Adapter::new(&options).expect("Failed to create adapter");
1055
1056 let url = adapter.healthcheck_url.clone();
1057 let protocol = adapter.healthcheck_protocol;
1058
1059 assert!(adapter.check_web_readiness(&url, &protocol).await.is_ok());
1062
1063 healthcheck.assert();
1065 }
1066
1067 #[tokio::test]
1068 async fn test_status_500_is_bad() {
1069 let app_server = MockServer::start();
1071 let healthcheck = app_server.mock(|when, then| {
1072 when.method(GET).path("/healthcheck");
1073 then.status(500).body("OK");
1074 });
1075
1076 let options = AdapterOptions {
1078 host: app_server.host(),
1079 port: app_server.port().to_string(),
1080 readiness_check_port: app_server.port().to_string(),
1081 readiness_check_path: "/healthcheck".to_string(),
1082 ..Default::default()
1083 };
1084
1085 let adapter = Adapter::new(&options).expect("Failed to create adapter");
1087
1088 let url = adapter.healthcheck_url.clone();
1089 let protocol = adapter.healthcheck_protocol;
1090
1091 assert!(adapter.check_web_readiness(&url, &protocol).await.is_err());
1094
1095 healthcheck.assert();
1097 }
1098
1099 #[tokio::test]
1100 async fn test_status_403_is_bad_when_configured() {
1101 let app_server = MockServer::start();
1103 let healthcheck = app_server.mock(|when, then| {
1104 when.method(GET).path("/healthcheck");
1105 then.status(403).body("OK");
1106 });
1107
1108 #[allow(deprecated)]
1110 let options = AdapterOptions {
1111 host: app_server.host(),
1112 port: app_server.port().to_string(),
1113 readiness_check_port: app_server.port().to_string(),
1114 readiness_check_path: "/healthcheck".to_string(),
1115 readiness_check_min_unhealthy_status: 400,
1116 readiness_check_healthy_status: (200..400).collect(),
1117 ..Default::default()
1118 };
1119
1120 let adapter = Adapter::new(&options).expect("Failed to create adapter");
1122
1123 let url = adapter.healthcheck_url.clone();
1124 let protocol = adapter.healthcheck_protocol;
1125
1126 assert!(adapter.check_web_readiness(&url, &protocol).await.is_err());
1129
1130 healthcheck.assert();
1132 }
1133
1134 #[tokio::test]
1135 async fn test_tcp_readiness_check_success() {
1136 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1138 let port = listener.local_addr().unwrap().port();
1139
1140 #[allow(deprecated)]
1141 let options = AdapterOptions {
1142 host: "127.0.0.1".to_string(),
1143 port: port.to_string(),
1144 readiness_check_port: port.to_string(),
1145 readiness_check_path: "/".to_string(),
1146 readiness_check_protocol: Protocol::Tcp,
1147 ..Default::default()
1148 };
1149
1150 let adapter = Adapter::new(&options).expect("Failed to create adapter");
1151 let url = adapter.healthcheck_url.clone();
1152 let protocol = adapter.healthcheck_protocol;
1153
1154 assert_eq!(protocol, Protocol::Tcp);
1155 assert!(adapter.check_web_readiness(&url, &protocol).await.is_ok());
1156 }
1157
1158 #[tokio::test]
1159 async fn test_tcp_readiness_check_failure() {
1160 #[allow(deprecated)]
1162 let options = AdapterOptions {
1163 host: "127.0.0.1".to_string(),
1164 port: "19999".to_string(),
1165 readiness_check_port: "19999".to_string(),
1166 readiness_check_path: "/".to_string(),
1167 readiness_check_protocol: Protocol::Tcp,
1168 ..Default::default()
1169 };
1170
1171 let adapter = Adapter::new(&options).expect("Failed to create adapter");
1172 let url = adapter.healthcheck_url.clone();
1173 let protocol = adapter.healthcheck_protocol;
1174
1175 assert!(adapter.check_web_readiness(&url, &protocol).await.is_err());
1176 }
1177
1178 #[test]
1179 fn test_protocol_from_str() {
1180 assert_eq!(Protocol::from("http"), Protocol::Http);
1181 assert_eq!(Protocol::from("HTTP"), Protocol::Http);
1182 assert_eq!(Protocol::from("tcp"), Protocol::Tcp);
1183 assert_eq!(Protocol::from("TCP"), Protocol::Tcp);
1184 assert_eq!(Protocol::from("unknown"), Protocol::Http); assert_eq!(Protocol::from(""), Protocol::Http);
1186 }
1187
1188 #[test]
1189 fn test_invoke_mode_from_str() {
1190 assert_eq!(LambdaInvokeMode::from("buffered"), LambdaInvokeMode::Buffered);
1191 assert_eq!(LambdaInvokeMode::from("BUFFERED"), LambdaInvokeMode::Buffered);
1192 assert_eq!(
1193 LambdaInvokeMode::from("response_stream"),
1194 LambdaInvokeMode::ResponseStream
1195 );
1196 assert_eq!(
1197 LambdaInvokeMode::from("RESPONSE_STREAM"),
1198 LambdaInvokeMode::ResponseStream
1199 );
1200 assert_eq!(LambdaInvokeMode::from("unknown"), LambdaInvokeMode::Buffered); assert_eq!(LambdaInvokeMode::from(""), LambdaInvokeMode::Buffered);
1202 }
1203
1204 #[test]
1205 fn test_adapter_new_invalid_host() {
1206 #[allow(deprecated)]
1207 let options = AdapterOptions {
1208 host: "invalid host with spaces".to_string(),
1209 port: "8080".to_string(),
1210 readiness_check_port: "8080".to_string(),
1211 readiness_check_path: "/".to_string(),
1212 ..Default::default()
1213 };
1214
1215 let result = Adapter::new(&options);
1216 assert!(result.is_err());
1217 }
1218
1219 #[test]
1220 fn test_adapter_new_valid_config() {
1221 #[allow(deprecated)]
1222 let options = AdapterOptions {
1223 host: "127.0.0.1".to_string(),
1224 port: "3000".to_string(),
1225 readiness_check_port: "3000".to_string(),
1226 readiness_check_path: "/health".to_string(),
1227 readiness_check_protocol: Protocol::Http,
1228 ..Default::default()
1229 };
1230
1231 let adapter = Adapter::new(&options);
1232 assert!(adapter.is_ok());
1233 }
1234
1235 #[test]
1236 fn test_parse_status_codes_single_range() {
1237 let codes = parse_status_codes("200-204");
1238 assert_eq!(codes, vec![200, 201, 202, 203, 204]);
1239 }
1240
1241 #[test]
1242 fn test_parse_status_codes_mixed_with_spaces() {
1243 let codes = parse_status_codes("200, 301-303, 404");
1244 assert_eq!(codes, vec![200, 301, 302, 303, 404]);
1245 }
1246
1247 #[test]
1248 fn test_parse_status_codes_invalid_range_format() {
1249 let codes = parse_status_codes("200-300-400");
1251 assert!(codes.is_empty());
1252 }
1253
1254 #[test]
1255 fn test_apply_runtime_proxy_config_sets_env() {
1256 env::remove_var(ENV_LAMBDA_RUNTIME_API_PROXY);
1258 env::remove_var(ENV_LAMBDA_RUNTIME_API);
1259
1260 Adapter::apply_runtime_proxy_config();
1262 assert!(env::var(ENV_LAMBDA_RUNTIME_API).is_err());
1263
1264 env::set_var(ENV_LAMBDA_RUNTIME_API_PROXY, "127.0.0.1:9002");
1266 Adapter::apply_runtime_proxy_config();
1267 assert_eq!(env::var(ENV_LAMBDA_RUNTIME_API).unwrap(), "127.0.0.1:9002");
1268
1269 env::remove_var(ENV_LAMBDA_RUNTIME_API_PROXY);
1271 env::remove_var(ENV_LAMBDA_RUNTIME_API);
1272 }
1273
1274 #[test]
1275 fn test_compression_disabled_with_response_stream() {
1276 #[allow(deprecated)]
1277 let options = AdapterOptions {
1278 compression: true,
1279 invoke_mode: LambdaInvokeMode::ResponseStream,
1280 ..Default::default()
1281 };
1282
1283 let adapter = Adapter::new(&options).expect("Failed to create adapter");
1284 assert!(
1285 !adapter.compression,
1286 "Compression should be disabled when invoke mode is ResponseStream"
1287 );
1288 }
1289
1290 #[test]
1291 fn test_compression_enabled_with_buffered() {
1292 #[allow(deprecated)]
1293 let options = AdapterOptions {
1294 compression: true,
1295 invoke_mode: LambdaInvokeMode::Buffered,
1296 ..Default::default()
1297 };
1298
1299 let adapter = Adapter::new(&options).expect("Failed to create adapter");
1300 assert!(
1301 adapter.compression,
1302 "Compression should remain enabled when invoke mode is Buffered"
1303 );
1304 }
1305
1306 fn make_lambda_context(tenant_id: Option<&str>) -> lambda_http::Context {
1308 use lambda_http::lambda_runtime::Config;
1309 let mut headers = http::HeaderMap::new();
1310 headers.insert("lambda-runtime-aws-request-id", "test-id".parse().unwrap());
1311 headers.insert("lambda-runtime-deadline-ms", "123".parse().unwrap());
1312 headers.insert("lambda-runtime-client-context", "{}".parse().unwrap());
1313 if let Some(tid) = tenant_id {
1314 headers.insert("lambda-runtime-aws-tenant-id", tid.parse().unwrap());
1315 }
1316 let conf = Config {
1317 function_name: "test_function".into(),
1318 memory: 128,
1319 version: "latest".into(),
1320 log_stream: "/aws/lambda/test_function".into(),
1321 log_group: "2023/09/15/[$LATEST]ab831cef03e94457a94b6efcbe22406a".into(),
1322 };
1323 lambda_http::Context::new("test-id", Arc::new(conf), &headers).unwrap()
1324 }
1325
1326 #[tokio::test]
1327 async fn test_tenant_id_header_propagated() {
1328 let app_server = MockServer::start();
1329 app_server.mock(|when, then| {
1330 when.method(GET).path("/hello").header("x-amz-tenant-id", "tenant-abc");
1331 then.status(200).body("OK");
1332 });
1333
1334 let options = AdapterOptions {
1335 host: app_server.host(),
1336 port: app_server.port().to_string(),
1337 readiness_check_port: app_server.port().to_string(),
1338 readiness_check_path: "/".to_string(),
1339 ..Default::default()
1340 };
1341
1342 let adapter = Adapter::new(&options).expect("Failed to create adapter");
1343
1344 let alb_req = lambda_http::request::LambdaRequest::Alb({
1346 let mut req = lambda_http::aws_lambda_events::alb::AlbTargetGroupRequest::default();
1347 req.http_method = Method::GET;
1348 req.path = Some("/hello".into());
1349 req
1350 });
1351 let mut request = Request::from(alb_req);
1352 request.extensions_mut().insert(make_lambda_context(Some("tenant-abc")));
1353
1354 let response = adapter.fetch_response(request).await.expect("Request failed");
1355 assert_eq!(200, response.status().as_u16());
1356 }
1357
1358 #[tokio::test]
1359 async fn test_tenant_id_header_absent_when_no_tenant() {
1360 let app_server = MockServer::start();
1361 app_server.mock(|when, then| {
1362 when.method(GET)
1363 .path("/hello")
1364 .is_true(|req| !req.headers().iter().any(|(k, _)| k == "x-amz-tenant-id"));
1365 then.status(200).body("OK");
1366 });
1367
1368 let options = AdapterOptions {
1369 host: app_server.host(),
1370 port: app_server.port().to_string(),
1371 readiness_check_port: app_server.port().to_string(),
1372 readiness_check_path: "/".to_string(),
1373 ..Default::default()
1374 };
1375
1376 let adapter = Adapter::new(&options).expect("Failed to create adapter");
1377
1378 let alb_req = lambda_http::request::LambdaRequest::Alb({
1379 let mut req = lambda_http::aws_lambda_events::alb::AlbTargetGroupRequest::default();
1380 req.http_method = Method::GET;
1381 req.path = Some("/hello".into());
1382 req
1383 });
1384 let mut request = Request::from(alb_req);
1385 request.extensions_mut().insert(make_lambda_context(None));
1386
1387 let response = adapter.fetch_response(request).await.expect("Request failed");
1388 assert_eq!(200, response.status().as_u16());
1389 }
1390}