1use std::{
51 convert::Infallible,
52 io::Error as IoError,
53 net::SocketAddr,
54 sync::{Arc, Mutex},
55 time::Duration,
56};
57
58use http_body_util::Full;
59use hyper::{
60 Method, Request, Response, StatusCode,
61 body::{Bytes, Incoming},
62 server::conn::http1,
63 service::service_fn,
64};
65use hyper_util::rt::TokioIo;
66use serde::{Serialize, de::DeserializeOwned};
67use tokio::{net::TcpListener, sync::oneshot, task::JoinHandle, time::sleep};
68use tracing::debug;
69
70pub struct Config {
79 port: u16,
80 path: String,
81 duration: Duration,
82 message: String,
83}
84
85impl Default for Config {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91impl Config {
92 pub fn new() -> Self {
93 Self {
94 port: 3000,
95 path: "/".to_string(),
96 duration: Duration::from_secs(30),
97 message: "Authorization successful! You can close this window.".to_string(),
98 }
99 }
100
101 pub fn with_callback_path(mut self, path: impl Into<String>) -> Self {
102 self.path = path.into();
103 self
104 }
105
106 pub fn with_port(mut self, port: u16) -> Self {
107 self.port = port;
108 self
109 }
110
111 pub fn with_duration(mut self, duration: Duration) -> Self {
114 self.duration = duration;
115 self
116 }
117
118 pub fn with_message(mut self, message: impl Into<String>) -> Self {
119 self.message = message.into();
120 self
121 }
122}
123
124pub async fn listen<T>(config: Config) -> Result<T, ServerError>
144where
145 T: DeserializeOwned + Send + 'static,
146{
147 let (tx, rx) = oneshot::channel::<Result<T, ServerError>>();
148
149 let state = Arc::new(AppState {
150 tx: Arc::new(Mutex::new(Some(tx))),
151 path: config.path,
152 message: config.message,
153 });
154
155 let addr = SocketAddr::from(([127, 0, 0, 1], config.port));
156 debug!("Starting OAuth callback server on {}", addr);
157
158 let listener = TcpListener::bind(&addr)
159 .await
160 .map_err(|e| ServerError::BindFailed {
161 addr: addr.to_string(),
162 source: e,
163 })?;
164
165 let server_handle: JoinHandle<Result<(), ServerError>> = tokio::spawn(async move {
166 loop {
167 let (stream, remote_addr) = listener.accept().await?;
168 debug!("Accepted connection from {}", remote_addr);
169
170 let io = TokioIo::new(stream);
171 let state = state.clone();
172
173 tokio::spawn(async move {
174 let service = service_fn(|req| handle_request::<T>(req, state.clone()));
175
176 if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
177 debug!("Error serving connection: {:?}", err);
178 }
179 });
180 }
181 });
182
183 tokio::select! {
184 result = rx => {
185 debug!("Shutdown OAuth callback server");
186 server_handle.abort();
187 match result {
188 Ok(Ok(callback)) => Ok(callback),
189 Ok(Err(e)) => Err(e),
190 Err(_) => Err(ServerError::Shutdown),
191 }
192 }
193 _ = sleep(config.duration) => {
194 debug!("OAuth callback server timed out");
195 server_handle.abort();
196 Err(ServerError::Timeout)
197 }
198 _ = tokio::signal::ctrl_c() => {
199 debug!("OAuth callback server received shutdown signal");
200 server_handle.abort();
201 Err(ServerError::Shutdown)
202 }
203 }
204}
205struct AppState<T> {
206 #[allow(clippy::type_complexity)]
207 tx: Arc<Mutex<Option<oneshot::Sender<Result<T, ServerError>>>>>,
208 path: String,
209 message: String,
210}
211
212#[derive(serde::Serialize)]
213struct CallbackResponse {
214 message: String,
215}
216
217async fn handle_request<T>(
218 req: Request<Incoming>,
219 state: Arc<AppState<T>>,
220) -> Result<Response<Full<Bytes>>, Infallible>
221where
222 T: DeserializeOwned + Send + 'static,
223{
224 let method = req.method();
225 let path = req.uri().path();
226 let query = req.uri().query().unwrap_or("");
227
228 debug!("Received request: {} {} (query: {})", method, path, query);
229
230 if method != Method::GET {
231 debug!("Unexpected HTTP method: expected GET, got {}", method);
232
233 if let Some(sender) = state.tx.lock().unwrap().take() {
234 let _ = sender.send(Err(ServerError::UnexpectedMethod {
235 method: method.clone(),
236 }));
237 }
238
239 return Ok(error_response(
240 StatusCode::METHOD_NOT_ALLOWED,
241 "Method not allowed",
242 ));
243 }
244
245 if path != state.path {
246 debug!("Unexpected path: expected '{}', got '{}'", state.path, path);
247
248 if let Some(sender) = state.tx.lock().unwrap().take() {
249 let _ = sender.send(Err(ServerError::UnexpectedPath {
250 expected: state.path.to_string(),
251 actual: path.to_string(),
252 }));
253 }
254
255 return Ok(error_response(StatusCode::NOT_FOUND, "Not found"));
256 }
257
258 let params: T = match serde_urlencoded::from_str(query) {
259 Ok(p) => {
260 debug!("Successfully parsed OAuth callback parameters");
261 p
262 }
263 Err(e) => {
264 let error_msg = e.to_string();
265 debug!("Failed to parse OAuth callback query `{}`: {}", query, e);
266
267 if let Some(sender) = state.tx.lock().unwrap().take() {
268 let _ = sender.send(Err(ServerError::InvalidQuery {
269 query: query.to_string(),
270 source: e,
271 }));
272 }
273
274 return Ok(error_response(StatusCode::BAD_REQUEST, &error_msg));
275 }
276 };
277
278 if let Some(sender) = state.tx.lock().unwrap().take() {
279 let _ = sender.send(Ok(params));
280 }
281
282 let response = CallbackResponse {
283 message: state.message.clone(),
284 };
285
286 Ok(json_response(StatusCode::OK, &response))
287}
288
289fn json_response<T: Serialize>(status: StatusCode, body: &T) -> Response<Full<Bytes>> {
290 let json = serde_json::to_vec(body).unwrap();
291 Response::builder()
292 .status(status)
293 .header("Content-Type", "application/json")
294 .body(Full::new(Bytes::from(json)))
295 .unwrap()
296}
297
298fn error_response(status: StatusCode, message: &str) -> Response<Full<Bytes>> {
299 let error = serde_json::json!({ "error": message });
300 json_response(status, &error)
301}
302
303#[derive(Debug, thiserror::Error)]
304pub enum ServerError {
305 #[error("failed to bind to address `{addr}`: {source}")]
306 BindFailed { addr: String, source: IoError },
307 #[error(transparent)]
308 Io(#[from] IoError),
309 #[error("invalid OAuth callback query `{query}`: {source}")]
310 InvalidQuery {
311 query: String,
312 #[source]
313 source: serde_urlencoded::de::Error,
314 },
315 #[error("unexpected HTTP method: expected `GET`, got {method}")]
316 UnexpectedMethod { method: Method },
317 #[error("unexpected path: expected `{expected}`, got `{actual}`")]
318 UnexpectedPath { expected: String, actual: String },
319 #[error("server received shutdown signal")]
320 Shutdown,
321 #[error("timeout waiting for OAuth authorization callback")]
322 Timeout,
323}
324
325impl ServerError {
326 pub fn is_timeout(&self) -> bool {
327 matches!(self, Self::Timeout)
328 }
329
330 pub fn is_invalid_query(&self) -> bool {
331 matches!(self, Self::InvalidQuery { .. })
332 }
333
334 pub fn is_unexpected_method(&self) -> bool {
335 matches!(self, Self::UnexpectedMethod { .. })
336 }
337
338 pub fn is_unexpected_path(&self) -> bool {
339 matches!(self, Self::UnexpectedPath { .. })
340 }
341
342 pub fn is_shutdown(&self) -> bool {
343 matches!(self, Self::Shutdown)
344 }
345
346 pub fn is_bind_failed(&self) -> bool {
347 matches!(self, Self::BindFailed { .. })
348 }
349
350 pub fn is_io(&self) -> bool {
351 matches!(self, Self::Io(_))
352 }
353
354 pub fn query(&self) -> Option<&str> {
356 match self {
357 Self::InvalidQuery { query, source: _ } => Some(query),
358 _ => None,
359 }
360 }
361
362 pub fn method(&self) -> Option<&Method> {
364 match self {
365 Self::UnexpectedMethod { method } => Some(method),
366 _ => None,
367 }
368 }
369
370 pub fn path(&self) -> Option<(&str, &str)> {
372 match self {
373 Self::UnexpectedPath { expected, actual } => Some((expected, actual)),
374 _ => None,
375 }
376 }
377}