use aws_lambda_events::event::alb::AlbTargetGroupRequest;
use aws_lambda_events::event::apigw::{ApiGatewayProxyRequest, ApiGatewayV2httpRequest};
use bon::Builder;
use lambda_runtime::Context;
use opentelemetry::trace::{Link, Status};
use opentelemetry::Value;
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use std::fmt::{self, Display};
use tracing::Span;
use tracing_opentelemetry::OpenTelemetrySpanExt;
use urlencoding;
#[derive(Debug, Clone, PartialEq, Default)]
pub enum TriggerType {
Datasource,
Http,
PubSub,
Timer,
#[default]
Other,
}
impl Display for TriggerType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TriggerType::Datasource => write!(f, "datasource"),
TriggerType::Http => write!(f, "http"),
TriggerType::PubSub => write!(f, "pubsub"),
TriggerType::Timer => write!(f, "timer"),
TriggerType::Other => write!(f, "other"),
}
}
}
#[derive(Builder)]
pub struct SpanAttributes {
pub kind: Option<String>,
pub span_name: Option<String>,
#[builder(default)]
pub attributes: HashMap<String, Value>,
#[builder(default)]
pub links: Vec<Link>,
pub carrier: Option<HashMap<String, String>>,
#[builder(default = TriggerType::Other.to_string())]
pub trigger: String,
}
impl Default for SpanAttributes {
fn default() -> Self {
Self {
kind: None,
span_name: None,
attributes: HashMap::new(),
links: Vec::new(),
carrier: None,
trigger: TriggerType::Other.to_string(),
}
}
}
pub fn get_status_code(response: &JsonValue) -> Option<i64> {
response
.as_object()
.and_then(|obj| obj.get("statusCode"))
.and_then(|v| v.as_i64())
}
pub fn set_response_attributes(span: &Span, response: &JsonValue) {
if let Some(status_code) = get_status_code(response) {
span.set_attribute("http.status_code", status_code.to_string());
if status_code >= 500 {
span.set_status(Status::error(format!("HTTP {status_code} response")));
} else {
span.set_status(Status::Ok);
}
span.set_attribute("http.response.status_code", status_code.to_string());
}
}
pub fn set_common_attributes(span: &Span, context: &Context, is_cold_start: bool) {
span.set_attribute("faas.invocation_id", context.request_id.to_string());
span.set_attribute(
"cloud.resource_id",
context.invoked_function_arn.to_string(),
);
if is_cold_start {
span.set_attribute("faas.coldstart", true);
}
if let Some(account_id) = context.invoked_function_arn.split(':').nth(4) {
span.set_attribute("cloud.account.id", account_id.to_string());
}
if let Some(region) = context.invoked_function_arn.split(':').nth(3) {
span.set_attribute("cloud.region", region.to_string());
}
}
pub trait SpanAttributesExtractor {
fn extract_span_attributes(&self) -> SpanAttributes;
}
impl SpanAttributesExtractor for ApiGatewayV2httpRequest {
fn extract_span_attributes(&self) -> SpanAttributes {
let mut attributes = HashMap::new();
let method = self.request_context.http.method.to_string();
let path = self.raw_path.as_deref().unwrap_or("/");
attributes.insert(
"http.request.method".to_string(),
Value::String(method.clone().into()),
);
if let Some(raw_path) = &self.raw_path {
attributes.insert(
"url.path".to_string(),
Value::String(raw_path.to_string().into()),
);
}
if let Some(query) = &self.raw_query_string {
if !query.is_empty() {
attributes.insert(
"url.query".to_string(),
Value::String(query.to_string().into()),
);
}
}
if let Some(protocol) = &self.request_context.http.protocol {
let protocol_lower = protocol.to_lowercase();
if protocol_lower.starts_with("http/") {
attributes.insert(
"network.protocol.version".to_string(),
Value::String(
protocol_lower
.trim_start_matches("http/")
.to_string()
.into(),
),
);
}
attributes.insert(
"url.scheme".to_string(),
Value::String("https".to_string().into()),
); }
if let Some(route_key) = &self.route_key {
attributes.insert(
"http.route".to_string(),
Value::String(route_key.to_string().into()),
);
}
let carrier = self
.headers
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|v| (k.to_string(), v.to_string())))
.collect();
if let Some(source_ip) = &self.request_context.http.source_ip {
attributes.insert(
"client.address".to_string(),
Value::String(source_ip.to_string().into()),
);
}
if let Some(user_agent) = self.headers.get("user-agent").and_then(|h| h.to_str().ok()) {
attributes.insert(
"user_agent.original".to_string(),
Value::String(user_agent.to_string().into()),
);
}
if let Some(domain_name) = &self.request_context.domain_name {
attributes.insert(
"server.address".to_string(),
Value::String(domain_name.to_string().into()),
);
}
SpanAttributes::builder()
.attributes(attributes)
.carrier(carrier)
.span_name(format!("{method} {path}"))
.trigger(TriggerType::Http.to_string())
.build()
}
}
impl SpanAttributesExtractor for ApiGatewayProxyRequest {
fn extract_span_attributes(&self) -> SpanAttributes {
let mut attributes = HashMap::new();
let method = self.http_method.to_string();
let route = self.resource.as_deref().unwrap_or("/");
attributes.insert(
"http.request.method".to_string(),
Value::String(method.clone().into()),
);
if let Some(path) = &self.path {
attributes.insert(
"url.path".to_string(),
Value::String(path.to_string().into()),
);
}
if !self.multi_value_query_string_parameters.is_empty() {
let mut query_parts = Vec::new();
for key in self
.multi_value_query_string_parameters
.iter()
.map(|(k, _)| k)
{
if let Some(values) = self.multi_value_query_string_parameters.all(key) {
for value in values {
query_parts.push(format!(
"{}={}",
urlencoding::encode(key),
urlencoding::encode(value)
));
}
}
}
if !query_parts.is_empty() {
let query = query_parts.join("&");
attributes.insert("url.query".to_string(), Value::String(query.into()));
}
}
if let Some(protocol) = &self.request_context.protocol {
let protocol_lower = protocol.to_lowercase();
if protocol_lower.starts_with("http/") {
attributes.insert(
"network.protocol.version".to_string(),
Value::String(
protocol_lower
.trim_start_matches("http/")
.to_string()
.into(),
),
);
}
attributes.insert(
"url.scheme".to_string(),
Value::String("https".to_string().into()),
); }
attributes.insert(
"http.route".to_string(),
Value::String(route.to_string().into()),
);
let carrier = self
.headers
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|v| (k.to_string(), v.to_string())))
.collect();
if let Some(source_ip) = &self.request_context.identity.source_ip {
attributes.insert(
"client.address".to_string(),
Value::String(source_ip.to_string().into()),
);
}
if let Some(user_agent) = self.headers.get("user-agent").and_then(|h| h.to_str().ok()) {
attributes.insert(
"user_agent.original".to_string(),
Value::String(user_agent.to_string().into()),
);
}
if let Some(domain_name) = &self.request_context.domain_name {
attributes.insert(
"server.address".to_string(),
Value::String(domain_name.to_string().into()),
);
}
SpanAttributes::builder()
.attributes(attributes)
.carrier(carrier)
.span_name(format!("{method} {route}"))
.trigger(TriggerType::Http.to_string())
.build()
}
}
impl SpanAttributesExtractor for AlbTargetGroupRequest {
fn extract_span_attributes(&self) -> SpanAttributes {
let mut attributes = HashMap::new();
let method = self.http_method.to_string();
let route = self.path.as_deref().unwrap_or("/");
attributes.insert(
"http.request.method".to_string(),
Value::String(method.clone().into()),
);
if let Some(path) = &self.path {
attributes.insert(
"url.path".to_string(),
Value::String(path.to_string().into()),
);
}
if !self.multi_value_query_string_parameters.is_empty() {
let mut query_parts = Vec::new();
for key in self
.multi_value_query_string_parameters
.iter()
.map(|(k, _)| k)
{
if let Some(values) = self.multi_value_query_string_parameters.all(key) {
for value in values {
query_parts.push(format!(
"{}={}",
urlencoding::encode(key),
urlencoding::encode(value)
));
}
}
}
if !query_parts.is_empty() {
let query = query_parts.join("&");
attributes.insert("url.query".to_string(), Value::String(query.into()));
}
}
attributes.insert(
"url.scheme".to_string(),
Value::String("http".to_string().into()),
);
attributes.insert(
"network.protocol.version".to_string(),
Value::String("1.1".to_string().into()),
);
if let Some(target_group_arn) = &self.request_context.elb.target_group_arn {
attributes.insert(
"alb.target_group_arn".to_string(),
Value::String(target_group_arn.to_string().into()),
);
}
let carrier = self
.headers
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|v| (k.to_string(), v.to_string())))
.collect();
if let Some(source_ip) = &self
.headers
.get("x-forwarded-for")
.and_then(|h| h.to_str().ok())
{
if let Some(client_ip) = source_ip.split(',').next() {
attributes.insert(
"client.address".to_string(),
Value::String(client_ip.trim().to_string().into()),
);
}
}
if let Some(user_agent) = self.headers.get("user-agent").and_then(|h| h.to_str().ok()) {
attributes.insert(
"user_agent.original".to_string(),
Value::String(user_agent.to_string().into()),
);
}
if let Some(host) = self.headers.get("host").and_then(|h| h.to_str().ok()) {
attributes.insert(
"server.address".to_string(),
Value::String(host.to_string().into()),
);
}
SpanAttributes::builder()
.attributes(attributes)
.carrier(carrier)
.span_name(format!("{method} {route}"))
.trigger(TriggerType::Http.to_string())
.build()
}
}
impl SpanAttributesExtractor for serde_json::Value {
fn extract_span_attributes(&self) -> SpanAttributes {
let carrier = self
.get("headers")
.and_then(|headers| headers.as_object())
.map(|obj| {
obj.iter()
.filter_map(|(k, v)| v.as_str().map(|v| (k.to_string(), v.to_string())))
.collect()
})
.unwrap_or_default();
SpanAttributes::builder()
.carrier(carrier)
.trigger(TriggerType::Other.to_string())
.build()
}
}
#[cfg(test)]
mod tests {
use super::*;
use aws_lambda_events::http::Method;
#[test]
fn test_trigger_types() {
let attrs = SpanAttributes::default();
assert_eq!(attrs.trigger, TriggerType::Other.to_string());
let attrs = SpanAttributes::builder()
.trigger(TriggerType::Http.to_string())
.build();
assert_eq!(attrs.trigger, TriggerType::Http.to_string());
let attrs = SpanAttributes::builder().build();
assert_eq!(attrs.trigger, TriggerType::Other.to_string());
}
#[test]
fn test_apigw_v2_extraction() {
let mut request = ApiGatewayV2httpRequest::default();
request.raw_path = Some("/test".to_string());
request.route_key = Some("GET /test".to_string());
request.headers = aws_lambda_events::http::HeaderMap::new();
let mut http =
aws_lambda_events::apigw::ApiGatewayV2httpRequestContextHttpDescription::default();
http.method = Method::GET;
http.path = Some("/test".to_string());
http.protocol = Some("HTTP/1.1".to_string());
let mut request_context =
aws_lambda_events::apigw::ApiGatewayV2httpRequestContext::default();
request_context.http = http;
request.request_context = request_context;
let attrs = request.extract_span_attributes();
assert_eq!(
attrs.attributes.get("http.request.method"),
Some(&Value::String("GET".to_string().into()))
);
assert_eq!(
attrs.attributes.get("url.path"),
Some(&Value::String("/test".to_string().into()))
);
assert_eq!(
attrs.attributes.get("http.route"),
Some(&Value::String("GET /test".to_string().into()))
);
assert_eq!(
attrs.attributes.get("url.scheme"),
Some(&Value::String("https".to_string().into()))
);
assert_eq!(
attrs.attributes.get("network.protocol.version"),
Some(&Value::String("1.1".to_string().into()))
);
}
#[test]
fn test_apigw_v1_extraction() {
let mut request = ApiGatewayProxyRequest::default();
request.path = Some("/test".to_string());
request.http_method = Method::GET;
request.resource = Some("/test".to_string());
request.headers = aws_lambda_events::http::HeaderMap::new();
let mut request_context =
aws_lambda_events::apigw::ApiGatewayProxyRequestContext::default();
request_context.protocol = Some("HTTP/1.1".to_string());
request.request_context = request_context;
let attrs = request.extract_span_attributes();
assert_eq!(
attrs.attributes.get("http.request.method"),
Some(&Value::String("GET".to_string().into()))
);
assert_eq!(
attrs.attributes.get("url.path"),
Some(&Value::String("/test".to_string().into()))
);
assert_eq!(
attrs.attributes.get("http.route"),
Some(&Value::String("/test".to_string().into()))
);
assert_eq!(
attrs.attributes.get("url.scheme"),
Some(&Value::String("https".to_string().into()))
);
assert_eq!(
attrs.attributes.get("network.protocol.version"),
Some(&Value::String("1.1".to_string().into()))
);
}
#[test]
fn test_alb_extraction() {
let mut request = AlbTargetGroupRequest::default();
request.path = Some("/test".to_string());
request.http_method = Method::GET;
request.headers = aws_lambda_events::http::HeaderMap::new();
let mut elb = aws_lambda_events::alb::ElbContext::default();
elb.target_group_arn = Some("arn:aws:elasticloadbalancing:...".to_string());
let mut request_context = aws_lambda_events::alb::AlbTargetGroupRequestContext::default();
request_context.elb = elb;
request.request_context = request_context;
let attrs = request.extract_span_attributes();
assert_eq!(
attrs.attributes.get("http.request.method"),
Some(&Value::String("GET".to_string().into()))
);
assert_eq!(
attrs.attributes.get("url.path"),
Some(&Value::String("/test".to_string().into()))
);
assert_eq!(
attrs.attributes.get("url.scheme"),
Some(&Value::String("http".to_string().into()))
);
assert_eq!(
attrs.attributes.get("network.protocol.version"),
Some(&Value::String("1.1".to_string().into()))
);
assert_eq!(
attrs.attributes.get("alb.target_group_arn"),
Some(&Value::String(
"arn:aws:elasticloadbalancing:...".to_string().into()
))
);
}
#[test]
fn test_xray_header_extraction() {
let mut headers = aws_lambda_events::http::HeaderMap::new();
let xray_header =
"Root=1-58406520-a006649127e371903a2de979;Parent=4c721bf33e3caf8f;Sampled=1";
headers.insert(
"x-amzn-trace-id",
aws_lambda_events::http::HeaderValue::from_str(xray_header).unwrap(),
);
let traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01";
headers.insert(
"traceparent",
aws_lambda_events::http::HeaderValue::from_str(traceparent).unwrap(),
);
let mut request = ApiGatewayV2httpRequest::default();
request.headers = headers;
request.raw_path = Some("/test".to_string());
request.route_key = Some("GET /test".to_string());
let mut http =
aws_lambda_events::apigw::ApiGatewayV2httpRequestContextHttpDescription::default();
http.method = Method::GET;
http.path = Some("/test".to_string());
http.protocol = Some("HTTP/1.1".to_string());
let mut request_context =
aws_lambda_events::apigw::ApiGatewayV2httpRequestContext::default();
request_context.http = http;
request.request_context = request_context;
let attrs = request.extract_span_attributes();
assert!(attrs.carrier.is_some());
let carrier = attrs.carrier.unwrap();
assert!(carrier.contains_key("x-amzn-trace-id"));
assert_eq!(carrier.get("x-amzn-trace-id").unwrap(), xray_header);
assert!(carrier.contains_key("traceparent"));
assert_eq!(carrier.get("traceparent").unwrap(), traceparent);
}
#[test]
fn test_json_extractor_with_xray_headers() {
let json_value = serde_json::json!({
"headers": {
"x-amzn-trace-id": "Root=1-58406520-a006649127e371903a2de979;Parent=4c721bf33e3caf8f;Sampled=1",
"traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
"content-type": "application/json"
},
"body": "{\"message\":\"Hello World\"}",
"requestContext": {
"requestId": "12345"
}
});
let attrs = json_value.extract_span_attributes();
assert!(attrs.carrier.is_some());
let carrier = attrs.carrier.unwrap();
assert!(carrier.contains_key("x-amzn-trace-id"));
assert_eq!(
carrier.get("x-amzn-trace-id").unwrap(),
"Root=1-58406520-a006649127e371903a2de979;Parent=4c721bf33e3caf8f;Sampled=1"
);
assert!(carrier.contains_key("traceparent"));
assert_eq!(
carrier.get("traceparent").unwrap(),
"00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"
);
}
}