use serde::{Deserialize, Serialize};
use url::Url;
use super::error::DiscoveryError;
type Result<T> = std::result::Result<T, DiscoveryError>;
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub(crate) struct Discovery {
pub issuer: Url,
pub authorization_endpoint: Url,
pub token_endpoint: Url,
pub jwks_uri: Url,
#[serde(default = "discovery_default_scopes")]
pub scopes_supported: Vec<String>,
}
pub(crate) const DISCOVERY_REQUIRED_SCOPE: &str = "openid";
fn discovery_default_scopes() -> Vec<String> {
vec![DISCOVERY_REQUIRED_SCOPE.to_string()]
}
impl Discovery {
#[cfg(test)]
pub(crate) fn new_for_test(issuer: Url) -> Self {
Self {
authorization_endpoint: issuer.join("/v1/authorize").unwrap(),
token_endpoint: issuer.join("/v1/token").unwrap(),
jwks_uri: issuer.join("/.well-known/jwks.json").unwrap(),
scopes_supported: discovery_default_scopes(),
issuer,
}
}
}
pub(crate) async fn fetch_discovery(
http: &reqwest::Client,
issuer: impl AsRef<str>,
) -> Result<Discovery> {
let issuer: Url = issuer.as_ref().parse()?;
fetch_discovery_impl(
http,
issuer,
#[cfg(not(any(test, feature = "_insecure-issuer-validation")))]
ValidationStrategy::strict(),
#[cfg(any(test, feature = "_insecure-issuer-validation"))]
ValidationStrategy::insecure(),
)
.await
}
async fn fetch_discovery_impl(
http: &reqwest::Client,
issuer: Url,
validator: impl ValidateIssuer,
) -> Result<Discovery> {
validator.validate_issuer(&issuer)?;
let discovery_url = {
let mut url = issuer.clone();
match url.path_segments_mut() {
Ok(mut segments) => {
segments.extend(&[".well-known", "openid-configuration"]);
}
Err(()) => {
return Err(DiscoveryError::InvalidIssuer {
issuer: issuer.to_string(),
reason: "the issuer URL is not a valid URL".to_string(),
});
}
}
url
};
#[cfg(feature = "tracing")]
tracing::info!(
discovery_url = discovery_url.as_str(),
"Fetching OIDC discovery document."
);
let discovery: Discovery = http
.get(discovery_url.clone())
.header(reqwest::header::ACCEPT, "application/json")
.send()
.await?
.error_for_status()?
.json()
.await?;
validator.validate_issuer(&discovery.issuer)?;
if issuer != discovery.issuer {
return Err(DiscoveryError::IssuerMismatch {
document: discovery.issuer.into(),
query: issuer.to_string(),
});
}
if discovery
.scopes_supported
.iter()
.all(|scope| scope != DISCOVERY_REQUIRED_SCOPE)
{
return Err(DiscoveryError::InvalidScopes(discovery.scopes_supported));
}
Ok(discovery)
}
trait ValidateIssuer {
fn validate_issuer(&self, issuer: &Url) -> Result<()>;
}
#[derive(Debug, Clone, Copy)]
#[expect(clippy::struct_excessive_bools)]
struct ValidationStrategy {
scheme: bool,
host: bool,
query: bool,
fragment: bool,
}
impl Default for ValidationStrategy {
fn default() -> Self {
Self::strict()
}
}
impl ValidationStrategy {
const fn strict() -> Self {
Self {
scheme: true,
host: true,
query: true,
fragment: true,
}
}
#[cfg(any(test, feature = "_insecure-issuer-validation"))]
const fn insecure() -> Self {
Self {
scheme: false,
host: true,
query: true,
fragment: true,
}
}
}
impl ValidateIssuer for ValidationStrategy {
fn validate_issuer(&self, issuer: &Url) -> Result<()> {
if self.scheme && issuer.scheme() != "https" {
return Err(issuer_err(issuer, "the issuer scheme is not https"));
}
if self.host && issuer.host_str().is_none_or(str::is_empty) {
return Err(issuer_err(issuer, "the issuer has no host"));
}
if self.query && issuer.query().is_some() {
return Err(issuer_err(issuer, "the issuer contains a query"));
}
if self.fragment && issuer.fragment().is_some() {
return Err(issuer_err(issuer, "the issuer contains a fragment"));
}
Ok(())
}
}
fn issuer_err(issuer: &Url, reason: &str) -> DiscoveryError {
DiscoveryError::InvalidIssuer {
issuer: issuer.as_str().to_string(),
reason: reason.to_string(),
}
}
#[cfg(test)]
mod test {
use super::*;
use httpmock::prelude::*;
use rstest::rstest;
fn http_client() -> reqwest::Client {
reqwest::Client::builder().build().unwrap()
}
#[tokio::test]
#[rstest]
async fn test_discovery(#[values("", "/")] issuer_suffix: &str) {
let mock_server = MockServer::start_async().await;
let issuer = Url::parse(&mock_server.base_url()).unwrap();
let expected = Discovery::new_for_test(issuer);
let oidc_mock = mock_server
.mock_async(|when, then| {
when.method(GET).path("/.well-known/openid-configuration");
then.status(200).json_body_obj(&expected);
})
.await;
let actual = fetch_discovery_impl(
&http_client(),
mock_server.url(issuer_suffix).parse().unwrap(),
ValidationStrategy::insecure(),
)
.await
.expect("should fetch discovery document");
assert_eq!(actual.token_endpoint, expected.token_endpoint);
oidc_mock.assert_async().await;
}
#[rstest]
fn test_validation_strategy_invalid(
#[values(
"http://example.com", // not https
"https://example.com?foo=bar", // has query
"https://example.com#foo=bar", // has fragment
)]
issuer: &str,
) {
let strategy = ValidationStrategy::strict();
let issuer = Url::parse(issuer).unwrap();
assert!(
strategy.validate_issuer(&issuer).is_err(),
"issuer: {issuer:?}"
);
}
#[test]
fn test_validation_strategy_invalid_host() {
let strategy = ValidationStrategy {
scheme: false,
..ValidationStrategy::strict()
};
let issuer = Url::parse("foo:/./path").unwrap();
assert!(
strategy.validate_issuer(&issuer).is_err(),
"issuer: {issuer:?}"
);
}
#[rstest]
fn test_validation_strategy_valid(
#[values(
"https://example.com",
"https://example.com:80",
"https://example.com/some/path",
"https://example.com/some/path/with/slash/"
)]
issuer: &str,
) {
let strategy = ValidationStrategy::strict();
let issuer = Url::parse(issuer).unwrap();
assert!(strategy.validate_issuer(&issuer).is_ok());
}
}