1use std::{collections::HashMap, sync::Arc};
2
3use anyhow::{Result, anyhow};
4use axum::{
5 Router,
6 extract::{Query, State},
7 response::Redirect,
8 routing::get,
9};
10use serde::{Deserialize, Serialize};
11use tokio::{
12 sync::{Mutex, Notify, oneshot},
13 time::Duration,
14};
15
16type SharedTx = Arc<Mutex<Option<oneshot::Sender<User>>>>;
17type AppState = (SharedTx, String, String);
18
19#[derive(Serialize, Deserialize)]
20pub struct User {
21 pub token: String,
22 pub name: String,
23}
24
25pub async fn run_local_auth_server(expected_state: String, frontend_url: &str) -> Result<User> {
29 let notify = Arc::new(Notify::new());
30 let notify_clone = notify.clone();
31 let (tx, rx) = oneshot::channel::<User>();
32
33 let shared_tx: SharedTx = Arc::new(Mutex::new(Some(tx)));
34 let state: AppState = (shared_tx, expected_state, frontend_url.to_string());
35
36 let app = Router::new()
37 .route("/callback", get(callback))
38 .with_state(state);
39
40 let listener = tokio::net::TcpListener::bind("127.0.0.1:8080").await?;
41
42 let server = axum::serve(listener, app).with_graceful_shutdown(async move {
43 notify_clone.notified().await;
44 });
45
46 tokio::select! {
47 result = server => {
48 result?;
49 Err(anyhow!("Server stopped unexpectedly"))
50 }
51 user = rx => {
52 notify.notify_one();
53 Ok(user?)
54 }
55 _ = tokio::time::sleep(Duration::from_secs(600)) => {
56 notify.notify_one();
57 Err(anyhow!("Login timed out after 10 minutes. Please try again."))
58 }
59 }
60}
61
62async fn callback(
63 State((shared_tx, expected_state, frontend_url)): State<AppState>,
64 Query(params): Query<HashMap<String, String>>,
65) -> Redirect {
66 match params.get("state") {
69 Some(state) if state == &expected_state => {}
70 Some(_) => return Redirect::to(&format!("{}/cli/error?reason=invalid_state", frontend_url)),
71 None => return Redirect::to(&format!("{}/cli/error?reason=missing_state", frontend_url)),
72 }
73
74 if let Some(token) = params.get("token")
75 && let Some(user_name) = params.get("name")
76 {
77 let mut guard = shared_tx.lock().await;
78
79 if let Some(tx) = guard.take() {
80 let _ = tx.send(User {
81 name: user_name.to_string(),
82 token: token.to_string(),
83 });
84 }
85
86 Redirect::to(&format!("{}/cli/success", frontend_url))
87 } else {
88 Redirect::to(&format!("{}/cli/error", frontend_url))
89 }
90}