hermod_api/services/
jwt.rs1use crate::{
3 db::{get_user_by_id, User},
4 handlers::ApplicationError,
5 services::auth::AuthenticationError,
6};
7use actix_web::HttpRequest;
8use chrono::{Duration, Utc};
9use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
10use sqlx::PgPool;
11use uuid::Uuid;
12
13pub struct JwtClient {
15 auth_key: String,
16 pool: PgPool,
17}
18
19impl JwtClient {
20 pub fn new(auth_key: String, pool: PgPool) -> Self {
22 Self { auth_key, pool }
23 }
24
25 #[tracing::instrument(name = "services::jwt::encode_token", skip(self))]
27 pub fn encode_token(&self, user_id: Uuid) -> Result<String, AuthenticationError> {
28 Ok(self.encode_token_with_exp(user_id, 60 * 60)?)
29 }
30
31 #[tracing::instrument(name = "services::jwt::encode_token_with_exp", skip(self))]
33 pub fn encode_token_with_exp(
34 &self,
35 user_id: Uuid,
36 exp_offset: i64,
37 ) -> Result<String, anyhow::Error> {
38 let my_iat = Utc::now().timestamp();
39 let my_exp = Utc::now()
40 .checked_add_signed(Duration::seconds(exp_offset))
41 .expect("invalid timestamp")
42 .timestamp();
43
44 let my_claims = Claims {
45 sub: user_id.to_string(),
46 iat: my_iat as usize,
47 exp: my_exp as usize,
48 };
49
50 Ok(encode(
51 &Header::default(),
52 &my_claims,
53 &EncodingKey::from_secret(self.auth_key.as_bytes()),
54 )?)
55 }
56
57 #[tracing::instrument(name = "services::jwt::decode_token", skip(self))]
59 pub fn decode_token(&self, token: &str) -> Result<Claims, anyhow::Error> {
60 Ok(decode::<Claims>(
61 token,
62 &DecodingKey::from_secret(self.auth_key.as_bytes()),
63 &Validation::default(),
64 )?
65 .claims)
66 }
67
68 #[tracing::instrument(name = "services::jwt::user_or_403", skip(self, request))]
71 pub async fn user_or_403(&self, request: HttpRequest) -> Result<User, ApplicationError> {
72 let auth_header = request
73 .headers()
74 .get("Authorization")
75 .ok_or(AuthenticationError::Unauthorized)?;
76 let token = auth_header
77 .to_str()
78 .map_err(|e| AuthenticationError::UnexpectedError(anyhow::anyhow!(e)))?;
79 let claims = self
80 .decode_token(token)
81 .map_err(|e| AuthenticationError::UnexpectedError(anyhow::anyhow!(e)))?;
82 let user = get_user_by_id(claims.sub, &self.pool).await?;
83 tracing::Span::current().record("username", &tracing::field::display(&user.username));
84 tracing::Span::current().record("user_id", &tracing::field::display(&user.id));
85 Ok(user)
86 }
87}
88
89#[derive(serde::Deserialize, serde::Serialize, Debug)]
91pub struct Claims {
92 pub sub: String,
93 pub iat: usize,
94 pub exp: usize,
95}