graph_oauth/identity/
allowed_host_validator.rs

1use std::collections::HashSet;
2use std::hash::Hash;
3
4use url::{Host, Url};
5
6#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
7pub enum HostIs {
8    Valid,
9    Invalid,
10}
11
12pub trait ValidateHosts<RHS = Self> {
13    fn validate_hosts(&self, valid_hosts: &[Url]) -> HostIs;
14}
15
16impl ValidateHosts for Url {
17    fn validate_hosts(&self, valid_hosts: &[Url]) -> HostIs {
18        if valid_hosts.is_empty() {
19            return HostIs::Invalid;
20        }
21
22        let size_before = valid_hosts.len();
23        let hosts: Vec<Host<&str>> = valid_hosts.iter().flat_map(|url| url.host()).collect();
24        assert_eq!(size_before, hosts.len());
25
26        if let Some(host) = self.host() {
27            if hosts.contains(&host) {
28                return HostIs::Valid;
29            }
30        }
31
32        for value in valid_hosts.iter() {
33            if !value.scheme().eq("https") {
34                return HostIs::Invalid;
35            }
36        }
37
38        HostIs::Invalid
39    }
40}
41
42impl ValidateHosts for String {
43    fn validate_hosts(&self, valid_hosts: &[Url]) -> HostIs {
44        if let Ok(url) = Url::parse(self) {
45            return url.validate_hosts(valid_hosts);
46        }
47
48        HostIs::Invalid
49    }
50}
51
52impl ValidateHosts for &str {
53    fn validate_hosts(&self, valid_hosts: &[Url]) -> HostIs {
54        if let Ok(url) = Url::parse(self) {
55            return url.validate_hosts(valid_hosts);
56        }
57
58        HostIs::Invalid
59    }
60}
61
62#[derive(Clone, Debug)]
63pub struct AllowedHostValidator {
64    allowed_hosts: HashSet<Url>,
65}
66
67impl AllowedHostValidator {
68    pub fn new(allowed_hosts: HashSet<Url>) -> AllowedHostValidator {
69        for url in allowed_hosts.iter() {
70            if !url.scheme().eq("https") {
71                panic!("Requires https scheme");
72            }
73        }
74
75        AllowedHostValidator { allowed_hosts }
76    }
77
78    pub fn validate_str(&self, url_str: &str) -> HostIs {
79        if let Ok(url) = Url::parse(url_str) {
80            return self.validate_hosts(&[url]);
81        }
82
83        HostIs::Invalid
84    }
85
86    pub fn validate_url(&self, url: &Url) -> HostIs {
87        self.validate_hosts(&[url.clone()])
88    }
89}
90
91impl From<&[Url]> for AllowedHostValidator {
92    fn from(value: &[Url]) -> Self {
93        let hash_set = HashSet::from_iter(value.iter().cloned());
94        AllowedHostValidator::new(hash_set)
95    }
96}
97
98impl ValidateHosts for AllowedHostValidator {
99    fn validate_hosts(&self, valid_hosts: &[Url]) -> HostIs {
100        if valid_hosts.is_empty() {
101            return HostIs::Invalid;
102        }
103
104        let urls: Vec<Url> = self.allowed_hosts.iter().cloned().collect();
105        for url in valid_hosts.iter() {
106            if url.validate_hosts(urls.as_slice()).eq(&HostIs::Invalid) {
107                return HostIs::Invalid;
108            }
109        }
110
111        HostIs::Valid
112    }
113}
114
115impl Default for AllowedHostValidator {
116    fn default() -> Self {
117        let urls: HashSet<Url> = [
118            "https://graph.microsoft.com",
119            "https://graph.microsoft.us",
120            "https://dod-graph.microsoft.us",
121            "https://graph.microsoft.de",
122            "https://microsoftgraph.chinacloudapi.cn",
123            "https://canary.graph.microsoft.com",
124        ]
125        .iter()
126        .flat_map(|url_str| Url::parse(url_str))
127        .collect();
128        assert_eq!(6, urls.len());
129
130        AllowedHostValidator::new(urls)
131    }
132}
133
134#[cfg(test)]
135mod test {
136    use super::*;
137
138    #[test]
139    fn test_valid_hosts() {
140        let valid_hosts: Vec<String> = [
141            "graph.microsoft.com",
142            "graph.microsoft.us",
143            "dod-graph.microsoft.us",
144            "graph.microsoft.de",
145            "microsoftgraph.chinacloudapi.cn",
146            "canary.graph.microsoft.com",
147        ]
148        .iter()
149        .map(|s| s.to_string())
150        .collect();
151
152        let host_urls: Vec<Url> = valid_hosts
153            .iter()
154            .map(|s| format!("https://{s}"))
155            .flat_map(|s| Url::parse(&s))
156            .collect();
157
158        assert_eq!(6, host_urls.len());
159
160        for url in host_urls.iter() {
161            assert_eq!(HostIs::Valid, url.validate_hosts(&host_urls));
162        }
163    }
164
165    #[test]
166    fn test_invalid_hosts() {
167        let invalid_hosts = [
168            "graph.on.microsoft.com",
169            "microsoft.com",
170            "windows.net",
171            "example.org",
172        ];
173
174        let valid_hosts: Vec<Url> = [
175            "graph.microsoft.com",
176            "graph.microsoft.us",
177            "dod-graph.microsoft.us",
178            "graph.microsoft.de",
179            "microsoftgraph.chinacloudapi.cn",
180            "canary.graph.microsoft.com",
181        ]
182        .iter()
183        .map(|s| Url::parse(&format!("https://{s}")).unwrap())
184        .collect();
185        assert_eq!(6, valid_hosts.len());
186
187        let host_urls: Vec<Url> = invalid_hosts
188            .iter()
189            .map(|s| format!("https://{s}"))
190            .flat_map(|s| Url::parse(&s))
191            .collect();
192
193        assert_eq!(4, host_urls.len());
194
195        for url in host_urls.iter() {
196            assert_eq!(HostIs::Invalid, url.validate_hosts(valid_hosts.as_slice()));
197        }
198    }
199
200    #[test]
201    fn test_allowed_host_validator() {
202        let valid_hosts: Vec<String> = [
203            "graph.microsoft.com",
204            "graph.microsoft.us",
205            "dod-graph.microsoft.us",
206            "graph.microsoft.de",
207            "microsoftgraph.chinacloudapi.cn",
208            "canary.graph.microsoft.com",
209        ]
210        .iter()
211        .map(|s| s.to_string())
212        .collect();
213
214        let host_urls: Vec<Url> = valid_hosts
215            .iter()
216            .map(|s| format!("https://{s}"))
217            .flat_map(|s| Url::parse(&s))
218            .collect();
219
220        assert_eq!(6, host_urls.len());
221
222        let allowed_host_validator = AllowedHostValidator::from(host_urls.as_slice());
223
224        for url in host_urls.iter() {
225            assert_eq!(HostIs::Valid, allowed_host_validator.validate_url(url));
226        }
227    }
228}