graph-oauth 3.0.0

Rust SDK Client for Microsoft Identity Platform
use std::collections::HashSet;
use std::hash::Hash;

use url::{Host, Url};

#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum HostIs {
    Valid,
    Invalid,
}

pub trait ValidateHosts<RHS = Self> {
    fn validate_hosts(&self, valid_hosts: &[Url]) -> HostIs;
}

impl ValidateHosts for Url {
    fn validate_hosts(&self, valid_hosts: &[Url]) -> HostIs {
        if valid_hosts.is_empty() {
            return HostIs::Invalid;
        }

        let size_before = valid_hosts.len();
        let hosts: Vec<Host<&str>> = valid_hosts.iter().flat_map(|url| url.host()).collect();
        assert_eq!(size_before, hosts.len());

        if let Some(host) = self.host() {
            if hosts.contains(&host) {
                return HostIs::Valid;
            }
        }

        for value in valid_hosts.iter() {
            if !value.scheme().eq("https") {
                return HostIs::Invalid;
            }
        }

        HostIs::Invalid
    }
}

impl ValidateHosts for String {
    fn validate_hosts(&self, valid_hosts: &[Url]) -> HostIs {
        if let Ok(url) = Url::parse(self) {
            return url.validate_hosts(valid_hosts);
        }

        HostIs::Invalid
    }
}

impl ValidateHosts for &str {
    fn validate_hosts(&self, valid_hosts: &[Url]) -> HostIs {
        if let Ok(url) = Url::parse(self) {
            return url.validate_hosts(valid_hosts);
        }

        HostIs::Invalid
    }
}

#[derive(Clone, Debug)]
pub struct AllowedHostValidator {
    allowed_hosts: HashSet<Url>,
}

impl AllowedHostValidator {
    pub fn new(allowed_hosts: HashSet<Url>) -> AllowedHostValidator {
        for url in allowed_hosts.iter() {
            if !url.scheme().eq("https") {
                panic!("Requires https scheme");
            }
        }

        AllowedHostValidator { allowed_hosts }
    }

    pub fn validate_str(&self, url_str: &str) -> HostIs {
        if let Ok(url) = Url::parse(url_str) {
            return self.validate_hosts(&[url]);
        }

        HostIs::Invalid
    }

    pub fn validate_url(&self, url: &Url) -> HostIs {
        self.validate_hosts(&[url.clone()])
    }
}

impl From<&[Url]> for AllowedHostValidator {
    fn from(value: &[Url]) -> Self {
        let hash_set = HashSet::from_iter(value.iter().cloned());
        AllowedHostValidator::new(hash_set)
    }
}

impl ValidateHosts for AllowedHostValidator {
    fn validate_hosts(&self, valid_hosts: &[Url]) -> HostIs {
        if valid_hosts.is_empty() {
            return HostIs::Invalid;
        }

        let urls: Vec<Url> = self.allowed_hosts.iter().cloned().collect();
        for url in valid_hosts.iter() {
            if url.validate_hosts(urls.as_slice()).eq(&HostIs::Invalid) {
                return HostIs::Invalid;
            }
        }

        HostIs::Valid
    }
}

impl Default for AllowedHostValidator {
    fn default() -> Self {
        let urls: HashSet<Url> = [
            "https://graph.microsoft.com",
            "https://graph.microsoft.us",
            "https://dod-graph.microsoft.us",
            "https://graph.microsoft.de",
            "https://microsoftgraph.chinacloudapi.cn",
            "https://canary.graph.microsoft.com",
        ]
        .iter()
        .flat_map(|url_str| Url::parse(url_str))
        .collect();
        assert_eq!(6, urls.len());

        AllowedHostValidator::new(urls)
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_valid_hosts() {
        let valid_hosts: Vec<String> = [
            "graph.microsoft.com",
            "graph.microsoft.us",
            "dod-graph.microsoft.us",
            "graph.microsoft.de",
            "microsoftgraph.chinacloudapi.cn",
            "canary.graph.microsoft.com",
        ]
        .iter()
        .map(|s| s.to_string())
        .collect();

        let host_urls: Vec<Url> = valid_hosts
            .iter()
            .map(|s| format!("https://{s}"))
            .flat_map(|s| Url::parse(&s))
            .collect();

        assert_eq!(6, host_urls.len());

        for url in host_urls.iter() {
            assert_eq!(HostIs::Valid, url.validate_hosts(&host_urls));
        }
    }

    #[test]
    fn test_invalid_hosts() {
        let invalid_hosts = [
            "graph.on.microsoft.com",
            "microsoft.com",
            "windows.net",
            "example.org",
        ];

        let valid_hosts: Vec<Url> = [
            "graph.microsoft.com",
            "graph.microsoft.us",
            "dod-graph.microsoft.us",
            "graph.microsoft.de",
            "microsoftgraph.chinacloudapi.cn",
            "canary.graph.microsoft.com",
        ]
        .iter()
        .map(|s| Url::parse(&format!("https://{s}")).unwrap())
        .collect();
        assert_eq!(6, valid_hosts.len());

        let host_urls: Vec<Url> = invalid_hosts
            .iter()
            .map(|s| format!("https://{s}"))
            .flat_map(|s| Url::parse(&s))
            .collect();

        assert_eq!(4, host_urls.len());

        for url in host_urls.iter() {
            assert_eq!(HostIs::Invalid, url.validate_hosts(valid_hosts.as_slice()));
        }
    }

    #[test]
    fn test_allowed_host_validator() {
        let valid_hosts: Vec<String> = [
            "graph.microsoft.com",
            "graph.microsoft.us",
            "dod-graph.microsoft.us",
            "graph.microsoft.de",
            "microsoftgraph.chinacloudapi.cn",
            "canary.graph.microsoft.com",
        ]
        .iter()
        .map(|s| s.to_string())
        .collect();

        let host_urls: Vec<Url> = valid_hosts
            .iter()
            .map(|s| format!("https://{s}"))
            .flat_map(|s| Url::parse(&s))
            .collect();

        assert_eq!(6, host_urls.len());

        let allowed_host_validator = AllowedHostValidator::from(host_urls.as_slice());

        for url in host_urls.iter() {
            assert_eq!(HostIs::Valid, allowed_host_validator.validate_url(url));
        }
    }
}