1use std::net::{IpAddr, Ipv4Addr, SocketAddr};
2
3use oauth2::Scope;
4
5use crate::ConfigError::{CannotBindAddress, InvalidServerConfig};
6use crate::*;
7
8#[derive(Debug)]
12pub struct CliOAuthBuilder {
13 port_range: PortRange,
14 ip_address: IpAddr,
15 socket_address: Option<SocketAddr>,
16 timeout: u64,
17 scopes: Vec<Scope>,
18}
19
20impl CliOAuthBuilder {
21 pub(crate) fn new() -> Self {
22 CliOAuthBuilder {
23 port_range: DEFAULT_PORT_MIN..DEFAULT_PORT_MAX,
24 ip_address: IpAddr::V4(Ipv4Addr::LOCALHOST),
25 socket_address: None,
26 timeout: DEFAULT_TIMEOUT,
27 scopes: Default::default(),
28 }
29 }
30
31 fn resolve_address(&self) -> ConfigResult<SocketAddr> {
32 match self.socket_address {
33 Some(socket_addr) if is_address_available(socket_addr) => Ok(socket_addr),
34 Some(socket_addr) => Err(CannotBindAddress {
35 addr: socket_addr.ip(),
36 port_range: socket_addr.port()..socket_addr.port(),
37 }),
38 None => find_available_port(self.ip_address, self.port_range.clone()),
39 }
40 }
41
42 fn validate(&self) -> ConfigResult<()> {
43 if self.port_range.start < PORT_MIN {
44 return Err(InvalidServerConfig {
45 expected: format!("port >= {}", PORT_MIN),
46 found: format!("{}", self.port_range.start),
47 });
48 }
49 Ok(())
50 }
51
52 pub fn port(mut self, port: u16) -> Self {
56 self.port_range = port..(port + 1);
57 self
58 }
59
60 pub fn port_range(mut self, ports: PortRange) -> Self {
67 self.port_range = ports;
68 self
69 }
70
71 pub fn ip_address(mut self, ip_address: impl Into<IpAddr>) -> Self {
76 self.ip_address = ip_address.into();
77 self
78 }
79
80 pub fn socket_address(mut self, address: SocketAddr) -> Self {
85 self.socket_address = Some(address);
86 self
87 }
88
89 pub fn timeout(mut self, timeout: u64) -> Self {
96 self.timeout = timeout;
97 self
98 }
99
100 pub fn scope(mut self, scope: Scope) -> Self {
102 self.scopes.push(scope);
103 self
104 }
105
106 pub fn scopes<S>(mut self, scopes: S) -> Self
108 where
109 S: IntoIterator<Item = Scope>,
110 {
111 self.scopes.extend(scopes);
112 self
113 }
114
115 pub fn build(self) -> ConfigResult<CliOAuth> {
117 self.validate()?;
118 let socket_addr = self.resolve_address()?;
119 Ok(CliOAuth {
120 address: socket_addr,
121 timeout: self.timeout,
122 scopes: self.scopes,
123 auth_context: None,
124 auth_result: None,
125 })
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener};
132 use std::str::FromStr;
133
134 use oauth2::Scope;
135 use rstest::rstest;
136
137 use crate::tests::{next_ports, LOCALHOST};
138 use crate::{DEFAULT_PORT_MAX, DEFAULT_PORT_MIN, DEFAULT_TIMEOUT, PORT_MIN};
139
140 use super::CliOAuthBuilder;
141
142 #[rstest]
143 fn all_defaults() {
144 let builder = CliOAuthBuilder::new();
145 assert_eq!(
146 builder.port_range.clone(),
147 DEFAULT_PORT_MIN..DEFAULT_PORT_MAX
148 );
149 assert_eq!(builder.ip_address.clone(), LOCALHOST);
150 assert_eq!(builder.socket_address.clone(), None);
151 assert_eq!(builder.timeout, DEFAULT_TIMEOUT);
152 builder.validate().expect("builder should be valid");
153 }
154
155 #[rstest]
156 fn set_single_port() {
157 let port = 2048;
158 let builder = CliOAuthBuilder::new().port(port);
159 assert!(builder.port_range.contains(&port));
160 builder.validate().expect("builder should be valid");
161 }
162
163 #[rstest]
164 #[case::one_less_than_min(PORT_MIN - 1)]
165 #[case::one(1)]
166 #[case::zero(0)]
167 fn set_single_invalid_port(#[case] port: u16) {
168 let builder = CliOAuthBuilder::new().port(port);
169 let error = builder.validate().expect_err("Port should be invalid");
170 assert_eq!(
171 format!("{error}"),
172 format!("Invalid server config (expected port >= 1024, found {port})")
173 );
174 }
175
176 #[rstest]
177 fn set_port_range() {
178 let port_range = 2048..4096;
179 let builder = CliOAuthBuilder::new().port_range(port_range.clone());
180 assert_eq!(builder.port_range.clone(), port_range);
181 builder.validate().expect("builder should be valid");
182 }
183
184 #[rstest]
185 #[case::one_less_than_min(PORT_MIN - 1)]
186 #[case::one(1)]
187 #[case::zero(0)]
188 fn set_invalid_port_range(#[case] lower_port: u16) {
189 let builder = CliOAuthBuilder::new().port_range(lower_port..PORT_MIN);
190 let error = builder
191 .validate()
192 .expect_err("Port range should be invalid");
193 assert_eq!(
194 format!("{error}"),
195 format!("Invalid server config (expected port >= 1024, found {lower_port})")
196 );
197 }
198
199 #[rstest]
200 fn set_ip_address() {
201 let builder = CliOAuthBuilder::new()
202 .ip_address(IpAddr::V4(Ipv4Addr::from_str("192.168.0.20").unwrap()));
203 assert_eq!(
204 builder.ip_address.clone(),
205 Ipv4Addr::from([192, 168, 0, 20])
206 );
207 builder.validate().expect("builder should be valid");
208 }
209
210 #[rstest]
211 fn set_socket_address() {
212 let addr = SocketAddr::from_str("192.168.0.20:4096").unwrap();
213 let builder = CliOAuthBuilder::new().socket_address(addr);
214 assert_eq!(builder.socket_address.unwrap(), addr);
215 builder.validate().expect("builder should be valid");
216 }
217
218 #[rstest]
219 fn socket_address_overrides_ip_and_port() {
220 let (start_port, end_port) = next_ports(5);
221 let port_range = start_port..end_port;
222 let socket_addr = SocketAddr::from_str("127.0.0.1:8192").unwrap();
223
224 let builder = CliOAuthBuilder::new()
225 .ip_address(LOCALHOST)
226 .port_range(port_range.clone())
227 .socket_address(socket_addr);
228 let resolved_address = builder.resolve_address().unwrap();
229 assert_eq!(resolved_address, socket_addr);
230 }
231
232 #[rstest]
233 fn socket_address_from_ip_and_port_range() {
234 let (port, _) = next_ports(1);
235 let builder = CliOAuthBuilder::new().ip_address(LOCALHOST).port(port);
236 let resolved_address = builder.resolve_address().unwrap();
237 assert_eq!(resolved_address.port(), port);
238 assert_eq!(resolved_address.ip(), LOCALHOST);
239 }
240
241 #[rstest]
242 fn set_timeout() {
243 let builder = CliOAuthBuilder::new().timeout(120);
244 assert_eq!(builder.timeout, 120);
245 }
246
247 #[rstest]
248 fn add_scope() {
249 let builder = CliOAuthBuilder::new().scope(Scope::new(String::from("test_scope")));
250 assert_eq!(builder.scopes, vec![Scope::new(String::from("test_scope"))]);
251 }
252
253 #[rstest]
254 fn add_scopes() {
255 let scopes = vec![
256 Scope::new(String::from("scope:1")),
257 Scope::new(String::from("scope:2")),
258 ];
259 let builder = CliOAuthBuilder::new().scopes(scopes);
260 assert_eq!(
261 builder.scopes,
262 vec![
263 Scope::new(String::from("scope:1")),
264 Scope::new(String::from("scope:2"))
265 ]
266 );
267 }
268
269 #[rstest]
270 fn build_valid_struct() {
271 let (port, _) = next_ports(1);
272 let builder = CliOAuthBuilder::new().port(port).timeout(30);
273 let res = builder.build();
274 let auth = res.expect("valid struct should be built");
275 let built_addr = auth.address;
276 assert_eq!(built_addr, SocketAddr::new(LOCALHOST, port));
277 assert_eq!(auth.timeout, 30);
278 }
279
280 #[rstest]
281 fn build_struct_with_invalid_ports() {
282 let port = 26;
283 let builder = CliOAuthBuilder::new().port(port);
284 let res = builder.build();
285 let error = res.expect_err("error should be returned");
286 assert_eq!(
287 format!("{error}"),
288 format!("Invalid server config (expected port >= 1024, found {port})")
289 );
290 }
291
292 #[rstest]
293 fn build_struct_with_unavailable_ports() {
294 let (test_port, open_port) = next_ports(1);
295 let _socket =
296 TcpListener::bind(SocketAddr::new(LOCALHOST, open_port)).expect("port is already open");
297 let builder = CliOAuthBuilder::new().port(test_port);
298 let res = builder.build();
299 let error = res.expect_err("error should be returned");
300 assert_eq!(
301 format!("{error}"),
302 format!("Cannot bind to 127.0.0.1 on any port from {test_port}-{test_port}")
303 );
304 }
305}