use std::convert::Infallible;
use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use clap::Command;
use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper_util::rt::TokioIo;
use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
use rmcp::transport::streamable_http_server::{StreamableHttpServerConfig, StreamableHttpService};
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use tower_service::Service as TowerService;
use crate::Result;
use crate::config::Config;
use crate::server::BrontesServer;
pub const SHUTDOWN_GRACE: Duration = Duration::from_secs(5);
pub trait Acceptor: Send + Sync + 'static {
fn accept(&self) -> impl Future<Output = io::Result<(TokioIo<TcpStream>, SocketAddr)>> + Send;
}
pub struct TokioTcpAcceptor {
listener: TcpListener,
}
impl TokioTcpAcceptor {
pub const fn new(listener: TcpListener) -> Self {
Self { listener }
}
}
impl Acceptor for TokioTcpAcceptor {
async fn accept(&self) -> io::Result<(TokioIo<TcpStream>, SocketAddr)> {
let (stream, peer) = self.listener.accept().await?;
Ok((TokioIo::new(stream), peer))
}
}
pub async fn bind_default_acceptor(addr: SocketAddr) -> Result<TokioTcpAcceptor> {
let listener = TcpListener::bind(addr)
.await
.map_err(|e| crate::Error::Io {
context: format!("bind streamable HTTP listener on {addr}"),
source: e,
})?;
Ok(TokioTcpAcceptor::new(listener))
}
pub async fn serve_http(
cli: Command,
cfg: Config,
addr: SocketAddr,
cancel: CancellationToken,
extra_allowed_hosts: Vec<String>,
) -> Result<()> {
let acceptor = bind_default_acceptor(addr).await?;
serve_http_with(
cli,
cfg,
acceptor,
cancel,
extra_allowed_hosts,
SHUTDOWN_GRACE,
)
.await
}
pub async fn serve_http_with<A>(
cli: Command,
cfg: Config,
acceptor: A,
cancel: CancellationToken,
extra_allowed_hosts: Vec<String>,
shutdown_grace: Duration,
) -> Result<()>
where
A: Acceptor,
{
BrontesServer::new(cli.clone(), cfg.clone())?;
let factory_cli = cli;
let factory_cfg = cfg;
let session_manager = Arc::new(LocalSessionManager::default());
let mut allowed_hosts = StreamableHttpServerConfig::default().allowed_hosts;
allowed_hosts.extend(extra_allowed_hosts);
let config = StreamableHttpServerConfig::default()
.with_cancellation_token(cancel.clone())
.with_allowed_hosts(allowed_hosts);
let service: StreamableHttpService<BrontesServer, LocalSessionManager> =
StreamableHttpService::new(
move || {
BrontesServer::new(factory_cli.clone(), factory_cfg.clone()).map_err(|e| {
std::io::Error::other(format!("brontes server construction: {e}"))
})
},
session_manager,
config,
);
let tracker = TaskTracker::new();
loop {
tokio::select! {
biased;
() = cancel.cancelled() => {
tracing::info!("cancellation token fired; stopping accept loop");
break;
}
accepted = acceptor.accept() => {
match accepted {
Ok((io, peer)) => {
let conn_service = service.clone();
let conn_cancel = cancel.clone();
tracker.spawn(async move {
let svc = service_fn(move |req: hyper::Request<Incoming>| {
let mut per_call = conn_service.clone();
let fut = TowerService::call(&mut per_call, req);
async move {
let resp = match fut.await {
Ok(r) => r,
Err(never) => match never {},
};
Ok::<_, Infallible>(resp)
}
});
let conn = hyper::server::conn::http1::Builder::new()
.serve_connection(io, svc);
tokio::pin!(conn);
tokio::select! {
res = conn.as_mut() => {
if let Err(e) = res {
tracing::debug!(error = %e, peer = %peer, "connection ended with error");
}
}
() = conn_cancel.cancelled() => {
conn.as_mut().graceful_shutdown();
if let Err(e) = conn.as_mut().await {
tracing::debug!(error = %e, peer = %peer, "connection shutdown error");
}
}
}
});
}
Err(e) => {
tracing::warn!(error = %e, "accept failed; continuing");
}
}
}
}
}
tracker.close();
if tokio::time::timeout(shutdown_grace, tracker.wait())
.await
.is_ok()
{
tracing::info!("HTTP server drained cleanly");
} else {
tracing::warn!(
grace = ?shutdown_grace,
"HTTP server connections did not drain within {shutdown_grace:?}; abandoning"
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::SHUTDOWN_GRACE;
#[test]
fn shutdown_grace_matches_ophis_5_seconds() {
assert_eq!(SHUTDOWN_GRACE, Duration::from_secs(5));
}
}