viz_middleware/
jwt.rs

1//! JSON Web Token Middleware
2
3use std::{fmt::Debug, future::Future, marker::PhantomData, pin::Pin};
4
5use viz_core::{
6    http::{
7        header::{HeaderValue, WWW_AUTHENTICATE},
8        StatusCode,
9    },
10    Context, Middleware, Response, Result,
11};
12
13use viz_utils::tracing;
14
15#[cfg(feature = "jwt-header")]
16use viz_core::http::headers::{
17    authorization::{Authorization, Bearer},
18    HeaderMapExt,
19};
20
21#[cfg(all(
22    feature = "jwt-query",
23    not(all(feature = "jwt-header", feature = "jwt-param", feature = "jwt-cookie"))
24))]
25use std::collections::HashMap;
26
27use jsonwebtoken::{decode, DecodingKey, Validation};
28use serde::de::DeserializeOwned;
29
30pub use jsonwebtoken;
31
32/// JWT Middleware
33#[derive(Debug)]
34pub struct Jwt<T>
35where
36    T: Debug,
37{
38    #[cfg(all(
39        not(feature = "jwt-header"),
40        any(feature = "jwt-query", feature = "jwt-param", feature = "jwt-cookie")
41    ))]
42    n: String,
43    s: String,
44    v: Validation,
45    t: PhantomData<T>,
46}
47
48impl<T> Jwt<T>
49where
50    T: DeserializeOwned + Sync + Send + 'static + Debug,
51{
52    /// Creates JWT
53    pub fn new() -> Self {
54        Self {
55            #[cfg(all(
56                not(feature = "jwt-header"),
57                any(feature = "jwt-query", feature = "jwt-param", feature = "jwt-cookie")
58            ))]
59            n: "token".to_owned(),
60            s: "secret".to_owned(),
61            v: Validation::default(),
62            t: PhantomData,
63        }
64    }
65
66    /// Creates JWT Middleware with a secret
67    pub fn secret(mut self, secret: &str) -> Self {
68        self.s = secret.to_owned();
69        self
70    }
71
72    /// Creates JWT Middleware with an validation
73    pub fn validation(mut self, validation: Validation) -> Self {
74        self.v = validation;
75        self
76    }
77
78    /// Creates JWT Middleware with a name
79    #[cfg(all(
80        not(feature = "jwt-header"),
81        any(feature = "jwt-query", feature = "jwt-param", feature = "jwt-cookie")
82    ))]
83    pub fn name(mut self, name: &str) -> Self {
84        self.n = name.to_owned();
85        self
86    }
87
88    async fn run(&self, cx: &mut Context) -> Result<Response> {
89        let (status, error) = if let Some(val) = self.get(cx) {
90            match decode::<T>(&val, &DecodingKey::from_secret(self.s.as_ref()), &self.v) {
91                Ok(token) => {
92                    cx.extensions_mut().insert(token);
93                    return cx.next().await;
94                }
95                Err(e) => {
96                    tracing::error!("JWT error: {}", e);
97                    (StatusCode::UNAUTHORIZED, "Invalid or expired JWT")
98                }
99            }
100        } else {
101            (StatusCode::BAD_REQUEST, "Missing or malformed JWT")
102        };
103
104        let mut res: Response = status.into();
105        res.headers_mut().insert(WWW_AUTHENTICATE, HeaderValue::from_str(error)?);
106
107        Ok(res)
108    }
109
110    #[allow(unused_variables)]
111    /// Gets token via Header|Query|Param|Cookie.
112    fn get(&self, cx: &mut Context) -> Option<String> {
113        cfg_if::cfg_if! {
114            if #[cfg(feature = "jwt-header")] {
115                cx.headers()
116                    .typed_get::<Authorization<Bearer>>()
117                    .map(|auth| auth.0.token().to_owned())
118            } else if #[cfg(feature = "jwt-query")] {
119                cx.query::<HashMap<String, String>>()
120                    .ok()?
121                    .get(&self.n)
122                    .cloned()
123            } else if #[cfg(feature = "jwt-param")] {
124                cx.param(&self.n).ok()
125            }  else if #[cfg(feature = "jwt-cookie")] {
126                cx.cookie(&self.n).map(std::string::ToString::to_string)
127            } else {
128                None
129            }
130        }
131    }
132}
133
134impl<T> Default for Jwt<T>
135where
136    T: DeserializeOwned + Sync + Send + 'static + Debug,
137{
138    fn default() -> Self {
139        Self::new()
140    }
141}
142
143impl<'a, T> Middleware<'a, Context> for Jwt<T>
144where
145    T: DeserializeOwned + Sync + Send + 'static + Debug,
146{
147    type Output = Result<Response>;
148
149    #[must_use]
150    fn call(
151        &'a self,
152        cx: &'a mut Context,
153    ) -> Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>> {
154        Box::pin(self.run(cx))
155    }
156}