graph_oauth/identity/
allowed_host_validator.rs1use std::collections::HashSet;
2use std::hash::Hash;
3
4use url::{Host, Url};
5
6#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
7pub enum HostIs {
8 Valid,
9 Invalid,
10}
11
12pub trait ValidateHosts<RHS = Self> {
13 fn validate_hosts(&self, valid_hosts: &[Url]) -> HostIs;
14}
15
16impl ValidateHosts for Url {
17 fn validate_hosts(&self, valid_hosts: &[Url]) -> HostIs {
18 if valid_hosts.is_empty() {
19 return HostIs::Invalid;
20 }
21
22 let size_before = valid_hosts.len();
23 let hosts: Vec<Host<&str>> = valid_hosts.iter().flat_map(|url| url.host()).collect();
24 assert_eq!(size_before, hosts.len());
25
26 if let Some(host) = self.host() {
27 if hosts.contains(&host) {
28 return HostIs::Valid;
29 }
30 }
31
32 for value in valid_hosts.iter() {
33 if !value.scheme().eq("https") {
34 return HostIs::Invalid;
35 }
36 }
37
38 HostIs::Invalid
39 }
40}
41
42impl ValidateHosts for String {
43 fn validate_hosts(&self, valid_hosts: &[Url]) -> HostIs {
44 if let Ok(url) = Url::parse(self) {
45 return url.validate_hosts(valid_hosts);
46 }
47
48 HostIs::Invalid
49 }
50}
51
52impl ValidateHosts for &str {
53 fn validate_hosts(&self, valid_hosts: &[Url]) -> HostIs {
54 if let Ok(url) = Url::parse(self) {
55 return url.validate_hosts(valid_hosts);
56 }
57
58 HostIs::Invalid
59 }
60}
61
62#[derive(Clone, Debug)]
63pub struct AllowedHostValidator {
64 allowed_hosts: HashSet<Url>,
65}
66
67impl AllowedHostValidator {
68 pub fn new(allowed_hosts: HashSet<Url>) -> AllowedHostValidator {
69 for url in allowed_hosts.iter() {
70 if !url.scheme().eq("https") {
71 panic!("Requires https scheme");
72 }
73 }
74
75 AllowedHostValidator { allowed_hosts }
76 }
77
78 pub fn validate_str(&self, url_str: &str) -> HostIs {
79 if let Ok(url) = Url::parse(url_str) {
80 return self.validate_hosts(&[url]);
81 }
82
83 HostIs::Invalid
84 }
85
86 pub fn validate_url(&self, url: &Url) -> HostIs {
87 self.validate_hosts(&[url.clone()])
88 }
89}
90
91impl From<&[Url]> for AllowedHostValidator {
92 fn from(value: &[Url]) -> Self {
93 let hash_set = HashSet::from_iter(value.iter().cloned());
94 AllowedHostValidator::new(hash_set)
95 }
96}
97
98impl ValidateHosts for AllowedHostValidator {
99 fn validate_hosts(&self, valid_hosts: &[Url]) -> HostIs {
100 if valid_hosts.is_empty() {
101 return HostIs::Invalid;
102 }
103
104 let urls: Vec<Url> = self.allowed_hosts.iter().cloned().collect();
105 for url in valid_hosts.iter() {
106 if url.validate_hosts(urls.as_slice()).eq(&HostIs::Invalid) {
107 return HostIs::Invalid;
108 }
109 }
110
111 HostIs::Valid
112 }
113}
114
115impl Default for AllowedHostValidator {
116 fn default() -> Self {
117 let urls: HashSet<Url> = [
118 "https://graph.microsoft.com",
119 "https://graph.microsoft.us",
120 "https://dod-graph.microsoft.us",
121 "https://graph.microsoft.de",
122 "https://microsoftgraph.chinacloudapi.cn",
123 "https://canary.graph.microsoft.com",
124 ]
125 .iter()
126 .flat_map(|url_str| Url::parse(url_str))
127 .collect();
128 assert_eq!(6, urls.len());
129
130 AllowedHostValidator::new(urls)
131 }
132}
133
134#[cfg(test)]
135mod test {
136 use super::*;
137
138 #[test]
139 fn test_valid_hosts() {
140 let valid_hosts: Vec<String> = [
141 "graph.microsoft.com",
142 "graph.microsoft.us",
143 "dod-graph.microsoft.us",
144 "graph.microsoft.de",
145 "microsoftgraph.chinacloudapi.cn",
146 "canary.graph.microsoft.com",
147 ]
148 .iter()
149 .map(|s| s.to_string())
150 .collect();
151
152 let host_urls: Vec<Url> = valid_hosts
153 .iter()
154 .map(|s| format!("https://{s}"))
155 .flat_map(|s| Url::parse(&s))
156 .collect();
157
158 assert_eq!(6, host_urls.len());
159
160 for url in host_urls.iter() {
161 assert_eq!(HostIs::Valid, url.validate_hosts(&host_urls));
162 }
163 }
164
165 #[test]
166 fn test_invalid_hosts() {
167 let invalid_hosts = [
168 "graph.on.microsoft.com",
169 "microsoft.com",
170 "windows.net",
171 "example.org",
172 ];
173
174 let valid_hosts: Vec<Url> = [
175 "graph.microsoft.com",
176 "graph.microsoft.us",
177 "dod-graph.microsoft.us",
178 "graph.microsoft.de",
179 "microsoftgraph.chinacloudapi.cn",
180 "canary.graph.microsoft.com",
181 ]
182 .iter()
183 .map(|s| Url::parse(&format!("https://{s}")).unwrap())
184 .collect();
185 assert_eq!(6, valid_hosts.len());
186
187 let host_urls: Vec<Url> = invalid_hosts
188 .iter()
189 .map(|s| format!("https://{s}"))
190 .flat_map(|s| Url::parse(&s))
191 .collect();
192
193 assert_eq!(4, host_urls.len());
194
195 for url in host_urls.iter() {
196 assert_eq!(HostIs::Invalid, url.validate_hosts(valid_hosts.as_slice()));
197 }
198 }
199
200 #[test]
201 fn test_allowed_host_validator() {
202 let valid_hosts: Vec<String> = [
203 "graph.microsoft.com",
204 "graph.microsoft.us",
205 "dod-graph.microsoft.us",
206 "graph.microsoft.de",
207 "microsoftgraph.chinacloudapi.cn",
208 "canary.graph.microsoft.com",
209 ]
210 .iter()
211 .map(|s| s.to_string())
212 .collect();
213
214 let host_urls: Vec<Url> = valid_hosts
215 .iter()
216 .map(|s| format!("https://{s}"))
217 .flat_map(|s| Url::parse(&s))
218 .collect();
219
220 assert_eq!(6, host_urls.len());
221
222 let allowed_host_validator = AllowedHostValidator::from(host_urls.as_slice());
223
224 for url in host_urls.iter() {
225 assert_eq!(HostIs::Valid, allowed_host_validator.validate_url(url));
226 }
227 }
228}