hermod_api/services/
jwt.rs

1//! Contains code for serializing and deserializng JSON Web Tokens.
2use crate::{
3    db::{get_user_by_id, User},
4    handlers::ApplicationError,
5    services::auth::AuthenticationError,
6};
7use actix_web::HttpRequest;
8use chrono::{Duration, Utc};
9use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
10use sqlx::PgPool;
11use uuid::Uuid;
12
13/// Service to manage JWT tokens
14pub struct JwtClient {
15    auth_key: String,
16    pool: PgPool,
17}
18
19impl JwtClient {
20    /// Create a new JWT Client
21    pub fn new(auth_key: String, pool: PgPool) -> Self {
22        Self { auth_key, pool }
23    }
24
25    /// Encode a JWT token for the given user
26    #[tracing::instrument(name = "services::jwt::encode_token", skip(self))]
27    pub fn encode_token(&self, user_id: Uuid) -> Result<String, AuthenticationError> {
28        Ok(self.encode_token_with_exp(user_id, 60 * 60)?)
29    }
30
31    /// Encode a JWT token with a custom expiration time
32    #[tracing::instrument(name = "services::jwt::encode_token_with_exp", skip(self))]
33    pub fn encode_token_with_exp(
34        &self,
35        user_id: Uuid,
36        exp_offset: i64,
37    ) -> Result<String, anyhow::Error> {
38        let my_iat = Utc::now().timestamp();
39        let my_exp = Utc::now()
40            .checked_add_signed(Duration::seconds(exp_offset))
41            .expect("invalid timestamp")
42            .timestamp();
43
44        let my_claims = Claims {
45            sub: user_id.to_string(),
46            iat: my_iat as usize,
47            exp: my_exp as usize,
48        };
49
50        Ok(encode(
51            &Header::default(),
52            &my_claims,
53            &EncodingKey::from_secret(self.auth_key.as_bytes()),
54        )?)
55    }
56
57    /// Decode a JWT token and provide its claims if it is valid
58    #[tracing::instrument(name = "services::jwt::decode_token", skip(self))]
59    pub fn decode_token(&self, token: &str) -> Result<Claims, anyhow::Error> {
60        Ok(decode::<Claims>(
61            token,
62            &DecodingKey::from_secret(self.auth_key.as_bytes()),
63            &Validation::default(),
64        )?
65        .claims)
66    }
67
68    /// Given an incoming HTTP request, return the user currently logged in. If there is no
69    /// user logged in, generate a `403 Unauthorized` error response.
70    #[tracing::instrument(name = "services::jwt::user_or_403", skip(self, request))]
71    pub async fn user_or_403(&self, request: HttpRequest) -> Result<User, ApplicationError> {
72        let auth_header = request
73            .headers()
74            .get("Authorization")
75            .ok_or(AuthenticationError::Unauthorized)?;
76        let token = auth_header
77            .to_str()
78            .map_err(|e| AuthenticationError::UnexpectedError(anyhow::anyhow!(e)))?;
79        let claims = self
80            .decode_token(token)
81            .map_err(|e| AuthenticationError::UnexpectedError(anyhow::anyhow!(e)))?;
82        let user = get_user_by_id(claims.sub, &self.pool).await?;
83        tracing::Span::current().record("username", &tracing::field::display(&user.username));
84        tracing::Span::current().record("user_id", &tracing::field::display(&user.id));
85        Ok(user)
86    }
87}
88
89/// Claims represents the JWT payload.
90#[derive(serde::Deserialize, serde::Serialize, Debug)]
91pub struct Claims {
92    pub sub: String,
93    pub iat: usize,
94    pub exp: usize,
95}