use std::{
convert::Infallible,
io::Error as IoError,
net::SocketAddr,
sync::{Arc, Mutex},
time::Duration,
};
use http_body_util::Full;
use hyper::{
Method, Request, Response, StatusCode,
body::{Bytes, Incoming},
server::conn::http1,
service::service_fn,
};
use hyper_util::rt::TokioIo;
use serde::{Serialize, de::DeserializeOwned};
use tokio::{net::TcpListener, sync::oneshot, task::JoinHandle, time::sleep};
use tracing::debug;
pub struct Config {
port: u16,
path: String,
duration: Duration,
message: String,
}
impl Default for Config {
fn default() -> Self {
Self::new()
}
}
impl Config {
pub fn new() -> Self {
Self {
port: 3000,
path: "/".to_string(),
duration: Duration::from_secs(30),
message: "Authorization successful! You can close this window.".to_string(),
}
}
pub fn with_callback_path(mut self, path: impl Into<String>) -> Self {
self.path = path.into();
self
}
pub fn with_port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn with_duration(mut self, duration: Duration) -> Self {
self.duration = duration;
self
}
pub fn with_message(mut self, message: impl Into<String>) -> Self {
self.message = message.into();
self
}
}
pub async fn listen<T>(config: Config) -> Result<T, ServerError>
where
T: DeserializeOwned + Send + 'static,
{
let (tx, rx) = oneshot::channel::<Result<T, ServerError>>();
let state = Arc::new(AppState {
tx: Arc::new(Mutex::new(Some(tx))),
path: config.path,
message: config.message,
});
let addr = SocketAddr::from(([127, 0, 0, 1], config.port));
debug!("Starting OAuth callback server on {}", addr);
let listener = TcpListener::bind(&addr)
.await
.map_err(|e| ServerError::BindFailed {
addr: addr.to_string(),
source: e,
})?;
let server_handle: JoinHandle<Result<(), ServerError>> = tokio::spawn(async move {
loop {
let (stream, remote_addr) = listener.accept().await?;
debug!("Accepted connection from {}", remote_addr);
let io = TokioIo::new(stream);
let state = state.clone();
tokio::spawn(async move {
let service = service_fn(|req| handle_request::<T>(req, state.clone()));
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
debug!("Error serving connection: {:?}", err);
}
});
}
});
tokio::select! {
result = rx => {
debug!("Shutdown OAuth callback server");
server_handle.abort();
match result {
Ok(Ok(callback)) => Ok(callback),
Ok(Err(e)) => Err(e),
Err(_) => Err(ServerError::Shutdown),
}
}
_ = sleep(config.duration) => {
debug!("OAuth callback server timed out");
server_handle.abort();
Err(ServerError::Timeout)
}
_ = tokio::signal::ctrl_c() => {
debug!("OAuth callback server received shutdown signal");
server_handle.abort();
Err(ServerError::Shutdown)
}
}
}
struct AppState<T> {
#[allow(clippy::type_complexity)]
tx: Arc<Mutex<Option<oneshot::Sender<Result<T, ServerError>>>>>,
path: String,
message: String,
}
#[derive(serde::Serialize)]
struct CallbackResponse {
message: String,
}
async fn handle_request<T>(
req: Request<Incoming>,
state: Arc<AppState<T>>,
) -> Result<Response<Full<Bytes>>, Infallible>
where
T: DeserializeOwned + Send + 'static,
{
let method = req.method();
let path = req.uri().path();
let query = req.uri().query().unwrap_or("");
debug!("Received request: {} {} (query: {})", method, path, query);
if method != Method::GET {
debug!("Unexpected HTTP method: expected GET, got {}", method);
if let Some(sender) = state.tx.lock().unwrap().take() {
let _ = sender.send(Err(ServerError::UnexpectedMethod {
method: method.clone(),
}));
}
return Ok(error_response(
StatusCode::METHOD_NOT_ALLOWED,
"Method not allowed",
));
}
if path != state.path {
debug!("Unexpected path: expected '{}', got '{}'", state.path, path);
if let Some(sender) = state.tx.lock().unwrap().take() {
let _ = sender.send(Err(ServerError::UnexpectedPath {
expected: state.path.to_string(),
actual: path.to_string(),
}));
}
return Ok(error_response(StatusCode::NOT_FOUND, "Not found"));
}
let params: T = match serde_urlencoded::from_str(query) {
Ok(p) => {
debug!("Successfully parsed OAuth callback parameters");
p
}
Err(e) => {
let error_msg = e.to_string();
debug!("Failed to parse OAuth callback query `{}`: {}", query, e);
if let Some(sender) = state.tx.lock().unwrap().take() {
let _ = sender.send(Err(ServerError::InvalidQuery {
query: query.to_string(),
source: e,
}));
}
return Ok(error_response(StatusCode::BAD_REQUEST, &error_msg));
}
};
if let Some(sender) = state.tx.lock().unwrap().take() {
let _ = sender.send(Ok(params));
}
let response = CallbackResponse {
message: state.message.clone(),
};
Ok(json_response(StatusCode::OK, &response))
}
fn json_response<T: Serialize>(status: StatusCode, body: &T) -> Response<Full<Bytes>> {
let json = serde_json::to_vec(body).unwrap();
Response::builder()
.status(status)
.header("Content-Type", "application/json")
.body(Full::new(Bytes::from(json)))
.unwrap()
}
fn error_response(status: StatusCode, message: &str) -> Response<Full<Bytes>> {
let error = serde_json::json!({ "error": message });
json_response(status, &error)
}
#[derive(Debug, thiserror::Error)]
pub enum ServerError {
#[error("failed to bind to address `{addr}`: {source}")]
BindFailed { addr: String, source: IoError },
#[error(transparent)]
Io(#[from] IoError),
#[error("invalid OAuth callback query `{query}`: {source}")]
InvalidQuery {
query: String,
#[source]
source: serde_urlencoded::de::Error,
},
#[error("unexpected HTTP method: expected `GET`, got {method}")]
UnexpectedMethod { method: Method },
#[error("unexpected path: expected `{expected}`, got `{actual}`")]
UnexpectedPath { expected: String, actual: String },
#[error("server received shutdown signal")]
Shutdown,
#[error("timeout waiting for OAuth authorization callback")]
Timeout,
}
impl ServerError {
pub fn is_timeout(&self) -> bool {
matches!(self, Self::Timeout)
}
pub fn is_invalid_query(&self) -> bool {
matches!(self, Self::InvalidQuery { .. })
}
pub fn is_unexpected_method(&self) -> bool {
matches!(self, Self::UnexpectedMethod { .. })
}
pub fn is_unexpected_path(&self) -> bool {
matches!(self, Self::UnexpectedPath { .. })
}
pub fn is_shutdown(&self) -> bool {
matches!(self, Self::Shutdown)
}
pub fn is_bind_failed(&self) -> bool {
matches!(self, Self::BindFailed { .. })
}
pub fn is_io(&self) -> bool {
matches!(self, Self::Io(_))
}
pub fn query(&self) -> Option<&str> {
match self {
Self::InvalidQuery { query, source: _ } => Some(query),
_ => None,
}
}
pub fn method(&self) -> Option<&Method> {
match self {
Self::UnexpectedMethod { method } => Some(method),
_ => None,
}
}
pub fn path(&self) -> Option<(&str, &str)> {
match self {
Self::UnexpectedPath { expected, actual } => Some((expected, actual)),
_ => None,
}
}
}