use std::io::Write;
use std::sync::{Arc, Mutex};
use tiny_http::{Header, Request, Response, Server};
use tokio::sync::oneshot;
use url::form_urlencoded;
use crate::{OAuthClient, OpenAIAuthError, Result, types::SessionData};
#[derive(Debug)]
struct CallbackData {
tokens: crate::TokenSet,
#[allow(dead_code)]
session_data: Option<SessionData>,
}
struct ServerState {
tx: Mutex<Option<oneshot::Sender<Result<CallbackData>>>>,
expected_state: String,
html_responder: Arc<dyn Fn(CallbackEvent) -> String + Send + Sync>,
oauth_client: OAuthClient,
pkce_verifier: String,
tokens: Mutex<Option<crate::TokenSet>>,
session_data: Mutex<Option<SessionData>>,
}
#[derive(Debug, Clone)]
pub enum CallbackEvent {
Success {
code: String,
session_data: Option<SessionData>,
},
Error {
reason: String,
},
StateMismatch,
MissingCode,
}
pub async fn run_callback_server(
port: u16,
expected_state: &str,
oauth_client: &OAuthClient,
pkce_verifier: &str,
) -> Result<crate::TokenSet> {
run_callback_server_with_html(
port,
expected_state,
oauth_client,
pkce_verifier,
default_callback_html,
)
.await
}
pub async fn run_callback_server_with_html(
port: u16,
expected_state: &str,
oauth_client: &OAuthClient,
pkce_verifier: &str,
html_responder: impl Fn(CallbackEvent) -> String + Send + Sync + 'static,
) -> Result<crate::TokenSet> {
let (tx, rx) = oneshot::channel();
let state = Arc::new(ServerState {
tx: Mutex::new(Some(tx)),
expected_state: expected_state.to_string(),
html_responder: Arc::new(html_responder),
oauth_client: oauth_client.clone(),
pkce_verifier: pkce_verifier.to_string(),
tokens: Mutex::new(None),
session_data: Mutex::new(None),
});
let addr = format!("127.0.0.1:{port}");
tokio::task::spawn_blocking(move || run_sync_server(&addr, state));
match rx.await {
Ok(Ok(callback_data)) => Ok(callback_data.tokens),
Ok(Err(e)) => Err(e),
Err(_) => Err(OpenAIAuthError::CallbackServer(
"Server shut down unexpectedly".to_string(),
)),
}
}
fn run_sync_server(addr: &str, state: Arc<ServerState>) -> Result<()> {
let server = Server::http(addr)
.map_err(|e| OpenAIAuthError::CallbackServer(format!("Failed to bind to {addr}: {e}")))?;
for request in server.incoming_requests() {
let url = request.url();
if url.starts_with("/auth/callback") {
let should_stop = handle_callback_request(request, &state);
if should_stop {
break;
}
} else if url.starts_with("/success") {
let should_stop = handle_success_request(request, &state);
if should_stop {
break;
}
} else {
let response = Response::from_string("Not Found").with_status_code(404);
let _ = request.respond(response);
}
}
Ok(())
}
fn handle_callback_request(request: Request, state: &Arc<ServerState>) -> bool {
let url = request.url();
let query_str = url.split('?').nth(1).unwrap_or("");
let params = querystring::querify(query_str);
let code = params
.iter()
.find(|(k, _)| *k == "code")
.map(|(_, v)| v.to_string());
let received_state = params
.iter()
.find(|(k, _)| *k == "state")
.map(|(_, v)| v.to_string());
let error = params
.iter()
.find(|(k, _)| *k == "error")
.map(|(_, v)| v.to_string());
if let Some(error) = error {
let html = (state.html_responder)(CallbackEvent::Error {
reason: error.clone(),
});
let response = Response::from_string(html).with_header(
Header::from_bytes(&b"Content-Type"[..], &b"text/html; charset=utf-8"[..]).unwrap(),
);
let _ = request.respond(response);
let _ = state
.tx
.lock()
.unwrap()
.take()
.map(|tx| tx.send(Err(OpenAIAuthError::OAuth(format!("OAuth error: {error}")))));
return true;
}
let received_state_str = received_state.as_deref().unwrap_or("");
if received_state_str != state.expected_state {
let html = (state.html_responder)(CallbackEvent::StateMismatch);
let response = Response::from_string(html).with_header(
Header::from_bytes(&b"Content-Type"[..], &b"text/html; charset=utf-8"[..]).unwrap(),
);
let _ = request.respond(response);
let _ = state.tx.lock().unwrap().take().map(|tx| {
tx.send(Err(OpenAIAuthError::OAuth(
"State mismatch - possible CSRF attack".to_string(),
)))
});
return true;
}
let Some(code) = code else {
let html = (state.html_responder)(CallbackEvent::MissingCode);
let response = Response::from_string(html).with_header(
Header::from_bytes(&b"Content-Type"[..], &b"text/html; charset=utf-8"[..]).unwrap(),
);
let _ = request.respond(response);
let _ = state
.tx
.lock()
.unwrap()
.take()
.map(|tx| tx.send(Err(OpenAIAuthError::InvalidAuthorizationCode)));
return true;
};
let runtime = tokio::runtime::Runtime::new().unwrap();
let tokens_result = runtime.block_on(async {
state
.oauth_client
.exchange_code(&code, &state.pkce_verifier)
.await
});
match tokens_result {
Ok(tokens) => {
let session_data =
if let (Some(id_token), access_token) = (&tokens.id_token, &tokens.access_token) {
crate::jwt::extract_session_data(id_token, access_token).ok()
} else {
None
};
*state.tokens.lock().unwrap() = Some(tokens);
*state.session_data.lock().unwrap() = session_data.clone();
let mut serializer = form_urlencoded::Serializer::new(String::new());
if let Some(ref session) = session_data {
if let Some(ref org_id) = session.organization_id {
serializer.append_pair("org_id", org_id);
}
if let Some(ref project_id) = session.project_id {
serializer.append_pair("project_id", project_id);
}
if let Some(ref plan_type) = session.chatgpt_plan_type {
serializer.append_pair("plan_type", plan_type);
}
let needs_setup = !session.completed_platform_onboarding && session.is_org_owner;
serializer.append_pair("needs_setup", &needs_setup.to_string());
}
let query_string = serializer.finish();
let redirect_url = if query_string.is_empty() {
"/success".to_string()
} else {
format!("/success?{query_string}")
};
let response = Response::empty(302).with_header(
Header::from_bytes(&b"Location"[..], redirect_url.as_bytes()).unwrap(),
);
let _ = request.respond(response);
false }
Err(e) => {
let html = (state.html_responder)(CallbackEvent::Error {
reason: format!("Token exchange failed: {e}"),
});
let response = Response::from_string(html).with_header(
Header::from_bytes(&b"Content-Type"[..], &b"text/html; charset=utf-8"[..]).unwrap(),
);
let _ = request.respond(response);
let _ = state.tx.lock().unwrap().take().map(|tx| tx.send(Err(e)));
true
}
}
}
fn handle_success_request(request: Request, state: &Arc<ServerState>) -> bool {
let tokens = state.tokens.lock().unwrap().take();
let session_data = state.session_data.lock().unwrap().clone();
if let Some(tokens) = tokens {
let _ = state.tx.lock().unwrap().take().map(|tx| {
tx.send(Ok(CallbackData {
tokens: tokens.clone(),
session_data: session_data.clone(),
}))
});
let html = (state.html_responder)(CallbackEvent::Success {
code: tokens.access_token[..20.min(tokens.access_token.len())].to_string(),
session_data,
});
send_response_with_disconnect(request, html);
true } else {
let html = (state.html_responder)(CallbackEvent::MissingCode);
send_response_with_disconnect(request, html);
let _ = state.tx.lock().unwrap().take().map(|tx| {
tx.send(Err(OpenAIAuthError::CallbackServer(
"No tokens available".to_string(),
)))
});
true
}
}
fn send_response_with_disconnect(request: Request, body: String) {
let mut writer = request.into_writer();
let body_bytes = body.as_bytes();
let _ = write!(writer, "HTTP/1.1 200 OK\r\n");
let _ = write!(writer, "Content-Type: text/html; charset=utf-8\r\n");
let _ = write!(writer, "Content-Length: {}\r\n", body_bytes.len());
let _ = write!(writer, "Connection: close\r\n");
let _ = write!(writer, "\r\n");
let _ = writer.write_all(body_bytes);
let _ = writer.flush();
}
fn default_callback_html(event: CallbackEvent) -> String {
match event {
CallbackEvent::Success { session_data, .. } => {
let mut info_html = String::new();
if let Some(session) = session_data {
if let Some(org_id) = session.organization_id {
info_html.push_str(&format!("<p>Organization: {}</p>", org_id));
}
if let Some(project_id) = session.project_id {
info_html.push_str(&format!("<p>Project: {}</p>", project_id));
}
if let Some(plan_type) = session.chatgpt_plan_type {
info_html.push_str(&format!("<p>Plan: {}</p>", plan_type));
}
}
format!(
r#"
<html>
<head><title>Authorization Successful</title></head>
<body>
<h1>Authorization Successful!</h1>
<p>You have successfully authorized the application.</p>
{}
<p>You can close this window and return to the terminal.</p>
</body>
</html>
"#,
info_html
)
}
CallbackEvent::Error { reason } => format!(
r#"
<html>
<head><title>Authorization Failed</title></head>
<body>
<h1>Authorization Failed</h1>
<p>Error: {}</p>
<p>You can close this window.</p>
</body>
</html>
"#,
reason
),
CallbackEvent::StateMismatch => r#"
<html>
<head><title>Authorization Failed</title></head>
<body>
<h1>Authorization Failed</h1>
<p>Security validation failed. Please try again.</p>
<p>You can close this window.</p>
</body>
</html>
"#
.to_string(),
CallbackEvent::MissingCode => r#"
<html>
<head><title>Authorization Failed</title></head>
<body>
<h1>Authorization Failed</h1>
<p>No authorization code received.</p>
<p>You can close this window.</p>
</body>
</html>
"#
.to_string(),
}
}