1use std::{
2 borrow::Cow, convert::Infallible, error::Error, fmt::Display, marker::PhantomData, sync::Arc,
3};
4
5use axum::{
6 extract::{FromRef, FromRequestParts},
7 http::{HeaderMap, HeaderName, header::AUTHORIZATION, request::Parts},
8};
9use jsonwebtoken::{DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode};
10use serde::{Serialize, de::DeserializeOwned};
11
12use super::JwtError;
13use crate::utils::get_env;
14
15static PREFIX_BEARER: &str = "Bearer ";
16static PREFIX_NONE: &str = "";
17
18pub struct JwtContext<T>(Arc<JwtContextInner<T>>);
19
20struct JwtContextInner<T> {
21 encoding_key: EncodingKey,
22 decoding_key: DecodingKey,
23 jwt_header: Header,
24 validation: Validation,
25 data: PhantomData<T>,
26 extract: ExtractFrom,
27}
28
29impl JwtContext<()> {
30 pub fn builder() -> JwtContextBuilder {
31 JwtContextBuilder::new()
32 }
33}
34
35impl<T: Serialize> JwtContext<T> {
36 pub fn encode_token(&self, data: &T) -> jsonwebtoken::errors::Result<String> {
37 encode(&self.0.jwt_header, data, &self.0.encoding_key)
38 }
39}
40
41impl<T: DeserializeOwned> JwtContext<T> {
42 pub fn decode(&self, jwt: impl AsRef<[u8]>) -> Result<TokenData<T>, JwtError> {
43 decode(jwt.as_ref(), &self.0.decoding_key, &self.0.validation)
44 }
45
46 pub(crate) fn decode_from_headers(&self, headers: &HeaderMap) -> Option<T> {
47 let result = match &self.0.extract {
48 #[cfg(feature = "cookie")]
49 ExtractFrom::Cookie(cookie_name) => {
50 let jar = cookie_monster::CookieJar::from_headers(headers);
51 let cookie = jar.get(cookie_name)?;
52
53 self.decode(cookie.value())
54 }
55 ExtractFrom::Header { header, prefix } => {
56 let authorization_header = headers.get(header)?.to_str().ok()?;
57
58 let jwt = jwt_from_header_value(authorization_header, prefix)?;
59 self.decode(jwt)
60 }
61 };
62
63 result.ok().map(|t| t.claims)
64 }
65}
66
67fn jwt_from_header_value<'a>(header: &'a str, prefix: &str) -> Option<&'a str> {
68 let prefix_len = prefix.len();
69
70 if header.len() < prefix_len {
71 return None;
72 }
73
74 if !header[..prefix_len].eq_ignore_ascii_case(prefix) {
75 return None;
76 }
77
78 Some(&header[prefix_len..])
79}
80
81impl<S, U> FromRequestParts<S> for JwtContext<U>
82where
83 JwtContext<U>: FromRef<S>,
84 S: Send + Sync,
85{
86 type Rejection = Infallible;
87
88 async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
89 Ok(Self::from_ref(state))
90 }
91}
92
93pub struct JwtContextBuilder {
94 encoding_key: Option<EncodingKey>,
95 decoding_key: Option<DecodingKey>,
96 jwt_header: Header,
97 validation: Validation,
98 extract: ExtractFrom,
99}
100
101impl Default for JwtContextBuilder {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107impl JwtContextBuilder {
108 pub fn new() -> Self {
109 JwtContextBuilder {
110 encoding_key: None,
111 decoding_key: None,
112 jwt_header: Header::default(),
113 validation: Validation::default(),
114 extract: ExtractFrom::header_with_prefix(AUTHORIZATION, PREFIX_BEARER),
115 }
116 }
117
118 pub fn encoding_key(mut self, encoding_key: EncodingKey) -> Self {
119 self.encoding_key = Some(encoding_key);
120 self
121 }
122
123 pub fn decoding_key(mut self, decoding_key: DecodingKey) -> Self {
124 self.decoding_key = Some(decoding_key);
125 self
126 }
127
128 pub fn jwt_secret(self, jwt_secret: impl AsRef<[u8]>) -> Self {
129 let jwt_secret = jwt_secret.as_ref();
130 self.encoding_key(EncodingKey::from_secret(jwt_secret))
131 .decoding_key(DecodingKey::from_secret(jwt_secret))
132 }
133
134 pub fn jwt_secret_env(self, name: &str) -> Self {
135 self.jwt_secret(get_env(name))
136 }
137
138 pub fn validation(mut self, validation: Validation) -> Self {
139 self.validation = validation;
140 self
141 }
142
143 pub fn jwt_header(mut self, header: Header) -> Self {
144 self.jwt_header = header;
145 self
146 }
147
148 pub fn extract_header_with_prefix(
149 mut self,
150 header: impl AsRef<[u8]>,
151 prefix: impl Into<Cow<'static, str>>,
152 ) -> Self {
153 self.extract = ExtractFrom::header_with_prefix(
154 HeaderName::from_bytes(header.as_ref())
155 .expect("header value contains invalid characters"),
156 prefix.into(),
157 );
158 self
159 }
160
161 pub fn extract_header(mut self, header: impl AsRef<str>) -> Self {
162 self.extract = ExtractFrom::header_with_prefix(
163 HeaderName::from_bytes(header.as_ref().as_bytes())
164 .expect("header value contains invalid characters"),
165 PREFIX_NONE,
166 );
167 self
168 }
169
170 #[cfg(feature = "cookie")]
171 pub fn extract_cookie(mut self, cookie_name: impl Into<Cow<'static, str>>) -> Self {
172 self.extract = ExtractFrom::cookie(cookie_name.into());
173 self
174 }
175
176 pub fn try_build<T>(self) -> Result<JwtContext<T>, JwtBuilderError> {
177 let encoding_key = self
178 .encoding_key
179 .ok_or(JwtBuilderError::EncodingKeyMissing)?;
180
181 let decoding_key = self
182 .decoding_key
183 .ok_or(JwtBuilderError::DecodingKeyMissing)?;
184
185 Ok(JwtContext(Arc::new(JwtContextInner {
186 encoding_key,
187 decoding_key,
188 jwt_header: self.jwt_header,
189 validation: self.validation,
190 extract: self.extract,
191 data: PhantomData,
192 })))
193 }
194
195 pub fn build<T>(self) -> JwtContext<T> {
196 self.try_build().unwrap()
197 }
198}
199
200#[derive(Debug)]
201pub enum JwtBuilderError {
202 EncodingKeyMissing,
203 DecodingKeyMissing,
204}
205
206impl Display for JwtBuilderError {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 match self {
209 JwtBuilderError::EncodingKeyMissing => f.write_str("Encoding key is missing"),
210 JwtBuilderError::DecodingKeyMissing => f.write_str("Decoding key is missing"),
211 }
212 }
213}
214
215impl Error for JwtBuilderError {}
216
217pub(crate) enum ExtractFrom {
218 #[cfg(feature = "cookie")]
219 Cookie(Cow<'static, str>),
220 Header {
221 header: HeaderName,
222 prefix: Cow<'static, str>,
223 },
224}
225
226impl ExtractFrom {
227 #[cfg(feature = "cookie")]
228 fn cookie(name: Cow<'static, str>) -> Self {
229 ExtractFrom::Cookie(name)
230 }
231
232 fn header_with_prefix(header: HeaderName, prefix: impl Into<Cow<'static, str>>) -> Self {
233 ExtractFrom::Header {
234 header,
235 prefix: prefix.into(),
236 }
237 }
238}
239
240impl<T> Clone for JwtContext<T> {
241 fn clone(&self) -> Self {
242 JwtContext(self.0.clone())
243 }
244}
245
246#[cfg(test)]
247mod jwt_builder {
248 use crate::jwt::{DecodingKey, EncodingKey, JwtBuilderError, JwtContext};
249
250 #[test]
251 fn encoding_key_missing() {
252 let result = JwtContext::builder()
253 .decoding_key(DecodingKey::from_secret(b"test"))
254 .try_build::<()>();
255
256 assert!(matches!(result, Err(JwtBuilderError::EncodingKeyMissing)));
257 }
258
259 #[test]
260 fn decoding_key_missing() {
261 let result = JwtContext::builder()
262 .encoding_key(EncodingKey::from_secret(b"test"))
263 .try_build::<()>();
264
265 assert!(matches!(result, Err(JwtBuilderError::DecodingKeyMissing)));
266 }
267}