Skip to main content

oxide_cli/auth/
server.rs

1use std::{collections::HashMap, sync::Arc};
2
3use anyhow::{Result, anyhow};
4use axum::{
5  Router,
6  extract::{Query, State},
7  response::Redirect,
8  routing::get,
9};
10use serde::{Deserialize, Serialize};
11use tokio::{
12  sync::{Mutex, Notify, oneshot},
13  time::Duration,
14};
15
16use crate::FRONTEND_URL;
17
18type SharedTx = Arc<Mutex<Option<oneshot::Sender<User>>>>;
19type AppState = (SharedTx, String);
20
21#[derive(Serialize, Deserialize)]
22pub struct User {
23  pub token: String,
24  pub name: String,
25}
26
27/// Starts a one-shot local HTTP server on 127.0.0.1:8080 that waits for the
28/// OAuth callback redirect.  `expected_state` is the CSRF nonce generated
29/// by the caller; the callback validates it before accepting credentials.
30pub async fn run_local_auth_server(expected_state: String) -> Result<User> {
31  let notify = Arc::new(Notify::new());
32  let notify_clone = notify.clone();
33  let (tx, rx) = oneshot::channel::<User>();
34
35  let shared_tx: SharedTx = Arc::new(Mutex::new(Some(tx)));
36  let state: AppState = (shared_tx, expected_state);
37
38  let app = Router::new()
39    .route("/callback", get(callback))
40    .with_state(state);
41
42  let listener = tokio::net::TcpListener::bind("127.0.0.1:8080").await?;
43
44  let server = axum::serve(listener, app).with_graceful_shutdown(async move {
45    notify_clone.notified().await;
46  });
47
48  tokio::select! {
49    result = server => {
50      result?;
51      Err(anyhow!("Server stopped unexpectedly"))
52    }
53    user = rx => {
54      notify.notify_one();
55      Ok(user?)
56    }
57    _ = tokio::time::sleep(Duration::from_secs(300)) => {
58      notify.notify_one();
59      Err(anyhow!("Login timed out after 5 minutes. Please try again."))
60    }
61  }
62}
63
64async fn callback(
65  State((shared_tx, expected_state)): State<AppState>,
66  Query(params): Query<HashMap<String, String>>,
67) -> Redirect {
68  // Validate CSRF state token.  The backend must forward the `?state=`
69  // query param it received at /auth/cli-login through to this redirect.
70  match params.get("state") {
71    Some(state) if state == &expected_state => {}
72    Some(_) => return Redirect::to(&format!("{}/cli/error?reason=invalid_state", FRONTEND_URL)),
73    None => return Redirect::to(&format!("{}/cli/error?reason=missing_state", FRONTEND_URL)),
74  }
75
76  if let Some(token) = params.get("token")
77    && let Some(user_name) = params.get("name")
78  {
79    let mut guard = shared_tx.lock().await;
80
81    if let Some(tx) = guard.take() {
82      let _ = tx.send(User {
83        name: user_name.to_string(),
84        token: token.to_string(),
85      });
86    }
87
88    Redirect::to(&format!("{}/cli/success", FRONTEND_URL))
89  } else {
90    Redirect::to(&format!("{}/cli/error", FRONTEND_URL))
91  }
92}