Skip to main content

oxide_cli/auth/
server.rs

1use std::{collections::HashMap, sync::Arc};
2
3use anyhow::{Result, anyhow};
4use axum::{
5  Router,
6  extract::{Query, State},
7  response::Redirect,
8  routing::get,
9};
10use serde::{Deserialize, Serialize};
11use tokio::{
12  sync::{Mutex, Notify, oneshot},
13  time::Duration,
14};
15
16type SharedTx = Arc<Mutex<Option<oneshot::Sender<User>>>>;
17type AppState = (SharedTx, String, String);
18
19#[derive(Serialize, Deserialize)]
20pub struct User {
21  pub token: String,
22  pub name: String,
23}
24
25/// Starts a one-shot local HTTP server on 127.0.0.1:8080 that waits for the
26/// OAuth callback redirect.  `expected_state` is the CSRF nonce generated
27/// by the caller; the callback validates it before accepting credentials.
28pub async fn run_local_auth_server(expected_state: String, frontend_url: &str) -> Result<User> {
29  let notify = Arc::new(Notify::new());
30  let notify_clone = notify.clone();
31  let (tx, rx) = oneshot::channel::<User>();
32
33  let shared_tx: SharedTx = Arc::new(Mutex::new(Some(tx)));
34  let state: AppState = (shared_tx, expected_state, frontend_url.to_string());
35
36  let app = Router::new()
37    .route("/callback", get(callback))
38    .with_state(state);
39
40  let listener = tokio::net::TcpListener::bind("127.0.0.1:8080").await?;
41
42  let server = axum::serve(listener, app).with_graceful_shutdown(async move {
43    notify_clone.notified().await;
44  });
45
46  tokio::select! {
47    result = server => {
48      result?;
49      Err(anyhow!("Server stopped unexpectedly"))
50    }
51    user = rx => {
52      notify.notify_one();
53      Ok(user?)
54    }
55    _ = tokio::time::sleep(Duration::from_secs(600)) => {
56      notify.notify_one();
57      Err(anyhow!("Login timed out after 10 minutes. Please try again."))
58    }
59  }
60}
61
62async fn callback(
63  State((shared_tx, expected_state, frontend_url)): State<AppState>,
64  Query(params): Query<HashMap<String, String>>,
65) -> Redirect {
66  // Validate CSRF state token.  The backend must forward the `?state=`
67  // query param it received at /auth/cli-login through to this redirect.
68  match params.get("state") {
69    Some(state) if state == &expected_state => {}
70    Some(_) => return Redirect::to(&format!("{}/cli/error?reason=invalid_state", frontend_url)),
71    None => return Redirect::to(&format!("{}/cli/error?reason=missing_state", frontend_url)),
72  }
73
74  if let Some(token) = params.get("token")
75    && let Some(user_name) = params.get("name")
76  {
77    let mut guard = shared_tx.lock().await;
78
79    if let Some(tx) = guard.take() {
80      let _ = tx.send(User {
81        name: user_name.to_string(),
82        token: token.to_string(),
83      });
84    }
85
86    Redirect::to(&format!("{}/cli/success", frontend_url))
87  } else {
88    Redirect::to(&format!("{}/cli/error", frontend_url))
89  }
90}