use std::any::Any;
use std::backtrace::{Backtrace, BacktraceStatus};
use std::cell::RefCell;
use std::future::Future;
use std::panic::AssertUnwindSafe;
use std::pin::Pin;
use std::sync::{Arc, Once};
use std::task::{Context, Poll};
use axum::extract::MatchedPath;
use axum::http::{Request, StatusCode};
use axum::response::{IntoResponse, Response};
use futures::FutureExt;
use pin_project_lite::pin_project;
use tower::{Layer, Service};
use crate::middleware::RequestId;
use crate::middleware::exception_filter::AutumnErrorInfo;
pub type ReportFuture<'a> = Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ErrorEvent {
pub status: StatusCode,
pub message: String,
pub problem_type: Option<String>,
pub request_id: Option<String>,
pub route: Option<String>,
pub method: Option<String>,
pub panic: Option<PanicInfo>,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct PanicInfo {
pub payload: String,
pub backtrace: Option<String>,
}
pub trait ErrorReporter: Send + Sync + 'static {
fn report<'a>(&'a self, event: &'a ErrorEvent) -> ReportFuture<'a>;
}
#[derive(Debug, Clone, Default)]
pub struct LogReporter;
impl ErrorReporter for LogReporter {
fn report<'a>(&'a self, event: &'a ErrorEvent) -> ReportFuture<'a> {
Box::pin(async move {
if let Some(panic) = event.panic.as_ref() {
tracing::error!(
status = %event.status,
method = event.method.as_deref().unwrap_or("-"),
route = event.route.as_deref().unwrap_or("-"),
request_id = event.request_id.as_deref().unwrap_or("-"),
backtrace = panic.backtrace.as_deref().unwrap_or("(set RUST_BACKTRACE=1 to capture)"),
"handler panic captured: {}",
panic.payload
);
} else {
tracing::error!(
status = %event.status,
method = event.method.as_deref().unwrap_or("-"),
route = event.route.as_deref().unwrap_or("-"),
request_id = event.request_id.as_deref().unwrap_or("-"),
problem_type = event.problem_type.as_deref().unwrap_or("-"),
"server error captured: {}",
event.message
);
}
})
}
}
#[derive(Clone, Default)]
pub(crate) struct RegisteredReporters(pub(crate) Vec<Arc<dyn ErrorReporter>>);
struct ReporterChain {
reporters: Vec<Arc<dyn ErrorReporter>>,
enabled: bool,
sample_rate: f64,
}
impl ReporterChain {
fn dispatch(self: &Arc<Self>, event: ErrorEvent) {
if !self.enabled || !sampled(self.sample_rate) {
return;
}
if let Ok(handle) = tokio::runtime::Handle::try_current() {
let chain = Arc::clone(self);
handle.spawn(async move {
chain.report_all(&event).await;
});
}
}
async fn report_all(&self, event: &ErrorEvent) {
for reporter in &self.reporters {
match std::panic::catch_unwind(AssertUnwindSafe(|| reporter.report(event))) {
Ok(future) => {
if AssertUnwindSafe(future).catch_unwind().await.is_err() {
tracing::warn!("error reporter panicked while reporting; ignoring");
}
}
Err(_panic) => {
tracing::warn!("error reporter panicked constructing report future; ignoring");
}
}
}
}
}
thread_local! {
static RNG_STATE: std::cell::Cell<u64> = std::cell::Cell::new(seed_rng());
}
fn seed_rng() -> u64 {
let mut buf = [0u8; 8];
if getrandom::getrandom(&mut buf).is_ok() {
let seed = u64::from_ne_bytes(buf);
if seed != 0 {
return seed;
}
}
0x5555_5555_5555_5555
}
fn next_u64() -> u64 {
RNG_STATE.with(|cell| {
let mut x = cell.get();
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
cell.set(x);
x
})
}
#[allow(clippy::cast_precision_loss)]
fn sampled(rate: f64) -> bool {
if rate >= 1.0 {
return true;
}
if rate <= 0.0 {
return false;
}
let draw = next_u64() >> 11;
let value = draw as f64 / (1u64 << 53) as f64;
value < rate
}
thread_local! {
static LAST_PANIC: RefCell<Option<CapturedPanic>> = const { RefCell::new(None) };
}
struct CapturedPanic {
backtrace: Option<String>,
}
static HOOK_INSTALLED: Once = Once::new();
fn ensure_panic_hook() {
HOOK_INSTALLED.call_once(|| {
let previous = std::panic::take_hook();
std::panic::set_hook(Box::new(move |info| {
let backtrace = Backtrace::capture();
let backtrace =
(backtrace.status() == BacktraceStatus::Captured).then(|| backtrace.to_string());
LAST_PANIC.with(|cell| {
*cell.borrow_mut() = Some(CapturedPanic { backtrace });
});
previous(info);
}));
});
}
fn format_panic_payload(payload: &(dyn Any + Send)) -> String {
payload
.downcast_ref::<&str>()
.map(|s| (*s).to_owned())
.or_else(|| payload.downcast_ref::<String>().cloned())
.unwrap_or_else(|| "handler panicked".to_owned())
}
#[derive(Clone)]
struct RequestContext {
method: String,
route: Option<String>,
request_id: Option<String>,
}
#[derive(Clone)]
pub struct ReportingLayer {
chain: Arc<ReporterChain>,
}
impl ReportingLayer {
#[must_use]
pub(crate) fn new(
reporters: Vec<Arc<dyn ErrorReporter>>,
enabled: bool,
sample_rate: f64,
) -> Self {
ensure_panic_hook();
let reporters = if reporters.is_empty() {
vec![Arc::new(LogReporter) as Arc<dyn ErrorReporter>]
} else {
reporters
};
Self {
chain: Arc::new(ReporterChain {
reporters,
enabled,
sample_rate,
}),
}
}
}
impl<S> Layer<S> for ReportingLayer {
type Service = ReportingService<S>;
fn layer(&self, inner: S) -> Self::Service {
ReportingService {
inner,
chain: Arc::clone(&self.chain),
}
}
}
#[derive(Clone)]
pub struct ReportingService<S> {
inner: S,
chain: Arc<ReporterChain>,
}
impl<S, ReqBody> Service<Request<ReqBody>> for ReportingService<S>
where
S: Service<Request<ReqBody>, Response = Response>,
{
type Response = Response;
type Error = S::Error;
type Future = ReportingFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let method = req.method().as_str().to_owned();
let route = req
.extensions()
.get::<MatchedPath>()
.map(|m| m.as_str().to_owned());
let request_id = req
.extensions()
.get::<RequestId>()
.map(std::string::ToString::to_string);
let context = Some(RequestContext {
method,
route,
request_id,
});
let inner = &mut self.inner;
match std::panic::catch_unwind(AssertUnwindSafe(|| inner.call(req))) {
Ok(future) => ReportingFuture {
inner: Some(future),
pending_panic: None,
context,
chain: Arc::clone(&self.chain),
},
Err(panic) => ReportingFuture {
inner: None,
pending_panic: Some(panic),
context,
chain: Arc::clone(&self.chain),
},
}
}
}
pin_project! {
pub struct ReportingFuture<F> {
#[pin]
inner: Option<F>,
pending_panic: Option<Box<dyn Any + Send>>,
context: Option<RequestContext>,
chain: Arc<ReporterChain>,
}
}
impl<F, E> Future for ReportingFuture<F>
where
F: Future<Output = Result<Response, E>>,
{
type Output = Result<Response, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
if let Some(panic) = this.pending_panic.take() {
let context = this.context.take();
return Poll::Ready(Ok(handle_panic(&*panic, context, this.chain)));
}
let Some(inner) = this.inner.as_pin_mut() else {
return Poll::Pending;
};
match std::panic::catch_unwind(AssertUnwindSafe(move || inner.poll(cx))) {
Ok(Poll::Pending) => Poll::Pending,
Ok(Poll::Ready(Ok(response))) => {
if let Some(context) = this.context.take() {
report_response(&response, context, this.chain);
}
Poll::Ready(Ok(response))
}
Ok(Poll::Ready(Err(error))) => Poll::Ready(Err(error)),
Err(panic) => {
let context = this.context.take();
let response = handle_panic(&*panic, context, this.chain);
Poll::Ready(Ok(response))
}
}
}
}
fn report_response(response: &Response, context: RequestContext, chain: &Arc<ReporterChain>) {
if !response.status().is_server_error() {
return;
}
let info = response.extensions().get::<AutumnErrorInfo>();
let (message, problem_type) = info.map_or_else(
|| {
(
response
.status()
.canonical_reason()
.unwrap_or("server error")
.to_owned(),
None,
)
},
|info| (info.message.clone(), info.problem_type.map(str::to_owned)),
);
chain.dispatch(ErrorEvent {
status: response.status(),
message,
problem_type,
request_id: context.request_id,
route: context.route,
method: Some(context.method),
panic: None,
});
}
fn handle_panic(
payload: &(dyn Any + Send),
context: Option<RequestContext>,
chain: &Arc<ReporterChain>,
) -> Response {
let message = format_panic_payload(payload);
let backtrace = LAST_PANIC
.with(|cell| cell.borrow_mut().take())
.and_then(|captured| captured.backtrace);
if let Some(context) = context {
chain.dispatch(ErrorEvent {
status: StatusCode::INTERNAL_SERVER_ERROR,
message: message.clone(),
problem_type: None,
request_id: context.request_id,
route: context.route,
method: Some(context.method),
panic: Some(PanicInfo {
payload: message,
backtrace,
}),
});
}
crate::error::AutumnError::internal_server_error_msg("Internal server error").into_response()
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
#[test]
fn sampled_extremes_are_deterministic() {
assert!(sampled(1.0));
assert!(sampled(2.0));
assert!(!sampled(0.0));
assert!(!sampled(-1.0));
}
#[test]
fn sampled_full_rate_always_true_over_many_draws() {
for _ in 0..1000 {
assert!(sampled(1.0));
}
}
#[test]
fn format_panic_payload_handles_str_and_string() {
let s: &str = "boom";
assert_eq!(format_panic_payload(&s), "boom");
let owned: String = "kaboom".to_owned();
assert_eq!(format_panic_payload(&owned), "kaboom");
let other: u32 = 7;
assert_eq!(format_panic_payload(&other), "handler panicked");
}
#[test]
fn log_reporter_is_the_default_when_empty() {
let layer = ReportingLayer::new(Vec::new(), true, 1.0);
assert_eq!(layer.chain.reporters.len(), 1);
}
#[tokio::test]
async fn panic_in_inner_call_is_caught_as_500() {
use axum::body::Body;
use std::convert::Infallible;
use tower::ServiceExt;
#[derive(Clone)]
struct PanicInCall;
impl Service<Request<Body>> for PanicInCall {
type Response = Response;
type Error = Infallible;
type Future = std::future::Ready<Result<Response, Infallible>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Body>) -> Self::Future {
panic!("boom in call");
}
}
let service = ReportingLayer::new(Vec::new(), true, 1.0).layer(PanicInCall);
let response = service
.oneshot(Request::new(Body::empty()))
.await
.expect("panic in call must be converted to a response, not propagated");
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
async fn disabled_chain_does_not_dispatch() {
#[derive(Clone)]
struct Counter(Arc<Mutex<u32>>);
impl ErrorReporter for Counter {
fn report<'a>(&'a self, _event: &'a ErrorEvent) -> ReportFuture<'a> {
let count = self.0.clone();
Box::pin(async move {
*count.lock().unwrap() += 1;
})
}
}
let count = Arc::new(Mutex::new(0));
let chain = Arc::new(ReporterChain {
reporters: vec![Arc::new(Counter(count.clone()))],
enabled: false,
sample_rate: 1.0,
});
chain.dispatch(ErrorEvent {
status: StatusCode::INTERNAL_SERVER_ERROR,
message: "x".into(),
problem_type: None,
request_id: None,
route: None,
method: None,
panic: None,
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(*count.lock().unwrap(), 0);
}
fn server_error_event() -> ErrorEvent {
ErrorEvent {
status: StatusCode::INTERNAL_SERVER_ERROR,
message: "boom".into(),
problem_type: Some("https://autumn.dev/problems/x".into()),
request_id: Some("req-1".into()),
route: Some("/x".into()),
method: Some("GET".into()),
panic: None,
}
}
fn panic_event() -> ErrorEvent {
ErrorEvent {
status: StatusCode::INTERNAL_SERVER_ERROR,
message: "kaboom".into(),
problem_type: None,
request_id: None,
route: None,
method: None,
panic: Some(PanicInfo {
payload: "kaboom".into(),
backtrace: Some("<backtrace>".into()),
}),
}
}
#[tokio::test]
async fn log_reporter_reports_both_event_kinds() {
let reporter = LogReporter;
reporter.report(&server_error_event()).await;
reporter.report(&panic_event()).await;
}
#[test]
fn sampled_fractional_uses_prng_and_varies() {
let mut trues = 0;
for _ in 0..10_000 {
if sampled(0.5) {
trues += 1;
}
}
assert!(
trues > 0 && trues < 10_000,
"fractional sampling should produce a mix of decisions, got {trues}"
);
}
#[tokio::test]
async fn reporter_panicking_while_constructing_future_is_swallowed() {
struct PanicOnConstruct;
impl ErrorReporter for PanicOnConstruct {
fn report<'a>(&'a self, _event: &'a ErrorEvent) -> ReportFuture<'a> {
panic!("panic before returning the future");
}
}
let chain = ReporterChain {
reporters: vec![Arc::new(PanicOnConstruct)],
enabled: true,
sample_rate: 1.0,
};
chain.report_all(&server_error_event()).await;
}
#[test]
fn dispatch_without_a_runtime_is_a_noop() {
let chain = Arc::new(ReporterChain {
reporters: vec![Arc::new(LogReporter)],
enabled: true,
sample_rate: 1.0,
});
chain.dispatch(server_error_event());
}
}