1use std::collections::HashMap;
4use std::ops::Deref;
5use std::sync::Arc;
6
7use jsonwebtoken::jwk::{AlgorithmParameters, JwkSet};
8use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
9use reqwest::Client;
10use serde::Deserialize;
11#[cfg(feature = "wallet")]
12use serde::Serialize;
13use thiserror::Error;
14use tokio::sync::RwLock;
15use tracing::instrument;
16
17#[derive(Debug, Error)]
19pub enum Error {
20 #[error(transparent)]
22 Reqwest(#[from] reqwest::Error),
23 #[error(transparent)]
25 Jwt(#[from] jsonwebtoken::errors::Error),
26 #[error("Missing kid header")]
28 MissingKidHeader,
29 #[error("Missing jwk")]
31 MissingJwkHeader,
32 #[error("Unsupported signing algo")]
34 UnsupportedSigningAlgo,
35 #[error("Error getting access token")]
37 AccessTokenMissing,
38}
39
40impl From<Error> for cdk_common::error::Error {
41 fn from(value: Error) -> Self {
42 tracing::debug!("Clear auth verification failed: {}", value);
43 cdk_common::error::Error::ClearAuthFailed
44 }
45}
46
47#[derive(Debug, Clone, Deserialize)]
49pub struct OidcConfig {
50 pub jwks_uri: String,
51 pub issuer: String,
52 pub token_endpoint: String,
53 pub device_authorization_endpoint: String,
54}
55
56#[derive(Debug, Clone)]
58pub struct OidcClient {
59 client: Client,
60 openid_discovery: String,
61 oidc_config: Arc<RwLock<Option<OidcConfig>>>,
62 jwks_set: Arc<RwLock<Option<JwkSet>>>,
63}
64
65#[cfg(feature = "wallet")]
66#[derive(Debug, Clone, Copy, Serialize)]
67#[serde(rename_all = "snake_case")]
68pub enum GrantType {
69 RefreshToken,
70}
71
72#[cfg(feature = "wallet")]
73#[derive(Debug, Clone, Serialize)]
74pub struct AccessTokenRequest {
75 pub grant_type: GrantType,
76 pub client_id: String,
77 pub username: String,
78 pub password: String,
79}
80
81#[cfg(feature = "wallet")]
82#[derive(Debug, Clone, Serialize)]
83pub struct RefreshTokenRequest {
84 pub grant_type: GrantType,
85 pub client_id: String,
86 pub refresh_token: String,
87}
88
89#[cfg(feature = "wallet")]
90#[derive(Debug, Clone, Deserialize)]
91pub struct TokenResponse {
92 pub access_token: String,
93 pub refresh_token: Option<String>,
94 pub expires_in: Option<i64>,
95 pub token_type: String,
96}
97
98impl OidcClient {
99 pub fn new(openid_discovery: String) -> Self {
101 Self {
102 client: Client::new(),
103 openid_discovery,
104 oidc_config: Arc::new(RwLock::new(None)),
105 jwks_set: Arc::new(RwLock::new(None)),
106 }
107 }
108
109 #[instrument(skip(self))]
111 pub async fn get_oidc_config(&self) -> Result<OidcConfig, Error> {
112 tracing::debug!("Getting oidc config");
113 let oidc_config = self
114 .client
115 .get(&self.openid_discovery)
116 .send()
117 .await?
118 .json::<OidcConfig>()
119 .await?;
120
121 let mut current_config = self.oidc_config.write().await;
122
123 *current_config = Some(oidc_config.clone());
124
125 Ok(oidc_config)
126 }
127
128 #[instrument(skip(self))]
130 pub async fn get_jwkset(&self, jwks_uri: &str) -> Result<JwkSet, Error> {
131 tracing::debug!("Getting jwks set");
132 let jwks_set = self
133 .client
134 .get(jwks_uri)
135 .send()
136 .await?
137 .json::<JwkSet>()
138 .await?;
139
140 let mut current_set = self.jwks_set.write().await;
141
142 *current_set = Some(jwks_set.clone());
143
144 Ok(jwks_set)
145 }
146
147 #[instrument(skip_all)]
149 pub async fn verify_cat(&self, cat_jwt: &str) -> Result<(), Error> {
150 tracing::debug!("Verifying cat");
151 let header = decode_header(cat_jwt)?;
152
153 let kid = header.kid.ok_or(Error::MissingKidHeader)?;
154
155 let oidc_config = {
156 let locked = self.oidc_config.read().await;
157 match locked.deref() {
158 Some(config) => config.clone(),
159 None => {
160 drop(locked);
161 self.get_oidc_config().await?
162 }
163 }
164 };
165
166 let jwks = {
167 let locked = self.jwks_set.read().await;
168 match locked.deref() {
169 Some(set) => set.clone(),
170 None => {
171 drop(locked);
172 self.get_jwkset(&oidc_config.jwks_uri).await?
173 }
174 }
175 };
176
177 let jwk = match jwks.find(&kid) {
178 Some(jwk) => jwk.clone(),
179 None => {
180 let refreshed_jwks = self.get_jwkset(&oidc_config.jwks_uri).await?;
181 refreshed_jwks
182 .find(&kid)
183 .ok_or(Error::MissingKidHeader)?
184 .clone()
185 }
186 };
187
188 let decoding_key = match &jwk.algorithm {
189 AlgorithmParameters::RSA(rsa) => DecodingKey::from_rsa_components(&rsa.n, &rsa.e)?,
190 AlgorithmParameters::EllipticCurve(ecdsa) => {
191 DecodingKey::from_ec_components(&ecdsa.x, &ecdsa.y)?
192 }
193 _ => return Err(Error::UnsupportedSigningAlgo),
194 };
195
196 let validation = {
197 let mut validation = Validation::new(header.alg);
198 validation.validate_exp = true;
199 validation.validate_aud = false;
200 validation.set_issuer(&[oidc_config.issuer]);
201 validation
202 };
203
204 if let Err(err) =
205 decode::<HashMap<String, serde_json::Value>>(cat_jwt, &decoding_key, &validation)
206 {
207 tracing::debug!("Could not verify cat: {}", err);
208 return Err(err.into());
209 }
210
211 Ok(())
212 }
213
214 #[cfg(feature = "wallet")]
216 pub async fn refresh_access_token(
217 &self,
218 client_id: String,
219 refresh_token: String,
220 ) -> Result<TokenResponse, Error> {
221 let token_url = self.get_oidc_config().await?.token_endpoint;
222
223 let request = RefreshTokenRequest {
224 grant_type: GrantType::RefreshToken,
225 client_id,
226 refresh_token,
227 };
228
229 let response = self
230 .client
231 .post(token_url)
232 .form(&request)
233 .send()
234 .await?
235 .json::<TokenResponse>()
236 .await?;
237
238 Ok(response)
239 }
240}