Skip to main content

cdk_common/auth/
oidc.rs

1//! Open Id Connect
2
3use std::collections::HashMap;
4use std::ops::Deref;
5use std::sync::Arc;
6
7use jsonwebtoken::jwk::{AlgorithmParameters, JwkSet};
8use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
9use serde::Deserialize;
10#[cfg(feature = "wallet")]
11use serde::Serialize;
12use thiserror::Error;
13use tokio::sync::RwLock;
14use tracing::instrument;
15
16use crate::HttpClient;
17
18/// OIDC Error
19#[derive(Debug, Error)]
20pub enum Error {
21    /// From HTTP error
22    #[error(transparent)]
23    Http(#[from] crate::HttpError),
24    /// From JWT error
25    #[error(transparent)]
26    Jwt(#[from] jsonwebtoken::errors::Error),
27    /// Missing kid header
28    #[error("Missing kid header")]
29    MissingKidHeader,
30    /// Missing jwk header
31    #[error("Missing jwk")]
32    MissingJwkHeader,
33    /// Unsupported Algo
34    #[error("Unsupported signing algo")]
35    UnsupportedSigningAlgo,
36    /// Invalid Client ID
37    #[error("Invalid Client ID")]
38    InvalidClientId,
39}
40
41impl From<Error> for crate::error::Error {
42    fn from(value: Error) -> Self {
43        tracing::debug!("Clear auth verification failed: {}", value);
44        crate::error::Error::ClearAuthFailed
45    }
46}
47
48/// Open Id Config
49#[derive(Debug, Clone, Deserialize)]
50pub struct OidcConfig {
51    /// URI for the JSON Web Key Set
52    pub jwks_uri: String,
53    /// Token issuer identifier
54    pub issuer: String,
55    /// Token endpoint URL
56    pub token_endpoint: String,
57    /// Device authorization endpoint URL
58    pub device_authorization_endpoint: String,
59}
60
61/// Http Client
62#[derive(Debug, Clone)]
63pub struct OidcClient {
64    client: HttpClient,
65    openid_discovery: String,
66    client_id: Option<String>,
67    oidc_config: Arc<RwLock<Option<OidcConfig>>>,
68    jwks_set: Arc<RwLock<Option<JwkSet>>>,
69}
70
71/// OAuth2 grant type
72#[cfg(feature = "wallet")]
73#[derive(Debug, Clone, Copy, Serialize)]
74#[serde(rename_all = "snake_case")]
75pub enum GrantType {
76    /// Refresh token grant
77    RefreshToken,
78}
79
80/// Request to refresh an access token
81#[cfg(feature = "wallet")]
82#[derive(Debug, Clone, Serialize)]
83pub struct RefreshTokenRequest {
84    /// The grant type for this request
85    pub grant_type: GrantType,
86    /// OAuth2 client identifier
87    pub client_id: String,
88    /// The refresh token to exchange
89    pub refresh_token: String,
90}
91
92/// Response from token endpoint
93#[cfg(feature = "wallet")]
94#[derive(Debug, Clone, Deserialize)]
95pub struct TokenResponse {
96    /// The access token issued by the authorization server
97    pub access_token: String,
98    /// Optional refresh token for obtaining new access tokens
99    pub refresh_token: Option<String>,
100    /// Optional lifetime in seconds of the access token
101    pub expires_in: Option<i64>,
102    /// The type of token issued (typically "Bearer")
103    pub token_type: String,
104}
105
106impl OidcClient {
107    /// Create new [`OidcClient`]
108    pub fn new(openid_discovery: String, client_id: Option<String>) -> Self {
109        Self {
110            client: HttpClient::new(),
111            openid_discovery,
112            client_id,
113            oidc_config: Arc::new(RwLock::new(None)),
114            jwks_set: Arc::new(RwLock::new(None)),
115        }
116    }
117
118    /// Get client id
119    pub fn client_id(&self) -> Option<String> {
120        self.client_id.clone()
121    }
122
123    /// Get config from oidc server
124    #[instrument(skip(self))]
125    pub async fn get_oidc_config(&self) -> Result<OidcConfig, Error> {
126        tracing::debug!("Getting oidc config");
127        let oidc_config: OidcConfig = self.client.fetch(&self.openid_discovery).await?;
128
129        let mut current_config = self.oidc_config.write().await;
130
131        *current_config = Some(oidc_config.clone());
132
133        Ok(oidc_config)
134    }
135
136    /// Get jwk set
137    #[instrument(skip(self))]
138    pub async fn get_jwkset(&self, jwks_uri: &str) -> Result<JwkSet, Error> {
139        tracing::debug!("Getting jwks set");
140        let jwks_set: JwkSet = self.client.fetch(jwks_uri).await?;
141
142        let mut current_set = self.jwks_set.write().await;
143
144        *current_set = Some(jwks_set.clone());
145
146        Ok(jwks_set)
147    }
148
149    /// Verify cat token
150    #[instrument(skip_all)]
151    pub async fn verify_cat(&self, cat_jwt: &str) -> Result<(), Error> {
152        tracing::debug!("Verifying cat");
153        let header = decode_header(cat_jwt)?;
154
155        let kid = header.kid.ok_or(Error::MissingKidHeader)?;
156
157        let oidc_config = {
158            let locked = self.oidc_config.read().await;
159            match locked.deref() {
160                Some(config) => config.clone(),
161                None => {
162                    drop(locked);
163                    self.get_oidc_config().await?
164                }
165            }
166        };
167
168        let jwks = {
169            let locked = self.jwks_set.read().await;
170            match locked.deref() {
171                Some(set) => set.clone(),
172                None => {
173                    drop(locked);
174                    self.get_jwkset(&oidc_config.jwks_uri).await?
175                }
176            }
177        };
178
179        let jwk = match jwks.find(&kid) {
180            Some(jwk) => jwk.clone(),
181            None => {
182                let refreshed_jwks = self.get_jwkset(&oidc_config.jwks_uri).await?;
183                refreshed_jwks
184                    .find(&kid)
185                    .ok_or(Error::MissingKidHeader)?
186                    .clone()
187            }
188        };
189
190        let decoding_key = match &jwk.algorithm {
191            AlgorithmParameters::RSA(rsa) => DecodingKey::from_rsa_components(&rsa.n, &rsa.e)?,
192            AlgorithmParameters::EllipticCurve(ecdsa) => {
193                DecodingKey::from_ec_components(&ecdsa.x, &ecdsa.y)?
194            }
195            _ => return Err(Error::UnsupportedSigningAlgo),
196        };
197
198        let validation = {
199            let mut validation = Validation::new(header.alg);
200            validation.validate_exp = true;
201            validation.validate_aud = false;
202            validation.set_issuer(&[oidc_config.issuer]);
203            validation
204        };
205
206        match decode::<HashMap<String, serde_json::Value>>(cat_jwt, &decoding_key, &validation) {
207            Ok(claims) => {
208                tracing::debug!("Successfully verified cat");
209                if let Some(client_id) = &self.client_id {
210                    if let Some(token_client_id) = claims.claims.get("client_id") {
211                        if let Some(token_client_id_value) = token_client_id.as_str() {
212                            if token_client_id_value != client_id {
213                                tracing::warn!(
214                                    "Client ID mismatch: expected {}, got {}",
215                                    client_id,
216                                    token_client_id_value
217                                );
218                                return Err(Error::InvalidClientId);
219                            }
220                        }
221                    } else if let Some(azp) = claims.claims.get("azp") {
222                        if let Some(azp_value) = azp.as_str() {
223                            if azp_value != client_id {
224                                tracing::warn!(
225                                    "Client ID (azp) mismatch: expected {}, got {}",
226                                    client_id,
227                                    azp_value
228                                );
229                                return Err(Error::InvalidClientId);
230                            }
231                        }
232                    }
233                }
234            }
235            Err(err) => {
236                tracing::debug!("Could not verify cat: {}", err);
237                return Err(err.into());
238            }
239        }
240
241        Ok(())
242    }
243
244    /// Get new access token using refresh token
245    #[cfg(feature = "wallet")]
246    pub async fn refresh_access_token(
247        &self,
248        client_id: String,
249        refresh_token: String,
250    ) -> Result<TokenResponse, Error> {
251        let token_url = self.get_oidc_config().await?.token_endpoint;
252
253        let request = RefreshTokenRequest {
254            grant_type: GrantType::RefreshToken,
255            client_id,
256            refresh_token,
257        };
258
259        let response: TokenResponse = self.client.post_form(&token_url, &request).await?;
260
261        Ok(response)
262    }
263}