cdk/
oidc_client.rs

1//! Open Id Connect
2
3use std::collections::HashMap;
4use std::ops::Deref;
5use std::sync::Arc;
6
7use jsonwebtoken::jwk::{AlgorithmParameters, JwkSet};
8use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
9use reqwest::Client;
10use serde::Deserialize;
11#[cfg(feature = "wallet")]
12use serde::Serialize;
13use thiserror::Error;
14use tokio::sync::RwLock;
15use tracing::instrument;
16
17/// OIDC Error
18#[derive(Debug, Error)]
19pub enum Error {
20    /// From Reqwest error
21    #[error(transparent)]
22    Reqwest(#[from] reqwest::Error),
23    /// From Reqwest error
24    #[error(transparent)]
25    Jwt(#[from] jsonwebtoken::errors::Error),
26    /// Missing kid header
27    #[error("Missing kid header")]
28    MissingKidHeader,
29    /// Missing jwk header
30    #[error("Missing jwk")]
31    MissingJwkHeader,
32    /// Unsupported Algo
33    #[error("Unsupported signing algo")]
34    UnsupportedSigningAlgo,
35    /// Access token not returned
36    #[error("Error getting access token")]
37    AccessTokenMissing,
38}
39
40impl From<Error> for cdk_common::error::Error {
41    fn from(value: Error) -> Self {
42        tracing::debug!("Clear auth verification failed: {}", value);
43        cdk_common::error::Error::ClearAuthFailed
44    }
45}
46
47/// Open Id Config
48#[derive(Debug, Clone, Deserialize)]
49pub struct OidcConfig {
50    pub jwks_uri: String,
51    pub issuer: String,
52    pub token_endpoint: String,
53    pub device_authorization_endpoint: String,
54}
55
56/// Http Client
57#[derive(Debug, Clone)]
58pub struct OidcClient {
59    client: Client,
60    openid_discovery: String,
61    oidc_config: Arc<RwLock<Option<OidcConfig>>>,
62    jwks_set: Arc<RwLock<Option<JwkSet>>>,
63}
64
65#[cfg(feature = "wallet")]
66#[derive(Debug, Clone, Copy, Serialize)]
67#[serde(rename_all = "snake_case")]
68pub enum GrantType {
69    RefreshToken,
70}
71
72#[cfg(feature = "wallet")]
73#[derive(Debug, Clone, Serialize)]
74pub struct AccessTokenRequest {
75    pub grant_type: GrantType,
76    pub client_id: String,
77    pub username: String,
78    pub password: String,
79}
80
81#[cfg(feature = "wallet")]
82#[derive(Debug, Clone, Serialize)]
83pub struct RefreshTokenRequest {
84    pub grant_type: GrantType,
85    pub client_id: String,
86    pub refresh_token: String,
87}
88
89#[cfg(feature = "wallet")]
90#[derive(Debug, Clone, Deserialize)]
91pub struct TokenResponse {
92    pub access_token: String,
93    pub refresh_token: Option<String>,
94    pub expires_in: Option<i64>,
95    pub token_type: String,
96}
97
98impl OidcClient {
99    /// Create new [`OidcClient`]
100    pub fn new(openid_discovery: String) -> Self {
101        Self {
102            client: Client::new(),
103            openid_discovery,
104            oidc_config: Arc::new(RwLock::new(None)),
105            jwks_set: Arc::new(RwLock::new(None)),
106        }
107    }
108
109    /// Get config from oidc server
110    #[instrument(skip(self))]
111    pub async fn get_oidc_config(&self) -> Result<OidcConfig, Error> {
112        tracing::debug!("Getting oidc config");
113        let oidc_config = self
114            .client
115            .get(&self.openid_discovery)
116            .send()
117            .await?
118            .json::<OidcConfig>()
119            .await?;
120
121        let mut current_config = self.oidc_config.write().await;
122
123        *current_config = Some(oidc_config.clone());
124
125        Ok(oidc_config)
126    }
127
128    /// Get jwk set
129    #[instrument(skip(self))]
130    pub async fn get_jwkset(&self, jwks_uri: &str) -> Result<JwkSet, Error> {
131        tracing::debug!("Getting jwks set");
132        let jwks_set = self
133            .client
134            .get(jwks_uri)
135            .send()
136            .await?
137            .json::<JwkSet>()
138            .await?;
139
140        let mut current_set = self.jwks_set.write().await;
141
142        *current_set = Some(jwks_set.clone());
143
144        Ok(jwks_set)
145    }
146
147    /// Verify cat token
148    #[instrument(skip_all)]
149    pub async fn verify_cat(&self, cat_jwt: &str) -> Result<(), Error> {
150        tracing::debug!("Verifying cat");
151        let header = decode_header(cat_jwt)?;
152
153        let kid = header.kid.ok_or(Error::MissingKidHeader)?;
154
155        let oidc_config = {
156            let locked = self.oidc_config.read().await;
157            match locked.deref() {
158                Some(config) => config.clone(),
159                None => {
160                    drop(locked);
161                    self.get_oidc_config().await?
162                }
163            }
164        };
165
166        let jwks = {
167            let locked = self.jwks_set.read().await;
168            match locked.deref() {
169                Some(set) => set.clone(),
170                None => {
171                    drop(locked);
172                    self.get_jwkset(&oidc_config.jwks_uri).await?
173                }
174            }
175        };
176
177        let jwk = match jwks.find(&kid) {
178            Some(jwk) => jwk.clone(),
179            None => {
180                let refreshed_jwks = self.get_jwkset(&oidc_config.jwks_uri).await?;
181                refreshed_jwks
182                    .find(&kid)
183                    .ok_or(Error::MissingKidHeader)?
184                    .clone()
185            }
186        };
187
188        let decoding_key = match &jwk.algorithm {
189            AlgorithmParameters::RSA(rsa) => DecodingKey::from_rsa_components(&rsa.n, &rsa.e)?,
190            AlgorithmParameters::EllipticCurve(ecdsa) => {
191                DecodingKey::from_ec_components(&ecdsa.x, &ecdsa.y)?
192            }
193            _ => return Err(Error::UnsupportedSigningAlgo),
194        };
195
196        let validation = {
197            let mut validation = Validation::new(header.alg);
198            validation.validate_exp = true;
199            validation.validate_aud = false;
200            validation.set_issuer(&[oidc_config.issuer]);
201            validation
202        };
203
204        if let Err(err) =
205            decode::<HashMap<String, serde_json::Value>>(cat_jwt, &decoding_key, &validation)
206        {
207            tracing::debug!("Could not verify cat: {}", err);
208            return Err(err.into());
209        }
210
211        Ok(())
212    }
213
214    /// Get new access token using refresh token
215    #[cfg(feature = "wallet")]
216    pub async fn refresh_access_token(
217        &self,
218        client_id: String,
219        refresh_token: String,
220    ) -> Result<TokenResponse, Error> {
221        let token_url = self.get_oidc_config().await?.token_endpoint;
222
223        let request = RefreshTokenRequest {
224            grant_type: GrantType::RefreshToken,
225            client_id,
226            refresh_token,
227        };
228
229        let response = self
230            .client
231            .post(token_url)
232            .form(&request)
233            .send()
234            .await?
235            .json::<TokenResponse>()
236            .await?;
237
238        Ok(response)
239    }
240}