clio_auth/
builder.rs

1use std::net::{IpAddr, Ipv4Addr, SocketAddr};
2
3use oauth2::Scope;
4
5use crate::ConfigError::{CannotBindAddress, InvalidServerConfig};
6use crate::*;
7
8/// A builder for [`CliOAuth`] structs.
9///
10/// Not constructed directly. See [`CliOAuth::builder()`].
11#[derive(Debug)]
12pub struct CliOAuthBuilder {
13    port_range: PortRange,
14    ip_address: IpAddr,
15    socket_address: Option<SocketAddr>,
16    timeout: u64,
17    scopes: Vec<Scope>,
18}
19
20impl CliOAuthBuilder {
21    pub(crate) fn new() -> Self {
22        CliOAuthBuilder {
23            port_range: DEFAULT_PORT_MIN..DEFAULT_PORT_MAX,
24            ip_address: IpAddr::V4(Ipv4Addr::LOCALHOST),
25            socket_address: None,
26            timeout: DEFAULT_TIMEOUT,
27            scopes: Default::default(),
28        }
29    }
30
31    fn resolve_address(&self) -> ConfigResult<SocketAddr> {
32        match self.socket_address {
33            Some(socket_addr) if is_address_available(socket_addr) => Ok(socket_addr),
34            Some(socket_addr) => Err(CannotBindAddress {
35                addr: socket_addr.ip(),
36                port_range: socket_addr.port()..socket_addr.port(),
37            }),
38            None => find_available_port(self.ip_address, self.port_range.clone()),
39        }
40    }
41
42    fn validate(&self) -> ConfigResult<()> {
43        if self.port_range.start < PORT_MIN {
44            return Err(InvalidServerConfig {
45                expected: format!("port >= {}", PORT_MIN),
46                found: format!("{}", self.port_range.start),
47            });
48        }
49        Ok(())
50    }
51
52    /// Configures a single port for the web server to attempt to bind to.
53    ///
54    /// For simplicity, must be a non-privileged port (greater than or equal to `1024`).
55    pub fn port(mut self, port: u16) -> Self {
56        self.port_range = port..(port + 1);
57        self
58    }
59
60    /// Configures a range of ports for the web server to attempt to bind to.
61    ///
62    /// When the `CliOAuth` instance is constructed, each of these ports will be tried in order. The
63    /// first open one will be used.
64    ///
65    /// The default range is `3456..3465`.
66    pub fn port_range(mut self, ports: PortRange) -> Self {
67        self.port_range = ports;
68        self
69    }
70
71    /// Configures the local IP address for the web server to listen on.
72    ///
73    /// Address must be valid on the system. The default is "localhost" (`127.0.0.1`), which works
74    /// fine in most cases.
75    pub fn ip_address(mut self, ip_address: impl Into<IpAddr>) -> Self {
76        self.ip_address = ip_address.into();
77        self
78    }
79
80    /// Configures a socket address (IP address and port) for the web server to listen on.
81    ///
82    /// If provided, it overrides the [`ip_address`][`Self::ip_address()`],
83    /// [`port`][`Self::port()`], and [`port_range`][`Self::port_range()`] settings.
84    pub fn socket_address(mut self, address: SocketAddr) -> Self {
85        self.socket_address = Some(address);
86        self
87    }
88
89    /// Configures the number of seconds the server will wait for an authorization code.
90    ///
91    /// If the server has not received a request containing a valid authorization code, it will
92    /// shut itself down, and the token exchange will not be possible.
93    ///
94    /// The default is `60` seconds.
95    pub fn timeout(mut self, timeout: u64) -> Self {
96        self.timeout = timeout;
97        self
98    }
99
100    /// Adds a scope to include with the authorization request.
101    pub fn scope(mut self, scope: Scope) -> Self {
102        self.scopes.push(scope);
103        self
104    }
105
106    /// Adds scopes to include with the authorization request.
107    pub fn scopes<S>(mut self, scopes: S) -> Self
108    where
109        S: IntoIterator<Item = Scope>,
110    {
111        self.scopes.extend(scopes);
112        self
113    }
114
115    /// Constructs the [`CliOAuth`] instance with the configuration captured in this builder.
116    pub fn build(self) -> ConfigResult<CliOAuth> {
117        self.validate()?;
118        let socket_addr = self.resolve_address()?;
119        Ok(CliOAuth {
120            address: socket_addr,
121            timeout: self.timeout,
122            scopes: self.scopes,
123            auth_context: None,
124            auth_result: None,
125        })
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener};
132    use std::str::FromStr;
133
134    use oauth2::Scope;
135    use rstest::rstest;
136
137    use crate::tests::{next_ports, LOCALHOST};
138    use crate::{DEFAULT_PORT_MAX, DEFAULT_PORT_MIN, DEFAULT_TIMEOUT, PORT_MIN};
139
140    use super::CliOAuthBuilder;
141
142    #[rstest]
143    fn all_defaults() {
144        let builder = CliOAuthBuilder::new();
145        assert_eq!(
146            builder.port_range.clone(),
147            DEFAULT_PORT_MIN..DEFAULT_PORT_MAX
148        );
149        assert_eq!(builder.ip_address.clone(), LOCALHOST);
150        assert_eq!(builder.socket_address.clone(), None);
151        assert_eq!(builder.timeout, DEFAULT_TIMEOUT);
152        builder.validate().expect("builder should be valid");
153    }
154
155    #[rstest]
156    fn set_single_port() {
157        let port = 2048;
158        let builder = CliOAuthBuilder::new().port(port);
159        assert!(builder.port_range.contains(&port));
160        builder.validate().expect("builder should be valid");
161    }
162
163    #[rstest]
164    #[case::one_less_than_min(PORT_MIN - 1)]
165    #[case::one(1)]
166    #[case::zero(0)]
167    fn set_single_invalid_port(#[case] port: u16) {
168        let builder = CliOAuthBuilder::new().port(port);
169        let error = builder.validate().expect_err("Port should be invalid");
170        assert_eq!(
171            format!("{error}"),
172            format!("Invalid server config (expected port >= 1024, found {port})")
173        );
174    }
175
176    #[rstest]
177    fn set_port_range() {
178        let port_range = 2048..4096;
179        let builder = CliOAuthBuilder::new().port_range(port_range.clone());
180        assert_eq!(builder.port_range.clone(), port_range);
181        builder.validate().expect("builder should be valid");
182    }
183
184    #[rstest]
185    #[case::one_less_than_min(PORT_MIN - 1)]
186    #[case::one(1)]
187    #[case::zero(0)]
188    fn set_invalid_port_range(#[case] lower_port: u16) {
189        let builder = CliOAuthBuilder::new().port_range(lower_port..PORT_MIN);
190        let error = builder
191            .validate()
192            .expect_err("Port range should be invalid");
193        assert_eq!(
194            format!("{error}"),
195            format!("Invalid server config (expected port >= 1024, found {lower_port})")
196        );
197    }
198
199    #[rstest]
200    fn set_ip_address() {
201        let builder = CliOAuthBuilder::new()
202            .ip_address(IpAddr::V4(Ipv4Addr::from_str("192.168.0.20").unwrap()));
203        assert_eq!(
204            builder.ip_address.clone(),
205            Ipv4Addr::from([192, 168, 0, 20])
206        );
207        builder.validate().expect("builder should be valid");
208    }
209
210    #[rstest]
211    fn set_socket_address() {
212        let addr = SocketAddr::from_str("192.168.0.20:4096").unwrap();
213        let builder = CliOAuthBuilder::new().socket_address(addr);
214        assert_eq!(builder.socket_address.unwrap(), addr);
215        builder.validate().expect("builder should be valid");
216    }
217
218    #[rstest]
219    fn socket_address_overrides_ip_and_port() {
220        let (start_port, end_port) = next_ports(5);
221        let port_range = start_port..end_port;
222        let socket_addr = SocketAddr::from_str("127.0.0.1:8192").unwrap();
223
224        let builder = CliOAuthBuilder::new()
225            .ip_address(LOCALHOST)
226            .port_range(port_range.clone())
227            .socket_address(socket_addr);
228        let resolved_address = builder.resolve_address().unwrap();
229        assert_eq!(resolved_address, socket_addr);
230    }
231
232    #[rstest]
233    fn socket_address_from_ip_and_port_range() {
234        let (port, _) = next_ports(1);
235        let builder = CliOAuthBuilder::new().ip_address(LOCALHOST).port(port);
236        let resolved_address = builder.resolve_address().unwrap();
237        assert_eq!(resolved_address.port(), port);
238        assert_eq!(resolved_address.ip(), LOCALHOST);
239    }
240
241    #[rstest]
242    fn set_timeout() {
243        let builder = CliOAuthBuilder::new().timeout(120);
244        assert_eq!(builder.timeout, 120);
245    }
246
247    #[rstest]
248    fn add_scope() {
249        let builder = CliOAuthBuilder::new().scope(Scope::new(String::from("test_scope")));
250        assert_eq!(builder.scopes, vec![Scope::new(String::from("test_scope"))]);
251    }
252
253    #[rstest]
254    fn add_scopes() {
255        let scopes = vec![
256            Scope::new(String::from("scope:1")),
257            Scope::new(String::from("scope:2")),
258        ];
259        let builder = CliOAuthBuilder::new().scopes(scopes);
260        assert_eq!(
261            builder.scopes,
262            vec![
263                Scope::new(String::from("scope:1")),
264                Scope::new(String::from("scope:2"))
265            ]
266        );
267    }
268
269    #[rstest]
270    fn build_valid_struct() {
271        let (port, _) = next_ports(1);
272        let builder = CliOAuthBuilder::new().port(port).timeout(30);
273        let res = builder.build();
274        let auth = res.expect("valid struct should be built");
275        let built_addr = auth.address;
276        assert_eq!(built_addr, SocketAddr::new(LOCALHOST, port));
277        assert_eq!(auth.timeout, 30);
278    }
279
280    #[rstest]
281    fn build_struct_with_invalid_ports() {
282        let port = 26;
283        let builder = CliOAuthBuilder::new().port(port);
284        let res = builder.build();
285        let error = res.expect_err("error should be returned");
286        assert_eq!(
287            format!("{error}"),
288            format!("Invalid server config (expected port >= 1024, found {port})")
289        );
290    }
291
292    #[rstest]
293    fn build_struct_with_unavailable_ports() {
294        let (test_port, open_port) = next_ports(1);
295        let _socket =
296            TcpListener::bind(SocketAddr::new(LOCALHOST, open_port)).expect("port is already open");
297        let builder = CliOAuthBuilder::new().port(test_port);
298        let res = builder.build();
299        let error = res.expect_err("error should be returned");
300        assert_eq!(
301            format!("{error}"),
302            format!("Cannot bind to 127.0.0.1 on any port from {test_port}-{test_port}")
303        );
304    }
305}