1use std::{fmt::Debug, future::Future, marker::PhantomData, pin::Pin};
4
5use viz_core::{
6 http::{
7 header::{HeaderValue, WWW_AUTHENTICATE},
8 StatusCode,
9 },
10 Context, Middleware, Response, Result,
11};
12
13use viz_utils::tracing;
14
15#[cfg(feature = "jwt-header")]
16use viz_core::http::headers::{
17 authorization::{Authorization, Bearer},
18 HeaderMapExt,
19};
20
21#[cfg(all(
22 feature = "jwt-query",
23 not(all(feature = "jwt-header", feature = "jwt-param", feature = "jwt-cookie"))
24))]
25use std::collections::HashMap;
26
27use jsonwebtoken::{decode, DecodingKey, Validation};
28use serde::de::DeserializeOwned;
29
30pub use jsonwebtoken;
31
32#[derive(Debug)]
34pub struct Jwt<T>
35where
36 T: Debug,
37{
38 #[cfg(all(
39 not(feature = "jwt-header"),
40 any(feature = "jwt-query", feature = "jwt-param", feature = "jwt-cookie")
41 ))]
42 n: String,
43 s: String,
44 v: Validation,
45 t: PhantomData<T>,
46}
47
48impl<T> Jwt<T>
49where
50 T: DeserializeOwned + Sync + Send + 'static + Debug,
51{
52 pub fn new() -> Self {
54 Self {
55 #[cfg(all(
56 not(feature = "jwt-header"),
57 any(feature = "jwt-query", feature = "jwt-param", feature = "jwt-cookie")
58 ))]
59 n: "token".to_owned(),
60 s: "secret".to_owned(),
61 v: Validation::default(),
62 t: PhantomData,
63 }
64 }
65
66 pub fn secret(mut self, secret: &str) -> Self {
68 self.s = secret.to_owned();
69 self
70 }
71
72 pub fn validation(mut self, validation: Validation) -> Self {
74 self.v = validation;
75 self
76 }
77
78 #[cfg(all(
80 not(feature = "jwt-header"),
81 any(feature = "jwt-query", feature = "jwt-param", feature = "jwt-cookie")
82 ))]
83 pub fn name(mut self, name: &str) -> Self {
84 self.n = name.to_owned();
85 self
86 }
87
88 async fn run(&self, cx: &mut Context) -> Result<Response> {
89 let (status, error) = if let Some(val) = self.get(cx) {
90 match decode::<T>(&val, &DecodingKey::from_secret(self.s.as_ref()), &self.v) {
91 Ok(token) => {
92 cx.extensions_mut().insert(token);
93 return cx.next().await;
94 }
95 Err(e) => {
96 tracing::error!("JWT error: {}", e);
97 (StatusCode::UNAUTHORIZED, "Invalid or expired JWT")
98 }
99 }
100 } else {
101 (StatusCode::BAD_REQUEST, "Missing or malformed JWT")
102 };
103
104 let mut res: Response = status.into();
105 res.headers_mut().insert(WWW_AUTHENTICATE, HeaderValue::from_str(error)?);
106
107 Ok(res)
108 }
109
110 #[allow(unused_variables)]
111 fn get(&self, cx: &mut Context) -> Option<String> {
113 cfg_if::cfg_if! {
114 if #[cfg(feature = "jwt-header")] {
115 cx.headers()
116 .typed_get::<Authorization<Bearer>>()
117 .map(|auth| auth.0.token().to_owned())
118 } else if #[cfg(feature = "jwt-query")] {
119 cx.query::<HashMap<String, String>>()
120 .ok()?
121 .get(&self.n)
122 .cloned()
123 } else if #[cfg(feature = "jwt-param")] {
124 cx.param(&self.n).ok()
125 } else if #[cfg(feature = "jwt-cookie")] {
126 cx.cookie(&self.n).map(std::string::ToString::to_string)
127 } else {
128 None
129 }
130 }
131 }
132}
133
134impl<T> Default for Jwt<T>
135where
136 T: DeserializeOwned + Sync + Send + 'static + Debug,
137{
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143impl<'a, T> Middleware<'a, Context> for Jwt<T>
144where
145 T: DeserializeOwned + Sync + Send + 'static + Debug,
146{
147 type Output = Result<Response>;
148
149 #[must_use]
150 fn call(
151 &'a self,
152 cx: &'a mut Context,
153 ) -> Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>> {
154 Box::pin(self.run(cx))
155 }
156}