Skip to main content

ez_token/services/authentication/
pkce.rs

1use crate::cli::output::{
2    AppEmoji, finish_spinner_error, finish_spinner_success, print_step, start_spinner,
3};
4use crate::services::authentication::authenticator::Authenticator;
5use crate::services::authentication::urls::IdentityProvider;
6use crate::services::http_client::client::create_http_client;
7use crate::services::local_server::server::start_local_server;
8use miette::{Context, IntoDiagnostic, Result};
9use oauth2::{
10    AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope,
11    TokenResponse, TokenUrl, basic::BasicClient,
12};
13use tokio::sync::mpsc::Receiver;
14
15/// An OAuth2 Authorization Code flow with PKCE for supported identity providers.
16///
17/// This flow is designed for interactive user authentication. It opens the
18/// system browser to the provider's authorization endpoint, starts a local
19/// HTTP server to receive the callback, and exchanges the authorization
20/// code for an access token.
21///
22/// # Provider Requirements
23///
24/// ## Microsoft Entra ID
25/// - The application must have **Delegated permissions**
26/// - The redirect URI `http://localhost:{port}/callback` must be registered
27///   under **Mobile and desktop applications** in the Azure Portal
28/// - Public client flows must be enabled
29///
30/// ## Auth0
31/// - **Application Type** must be set to **Native**
32/// - The redirect URI `http://localhost:{port}/callback` must be listed
33///   under **Allowed Callback URLs**
34/// - **Token Endpoint Authentication Method** must be set to **None**
35/// - **Non-Verifiable Callback URI End-User Confirmation** must be **off**
36pub struct AuthorizationCodeFlow {
37    /// The resolved identity provider with all required endpoint data.
38    pub provider: IdentityProvider,
39
40    /// The Application (Client) ID registered in Entra ID.
41    pub client_id: String,
42
43    /// The list of OAuth2 scopes to request.
44    ///
45    /// For Microsoft use delegated scopes (e.g. `User.Read`).
46    /// For Auth0 use explicit scopes (e.g. `read:ez`).
47    pub scopes: Vec<String>,
48
49    /// The local port to listen on for the OAuth2 redirect callback.
50    ///
51    /// Must match the redirect URI registered with your identity provider
52    /// (e.g. `http://localhost:3000/callback`).
53    pub port: u16,
54}
55
56impl Authenticator for AuthorizationCodeFlow {
57    /// Performs the full Authorization Code + PKCE flow and returns an access token.
58    ///
59    /// **Steps:**
60    /// 1. Builds the authorization URL with a PKCE challenge
61    /// 2. Starts a local HTTP server to receive the callback
62    /// 3. Opens the system browser to the authorization URL
63    /// 4. Waits for the authorization code via the local server (120s timeout)
64    /// 5. Exchanges the code for an access token
65    ///
66    /// For Auth0, an `audience` parameter is included automatically in the
67    /// authorization request.
68    ///
69    /// # Errors
70    ///
71    /// Returns an error if:
72    /// - The provider produces an invalid URL
73    /// - The local server fails to bind to the given port
74    /// - The browser authorization is denied or times out
75    /// - The token exchange with the identity provider fails
76    async fn get_token(&self) -> Result<String> {
77        let auth_uri = AuthUrl::new(self.provider.auth_url())
78            .into_diagnostic()
79            .wrap_err("Invalid authorization URL")?;
80
81        let token_uri = TokenUrl::new(self.provider.token_url())
82            .into_diagnostic()
83            .wrap_err("Invalid token URL")?;
84
85        let redirect_url = RedirectUrl::new(format!("http://localhost:{}/callback", self.port))
86            .into_diagnostic()?;
87
88        let client = BasicClient::new(ClientId::new(self.client_id.clone()))
89            .set_auth_uri(auth_uri)
90            .set_token_uri(token_uri)
91            .set_redirect_uri(redirect_url);
92
93        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
94
95        let mut auth_req = client
96            .authorize_url(CsrfToken::new_random)
97            .set_pkce_challenge(pkce_challenge);
98
99        if let Some(audience) = self.provider.audience() {
100            auth_req = auth_req.add_extra_param("audience", audience);
101        }
102
103        for scope in &self.scopes {
104            auth_req = auth_req.add_scope(Scope::new(scope.clone()));
105        }
106
107        let (authorize_url, _) = auth_req.url();
108        let (mut rx, server_handle) = start_local_server(self.port).await?;
109
110        print_step(AppEmoji::Rocket, "Opening browser...");
111        if webbrowser::open(authorize_url.as_str()).is_err() {
112            println!("Please open: {}", authorize_url);
113        }
114
115        let code = self.wait_for_code(&mut rx).await?;
116        server_handle.abort();
117
118        let http_client = create_http_client()?;
119        let token_result = client
120            .exchange_code(AuthorizationCode::new(code))
121            .set_pkce_verifier(pkce_verifier)
122            .request_async(&http_client)
123            .await
124            .into_diagnostic()
125            .wrap_err("Failed to exchange Authorization Code for Access Token")?;
126
127        Ok(token_result.access_token().secret().clone())
128    }
129}
130
131impl AuthorizationCodeFlow {
132    /// Waits for the authorization code to arrive via the local callback server.
133    ///
134    /// Displays a spinner while waiting. The code is sent over a channel by
135    /// the local HTTP server once the browser completes the authorization.
136    /// Times out after 120 seconds — this covers cases where the identity
137    /// provider shows an error page instead of redirecting back to the callback.
138    ///
139    /// # Errors
140    ///
141    /// Returns an error if:
142    /// - Authentication times out after 120 seconds
143    /// - The channel closes before a result is received
144    /// - The browser authorization is denied or the provider returns an error
145    async fn wait_for_code(&self, rx: &mut Receiver<Result<String, String>>) -> Result<String> {
146        let spinner = start_spinner("Waiting for authentication...")?;
147
148        let result = tokio::time::timeout(std::time::Duration::from_secs(120), rx.recv())
149            .await
150            .map_err(|_| {
151                miette::miette!(
152                    help = "Check your browser and try again",
153                    "Authentication timed out after 120 seconds"
154                )
155            })?
156            .ok_or_else(|| miette::miette!("Failed to receive communication from local server"))?;
157
158        match result {
159            Ok(code) => {
160                finish_spinner_success(&spinner, "Authentication successful!");
161                Ok(code)
162            }
163            Err(err_msg) => {
164                finish_spinner_error(&spinner, "Authentication failed!");
165                Err(miette::miette!("Browser authentication error: {}", err_msg))
166            }
167        }
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use tokio::sync::mpsc;
175
176    fn create_dummy_flow() -> AuthorizationCodeFlow {
177        AuthorizationCodeFlow {
178            provider: IdentityProvider::Microsoft {
179                tenant_id: "common".to_string(),
180            },
181            client_id: "dummy_client".to_string(),
182            scopes: vec![],
183            port: 3000,
184        }
185    }
186
187    #[tokio::test]
188    async fn test_wait_for_code_success() {
189        console::set_colors_enabled(false);
190        let flow = create_dummy_flow();
191
192        let (tx, mut rx) = mpsc::channel(1);
193
194        tx.send(Ok("valid_auth_code_123".to_string()))
195            .await
196            .unwrap();
197
198        let result = flow.wait_for_code(&mut rx).await;
199
200        assert!(result.is_ok());
201        assert_eq!(result.unwrap(), "valid_auth_code_123");
202    }
203
204    #[tokio::test]
205    async fn test_wait_for_code_server_error() {
206        console::set_colors_enabled(false);
207        let flow = create_dummy_flow();
208        let (tx, mut rx) = mpsc::channel(1);
209
210        tx.send(Err("access_denied".to_string())).await.unwrap();
211
212        let result = flow.wait_for_code(&mut rx).await;
213
214        assert!(result.is_err());
215        assert!(result.unwrap_err().to_string().contains("access_denied"));
216    }
217
218    #[tokio::test]
219    async fn test_wait_for_code_channel_dropped_prematurely() {
220        console::set_colors_enabled(false);
221        let flow = create_dummy_flow();
222        let (tx, mut rx) = mpsc::channel::<Result<String, String>>(1);
223
224        drop(tx);
225
226        let result = flow.wait_for_code(&mut rx).await;
227
228        assert!(result.is_err());
229        assert!(
230            result
231                .unwrap_err()
232                .to_string()
233                .contains("Failed to receive communication")
234        );
235    }
236
237    #[tokio::test]
238    async fn test_wait_for_code_timeout() {
239        console::set_colors_enabled(false);
240        let flow = create_dummy_flow();
241        let (_tx, mut rx) = mpsc::channel::<Result<String, String>>(1);
242
243        tokio::time::pause();
244
245        let result = flow.wait_for_code(&mut rx).await;
246
247        assert!(result.is_err());
248        assert!(result.unwrap_err().to_string().contains("timed out"));
249    }
250}