beep_auth/infrastructure/
keycloak_repository.rs

1use std::sync::Arc;
2
3use crate::domain::{
4    models::{claims::Claims, errors::AuthError, identity::Identity},
5    ports::AuthRepository,
6};
7use chrono::Utc;
8use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Serialize, Deserialize)]
13struct Jwks {
14    keys: Vec<Jwk>,
15}
16
17#[derive(Debug, Serialize, Deserialize)]
18struct Jwk {
19    kid: String,
20    n: String,
21    e: String,
22}
23
24#[derive(Clone)]
25pub struct KeycloakAuthRepository {
26    pub http: Arc<Client>,
27    pub issuer: String,
28    pub audience: Option<String>,
29}
30
31impl KeycloakAuthRepository {
32    pub fn new(issuer: impl Into<String>, audience: Option<String>) -> Self {
33        Self {
34            http: Arc::new(Client::new()),
35            issuer: issuer.into(),
36            audience,
37        }
38    }
39
40    async fn fetch_jwks(&self) -> Result<Jwks, AuthError> {
41        let url = format!("{}/protocol/openid-connect/certs", self.issuer);
42
43        let resp = self
44            .http
45            .get(url)
46            .send()
47            .await
48            .map_err(|e| AuthError::Network {
49                message: e.to_string(),
50            })?;
51
52        if resp.status().is_client_error() || resp.status().is_server_error() {
53            return Err(AuthError::Network {
54                message: format!("failed to fetch jwks: {}", resp.status()),
55            });
56        }
57
58        let bytes = resp.bytes().await.map_err(|e| AuthError::Network {
59            message: e.to_string(),
60        })?;
61
62        let jwks: Jwks = serde_json::from_slice(&bytes).map_err(|e| AuthError::Network {
63            message: e.to_string(),
64        })?;
65
66        Ok(jwks)
67    }
68}
69
70impl AuthRepository for KeycloakAuthRepository {
71    async fn validate_token(
72        &self,
73        token: &str,
74    ) -> Result<crate::domain::models::claims::Claims, AuthError> {
75        let header = decode_header(token).map_err(|e| AuthError::InvalidToken {
76            message: e.to_string(),
77        })?;
78
79        let kid = header.kid.ok_or_else(|| AuthError::InvalidToken {
80            message: "missing kind".into(),
81        })?;
82
83        let jwks = self.fetch_jwks().await?;
84
85        let keys = jwks.keys;
86
87        let key = keys
88            .iter()
89            .find(|k| k.kid == kid)
90            .ok_or_else(|| AuthError::KeyNotFound { key: kid.clone() })?;
91
92        let decoding_key =
93            DecodingKey::from_rsa_components(&key.n, &key.e).map_err(|e| AuthError::Internal {
94                message: e.to_string(),
95            })?;
96
97        let mut validation = Validation::new(Algorithm::RS256);
98
99        validation.validate_aud = false;
100
101        let data = decode::<Claims>(token, &decoding_key, &validation).map_err(|e| {
102            AuthError::InvalidToken {
103                message: e.to_string(),
104            }
105        })?;
106
107        let claims = data.claims;
108
109        let now = Utc::now().timestamp();
110
111        if claims.exp.unwrap_or(0) < now {
112            return Err(AuthError::Expired);
113        }
114
115        Ok(claims)
116    }
117
118    async fn identify(
119        &self,
120        token: &str,
121    ) -> Result<crate::domain::models::identity::Identity, AuthError> {
122        let claims = self.validate_token(token).await?;
123
124        Ok(Identity::from(claims))
125    }
126}