1use std::marker::PhantomData;
6
7use axum::extract::FromRequestParts;
8use axum::http::request::Parts;
9use axum::http::StatusCode;
10
11use crate::headers::DocumentedHeader;
12
13pub struct Header<H: DocumentedHeader>(pub String, pub PhantomData<H>);
43
44impl<S, H> FromRequestParts<S> for Header<H>
45where
46 S: Send + Sync,
47 H: DocumentedHeader,
48{
49 type Rejection = (StatusCode, &'static str);
50
51 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
52 let value = parts
53 .headers
54 .get(H::name())
55 .ok_or((StatusCode::BAD_REQUEST, "missing required header"))?
56 .to_str()
57 .map_err(|_| (StatusCode::BAD_REQUEST, "header is not valid UTF-8"))?
58 .to_string();
59 Ok(Self(value, PhantomData))
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66
67 use axum::http::Request;
68
69 struct XApiKey;
70 impl DocumentedHeader for XApiKey {
71 fn name() -> &'static str {
72 "X-Api-Key"
73 }
74 }
75
76 fn parts_with_header(name: &str, value: &str) -> Parts {
77 let req = Request::builder()
78 .uri("/x")
79 .header(name, value)
80 .body(())
81 .unwrap();
82 req.into_parts().0
83 }
84
85 #[tokio::test]
86 async fn header_extractor_returns_value_when_header_present() {
87 let mut parts = parts_with_header("X-Api-Key", "ak_live_42");
88 let h = Header::<XApiKey>::from_request_parts(&mut parts, &())
89 .await
90 .expect("present");
91 assert_eq!(h.0, "ak_live_42");
92 }
93
94 #[tokio::test]
95 async fn header_extractor_returns_400_when_header_missing() {
96 let req = Request::builder().uri("/x").body(()).unwrap();
97 let mut parts = req.into_parts().0;
98 let res = Header::<XApiKey>::from_request_parts(&mut parts, &()).await;
99 assert!(res.is_err());
100 let err = res.err().unwrap();
101 assert_eq!(err.0, StatusCode::BAD_REQUEST);
102 }
103
104 #[tokio::test]
105 async fn header_extractor_returns_400_when_header_not_utf8() {
106 let req = Request::builder()
107 .uri("/x")
108 .header("X-Api-Key", &[0xff, 0xfe][..])
109 .body(())
110 .unwrap();
111 let mut parts = req.into_parts().0;
112 let res = Header::<XApiKey>::from_request_parts(&mut parts, &()).await;
113 assert!(res.is_err());
114 let err = res.err().unwrap();
115 assert_eq!(err.0, StatusCode::BAD_REQUEST);
116 }
117
118 #[tokio::test]
119 async fn header_extractor_lookup_is_case_insensitive() {
120 let mut parts = parts_with_header("x-api-key", "lower");
123 let h = Header::<XApiKey>::from_request_parts(&mut parts, &())
124 .await
125 .expect("present");
126 assert_eq!(h.0, "lower");
127 }
128}