salvo_jwt_auth/
lib.rs

1//! Provides JWT (JSON Web Token) authentication support for the Salvo web framework.
2//!
3//! This crate helps you implement JWT-based authentication in your Salvo web applications.
4//! It offers flexible token extraction from various sources (headers, query parameters, cookies, etc.)
5//! and multiple decoding strategies.
6//!
7//! # Features
8//!
9//! - Extract JWT tokens from multiple sources (headers, query parameters, cookies, forms)
10//! - Configurable token validation
11//! - OpenID Connect support (behind the `oidc` feature flag)
12//! - Seamless integration with Salvo's middleware system
13//!
14//! # Example:
15//!
16//! ```no_run
17//! use jsonwebtoken::{self, EncodingKey};
18//! use salvo::http::{Method, StatusError};
19//! use salvo::jwt_auth::{ConstDecoder, QueryFinder};
20//! use salvo::prelude::*;
21//! use serde::{Deserialize, Serialize};
22//! use time::{Duration, OffsetDateTime};
23//!
24//! const SECRET_KEY: &str = "YOUR_SECRET_KEY"; // In production, use a secure key management solution
25//!
26//! #[derive(Serialize, Deserialize, Clone, Debug)]
27//! pub struct JwtClaims {
28//!     username: String,
29//!     exp: i64,
30//! }
31//!
32//! #[tokio::main]
33//! async fn main() {
34//!     let auth_handler: JwtAuth<JwtClaims, _> = JwtAuth::new(ConstDecoder::from_secret(SECRET_KEY.as_bytes()))
35//!         .finders(vec![
36//!             // Box::new(HeaderFinder::new()),
37//!             Box::new(QueryFinder::new("jwt_token")),
38//!             // Box::new(CookieFinder::new("jwt_token")),
39//!         ])
40//!         .force_passed(true);
41//!
42//!     let acceptor = TcpListener::new("0.0.0.0:8698").bind().await;
43//!     Server::new(acceptor)
44//!         .serve(Router::with_hoop(auth_handler).goal(index))
45//!         .await;
46//! }
47//! #[handler]
48//! async fn index(req: &mut Request, depot: &mut Depot, res: &mut Response) -> anyhow::Result<()> {
49//!     if req.method() == Method::POST {
50//!         let (username, password) = (
51//!             req.form::<String>("username").await.unwrap_or_default(),
52//!             req.form::<String>("password").await.unwrap_or_default(),
53//!         );
54//!         if !validate(&username, &password) {
55//!             res.render(Text::Html(LOGIN_HTML));
56//!             return Ok(());
57//!         }
58//!         let exp = OffsetDateTime::now_utc() + Duration::days(14);
59//!         let claim = JwtClaims {
60//!             username,
61//!             exp: exp.unix_timestamp(),
62//!         };
63//!         let token = jsonwebtoken::encode(
64//!             &jsonwebtoken::Header::default(),
65//!             &claim,
66//!             &EncodingKey::from_secret(SECRET_KEY.as_bytes()),
67//!         )?;
68//!         res.render(Redirect::other(format!("/?jwt_token={token}")));
69//!     } else {
70//!         match depot.jwt_auth_state() {
71//!             JwtAuthState::Authorized => {
72//!                 let data = depot.jwt_auth_data::<JwtClaims>().unwrap();
73//!                 res.render(Text::Plain(format!(
74//!                     "Hi {}, you have logged in successfully!",
75//!                     data.claims.username
76//!                 )));
77//!             }
78//!             JwtAuthState::Unauthorized => {
79//!                 res.render(Text::Html(LOGIN_HTML));
80//!             }
81//!             JwtAuthState::Forbidden => {
82//!                 res.render(StatusError::forbidden());
83//!             }
84//!         }
85//!     }
86//!     Ok(())
87//! }
88//!
89//! fn validate(username: &str, password: &str) -> bool {
90//!     // In a real application, use secure password verification
91//!     username == "root" && password == "pwd"
92//! }
93//!
94//! static LOGIN_HTML: &str = r#"<!DOCTYPE html>
95//! <html>
96//!     <head>
97//!         <title>JWT Auth Demo</title>
98//!     </head>
99//!     <body>
100//!         <h1>JWT Auth</h1>
101//!         <form action="/" method="post">
102//!         <label for="username"><b>Username</b></label>
103//!         <input type="text" placeholder="Enter Username" name="username" required>
104//!
105//!         <label for="password"><b>Password</b></label>
106//!         <input type="password" placeholder="Enter Password" name="password" required>
107//!
108//!         <button type="submit">Login</button>
109//!     </form>
110//!     </body>
111//! </html>
112//! "#;
113//! ```
114
115#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
116#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
117#![cfg_attr(docsrs, feature(doc_cfg))]
118
119use std::fmt::{self, Debug, Formatter};
120use std::marker::PhantomData;
121
122#[doc(no_inline)]
123pub use jsonwebtoken::{
124    Algorithm, DecodingKey, TokenData, Validation, decode, errors::Error as JwtError,
125};
126use serde::de::DeserializeOwned;
127use thiserror::Error;
128
129use salvo_core::http::{Method, Request, Response, StatusError};
130use salvo_core::{Depot, FlowCtrl, Handler, async_trait};
131
132mod finder;
133pub use finder::{CookieFinder, FormFinder, HeaderFinder, JwtTokenFinder, QueryFinder};
134
135mod decoder;
136pub use decoder::{ConstDecoder, JwtAuthDecoder};
137
138#[macro_use]
139mod cfg;
140
141cfg_feature! {
142    #![feature = "oidc"]
143    pub mod oidc;
144    pub use oidc::OidcDecoder;
145}
146
147/// key used to insert auth decoded data to depot.
148pub const JWT_AUTH_DATA_KEY: &str = "::salvo::jwt_auth::auth_data";
149/// key used to insert auth state data to depot.
150pub const JWT_AUTH_STATE_KEY: &str = "::salvo::jwt_auth::auth_state";
151/// key used to insert auth token data to depot.
152pub const JWT_AUTH_TOKEN_KEY: &str = "::salvo::jwt_auth::auth_token";
153/// key used to insert auth error to depot.
154pub const JWT_AUTH_ERROR_KEY: &str = "::salvo::jwt_auth::auth_error";
155
156const ALL_METHODS: [Method; 9] = [
157    Method::GET,
158    Method::POST,
159    Method::PUT,
160    Method::DELETE,
161    Method::HEAD,
162    Method::OPTIONS,
163    Method::CONNECT,
164    Method::PATCH,
165    Method::TRACE,
166];
167
168/// JwtAuthError
169#[derive(Debug, Error)]
170pub enum JwtAuthError {
171    /// HTTP client error
172    #[cfg(feature = "oidc")]
173    #[cfg_attr(docsrs, doc(cfg(feature = "oidc")))]
174    #[error("ClientError")]
175    ClientError(#[from] hyper_util::client::legacy::Error),
176
177    /// Error occurred in hyper.
178    #[cfg(feature = "oidc")]
179    #[cfg_attr(docsrs, doc(cfg(feature = "oidc")))]
180    #[error("HyperError")]
181    Hyper(#[from] salvo_core::hyper::Error),
182
183    /// InvalidUri
184    #[error("InvalidUri")]
185    InvalidUri(#[from] salvo_core::http::uri::InvalidUri),
186    /// Serde error
187    #[error("Serde error")]
188    SerdeError(#[from] serde_json::Error),
189    /// Failed to discover OIDC configuration
190    #[error("Failed to discover OIDC configuration")]
191    DiscoverError,
192    /// Decoding of JWKS error
193    #[error("Decoding of JWKS error")]
194    DecodeError(#[from] base64::DecodeError),
195    /// JWT is missing kid, alg, or decoding components
196    #[error("JWT is missing kid, alg, or decoding components")]
197    InvalidJwk,
198    /// Issuer URL invalid
199    #[error("Issuer URL invalid")]
200    IssuerParseError,
201    /// Failure of validating the token. See [jsonwebtoken::errors::ErrorKind] for possible reasons this value could be returned
202    /// Would typically result in a 401 HTTP Status code
203    #[error("JWT Is Invalid")]
204    ValidationFailed(#[from] jsonwebtoken::errors::Error),
205    /// Failure to re-validate the JWKS.
206    /// Would typically result in a 401 or 500 status code depending on preference
207    #[error("Token was unable to be validated due to cache expiration")]
208    CacheError,
209    /// Token did not contain a kid in its header and would be impossible to validate
210    /// Would typically result in a 401 HTTP Status code
211    #[error("Token did not contain a KID field")]
212    MissingKid,
213}
214
215/// Possible states of JWT authentication.
216///
217/// The middleware sets this state in the depot after processing a request.
218/// You can access it via `depot.jwt_auth_state()`.
219#[derive(Copy, Clone, Eq, PartialEq, Debug)]
220pub enum JwtAuthState {
221    /// Authentication was successful and the token was valid.
222    Authorized,
223    /// No token was provided in the request.
224    /// Usually results in a 401 Unauthorized response unless `force_passed` is true.
225    Unauthorized,
226    /// A token was provided but it failed validation.
227    /// Usually results in a 403 Forbidden response unless `force_passed` is true.
228    Forbidden,
229}
230
231/// Extension trait for accessing JWT authentication data from the depot.
232///
233/// This trait provides convenient methods to retrieve JWT authentication information
234/// that was previously stored in the depot by the `JwtAuth` middleware.
235pub trait JwtAuthDepotExt {
236    /// Gets the JWT token string from the depot.
237    fn jwt_auth_token(&self) -> Option<&str>;
238
239    /// Gets the decoded JWT claims data from the depot.
240    ///
241    /// The generic parameter `C` should be the same type used when configuring the `JwtAuth` middleware.
242    fn jwt_auth_data<C>(&self) -> Option<&TokenData<C>>
243    where
244        C: DeserializeOwned + Send + Sync + 'static;
245
246    /// Gets the current JWT authentication state from the depot.
247    ///
248    /// Returns `JwtAuthState::Unauthorized` if no state is present in the depot.
249    fn jwt_auth_state(&self) -> JwtAuthState;
250
251    /// Gets the JWT error if authentication failed.
252    fn jwt_auth_error(&self) -> Option<&JwtError>;
253}
254
255impl JwtAuthDepotExt for Depot {
256    #[inline]
257    fn jwt_auth_token(&self) -> Option<&str> {
258        self.get::<String>(JWT_AUTH_TOKEN_KEY).map(|v| &**v).ok()
259    }
260
261    #[inline]
262    fn jwt_auth_data<C>(&self) -> Option<&TokenData<C>>
263    where
264        C: DeserializeOwned + Send + Sync + 'static,
265    {
266        self.get(JWT_AUTH_DATA_KEY).ok()
267    }
268
269    #[inline]
270    fn jwt_auth_state(&self) -> JwtAuthState {
271        self.get(JWT_AUTH_STATE_KEY)
272            .ok()
273            .cloned()
274            .unwrap_or(JwtAuthState::Unauthorized)
275    }
276
277    #[inline]
278    fn jwt_auth_error(&self) -> Option<&JwtError> {
279        self.get(JWT_AUTH_ERROR_KEY).ok()
280    }
281}
282
283/// JWT Authentication middleware for Salvo.
284///
285/// `JwtAuth` extracts and validates JWT tokens from incoming requests based on the configured
286/// token finders and decoder. If valid, it stores the decoded data in the depot for later use.
287///
288/// # Type Parameters
289///
290/// * `C` - The claims type that will be deserialized from the JWT payload.
291/// * `D` - The decoder implementation used to validate and decode the JWT token.
292#[non_exhaustive]
293pub struct JwtAuth<C, D> {
294    /// When set to `true`, the middleware will allow the request to proceed even if
295    /// authentication fails, storing only the authentication state in the depot.
296    ///
297    /// When set to `false` (default), requests with invalid or missing tokens will be
298    /// immediately rejected with appropriate status codes.
299    pub force_passed: bool,
300    _claims: PhantomData<C>,
301    /// The decoder used to validate and decode the JWT token.
302    pub decoder: D,
303    /// A list of token finders that will be used to extract the token from the request.
304    /// Finders are tried in order until one returns a token.
305    pub finders: Vec<Box<dyn JwtTokenFinder>>,
306}
307impl<C, D> Debug for JwtAuth<C, D>
308where
309    C: DeserializeOwned + Send + Sync + 'static,
310    D: JwtAuthDecoder + Send + Sync + 'static,
311{
312    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
313        f.debug_struct("JwtAuth")
314            .field("force_passed", &self.force_passed)
315            .finish()
316    }
317}
318
319impl<C, D> JwtAuth<C, D>
320where
321    C: DeserializeOwned + Send + Sync + 'static,
322    D: JwtAuthDecoder + Send + Sync + 'static,
323{
324    /// Create new `JwtAuth`.
325    #[inline]
326    #[must_use]
327    pub fn new(decoder: D) -> Self {
328        Self {
329            force_passed: false,
330            decoder,
331            _claims: PhantomData::<C>,
332            finders: vec![Box::new(HeaderFinder::new())],
333        }
334    }
335    /// Sets force_passed value and return Self.
336    #[inline]
337    #[must_use]
338    pub fn force_passed(mut self, force_passed: bool) -> Self {
339        self.force_passed = force_passed;
340        self
341    }
342
343    /// Get decoder mutable reference.
344    #[inline]
345    pub fn decoder_mut(&mut self) -> &mut D {
346        &mut self.decoder
347    }
348
349    /// Gets a mutable reference to the extractor list.
350    #[inline]
351    pub fn finders_mut(&mut self) -> &mut Vec<Box<dyn JwtTokenFinder>> {
352        &mut self.finders
353    }
354    /// Sets extractor list with new value and return Self.
355    #[inline]
356    #[must_use]
357    pub fn finders(mut self, finders: Vec<Box<dyn JwtTokenFinder>>) -> Self {
358        self.finders = finders;
359        self
360    }
361
362    async fn find_token(&self, req: &mut Request) -> Option<String> {
363        for finder in &self.finders {
364            if let Some(token) = finder.find_token(req).await {
365                return Some(token);
366            }
367        }
368        None
369    }
370}
371
372#[async_trait]
373impl<C, D> Handler for JwtAuth<C, D>
374where
375    C: DeserializeOwned + Clone + Send + Sync + 'static,
376    D: JwtAuthDecoder + Send + Sync + 'static,
377{
378    async fn handle(
379        &self,
380        req: &mut Request,
381        depot: &mut Depot,
382        res: &mut Response,
383        ctrl: &mut FlowCtrl,
384    ) {
385        let token = self.find_token(req).await;
386        if let Some(token) = token {
387            match self.decoder.decode::<C>(&token, depot).await {
388                Ok(data) => {
389                    depot.insert(JWT_AUTH_DATA_KEY, data);
390                    depot.insert(JWT_AUTH_STATE_KEY, JwtAuthState::Authorized);
391                    depot.insert(JWT_AUTH_TOKEN_KEY, token);
392                }
393                Err(e) => {
394                    tracing::info!(error = ?e, "jwt auth error");
395                    depot.insert(JWT_AUTH_STATE_KEY, JwtAuthState::Forbidden);
396                    depot.insert(JWT_AUTH_ERROR_KEY, e);
397                    if !self.force_passed {
398                        res.render(StatusError::forbidden());
399                        ctrl.skip_rest();
400                    }
401                }
402            }
403        } else {
404            depot.insert(JWT_AUTH_STATE_KEY, JwtAuthState::Unauthorized);
405            if !self.force_passed {
406                res.render(StatusError::unauthorized());
407                ctrl.skip_rest();
408            }
409        }
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use jsonwebtoken::EncodingKey;
416    use salvo_core::prelude::*;
417    use salvo_core::test::{ResponseExt, TestClient};
418    use serde::{Deserialize, Serialize};
419    use time::{Duration, OffsetDateTime};
420
421    use super::*;
422
423    #[derive(Serialize, Deserialize, Clone, Debug)]
424    struct JwtClaims {
425        user: String,
426        exp: i64,
427    }
428    #[tokio::test]
429    async fn test_jwt_auth() {
430        let auth_handler: JwtAuth<JwtClaims, ConstDecoder> =
431            JwtAuth::new(ConstDecoder::from_secret(b"ABCDEF")).finders(vec![
432                Box::new(HeaderFinder::new()),
433                Box::new(QueryFinder::new("jwt_token")),
434                Box::new(CookieFinder::new("jwt_token")),
435            ]);
436
437        #[handler]
438        async fn hello() -> &'static str {
439            "hello"
440        }
441
442        let router = Router::new()
443            .hoop(auth_handler)
444            .push(Router::with_path("hello").get(hello));
445        let service = Service::new(router);
446
447        async fn access(service: &Service, token: &str) -> String {
448            TestClient::get("http://127.0.0.1:5801/hello")
449                .add_header("Authorization", format!("Bearer {token}"), true)
450                .send(service)
451                .await
452                .take_string()
453                .await
454                .unwrap()
455        }
456
457        let claim = JwtClaims {
458            user: "root".into(),
459            exp: (OffsetDateTime::now_utc() + Duration::days(1)).unix_timestamp(),
460        };
461
462        let token = jsonwebtoken::encode(
463            &jsonwebtoken::Header::default(),
464            &claim,
465            &EncodingKey::from_secret(b"ABCDEF"),
466        )
467        .unwrap();
468        let content = access(&service, &token).await;
469        assert!(content.contains("hello"));
470
471        let content = TestClient::get(format!("http://127.0.0.1:5801/hello?jwt_token={token}"))
472            .send(&service)
473            .await
474            .take_string()
475            .await
476            .unwrap();
477        assert!(content.contains("hello"));
478        let content = TestClient::get("http://127.0.0.1:5801/hello")
479            .add_header("Cookie", format!("jwt_token={token}"), true)
480            .send(&service)
481            .await
482            .take_string()
483            .await
484            .unwrap();
485        assert!(content.contains("hello"));
486
487        let token = jsonwebtoken::encode(
488            &jsonwebtoken::Header::default(),
489            &claim,
490            &EncodingKey::from_secret(b"ABCDEFG"),
491        )
492        .unwrap();
493        let content = access(&service, &token).await;
494        assert!(content.contains("Forbidden"));
495    }
496}