1use std::collections::HashMap;
4use std::ops::Deref;
5use std::sync::Arc;
6
7use jsonwebtoken::jwk::{AlgorithmParameters, JwkSet};
8use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
9use serde::Deserialize;
10#[cfg(feature = "wallet")]
11use serde::Serialize;
12use thiserror::Error;
13use tokio::sync::RwLock;
14use tracing::instrument;
15
16use crate::HttpClient;
17
18#[derive(Debug, Error)]
20pub enum Error {
21 #[error(transparent)]
23 Http(#[from] crate::HttpError),
24 #[error(transparent)]
26 Jwt(#[from] jsonwebtoken::errors::Error),
27 #[error("Missing kid header")]
29 MissingKidHeader,
30 #[error("Missing jwk")]
32 MissingJwkHeader,
33 #[error("Unsupported signing algo")]
35 UnsupportedSigningAlgo,
36 #[error("Invalid Client ID")]
38 InvalidClientId,
39}
40
41impl From<Error> for crate::error::Error {
42 fn from(value: Error) -> Self {
43 tracing::debug!("Clear auth verification failed: {}", value);
44 crate::error::Error::ClearAuthFailed
45 }
46}
47
48#[derive(Debug, Clone, Deserialize)]
50pub struct OidcConfig {
51 pub jwks_uri: String,
53 pub issuer: String,
55 pub token_endpoint: String,
57 pub device_authorization_endpoint: String,
59}
60
61#[derive(Debug, Clone)]
63pub struct OidcClient {
64 client: HttpClient,
65 openid_discovery: String,
66 client_id: Option<String>,
67 oidc_config: Arc<RwLock<Option<OidcConfig>>>,
68 jwks_set: Arc<RwLock<Option<JwkSet>>>,
69}
70
71#[cfg(feature = "wallet")]
73#[derive(Debug, Clone, Copy, Serialize)]
74#[serde(rename_all = "snake_case")]
75pub enum GrantType {
76 RefreshToken,
78}
79
80#[cfg(feature = "wallet")]
82#[derive(Debug, Clone, Serialize)]
83pub struct RefreshTokenRequest {
84 pub grant_type: GrantType,
86 pub client_id: String,
88 pub refresh_token: String,
90}
91
92#[cfg(feature = "wallet")]
94#[derive(Debug, Clone, Deserialize)]
95pub struct TokenResponse {
96 pub access_token: String,
98 pub refresh_token: Option<String>,
100 pub expires_in: Option<i64>,
102 pub token_type: String,
104}
105
106impl OidcClient {
107 pub fn new(openid_discovery: String, client_id: Option<String>) -> Self {
109 Self {
110 client: HttpClient::new(),
111 openid_discovery,
112 client_id,
113 oidc_config: Arc::new(RwLock::new(None)),
114 jwks_set: Arc::new(RwLock::new(None)),
115 }
116 }
117
118 pub fn client_id(&self) -> Option<String> {
120 self.client_id.clone()
121 }
122
123 #[instrument(skip(self))]
125 pub async fn get_oidc_config(&self) -> Result<OidcConfig, Error> {
126 tracing::debug!("Getting oidc config");
127 let oidc_config: OidcConfig = self.client.fetch(&self.openid_discovery).await?;
128
129 let mut current_config = self.oidc_config.write().await;
130
131 *current_config = Some(oidc_config.clone());
132
133 Ok(oidc_config)
134 }
135
136 #[instrument(skip(self))]
138 pub async fn get_jwkset(&self, jwks_uri: &str) -> Result<JwkSet, Error> {
139 tracing::debug!("Getting jwks set");
140 let jwks_set: JwkSet = self.client.fetch(jwks_uri).await?;
141
142 let mut current_set = self.jwks_set.write().await;
143
144 *current_set = Some(jwks_set.clone());
145
146 Ok(jwks_set)
147 }
148
149 #[instrument(skip_all)]
151 pub async fn verify_cat(&self, cat_jwt: &str) -> Result<(), Error> {
152 tracing::debug!("Verifying cat");
153 let header = decode_header(cat_jwt)?;
154
155 let kid = header.kid.ok_or(Error::MissingKidHeader)?;
156
157 let oidc_config = {
158 let locked = self.oidc_config.read().await;
159 match locked.deref() {
160 Some(config) => config.clone(),
161 None => {
162 drop(locked);
163 self.get_oidc_config().await?
164 }
165 }
166 };
167
168 let jwks = {
169 let locked = self.jwks_set.read().await;
170 match locked.deref() {
171 Some(set) => set.clone(),
172 None => {
173 drop(locked);
174 self.get_jwkset(&oidc_config.jwks_uri).await?
175 }
176 }
177 };
178
179 let jwk = match jwks.find(&kid) {
180 Some(jwk) => jwk.clone(),
181 None => {
182 let refreshed_jwks = self.get_jwkset(&oidc_config.jwks_uri).await?;
183 refreshed_jwks
184 .find(&kid)
185 .ok_or(Error::MissingKidHeader)?
186 .clone()
187 }
188 };
189
190 let decoding_key = match &jwk.algorithm {
191 AlgorithmParameters::RSA(rsa) => DecodingKey::from_rsa_components(&rsa.n, &rsa.e)?,
192 AlgorithmParameters::EllipticCurve(ecdsa) => {
193 DecodingKey::from_ec_components(&ecdsa.x, &ecdsa.y)?
194 }
195 _ => return Err(Error::UnsupportedSigningAlgo),
196 };
197
198 let validation = {
199 let mut validation = Validation::new(header.alg);
200 validation.validate_exp = true;
201 validation.validate_aud = false;
202 validation.set_issuer(&[oidc_config.issuer]);
203 validation
204 };
205
206 match decode::<HashMap<String, serde_json::Value>>(cat_jwt, &decoding_key, &validation) {
207 Ok(claims) => {
208 tracing::debug!("Successfully verified cat");
209 if let Some(client_id) = &self.client_id {
210 if let Some(token_client_id) = claims.claims.get("client_id") {
211 if let Some(token_client_id_value) = token_client_id.as_str() {
212 if token_client_id_value != client_id {
213 tracing::warn!(
214 "Client ID mismatch: expected {}, got {}",
215 client_id,
216 token_client_id_value
217 );
218 return Err(Error::InvalidClientId);
219 }
220 }
221 } else if let Some(azp) = claims.claims.get("azp") {
222 if let Some(azp_value) = azp.as_str() {
223 if azp_value != client_id {
224 tracing::warn!(
225 "Client ID (azp) mismatch: expected {}, got {}",
226 client_id,
227 azp_value
228 );
229 return Err(Error::InvalidClientId);
230 }
231 }
232 }
233 }
234 }
235 Err(err) => {
236 tracing::debug!("Could not verify cat: {}", err);
237 return Err(err.into());
238 }
239 }
240
241 Ok(())
242 }
243
244 #[cfg(feature = "wallet")]
246 pub async fn refresh_access_token(
247 &self,
248 client_id: String,
249 refresh_token: String,
250 ) -> Result<TokenResponse, Error> {
251 let token_url = self.get_oidc_config().await?.token_endpoint;
252
253 let request = RefreshTokenRequest {
254 grant_type: GrantType::RefreshToken,
255 client_id,
256 refresh_token,
257 };
258
259 let response: TokenResponse = self.client.post_form(&token_url, &request).await?;
260
261 Ok(response)
262 }
263}