clio_auth/
lib.rs

1//! OAuth 2.0 helper for CLI and desktop applications.
2//!
3//! This package facilitates the [OAuth 2.0 Authorization Code with PKCE][1] flow for command line
4//! and desktop GUI applications. It works hand-in-hand with the [oauth2][2] crate by providing the
5//! "missing pieces" for the flow: a web server to handle the authorization callback, and opening
6//! the browser with the authorization link.
7//!
8//! # Usage
9//!
10//! General usage is as follows:
11//!
12//! 1. Configure a [`CliOAuthBuilder`] and build a [`CliOAuth`] helper
13//! 1. Configure an [`oauth2::Client`]
14//! 1. Start the [authorization flow](CliOAuth::authorize)
15//! 1. [Validate and obtain](CliOAuth::validate) the authorization code
16//! 1. [Exchange the code](oauth2::Client::exchange_code) for a token
17//!
18//! # Example
19//!
20//! This example is adapted directly from the [`oauth2`] package documentation ("Asynchronous API"),
21//! and demonstrates how `CliOAuth` fills in the gaps.
22//!
23//! ```no_run
24//! use anyhow;
25//! use oauth2::{
26//!     AuthorizationCode,
27//!     AuthUrl,
28//!     ClientId,
29//!     ClientSecret,
30//!     CsrfToken,
31//!     PkceCodeChallenge,
32//!     RedirectUrl,
33//!     Scope,
34//!     TokenResponse,
35//!     TokenUrl
36//! };
37//! use oauth2::basic::BasicClient;
38//! # #[cfg(feature = "reqwest")]
39//! use oauth2::reqwest::async_http_client;
40//! use url::Url;
41//!
42//! # #[cfg(feature = "reqwest")]
43//! # async fn err_wrapper() -> Result<(), anyhow::Error> {
44//! // CliOAuth: Build helper with default options
45//! let mut auth = clio_auth::CliOAuth::builder().build().unwrap();        // (1)
46//! // Create an OAuth2 client by specifying the client ID, client secret, authorization URL and
47//! // token URL.
48//! let client =
49//!     BasicClient::new(
50//!         ClientId::new("client_id".to_string()),
51//!         Some(ClientSecret::new("client_secret".to_string())),
52//!         AuthUrl::new("http://authorize".to_string())?,
53//!         Some(TokenUrl::new("http://token".to_string())?)
54//!     )
55//!     // CliOAuth: Use the local redirect URL
56//!     .set_redirect_uri(auth.redirect_url());                           // (2)
57//!
58//! // CliOAuth: The PKCE challenge is handled internally. Just authorize... (3)
59//! match auth.authorize(&oauth_client).await {
60//!     Ok(()) => info!("authorized successfully"),
61//!     Err(e) => warn!("uh oh! {:?}", e),
62//! };
63//! // CliOAuth: The browser is opened to the authorization URL              (3)
64//!
65//! // Once the user has been redirected to the redirect URL, you'll have access to the
66//! // authorization code. For security reasons, your code should verify that the `state`
67//! // parameter returned by the server matches `csrf_state`.
68//! // CliOAuth: Validation must be performed to acquire the authorization code. CliOAuth handles
69//! // the CSRF verification.
70//! match auth.validate() {                                              // (4)
71//!     Ok(AuthContext {
72//!         auth_code,
73//!         pkce_verifier,
74//!         state: _,
75//!     }) => {
76//!         // Now you can trade it for an access token.
77//!         let token_result = client
78//!             .exchange_code(auth_code)                                // (5)
79//!             // Set the PKCE code verifier.
80//!             .set_pkce_verifier(pkce_verifier)
81//!             .request_async(async_http_client)
82//!             .await?;
83//!         // Unwrapping token_result will either produce a Token or a RequestTokenError.
84//!     },
85//!     Err(e) => warn!("uh oh! {:?}", e),
86//! }
87//!
88//! # Ok(())
89//! # }
90//! ```
91//!
92//! _Breaking it down..._
93//!
94//! 1. `CliOAuth` construction starts with a [builder](CliOAuthBuilder), which allows you to
95//! customize the way the authorization helper is configured. See the builder doc for more details
96//! about configuration.
97//! 2. `CliOAuth` constructs the authorization URL based on the address & port it is running on. The
98//! URL is provided to the [`oauth2::Client`] during construction.
99//! 3. Invoking the [`CliOAuth::authorize`] method will do the following things:
100//!    - Launch a local web server
101//!    - Generate the CSRF protection token (`state` parameter)
102//!    - Open the user's browser with the URL to initiate the authorization flow
103//!    - Receive the redirect from the IdP that contains the incoming authorization code
104//!    - Shutdown the local web server
105//! 4. Invoking the [`CliOAuth::validate`] method will verify that an auth code was received and
106//! that the `state` parameter matches the expected value. If validation succeeds, the auth code and
107//! PKCE verifier will be returned to the caller.
108//! 5. The auth code and PKCE verifier are provided to the
109//! [exchange code](oauth2::Client::exchange_code) flow.
110//!
111//! [1]: https://www.rfc-editor.org/rfc/rfc7636
112//! [2]: https://crates.io/crates/oauth2
113
114use std::fmt::{Debug, Formatter};
115use std::net::{IpAddr, SocketAddr, TcpListener};
116use std::ops::Range;
117use std::sync::{Arc, Mutex};
118use std::time::Duration;
119
120use log::debug;
121use oauth2::{
122    AuthorizationCode, CsrfToken, ErrorResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl,
123    RevocableToken, Scope, TokenIntrospectionResponse, TokenResponse, TokenType,
124};
125use tokio::runtime::Handle;
126use url::Url;
127
128pub use crate::builder::CliOAuthBuilder;
129pub use crate::error::{AuthError, ConfigError, ServerError};
130use crate::server::launch;
131use crate::ConfigError::CannotBindAddress;
132
133mod builder;
134mod error;
135mod server;
136
137pub(crate) type PortRange = Range<u16>;
138/// A shortcut [`Result`] using an error of [`ConfigError`].
139pub type ConfigResult<T> = Result<T, ConfigError>;
140type AuthorizationResultHolder = Arc<Mutex<Option<AuthorizationResult>>>;
141
142/// The CLI OAuth helper.
143#[derive(Debug)]
144pub struct CliOAuth {
145    address: SocketAddr,
146    timeout: u64,
147    scopes: Vec<Scope>,
148    auth_context: Option<AuthContext>,
149    auth_result: Option<AuthorizationResult>,
150}
151
152impl CliOAuth {
153    /// Constructs a new builder struct for configuration.
154    pub fn builder() -> CliOAuthBuilder {
155        CliOAuthBuilder::new()
156    }
157
158    /// Generates the redirect URL that will sent in the authorization URL to the identity
159    /// provider.
160    ///
161    /// Pass the result of this method to [`oauth2::Client::set_redirect_uri`] while building the
162    /// client.
163    pub fn redirect_url(&self) -> RedirectUrl {
164        let url = format!("http://{}", self.address);
165        RedirectUrl::from_url(Url::parse(&url).unwrap())
166    }
167
168    /// Initiates the Authorization Code flow.
169    ///
170    /// The PKCE challenge and verifier are generated. The challenge is used in the authorization
171    /// URL, and the verifier is saved for the validation step.
172    ///
173    /// The user's browser is then opened to the authorization URL, and the authorization code (`code`) and CSRF token
174    /// (`state`) are extracted from the redirect request and recorded . These values will also be used in the
175    /// validation step, and then returned to the caller for the token exchange.
176    #[cfg(not(tarpaulin_include))]
177    pub async fn authorize<TE, TR, TT, TIR, RT, TRE>(
178        &mut self,
179        oauth_client: &oauth2::Client<TE, TR, TT, TIR, RT, TRE>,
180    ) -> Result<(), ServerError>
181    where
182        TE: ErrorResponse + 'static,
183        TR: TokenResponse<TT>,
184        TT: TokenType,
185        TIR: TokenIntrospectionResponse<TT>,
186        RT: RevocableToken,
187        TRE: ErrorResponse + 'static,
188    {
189        let scopes: Vec<Scope> = self.scopes.to_vec();
190        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
191        let (auth_url, state) = oauth_client
192            .authorize_url(CsrfToken::new_random)
193            .add_scopes(scopes)
194            .set_pkce_challenge(pkce_challenge)
195            .url();
196
197        // Acquire handle to Tokio runtime
198        let handle = Handle::try_current()?;
199        let server = handle.spawn(launch(self.address, Duration::from_secs(self.timeout)));
200
201        debug!("🔑 authorization URL: {}", auth_url);
202        open::that(auth_url.as_str())?;
203
204        let result = server.await?;
205
206        match result {
207            Ok(auth_result) => {
208                self.auth_result = Some(auth_result.clone());
209                let auth_ctx = AuthContext {
210                    auth_code: AuthorizationCode::new(auth_result.auth_code.clone()),
211                    state,
212                    pkce_verifier,
213                };
214                self.auth_context = Some(auth_ctx);
215                Ok(())
216            }
217            Err(e) => Err(e),
218        }
219    }
220
221    /// Validates the authorization code and CSRF token (`state`).
222    ///
223    /// If validation is successful, then the code and PKCE verifier are returned to the caller in
224    /// order to build the [exchange code](oauth2::Client::exchange_code) request.
225    ///
226    /// This method *must* be called after [`CliOAuth::authorize`] completes successfully.
227    pub fn validate(&mut self) -> Result<AuthContext, AuthError> {
228        let expected_state = self
229            .auth_result
230            .take()
231            .ok_or(AuthError::InvalidAuthState)?
232            .state;
233        match self.auth_context.take() {
234            Some(auth_ctx) if auth_ctx.state.secret() == &expected_state => Ok(auth_ctx),
235            Some(_) => Err(AuthError::CsrfMismatch),
236            None => Err(AuthError::InvalidAuthState),
237        }
238    }
239}
240
241/// Holds intermediate values needed to complete the authorization flow.
242///
243/// These values are generated during the [authorize](CliOAuth::authorize) step, and
244/// provided to the caller after [validation](CliOAuth::validate). They can then be used for the
245/// [code exchange](oauth2::Client::exchange_code).
246#[derive(Debug)]
247pub struct AuthContext {
248    /// The authorization code obtained from the Authorize step.
249    pub auth_code: AuthorizationCode,
250    pub state: CsrfToken,
251    /// The PKCE verifier that will be supplied to the Exchange Code step.
252    pub pkce_verifier: PkceCodeVerifier,
253}
254
255#[derive(Clone)]
256struct AuthorizationResult {
257    pub auth_code: String,
258    pub state: String,
259}
260
261impl Debug for AuthorizationResult {
262    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
263        f.write_fmt(format_args!(
264            "auth code={}*****, state={}*****",
265            self.auth_code.chars().take(3).collect::<String>(),
266            self.state.chars().take(3).collect::<String>(),
267        ))
268    }
269}
270
271const PORT_MIN: u16 = 1024;
272const DEFAULT_PORT_MIN: u16 = 3456;
273const DEFAULT_PORT_MAX: u16 = DEFAULT_PORT_MIN + 10;
274const DEFAULT_TIMEOUT: u64 = 60;
275
276/// Finds an available port within the give range.
277///
278/// Each port will be tried in ascending order. The first port that can successfully bind will be
279/// used, and the resulting socket address will be returned. An error will be returned if no ports
280/// in the range are available.
281///
282/// Note that this function **cannot guarantee** that the address/port combination will be usable by
283/// the server, since any other process on the system could bind to it before this process does.
284fn find_available_port(ip_addr: IpAddr, port_range: PortRange) -> ConfigResult<SocketAddr> {
285    for port in port_range.clone() {
286        let socket_addr = SocketAddr::new(ip_addr, port);
287        if is_address_available(socket_addr) {
288            return Ok(socket_addr);
289        }
290    }
291    Err(CannotBindAddress {
292        addr: ip_addr,
293        port_range,
294    })
295}
296
297/// Checks whether the given socket address is available for this process to use.
298fn is_address_available(socket_addr: SocketAddr) -> bool {
299    TcpListener::bind(socket_addr).is_ok()
300}
301
302// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
303// NOTE! The tests below all use different ports/port ranges, because the order of the tests
304// cannot be guaranteed. If the ports overlap, then tests will fail randomly. Make sure that any
305// future tests use their own unique port values. The best way to do that is with the `next_ports`
306// function to acquire a range of ports for the test.
307// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
308
309#[cfg(test)]
310mod tests {
311    use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener};
312    use std::sync::atomic::AtomicU16;
313    use std::sync::atomic::Ordering::AcqRel;
314
315    use rstest::{fixture, rstest};
316
317    use crate::{find_available_port, is_address_available, PortRange};
318
319    pub(crate) static LOCALHOST: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST);
320    pub(crate) static PORT_GENERATOR: AtomicU16 = AtomicU16::new(8000);
321
322    // Acquires a range of port numbers for a test.
323    //
324    // Any test that needs to perform testing with network ports should call this method at the
325    // beginning to get the next start and end ports for the test:
326    //
327    // ```
328    // let (port_start, port_end) = next_ports(5);
329    // ```
330    //
331    // The function is backed by an atomic integer, so each test is guaranteed to get a unique
332    // range.
333    pub(crate) fn next_ports(count: u16) -> (u16, u16) {
334        let start = PORT_GENERATOR.fetch_add(count, AcqRel);
335        let end = start + count - 1;
336        (start, end)
337    }
338
339    /// Acquires a range of port numbers for a test.
340    ///
341    /// This is an alternative to [`next_ports`].
342    pub(crate) fn port_range(count: u16) -> PortRange {
343        let (start, end) = next_ports(count);
344        start..end
345    }
346
347    #[fixture]
348    fn one_port() -> PortRange {
349        port_range(1)
350    }
351
352    #[fixture]
353    fn two_ports() -> PortRange {
354        port_range(2)
355    }
356
357    #[fixture]
358    fn three_ports() -> PortRange {
359        port_range(3)
360    }
361
362    #[rstest]
363    fn find_available_port_with_open_port(three_ports: PortRange) {
364        let res = find_available_port(LOCALHOST, three_ports.clone());
365        match res {
366            Ok(addr) => assert!(three_ports.contains(&addr.port())),
367            Err(e) => panic!("error finding available port: {:?}", e),
368        }
369    }
370
371    #[rstest]
372    fn find_available_port_with_no_open_port(two_ports: PortRange) {
373        // Acquire sockets on both ports we need
374        let _s1 = TcpListener::bind(SocketAddr::new(LOCALHOST, two_ports.start)).unwrap();
375        let _s2 = TcpListener::bind(SocketAddr::new(LOCALHOST, two_ports.end)).unwrap();
376        let res = find_available_port(LOCALHOST, two_ports);
377        res.expect_err("ports should not be available");
378    }
379
380    #[rstest]
381    fn check_address_is_available_when_port_is_open(two_ports: PortRange) {
382        let _sock = TcpListener::bind(SocketAddr::new(LOCALHOST, two_ports.end))
383            .expect("control port {open_port} is already open");
384        let address = SocketAddr::new(LOCALHOST, two_ports.start);
385        assert!(is_address_available(address));
386    }
387
388    #[rstest]
389    fn check_address_is_not_available_when_port_is_used(one_port: PortRange) {
390        let _socket = TcpListener::bind(SocketAddr::new(LOCALHOST, one_port.end)).expect(
391            "port is already \
392            open",
393        );
394        let address = SocketAddr::new(LOCALHOST, one_port.start);
395        assert!(!is_address_available(address));
396    }
397
398    mod cli_oauth {
399        use crate::{AuthContext, AuthError, AuthorizationResult, CliOAuth};
400        use oauth2::{AuthorizationCode, CsrfToken, PkceCodeVerifier};
401        use rstest::{fixture, rstest};
402
403        #[fixture]
404        fn auth() -> CliOAuth {
405            CliOAuth {
406                address: ([127, 0, 0, 1], 8080).into(),
407                timeout: 30,
408                scopes: vec![],
409                auth_context: None,
410                auth_result: None,
411            }
412        }
413
414        #[fixture]
415        fn auth_context() -> AuthContext {
416            AuthContext {
417                state: CsrfToken::new(String::from("state")),
418                auth_code: AuthorizationCode::new(String::from("code")),
419                pkce_verifier: PkceCodeVerifier::new(String::from("pkce")),
420            }
421        }
422
423        #[fixture]
424        fn auth_result() -> AuthorizationResult {
425            AuthorizationResult {
426                auth_code: String::from("code"),
427                state: String::from("state"),
428            }
429        }
430
431        #[rstest]
432        fn redirect_url_valid(auth: CliOAuth) {
433            let url = auth.redirect_url();
434            assert_eq!("http://127.0.0.1:8080/", url.as_str());
435        }
436
437        #[rstest]
438        fn validate_with_no_context(mut auth: CliOAuth, auth_result: AuthorizationResult) {
439            auth.auth_result = Some(auth_result);
440            assert!(auth.validate().is_err());
441        }
442
443        #[rstest]
444        fn validate_with_no_result(mut auth: CliOAuth, auth_context: AuthContext) {
445            auth.auth_context = Some(auth_context);
446            assert!(auth.validate().is_err());
447        }
448
449        #[rstest]
450        fn validate_state_mismatch(
451            mut auth: CliOAuth,
452            mut auth_result: AuthorizationResult,
453            auth_context: AuthContext,
454        ) {
455            auth_result.state = String::from("other_state");
456            auth.auth_result = Some(auth_result);
457            auth.auth_context = Some(auth_context);
458            match auth.validate() {
459                Err(AuthError::CsrfMismatch) => (),
460                Err(e) => panic!("CsrfMismatch error should be raised, but was {:?}", e),
461                Ok(_) => panic!("Validation should fail"),
462            };
463        }
464    }
465}