1use std::{collections::HashMap, sync::Arc};
2
3use anyhow::{Result, anyhow};
4use axum::{
5 Router,
6 extract::{Query, State},
7 response::Redirect,
8 routing::get,
9};
10use serde::{Deserialize, Serialize};
11use tokio::{
12 sync::{Mutex, Notify, oneshot},
13 time::Duration,
14};
15
16use crate::FRONTEND_URL;
17
18type SharedTx = Arc<Mutex<Option<oneshot::Sender<User>>>>;
19type AppState = (SharedTx, String);
20
21#[derive(Serialize, Deserialize)]
22pub struct User {
23 pub token: String,
24 pub name: String,
25}
26
27pub async fn run_local_auth_server(expected_state: String) -> Result<User> {
31 let notify = Arc::new(Notify::new());
32 let notify_clone = notify.clone();
33 let (tx, rx) = oneshot::channel::<User>();
34
35 let shared_tx: SharedTx = Arc::new(Mutex::new(Some(tx)));
36 let state: AppState = (shared_tx, expected_state);
37
38 let app = Router::new()
39 .route("/callback", get(callback))
40 .with_state(state);
41
42 let listener = tokio::net::TcpListener::bind("127.0.0.1:8080").await?;
43
44 let server = axum::serve(listener, app).with_graceful_shutdown(async move {
45 notify_clone.notified().await;
46 });
47
48 tokio::select! {
49 result = server => {
50 result?;
51 Err(anyhow!("Server stopped unexpectedly"))
52 }
53 user = rx => {
54 notify.notify_one();
55 Ok(user?)
56 }
57 _ = tokio::time::sleep(Duration::from_secs(300)) => {
58 notify.notify_one();
59 Err(anyhow!("Login timed out after 5 minutes. Please try again."))
60 }
61 }
62}
63
64async fn callback(
65 State((shared_tx, expected_state)): State<AppState>,
66 Query(params): Query<HashMap<String, String>>,
67) -> Redirect {
68 match params.get("state") {
71 Some(state) if state == &expected_state => {}
72 Some(_) => return Redirect::to(&format!("{}/cli/error?reason=invalid_state", FRONTEND_URL)),
73 None => return Redirect::to(&format!("{}/cli/error?reason=missing_state", FRONTEND_URL)),
74 }
75
76 if let Some(token) = params.get("token")
77 && let Some(user_name) = params.get("name")
78 {
79 let mut guard = shared_tx.lock().await;
80
81 if let Some(tx) = guard.take() {
82 let _ = tx.send(User {
83 name: user_name.to_string(),
84 token: token.to_string(),
85 });
86 }
87
88 Redirect::to(&format!("{}/cli/success", FRONTEND_URL))
89 } else {
90 Redirect::to(&format!("{}/cli/error", FRONTEND_URL))
91 }
92}