1use anyhow::{anyhow, Result};
2use axum::{
3 extract::{Query, State},
4 response::Html,
5 routing::get,
6 Router,
7};
8use serde::Deserialize;
9use std::net::TcpListener;
10use tokio::sync::{mpsc, oneshot};
11
12#[derive(Deserialize)]
13struct AuthCallback {
14 code: String,
15 state: String,
16}
17
18struct AppState {
19 tx: mpsc::Sender<String>,
20 expected_state: String,
21}
22
23pub async fn run_server(
24 listener: TcpListener,
25 expected_state: String,
26) -> Result<oauth2::AuthorizationCode> {
27 let (tx, mut rx) = mpsc::channel(1);
28 let (ready_tx, ready_rx) = oneshot::channel();
29
30 let state = std::sync::Arc::new(AppState { tx, expected_state });
31
32 let app = Router::new()
33 .route("/callback", get(handler.clone()))
34 .route("/oauth2callback", get(handler))
35 .with_state(state);
36
37 listener.set_nonblocking(true)?;
39 let tokio_listener = tokio::net::TcpListener::from_std(listener)?;
40
41 tokio::spawn(async move {
43 let _ = ready_tx.send(());
45
46 if let Err(e) = axum::serve(tokio_listener, app).await {
47 eprintln!("Server error: {}", e);
48 }
49 });
50
51 let _ = ready_rx.await;
53
54 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
56
57 let code_str = rx
59 .recv()
60 .await
61 .ok_or_else(|| anyhow!("Failed to receive auth code"))?;
62
63 Ok(oauth2::AuthorizationCode::new(code_str))
64}
65
66async fn handler(
67 Query(params): Query<AuthCallback>,
68 State(state): State<std::sync::Arc<AppState>>,
69) -> Html<&'static str> {
70 if params.state != state.expected_state {
71 return Html("<h1>Error: Invalid State</h1><p>CSRF check failed.</p>");
72 }
73
74 let _ = state.tx.send(params.code).await;
76
77 Html("<h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p><script>window.close()</script>")
78}