1use std::collections::HashMap;
4use std::ops::Deref;
5use std::sync::Arc;
6
7use jsonwebtoken::jwk::{AlgorithmParameters, JwkSet};
8use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
9use serde::Deserialize;
10#[cfg(feature = "wallet")]
11use serde::Serialize;
12use thiserror::Error;
13use tokio::sync::RwLock;
14use tracing::instrument;
15
16use crate::HttpClient;
17
18fn validate_client_id_claim(
19 claim_name: &str,
20 claim_value: &serde_json::Value,
21 client_id: &str,
22) -> Result<(), Error> {
23 let Some(token_client_id) = claim_value.as_str() else {
24 tracing::warn!("{} claim is not a string", claim_name);
25 return Err(Error::InvalidClientId);
26 };
27
28 if token_client_id != client_id {
29 tracing::warn!(
30 "Client ID ({}) mismatch: expected {}, got {}",
31 claim_name,
32 client_id,
33 token_client_id
34 );
35 return Err(Error::InvalidClientId);
36 }
37
38 Ok(())
39}
40
41fn validate_client_id_claims(
42 claims: &HashMap<String, serde_json::Value>,
43 client_id: &str,
44) -> Result<(), Error> {
45 match claims.get("client_id") {
46 Some(token_client_id) => validate_client_id_claim("client_id", token_client_id, client_id),
47 None => match claims.get("azp") {
48 Some(azp) => validate_client_id_claim("azp", azp, client_id),
49 None => {
50 tracing::warn!("CAT missing client_id or azp claim for configured client ID");
51 Err(Error::InvalidClientId)
52 }
53 },
54 }
55}
56
57#[derive(Debug, Error)]
59pub enum Error {
60 #[error(transparent)]
62 Http(#[from] crate::HttpError),
63 #[error(transparent)]
65 Jwt(#[from] jsonwebtoken::errors::Error),
66 #[error("Missing kid header")]
68 MissingKidHeader,
69 #[error("Missing jwk")]
71 MissingJwkHeader,
72 #[error("Unsupported signing algo")]
74 UnsupportedSigningAlgo,
75 #[error("Invalid Client ID")]
77 InvalidClientId,
78}
79
80impl From<Error> for crate::error::Error {
81 fn from(value: Error) -> Self {
82 tracing::debug!("Clear auth verification failed: {}", value);
83 crate::error::Error::ClearAuthFailed
84 }
85}
86
87#[derive(Debug, Clone, Deserialize)]
89pub struct OidcConfig {
90 pub jwks_uri: String,
92 pub issuer: String,
94 pub token_endpoint: String,
96 pub device_authorization_endpoint: String,
98}
99
100#[derive(Debug, Clone)]
102pub struct OidcClient {
103 client: HttpClient,
104 openid_discovery: String,
105 client_id: Option<String>,
106 oidc_config: Arc<RwLock<Option<OidcConfig>>>,
107 jwks_set: Arc<RwLock<Option<JwkSet>>>,
108}
109
110#[cfg(feature = "wallet")]
112#[derive(Debug, Clone, Copy, Serialize)]
113#[serde(rename_all = "snake_case")]
114pub enum GrantType {
115 RefreshToken,
117}
118
119#[cfg(feature = "wallet")]
121#[derive(Debug, Clone, Serialize)]
122pub struct RefreshTokenRequest {
123 pub grant_type: GrantType,
125 pub client_id: String,
127 pub refresh_token: String,
129}
130
131#[cfg(feature = "wallet")]
133#[derive(Debug, Clone, Deserialize)]
134pub struct TokenResponse {
135 pub access_token: String,
137 pub refresh_token: Option<String>,
139 pub expires_in: Option<i64>,
141 pub token_type: String,
143}
144
145impl OidcClient {
146 pub fn new(openid_discovery: String, client_id: Option<String>) -> Self {
148 Self {
149 client: HttpClient::new(),
150 openid_discovery,
151 client_id,
152 oidc_config: Arc::new(RwLock::new(None)),
153 jwks_set: Arc::new(RwLock::new(None)),
154 }
155 }
156
157 pub fn client_id(&self) -> Option<String> {
159 self.client_id.clone()
160 }
161
162 #[instrument(skip(self))]
164 pub async fn get_oidc_config(&self) -> Result<OidcConfig, Error> {
165 tracing::debug!("Getting oidc config");
166 let oidc_config: OidcConfig = self.client.fetch(&self.openid_discovery).await?;
167
168 let mut current_config = self.oidc_config.write().await;
169
170 *current_config = Some(oidc_config.clone());
171
172 Ok(oidc_config)
173 }
174
175 #[instrument(skip(self))]
177 pub async fn get_jwkset(&self, jwks_uri: &str) -> Result<JwkSet, Error> {
178 tracing::debug!("Getting jwks set");
179 let jwks_set: JwkSet = self.client.fetch(jwks_uri).await?;
180
181 let mut current_set = self.jwks_set.write().await;
182
183 *current_set = Some(jwks_set.clone());
184
185 Ok(jwks_set)
186 }
187
188 #[instrument(skip_all)]
190 pub async fn verify_cat(&self, cat_jwt: &str) -> Result<(), Error> {
191 tracing::debug!("Verifying cat");
192 let header = decode_header(cat_jwt)?;
193
194 let kid = header.kid.ok_or(Error::MissingKidHeader)?;
195
196 let oidc_config = {
197 let locked = self.oidc_config.read().await;
198 match locked.deref() {
199 Some(config) => config.clone(),
200 None => {
201 drop(locked);
202 self.get_oidc_config().await?
203 }
204 }
205 };
206
207 let jwks = {
208 let locked = self.jwks_set.read().await;
209 match locked.deref() {
210 Some(set) => set.clone(),
211 None => {
212 drop(locked);
213 self.get_jwkset(&oidc_config.jwks_uri).await?
214 }
215 }
216 };
217
218 let jwk = match jwks.find(&kid) {
219 Some(jwk) => jwk.clone(),
220 None => {
221 let refreshed_jwks = self.get_jwkset(&oidc_config.jwks_uri).await?;
222 refreshed_jwks
223 .find(&kid)
224 .ok_or(Error::MissingKidHeader)?
225 .clone()
226 }
227 };
228
229 let decoding_key = match &jwk.algorithm {
230 AlgorithmParameters::RSA(rsa) => DecodingKey::from_rsa_components(&rsa.n, &rsa.e)?,
231 AlgorithmParameters::EllipticCurve(ecdsa) => {
232 DecodingKey::from_ec_components(&ecdsa.x, &ecdsa.y)?
233 }
234 _ => return Err(Error::UnsupportedSigningAlgo),
235 };
236
237 let validation = {
238 let mut validation = Validation::new(header.alg);
239 validation.validate_exp = true;
240 validation.validate_aud = false;
241 validation.set_issuer(&[oidc_config.issuer]);
242 validation
243 };
244
245 match decode::<HashMap<String, serde_json::Value>>(cat_jwt, &decoding_key, &validation) {
246 Ok(claims) => {
247 tracing::debug!("Successfully verified cat");
248 if let Some(client_id) = &self.client_id {
249 validate_client_id_claims(&claims.claims, client_id)?;
250 }
251 }
252 Err(err) => {
253 tracing::debug!("Could not verify cat: {}", err);
254 return Err(err.into());
255 }
256 }
257
258 Ok(())
259 }
260
261 #[cfg(feature = "wallet")]
263 pub async fn refresh_access_token(
264 &self,
265 client_id: String,
266 refresh_token: String,
267 ) -> Result<TokenResponse, Error> {
268 let token_url = self.get_oidc_config().await?.token_endpoint;
269
270 let request = RefreshTokenRequest {
271 grant_type: GrantType::RefreshToken,
272 client_id,
273 refresh_token,
274 };
275
276 let response: TokenResponse = self.client.post_form(&token_url, &request).await?;
277
278 Ok(response)
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use serde_json::json;
285
286 use super::*;
287
288 fn claims(value: serde_json::Value) -> HashMap<String, serde_json::Value> {
289 serde_json::from_value(value).expect("claims should be an object")
290 }
291
292 #[test]
293 fn validate_client_id_claims_accepts_client_id() {
294 let claims = claims(json!({
295 "client_id": "expected-client",
296 "azp": "other-client",
297 }));
298
299 assert!(validate_client_id_claims(&claims, "expected-client").is_ok());
300 }
301
302 #[test]
303 fn validate_client_id_claims_accepts_azp_fallback() {
304 let claims = claims(json!({
305 "azp": "expected-client",
306 }));
307
308 assert!(validate_client_id_claims(&claims, "expected-client").is_ok());
309 }
310
311 #[test]
312 fn validate_client_id_claims_rejects_missing_claims() {
313 let claims = claims(json!({
314 "sub": "user",
315 }));
316
317 assert!(matches!(
318 validate_client_id_claims(&claims, "expected-client"),
319 Err(Error::InvalidClientId)
320 ));
321 }
322
323 #[test]
324 fn validate_client_id_claims_rejects_non_string_client_id() {
325 let claims = claims(json!({
326 "client_id": null,
327 "azp": "expected-client",
328 }));
329
330 assert!(matches!(
331 validate_client_id_claims(&claims, "expected-client"),
332 Err(Error::InvalidClientId)
333 ));
334 }
335
336 #[test]
337 fn validate_client_id_claims_rejects_non_string_azp() {
338 let claims = claims(json!({
339 "azp": 42,
340 }));
341
342 assert!(matches!(
343 validate_client_id_claims(&claims, "expected-client"),
344 Err(Error::InvalidClientId)
345 ));
346 }
347
348 #[test]
349 fn validate_client_id_claims_rejects_mismatch() {
350 let claims = claims(json!({
351 "client_id": "other-client",
352 }));
353
354 assert!(matches!(
355 validate_client_id_claims(&claims, "expected-client"),
356 Err(Error::InvalidClientId)
357 ));
358 }
359}