Skip to main content

rustauth_oidc/
discovery.rs

1use serde::{Deserialize, Serialize};
2use url::Url;
3
4use crate::options::{OidcConfig, TokenEndpointAuthentication};
5
6/// Required fields that must be present in a valid OIDC discovery document.
7///
8/// Matches Better Auth `REQUIRED_DISCOVERY_FIELDS` in `@better-auth/sso`.
9pub const REQUIRED_DISCOVERY_FIELDS: &[&str] = &[
10    "issuer",
11    "authorization_endpoint",
12    "token_endpoint",
13    "jwks_uri",
14];
15
16#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
17pub struct OidcDiscoveryDocument {
18    #[serde(default)]
19    pub issuer: String,
20    #[serde(default)]
21    pub authorization_endpoint: String,
22    #[serde(default)]
23    pub token_endpoint: String,
24    #[serde(default)]
25    pub jwks_uri: String,
26    pub userinfo_endpoint: Option<String>,
27    pub revocation_endpoint: Option<String>,
28    pub end_session_endpoint: Option<String>,
29    pub introspection_endpoint: Option<String>,
30    pub token_endpoint_auth_methods_supported: Option<Vec<String>>,
31    pub scopes_supported: Option<Vec<String>>,
32    pub response_types_supported: Option<Vec<String>>,
33    pub subject_types_supported: Option<Vec<String>>,
34    pub id_token_signing_alg_values_supported: Option<Vec<String>>,
35    pub claims_supported: Option<Vec<String>>,
36    pub code_challenge_methods_supported: Option<Vec<String>>,
37}
38
39/// Returns true when an optional endpoint URL is present and non-empty.
40///
41/// Better Auth treats empty strings as missing for runtime discovery
42/// (`!config.tokenEndpoint` is true for `""`).
43pub fn is_configured_oidc_endpoint(endpoint: Option<&str>) -> bool {
44    endpoint.is_some_and(|value| !value.is_empty())
45}
46
47fn merge_required_endpoint(existing: Option<&str>, discovered: String) -> String {
48    existing
49        .filter(|value| !value.is_empty())
50        .map(str::to_owned)
51        .unwrap_or(discovered)
52}
53
54fn merge_optional_endpoint(existing: Option<&str>, discovered: Option<String>) -> Option<String> {
55    if let Some(value) = existing.filter(|value| !value.is_empty()) {
56        return Some(value.to_owned());
57    }
58    discovered
59}
60
61fn non_empty_endpoint(endpoint: Option<&str>) -> Option<&str> {
62    endpoint.filter(|value| !value.is_empty())
63}
64
65pub fn compute_discovery_url(issuer: &str) -> String {
66    format!(
67        "{}/.well-known/openid-configuration",
68        issuer.trim_end_matches('/')
69    )
70}
71
72pub fn normalize_url(value: &str) -> Result<String, url::ParseError> {
73    Url::parse(value).map(|url| url.to_string())
74}
75
76/// Normalize and validate an absolute HTTP(S) URL.
77///
78/// This is stricter than [`normalize_url`], which is retained for backward
79/// compatibility and only parses the URL.
80pub fn normalize_absolute_http_url(
81    field: &'static str,
82    value: &str,
83) -> Result<String, OidcDiscoveryError> {
84    validate_trusted_url(field, value, &|_| true)?;
85    Url::parse(value)
86        .map(|url| url.to_string())
87        .map_err(|source| OidcDiscoveryError::InvalidUrl {
88            field,
89            reason: source.to_string(),
90        })
91}
92
93/// Normalize an OIDC endpoint URL, resolving relative endpoints against the
94/// issuer origin and path.
95pub fn normalize_endpoint_url(
96    field: &'static str,
97    endpoint: &str,
98    issuer: &str,
99) -> Result<String, OidcDiscoveryError> {
100    normalize_endpoint(field, endpoint, issuer)
101}
102
103pub fn validate_issuer_url(value: &str) -> Result<String, openidconnect::url::ParseError> {
104    openidconnect::IssuerUrl::new(value.to_owned()).map(|issuer| issuer.to_string())
105}
106
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub struct HydratedOidcDiscovery {
109    pub issuer: String,
110    pub discovery_endpoint: String,
111    pub authorization_endpoint: String,
112    pub token_endpoint: String,
113    pub jwks_endpoint: String,
114    pub user_info_endpoint: Option<String>,
115    pub revocation_endpoint: Option<String>,
116    pub end_session_endpoint: Option<String>,
117    pub introspection_endpoint: Option<String>,
118    pub token_endpoint_authentication: TokenEndpointAuthentication,
119    pub scopes_supported: Option<Vec<String>>,
120}
121
122pub async fn discover_oidc_config(
123    issuer: &str,
124    discovery_endpoint: Option<&str>,
125    existing: PartialOidcDiscoveryConfig<'_>,
126    client: &reqwest::Client,
127) -> Result<HydratedOidcDiscovery, OidcDiscoveryError> {
128    discover_oidc_config_with_origin_validator(
129        issuer,
130        discovery_endpoint,
131        existing,
132        |_| true,
133        client,
134    )
135    .await
136}
137
138pub async fn discover_oidc_config_with_origin_validator<F>(
139    issuer: &str,
140    discovery_endpoint: Option<&str>,
141    existing: PartialOidcDiscoveryConfig<'_>,
142    is_trusted_origin: F,
143    client: &reqwest::Client,
144) -> Result<HydratedOidcDiscovery, OidcDiscoveryError>
145where
146    F: Fn(&str) -> bool,
147{
148    let discovery_endpoint = discovery_endpoint
149        .map(str::to_owned)
150        .or_else(|| existing.discovery_endpoint.map(str::to_owned))
151        .unwrap_or_else(|| compute_discovery_url(issuer));
152    validate_trusted_url(
153        "discovery_endpoint",
154        &discovery_endpoint,
155        &is_trusted_origin,
156    )?;
157    let document = fetch_discovery_document(&discovery_endpoint, client).await?;
158    validate_discovery_document(&document, issuer)?;
159    let normalized = normalize_discovery_document(document, issuer)?;
160    let token_endpoint_authentication =
161        select_token_endpoint_authentication(&normalized, existing.token_endpoint_authentication);
162
163    let hydrated = HydratedOidcDiscovery {
164        issuer: existing
165            .issuer
166            .map(str::to_owned)
167            .unwrap_or(normalized.issuer),
168        discovery_endpoint,
169        authorization_endpoint: merge_required_endpoint(
170            existing.authorization_endpoint,
171            normalized.authorization_endpoint,
172        ),
173        token_endpoint: merge_required_endpoint(existing.token_endpoint, normalized.token_endpoint),
174        jwks_endpoint: merge_required_endpoint(existing.jwks_endpoint, normalized.jwks_uri),
175        user_info_endpoint: merge_optional_endpoint(
176            existing.user_info_endpoint,
177            normalized.userinfo_endpoint,
178        ),
179        revocation_endpoint: merge_optional_endpoint(
180            existing.revocation_endpoint,
181            normalized.revocation_endpoint,
182        ),
183        end_session_endpoint: merge_optional_endpoint(
184            existing.end_session_endpoint,
185            normalized.end_session_endpoint,
186        ),
187        introspection_endpoint: merge_optional_endpoint(
188            existing.introspection_endpoint,
189            normalized.introspection_endpoint,
190        ),
191        token_endpoint_authentication,
192        scopes_supported: normalized.scopes_supported,
193    };
194    validate_trusted_url(
195        "authorization_endpoint",
196        &hydrated.authorization_endpoint,
197        &is_trusted_origin,
198    )?;
199    validate_trusted_url(
200        "token_endpoint",
201        &hydrated.token_endpoint,
202        &is_trusted_origin,
203    )?;
204    validate_trusted_url("jwks_uri", &hydrated.jwks_endpoint, &is_trusted_origin)?;
205    if let Some(user_info_endpoint) = &hydrated.user_info_endpoint {
206        validate_trusted_url("userinfo_endpoint", user_info_endpoint, &is_trusted_origin)?;
207    }
208    if let Some(revocation_endpoint) = &hydrated.revocation_endpoint {
209        validate_trusted_url(
210            "revocation_endpoint",
211            revocation_endpoint,
212            &is_trusted_origin,
213        )?;
214    }
215    if let Some(end_session_endpoint) = &hydrated.end_session_endpoint {
216        validate_trusted_url(
217            "end_session_endpoint",
218            end_session_endpoint,
219            &is_trusted_origin,
220        )?;
221    }
222    if let Some(introspection_endpoint) = &hydrated.introspection_endpoint {
223        validate_trusted_url(
224            "introspection_endpoint",
225            introspection_endpoint,
226            &is_trusted_origin,
227        )?;
228    }
229    Ok(hydrated)
230}
231
232pub trait OidcEndpointConfig {
233    fn discovery_endpoint(&self) -> &str;
234    fn authorization_endpoint(&self) -> Option<&str>;
235    fn token_endpoint(&self) -> Option<&str>;
236    fn user_info_endpoint(&self) -> Option<&str>;
237    fn jwks_endpoint(&self) -> Option<&str>;
238    fn revocation_endpoint(&self) -> Option<&str>;
239    fn end_session_endpoint(&self) -> Option<&str>;
240    fn introspection_endpoint(&self) -> Option<&str>;
241}
242
243impl OidcEndpointConfig for OidcConfig {
244    fn discovery_endpoint(&self) -> &str {
245        &self.discovery_endpoint
246    }
247
248    fn authorization_endpoint(&self) -> Option<&str> {
249        self.authorization_endpoint.as_deref()
250    }
251
252    fn token_endpoint(&self) -> Option<&str> {
253        self.token_endpoint.as_deref()
254    }
255
256    fn user_info_endpoint(&self) -> Option<&str> {
257        self.user_info_endpoint.as_deref()
258    }
259
260    fn jwks_endpoint(&self) -> Option<&str> {
261        self.jwks_endpoint.as_deref()
262    }
263
264    fn revocation_endpoint(&self) -> Option<&str> {
265        self.revocation_endpoint.as_deref()
266    }
267
268    fn end_session_endpoint(&self) -> Option<&str> {
269        self.end_session_endpoint.as_deref()
270    }
271
272    fn introspection_endpoint(&self) -> Option<&str> {
273        self.introspection_endpoint.as_deref()
274    }
275}
276
277pub fn validate_configured_oidc_endpoint_origins<C, F>(
278    config: &C,
279    is_trusted_origin: F,
280) -> Result<(), OidcDiscoveryError>
281where
282    C: OidcEndpointConfig + ?Sized,
283    F: Fn(&str) -> bool,
284{
285    validate_trusted_url(
286        "discovery_endpoint",
287        config.discovery_endpoint(),
288        &is_trusted_origin,
289    )?;
290    if let Some(endpoint) = config.authorization_endpoint() {
291        validate_trusted_url("authorization_endpoint", endpoint, &is_trusted_origin)?;
292    }
293    if let Some(endpoint) = config.token_endpoint() {
294        validate_trusted_url("token_endpoint", endpoint, &is_trusted_origin)?;
295    }
296    if let Some(endpoint) = config.user_info_endpoint() {
297        validate_trusted_url("userinfo_endpoint", endpoint, &is_trusted_origin)?;
298    }
299    if let Some(endpoint) = config.jwks_endpoint() {
300        validate_trusted_url("jwks_uri", endpoint, &is_trusted_origin)?;
301    }
302    if let Some(endpoint) = config.revocation_endpoint() {
303        validate_trusted_url("revocation_endpoint", endpoint, &is_trusted_origin)?;
304    }
305    if let Some(endpoint) = config.end_session_endpoint() {
306        validate_trusted_url("end_session_endpoint", endpoint, &is_trusted_origin)?;
307    }
308    if let Some(endpoint) = config.introspection_endpoint() {
309        validate_trusted_url("introspection_endpoint", endpoint, &is_trusted_origin)?;
310    }
311    Ok(())
312}
313
314#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
315pub struct PartialOidcDiscoveryConfig<'a> {
316    pub issuer: Option<&'a str>,
317    pub discovery_endpoint: Option<&'a str>,
318    pub authorization_endpoint: Option<&'a str>,
319    pub token_endpoint: Option<&'a str>,
320    pub user_info_endpoint: Option<&'a str>,
321    pub jwks_endpoint: Option<&'a str>,
322    pub revocation_endpoint: Option<&'a str>,
323    pub end_session_endpoint: Option<&'a str>,
324    pub introspection_endpoint: Option<&'a str>,
325    pub token_endpoint_authentication: Option<TokenEndpointAuthentication>,
326}
327
328#[derive(Debug, Clone, Copy, PartialEq, Eq)]
329pub enum OidcRuntimeRequirement {
330    SignIn,
331    Callback,
332}
333
334impl OidcRuntimeRequirement {
335    pub fn is_satisfied(self, config: &OidcConfig) -> bool {
336        // Better Auth performs runtime discovery unless the provider has the
337        // complete OIDC endpoint set needed across sign-in and callback.
338        // Preserve the enum for API clarity while keeping both modes aligned
339        // with that upstream contract.
340        let _ = self;
341        is_configured_oidc_endpoint(config.authorization_endpoint.as_deref())
342            && is_configured_oidc_endpoint(config.token_endpoint.as_deref())
343            && is_configured_oidc_endpoint(config.jwks_endpoint.as_deref())
344    }
345}
346
347pub fn needs_runtime_discovery(config: &OidcConfig, requirement: OidcRuntimeRequirement) -> bool {
348    !requirement.is_satisfied(config)
349}
350
351pub async fn ensure_runtime_oidc_config_with_origin_validator<F>(
352    issuer: &str,
353    config: OidcConfig,
354    requirement: OidcRuntimeRequirement,
355    is_trusted_origin: F,
356    validate_configured_origins: bool,
357    client: &reqwest::Client,
358) -> Result<OidcConfig, OidcDiscoveryError>
359where
360    F: Fn(&str) -> bool,
361{
362    if !needs_runtime_discovery(&config, requirement) {
363        if validate_configured_origins {
364            validate_configured_oidc_endpoint_origins(&config, &is_trusted_origin)?;
365        }
366        return Ok(config);
367    }
368
369    let hydrated = discover_oidc_config_with_origin_validator(
370        issuer,
371        (!config.discovery_endpoint.is_empty()).then_some(config.discovery_endpoint.as_str()),
372        PartialOidcDiscoveryConfig {
373            issuer: Some(config.issuer.as_str()),
374            discovery_endpoint: (!config.discovery_endpoint.is_empty())
375                .then_some(config.discovery_endpoint.as_str()),
376            authorization_endpoint: non_empty_endpoint(config.authorization_endpoint.as_deref()),
377            token_endpoint: non_empty_endpoint(config.token_endpoint.as_deref()),
378            user_info_endpoint: non_empty_endpoint(config.user_info_endpoint.as_deref()),
379            jwks_endpoint: non_empty_endpoint(config.jwks_endpoint.as_deref()),
380            revocation_endpoint: non_empty_endpoint(config.revocation_endpoint.as_deref()),
381            end_session_endpoint: non_empty_endpoint(config.end_session_endpoint.as_deref()),
382            introspection_endpoint: non_empty_endpoint(config.introspection_endpoint.as_deref()),
383            token_endpoint_authentication: config.token_endpoint_authentication,
384        },
385        &is_trusted_origin,
386        client,
387    )
388    .await?;
389
390    let hydrated_config = OidcConfig {
391        issuer: hydrated.issuer,
392        pkce: config.pkce,
393        client_id: config.client_id,
394        client_secret: config.client_secret,
395        discovery_endpoint: hydrated.discovery_endpoint,
396        authorization_endpoint: Some(hydrated.authorization_endpoint),
397        token_endpoint: Some(hydrated.token_endpoint),
398        user_info_endpoint: hydrated.user_info_endpoint,
399        jwks_endpoint: Some(hydrated.jwks_endpoint),
400        revocation_endpoint: hydrated.revocation_endpoint,
401        end_session_endpoint: hydrated.end_session_endpoint,
402        introspection_endpoint: hydrated.introspection_endpoint,
403        token_endpoint_authentication: Some(hydrated.token_endpoint_authentication),
404        scopes: config.scopes,
405        mapping: config.mapping,
406        override_user_info: config.override_user_info,
407    };
408
409    if validate_configured_origins {
410        validate_configured_oidc_endpoint_origins(&hydrated_config, &is_trusted_origin)?;
411    }
412    Ok(hydrated_config)
413}
414
415#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
416pub enum OidcDiscoveryError {
417    #[error("OIDC discovery request failed: {0}")]
418    Request(String),
419    #[error("OIDC discovery endpoint not found")]
420    NotFound,
421    #[error("OIDC discovery request timed out")]
422    Timeout,
423    #[error("OIDC discovery endpoint returned invalid JSON: {0}")]
424    InvalidJson(String),
425    #[error("OIDC discovery document contains untrusted URL for `{field}`: {url}")]
426    UntrustedOrigin { field: &'static str, url: String },
427    #[error("OIDC discovery document is missing required field `{0}`")]
428    MissingField(&'static str),
429    #[error("OIDC discovery document is missing required fields: {0:?}")]
430    MissingFields(Vec<&'static str>),
431    #[error("OIDC discovery issuer mismatch")]
432    IssuerMismatch,
433    #[error("OIDC discovery document contains invalid URL for `{field}`: {reason}")]
434    InvalidUrl { field: &'static str, reason: String },
435}
436
437impl OidcDiscoveryError {
438    pub fn code(&self) -> &'static str {
439        match self {
440            Self::Timeout => "discovery_timeout",
441            Self::NotFound => "discovery_not_found",
442            Self::InvalidJson(_) => "discovery_invalid_json",
443            Self::InvalidUrl { .. } => "discovery_invalid_url",
444            Self::UntrustedOrigin { .. } => "discovery_untrusted_origin",
445            Self::IssuerMismatch => "issuer_mismatch",
446            Self::MissingField(_) | Self::MissingFields(_) => "discovery_incomplete",
447            Self::Request(_) => "discovery_unexpected_error",
448        }
449    }
450
451    pub fn status(&self) -> http::StatusCode {
452        match self {
453            Self::Timeout | Self::Request(_) => http::StatusCode::BAD_GATEWAY,
454            Self::NotFound
455            | Self::InvalidJson(_)
456            | Self::InvalidUrl { .. }
457            | Self::UntrustedOrigin { .. }
458            | Self::IssuerMismatch
459            | Self::MissingField(_)
460            | Self::MissingFields(_) => http::StatusCode::BAD_REQUEST,
461        }
462    }
463}
464
465/// Validate a discovery URL before fetching.
466pub fn validate_discovery_url<F>(url: &str, is_trusted_origin: F) -> Result<(), OidcDiscoveryError>
467where
468    F: Fn(&str) -> bool,
469{
470    validate_trusted_url("discovery_endpoint", url, &is_trusted_origin)
471}
472
473/// Fetch the OIDC discovery document from the IdP.
474pub async fn fetch_discovery_document(
475    discovery_endpoint: &str,
476    client: &reqwest::Client,
477) -> Result<OidcDiscoveryDocument, OidcDiscoveryError> {
478    let response = client
479        .get(discovery_endpoint)
480        .header("accept", "application/json")
481        .timeout(std::time::Duration::from_secs(10))
482        .send()
483        .await
484        .map_err(classify_reqwest_error)?;
485    let status = response.status();
486    if status == http::StatusCode::NOT_FOUND {
487        return Err(OidcDiscoveryError::NotFound);
488    }
489    if status == http::StatusCode::REQUEST_TIMEOUT {
490        return Err(OidcDiscoveryError::Timeout);
491    }
492    let response = response
493        .error_for_status()
494        .map_err(classify_reqwest_error)?;
495    response
496        .json::<OidcDiscoveryDocument>()
497        .await
498        .map_err(|error| OidcDiscoveryError::InvalidJson(error.to_string()))
499}
500
501fn classify_reqwest_error(error: reqwest::Error) -> OidcDiscoveryError {
502    if error.is_timeout() {
503        return OidcDiscoveryError::Timeout;
504    }
505    if error.status() == Some(http::StatusCode::NOT_FOUND) {
506        return OidcDiscoveryError::NotFound;
507    }
508    OidcDiscoveryError::Request(error.to_string())
509}
510
511/// Validate a discovery document for required fields and issuer match.
512pub fn validate_discovery_document(
513    document: &OidcDiscoveryDocument,
514    issuer: &str,
515) -> Result<(), OidcDiscoveryError> {
516    let mut missing = Vec::new();
517    for field in REQUIRED_DISCOVERY_FIELDS {
518        let is_empty = match *field {
519            "issuer" => document.issuer.is_empty(),
520            "authorization_endpoint" => document.authorization_endpoint.is_empty(),
521            "token_endpoint" => document.token_endpoint.is_empty(),
522            "jwks_uri" => document.jwks_uri.is_empty(),
523            _ => false,
524        };
525        if is_empty {
526            missing.push(*field);
527        }
528    }
529    if !missing.is_empty() {
530        return Err(if missing.len() == 1 {
531            OidcDiscoveryError::MissingField(missing[0])
532        } else {
533            OidcDiscoveryError::MissingFields(missing)
534        });
535    }
536    if trim_trailing_slash(&document.issuer) != trim_trailing_slash(issuer) {
537        return Err(OidcDiscoveryError::IssuerMismatch);
538    }
539    Ok(())
540}
541
542/// Normalize discovery document URLs and validate each endpoint origin.
543pub fn normalize_discovery_urls<F>(
544    document: OidcDiscoveryDocument,
545    issuer: &str,
546    is_trusted_origin: F,
547) -> Result<OidcDiscoveryDocument, OidcDiscoveryError>
548where
549    F: Fn(&str) -> bool,
550{
551    let normalized = normalize_discovery_document(document, issuer)?;
552    validate_trusted_url(
553        "authorization_endpoint",
554        &normalized.authorization_endpoint,
555        &is_trusted_origin,
556    )?;
557    validate_trusted_url(
558        "token_endpoint",
559        &normalized.token_endpoint,
560        &is_trusted_origin,
561    )?;
562    validate_trusted_url("jwks_uri", &normalized.jwks_uri, &is_trusted_origin)?;
563    if let Some(userinfo_endpoint) = &normalized.userinfo_endpoint {
564        validate_trusted_url("userinfo_endpoint", userinfo_endpoint, &is_trusted_origin)?;
565    }
566    if let Some(revocation_endpoint) = &normalized.revocation_endpoint {
567        validate_trusted_url(
568            "revocation_endpoint",
569            revocation_endpoint,
570            &is_trusted_origin,
571        )?;
572    }
573    if let Some(end_session_endpoint) = &normalized.end_session_endpoint {
574        validate_trusted_url(
575            "end_session_endpoint",
576            end_session_endpoint,
577            &is_trusted_origin,
578        )?;
579    }
580    if let Some(introspection_endpoint) = &normalized.introspection_endpoint {
581        validate_trusted_url(
582            "introspection_endpoint",
583            introspection_endpoint,
584            &is_trusted_origin,
585        )?;
586    }
587    Ok(normalized)
588}
589
590fn normalize_discovery_document(
591    mut document: OidcDiscoveryDocument,
592    issuer: &str,
593) -> Result<OidcDiscoveryDocument, OidcDiscoveryError> {
594    document.authorization_endpoint = normalize_endpoint(
595        "authorization_endpoint",
596        &document.authorization_endpoint,
597        issuer,
598    )?;
599    document.token_endpoint =
600        normalize_endpoint("token_endpoint", &document.token_endpoint, issuer)?;
601    document.jwks_uri = normalize_endpoint("jwks_uri", &document.jwks_uri, issuer)?;
602    document.userinfo_endpoint = document
603        .userinfo_endpoint
604        .as_deref()
605        .map(|endpoint| normalize_endpoint("userinfo_endpoint", endpoint, issuer))
606        .transpose()?;
607    document.revocation_endpoint = document
608        .revocation_endpoint
609        .as_deref()
610        .map(|endpoint| normalize_endpoint("revocation_endpoint", endpoint, issuer))
611        .transpose()?;
612    document.end_session_endpoint = document
613        .end_session_endpoint
614        .as_deref()
615        .map(|endpoint| normalize_endpoint("end_session_endpoint", endpoint, issuer))
616        .transpose()?;
617    document.introspection_endpoint = document
618        .introspection_endpoint
619        .as_deref()
620        .map(|endpoint| normalize_endpoint("introspection_endpoint", endpoint, issuer))
621        .transpose()?;
622    Ok(document)
623}
624
625fn normalize_endpoint(
626    field: &'static str,
627    endpoint: &str,
628    issuer: &str,
629) -> Result<String, OidcDiscoveryError> {
630    if let Ok(url) = Url::parse(endpoint) {
631        ensure_supported_url_scheme(field, &url)?;
632        return Ok(url.to_string());
633    }
634
635    let issuer_url = Url::parse(issuer).map_err(|source| OidcDiscoveryError::InvalidUrl {
636        field,
637        reason: source.to_string(),
638    })?;
639    let origin = issuer_url.origin().ascii_serialization();
640    let base_path = issuer_url.path().trim_end_matches('/');
641    let endpoint_path = endpoint.trim_start_matches('/');
642    let url = Url::parse(&format!("{origin}{base_path}/{endpoint_path}")).map_err(|source| {
643        OidcDiscoveryError::InvalidUrl {
644            field,
645            reason: source.to_string(),
646        }
647    })?;
648    ensure_supported_url_scheme(field, &url)?;
649    Ok(url.to_string())
650}
651
652fn validate_trusted_url<F>(
653    field: &'static str,
654    value: &str,
655    is_trusted_origin: &F,
656) -> Result<(), OidcDiscoveryError>
657where
658    F: Fn(&str) -> bool,
659{
660    let url = Url::parse(value).map_err(|source| OidcDiscoveryError::InvalidUrl {
661        field,
662        reason: source.to_string(),
663    })?;
664    ensure_supported_url_scheme(field, &url)?;
665    if !is_trusted_origin(value) {
666        return Err(OidcDiscoveryError::UntrustedOrigin {
667            field,
668            url: value.to_owned(),
669        });
670    }
671    Ok(())
672}
673
674fn ensure_supported_url_scheme(field: &'static str, url: &Url) -> Result<(), OidcDiscoveryError> {
675    if matches!(url.scheme(), "http" | "https") {
676        return Ok(());
677    }
678    Err(OidcDiscoveryError::InvalidUrl {
679        field,
680        reason: format!("unsupported URL scheme `{}`", url.scheme()),
681    })
682}
683
684/// Select the token endpoint authentication method from discovery metadata.
685pub fn select_token_endpoint_authentication(
686    document: &OidcDiscoveryDocument,
687    existing: Option<TokenEndpointAuthentication>,
688) -> TokenEndpointAuthentication {
689    if let Some(existing) = existing {
690        return existing;
691    }
692    let Some(supported) = &document.token_endpoint_auth_methods_supported else {
693        return TokenEndpointAuthentication::ClientSecretBasic;
694    };
695    if supported
696        .iter()
697        .any(|method| method == "client_secret_basic")
698    {
699        return TokenEndpointAuthentication::ClientSecretBasic;
700    }
701    if supported
702        .iter()
703        .any(|method| method == "client_secret_post")
704    {
705        return TokenEndpointAuthentication::ClientSecretPost;
706    }
707    TokenEndpointAuthentication::ClientSecretBasic
708}
709
710fn trim_trailing_slash(value: &str) -> &str {
711    value.strip_suffix('/').unwrap_or(value)
712}
713
714#[cfg(test)]
715mod tests {
716    use super::*;
717
718    fn discovery_document(issuer: &str) -> OidcDiscoveryDocument {
719        OidcDiscoveryDocument {
720            issuer: issuer.to_owned(),
721            authorization_endpoint: format!("{issuer}/authorize"),
722            token_endpoint: format!("{issuer}/token"),
723            jwks_uri: format!("{issuer}/keys"),
724            userinfo_endpoint: Some(format!("{issuer}/userinfo")),
725            revocation_endpoint: None,
726            end_session_endpoint: None,
727            introspection_endpoint: None,
728            token_endpoint_auth_methods_supported: None,
729            scopes_supported: None,
730            response_types_supported: None,
731            subject_types_supported: None,
732            id_token_signing_alg_values_supported: None,
733            claims_supported: None,
734            code_challenge_methods_supported: None,
735        }
736    }
737
738    #[test]
739    fn normalizes_relative_discovery_endpoints_against_issuer_path(
740    ) -> Result<(), OidcDiscoveryError> {
741        assert_eq!(
742            normalize_endpoint(
743                "token_endpoint",
744                "oauth/token",
745                "https://idp.example.com/tenant"
746            )?,
747            "https://idp.example.com/tenant/oauth/token"
748        );
749        assert_eq!(
750            normalize_endpoint("jwks_uri", "/keys", "https://idp.example.com/tenant")?,
751            "https://idp.example.com/tenant/keys"
752        );
753        let document = normalize_discovery_document(
754            OidcDiscoveryDocument {
755                issuer: "https://idp.example.com/tenant".to_owned(),
756                authorization_endpoint: "authorize".to_owned(),
757                token_endpoint: "token".to_owned(),
758                jwks_uri: "keys".to_owned(),
759                userinfo_endpoint: Some("userinfo".to_owned()),
760                revocation_endpoint: Some("revoke".to_owned()),
761                end_session_endpoint: Some("endsession".to_owned()),
762                introspection_endpoint: Some("introspect".to_owned()),
763                token_endpoint_auth_methods_supported: None,
764                scopes_supported: None,
765                response_types_supported: None,
766                subject_types_supported: None,
767                id_token_signing_alg_values_supported: None,
768                claims_supported: None,
769                code_challenge_methods_supported: None,
770            },
771            "https://idp.example.com/tenant",
772        )?;
773        assert_eq!(
774            document.revocation_endpoint.as_deref(),
775            Some("https://idp.example.com/tenant/revoke")
776        );
777        assert_eq!(
778            document.end_session_endpoint.as_deref(),
779            Some("https://idp.example.com/tenant/endsession")
780        );
781        assert_eq!(
782            document.introspection_endpoint.as_deref(),
783            Some("https://idp.example.com/tenant/introspect")
784        );
785        Ok(())
786    }
787
788    #[test]
789    fn discovery_url_preserves_issuer_path() {
790        assert_eq!(
791            compute_discovery_url("https://idp.example.com/tenant/v1/"),
792            "https://idp.example.com/tenant/v1/.well-known/openid-configuration"
793        );
794    }
795
796    #[test]
797    fn absolute_http_url_api_rejects_relative_and_non_http_values() -> Result<(), OidcDiscoveryError>
798    {
799        assert!(normalize_absolute_http_url("discovery_endpoint", "/relative").is_err());
800        assert!(
801            normalize_absolute_http_url("discovery_endpoint", "ftp://idp.example.com").is_err()
802        );
803        assert_eq!(
804            normalize_absolute_http_url("discovery_endpoint", "https://idp.example.com")?,
805            "https://idp.example.com/"
806        );
807        Ok::<(), OidcDiscoveryError>(())
808    }
809
810    #[test]
811    fn normalize_endpoint_resolves_relative_urls_with_duplicate_slashes(
812    ) -> Result<(), OidcDiscoveryError> {
813        assert_eq!(
814            normalize_endpoint(
815                "token_endpoint",
816                "//oauth2/token",
817                "https://idp.example.com/base//",
818            )?,
819            "https://idp.example.com/base/oauth2/token"
820        );
821        assert_eq!(
822            normalize_endpoint(
823                "token_endpoint",
824                "oauth2/token",
825                "https://idp.example.com/base/"
826            )?,
827            "https://idp.example.com/base/oauth2/token"
828        );
829        Ok(())
830    }
831
832    #[test]
833    fn endpoint_url_api_resolves_relative_values_against_issuer_path(
834    ) -> Result<(), OidcDiscoveryError> {
835        assert_eq!(
836            normalize_endpoint_url(
837                "authorization_endpoint",
838                "/oauth2/authorize",
839                "https://idp.example.com/tenant/",
840            )?,
841            "https://idp.example.com/tenant/oauth2/authorize"
842        );
843        assert!(normalize_endpoint_url(
844            "authorization_endpoint",
845            "ftp://idp.example.com/authorize",
846            "https://idp.example.com/tenant/",
847        )
848        .is_err());
849        Ok::<(), OidcDiscoveryError>(())
850    }
851
852    #[test]
853    fn is_configured_oidc_endpoint_treats_empty_string_as_missing() {
854        assert!(!is_configured_oidc_endpoint(None));
855        assert!(!is_configured_oidc_endpoint(Some("")));
856        assert!(is_configured_oidc_endpoint(Some(
857            "https://idp.example.com/oauth2/v1/authorize"
858        )));
859    }
860
861    #[test]
862    fn runtime_discovery_treats_empty_string_endpoints_as_missing() {
863        let config = OidcConfig {
864            issuer: "https://idp.example.com".to_owned(),
865            pkce: true,
866            client_id: "client".to_owned(),
867            client_secret: "secret".into(),
868            discovery_endpoint: compute_discovery_url("https://idp.example.com"),
869            authorization_endpoint: Some(String::new()),
870            token_endpoint: Some("https://idp.example.com/token".to_owned()),
871            user_info_endpoint: None,
872            jwks_endpoint: Some("https://idp.example.com/keys".to_owned()),
873            revocation_endpoint: None,
874            end_session_endpoint: None,
875            introspection_endpoint: None,
876            token_endpoint_authentication: None,
877            scopes: None,
878            mapping: None,
879            override_user_info: false,
880        };
881
882        assert!(needs_runtime_discovery(
883            &config,
884            OidcRuntimeRequirement::SignIn
885        ));
886        assert!(needs_runtime_discovery(
887            &config,
888            OidcRuntimeRequirement::Callback
889        ));
890        assert!(!is_configured_oidc_endpoint(
891            config.authorization_endpoint.as_deref()
892        ));
893    }
894
895    #[test]
896    fn runtime_discovery_requirements_match_sign_in_and_callback_needs() {
897        let mut config = OidcConfig {
898            issuer: "https://idp.example.com".to_owned(),
899            pkce: true,
900            client_id: "client".to_owned(),
901            client_secret: "secret".into(),
902            discovery_endpoint: compute_discovery_url("https://idp.example.com"),
903            authorization_endpoint: None,
904            token_endpoint: Some("https://idp.example.com/token".to_owned()),
905            user_info_endpoint: Some("https://idp.example.com/userinfo".to_owned()),
906            jwks_endpoint: None,
907            revocation_endpoint: None,
908            end_session_endpoint: None,
909            introspection_endpoint: None,
910            token_endpoint_authentication: None,
911            scopes: None,
912            mapping: None,
913            override_user_info: false,
914        };
915
916        assert!(needs_runtime_discovery(
917            &config,
918            OidcRuntimeRequirement::SignIn
919        ));
920        assert!(needs_runtime_discovery(
921            &config,
922            OidcRuntimeRequirement::Callback
923        ));
924
925        config.authorization_endpoint = Some("https://idp.example.com/authorize".to_owned());
926        assert!(needs_runtime_discovery(
927            &config,
928            OidcRuntimeRequirement::SignIn
929        ));
930        assert!(needs_runtime_discovery(
931            &config,
932            OidcRuntimeRequirement::Callback
933        ));
934
935        config.user_info_endpoint = None;
936        config.jwks_endpoint = Some("https://idp.example.com/keys".to_owned());
937        assert!(!needs_runtime_discovery(
938            &config,
939            OidcRuntimeRequirement::SignIn
940        ));
941        assert!(!needs_runtime_discovery(
942            &config,
943            OidcRuntimeRequirement::Callback
944        ));
945    }
946
947    #[test]
948    fn discovery_errors_expose_stable_codes_and_statuses() {
949        assert_eq!(
950            OidcDiscoveryError::MissingField("issuer").code(),
951            "discovery_incomplete"
952        );
953        assert_eq!(
954            OidcDiscoveryError::MissingFields(vec!["issuer", "jwks_uri"]).code(),
955            "discovery_incomplete"
956        );
957        assert_eq!(OidcDiscoveryError::IssuerMismatch.code(), "issuer_mismatch");
958        assert_eq!(
959            OidcDiscoveryError::InvalidUrl {
960                field: "authorization_endpoint",
961                reason: "bad URL".to_owned(),
962            }
963            .code(),
964            "discovery_invalid_url"
965        );
966        assert_eq!(
967            OidcDiscoveryError::Timeout.status(),
968            http::StatusCode::BAD_GATEWAY
969        );
970        assert_eq!(
971            OidcDiscoveryError::InvalidJson("bad".to_owned()).status(),
972            http::StatusCode::BAD_REQUEST
973        );
974    }
975
976    #[test]
977    fn discovery_validation_reports_all_missing_required_fields(
978    ) -> Result<(), Box<dyn std::error::Error>> {
979        let document: OidcDiscoveryDocument = serde_json::from_str(
980            r#"{
981                "issuer":"https://idp.example.com"
982            }"#,
983        )?;
984
985        let error = match validate_discovery_document(&document, "https://idp.example.com") {
986            Ok(()) => return Err("expected incomplete discovery document".into()),
987            Err(error) => error,
988        };
989
990        assert_eq!(error.code(), "discovery_incomplete");
991        assert!(matches!(
992            error,
993            OidcDiscoveryError::MissingFields(fields)
994                if fields == vec!["authorization_endpoint", "token_endpoint", "jwks_uri"]
995        ));
996        Ok(())
997    }
998
999    #[test]
1000    fn discovery_validation_reports_each_missing_required_field() {
1001        for (field, document) in [
1002            (
1003                "issuer",
1004                OidcDiscoveryDocument {
1005                    issuer: String::new(),
1006                    ..discovery_document("https://idp.example.com")
1007                },
1008            ),
1009            (
1010                "authorization_endpoint",
1011                OidcDiscoveryDocument {
1012                    authorization_endpoint: String::new(),
1013                    ..discovery_document("https://idp.example.com")
1014                },
1015            ),
1016            (
1017                "token_endpoint",
1018                OidcDiscoveryDocument {
1019                    token_endpoint: String::new(),
1020                    ..discovery_document("https://idp.example.com")
1021                },
1022            ),
1023            (
1024                "jwks_uri",
1025                OidcDiscoveryDocument {
1026                    jwks_uri: String::new(),
1027                    ..discovery_document("https://idp.example.com")
1028                },
1029            ),
1030        ] {
1031            assert!(matches!(
1032                validate_discovery_document(&document, "https://idp.example.com"),
1033                Err(OidcDiscoveryError::MissingField(missing)) if missing == field
1034            ));
1035        }
1036    }
1037
1038    #[test]
1039    fn discovery_validation_normalizes_issuer_trailing_slash() {
1040        let document = discovery_document("https://idp.example.com/");
1041        assert!(validate_discovery_document(&document, "https://idp.example.com").is_ok());
1042        let document = discovery_document("https://idp.example.com");
1043        assert!(validate_discovery_document(&document, "https://idp.example.com/").is_ok());
1044    }
1045
1046    #[test]
1047    fn discovery_validation_rejects_issuer_mismatch() {
1048        let document = discovery_document("https://evil.example.com");
1049        assert!(matches!(
1050            validate_discovery_document(&document, "https://idp.example.com"),
1051            Err(OidcDiscoveryError::IssuerMismatch)
1052        ));
1053    }
1054
1055    #[test]
1056    fn required_discovery_fields_match_upstream_contract() {
1057        assert_eq!(
1058            REQUIRED_DISCOVERY_FIELDS,
1059            &[
1060                "issuer",
1061                "authorization_endpoint",
1062                "token_endpoint",
1063                "jwks_uri",
1064            ]
1065        );
1066    }
1067
1068    #[test]
1069    fn validate_discovery_url_rejects_invalid_and_untrusted_urls() {
1070        assert!(matches!(
1071            validate_discovery_url("not-a-url", |_| true),
1072            Err(OidcDiscoveryError::InvalidUrl { .. })
1073        ));
1074        assert!(matches!(
1075            validate_discovery_url("ftp://idp.example.com/config", |_| true),
1076            Err(OidcDiscoveryError::InvalidUrl { .. })
1077        ));
1078        assert!(matches!(
1079            validate_discovery_url(
1080                "https://untrusted.example.com/.well-known/openid-configuration",
1081                |_| false
1082            ),
1083            Err(OidcDiscoveryError::UntrustedOrigin { .. })
1084        ));
1085        assert!(validate_discovery_url(
1086            "https://idp.example.com/.well-known/openid-configuration",
1087            |_| true
1088        )
1089        .is_ok());
1090    }
1091
1092    #[test]
1093    fn normalize_discovery_urls_rejects_untrusted_required_endpoints(
1094    ) -> Result<(), Box<dyn std::error::Error>> {
1095        let document = OidcDiscoveryDocument {
1096            issuer: "https://idp.example.com".to_owned(),
1097            authorization_endpoint: "/oauth2/authorize".to_owned(),
1098            token_endpoint: "/oauth2/token".to_owned(),
1099            jwks_uri: "/.well-known/jwks.json".to_owned(),
1100            userinfo_endpoint: Some("/userinfo".to_owned()),
1101            revocation_endpoint: Some("/revoke".to_owned()),
1102            end_session_endpoint: Some("/endsession".to_owned()),
1103            introspection_endpoint: Some("/introspection".to_owned()),
1104            token_endpoint_auth_methods_supported: None,
1105            scopes_supported: None,
1106            response_types_supported: None,
1107            subject_types_supported: None,
1108            id_token_signing_alg_values_supported: None,
1109            claims_supported: None,
1110            code_challenge_methods_supported: None,
1111        };
1112
1113        for (suffix, field_hint) in [
1114            ("/oauth2/token", "token_endpoint"),
1115            ("/oauth2/authorize", "authorization_endpoint"),
1116            ("/.well-known/jwks.json", "jwks_uri"),
1117            ("/userinfo", "userinfo_endpoint"),
1118            ("/revoke", "revocation_endpoint"),
1119            ("/endsession", "end_session_endpoint"),
1120            ("/introspection", "introspection_endpoint"),
1121        ] {
1122            let error =
1123                match normalize_discovery_urls(document.clone(), "https://idp.example.com", |url| {
1124                    !url.ends_with(suffix)
1125                }) {
1126                    Ok(_) => return Err(format!("expected untrusted {field_hint}").into()),
1127                    Err(error) => error,
1128                };
1129            assert_eq!(error.code(), "discovery_untrusted_origin");
1130            assert!(error.to_string().contains(field_hint));
1131        }
1132        Ok(())
1133    }
1134
1135    #[test]
1136    fn token_endpoint_authentication_prefers_existing_config_value() {
1137        let document = discovery_document("https://idp.example.com");
1138        assert_eq!(
1139            select_token_endpoint_authentication(
1140                &document,
1141                Some(TokenEndpointAuthentication::ClientSecretPost)
1142            ),
1143            TokenEndpointAuthentication::ClientSecretPost
1144        );
1145    }
1146
1147    #[test]
1148    fn token_endpoint_authentication_prefers_client_secret_basic_when_both_supported() {
1149        let mut document = discovery_document("https://idp.example.com");
1150        document.token_endpoint_auth_methods_supported = Some(vec![
1151            "client_secret_post".to_owned(),
1152            "client_secret_basic".to_owned(),
1153        ]);
1154        assert_eq!(
1155            select_token_endpoint_authentication(&document, None),
1156            TokenEndpointAuthentication::ClientSecretBasic
1157        );
1158    }
1159
1160    #[test]
1161    fn token_endpoint_authentication_selects_client_secret_post_when_only_supported() {
1162        let mut document = discovery_document("https://idp.example.com");
1163        document.token_endpoint_auth_methods_supported =
1164            Some(vec!["client_secret_post".to_owned()]);
1165        assert_eq!(
1166            select_token_endpoint_authentication(&document, None),
1167            TokenEndpointAuthentication::ClientSecretPost
1168        );
1169    }
1170
1171    #[test]
1172    fn normalize_absolute_http_url_accepts_http_and_https() -> Result<(), OidcDiscoveryError> {
1173        assert_eq!(
1174            normalize_absolute_http_url("discovery_endpoint", "http://idp.example.com/path")?,
1175            "http://idp.example.com/path"
1176        );
1177        assert_eq!(
1178            normalize_absolute_http_url("discovery_endpoint", "https://idp.example.com/path")?,
1179            "https://idp.example.com/path"
1180        );
1181        Ok(())
1182    }
1183
1184    #[test]
1185    fn token_endpoint_authentication_defaults_for_empty_or_unsupported_methods() {
1186        let mut document = discovery_document("https://idp.example.com");
1187        document.token_endpoint_auth_methods_supported = Some(Vec::new());
1188        assert_eq!(
1189            select_token_endpoint_authentication(&document, None),
1190            TokenEndpointAuthentication::ClientSecretBasic
1191        );
1192
1193        document.token_endpoint_auth_methods_supported = Some(vec![
1194            "private_key_jwt".to_owned(),
1195            "tls_client_auth".to_owned(),
1196        ]);
1197        assert_eq!(
1198            select_token_endpoint_authentication(&document, None),
1199            TokenEndpointAuthentication::ClientSecretBasic
1200        );
1201    }
1202
1203    #[test]
1204    fn discovery_validation_accepts_document_without_optional_metadata(
1205    ) -> Result<(), Box<dyn std::error::Error>> {
1206        let document: OidcDiscoveryDocument = serde_json::from_str(
1207            r#"{
1208                "issuer":"https://idp.example.com",
1209                "authorization_endpoint":"https://idp.example.com/authorize",
1210                "token_endpoint":"https://idp.example.com/token",
1211                "jwks_uri":"https://idp.example.com/keys"
1212            }"#,
1213        )?;
1214
1215        validate_discovery_document(&document, "https://idp.example.com")?;
1216        assert_eq!(document.userinfo_endpoint, None);
1217        assert_eq!(document.response_types_supported, None);
1218        assert_eq!(document.subject_types_supported, None);
1219        assert_eq!(document.id_token_signing_alg_values_supported, None);
1220        assert_eq!(document.claims_supported, None);
1221        assert_eq!(document.code_challenge_methods_supported, None);
1222        Ok(())
1223    }
1224
1225    #[tokio::test]
1226    async fn fetch_discovery_document_classifies_http_and_json_errors(
1227    ) -> Result<(), Box<dyn std::error::Error>> {
1228        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1229        let address = listener.local_addr()?;
1230        tokio::spawn(async move {
1231            while let Ok((mut stream, _)) = listener.accept().await {
1232                tokio::spawn(async move {
1233                    let mut buffer = [0_u8; 1024];
1234                    let Ok(read) = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await
1235                    else {
1236                        return;
1237                    };
1238                    let request = String::from_utf8_lossy(&buffer[..read]);
1239                    let (status, body) = if request.starts_with("GET /missing ") {
1240                        ("404 Not Found", "not found")
1241                    } else if request.starts_with("GET /server-error ") {
1242                        ("500 Internal Server Error", "server error")
1243                    } else if request.starts_with("GET /timeout-status ") {
1244                        ("408 Request Timeout", "timeout")
1245                    } else if request.starts_with("GET /empty ") {
1246                        ("200 OK", "")
1247                    } else {
1248                        ("200 OK", "not-json")
1249                    };
1250                    let response = format!(
1251                        "HTTP/1.1 {status}\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
1252                        body.len()
1253                    );
1254                    let _ =
1255                        tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1256                });
1257            }
1258        });
1259
1260        let client = reqwest::Client::new();
1261        let missing_error =
1262            match fetch_discovery_document(&format!("http://{address}/missing"), &client).await {
1263                Ok(_) => return Err("expected missing discovery document to fail".into()),
1264                Err(error) => error,
1265            };
1266        assert_eq!(missing_error.code(), "discovery_not_found");
1267
1268        let server_error = match fetch_discovery_document(
1269            &format!("http://{address}/server-error"),
1270            &client,
1271        )
1272        .await
1273        {
1274            Ok(_) => return Err("expected server error discovery document to fail".into()),
1275            Err(error) => error,
1276        };
1277        assert_eq!(server_error.code(), "discovery_unexpected_error");
1278
1279        let timeout_error =
1280            match fetch_discovery_document(&format!("http://{address}/timeout-status"), &client)
1281                .await
1282            {
1283                Ok(_) => return Err("expected timeout discovery document to fail".into()),
1284                Err(error) => error,
1285            };
1286        assert_eq!(timeout_error.code(), "discovery_timeout");
1287
1288        let empty_response_error =
1289            match fetch_discovery_document(&format!("http://{address}/empty"), &client).await {
1290                Ok(_) => return Err("expected empty discovery document to fail".into()),
1291                Err(error) => error,
1292            };
1293        assert_eq!(empty_response_error.code(), "discovery_invalid_json");
1294
1295        let invalid_json_error = match fetch_discovery_document(
1296            &format!("http://{address}/invalid-json"),
1297            &client,
1298        )
1299        .await
1300        {
1301            Ok(_) => return Err("expected invalid JSON discovery document to fail".into()),
1302            Err(error) => error,
1303        };
1304        assert_eq!(invalid_json_error.code(), "discovery_invalid_json");
1305        Ok(())
1306    }
1307
1308    #[tokio::test]
1309    async fn discovery_rejects_untrusted_discovered_endpoint_origins(
1310    ) -> Result<(), Box<dyn std::error::Error>> {
1311        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1312        let address = listener.local_addr()?;
1313        let base_url = format!("http://{address}");
1314        let server_base_url = base_url.clone();
1315        tokio::spawn(async move {
1316            while let Ok((mut stream, _)) = listener.accept().await {
1317                let server_base_url = server_base_url.clone();
1318                tokio::spawn(async move {
1319                    let mut buffer = [0_u8; 1024];
1320                    let Ok(read) = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await
1321                    else {
1322                        return;
1323                    };
1324                    let request = String::from_utf8_lossy(&buffer[..read]);
1325                    let body = if request.starts_with("GET /.well-known/openid-configuration ") {
1326                        format!(
1327                            r#"{{
1328                                "issuer":"{server_base_url}",
1329                                "authorization_endpoint":"{server_base_url}/authorize",
1330                                "token_endpoint":"https://untrusted.example.com/token",
1331                                "jwks_uri":"{server_base_url}/keys",
1332                                "userinfo_endpoint":"{server_base_url}/userinfo"
1333                            }}"#
1334                        )
1335                    } else {
1336                        r#"{"error":"not_found"}"#.to_owned()
1337                    };
1338                    let response = format!(
1339                        "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
1340                        body.len()
1341                    );
1342                    let _ =
1343                        tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1344                });
1345            }
1346        });
1347
1348        let error = match discover_oidc_config_with_origin_validator(
1349            &base_url,
1350            None,
1351            PartialOidcDiscoveryConfig::default(),
1352            |url| url.starts_with(&base_url),
1353            &reqwest::Client::new(),
1354        )
1355        .await
1356        {
1357            Ok(_) => return Err("expected untrusted discovered endpoint to fail".into()),
1358            Err(error) => error,
1359        };
1360        assert_eq!(error.code(), "discovery_untrusted_origin");
1361        Ok(())
1362    }
1363
1364    #[tokio::test]
1365    async fn discovery_rejects_untrusted_optional_endpoint_origins(
1366    ) -> Result<(), Box<dyn std::error::Error>> {
1367        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1368        let address = listener.local_addr()?;
1369        let base_url = format!("http://{address}");
1370        let server_base_url = base_url.clone();
1371        tokio::spawn(async move {
1372            while let Ok((mut stream, _)) = listener.accept().await {
1373                let server_base_url = server_base_url.clone();
1374                tokio::spawn(async move {
1375                    let mut buffer = [0_u8; 1024];
1376                    let Ok(read) = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await
1377                    else {
1378                        return;
1379                    };
1380                    let request = String::from_utf8_lossy(&buffer[..read]);
1381                    let body = if request.starts_with("GET /.well-known/openid-configuration ") {
1382                        format!(
1383                            r#"{{
1384                                "issuer":"{server_base_url}",
1385                                "authorization_endpoint":"{server_base_url}/authorize",
1386                                "token_endpoint":"{server_base_url}/token",
1387                                "jwks_uri":"{server_base_url}/keys",
1388                                "revocation_endpoint":"https://untrusted.example.com/revoke"
1389                            }}"#
1390                        )
1391                    } else {
1392                        r#"{"error":"not_found"}"#.to_owned()
1393                    };
1394                    let response = format!(
1395                        "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
1396                        body.len()
1397                    );
1398                    let _ =
1399                        tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1400                });
1401            }
1402        });
1403
1404        let error = match discover_oidc_config_with_origin_validator(
1405            &base_url,
1406            None,
1407            PartialOidcDiscoveryConfig::default(),
1408            |url| url.starts_with(&base_url),
1409            &reqwest::Client::new(),
1410        )
1411        .await
1412        {
1413            Ok(_) => return Err("expected untrusted optional endpoint to fail".into()),
1414            Err(error) => error,
1415        };
1416        assert_eq!(error.code(), "discovery_untrusted_origin");
1417        assert!(error.to_string().contains("revocation_endpoint"));
1418        Ok(())
1419    }
1420
1421    #[tokio::test]
1422    async fn discover_ignores_empty_existing_endpoint_overrides(
1423    ) -> Result<(), Box<dyn std::error::Error>> {
1424        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1425        let address = listener.local_addr()?;
1426        let base_url = format!("http://{address}");
1427        let server_base_url = base_url.clone();
1428        tokio::spawn(async move {
1429            while let Ok((mut stream, _)) = listener.accept().await {
1430                let server_base_url = server_base_url.clone();
1431                tokio::spawn(async move {
1432                    let mut buffer = [0_u8; 1024];
1433                    let Ok(read) = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await
1434                    else {
1435                        return;
1436                    };
1437                    let request = String::from_utf8_lossy(&buffer[..read]);
1438                    let body = if request.starts_with("GET /.well-known/openid-configuration ") {
1439                        format!(
1440                            r#"{{
1441                                "issuer":"{server_base_url}",
1442                                "authorization_endpoint":"{server_base_url}/authorize",
1443                                "token_endpoint":"{server_base_url}/token",
1444                                "jwks_uri":"{server_base_url}/keys"
1445                            }}"#
1446                        )
1447                    } else {
1448                        r#"{"error":"not_found"}"#.to_owned()
1449                    };
1450                    let response = format!(
1451                        "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
1452                        body.len()
1453                    );
1454                    let _ =
1455                        tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1456                });
1457            }
1458        });
1459
1460        let hydrated = discover_oidc_config_with_origin_validator(
1461            &base_url,
1462            None,
1463            PartialOidcDiscoveryConfig {
1464                authorization_endpoint: Some(""),
1465                ..PartialOidcDiscoveryConfig::default()
1466            },
1467            |url| url.starts_with(&base_url),
1468            &reqwest::Client::new(),
1469        )
1470        .await?;
1471
1472        assert_eq!(
1473            hydrated.authorization_endpoint,
1474            format!("{base_url}/authorize")
1475        );
1476        Ok(())
1477    }
1478
1479    #[tokio::test]
1480    async fn discovery_preserves_user_supplied_endpoints_over_discovered_values(
1481    ) -> Result<(), Box<dyn std::error::Error>> {
1482        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1483        let address = listener.local_addr()?;
1484        let base_url = format!("http://{address}");
1485        let server_base_url = base_url.clone();
1486        tokio::spawn(async move {
1487            while let Ok((mut stream, _)) = listener.accept().await {
1488                let server_base_url = server_base_url.clone();
1489                tokio::spawn(async move {
1490                    let mut buffer = [0_u8; 1024];
1491                    let Ok(read) = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await
1492                    else {
1493                        return;
1494                    };
1495                    let request = String::from_utf8_lossy(&buffer[..read]);
1496                    let body = if request.starts_with("GET /.well-known/openid-configuration ") {
1497                        format!(
1498                            r#"{{
1499                                "issuer":"{server_base_url}",
1500                                "authorization_endpoint":"{server_base_url}/discovered/authorize",
1501                                "token_endpoint":"{server_base_url}/discovered/token",
1502                                "jwks_uri":"{server_base_url}/discovered/keys",
1503                                "userinfo_endpoint":"{server_base_url}/discovered/userinfo",
1504                                "token_endpoint_auth_methods_supported":["client_secret_post"]
1505                            }}"#
1506                        )
1507                    } else {
1508                        r#"{"error":"not_found"}"#.to_owned()
1509                    };
1510                    let response = format!(
1511                        "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
1512                        body.len()
1513                    );
1514                    let _ =
1515                        tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1516                });
1517            }
1518        });
1519
1520        let custom_authorization_endpoint = format!("{base_url}/custom/authorize");
1521        let custom_token_endpoint = format!("{base_url}/custom/token");
1522        let custom_user_info_endpoint = format!("{base_url}/custom/userinfo");
1523        let custom_jwks_endpoint = format!("{base_url}/custom/keys");
1524        let existing = PartialOidcDiscoveryConfig {
1525            authorization_endpoint: Some(&custom_authorization_endpoint),
1526            token_endpoint: Some(&custom_token_endpoint),
1527            user_info_endpoint: Some(&custom_user_info_endpoint),
1528            jwks_endpoint: Some(&custom_jwks_endpoint),
1529            token_endpoint_authentication: Some(TokenEndpointAuthentication::ClientSecretBasic),
1530            ..PartialOidcDiscoveryConfig::default()
1531        };
1532
1533        let hydrated = discover_oidc_config_with_origin_validator(
1534            &base_url,
1535            None,
1536            existing,
1537            |url| url.starts_with(&base_url),
1538            &reqwest::Client::new(),
1539        )
1540        .await?;
1541
1542        assert_eq!(
1543            hydrated.authorization_endpoint,
1544            custom_authorization_endpoint
1545        );
1546        assert_eq!(hydrated.token_endpoint, custom_token_endpoint);
1547        assert_eq!(hydrated.jwks_endpoint, custom_jwks_endpoint);
1548        assert_eq!(
1549            hydrated.user_info_endpoint.as_deref(),
1550            Some(custom_user_info_endpoint.as_str())
1551        );
1552        assert_eq!(
1553            hydrated.token_endpoint_authentication,
1554            TokenEndpointAuthentication::ClientSecretBasic
1555        );
1556        Ok(())
1557    }
1558
1559    #[tokio::test]
1560    async fn discover_uses_custom_and_existing_discovery_endpoints(
1561    ) -> Result<(), Box<dyn std::error::Error>> {
1562        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1563        let address = listener.local_addr()?;
1564        let base_url = format!("http://{address}");
1565        let server_base_url = base_url.clone();
1566        tokio::spawn(async move {
1567            while let Ok((mut stream, _)) = listener.accept().await {
1568                let server_base_url = server_base_url.clone();
1569                tokio::spawn(async move {
1570                    let mut buffer = [0_u8; 4096];
1571                    let Ok(read) = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await
1572                    else {
1573                        return;
1574                    };
1575                    let request = String::from_utf8_lossy(&buffer[..read]);
1576                    let body = if request.contains("GET /custom/.well-known/openid-configuration ")
1577                    {
1578                        format!(
1579                            r#"{{
1580                                "issuer":"{server_base_url}",
1581                                "authorization_endpoint":"{server_base_url}/authorize",
1582                                "token_endpoint":"{server_base_url}/token",
1583                                "jwks_uri":"{server_base_url}/keys"
1584                            }}"#
1585                        )
1586                    } else if request.contains("GET /tenant/.well-known/openid-configuration ") {
1587                        format!(
1588                            r#"{{
1589                                "issuer":"{server_base_url}",
1590                                "authorization_endpoint":"{server_base_url}/tenant/authorize",
1591                                "token_endpoint":"{server_base_url}/tenant/token",
1592                                "jwks_uri":"{server_base_url}/tenant/keys"
1593                            }}"#
1594                        )
1595                    } else {
1596                        r#"{"error":"not_found"}"#.to_owned()
1597                    };
1598                    let response = format!(
1599                        "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
1600                        body.len()
1601                    );
1602                    let _ =
1603                        tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1604                });
1605            }
1606        });
1607
1608        let custom_endpoint = format!("{base_url}/custom/.well-known/openid-configuration");
1609        let custom = discover_oidc_config_with_origin_validator(
1610            &base_url,
1611            Some(&custom_endpoint),
1612            PartialOidcDiscoveryConfig::default(),
1613            |url| url.starts_with(&base_url),
1614            &reqwest::Client::new(),
1615        )
1616        .await?;
1617        assert_eq!(custom.discovery_endpoint, custom_endpoint);
1618
1619        let existing_endpoint = format!("{base_url}/tenant/.well-known/openid-configuration");
1620        let existing = discover_oidc_config_with_origin_validator(
1621            &base_url,
1622            None,
1623            PartialOidcDiscoveryConfig {
1624                discovery_endpoint: Some(&existing_endpoint),
1625                ..PartialOidcDiscoveryConfig::default()
1626            },
1627            |url| url.starts_with(&base_url),
1628            &reqwest::Client::new(),
1629        )
1630        .await?;
1631        assert_eq!(existing.discovery_endpoint, existing_endpoint);
1632        assert_eq!(
1633            existing.authorization_endpoint,
1634            format!("{base_url}/tenant/authorize")
1635        );
1636        Ok(())
1637    }
1638
1639    #[tokio::test]
1640    async fn discover_includes_scopes_supported_and_ignores_unknown_fields(
1641    ) -> Result<(), Box<dyn std::error::Error>> {
1642        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1643        let address = listener.local_addr()?;
1644        let base_url = format!("http://{address}");
1645        let server_base_url = base_url.clone();
1646        tokio::spawn(async move {
1647            while let Ok((mut stream, _)) = listener.accept().await {
1648                let server_base_url = server_base_url.clone();
1649                tokio::spawn(async move {
1650                    let mut buffer = [0_u8; 1024];
1651                    let Ok(read) = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await
1652                    else {
1653                        return;
1654                    };
1655                    let request = String::from_utf8_lossy(&buffer[..read]);
1656                    let body = if request.starts_with("GET /.well-known/openid-configuration ") {
1657                        format!(
1658                            r#"{{
1659                                "issuer":"{server_base_url}",
1660                                "authorization_endpoint":"{server_base_url}/authorize",
1661                                "token_endpoint":"{server_base_url}/token",
1662                                "jwks_uri":"{server_base_url}/keys",
1663                                "scopes_supported":["openid","profile","email","custom"],
1664                                "x-vendor-feature":true,
1665                                "custom_logout_endpoint":"{server_base_url}/logout"
1666                            }}"#
1667                        )
1668                    } else {
1669                        r#"{"error":"not_found"}"#.to_owned()
1670                    };
1671                    let response = format!(
1672                        "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
1673                        body.len()
1674                    );
1675                    let _ =
1676                        tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1677                });
1678            }
1679        });
1680
1681        let hydrated = discover_oidc_config_with_origin_validator(
1682            &base_url,
1683            None,
1684            PartialOidcDiscoveryConfig::default(),
1685            |url| url.starts_with(&base_url),
1686            &reqwest::Client::new(),
1687        )
1688        .await?;
1689
1690        assert_eq!(
1691            hydrated.scopes_supported,
1692            Some(vec![
1693                "openid".to_owned(),
1694                "profile".to_owned(),
1695                "email".to_owned(),
1696                "custom".to_owned()
1697            ])
1698        );
1699        assert_eq!(hydrated.user_info_endpoint, None);
1700        Ok(())
1701    }
1702
1703    #[tokio::test]
1704    async fn discover_rejects_untrusted_main_discovery_url(
1705    ) -> Result<(), Box<dyn std::error::Error>> {
1706        let error = match discover_oidc_config_with_origin_validator(
1707            "https://idp.example.com",
1708            None,
1709            PartialOidcDiscoveryConfig::default(),
1710            |_| false,
1711            &reqwest::Client::new(),
1712        )
1713        .await
1714        {
1715            Ok(_) => return Err("expected untrusted discovery URL to fail".into()),
1716            Err(error) => error,
1717        };
1718        assert_eq!(error.code(), "discovery_untrusted_origin");
1719        assert!(error.to_string().contains("discovery_endpoint"));
1720        Ok(())
1721    }
1722
1723    #[tokio::test]
1724    async fn ensure_runtime_returns_unchanged_config_when_discovery_not_needed(
1725    ) -> Result<(), Box<dyn std::error::Error>> {
1726        let config = OidcConfig {
1727            issuer: "https://idp.example.com".to_owned(),
1728            pkce: true,
1729            client_id: "client-id".to_owned(),
1730            client_secret: "client-secret".into(),
1731            discovery_endpoint: compute_discovery_url("https://idp.example.com"),
1732            authorization_endpoint: Some("https://idp.example.com/authorize".to_owned()),
1733            token_endpoint: Some("https://idp.example.com/token".to_owned()),
1734            user_info_endpoint: Some("https://idp.example.com/userinfo".to_owned()),
1735            jwks_endpoint: Some("https://idp.example.com/keys".to_owned()),
1736            revocation_endpoint: None,
1737            end_session_endpoint: None,
1738            introspection_endpoint: None,
1739            token_endpoint_authentication: None,
1740            scopes: Some(vec!["openid".to_owned()]),
1741            mapping: None,
1742            override_user_info: false,
1743        };
1744
1745        let unchanged = ensure_runtime_oidc_config_with_origin_validator(
1746            "https://idp.example.com",
1747            config.clone(),
1748            OidcRuntimeRequirement::Callback,
1749            |_| true,
1750            false,
1751            &reqwest::Client::new(),
1752        )
1753        .await?;
1754
1755        assert_eq!(unchanged.client_id, config.client_id);
1756        assert_eq!(
1757            unchanged.client_secret.expose_secret(),
1758            config.client_secret.expose_secret()
1759        );
1760        assert_eq!(unchanged.pkce, config.pkce);
1761        assert_eq!(unchanged.scopes, config.scopes);
1762        assert_eq!(
1763            unchanged.authorization_endpoint,
1764            config.authorization_endpoint
1765        );
1766        Ok(())
1767    }
1768
1769    #[tokio::test]
1770    async fn ensure_runtime_throws_when_discovery_fails() -> Result<(), Box<dyn std::error::Error>>
1771    {
1772        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1773        let address = listener.local_addr()?;
1774        let base_url = format!("http://{address}");
1775        tokio::spawn(async move {
1776            while let Ok((mut stream, _)) = listener.accept().await {
1777                tokio::spawn(async move {
1778                    let mut buffer = [0_u8; 1024];
1779                    let _ = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await;
1780                    let response =
1781                        "HTTP/1.1 404 Not Found\r\ncontent-type: application/json\r\ncontent-length: 2\r\nconnection: close\r\n\r\n{}";
1782                    let _ =
1783                        tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1784                });
1785            }
1786        });
1787
1788        let config = OidcConfig {
1789            issuer: base_url.clone(),
1790            pkce: true,
1791            client_id: "client-id".to_owned(),
1792            client_secret: "client-secret".into(),
1793            discovery_endpoint: compute_discovery_url(&base_url),
1794            authorization_endpoint: None,
1795            token_endpoint: None,
1796            user_info_endpoint: None,
1797            jwks_endpoint: None,
1798            revocation_endpoint: None,
1799            end_session_endpoint: None,
1800            introspection_endpoint: None,
1801            token_endpoint_authentication: None,
1802            scopes: None,
1803            mapping: None,
1804            override_user_info: false,
1805        };
1806
1807        let error = match ensure_runtime_oidc_config_with_origin_validator(
1808            &base_url,
1809            config,
1810            OidcRuntimeRequirement::SignIn,
1811            |_| true,
1812            false,
1813            &reqwest::Client::new(),
1814        )
1815        .await
1816        {
1817            Ok(_) => return Err("expected runtime discovery failure".into()),
1818            Err(error) => error,
1819        };
1820        assert_eq!(error.code(), "discovery_not_found");
1821        Ok(())
1822    }
1823
1824    #[tokio::test]
1825    async fn runtime_discovery_preserves_only_explicit_request_scopes(
1826    ) -> Result<(), Box<dyn std::error::Error>> {
1827        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
1828        let address = listener.local_addr()?;
1829        let base_url = format!("http://{address}");
1830        let server_base_url = base_url.clone();
1831        tokio::spawn(async move {
1832            while let Ok((mut stream, _)) = listener.accept().await {
1833                let server_base_url = server_base_url.clone();
1834                tokio::spawn(async move {
1835                    let mut buffer = [0_u8; 1024];
1836                    let Ok(read) = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await
1837                    else {
1838                        return;
1839                    };
1840                    let request = String::from_utf8_lossy(&buffer[..read]);
1841                    let body = if request.starts_with("GET /.well-known/openid-configuration ") {
1842                        format!(
1843                            r#"{{
1844                                "issuer":"{server_base_url}",
1845                                "authorization_endpoint":"{server_base_url}/authorize",
1846                                "token_endpoint":"{server_base_url}/token",
1847                                "jwks_uri":"{server_base_url}/keys",
1848                                "scopes_supported":["openid","profile"]
1849                            }}"#
1850                        )
1851                    } else {
1852                        r#"{"error":"not_found"}"#.to_owned()
1853                    };
1854                    let response = format!(
1855                        "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
1856                        body.len()
1857                    );
1858                    let _ =
1859                        tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()).await;
1860                });
1861            }
1862        });
1863
1864        let config = OidcConfig {
1865            issuer: base_url.clone(),
1866            pkce: true,
1867            client_id: "client".to_owned(),
1868            client_secret: "secret".into(),
1869            discovery_endpoint: compute_discovery_url(&base_url),
1870            authorization_endpoint: None,
1871            token_endpoint: None,
1872            user_info_endpoint: None,
1873            jwks_endpoint: None,
1874            revocation_endpoint: None,
1875            end_session_endpoint: None,
1876            introspection_endpoint: None,
1877            token_endpoint_authentication: None,
1878            scopes: None,
1879            mapping: None,
1880            override_user_info: false,
1881        };
1882
1883        let hydrated = ensure_runtime_oidc_config_with_origin_validator(
1884            &base_url,
1885            config,
1886            OidcRuntimeRequirement::SignIn,
1887            |url| url.starts_with(&base_url),
1888            false,
1889            &reqwest::Client::new(),
1890        )
1891        .await?;
1892
1893        assert_eq!(hydrated.scopes, None);
1894
1895        let explicit_config = OidcConfig {
1896            scopes: Some(vec!["openid".to_owned(), "email".to_owned()]),
1897            authorization_endpoint: None,
1898            token_endpoint: None,
1899            jwks_endpoint: None,
1900            ..hydrated
1901        };
1902        let explicit_hydrated = ensure_runtime_oidc_config_with_origin_validator(
1903            &base_url,
1904            explicit_config,
1905            OidcRuntimeRequirement::SignIn,
1906            |url| url.starts_with(&base_url),
1907            false,
1908            &reqwest::Client::new(),
1909        )
1910        .await?;
1911
1912        assert_eq!(
1913            explicit_hydrated.scopes,
1914            Some(vec!["openid".to_owned(), "email".to_owned()])
1915        );
1916        Ok(())
1917    }
1918}