#![allow(dead_code)]
use super::collector::Collector;
use lambda_runtime::{LambdaEvent, LambdaInvocation};
use pin_project::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tower::Layer;
pub struct MetricsLayer {
pub(crate) collector: &'static Collector,
}
impl MetricsLayer {
pub fn new(collector: &'static Collector) -> Self {
Self { collector }
}
}
impl<S> Layer<S> for MetricsLayer {
type Service = MetricsService<S>;
fn layer(&self, inner: S) -> Self::Service {
MetricsService {
metrics: self.collector,
inner,
}
}
}
pub struct MetricsService<S> {
metrics: &'static Collector,
inner: S,
}
impl<S> MetricsService<S> {
pub fn new<Request, Response>(metrics: &'static Collector, inner: S) -> MetricsService<S>
where
S: tower::Service<LambdaEvent<Request>>,
{
Self { metrics, inner }
}
}
impl<S> tower::Service<LambdaInvocation> for MetricsService<S>
where
S: tower::Service<LambdaInvocation>,
{
type Response = S::Response;
type Error = S::Error;
type Future = MetricsServiceFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: LambdaInvocation) -> Self::Future {
if let Some(prop_name) = self.metrics.config.lambda_request_id {
self.metrics.set_property(prop_name, req.context.request_id.clone());
}
if let Some(prop_name) = self.metrics.config.lambda_xray_trace_id {
self.metrics.set_property(prop_name, req.context.xray_trace_id.clone());
}
let mut cold_start_span = None;
if let Some(counter_name) = self.metrics.config.lambda_cold_start {
static COLD_START_BEGIN: std::sync::Once = std::sync::Once::new();
COLD_START_BEGIN.call_once(|| {
cold_start_span = self.metrics.take_cold_start_span().map(|span| span.entered());
self.metrics
.write_single(counter_name, Some(metrics::Unit::Count), 1, std::io::stdout())
.expect("failed to flush cold start metric");
});
}
MetricsServiceFuture {
metrics: self.metrics,
inner: self.inner.call(req),
cold_start_span,
}
}
}
#[pin_project]
#[doc(hidden)]
pub struct MetricsServiceFuture<F> {
#[pin]
metrics: &'static Collector,
#[pin]
inner: F,
cold_start_span: Option<tracing::span::EnteredSpan>,
}
impl<F, Response, Error> Future for MetricsServiceFuture<F>
where
F: Future<Output = Result<Response, Error>>,
Error: Into<Error>,
{
type Output = Result<Response, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
if let Poll::Ready(result) = this.inner.poll(cx) {
this.metrics.flush(std::io::stdout()).expect("failed to flush metrics");
static COLD_START_END: std::sync::Once = std::sync::Once::new();
COLD_START_END.call_once(|| {
let _span = this.cold_start_span.take();
});
return Poll::Ready(result);
}
Poll::Pending
}
}
pub mod service {
use core::fmt::Debug;
use futures::Stream;
use lambda_runtime::{layers::TracingLayer, Diagnostic, IntoFunctionResponse};
use serde::{Deserialize, Serialize};
use tower::Service;
use super::*;
pub async fn run<A, F, R, B, S, D, E>(metrics: &'static Collector, handler: F) -> Result<(), lambda_runtime::Error>
where
F: Service<LambdaEvent<A>, Response = R>,
F::Future: Future<Output = Result<R, F::Error>>,
F::Error: Into<Diagnostic> + std::fmt::Debug,
A: for<'de> Deserialize<'de>,
R: IntoFunctionResponse<B, S>,
B: Serialize,
S: Stream<Item = Result<D, E>> + Unpin + Send + 'static,
D: Into<bytes::Bytes> + Send,
E: Into<lambda_runtime::Error> + Send + Debug,
{
let runtime = lambda_runtime::Runtime::new(handler)
.layer(TracingLayer::new())
.layer(MetricsLayer::new(metrics));
runtime.run().await
}
pub async fn run_http<'a, R, S, E>(metrics: &'static Collector, handler: S) -> Result<(), lambda_runtime::Error>
where
S: Service<lambda_http::Request, Response = R, Error = E>,
S::Future: Send + 'a,
R: lambda_http::IntoResponse,
E: std::fmt::Debug + Into<Diagnostic>,
{
run(metrics, lambda_http::Adapter::from(handler)).await
}
}
pub mod handler {
use lambda_http::service_fn;
use super::*;
pub async fn run<T, F, Request, Response>(
metrics: &'static Collector,
handler: T,
) -> Result<(), lambda_runtime::Error>
where
T: FnMut(LambdaEvent<Request>) -> F,
F: Future<Output = Result<Response, lambda_runtime::Error>>,
Request: for<'de> serde::Deserialize<'de>,
Response: serde::Serialize,
{
super::service::run(metrics, lambda_runtime::service_fn(handler)).await
}
pub async fn run_http<'a, T, F, Response>(
metrics: &'static Collector,
handler: T,
) -> Result<(), lambda_runtime::Error>
where
T: FnMut(lambda_http::Request) -> F,
F: Future<Output = Result<Response, lambda_runtime::Error>> + Send + 'a,
Response: lambda_http::IntoResponse,
{
super::service::run(metrics, lambda_http::Adapter::from(service_fn(handler))).await
}
}