tower-oauth2-resource-server 0.11.0

Tower middleware that provides JWT authorization against an OpenID Connect (OIDC) Provider
Documentation
use std::{collections::HashSet, error::Error};

use serde::Deserialize;
use url::Url;

#[cfg(test)]
use mockall::automock;

use crate::error::StartupError;

#[derive(Clone, Debug, Deserialize)]
pub(crate) struct OidcConfig {
    pub jwks_uri: Url,
    pub issuer: String,
    pub claims_supported: Option<Vec<String>>,
}

#[cfg_attr(test, allow(dead_code))]
pub(crate) struct OidcDiscovery {}

#[cfg_attr(test, automock)]
impl OidcDiscovery {
    #[cfg_attr(test, allow(dead_code))]
    pub async fn discover(issuer_url: &Url) -> Result<OidcConfig, Box<dyn Error>> {
        let paths = get_paths(issuer_url)?;
        for path in paths {
            if let Ok(response) = reqwest::get(path).await {
                if let Ok(oidc_config) = response.json().await {
                    return Ok(oidc_config);
                }
            }
        }
        Err("Failed to fetch OIDC configuration".into())
    }
}

fn get_paths(issuer_url: &Url) -> Result<HashSet<Url>, Box<dyn Error>> {
    let path_err =
        || StartupError::InvalidParameter(format!("Could not parse issuer: {}", issuer_url));
    let build_url = |base: &Url, segments: &[&str]| -> Result<Url, Box<dyn Error>> {
        let mut url = base.clone();
        url.path_segments_mut()
            .map_err(|_| path_err())?
            .clear()
            .extend(segments);
        Ok(url)
    };

    let base_segments: Vec<_> = issuer_url
        .path_segments()
        .ok_or(path_err())?
        .filter(|p| !p.is_empty())
        .collect();

    let paths = vec![
        {
            let mut segments = base_segments.clone();
            segments.extend(&[".well-known", "openid-configuration"]);
            build_url(issuer_url, &segments)
        },
        {
            let mut segments = vec![".well-known", "openid-configuration"];
            segments.extend(base_segments.clone());
            build_url(issuer_url, &segments)
        },
        {
            let mut segments = vec![".well-known", "oauth-authorization-server"];
            segments.extend(base_segments.clone());
            build_url(issuer_url, &segments)
        },
    ];

    paths.into_iter().collect()
}

#[cfg(test)]
mod tests {
    use std::collections::HashSet;

    use url::Url;

    use super::get_paths;

    #[test]
    fn test_get_paths_with_path() {
        let result = get_paths(
            &"https://authorization-server.com/issuer"
                .parse::<Url>()
                .unwrap(),
        );
        let paths = result
            .unwrap()
            .into_iter()
            .map(|p| p.to_string())
            .collect::<HashSet<_>>();
        assert_eq!(paths.len(), 3);
        assert!(
            paths.contains(
                "https://authorization-server.com/issuer/.well-known/openid-configuration"
            )
        );
        assert!(
            paths.contains(
                "https://authorization-server.com/.well-known/openid-configuration/issuer"
            )
        );
        assert!(paths.contains(
            "https://authorization-server.com/.well-known/oauth-authorization-server/issuer"
        ));
    }

    #[test]
    fn test_get_paths_with_path_trailing_slash() {
        let result = get_paths(
            &"https://authorization-server.com/issuer/"
                .parse::<Url>()
                .unwrap(),
        );
        let paths = result
            .unwrap()
            .into_iter()
            .map(|p| p.to_string())
            .collect::<HashSet<_>>();

        assert!(
            paths.contains(
                "https://authorization-server.com/issuer/.well-known/openid-configuration"
            )
        );
        assert!(
            paths.contains(
                "https://authorization-server.com/.well-known/openid-configuration/issuer"
            )
        );
        assert!(paths.contains(
            "https://authorization-server.com/.well-known/oauth-authorization-server/issuer"
        ));
    }

    #[test]
    fn test_get_paths_no_path() {
        let result = get_paths(&"https://authorization-server.com".parse::<Url>().unwrap());
        let paths = result
            .unwrap()
            .into_iter()
            .map(|p| p.to_string())
            .collect::<HashSet<_>>();

        assert_eq!(paths.len(), 2);
        assert!(
            paths.contains("https://authorization-server.com/.well-known/openid-configuration")
        );
        assert!(
            paths.contains(
                "https://authorization-server.com/.well-known/oauth-authorization-server"
            )
        );
    }

    #[test]
    fn test_get_paths_no_path_trailing_slash() {
        let result = get_paths(&"https://authorization-server.com/".parse::<Url>().unwrap());
        let paths = result
            .unwrap()
            .into_iter()
            .map(|p| p.to_string())
            .collect::<HashSet<_>>();

        assert_eq!(paths.len(), 2);
        assert!(
            paths.contains("https://authorization-server.com/.well-known/openid-configuration")
        );
        assert!(
            paths.contains(
                "https://authorization-server.com/.well-known/oauth-authorization-server"
            )
        );
    }
}