use std::collections::HashMap;
use std::net::TcpListener;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::thread::JoinHandle;
use anyhow::Context as _;
use axum::Router;
use axum::body::Bytes;
use axum::extract::State;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::response::IntoResponse;
use axum::routing::get;
use futures::{SinkExt, StreamExt};
use tokio::sync::{broadcast, mpsc, oneshot};
use tower_http::services::ServeDir;
use super::codec::{CONTROL_TOPIC, CodecKind, ControlMessage, Envelope, WIRE_PROTOCOL_VERSION};
pub(crate) const PUBLISH_BROADCAST_CAPACITY: usize = 1024;
pub(crate) const CONNECTION_OUTBOUND_CAPACITY: usize = 1024;
pub(crate) const SUBSCRIBE_MPSC_CAPACITY: usize = 1024;
pub(crate) struct WebServerInner {
pub(crate) codec: CodecKind,
pub(crate) pub_topics: Mutex<HashMap<String, broadcast::Sender<Bytes>>>,
pub(crate) sub_topics: Mutex<HashMap<String, Vec<mpsc::Sender<Bytes>>>>,
}
impl WebServerInner {
fn new(codec: CodecKind) -> Self {
Self {
codec,
pub_topics: Mutex::new(HashMap::new()),
sub_topics: Mutex::new(HashMap::new()),
}
}
pub(crate) fn get_or_create_pub_topic(&self, topic: &str) -> broadcast::Sender<Bytes> {
let mut guard = self.pub_topics.lock().expect("pub_topics lock poisoned");
guard
.entry(topic.to_string())
.or_insert_with(|| broadcast::channel(PUBLISH_BROADCAST_CAPACITY).0)
.clone()
}
pub(crate) fn register_sub_sender(&self, topic: &str, tx: mpsc::Sender<Bytes>) {
let mut guard = self.sub_topics.lock().expect("sub_topics lock poisoned");
guard.entry(topic.to_string()).or_default().push(tx);
}
fn dispatch_client_payload(&self, topic: &str, payload: Bytes) {
let mut guard = self.sub_topics.lock().expect("sub_topics lock poisoned");
if let Some(senders) = guard.get_mut(topic) {
senders.retain(|tx| match tx.try_send(payload.clone()) {
Ok(()) => true,
Err(mpsc::error::TrySendError::Full(_)) => {
log::warn!("web_sub: topic '{topic}' listener overloaded — dropping frame");
true
}
Err(mpsc::error::TrySendError::Closed(_)) => false,
});
}
}
}
#[cfg(feature = "web-tls")]
struct TlsPaths {
cert_path: PathBuf,
key_path: PathBuf,
}
pub struct WebServer {
pub(crate) inner: Arc<WebServerInner>,
port: u16,
shutdown_tx: Option<oneshot::Sender<()>>,
thread: Option<JoinHandle<()>>,
historical_noop: bool,
tls: bool,
}
impl WebServer {
pub fn bind(addr: impl Into<String>) -> WebServerBuilder {
WebServerBuilder {
addr: addr.into(),
codec: CodecKind::Bincode,
static_dir: None,
#[cfg(feature = "web-tls")]
tls: None,
}
}
pub fn port(&self) -> u16 {
self.port
}
pub fn codec(&self) -> CodecKind {
self.inner.codec
}
pub fn is_historical_noop(&self) -> bool {
self.historical_noop
}
pub fn is_tls(&self) -> bool {
self.tls
}
pub fn stop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(handle) = self.thread.take() {
let _ = handle.join();
}
}
}
impl Drop for WebServer {
fn drop(&mut self) {
self.stop();
}
}
pub struct WebServerBuilder {
addr: String,
codec: CodecKind,
static_dir: Option<PathBuf>,
#[cfg(feature = "web-tls")]
tls: Option<TlsPaths>,
}
impl WebServerBuilder {
pub fn codec(mut self, codec: CodecKind) -> Self {
self.codec = codec;
self
}
pub fn serve_static(mut self, dir: impl Into<PathBuf>) -> Self {
self.static_dir = Some(dir.into());
self
}
#[cfg(feature = "web-tls")]
pub fn tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
self.tls = Some(TlsPaths {
cert_path: cert_path.into(),
key_path: key_path.into(),
});
self
}
pub fn start(self) -> anyhow::Result<WebServer> {
let listener =
TcpListener::bind(&self.addr).with_context(|| format!("web: bind to {}", self.addr))?;
let port = listener.local_addr().context("web: local_addr")?.port();
#[cfg(feature = "web-tls")]
let tls_config = match self.tls {
Some(paths) => Some(load_tls_config(&paths)?),
None => None,
};
#[cfg(feature = "web-tls")]
let is_tls = tls_config.is_some();
#[cfg(not(feature = "web-tls"))]
let is_tls = false;
let inner = Arc::new(WebServerInner::new(self.codec));
let inner_clone = inner.clone();
let static_dir = self.static_dir.clone();
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let handle = std::thread::Builder::new()
.name("wingfoil-web".to_string())
.spawn(move || {
let rt = match tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
Ok(rt) => rt,
Err(e) => {
log::error!("web: failed to build runtime: {e}");
return;
}
};
rt.block_on(async move {
let app = build_router(inner_clone, static_dir);
#[cfg(feature = "web-tls")]
if let Some(cfg) = tls_config {
serve_tls(listener, app, cfg, shutdown_rx).await;
return;
}
listener
.set_nonblocking(true)
.expect("listener set_nonblocking");
let tokio_listener =
tokio::net::TcpListener::from_std(listener).expect("web: from_std");
if let Err(e) = axum::serve(tokio_listener, app)
.with_graceful_shutdown(async move {
let _ = shutdown_rx.await;
})
.await
{
log::warn!("web: axum serve exited with error: {e}");
}
});
})
.context("web: spawn server thread")?;
Ok(WebServer {
inner,
port,
shutdown_tx: Some(shutdown_tx),
thread: Some(handle),
historical_noop: false,
tls: is_tls,
})
}
pub fn start_historical(self) -> anyhow::Result<WebServer> {
Ok(WebServer {
inner: Arc::new(WebServerInner::new(self.codec)),
port: 0,
shutdown_tx: None,
thread: None,
historical_noop: true,
tls: false,
})
}
}
fn build_router(inner: Arc<WebServerInner>, static_dir: Option<PathBuf>) -> Router {
let mut router = Router::new()
.route("/ws", get(ws_handler))
.with_state(inner);
if let Some(dir) = static_dir {
router = router.fallback_service(ServeDir::new(dir));
}
router
}
async fn ws_handler(
ws: WebSocketUpgrade,
State(inner): State<Arc<WebServerInner>>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(socket, inner))
}
async fn handle_socket(socket: WebSocket, inner: Arc<WebServerInner>) {
let codec = inner.codec;
let (mut ws_sink, mut ws_stream) = socket.split();
let (outbound_tx, mut outbound_rx) = mpsc::channel::<Bytes>(CONNECTION_OUTBOUND_CAPACITY);
let writer = tokio::spawn(async move {
while let Some(bytes) = outbound_rx.recv().await {
if ws_sink.send(Message::Binary(bytes)).await.is_err() {
break;
}
}
let _ = ws_sink.close().await;
});
let hello = ControlMessage::Hello {
codec,
version: WIRE_PROTOCOL_VERSION,
};
let hello_bytes = match encode_control_frame(codec, &hello) {
Ok(b) => b,
Err(e) => {
log::error!("web: encode hello failed: {e}");
writer.abort();
return;
}
};
if outbound_tx.send(hello_bytes).await.is_err() {
return;
}
let mut forwarders: HashMap<String, tokio::task::JoinHandle<()>> = HashMap::new();
while let Some(msg) = ws_stream.next().await {
let msg = match msg {
Ok(m) => m,
Err(e) => {
log::debug!("web: ws recv error: {e}");
break;
}
};
let bytes: Bytes = match msg {
Message::Binary(b) => b,
Message::Text(t) => Bytes::copy_from_slice(t.as_bytes()),
Message::Close(_) => break,
Message::Ping(_) | Message::Pong(_) => continue,
};
let env: Envelope = match codec.decode(&bytes) {
Ok(e) => e,
Err(e) => {
log::warn!("web: bad envelope from client: {e}");
continue;
}
};
if env.topic == CONTROL_TOPIC {
let ctrl: ControlMessage = match codec.decode(&env.payload) {
Ok(c) => c,
Err(e) => {
log::warn!("web: bad control payload: {e}");
continue;
}
};
match ctrl {
ControlMessage::Subscribe { topics } => {
for topic in topics {
if forwarders.contains_key(&topic) {
continue;
}
let sender = inner.get_or_create_pub_topic(&topic);
let rx = sender.subscribe();
let out = outbound_tx.clone();
let topic_for_log = topic.clone();
let handle = tokio::spawn(async move {
forward_broadcast(topic_for_log, rx, out).await;
});
forwarders.insert(topic, handle);
}
}
ControlMessage::Unsubscribe { topics } => {
for topic in topics {
if let Some(h) = forwarders.remove(&topic) {
h.abort();
}
}
}
ControlMessage::Hello { .. } => {
}
}
} else {
inner.dispatch_client_payload(&env.topic, Bytes::from(env.payload));
}
}
for (_, h) in forwarders.drain() {
h.abort();
}
drop(outbound_tx);
let _ = writer.await;
}
async fn forward_broadcast(
topic: String,
mut rx: broadcast::Receiver<Bytes>,
out: mpsc::Sender<Bytes>,
) {
loop {
match rx.recv().await {
Ok(bytes) => match out.try_send(bytes) {
Ok(()) => {}
Err(mpsc::error::TrySendError::Full(_)) => {
log::warn!("web_pub: client outbound full, dropping frame on '{topic}'");
}
Err(mpsc::error::TrySendError::Closed(_)) => break,
},
Err(broadcast::error::RecvError::Lagged(n)) => {
log::warn!("web_pub: client lagged by {n} frames on '{topic}'");
}
Err(broadcast::error::RecvError::Closed) => break,
}
}
}
fn encode_control_frame(codec: CodecKind, ctrl: &ControlMessage) -> anyhow::Result<Bytes> {
let payload = codec.encode(ctrl)?;
let env = Envelope {
topic: CONTROL_TOPIC.to_string(),
time_ns: 0,
payload,
};
Ok(Bytes::from(codec.encode(&env)?))
}
#[cfg(feature = "web-tls")]
fn load_tls_config(paths: &TlsPaths) -> anyhow::Result<axum_server::tls_rustls::RustlsConfig> {
use std::fs::File;
use std::io::BufReader;
use rustls::ServerConfig;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
let cert_file = File::open(&paths.cert_path)
.with_context(|| format!("web-tls: open cert {}", paths.cert_path.display()))?;
let mut cert_reader = BufReader::new(cert_file);
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
.collect::<Result<_, _>>()
.with_context(|| format!("web-tls: parse cert {}", paths.cert_path.display()))?;
if certs.is_empty() {
anyhow::bail!(
"web-tls: no certificates found in {}",
paths.cert_path.display()
);
}
let key_file = File::open(&paths.key_path)
.with_context(|| format!("web-tls: open key {}", paths.key_path.display()))?;
let mut key_reader = BufReader::new(key_file);
let key: PrivateKeyDer<'static> = rustls_pemfile::private_key(&mut key_reader)
.with_context(|| format!("web-tls: parse key {}", paths.key_path.display()))?
.ok_or_else(|| {
anyhow::anyhow!(
"web-tls: no private key found in {}",
paths.key_path.display()
)
})?;
let server_config =
ServerConfig::builder_with_provider(rustls::crypto::ring::default_provider().into())
.with_safe_default_protocol_versions()
.context("web-tls: rustls protocol versions")?
.with_no_client_auth()
.with_single_cert(certs, key)
.context("web-tls: build rustls ServerConfig")?;
Ok(axum_server::tls_rustls::RustlsConfig::from_config(
Arc::new(server_config),
))
}
#[cfg(feature = "web-tls")]
async fn serve_tls(
listener: TcpListener,
app: Router,
config: axum_server::tls_rustls::RustlsConfig,
shutdown_rx: oneshot::Receiver<()>,
) {
let handle = axum_server::Handle::new();
let shutdown_handle = handle.clone();
tokio::spawn(async move {
let _ = shutdown_rx.await;
shutdown_handle.graceful_shutdown(Some(std::time::Duration::from_secs(5)));
});
if let Err(e) = axum_server::from_tcp_rustls(listener, config)
.handle(handle)
.serve(app.into_make_service())
.await
{
log::warn!("web: axum-server (TLS) exited with error: {e}");
}
}