Skip to main content

cdk_common/auth/
oidc.rs

1//! Open Id Connect
2
3use std::collections::HashMap;
4use std::ops::Deref;
5use std::sync::Arc;
6
7use jsonwebtoken::jwk::{AlgorithmParameters, JwkSet};
8use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
9use serde::Deserialize;
10#[cfg(feature = "wallet")]
11use serde::Serialize;
12use thiserror::Error;
13use tokio::sync::RwLock;
14use tracing::instrument;
15
16use crate::HttpClient;
17
18fn validate_client_id_claim(
19    claim_name: &str,
20    claim_value: &serde_json::Value,
21    client_id: &str,
22) -> Result<(), Error> {
23    let Some(token_client_id) = claim_value.as_str() else {
24        tracing::warn!("{} claim is not a string", claim_name);
25        return Err(Error::InvalidClientId);
26    };
27
28    if token_client_id != client_id {
29        tracing::warn!(
30            "Client ID ({}) mismatch: expected {}, got {}",
31            claim_name,
32            client_id,
33            token_client_id
34        );
35        return Err(Error::InvalidClientId);
36    }
37
38    Ok(())
39}
40
41fn validate_client_id_claims(
42    claims: &HashMap<String, serde_json::Value>,
43    client_id: &str,
44) -> Result<(), Error> {
45    match claims.get("client_id") {
46        Some(token_client_id) => validate_client_id_claim("client_id", token_client_id, client_id),
47        None => match claims.get("azp") {
48            Some(azp) => validate_client_id_claim("azp", azp, client_id),
49            None => {
50                tracing::warn!("CAT missing client_id or azp claim for configured client ID");
51                Err(Error::InvalidClientId)
52            }
53        },
54    }
55}
56
57/// OIDC Error
58#[derive(Debug, Error)]
59pub enum Error {
60    /// From HTTP error
61    #[error(transparent)]
62    Http(#[from] crate::HttpError),
63    /// From JWT error
64    #[error(transparent)]
65    Jwt(#[from] jsonwebtoken::errors::Error),
66    /// Missing kid header
67    #[error("Missing kid header")]
68    MissingKidHeader,
69    /// Missing jwk header
70    #[error("Missing jwk")]
71    MissingJwkHeader,
72    /// Unsupported Algo
73    #[error("Unsupported signing algo")]
74    UnsupportedSigningAlgo,
75    /// Invalid Client ID
76    #[error("Invalid Client ID")]
77    InvalidClientId,
78}
79
80impl From<Error> for crate::error::Error {
81    fn from(value: Error) -> Self {
82        tracing::debug!("Clear auth verification failed: {}", value);
83        crate::error::Error::ClearAuthFailed
84    }
85}
86
87/// Open Id Config
88#[derive(Debug, Clone, Deserialize)]
89pub struct OidcConfig {
90    /// URI for the JSON Web Key Set
91    pub jwks_uri: String,
92    /// Token issuer identifier
93    pub issuer: String,
94    /// Token endpoint URL
95    pub token_endpoint: String,
96    /// Device authorization endpoint URL
97    pub device_authorization_endpoint: String,
98}
99
100/// Http Client
101#[derive(Debug, Clone)]
102pub struct OidcClient {
103    client: HttpClient,
104    openid_discovery: String,
105    client_id: Option<String>,
106    oidc_config: Arc<RwLock<Option<OidcConfig>>>,
107    jwks_set: Arc<RwLock<Option<JwkSet>>>,
108}
109
110/// OAuth2 grant type
111#[cfg(feature = "wallet")]
112#[derive(Debug, Clone, Copy, Serialize)]
113#[serde(rename_all = "snake_case")]
114pub enum GrantType {
115    /// Refresh token grant
116    RefreshToken,
117}
118
119/// Request to refresh an access token
120#[cfg(feature = "wallet")]
121#[derive(Debug, Clone, Serialize)]
122pub struct RefreshTokenRequest {
123    /// The grant type for this request
124    pub grant_type: GrantType,
125    /// OAuth2 client identifier
126    pub client_id: String,
127    /// The refresh token to exchange
128    pub refresh_token: String,
129}
130
131/// Response from token endpoint
132#[cfg(feature = "wallet")]
133#[derive(Debug, Clone, Deserialize)]
134pub struct TokenResponse {
135    /// The access token issued by the authorization server
136    pub access_token: String,
137    /// Optional refresh token for obtaining new access tokens
138    pub refresh_token: Option<String>,
139    /// Optional lifetime in seconds of the access token
140    pub expires_in: Option<i64>,
141    /// The type of token issued (typically "Bearer")
142    pub token_type: String,
143}
144
145impl OidcClient {
146    /// Create new [`OidcClient`]
147    pub fn new(openid_discovery: String, client_id: Option<String>) -> Self {
148        Self {
149            client: HttpClient::new(),
150            openid_discovery,
151            client_id,
152            oidc_config: Arc::new(RwLock::new(None)),
153            jwks_set: Arc::new(RwLock::new(None)),
154        }
155    }
156
157    /// Get client id
158    pub fn client_id(&self) -> Option<String> {
159        self.client_id.clone()
160    }
161
162    /// Get config from oidc server
163    #[instrument(skip(self))]
164    pub async fn get_oidc_config(&self) -> Result<OidcConfig, Error> {
165        tracing::debug!("Getting oidc config");
166        let oidc_config: OidcConfig = self.client.fetch(&self.openid_discovery).await?;
167
168        let mut current_config = self.oidc_config.write().await;
169
170        *current_config = Some(oidc_config.clone());
171
172        Ok(oidc_config)
173    }
174
175    /// Get jwk set
176    #[instrument(skip(self))]
177    pub async fn get_jwkset(&self, jwks_uri: &str) -> Result<JwkSet, Error> {
178        tracing::debug!("Getting jwks set");
179        let jwks_set: JwkSet = self.client.fetch(jwks_uri).await?;
180
181        let mut current_set = self.jwks_set.write().await;
182
183        *current_set = Some(jwks_set.clone());
184
185        Ok(jwks_set)
186    }
187
188    /// Verify cat token
189    #[instrument(skip_all)]
190    pub async fn verify_cat(&self, cat_jwt: &str) -> Result<(), Error> {
191        tracing::debug!("Verifying cat");
192        let header = decode_header(cat_jwt)?;
193
194        let kid = header.kid.ok_or(Error::MissingKidHeader)?;
195
196        let oidc_config = {
197            let locked = self.oidc_config.read().await;
198            match locked.deref() {
199                Some(config) => config.clone(),
200                None => {
201                    drop(locked);
202                    self.get_oidc_config().await?
203                }
204            }
205        };
206
207        let jwks = {
208            let locked = self.jwks_set.read().await;
209            match locked.deref() {
210                Some(set) => set.clone(),
211                None => {
212                    drop(locked);
213                    self.get_jwkset(&oidc_config.jwks_uri).await?
214                }
215            }
216        };
217
218        let jwk = match jwks.find(&kid) {
219            Some(jwk) => jwk.clone(),
220            None => {
221                let refreshed_jwks = self.get_jwkset(&oidc_config.jwks_uri).await?;
222                refreshed_jwks
223                    .find(&kid)
224                    .ok_or(Error::MissingKidHeader)?
225                    .clone()
226            }
227        };
228
229        let decoding_key = match &jwk.algorithm {
230            AlgorithmParameters::RSA(rsa) => DecodingKey::from_rsa_components(&rsa.n, &rsa.e)?,
231            AlgorithmParameters::EllipticCurve(ecdsa) => {
232                DecodingKey::from_ec_components(&ecdsa.x, &ecdsa.y)?
233            }
234            _ => return Err(Error::UnsupportedSigningAlgo),
235        };
236
237        let validation = {
238            let mut validation = Validation::new(header.alg);
239            validation.validate_exp = true;
240            validation.validate_aud = false;
241            validation.set_issuer(&[oidc_config.issuer]);
242            validation
243        };
244
245        match decode::<HashMap<String, serde_json::Value>>(cat_jwt, &decoding_key, &validation) {
246            Ok(claims) => {
247                tracing::debug!("Successfully verified cat");
248                if let Some(client_id) = &self.client_id {
249                    validate_client_id_claims(&claims.claims, client_id)?;
250                }
251            }
252            Err(err) => {
253                tracing::debug!("Could not verify cat: {}", err);
254                return Err(err.into());
255            }
256        }
257
258        Ok(())
259    }
260
261    /// Get new access token using refresh token
262    #[cfg(feature = "wallet")]
263    pub async fn refresh_access_token(
264        &self,
265        client_id: String,
266        refresh_token: String,
267    ) -> Result<TokenResponse, Error> {
268        let token_url = self.get_oidc_config().await?.token_endpoint;
269
270        let request = RefreshTokenRequest {
271            grant_type: GrantType::RefreshToken,
272            client_id,
273            refresh_token,
274        };
275
276        let response: TokenResponse = self.client.post_form(&token_url, &request).await?;
277
278        Ok(response)
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use serde_json::json;
285
286    use super::*;
287
288    fn claims(value: serde_json::Value) -> HashMap<String, serde_json::Value> {
289        serde_json::from_value(value).expect("claims should be an object")
290    }
291
292    #[test]
293    fn validate_client_id_claims_accepts_client_id() {
294        let claims = claims(json!({
295            "client_id": "expected-client",
296            "azp": "other-client",
297        }));
298
299        assert!(validate_client_id_claims(&claims, "expected-client").is_ok());
300    }
301
302    #[test]
303    fn validate_client_id_claims_accepts_azp_fallback() {
304        let claims = claims(json!({
305            "azp": "expected-client",
306        }));
307
308        assert!(validate_client_id_claims(&claims, "expected-client").is_ok());
309    }
310
311    #[test]
312    fn validate_client_id_claims_rejects_missing_claims() {
313        let claims = claims(json!({
314            "sub": "user",
315        }));
316
317        assert!(matches!(
318            validate_client_id_claims(&claims, "expected-client"),
319            Err(Error::InvalidClientId)
320        ));
321    }
322
323    #[test]
324    fn validate_client_id_claims_rejects_non_string_client_id() {
325        let claims = claims(json!({
326            "client_id": null,
327            "azp": "expected-client",
328        }));
329
330        assert!(matches!(
331            validate_client_id_claims(&claims, "expected-client"),
332            Err(Error::InvalidClientId)
333        ));
334    }
335
336    #[test]
337    fn validate_client_id_claims_rejects_non_string_azp() {
338        let claims = claims(json!({
339            "azp": 42,
340        }));
341
342        assert!(matches!(
343            validate_client_id_claims(&claims, "expected-client"),
344            Err(Error::InvalidClientId)
345        ));
346    }
347
348    #[test]
349    fn validate_client_id_claims_rejects_mismatch() {
350        let claims = claims(json!({
351            "client_id": "other-client",
352        }));
353
354        assert!(matches!(
355            validate_client_id_claims(&claims, "expected-client"),
356            Err(Error::InvalidClientId)
357        ));
358    }
359}