Skip to main content

axum_security/jwt/
builder.rs

1use std::{
2    borrow::Cow, convert::Infallible, error::Error, fmt::Display, marker::PhantomData, sync::Arc,
3};
4
5use axum::{
6    extract::{FromRef, FromRequestParts},
7    http::{HeaderMap, HeaderName, header::AUTHORIZATION, request::Parts},
8};
9use jsonwebtoken::{DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode};
10use serde::{Serialize, de::DeserializeOwned};
11
12use super::JwtError;
13use crate::utils::get_env;
14
15static PREFIX_BEARER: &str = "Bearer ";
16static PREFIX_NONE: &str = "";
17
18pub struct JwtContext<T>(Arc<JwtContextInner<T>>);
19
20struct JwtContextInner<T> {
21    encoding_key: EncodingKey,
22    decoding_key: DecodingKey,
23    jwt_header: Header,
24    validation: Validation,
25    data: PhantomData<T>,
26    extract: ExtractFrom,
27}
28
29impl JwtContext<()> {
30    pub fn builder() -> JwtContextBuilder {
31        JwtContextBuilder::new()
32    }
33}
34
35impl<T: Serialize> JwtContext<T> {
36    pub fn encode_token(&self, data: &T) -> jsonwebtoken::errors::Result<String> {
37        encode(&self.0.jwt_header, data, &self.0.encoding_key)
38    }
39}
40
41impl<T: DeserializeOwned> JwtContext<T> {
42    pub fn decode(&self, jwt: impl AsRef<[u8]>) -> Result<TokenData<T>, JwtError> {
43        decode(jwt.as_ref(), &self.0.decoding_key, &self.0.validation)
44    }
45
46    pub(crate) fn decode_from_headers(&self, headers: &HeaderMap) -> Option<T> {
47        let result = match &self.0.extract {
48            #[cfg(feature = "cookie")]
49            ExtractFrom::Cookie(cookie_name) => {
50                let jar = cookie_monster::CookieJar::from_headers(headers);
51                let cookie = jar.get(cookie_name)?;
52
53                self.decode(cookie.value())
54            }
55            ExtractFrom::Header { header, prefix } => {
56                let authorization_header = headers.get(header)?.to_str().ok()?;
57
58                let jwt = jwt_from_header_value(authorization_header, prefix)?;
59                self.decode(jwt)
60            }
61        };
62
63        result.ok().map(|t| t.claims)
64    }
65}
66
67fn jwt_from_header_value<'a>(header: &'a str, prefix: &str) -> Option<&'a str> {
68    let prefix_len = prefix.len();
69
70    if header.len() < prefix_len {
71        return None;
72    }
73
74    if !header[..prefix_len].eq_ignore_ascii_case(prefix) {
75        return None;
76    }
77
78    Some(&header[prefix_len..])
79}
80
81impl<S, U> FromRequestParts<S> for JwtContext<U>
82where
83    JwtContext<U>: FromRef<S>,
84    S: Send + Sync,
85{
86    type Rejection = Infallible;
87
88    async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
89        Ok(Self::from_ref(state))
90    }
91}
92
93pub struct JwtContextBuilder {
94    encoding_key: Option<EncodingKey>,
95    decoding_key: Option<DecodingKey>,
96    jwt_header: Header,
97    validation: Validation,
98    extract: ExtractFrom,
99}
100
101impl Default for JwtContextBuilder {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107impl JwtContextBuilder {
108    pub fn new() -> Self {
109        JwtContextBuilder {
110            encoding_key: None,
111            decoding_key: None,
112            jwt_header: Header::default(),
113            validation: Validation::default(),
114            extract: ExtractFrom::header_with_prefix(AUTHORIZATION, PREFIX_BEARER),
115        }
116    }
117
118    pub fn encoding_key(mut self, encoding_key: EncodingKey) -> Self {
119        self.encoding_key = Some(encoding_key);
120        self
121    }
122
123    pub fn decoding_key(mut self, decoding_key: DecodingKey) -> Self {
124        self.decoding_key = Some(decoding_key);
125        self
126    }
127
128    pub fn jwt_secret(self, jwt_secret: impl AsRef<[u8]>) -> Self {
129        let jwt_secret = jwt_secret.as_ref();
130        self.encoding_key(EncodingKey::from_secret(jwt_secret))
131            .decoding_key(DecodingKey::from_secret(jwt_secret))
132    }
133
134    pub fn jwt_secret_env(self, name: &str) -> Self {
135        self.jwt_secret(get_env(name))
136    }
137
138    pub fn validation(mut self, validation: Validation) -> Self {
139        self.validation = validation;
140        self
141    }
142
143    pub fn jwt_header(mut self, header: Header) -> Self {
144        self.jwt_header = header;
145        self
146    }
147
148    pub fn extract_header_with_prefix(
149        mut self,
150        header: impl AsRef<[u8]>,
151        prefix: impl Into<Cow<'static, str>>,
152    ) -> Self {
153        self.extract = ExtractFrom::header_with_prefix(
154            HeaderName::from_bytes(header.as_ref())
155                .expect("header value contains invalid characters"),
156            prefix.into(),
157        );
158        self
159    }
160
161    pub fn extract_header(mut self, header: impl AsRef<str>) -> Self {
162        self.extract = ExtractFrom::header_with_prefix(
163            HeaderName::from_bytes(header.as_ref().as_bytes())
164                .expect("header value contains invalid characters"),
165            PREFIX_NONE,
166        );
167        self
168    }
169
170    #[cfg(feature = "cookie")]
171    pub fn extract_cookie(mut self, cookie_name: impl Into<Cow<'static, str>>) -> Self {
172        self.extract = ExtractFrom::cookie(cookie_name.into());
173        self
174    }
175
176    pub fn try_build<T>(self) -> Result<JwtContext<T>, JwtBuilderError> {
177        let encoding_key = self
178            .encoding_key
179            .ok_or(JwtBuilderError::EncodingKeyMissing)?;
180
181        let decoding_key = self
182            .decoding_key
183            .ok_or(JwtBuilderError::DecodingKeyMissing)?;
184
185        Ok(JwtContext(Arc::new(JwtContextInner {
186            encoding_key,
187            decoding_key,
188            jwt_header: self.jwt_header,
189            validation: self.validation,
190            extract: self.extract,
191            data: PhantomData,
192        })))
193    }
194
195    pub fn build<T>(self) -> JwtContext<T> {
196        self.try_build().unwrap()
197    }
198}
199
200#[derive(Debug)]
201pub enum JwtBuilderError {
202    EncodingKeyMissing,
203    DecodingKeyMissing,
204}
205
206impl Display for JwtBuilderError {
207    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208        match self {
209            JwtBuilderError::EncodingKeyMissing => f.write_str("Encoding key is missing"),
210            JwtBuilderError::DecodingKeyMissing => f.write_str("Decoding key is missing"),
211        }
212    }
213}
214
215impl Error for JwtBuilderError {}
216
217pub(crate) enum ExtractFrom {
218    #[cfg(feature = "cookie")]
219    Cookie(Cow<'static, str>),
220    Header {
221        header: HeaderName,
222        prefix: Cow<'static, str>,
223    },
224}
225
226impl ExtractFrom {
227    #[cfg(feature = "cookie")]
228    fn cookie(name: Cow<'static, str>) -> Self {
229        ExtractFrom::Cookie(name)
230    }
231
232    fn header_with_prefix(header: HeaderName, prefix: impl Into<Cow<'static, str>>) -> Self {
233        ExtractFrom::Header {
234            header,
235            prefix: prefix.into(),
236        }
237    }
238}
239
240impl<T> Clone for JwtContext<T> {
241    fn clone(&self) -> Self {
242        JwtContext(self.0.clone())
243    }
244}
245
246#[cfg(test)]
247mod jwt_builder {
248    use crate::jwt::{DecodingKey, EncodingKey, JwtBuilderError, JwtContext};
249
250    #[test]
251    fn encoding_key_missing() {
252        let result = JwtContext::builder()
253            .decoding_key(DecodingKey::from_secret(b"test"))
254            .try_build::<()>();
255
256        assert!(matches!(result, Err(JwtBuilderError::EncodingKeyMissing)));
257    }
258
259    #[test]
260    fn decoding_key_missing() {
261        let result = JwtContext::builder()
262            .encoding_key(EncodingKey::from_secret(b"test"))
263            .try_build::<()>();
264
265        assert!(matches!(result, Err(JwtBuilderError::DecodingKeyMissing)));
266    }
267}