use std::{
net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener},
path::PathBuf,
sync::Arc,
time::Duration,
};
use anyhow::Result;
use axum::{extract::DefaultBodyLimit, routing::get, Router};
use axum_server::Handle;
use tokio::sync::Mutex;
use tower_http::{cors::CorsLayer, trace::TraceLayer};
use url::Url;
use super::inbox;
#[cfg(feature = "link-compat")]
use super::link;
use super::persistence::EntryRepository;
use super::waiting_list::WaitingList;
const DEFAULT_LINK_TIMEOUT: Duration = Duration::from_secs(10 * 60);
const DEFAULT_INBOX_TIMEOUT: Duration = Duration::from_secs(25);
const DEFAULT_INBOX_CACHE_TTL: Duration = Duration::from_secs(5 * 60);
const DEFAULT_MAX_BODY_SIZE: usize = 2 * 1024;
const DEFAULT_MAX_ENTRIES: usize = 10_000;
#[derive(Clone)]
pub(crate) struct AppState {
pub config: Config,
pub pending_list: Arc<Mutex<WaitingList>>,
}
impl AppState {
pub fn new(config: Config) -> anyhow::Result<Self> {
if let Some(path) = &config.persist_db {
tracing::info!(path = %path.display(), "Persistence enabled with SQLite");
} else {
tracing::debug!("Using in-memory storage (no persistence)");
}
let repository = EntryRepository::new(config.persist_db.as_deref(), config.max_entries)?;
let waiting_list = WaitingList::new(repository);
Ok(Self {
config,
pending_list: Arc::new(Mutex::new(waiting_list)),
})
}
}
#[derive(Debug, Clone)]
pub(crate) struct Config {
pub bind_address: IpAddr,
pub http_port: u16,
pub link_timeout: Duration,
pub inbox_timeout: Duration,
pub inbox_cache_ttl: Duration,
pub max_body_size: usize,
pub max_entries: usize,
pub persist_db: Option<PathBuf>,
pub cors_allow_all: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
bind_address: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
http_port: 0,
link_timeout: DEFAULT_LINK_TIMEOUT,
inbox_timeout: DEFAULT_INBOX_TIMEOUT,
inbox_cache_ttl: DEFAULT_INBOX_CACHE_TTL,
max_body_size: DEFAULT_MAX_BODY_SIZE,
max_entries: DEFAULT_MAX_ENTRIES,
persist_db: None,
cors_allow_all: false,
}
}
}
#[derive(Debug, Default)]
pub struct HttpRelayBuilder(Config);
impl HttpRelayBuilder {
pub fn bind_address(mut self, addr: IpAddr) -> Self {
self.0.bind_address = addr;
self
}
pub fn http_port(mut self, port: u16) -> Self {
self.0.http_port = port;
self
}
pub fn link_timeout(mut self, timeout: Duration) -> Self {
self.0.link_timeout = timeout;
self
}
pub fn inbox_timeout(mut self, timeout: Duration) -> Self {
self.0.inbox_timeout = timeout;
self
}
pub fn inbox_cache_ttl(mut self, ttl: Duration) -> Self {
self.0.inbox_cache_ttl = ttl;
self
}
pub fn max_body_size(mut self, size: usize) -> Self {
self.0.max_body_size = size;
self
}
pub fn max_entries(mut self, max: usize) -> Self {
self.0.max_entries = max;
self
}
pub fn persist_db(mut self, path: Option<PathBuf>) -> Self {
self.0.persist_db = path;
self
}
pub fn cors_allow_all(mut self, cors_allow_all: bool) -> Self {
self.0.cors_allow_all = cors_allow_all;
self
}
pub async fn run(self) -> Result<HttpRelay> {
HttpRelay::start(self.0).await
}
}
pub struct HttpRelay {
pub(crate) http_handle: Handle<SocketAddr>,
http_address: SocketAddr,
}
impl HttpRelay {
fn build_router(state: AppState) -> Router {
let max_body_size = state.config.max_body_size;
let router = Router::new()
.route("/", get(|| async { "Http Relay" }))
.route("/inbox/{id}/ack", get(inbox::ack_handler))
.route("/inbox/{id}/await", get(inbox::await_handler))
.route(
"/inbox/{id}",
get(inbox::get_handler)
.post(inbox::post_handler)
.delete(inbox::delete_handler),
);
#[cfg(feature = "link-compat")]
let router = router.route(
"/link/{id}",
get(link::get_handler).post(link::post_handler),
);
let router = router
.layer(DefaultBodyLimit::max(max_body_size))
.layer(TraceLayer::new_for_http());
let router = if state.config.cors_allow_all {
router.layer(CorsLayer::very_permissive())
} else {
router
};
router.with_state(state)
}
#[cfg(test)]
pub(crate) fn create_app(config: Config) -> Result<(Router, AppState)> {
let app_state = AppState::new(config)?;
let app = Self::build_router(app_state.clone());
Ok((app, app_state))
}
#[cfg(test)]
pub(crate) fn create_test_server(config: Config) -> (axum_test::TestServer, AppState) {
let (app, state) = Self::create_app(config).unwrap();
let server = axum_test::TestServer::new(app).unwrap();
(server, state)
}
async fn start(config: Config) -> Result<Self> {
let app_state = AppState::new(config.clone())?;
let app = Self::build_router(app_state.clone());
let http_handle = Handle::new();
let shutdown_handle = http_handle.clone();
let http_listener =
TcpListener::bind(SocketAddr::new(config.bind_address, config.http_port))?;
http_listener.set_nonblocking(true)?;
let http_address = http_listener.local_addr()?;
let server = axum_server::from_tcp(http_listener)?;
tokio::spawn(async move {
server
.handle(http_handle.clone())
.serve(app.into_make_service())
.await
.map_err(|error| tracing::error!(?error, "HttpRelay http server error"))
});
let cleanup_interval = Duration::from_secs(15);
let pending_list = app_state.pending_list.clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(cleanup_interval).await;
let removed = pending_list.lock().await.cleanup_expired();
if removed > 0 {
tracing::debug!(removed, "Cleaned up expired entries");
}
}
});
Ok(Self {
http_handle: shutdown_handle,
http_address,
})
}
pub fn builder() -> HttpRelayBuilder {
HttpRelayBuilder::default()
}
pub fn http_address(&self) -> SocketAddr {
self.http_address
}
pub fn local_url(&self) -> Url {
Url::parse(&format!("http://localhost:{}", self.http_address.port()))
.expect("hardcoded URL scheme and localhost are always valid")
}
#[cfg(feature = "link-compat")]
pub fn local_link_url(&self) -> Url {
let mut url = self.local_url();
let mut segments = url
.path_segments_mut()
.expect("http URLs always have path segments");
segments.push("link");
drop(segments);
url
}
pub async fn shutdown(self) -> anyhow::Result<()> {
self.http_handle
.graceful_shutdown(Some(Duration::from_secs(1)));
Ok(())
}
}
impl Drop for HttpRelay {
fn drop(&mut self) {
self.http_handle.shutdown();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_root_returns_http_relay() {
let (server, _state) = HttpRelay::create_test_server(Config::default());
let response = server.get("/").await;
assert_eq!(response.status_code(), 200);
assert_eq!(response.text(), "Http Relay");
}
#[tokio::test]
async fn test_no_cors_headers_by_default() {
let (server, _state) = HttpRelay::create_test_server(Config::default());
let response = server
.get("/")
.add_header("Origin", "https://example.com")
.await;
assert!(response
.maybe_header("access-control-allow-origin")
.is_none(),);
}
#[tokio::test]
async fn test_cors_allow_all_adds_headers() {
let config = Config {
cors_allow_all: true,
..Config::default()
};
let (server, _state) = HttpRelay::create_test_server(config);
let response = server
.get("/")
.add_header("Origin", "https://example.com")
.await;
assert!(response
.maybe_header("access-control-allow-origin")
.is_some(),);
}
#[tokio::test]
async fn test_start_and_shutdown() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
let relay = HttpRelay::start(Config::default()).await.unwrap();
let addr = relay.http_address();
let mut stream = TcpStream::connect(addr).await.unwrap();
stream
.write_all(b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
.await
.unwrap();
let mut response = String::new();
stream.read_to_string(&mut response).await.unwrap();
assert!(response.starts_with("HTTP/1.1 200"));
assert!(response.contains("Http Relay"));
relay.shutdown().await.unwrap();
}
}