use crate::cli::output::{
AppEmoji, finish_spinner_error, finish_spinner_success, print_step, start_spinner,
};
use crate::services::authentication::authenticator::Authenticator;
use crate::services::authentication::urls::IdentityProvider;
use crate::services::http_client::client::create_http_client;
use crate::services::local_server::server::start_local_server;
use miette::{Context, IntoDiagnostic, Result};
use oauth2::{
AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope,
TokenResponse, TokenUrl, basic::BasicClient,
};
use tokio::sync::mpsc::Receiver;
pub struct AuthorizationCodeFlow {
pub provider: IdentityProvider,
pub client_id: String,
pub scopes: Vec<String>,
pub port: u16,
}
impl Authenticator for AuthorizationCodeFlow {
async fn get_token(&self) -> Result<String> {
let auth_uri = AuthUrl::new(self.provider.auth_url())
.into_diagnostic()
.wrap_err("Invalid authorization URL")?;
let token_uri = TokenUrl::new(self.provider.token_url())
.into_diagnostic()
.wrap_err("Invalid token URL")?;
let redirect_url = RedirectUrl::new(format!("http://localhost:{}/callback", self.port))
.into_diagnostic()?;
let client = BasicClient::new(ClientId::new(self.client_id.clone()))
.set_auth_uri(auth_uri)
.set_token_uri(token_uri)
.set_redirect_uri(redirect_url);
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let mut auth_req = client
.authorize_url(CsrfToken::new_random)
.set_pkce_challenge(pkce_challenge);
if let Some(audience) = self.provider.audience() {
auth_req = auth_req.add_extra_param("audience", audience);
}
for scope in &self.scopes {
auth_req = auth_req.add_scope(Scope::new(scope.clone()));
}
let (authorize_url, _) = auth_req.url();
let (mut rx, server_handle) = start_local_server(self.port).await?;
print_step(AppEmoji::Rocket, "Opening browser...");
if webbrowser::open(authorize_url.as_str()).is_err() {
println!("Please open: {}", authorize_url);
}
let code = self.wait_for_code(&mut rx).await?;
server_handle.abort();
let http_client = create_http_client()?;
let token_result = client
.exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(pkce_verifier)
.request_async(&http_client)
.await
.into_diagnostic()
.wrap_err("Failed to exchange Authorization Code for Access Token")?;
Ok(token_result.access_token().secret().clone())
}
}
impl AuthorizationCodeFlow {
async fn wait_for_code(&self, rx: &mut Receiver<Result<String, String>>) -> Result<String> {
let spinner = start_spinner("Waiting for authentication...")?;
let result = tokio::time::timeout(std::time::Duration::from_secs(120), rx.recv())
.await
.map_err(|_| {
miette::miette!(
help = "Check your browser and try again",
"Authentication timed out after 120 seconds"
)
})?
.ok_or_else(|| miette::miette!("Failed to receive communication from local server"))?;
match result {
Ok(code) => {
finish_spinner_success(&spinner, "Authentication successful!");
Ok(code)
}
Err(err_msg) => {
finish_spinner_error(&spinner, "Authentication failed!");
Err(miette::miette!("Browser authentication error: {}", err_msg))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::mpsc;
fn create_dummy_flow() -> AuthorizationCodeFlow {
AuthorizationCodeFlow {
provider: IdentityProvider::Microsoft {
tenant_id: "common".to_string(),
},
client_id: "dummy_client".to_string(),
scopes: vec![],
port: 3000,
}
}
#[tokio::test]
async fn test_wait_for_code_success() {
console::set_colors_enabled(false);
let flow = create_dummy_flow();
let (tx, mut rx) = mpsc::channel(1);
tx.send(Ok("valid_auth_code_123".to_string()))
.await
.unwrap();
let result = flow.wait_for_code(&mut rx).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "valid_auth_code_123");
}
#[tokio::test]
async fn test_wait_for_code_server_error() {
console::set_colors_enabled(false);
let flow = create_dummy_flow();
let (tx, mut rx) = mpsc::channel(1);
tx.send(Err("access_denied".to_string())).await.unwrap();
let result = flow.wait_for_code(&mut rx).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("access_denied"));
}
#[tokio::test]
async fn test_wait_for_code_channel_dropped_prematurely() {
console::set_colors_enabled(false);
let flow = create_dummy_flow();
let (tx, mut rx) = mpsc::channel::<Result<String, String>>(1);
drop(tx);
let result = flow.wait_for_code(&mut rx).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Failed to receive communication")
);
}
#[tokio::test]
async fn test_wait_for_code_timeout() {
console::set_colors_enabled(false);
let flow = create_dummy_flow();
let (_tx, mut rx) = mpsc::channel::<Result<String, String>>(1);
tokio::time::pause();
let result = flow.wait_for_code(&mut rx).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("timed out"));
}
}