ez_token/services/authentication/
pkce.rs1use crate::cli::output::{
2 AppEmoji, finish_spinner_error, finish_spinner_success, print_step, start_spinner,
3};
4use crate::services::authentication::authenticator::Authenticator;
5use crate::services::authentication::urls::IdentityProvider;
6use crate::services::http_client::client::create_http_client;
7use crate::services::local_server::server::start_local_server;
8use miette::{Context, IntoDiagnostic, Result};
9use oauth2::{
10 AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope,
11 TokenResponse, TokenUrl, basic::BasicClient,
12};
13use tokio::sync::mpsc::Receiver;
14
15pub struct AuthorizationCodeFlow {
37 pub provider: IdentityProvider,
39
40 pub client_id: String,
42
43 pub scopes: Vec<String>,
48
49 pub port: u16,
54}
55
56impl Authenticator for AuthorizationCodeFlow {
57 async fn get_token(&self) -> Result<String> {
77 let auth_uri = AuthUrl::new(self.provider.auth_url())
78 .into_diagnostic()
79 .wrap_err("Invalid authorization URL")?;
80
81 let token_uri = TokenUrl::new(self.provider.token_url())
82 .into_diagnostic()
83 .wrap_err("Invalid token URL")?;
84
85 let redirect_url = RedirectUrl::new(format!("http://localhost:{}/callback", self.port))
86 .into_diagnostic()?;
87
88 let client = BasicClient::new(ClientId::new(self.client_id.clone()))
89 .set_auth_uri(auth_uri)
90 .set_token_uri(token_uri)
91 .set_redirect_uri(redirect_url);
92
93 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
94
95 let mut auth_req = client
96 .authorize_url(CsrfToken::new_random)
97 .set_pkce_challenge(pkce_challenge);
98
99 if let Some(audience) = self.provider.audience() {
100 auth_req = auth_req.add_extra_param("audience", audience);
101 }
102
103 for scope in &self.scopes {
104 auth_req = auth_req.add_scope(Scope::new(scope.clone()));
105 }
106
107 let (authorize_url, _) = auth_req.url();
108 let (mut rx, server_handle) = start_local_server(self.port).await?;
109
110 print_step(AppEmoji::Rocket, "Opening browser...");
111 if webbrowser::open(authorize_url.as_str()).is_err() {
112 println!("Please open: {}", authorize_url);
113 }
114
115 let code = self.wait_for_code(&mut rx).await?;
116 server_handle.abort();
117
118 let http_client = create_http_client()?;
119 let token_result = client
120 .exchange_code(AuthorizationCode::new(code))
121 .set_pkce_verifier(pkce_verifier)
122 .request_async(&http_client)
123 .await
124 .into_diagnostic()
125 .wrap_err("Failed to exchange Authorization Code for Access Token")?;
126
127 Ok(token_result.access_token().secret().clone())
128 }
129}
130
131impl AuthorizationCodeFlow {
132 async fn wait_for_code(&self, rx: &mut Receiver<Result<String, String>>) -> Result<String> {
146 let spinner = start_spinner("Waiting for authentication...")?;
147
148 let result = tokio::time::timeout(std::time::Duration::from_secs(120), rx.recv())
149 .await
150 .map_err(|_| {
151 miette::miette!(
152 help = "Check your browser and try again",
153 "Authentication timed out after 120 seconds"
154 )
155 })?
156 .ok_or_else(|| miette::miette!("Failed to receive communication from local server"))?;
157
158 match result {
159 Ok(code) => {
160 finish_spinner_success(&spinner, "Authentication successful!");
161 Ok(code)
162 }
163 Err(err_msg) => {
164 finish_spinner_error(&spinner, "Authentication failed!");
165 Err(miette::miette!("Browser authentication error: {}", err_msg))
166 }
167 }
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use tokio::sync::mpsc;
175
176 fn create_dummy_flow() -> AuthorizationCodeFlow {
177 AuthorizationCodeFlow {
178 provider: IdentityProvider::Microsoft {
179 tenant_id: "common".to_string(),
180 },
181 client_id: "dummy_client".to_string(),
182 scopes: vec![],
183 port: 3000,
184 }
185 }
186
187 #[tokio::test]
188 async fn test_wait_for_code_success() {
189 console::set_colors_enabled(false);
190 let flow = create_dummy_flow();
191
192 let (tx, mut rx) = mpsc::channel(1);
193
194 tx.send(Ok("valid_auth_code_123".to_string()))
195 .await
196 .unwrap();
197
198 let result = flow.wait_for_code(&mut rx).await;
199
200 assert!(result.is_ok());
201 assert_eq!(result.unwrap(), "valid_auth_code_123");
202 }
203
204 #[tokio::test]
205 async fn test_wait_for_code_server_error() {
206 console::set_colors_enabled(false);
207 let flow = create_dummy_flow();
208 let (tx, mut rx) = mpsc::channel(1);
209
210 tx.send(Err("access_denied".to_string())).await.unwrap();
211
212 let result = flow.wait_for_code(&mut rx).await;
213
214 assert!(result.is_err());
215 assert!(result.unwrap_err().to_string().contains("access_denied"));
216 }
217
218 #[tokio::test]
219 async fn test_wait_for_code_channel_dropped_prematurely() {
220 console::set_colors_enabled(false);
221 let flow = create_dummy_flow();
222 let (tx, mut rx) = mpsc::channel::<Result<String, String>>(1);
223
224 drop(tx);
225
226 let result = flow.wait_for_code(&mut rx).await;
227
228 assert!(result.is_err());
229 assert!(
230 result
231 .unwrap_err()
232 .to_string()
233 .contains("Failed to receive communication")
234 );
235 }
236
237 #[tokio::test]
238 async fn test_wait_for_code_timeout() {
239 console::set_colors_enabled(false);
240 let flow = create_dummy_flow();
241 let (_tx, mut rx) = mpsc::channel::<Result<String, String>>(1);
242
243 tokio::time::pause();
244
245 let result = flow.wait_for_code(&mut rx).await;
246
247 assert!(result.is_err());
248 assert!(result.unwrap_err().to_string().contains("timed out"));
249 }
250}