Skip to main content

asknothingx2_util/oauth/
oneshot.rs

1//! Oneshot OAuth callback server for development and testing.
2//!
3//! A lightweight HTTP server that listens for a single OAuth callback and automatically
4//! shuts down. Useful for testing OAuth flows from various providers on localhost.
5//!
6//! ```no_run
7//! use std::time::Duration;
8//!
9//! use asknothingx2_util::oauth::oneshot::{self, Config, ServerError};
10//! use serde::Deserialize;
11//!
12//! #[derive(Deserialize)]
13//! struct Callback {
14//!     pub code: String,
15//!     pub state: String,
16//! }
17//!
18//! async fn callback() -> Result<Callback, ServerError> {
19//!     let config = Config::new()
20//!         .with_port(8080)
21//!         .with_callback_path("/auth/callback")
22//!         .with_duration(Duration::from_secs(10));
23//!
24//!     oneshot::listen(config).await
25//! }
26//! ```
27//!
28//! # Error Handling
29//! ```
30//! # use asknothingx2_util::oauth::oneshot::{self, Config};
31//! # async fn run(config:Config) {
32//! match oneshot::listen(config).await {
33//!     Ok(callback) => { callback }
34//!     Err(e) => {
35//!         if e.is_timeout() {
36//!             eprintln!("Timeout");
37//!         } else if e.is_invalid_query() {
38//!             eprintln!("Query: {}", e.query().unwrap());
39//!         } else if e.is_unexpected_path() {
40//!             let (expected, actual) = e.path().unwrap();
41//!             eprintln!("Expected: {}", expected);
42//!             eprintln!("Received: {}", actual);
43//!         } else if e.is_shutdown() {
44//!             eprintln!("Shutdown");
45//!         }
46//!     }
47//! }
48//! # }
49//! ```
50use std::{
51    convert::Infallible,
52    io::Error as IoError,
53    net::SocketAddr,
54    sync::{Arc, Mutex},
55    time::Duration,
56};
57
58use http_body_util::Full;
59use hyper::{
60    Method, Request, Response, StatusCode,
61    body::{Bytes, Incoming},
62    server::conn::http1,
63    service::service_fn,
64};
65use hyper_util::rt::TokioIo;
66use serde::{Serialize, de::DeserializeOwned};
67use tokio::{net::TcpListener, sync::oneshot, task::JoinHandle, time::sleep};
68use tracing::debug;
69
70/// Configuration for the oneshot OAuth callback server.
71///
72/// # Defaults
73///
74/// - **port**: `3000`
75/// - **path**: `"/"`
76/// - **duration**: `30 seconds`
77/// - **message**: `"Authorization successful! You can close this window."`
78pub struct Config {
79    port: u16,
80    path: String,
81    duration: Duration,
82    message: String,
83}
84
85impl Default for Config {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl Config {
92    pub fn new() -> Self {
93        Self {
94            port: 3000,
95            path: "/".to_string(),
96            duration: Duration::from_secs(30),
97            message: "Authorization successful! You can close this window.".to_string(),
98        }
99    }
100
101    pub fn with_callback_path(mut self, path: impl Into<String>) -> Self {
102        self.path = path.into();
103        self
104    }
105
106    pub fn with_port(mut self, port: u16) -> Self {
107        self.port = port;
108        self
109    }
110
111    /// The server will return [`ServerError::Timeout`] if no callback is received
112    /// within this duration.
113    pub fn with_duration(mut self, duration: Duration) -> Self {
114        self.duration = duration;
115        self
116    }
117
118    pub fn with_message(mut self, message: impl Into<String>) -> Self {
119        self.message = message.into();
120        self
121    }
122}
123
124/// Starts a oneshot HTTP server and waits for an OAuth callback.
125///
126/// This function binds to `127.0.0.1` on the configured port and listens for a single
127/// HTTP GET request. When a valid callback is received, the query parameters are parsed
128/// into type `T` and the server automatically shuts down.
129///
130/// # Type Parameters
131///
132/// * `T` - The callback parameter type that implements [`serde::Deserialize`].
133///
134/// # Server Behavior
135///
136/// The server shuts down immediately when:
137/// - O - A valid callback is received (returns `Ok(T)`)
138/// - X - Query parsing fails (returns Err([ServerError::InvalidQuery]))
139/// - X - Wrong HTTP method is used (returns Err([ServerError::UnexpectedMethod]))
140/// - X - Wrong path is requested (returns Err([ServerError::UnexpectedPath]))
141/// - X - Timeout is reached (returns Err([ServerError::Timeout]))
142/// - X - Ctrl+C is pressed (returns Err([ServerError::Shutdown]))
143pub async fn listen<T>(config: Config) -> Result<T, ServerError>
144where
145    T: DeserializeOwned + Send + 'static,
146{
147    let (tx, rx) = oneshot::channel::<Result<T, ServerError>>();
148
149    let state = Arc::new(AppState {
150        tx: Arc::new(Mutex::new(Some(tx))),
151        path: config.path,
152        message: config.message,
153    });
154
155    let addr = SocketAddr::from(([127, 0, 0, 1], config.port));
156    debug!("Starting OAuth callback server on {}", addr);
157
158    let listener = TcpListener::bind(&addr)
159        .await
160        .map_err(|e| ServerError::BindFailed {
161            addr: addr.to_string(),
162            source: e,
163        })?;
164
165    let server_handle: JoinHandle<Result<(), ServerError>> = tokio::spawn(async move {
166        loop {
167            let (stream, remote_addr) = listener.accept().await?;
168            debug!("Accepted connection from {}", remote_addr);
169
170            let io = TokioIo::new(stream);
171            let state = state.clone();
172
173            tokio::spawn(async move {
174                let service = service_fn(|req| handle_request::<T>(req, state.clone()));
175
176                if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
177                    debug!("Error serving connection: {:?}", err);
178                }
179            });
180        }
181    });
182
183    tokio::select! {
184        result = rx => {
185            debug!("Shutdown OAuth callback server");
186            server_handle.abort();
187            match result {
188                Ok(Ok(callback)) => Ok(callback),
189                Ok(Err(e)) => Err(e),
190                Err(_) => Err(ServerError::Shutdown),
191            }
192        }
193        _ = sleep(config.duration) => {
194            debug!("OAuth callback server timed out");
195            server_handle.abort();
196            Err(ServerError::Timeout)
197        }
198        _ = tokio::signal::ctrl_c() => {
199            debug!("OAuth callback server received shutdown signal");
200            server_handle.abort();
201            Err(ServerError::Shutdown)
202        }
203    }
204}
205struct AppState<T> {
206    #[allow(clippy::type_complexity)]
207    tx: Arc<Mutex<Option<oneshot::Sender<Result<T, ServerError>>>>>,
208    path: String,
209    message: String,
210}
211
212#[derive(serde::Serialize)]
213struct CallbackResponse {
214    message: String,
215}
216
217async fn handle_request<T>(
218    req: Request<Incoming>,
219    state: Arc<AppState<T>>,
220) -> Result<Response<Full<Bytes>>, Infallible>
221where
222    T: DeserializeOwned + Send + 'static,
223{
224    let method = req.method();
225    let path = req.uri().path();
226    let query = req.uri().query().unwrap_or("");
227
228    debug!("Received request: {} {} (query: {})", method, path, query);
229
230    if method != Method::GET {
231        debug!("Unexpected HTTP method: expected GET, got {}", method);
232
233        if let Some(sender) = state.tx.lock().unwrap().take() {
234            let _ = sender.send(Err(ServerError::UnexpectedMethod {
235                method: method.clone(),
236            }));
237        }
238
239        return Ok(error_response(
240            StatusCode::METHOD_NOT_ALLOWED,
241            "Method not allowed",
242        ));
243    }
244
245    if path != state.path {
246        debug!("Unexpected path: expected '{}', got '{}'", state.path, path);
247
248        if let Some(sender) = state.tx.lock().unwrap().take() {
249            let _ = sender.send(Err(ServerError::UnexpectedPath {
250                expected: state.path.to_string(),
251                actual: path.to_string(),
252            }));
253        }
254
255        return Ok(error_response(StatusCode::NOT_FOUND, "Not found"));
256    }
257
258    let params: T = match serde_urlencoded::from_str(query) {
259        Ok(p) => {
260            debug!("Successfully parsed OAuth callback parameters");
261            p
262        }
263        Err(e) => {
264            let error_msg = e.to_string();
265            debug!("Failed to parse OAuth callback query `{}`: {}", query, e);
266
267            if let Some(sender) = state.tx.lock().unwrap().take() {
268                let _ = sender.send(Err(ServerError::InvalidQuery {
269                    query: query.to_string(),
270                    source: e,
271                }));
272            }
273
274            return Ok(error_response(StatusCode::BAD_REQUEST, &error_msg));
275        }
276    };
277
278    if let Some(sender) = state.tx.lock().unwrap().take() {
279        let _ = sender.send(Ok(params));
280    }
281
282    let response = CallbackResponse {
283        message: state.message.clone(),
284    };
285
286    Ok(json_response(StatusCode::OK, &response))
287}
288
289fn json_response<T: Serialize>(status: StatusCode, body: &T) -> Response<Full<Bytes>> {
290    let json = serde_json::to_vec(body).unwrap();
291    Response::builder()
292        .status(status)
293        .header("Content-Type", "application/json")
294        .body(Full::new(Bytes::from(json)))
295        .unwrap()
296}
297
298fn error_response(status: StatusCode, message: &str) -> Response<Full<Bytes>> {
299    let error = serde_json::json!({ "error": message });
300    json_response(status, &error)
301}
302
303#[derive(Debug, thiserror::Error)]
304pub enum ServerError {
305    #[error("failed to bind to address `{addr}`: {source}")]
306    BindFailed { addr: String, source: IoError },
307    #[error(transparent)]
308    Io(#[from] IoError),
309    #[error("invalid OAuth callback query `{query}`: {source}")]
310    InvalidQuery {
311        query: String,
312        #[source]
313        source: serde_urlencoded::de::Error,
314    },
315    #[error("unexpected HTTP method: expected `GET`, got {method}")]
316    UnexpectedMethod { method: Method },
317    #[error("unexpected path: expected `{expected}`, got `{actual}`")]
318    UnexpectedPath { expected: String, actual: String },
319    #[error("server received shutdown signal")]
320    Shutdown,
321    #[error("timeout waiting for OAuth authorization callback")]
322    Timeout,
323}
324
325impl ServerError {
326    pub fn is_timeout(&self) -> bool {
327        matches!(self, Self::Timeout)
328    }
329
330    pub fn is_invalid_query(&self) -> bool {
331        matches!(self, Self::InvalidQuery { .. })
332    }
333
334    pub fn is_unexpected_method(&self) -> bool {
335        matches!(self, Self::UnexpectedMethod { .. })
336    }
337
338    pub fn is_unexpected_path(&self) -> bool {
339        matches!(self, Self::UnexpectedPath { .. })
340    }
341
342    pub fn is_shutdown(&self) -> bool {
343        matches!(self, Self::Shutdown)
344    }
345
346    pub fn is_bind_failed(&self) -> bool {
347        matches!(self, Self::BindFailed { .. })
348    }
349
350    pub fn is_io(&self) -> bool {
351        matches!(self, Self::Io(_))
352    }
353
354    /// Returns the query string if this is an [`InvalidQuery`](ServerError::InvalidQuery) error.
355    pub fn query(&self) -> Option<&str> {
356        match self {
357            Self::InvalidQuery { query, source: _ } => Some(query),
358            _ => None,
359        }
360    }
361
362    /// Returns the HTTP method if this is an [`UnexpectedMethod`](ServerError::UnexpectedMethod) error.
363    pub fn method(&self) -> Option<&Method> {
364        match self {
365            Self::UnexpectedMethod { method } => Some(method),
366            _ => None,
367        }
368    }
369
370    /// Returns a tuple of `(expected, actual)` paths if this is an [`UnexpectedPath`](ServerError::UnexpectedPath) error.
371    pub fn path(&self) -> Option<(&str, &str)> {
372        match self {
373            Self::UnexpectedPath { expected, actual } => Some((expected, actual)),
374            _ => None,
375        }
376    }
377}