use crate::cold_start::check_cold_start;
use crate::extractor::TraceContextExtractor;
use crate::future::OtelTracingFuture;
use lambda_runtime::LambdaEvent;
use opentelemetry_sdk::logs::SdkLoggerProvider;
use opentelemetry_sdk::trace::SdkTracerProvider;
use opentelemetry_semantic_conventions::attribute::{
CLIENT_ADDRESS, CLOUD_ACCOUNT_ID, CLOUD_PROVIDER, CLOUD_REGION, FAAS_MAX_MEMORY, FAAS_NAME,
FAAS_VERSION, HTTP_REQUEST_METHOD, HTTP_ROUTE, MESSAGING_BATCH_MESSAGE_COUNT,
MESSAGING_DESTINATION_NAME, MESSAGING_MESSAGE_ID, MESSAGING_OPERATION_TYPE, MESSAGING_SYSTEM,
NETWORK_PROTOCOL_VERSION, SERVER_ADDRESS, URL_PATH, URL_QUERY, URL_SCHEME, USER_AGENT_ORIGINAL,
};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tower::Service;
use tracing::Span;
use tracing_opentelemetry::OpenTelemetrySpanExt;
#[derive(Clone)]
pub struct OtelTracingService<S, E> {
inner: S,
extractor: E,
tracer_provider: Option<Arc<SdkTracerProvider>>,
logger_provider: Option<Arc<SdkLoggerProvider>>,
flush_on_end: bool,
flush_timeout: Duration,
}
impl<S, E> OtelTracingService<S, E> {
pub(crate) fn new(
inner: S,
extractor: E,
tracer_provider: Option<Arc<SdkTracerProvider>>,
logger_provider: Option<Arc<SdkLoggerProvider>>,
flush_on_end: bool,
flush_timeout: Duration,
) -> Self {
Self {
inner,
extractor,
tracer_provider,
logger_provider,
flush_on_end,
flush_timeout,
}
}
}
impl<S, E, T> Service<LambdaEvent<T>> for OtelTracingService<S, E>
where
S: Service<LambdaEvent<T>>,
S::Error: std::fmt::Display,
E: TraceContextExtractor<T>,
T: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = OtelTracingFuture<S::Future, S::Response, S::Error>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, event: LambdaEvent<T>) -> Self::Future {
let (payload, lambda_ctx) = event.into_parts();
let parent_context = self.extractor.extract_context(&payload);
let links = self.extractor.extract_links(&payload);
let is_cold_start = check_cold_start();
let span_name = self.extractor.span_name(&payload, &lambda_ctx);
let span = tracing::info_span!(
"lambda.invoke",
otel.name = %span_name,
otel.kind = "server",
faas.trigger = %self.extractor.trigger_type(),
faas.invocation_id = %lambda_ctx.request_id,
faas.coldstart = is_cold_start,
{ HTTP_REQUEST_METHOD } = tracing::field::Empty,
{ URL_PATH } = tracing::field::Empty,
{ URL_QUERY } = tracing::field::Empty,
{ URL_SCHEME } = tracing::field::Empty,
{ HTTP_ROUTE } = tracing::field::Empty,
{ USER_AGENT_ORIGINAL } = tracing::field::Empty,
{ CLIENT_ADDRESS } = tracing::field::Empty,
{ SERVER_ADDRESS } = tracing::field::Empty,
{ NETWORK_PROTOCOL_VERSION } = tracing::field::Empty,
{ MESSAGING_SYSTEM } = tracing::field::Empty,
{ MESSAGING_OPERATION_TYPE } = tracing::field::Empty,
{ MESSAGING_DESTINATION_NAME } = tracing::field::Empty,
{ MESSAGING_MESSAGE_ID } = tracing::field::Empty,
{ MESSAGING_BATCH_MESSAGE_COUNT } = tracing::field::Empty,
{ CLOUD_PROVIDER } = tracing::field::Empty,
{ CLOUD_REGION } = tracing::field::Empty,
{ CLOUD_ACCOUNT_ID } = tracing::field::Empty,
{ FAAS_NAME } = tracing::field::Empty,
{ FAAS_VERSION } = tracing::field::Empty,
{ FAAS_MAX_MEMORY } = tracing::field::Empty,
aws.lambda.invoked_arn = tracing::field::Empty,
);
let _ = span.set_parent(parent_context);
for link in links {
span.add_link(link.span_context.clone());
}
self.extractor.record_attributes(&payload, &span);
record_lambda_context_attributes(&span, &lambda_ctx);
let event = LambdaEvent::new(payload, lambda_ctx);
let future = {
let _guard = span.enter();
self.inner.call(event)
};
OtelTracingFuture::new(
future,
span,
self.tracer_provider.clone(),
self.logger_provider.clone(),
self.flush_on_end,
self.flush_timeout,
)
}
}
fn record_lambda_context_attributes(span: &Span, ctx: &lambda_runtime::Context) {
span.record(CLOUD_PROVIDER, "aws");
span.record(FAAS_NAME, ctx.env_config.function_name.as_str());
span.record(FAAS_VERSION, ctx.env_config.version.as_str());
let memory_bytes = ctx.env_config.memory as i64 * 1024 * 1024;
span.record(FAAS_MAX_MEMORY, memory_bytes);
if let Ok(region) = std::env::var("AWS_REGION") {
span.record(CLOUD_REGION, region.as_str());
}
span.record("aws.lambda.invoked_arn", ctx.invoked_function_arn.as_str());
if let Some(account_id) = ctx.invoked_function_arn.split(':').nth(4) {
span.record(CLOUD_ACCOUNT_ID, account_id);
}
}