1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3
4use std::sync::Arc;
10
11use axum::extract::{FromRequestParts, MatchedPath, State};
12use axum::http::request::Parts;
13use axum::http::{Request, StatusCode};
14use axum::middleware::Next;
15use axum::response::IntoResponse;
16use tailtriage_core::{Outcome, OwnedRequestHandle, RequestOptions, Tailtriage};
17
18#[must_use]
20pub const fn crate_name() -> &'static str {
21 "tailtriage-axum"
22}
23
24pub async fn middleware(
29 State(tailtriage): State<Arc<Tailtriage>>,
30 mut request: Request<axum::body::Body>,
31 next: Next,
32) -> axum::response::Response {
33 let route = request_route_label(&request);
34 let started = tailtriage.begin_request_with_owned(route, RequestOptions::new().kind("http"));
35
36 request
37 .extensions_mut()
38 .insert(TailtriageRequest(started.handle.clone()));
39
40 let response = next.run(request).await;
41 let status = response.status();
42
43 started.completion.finish(status_to_outcome(status));
44 response
45}
46
47#[derive(Debug, Clone)]
49pub struct TailtriageRequest(pub OwnedRequestHandle);
50
51impl TailtriageRequest {
52 #[must_use]
54 pub fn into_inner(self) -> OwnedRequestHandle {
55 self.0
56 }
57}
58
59impl<S> FromRequestParts<S> for TailtriageRequest
60where
61 S: Send + Sync,
62{
63 type Rejection = TailtriageExtractorError;
64
65 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
66 parts
67 .extensions
68 .get::<TailtriageRequest>()
69 .cloned()
70 .ok_or(TailtriageExtractorError)
71 }
72}
73
74#[derive(Debug, Clone, Copy)]
76pub struct TailtriageExtractorError;
77
78impl IntoResponse for TailtriageExtractorError {
79 fn into_response(self) -> axum::response::Response {
80 (
81 StatusCode::INTERNAL_SERVER_ERROR,
82 "tailtriage extractor missing. Add tailtriage_axum::middleware.",
83 )
84 .into_response()
85 }
86}
87
88fn request_route_label(request: &Request<axum::body::Body>) -> String {
89 request
90 .extensions()
91 .get::<MatchedPath>()
92 .map_or_else(|| request.uri().path(), MatchedPath::as_str)
93 .to_owned()
94}
95
96fn status_to_outcome(status: StatusCode) -> Outcome {
97 if status.is_server_error() {
98 Outcome::Error
99 } else {
100 Outcome::Ok
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::{crate_name, status_to_outcome};
107
108 #[test]
109 fn crate_name_is_stable() {
110 assert_eq!(crate_name(), "tailtriage-axum");
111 }
112
113 #[test]
114 fn maps_server_errors_to_error_outcome() {
115 assert_eq!(
116 status_to_outcome(axum::http::StatusCode::INTERNAL_SERVER_ERROR),
117 tailtriage_core::Outcome::Error
118 );
119 assert_eq!(
120 status_to_outcome(axum::http::StatusCode::BAD_REQUEST),
121 tailtriage_core::Outcome::Ok
122 );
123 }
124}