clio_auth/lib.rs
1//! OAuth 2.0 helper for CLI and desktop applications.
2//!
3//! This package facilitates the [OAuth 2.0 Authorization Code with PKCE][1] flow for command line
4//! and desktop GUI applications. It works hand-in-hand with the [oauth2][2] crate by providing the
5//! "missing pieces" for the flow: a web server to handle the authorization callback, and opening
6//! the browser with the authorization link.
7//!
8//! # Usage
9//!
10//! General usage is as follows:
11//!
12//! 1. Configure a [`CliOAuthBuilder`] and build a [`CliOAuth`] helper
13//! 1. Configure an [`oauth2::Client`]
14//! 1. Start the [authorization flow](CliOAuth::authorize)
15//! 1. [Validate and obtain](CliOAuth::validate) the authorization code
16//! 1. [Exchange the code](oauth2::Client::exchange_code) for a token
17//!
18//! # Example
19//!
20//! This example is adapted directly from the [`oauth2`] package documentation ("Asynchronous API"),
21//! and demonstrates how `CliOAuth` fills in the gaps.
22//!
23//! ```no_run
24//! use anyhow;
25//! use oauth2::{
26//! AuthorizationCode,
27//! AuthUrl,
28//! ClientId,
29//! ClientSecret,
30//! CsrfToken,
31//! PkceCodeChallenge,
32//! RedirectUrl,
33//! Scope,
34//! TokenResponse,
35//! TokenUrl
36//! };
37//! use oauth2::basic::BasicClient;
38//! # #[cfg(feature = "reqwest")]
39//! use oauth2::reqwest::async_http_client;
40//! use url::Url;
41//!
42//! # #[cfg(feature = "reqwest")]
43//! # async fn err_wrapper() -> Result<(), anyhow::Error> {
44//! // CliOAuth: Build helper with default options
45//! let mut auth = clio_auth::CliOAuth::builder().build().unwrap(); // (1)
46//! // Create an OAuth2 client by specifying the client ID, client secret, authorization URL and
47//! // token URL.
48//! let client =
49//! BasicClient::new(
50//! ClientId::new("client_id".to_string()),
51//! Some(ClientSecret::new("client_secret".to_string())),
52//! AuthUrl::new("http://authorize".to_string())?,
53//! Some(TokenUrl::new("http://token".to_string())?)
54//! )
55//! // CliOAuth: Use the local redirect URL
56//! .set_redirect_uri(auth.redirect_url()); // (2)
57//!
58//! // CliOAuth: The PKCE challenge is handled internally. Just authorize... (3)
59//! match auth.authorize(&oauth_client).await {
60//! Ok(()) => info!("authorized successfully"),
61//! Err(e) => warn!("uh oh! {:?}", e),
62//! };
63//! // CliOAuth: The browser is opened to the authorization URL (3)
64//!
65//! // Once the user has been redirected to the redirect URL, you'll have access to the
66//! // authorization code. For security reasons, your code should verify that the `state`
67//! // parameter returned by the server matches `csrf_state`.
68//! // CliOAuth: Validation must be performed to acquire the authorization code. CliOAuth handles
69//! // the CSRF verification.
70//! match auth.validate() { // (4)
71//! Ok(AuthContext {
72//! auth_code,
73//! pkce_verifier,
74//! state: _,
75//! }) => {
76//! // Now you can trade it for an access token.
77//! let token_result = client
78//! .exchange_code(auth_code) // (5)
79//! // Set the PKCE code verifier.
80//! .set_pkce_verifier(pkce_verifier)
81//! .request_async(async_http_client)
82//! .await?;
83//! // Unwrapping token_result will either produce a Token or a RequestTokenError.
84//! },
85//! Err(e) => warn!("uh oh! {:?}", e),
86//! }
87//!
88//! # Ok(())
89//! # }
90//! ```
91//!
92//! _Breaking it down..._
93//!
94//! 1. `CliOAuth` construction starts with a [builder](CliOAuthBuilder), which allows you to
95//! customize the way the authorization helper is configured. See the builder doc for more details
96//! about configuration.
97//! 2. `CliOAuth` constructs the authorization URL based on the address & port it is running on. The
98//! URL is provided to the [`oauth2::Client`] during construction.
99//! 3. Invoking the [`CliOAuth::authorize`] method will do the following things:
100//! - Launch a local web server
101//! - Generate the CSRF protection token (`state` parameter)
102//! - Open the user's browser with the URL to initiate the authorization flow
103//! - Receive the redirect from the IdP that contains the incoming authorization code
104//! - Shutdown the local web server
105//! 4. Invoking the [`CliOAuth::validate`] method will verify that an auth code was received and
106//! that the `state` parameter matches the expected value. If validation succeeds, the auth code and
107//! PKCE verifier will be returned to the caller.
108//! 5. The auth code and PKCE verifier are provided to the
109//! [exchange code](oauth2::Client::exchange_code) flow.
110//!
111//! [1]: https://www.rfc-editor.org/rfc/rfc7636
112//! [2]: https://crates.io/crates/oauth2
113
114use std::fmt::{Debug, Formatter};
115use std::net::{IpAddr, SocketAddr, TcpListener};
116use std::ops::Range;
117use std::sync::{Arc, Mutex};
118use std::time::Duration;
119
120use log::debug;
121use oauth2::{
122 AuthorizationCode, CsrfToken, ErrorResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl,
123 RevocableToken, Scope, TokenIntrospectionResponse, TokenResponse, TokenType,
124};
125use tokio::runtime::Handle;
126use url::Url;
127
128pub use crate::builder::CliOAuthBuilder;
129pub use crate::error::{AuthError, ConfigError, ServerError};
130use crate::server::launch;
131use crate::ConfigError::CannotBindAddress;
132
133mod builder;
134mod error;
135mod server;
136
137pub(crate) type PortRange = Range<u16>;
138/// A shortcut [`Result`] using an error of [`ConfigError`].
139pub type ConfigResult<T> = Result<T, ConfigError>;
140type AuthorizationResultHolder = Arc<Mutex<Option<AuthorizationResult>>>;
141
142/// The CLI OAuth helper.
143#[derive(Debug)]
144pub struct CliOAuth {
145 address: SocketAddr,
146 timeout: u64,
147 scopes: Vec<Scope>,
148 auth_context: Option<AuthContext>,
149 auth_result: Option<AuthorizationResult>,
150}
151
152impl CliOAuth {
153 /// Constructs a new builder struct for configuration.
154 pub fn builder() -> CliOAuthBuilder {
155 CliOAuthBuilder::new()
156 }
157
158 /// Generates the redirect URL that will sent in the authorization URL to the identity
159 /// provider.
160 ///
161 /// Pass the result of this method to [`oauth2::Client::set_redirect_uri`] while building the
162 /// client.
163 pub fn redirect_url(&self) -> RedirectUrl {
164 let url = format!("http://{}", self.address);
165 RedirectUrl::from_url(Url::parse(&url).unwrap())
166 }
167
168 /// Initiates the Authorization Code flow.
169 ///
170 /// The PKCE challenge and verifier are generated. The challenge is used in the authorization
171 /// URL, and the verifier is saved for the validation step.
172 ///
173 /// The user's browser is then opened to the authorization URL, and the authorization code (`code`) and CSRF token
174 /// (`state`) are extracted from the redirect request and recorded . These values will also be used in the
175 /// validation step, and then returned to the caller for the token exchange.
176 #[cfg(not(tarpaulin_include))]
177 pub async fn authorize<TE, TR, TT, TIR, RT, TRE>(
178 &mut self,
179 oauth_client: &oauth2::Client<TE, TR, TT, TIR, RT, TRE>,
180 ) -> Result<(), ServerError>
181 where
182 TE: ErrorResponse + 'static,
183 TR: TokenResponse<TT>,
184 TT: TokenType,
185 TIR: TokenIntrospectionResponse<TT>,
186 RT: RevocableToken,
187 TRE: ErrorResponse + 'static,
188 {
189 let scopes: Vec<Scope> = self.scopes.to_vec();
190 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
191 let (auth_url, state) = oauth_client
192 .authorize_url(CsrfToken::new_random)
193 .add_scopes(scopes)
194 .set_pkce_challenge(pkce_challenge)
195 .url();
196
197 // Acquire handle to Tokio runtime
198 let handle = Handle::try_current()?;
199 let server = handle.spawn(launch(self.address, Duration::from_secs(self.timeout)));
200
201 debug!("🔑 authorization URL: {}", auth_url);
202 open::that(auth_url.as_str())?;
203
204 let result = server.await?;
205
206 match result {
207 Ok(auth_result) => {
208 self.auth_result = Some(auth_result.clone());
209 let auth_ctx = AuthContext {
210 auth_code: AuthorizationCode::new(auth_result.auth_code.clone()),
211 state,
212 pkce_verifier,
213 };
214 self.auth_context = Some(auth_ctx);
215 Ok(())
216 }
217 Err(e) => Err(e),
218 }
219 }
220
221 /// Validates the authorization code and CSRF token (`state`).
222 ///
223 /// If validation is successful, then the code and PKCE verifier are returned to the caller in
224 /// order to build the [exchange code](oauth2::Client::exchange_code) request.
225 ///
226 /// This method *must* be called after [`CliOAuth::authorize`] completes successfully.
227 pub fn validate(&mut self) -> Result<AuthContext, AuthError> {
228 let expected_state = self
229 .auth_result
230 .take()
231 .ok_or(AuthError::InvalidAuthState)?
232 .state;
233 match self.auth_context.take() {
234 Some(auth_ctx) if auth_ctx.state.secret() == &expected_state => Ok(auth_ctx),
235 Some(_) => Err(AuthError::CsrfMismatch),
236 None => Err(AuthError::InvalidAuthState),
237 }
238 }
239}
240
241/// Holds intermediate values needed to complete the authorization flow.
242///
243/// These values are generated during the [authorize](CliOAuth::authorize) step, and
244/// provided to the caller after [validation](CliOAuth::validate). They can then be used for the
245/// [code exchange](oauth2::Client::exchange_code).
246#[derive(Debug)]
247pub struct AuthContext {
248 /// The authorization code obtained from the Authorize step.
249 pub auth_code: AuthorizationCode,
250 pub state: CsrfToken,
251 /// The PKCE verifier that will be supplied to the Exchange Code step.
252 pub pkce_verifier: PkceCodeVerifier,
253}
254
255#[derive(Clone)]
256struct AuthorizationResult {
257 pub auth_code: String,
258 pub state: String,
259}
260
261impl Debug for AuthorizationResult {
262 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
263 f.write_fmt(format_args!(
264 "auth code={}*****, state={}*****",
265 self.auth_code.chars().take(3).collect::<String>(),
266 self.state.chars().take(3).collect::<String>(),
267 ))
268 }
269}
270
271const PORT_MIN: u16 = 1024;
272const DEFAULT_PORT_MIN: u16 = 3456;
273const DEFAULT_PORT_MAX: u16 = DEFAULT_PORT_MIN + 10;
274const DEFAULT_TIMEOUT: u64 = 60;
275
276/// Finds an available port within the give range.
277///
278/// Each port will be tried in ascending order. The first port that can successfully bind will be
279/// used, and the resulting socket address will be returned. An error will be returned if no ports
280/// in the range are available.
281///
282/// Note that this function **cannot guarantee** that the address/port combination will be usable by
283/// the server, since any other process on the system could bind to it before this process does.
284fn find_available_port(ip_addr: IpAddr, port_range: PortRange) -> ConfigResult<SocketAddr> {
285 for port in port_range.clone() {
286 let socket_addr = SocketAddr::new(ip_addr, port);
287 if is_address_available(socket_addr) {
288 return Ok(socket_addr);
289 }
290 }
291 Err(CannotBindAddress {
292 addr: ip_addr,
293 port_range,
294 })
295}
296
297/// Checks whether the given socket address is available for this process to use.
298fn is_address_available(socket_addr: SocketAddr) -> bool {
299 TcpListener::bind(socket_addr).is_ok()
300}
301
302// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
303// NOTE! The tests below all use different ports/port ranges, because the order of the tests
304// cannot be guaranteed. If the ports overlap, then tests will fail randomly. Make sure that any
305// future tests use their own unique port values. The best way to do that is with the `next_ports`
306// function to acquire a range of ports for the test.
307// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
308
309#[cfg(test)]
310mod tests {
311 use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener};
312 use std::sync::atomic::AtomicU16;
313 use std::sync::atomic::Ordering::AcqRel;
314
315 use rstest::{fixture, rstest};
316
317 use crate::{find_available_port, is_address_available, PortRange};
318
319 pub(crate) static LOCALHOST: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST);
320 pub(crate) static PORT_GENERATOR: AtomicU16 = AtomicU16::new(8000);
321
322 // Acquires a range of port numbers for a test.
323 //
324 // Any test that needs to perform testing with network ports should call this method at the
325 // beginning to get the next start and end ports for the test:
326 //
327 // ```
328 // let (port_start, port_end) = next_ports(5);
329 // ```
330 //
331 // The function is backed by an atomic integer, so each test is guaranteed to get a unique
332 // range.
333 pub(crate) fn next_ports(count: u16) -> (u16, u16) {
334 let start = PORT_GENERATOR.fetch_add(count, AcqRel);
335 let end = start + count - 1;
336 (start, end)
337 }
338
339 /// Acquires a range of port numbers for a test.
340 ///
341 /// This is an alternative to [`next_ports`].
342 pub(crate) fn port_range(count: u16) -> PortRange {
343 let (start, end) = next_ports(count);
344 start..end
345 }
346
347 #[fixture]
348 fn one_port() -> PortRange {
349 port_range(1)
350 }
351
352 #[fixture]
353 fn two_ports() -> PortRange {
354 port_range(2)
355 }
356
357 #[fixture]
358 fn three_ports() -> PortRange {
359 port_range(3)
360 }
361
362 #[rstest]
363 fn find_available_port_with_open_port(three_ports: PortRange) {
364 let res = find_available_port(LOCALHOST, three_ports.clone());
365 match res {
366 Ok(addr) => assert!(three_ports.contains(&addr.port())),
367 Err(e) => panic!("error finding available port: {:?}", e),
368 }
369 }
370
371 #[rstest]
372 fn find_available_port_with_no_open_port(two_ports: PortRange) {
373 // Acquire sockets on both ports we need
374 let _s1 = TcpListener::bind(SocketAddr::new(LOCALHOST, two_ports.start)).unwrap();
375 let _s2 = TcpListener::bind(SocketAddr::new(LOCALHOST, two_ports.end)).unwrap();
376 let res = find_available_port(LOCALHOST, two_ports);
377 res.expect_err("ports should not be available");
378 }
379
380 #[rstest]
381 fn check_address_is_available_when_port_is_open(two_ports: PortRange) {
382 let _sock = TcpListener::bind(SocketAddr::new(LOCALHOST, two_ports.end))
383 .expect("control port {open_port} is already open");
384 let address = SocketAddr::new(LOCALHOST, two_ports.start);
385 assert!(is_address_available(address));
386 }
387
388 #[rstest]
389 fn check_address_is_not_available_when_port_is_used(one_port: PortRange) {
390 let _socket = TcpListener::bind(SocketAddr::new(LOCALHOST, one_port.end)).expect(
391 "port is already \
392 open",
393 );
394 let address = SocketAddr::new(LOCALHOST, one_port.start);
395 assert!(!is_address_available(address));
396 }
397
398 mod cli_oauth {
399 use crate::{AuthContext, AuthError, AuthorizationResult, CliOAuth};
400 use oauth2::{AuthorizationCode, CsrfToken, PkceCodeVerifier};
401 use rstest::{fixture, rstest};
402
403 #[fixture]
404 fn auth() -> CliOAuth {
405 CliOAuth {
406 address: ([127, 0, 0, 1], 8080).into(),
407 timeout: 30,
408 scopes: vec![],
409 auth_context: None,
410 auth_result: None,
411 }
412 }
413
414 #[fixture]
415 fn auth_context() -> AuthContext {
416 AuthContext {
417 state: CsrfToken::new(String::from("state")),
418 auth_code: AuthorizationCode::new(String::from("code")),
419 pkce_verifier: PkceCodeVerifier::new(String::from("pkce")),
420 }
421 }
422
423 #[fixture]
424 fn auth_result() -> AuthorizationResult {
425 AuthorizationResult {
426 auth_code: String::from("code"),
427 state: String::from("state"),
428 }
429 }
430
431 #[rstest]
432 fn redirect_url_valid(auth: CliOAuth) {
433 let url = auth.redirect_url();
434 assert_eq!("http://127.0.0.1:8080/", url.as_str());
435 }
436
437 #[rstest]
438 fn validate_with_no_context(mut auth: CliOAuth, auth_result: AuthorizationResult) {
439 auth.auth_result = Some(auth_result);
440 assert!(auth.validate().is_err());
441 }
442
443 #[rstest]
444 fn validate_with_no_result(mut auth: CliOAuth, auth_context: AuthContext) {
445 auth.auth_context = Some(auth_context);
446 assert!(auth.validate().is_err());
447 }
448
449 #[rstest]
450 fn validate_state_mismatch(
451 mut auth: CliOAuth,
452 mut auth_result: AuthorizationResult,
453 auth_context: AuthContext,
454 ) {
455 auth_result.state = String::from("other_state");
456 auth.auth_result = Some(auth_result);
457 auth.auth_context = Some(auth_context);
458 match auth.validate() {
459 Err(AuthError::CsrfMismatch) => (),
460 Err(e) => panic!("CsrfMismatch error should be raised, but was {:?}", e),
461 Ok(_) => panic!("Validation should fail"),
462 };
463 }
464 }
465}