use anyhow::{Context, Result, anyhow};
use axum::Router;
use axum::extract::Query;
use axum::response::Html;
use axum::routing::get;
use reqwest::Client;
use std::sync::Arc;
use systemprompt_logging::CliService;
use tokio::sync::{Mutex, oneshot};
use crate::OAuthProvider;
use crate::constants::oauth::{CALLBACK_PORT, CALLBACK_TIMEOUT_SECS};
#[derive(serde::Deserialize)]
struct CallbackParams {
access_token: Option<String>,
error: Option<String>,
error_description: Option<String>,
}
#[derive(serde::Deserialize)]
struct AuthorizeResponse {
authorize_url: String,
}
#[derive(Debug, Clone, Copy)]
pub struct OAuthTemplates {
pub success_html: &'static str,
pub error_html: &'static str,
}
pub async fn run_oauth_flow(
api_url: &str,
provider: OAuthProvider,
templates: OAuthTemplates,
) -> Result<String> {
let (tx, rx) = oneshot::channel::<Result<String>>();
let tx = Arc::new(Mutex::new(Some(tx)));
let success_html = templates.success_html.to_string();
let error_html = templates.error_html.to_string();
let callback_handler = {
let tx = Arc::clone(&tx);
let success_html = success_html.clone();
let error_html = error_html.clone();
move |Query(params): Query<CallbackParams>| {
let tx = Arc::clone(&tx);
let success_html = success_html.clone();
let error_html = error_html.clone();
async move {
let result = if let Some(error) = params.error {
let desc = params
.error_description
.unwrap_or_else(|| "(no description provided)".into());
Err(anyhow!("OAuth error: {} - {}", error, desc))
} else if let Some(token) = params.access_token {
Ok(token)
} else {
Err(anyhow!("No token received in callback"))
};
let sender = tx.lock().await.take();
if let Some(sender) = sender {
let is_success = result.is_ok();
if sender.send(result).is_err() {
tracing::warn!("OAuth result receiver dropped before result could be sent");
}
if is_success {
Html(success_html)
} else {
Html(error_html)
}
} else {
Html(error_html)
}
}
}
};
let app = Router::new().route("/callback", get(callback_handler));
let addr = format!("127.0.0.1:{CALLBACK_PORT}");
let listener = tokio::net::TcpListener::bind(&addr).await?;
CliService::info(&format!("Starting authentication server on http://{addr}"));
let redirect_uri = format!("http://127.0.0.1:{CALLBACK_PORT}/callback");
CliService::info("Fetching authorization URL...");
let client = Client::new();
let oauth_endpoint = format!(
"{}/api/v1/auth/oauth/{}?redirect_uri={}",
api_url,
provider.as_str(),
urlencoding::encode(&redirect_uri)
);
let response = client
.get(&oauth_endpoint)
.send()
.await
.context("Failed to connect to API")?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_else(|e| {
tracing::warn!(error = %e, "Failed to read OAuth error response body");
format!("(body unreadable: {})", e)
});
return Err(anyhow!(
"Failed to get authorization URL ({}): {}",
status,
body
));
}
let auth_response: AuthorizeResponse = response
.json()
.await
.context("Failed to parse authorization response")?;
let auth_url = auth_response.authorize_url;
CliService::info(&format!(
"Opening browser for {} authentication...",
provider.display_name()
));
CliService::info(&format!("URL: {auth_url}"));
if let Err(e) = open::that(&auth_url) {
CliService::warning(&format!("Could not open browser automatically: {e}"));
CliService::info("Please open this URL manually:");
CliService::key_value("URL", &auth_url);
}
CliService::info("Waiting for authentication...");
CliService::info(&format!("(timeout in {CALLBACK_TIMEOUT_SECS} seconds)"));
let server = axum::serve(listener, app);
tokio::select! {
result = rx => {
result.map_err(|_| anyhow!("Authentication cancelled"))?
}
_ = server => {
Err(anyhow!("Server stopped unexpectedly"))
}
() = tokio::time::sleep(std::time::Duration::from_secs(CALLBACK_TIMEOUT_SECS)) => {
Err(anyhow!("Authentication timed out after {CALLBACK_TIMEOUT_SECS} seconds"))
}
}
}