Skip to main content

securitydept_oauth_provider/
runtime.rs

1use std::time::Instant;
2
3use openidconnect::{
4    AccessToken, AuthUrl, ClientId, ClientSecret, DeviceAuthorizationUrl, IntrospectionUrl,
5    IssuerUrl, JsonWebKeySet, JsonWebKeySetUrl, ResponseTypes, RevocationUrl, TokenUrl,
6    UserInfoUrl,
7    core::{
8        CoreClient, CoreJsonWebKeySet, CoreResponseType, CoreSubjectIdentifierType,
9        CoreTokenIntrospectionResponse,
10    },
11    reqwest,
12};
13use tokio::sync::RwLock;
14
15use crate::{
16    OAuthProviderConfig, OAuthProviderError, OAuthProviderMetadata, OAuthProviderResult,
17    ProviderMetadataWithExtra, config::default_id_token_signing_alg_values_supported,
18    models::ExtraProviderMetadata,
19};
20
21struct ProviderState {
22    metadata: OAuthProviderMetadata,
23    metadata_fetched_at: Instant,
24    jwks_fetched_at: Instant,
25}
26
27pub struct OAuthProviderRuntime {
28    config: OAuthProviderConfig,
29    http_client: reqwest::Client,
30    state: RwLock<ProviderState>,
31}
32
33impl OAuthProviderRuntime {
34    pub async fn from_config(config: OAuthProviderConfig) -> OAuthProviderResult<Self> {
35        config.validate()?;
36
37        let http_client =
38            reqwest::Client::builder()
39                .build()
40                .map_err(|e| OAuthProviderError::HttpClient {
41                    message: format!("Failed to build HTTP client: {e}"),
42                })?;
43
44        let metadata = fetch_metadata(&config, &http_client).await?;
45        Ok(Self {
46            config,
47            http_client,
48            state: RwLock::new(ProviderState {
49                metadata,
50                metadata_fetched_at: Instant::now(),
51                jwks_fetched_at: Instant::now(),
52            }),
53        })
54    }
55
56    pub fn http_client(&self) -> &reqwest::Client {
57        &self.http_client
58    }
59
60    pub async fn metadata(&self) -> OAuthProviderResult<OAuthProviderMetadata> {
61        self.ensure_metadata_and_jwks_fresh().await?;
62        Ok(self.state.read().await.metadata.clone())
63    }
64
65    pub async fn oidc_provider_metadata(&self) -> OAuthProviderResult<ProviderMetadataWithExtra> {
66        let metadata = self.metadata().await?;
67        to_oidc_provider_metadata(&metadata)
68    }
69
70    pub async fn jwks(&self) -> OAuthProviderResult<CoreJsonWebKeySet> {
71        Ok(self.metadata().await?.jwks)
72    }
73
74    pub async fn refresh_jwks(&self) -> OAuthProviderResult<OAuthProviderMetadata> {
75        let jwks_uri = { self.state.read().await.metadata.jwks_uri.clone() };
76        let jwks = fetch_jwks(&jwks_uri, &self.http_client).await?;
77
78        let mut state = self.state.write().await;
79        state.metadata.jwks = jwks;
80        state.jwks_fetched_at = Instant::now();
81        Ok(state.metadata.clone())
82    }
83
84    pub async fn refresh_metadata(&self) -> OAuthProviderResult<OAuthProviderMetadata> {
85        if self.config.remote.well_known_url.is_none() {
86            return Ok(self.state.read().await.metadata.clone());
87        }
88
89        let metadata = fetch_metadata(&self.config, &self.http_client).await?;
90        let mut state = self.state.write().await;
91        state.metadata = metadata;
92        state.metadata_fetched_at = Instant::now();
93        state.jwks_fetched_at = Instant::now();
94        Ok(state.metadata.clone())
95    }
96
97    pub async fn introspect(
98        &self,
99        client_id: &str,
100        client_secret: Option<&str>,
101        token: &str,
102        token_type_hint: Option<&str>,
103    ) -> OAuthProviderResult<CoreTokenIntrospectionResponse> {
104        let metadata = self.metadata().await?;
105        let introspection_url = metadata.introspection_endpoint.clone().ok_or_else(|| {
106            OAuthProviderError::InvalidConfig {
107                message: "introspection endpoint is not configured and was not discovered"
108                    .to_string(),
109            }
110        })?;
111
112        let client = if let Some(client_secret) = client_secret {
113            CoreClient::new(
114                ClientId::new(client_id.to_string()),
115                metadata.issuer.clone(),
116                CoreJsonWebKeySet::new(vec![]),
117            )
118            .set_client_secret(ClientSecret::new(client_secret.to_string()))
119            .set_introspection_url(introspection_url)
120        } else {
121            CoreClient::new(
122                ClientId::new(client_id.to_string()),
123                metadata.issuer.clone(),
124                CoreJsonWebKeySet::new(vec![]),
125            )
126            .set_introspection_url(introspection_url)
127        };
128
129        let access_token = AccessToken::new(token.to_string());
130        let mut request = client.introspect(&access_token);
131        if let Some(token_type_hint) = token_type_hint {
132            request = request.set_token_type_hint(token_type_hint);
133        }
134
135        request.request_async(&self.http_client).await.map_err(|e| {
136            OAuthProviderError::Introspection {
137                message: format!("Opaque token introspection failed: {e}"),
138            }
139        })
140    }
141
142    async fn ensure_metadata_and_jwks_fresh(&self) -> OAuthProviderResult<()> {
143        self.ensure_metadata_fresh().await?;
144        self.ensure_jwks_fresh().await
145    }
146
147    async fn ensure_metadata_fresh(&self) -> OAuthProviderResult<()> {
148        if self.config.remote.metadata_refresh_interval.is_zero()
149            || self.config.remote.well_known_url.is_none()
150        {
151            return Ok(());
152        }
153
154        let should_refresh = {
155            let state = self.state.read().await;
156            state.metadata_fetched_at.elapsed() >= self.config.remote.metadata_refresh_interval
157        };
158
159        if should_refresh {
160            let _ = self.refresh_metadata().await?;
161        }
162
163        Ok(())
164    }
165
166    async fn ensure_jwks_fresh(&self) -> OAuthProviderResult<()> {
167        if self.config.remote.jwks_refresh_interval.is_zero() {
168            return Ok(());
169        }
170
171        let should_refresh = {
172            let state = self.state.read().await;
173            state.jwks_fetched_at.elapsed() >= self.config.remote.jwks_refresh_interval
174        };
175
176        if should_refresh {
177            let _ = self.refresh_jwks().await?;
178        }
179
180        Ok(())
181    }
182}
183
184async fn fetch_metadata(
185    config: &OAuthProviderConfig,
186    http_client: &reqwest::Client,
187) -> OAuthProviderResult<OAuthProviderMetadata> {
188    if let Some(well_known_url) = config.remote.well_known_url.as_deref() {
189        let response = http_client.get(well_known_url).send().await.map_err(|e| {
190            OAuthProviderError::Metadata {
191                message: format!("Failed to fetch discovery document: {e}"),
192            }
193        })?;
194        let body = response
195            .bytes()
196            .await
197            .map_err(|e| OAuthProviderError::Metadata {
198                message: format!("Failed to read discovery document: {e}"),
199            })?;
200
201        let mut metadata: ProviderMetadataWithExtra =
202            serde_json::from_slice(&body).map_err(|e| OAuthProviderError::Metadata {
203                message: format!("Failed to parse discovery document: {e}"),
204            })?;
205
206        if let Some(issuer_url) = config.remote.issuer_url.as_ref() {
207            metadata = metadata.set_issuer(IssuerUrl::new(issuer_url.clone()).map_err(|e| {
208                OAuthProviderError::Metadata {
209                    message: format!("Invalid issuer_url: {e}"),
210                }
211            })?);
212        }
213        if let Some(authorization_endpoint) = config.oidc.authorization_endpoint.as_ref() {
214            metadata = metadata.set_authorization_endpoint(
215                AuthUrl::new(authorization_endpoint.clone()).map_err(|e| {
216                    OAuthProviderError::Metadata {
217                        message: format!("Invalid authorization_endpoint: {e}"),
218                    }
219                })?,
220            );
221        }
222        if let Some(token_endpoint) = config.oidc.token_endpoint.as_ref() {
223            metadata =
224                metadata.set_token_endpoint(Some(TokenUrl::new(token_endpoint.clone()).map_err(
225                    |e| OAuthProviderError::Metadata {
226                        message: format!("Invalid token_endpoint: {e}"),
227                    },
228                )?));
229        }
230        if let Some(userinfo_endpoint) = config.oidc.userinfo_endpoint.as_ref() {
231            metadata = metadata.set_userinfo_endpoint(Some(
232                UserInfoUrl::new(userinfo_endpoint.clone()).map_err(|e| {
233                    OAuthProviderError::Metadata {
234                        message: format!("Invalid userinfo_endpoint: {e}"),
235                    }
236                })?,
237            ));
238        }
239        if let Some(jwks_uri) = config.remote.jwks_uri.as_ref() {
240            metadata =
241                metadata.set_jwks_uri(JsonWebKeySetUrl::new(jwks_uri.clone()).map_err(|e| {
242                    OAuthProviderError::Metadata {
243                        message: format!("Invalid jwks_uri: {e}"),
244                    }
245                })?);
246        }
247        if let Some(introspection_endpoint) = config.oidc.introspection_endpoint.as_ref() {
248            metadata.additional_metadata_mut().introspection_endpoint =
249                Some(introspection_endpoint.clone());
250        }
251        if let Some(revocation_endpoint) = config.oidc.revocation_endpoint.as_ref() {
252            metadata.additional_metadata_mut().revocation_endpoint =
253                Some(revocation_endpoint.clone());
254        }
255        if let Some(device_authorization_endpoint) =
256            config.oidc.device_authorization_endpoint.as_ref()
257        {
258            metadata
259                .additional_metadata_mut()
260                .device_authorization_endpoint = Some(device_authorization_endpoint.clone());
261        }
262        if let Some(token_endpoint_auth_methods_supported) =
263            config.oidc.token_endpoint_auth_methods_supported.as_ref()
264        {
265            metadata = metadata.set_token_endpoint_auth_methods_supported(Some(
266                token_endpoint_auth_methods_supported.clone(),
267            ));
268        }
269        if let Some(id_token_signing_alg_values_supported) =
270            config.oidc.id_token_signing_alg_values_supported.as_ref()
271        {
272            metadata = metadata.set_id_token_signing_alg_values_supported(
273                id_token_signing_alg_values_supported.clone(),
274            );
275        }
276        if let Some(userinfo_signing_alg_values_supported) =
277            config.oidc.userinfo_signing_alg_values_supported.as_ref()
278        {
279            metadata = metadata.set_userinfo_signing_alg_values_supported(Some(
280                userinfo_signing_alg_values_supported.clone(),
281            ));
282        }
283
284        let jwks = fetch_jwks(metadata.jwks_uri(), http_client).await?;
285        return from_provider_metadata(metadata.set_jwks(jwks));
286    }
287
288    let issuer =
289        IssuerUrl::new(config.remote.issuer_url.clone().unwrap_or_default()).map_err(|e| {
290            OAuthProviderError::Metadata {
291                message: format!("Invalid issuer_url: {e}"),
292            }
293        })?;
294    let jwks_uri = JsonWebKeySetUrl::new(config.remote.jwks_uri.clone().unwrap_or_default())
295        .map_err(|e| OAuthProviderError::Metadata {
296            message: format!("Invalid jwks_uri: {e}"),
297        })?;
298    let jwks = fetch_jwks(&jwks_uri, http_client).await?;
299
300    Ok(OAuthProviderMetadata {
301        issuer,
302        authorization_endpoint: config
303            .oidc
304            .authorization_endpoint
305            .as_ref()
306            .map(|value| AuthUrl::new(value.clone()))
307            .transpose()
308            .map_err(|e| OAuthProviderError::Metadata {
309                message: format!("Invalid authorization_endpoint: {e}"),
310            })?,
311        token_endpoint: config
312            .oidc
313            .token_endpoint
314            .as_ref()
315            .map(|value| TokenUrl::new(value.clone()))
316            .transpose()
317            .map_err(|e| OAuthProviderError::Metadata {
318                message: format!("Invalid token_endpoint: {e}"),
319            })?,
320        userinfo_endpoint: config
321            .oidc
322            .userinfo_endpoint
323            .as_ref()
324            .map(|value| UserInfoUrl::new(value.clone()))
325            .transpose()
326            .map_err(|e| OAuthProviderError::Metadata {
327                message: format!("Invalid userinfo_endpoint: {e}"),
328            })?,
329        introspection_endpoint: config
330            .oidc
331            .introspection_endpoint
332            .as_ref()
333            .map(|value| IntrospectionUrl::new(value.clone()))
334            .transpose()
335            .map_err(|e| OAuthProviderError::Metadata {
336                message: format!("Invalid introspection_endpoint: {e}"),
337            })?,
338        revocation_endpoint: config
339            .oidc
340            .revocation_endpoint
341            .as_ref()
342            .map(|value| RevocationUrl::new(value.clone()))
343            .transpose()
344            .map_err(|e| OAuthProviderError::Metadata {
345                message: format!("Invalid revocation_endpoint: {e}"),
346            })?,
347        device_authorization_endpoint: config
348            .oidc
349            .device_authorization_endpoint
350            .as_ref()
351            .map(|value| DeviceAuthorizationUrl::new(value.clone()))
352            .transpose()
353            .map_err(|e| OAuthProviderError::Metadata {
354                message: format!("Invalid device_authorization_endpoint: {e}"),
355            })?,
356        jwks_uri,
357        jwks,
358        token_endpoint_auth_methods_supported: config
359            .oidc
360            .token_endpoint_auth_methods_supported
361            .clone(),
362        response_types_supported: vec![ResponseTypes::new(vec![CoreResponseType::Code])],
363        subject_types_supported: vec![CoreSubjectIdentifierType::Public],
364        id_token_signing_alg_values_supported: config
365            .oidc
366            .id_token_signing_alg_values_supported
367            .clone()
368            .unwrap_or_else(default_id_token_signing_alg_values_supported),
369        userinfo_signing_alg_values_supported: config
370            .oidc
371            .userinfo_signing_alg_values_supported
372            .clone(),
373        additional_metadata: ExtraProviderMetadata {
374            introspection_endpoint: config.oidc.introspection_endpoint.clone(),
375            revocation_endpoint: config.oidc.revocation_endpoint.clone(),
376            device_authorization_endpoint: config.oidc.device_authorization_endpoint.clone(),
377            extra: Default::default(),
378        },
379    })
380}
381
382fn from_provider_metadata(
383    metadata: ProviderMetadataWithExtra,
384) -> OAuthProviderResult<OAuthProviderMetadata> {
385    Ok(OAuthProviderMetadata {
386        issuer: metadata.issuer().clone(),
387        authorization_endpoint: Some(metadata.authorization_endpoint().clone()),
388        token_endpoint: metadata.token_endpoint().cloned(),
389        userinfo_endpoint: metadata.userinfo_endpoint().cloned(),
390        introspection_endpoint: metadata
391            .additional_metadata()
392            .introspection_endpoint
393            .as_ref()
394            .map(|value| IntrospectionUrl::new(value.clone()))
395            .transpose()
396            .map_err(|e| OAuthProviderError::Metadata {
397                message: format!("Invalid introspection_endpoint: {e}"),
398            })?,
399        revocation_endpoint: metadata
400            .additional_metadata()
401            .revocation_endpoint
402            .as_ref()
403            .map(|value| RevocationUrl::new(value.clone()))
404            .transpose()
405            .map_err(|e| OAuthProviderError::Metadata {
406                message: format!("Invalid revocation_endpoint: {e}"),
407            })?,
408        device_authorization_endpoint: metadata
409            .additional_metadata()
410            .device_authorization_endpoint
411            .as_ref()
412            .map(|value| DeviceAuthorizationUrl::new(value.clone()))
413            .transpose()
414            .map_err(|e| OAuthProviderError::Metadata {
415                message: format!("Invalid device_authorization_endpoint: {e}"),
416            })?,
417        jwks_uri: metadata.jwks_uri().clone(),
418        jwks: metadata.jwks().clone(),
419        token_endpoint_auth_methods_supported: metadata
420            .token_endpoint_auth_methods_supported()
421            .cloned(),
422        response_types_supported: metadata.response_types_supported().clone(),
423        subject_types_supported: metadata.subject_types_supported().clone(),
424        id_token_signing_alg_values_supported: metadata
425            .id_token_signing_alg_values_supported()
426            .clone(),
427        userinfo_signing_alg_values_supported: metadata
428            .userinfo_signing_alg_values_supported()
429            .cloned(),
430        additional_metadata: metadata.additional_metadata().clone(),
431    })
432}
433
434fn to_oidc_provider_metadata(
435    metadata: &OAuthProviderMetadata,
436) -> OAuthProviderResult<ProviderMetadataWithExtra> {
437    let authorization_endpoint = metadata.authorization_endpoint.clone().ok_or_else(|| {
438        OAuthProviderError::InvalidConfig {
439            message: "authorization_endpoint is required to build an OIDC client".to_string(),
440        }
441    })?;
442
443    Ok(ProviderMetadataWithExtra::new(
444        metadata.issuer.clone(),
445        authorization_endpoint,
446        metadata.jwks_uri.clone(),
447        metadata.response_types_supported.clone(),
448        metadata.subject_types_supported.clone(),
449        metadata.id_token_signing_alg_values_supported.clone(),
450        metadata.additional_metadata.clone(),
451    )
452    .set_jwks(metadata.jwks.clone())
453    .set_token_endpoint(metadata.token_endpoint.clone())
454    .set_userinfo_endpoint(metadata.userinfo_endpoint.clone())
455    .set_token_endpoint_auth_methods_supported(
456        metadata.token_endpoint_auth_methods_supported.clone(),
457    )
458    .set_userinfo_signing_alg_values_supported(
459        metadata.userinfo_signing_alg_values_supported.clone(),
460    ))
461}
462
463async fn fetch_jwks(
464    jwks_uri: &JsonWebKeySetUrl,
465    http_client: &reqwest::Client,
466) -> OAuthProviderResult<CoreJsonWebKeySet> {
467    JsonWebKeySet::fetch_async(jwks_uri, http_client)
468        .await
469        .map_err(|e| OAuthProviderError::Metadata {
470            message: format!("Failed to fetch JWKS: {e}"),
471        })
472}
473
474#[cfg(test)]
475mod tests {
476    use openidconnect::{
477        AuthUrl, IssuerUrl, JsonWebKeySetUrl, ProviderMetadata, ResponseTypes,
478        core::{
479            CoreJsonWebKeySet, CoreJwsSigningAlgorithm, CoreResponseType, CoreSubjectIdentifierType,
480        },
481    };
482
483    use super::from_provider_metadata;
484    use crate::{ExtraProviderMetadata, OAuthProviderError, ProviderMetadataWithExtra};
485
486    #[test]
487    fn from_provider_metadata_rejects_invalid_discovery_override_urls() {
488        let metadata: ProviderMetadataWithExtra = ProviderMetadata::new(
489            IssuerUrl::new("https://issuer.example.com".to_string()).expect("issuer should parse"),
490            AuthUrl::new("https://issuer.example.com/authorize".to_string())
491                .expect("auth url should parse"),
492            JsonWebKeySetUrl::new("https://issuer.example.com/jwks".to_string())
493                .expect("jwks uri should parse"),
494            vec![ResponseTypes::new(vec![CoreResponseType::Code])],
495            vec![CoreSubjectIdentifierType::Public],
496            vec![CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256],
497            ExtraProviderMetadata {
498                introspection_endpoint: Some("not-a-url".to_string()),
499                ..Default::default()
500            },
501        )
502        .set_jwks(CoreJsonWebKeySet::new(vec![]));
503
504        let error = from_provider_metadata(metadata).expect_err("invalid endpoint should fail");
505        match error {
506            OAuthProviderError::Metadata { message } => {
507                assert!(message.contains("Invalid introspection_endpoint"));
508            }
509            other => panic!("unexpected error: {other}"),
510        }
511    }
512}