use std::any::Any;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use bytes::Bytes;
use http::Response;
use http::StatusCode;
use http::header;
use http_body_util::Full;
use hyper_util::rt::TokioExecutor;
use hyper_util::rt::TokioIo;
use hyper_util::server::conn::auto::Builder as AutoBuilder;
use tokio::net::TcpListener;
use tower::Service;
use tower::ServiceBuilder;
use tower_http::catch_panic::CatchPanicLayer;
use crate::codec::content_type;
use crate::dispatcher::Dispatcher;
use crate::error::ConnectError;
use crate::error::ErrorCode;
use crate::router::Router;
use crate::service::ConnectRpcService;
#[derive(Clone, Debug)]
pub struct PeerAddr(pub SocketAddr);
#[cfg(feature = "server-tls")]
#[derive(Clone, Debug)]
pub struct PeerCerts(pub Arc<[rustls::pki_types::CertificateDer<'static>]>);
#[derive(Clone, Debug)]
struct PeerInfo {
addr: SocketAddr,
#[cfg(feature = "server-tls")]
certs: Option<Arc<[rustls::pki_types::CertificateDer<'static>]>>,
}
impl PeerInfo {
fn insert_into(&self, ext: &mut http::Extensions) {
ext.insert(PeerAddr(self.addr));
#[cfg(feature = "server-tls")]
if let Some(certs) = &self.certs {
ext.insert(PeerCerts(Arc::clone(certs)));
}
}
}
#[cfg(feature = "server-tls")]
pub const DEFAULT_TLS_HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
pub struct Server {
service: ConnectRpcService,
http1_keep_alive: bool,
#[cfg(feature = "server-tls")]
tls_config: Option<Arc<rustls::ServerConfig>>,
#[cfg(feature = "server-tls")]
tls_handshake_timeout: std::time::Duration,
}
impl Server {
pub fn new(router: Router) -> Self {
Self {
service: ConnectRpcService::new(router),
http1_keep_alive: true,
#[cfg(feature = "server-tls")]
tls_config: None,
#[cfg(feature = "server-tls")]
tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
}
}
pub fn from_service(service: ConnectRpcService) -> Self {
Self {
service,
http1_keep_alive: true,
#[cfg(feature = "server-tls")]
tls_config: None,
#[cfg(feature = "server-tls")]
tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
}
}
#[cfg(feature = "server-tls")]
#[must_use]
pub fn with_tls(mut self, config: Arc<rustls::ServerConfig>) -> Self {
self.tls_config = Some(config);
self
}
#[cfg(feature = "server-tls")]
#[must_use]
pub fn tls_handshake_timeout(mut self, timeout: std::time::Duration) -> Self {
self.tls_handshake_timeout = timeout;
self
}
pub fn router(&self) -> &Router {
self.service.dispatcher()
}
pub async fn serve(
self,
addr: SocketAddr,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let listener = TcpListener::bind(addr).await?;
#[cfg(feature = "server-tls")]
let tls_acceptor = self.tls_config.map(tokio_rustls::TlsAcceptor::from);
#[cfg(not(feature = "server-tls"))]
let tls_acceptor: Option<()> = None;
let scheme = if tls_acceptor.is_some() {
"https"
} else {
"http"
};
tracing::info!("ConnectRPC server listening on {scheme}://{addr}");
serve_with_listener(
listener,
self.service,
tls_acceptor,
self.http1_keep_alive,
#[cfg(feature = "server-tls")]
self.tls_handshake_timeout,
None,
)
.await
}
#[must_use]
pub fn from_listener(listener: TcpListener) -> BoundServer {
BoundServer {
listener,
http1_keep_alive: true,
#[cfg(feature = "server-tls")]
tls_config: None,
#[cfg(feature = "server-tls")]
tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
}
}
pub async fn bind(
addr: impl tokio::net::ToSocketAddrs,
) -> Result<BoundServer, Box<dyn std::error::Error + Send + Sync>> {
let listener = TcpListener::bind(addr).await?;
Ok(BoundServer {
listener,
http1_keep_alive: true,
#[cfg(feature = "server-tls")]
tls_config: None,
#[cfg(feature = "server-tls")]
tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
})
}
}
pub struct BoundServer {
listener: TcpListener,
http1_keep_alive: bool,
#[cfg(feature = "server-tls")]
tls_config: Option<Arc<rustls::ServerConfig>>,
#[cfg(feature = "server-tls")]
tls_handshake_timeout: std::time::Duration,
}
impl BoundServer {
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.listener.local_addr()
}
#[cfg(feature = "server-tls")]
#[must_use]
pub fn with_tls(mut self, config: Arc<rustls::ServerConfig>) -> Self {
self.tls_config = Some(config);
self
}
#[cfg(feature = "server-tls")]
#[must_use]
pub fn tls_handshake_timeout(mut self, timeout: std::time::Duration) -> Self {
self.tls_handshake_timeout = timeout;
self
}
#[must_use]
pub fn http1_keep_alive(mut self, enabled: bool) -> Self {
self.http1_keep_alive = enabled;
self
}
pub async fn serve(
self,
router: Router,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
self.serve_with_service(ConnectRpcService::new(router))
.await
}
pub async fn serve_with_graceful_shutdown<F>(
self,
router: Router,
signal: F,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
F: Future<Output = ()> + Send + 'static,
{
self.serve_with_service_and_shutdown(ConnectRpcService::new(router), signal)
.await
}
pub async fn serve_with_service<D: Dispatcher>(
self,
service: ConnectRpcService<D>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
#[cfg(feature = "server-tls")]
let tls_acceptor = self.tls_config.map(tokio_rustls::TlsAcceptor::from);
#[cfg(not(feature = "server-tls"))]
let tls_acceptor: Option<()> = None;
serve_with_listener(
self.listener,
service,
tls_acceptor,
self.http1_keep_alive,
#[cfg(feature = "server-tls")]
self.tls_handshake_timeout,
None,
)
.await
}
pub async fn serve_with_service_and_shutdown<D, F>(
self,
service: ConnectRpcService<D>,
signal: F,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
D: Dispatcher,
F: Future<Output = ()> + Send + 'static,
{
#[cfg(feature = "server-tls")]
let tls_acceptor = self.tls_config.map(tokio_rustls::TlsAcceptor::from);
#[cfg(not(feature = "server-tls"))]
let tls_acceptor: Option<()> = None;
serve_with_listener(
self.listener,
service,
tls_acceptor,
self.http1_keep_alive,
#[cfg(feature = "server-tls")]
self.tls_handshake_timeout,
Some(Box::pin(signal)),
)
.await
}
}
type WrappedService<D> = tower_http::catch_panic::CatchPanic<
ConnectRpcService<D>,
fn(Box<dyn Any + Send>) -> Response<Full<Bytes>>,
>;
async fn serve_accepted_stream<D, S>(
io: S,
peer: PeerInfo,
service: Arc<WrappedService<D>>,
http1_keep_alive: bool,
) where
D: Dispatcher,
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
tracing::trace!(remote_addr = %peer.addr, "Accepted new connection");
let peer_for_requests = peer.clone();
let svc = hyper::service::service_fn(move |mut req| {
peer_for_requests.insert_into(req.extensions_mut());
let mut service = (*service).clone();
async move { service.call(req).await }
});
let mut builder = AutoBuilder::new(TokioExecutor::new());
builder.http1().keep_alive(http1_keep_alive);
match builder.serve_connection(TokioIo::new(io), svc).await {
Ok(()) => {
tracing::trace!(remote_addr = %peer.addr, "Connection completed normally");
}
Err(err) => {
tracing::trace!(
remote_addr = %peer.addr,
error = %err,
"Connection ended with error",
);
}
}
}
#[cfg(feature = "server-tls")]
type MaybeTlsAcceptor = Option<tokio_rustls::TlsAcceptor>;
#[cfg(not(feature = "server-tls"))]
type MaybeTlsAcceptor = Option<()>;
type ShutdownSignal = Option<Pin<Box<dyn Future<Output = ()> + Send>>>;
async fn serve_with_listener<D: Dispatcher>(
listener: TcpListener,
service: ConnectRpcService<D>,
tls_acceptor: MaybeTlsAcceptor,
http1_keep_alive: bool,
#[cfg(feature = "server-tls")] tls_handshake_timeout: std::time::Duration,
shutdown: ShutdownSignal,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let service: WrappedService<D> = ServiceBuilder::new()
.layer(CatchPanicLayer::custom(panic_handler as fn(_) -> _))
.service(service);
let service = Arc::new(service);
#[cfg(feature = "server-tls")]
let tls_acceptor = tls_acceptor.map(Arc::new);
#[cfg(not(feature = "server-tls"))]
let _ = tls_acceptor;
let tracker = tokio_util::task::TaskTracker::new();
let mut shutdown = shutdown.unwrap_or_else(|| Box::pin(std::future::pending()));
loop {
let (stream, remote_addr) = tokio::select! {
biased;
_ = &mut shutdown => {
tracing::info!("Shutdown signal received; draining connections");
break;
}
accept_result = listener.accept() => match accept_result {
Ok(conn) => conn,
Err(err) => {
if is_transient_accept_error(&err) {
tracing::warn!("Transient accept error (continuing): {}", err);
continue;
}
return Err(err.into());
}
},
};
if let Err(e) = stream.set_nodelay(true) {
tracing::warn!("failed to set TCP_NODELAY: {e}");
}
let service = Arc::clone(&service);
#[cfg(feature = "server-tls")]
let tls_acceptor = tls_acceptor.clone();
tracker.spawn(async move {
#[cfg(feature = "server-tls")]
if let Some(acceptor) = tls_acceptor {
match tokio::time::timeout(tls_handshake_timeout, acceptor.accept(stream)).await {
Ok(Ok(tls_stream)) => {
let (_, conn) = tls_stream.get_ref();
let certs = conn.peer_certificates().map(|chain| -> Arc<[_]> {
chain.iter().map(|c| c.clone().into_owned()).collect()
});
let peer = PeerInfo {
addr: remote_addr,
certs,
};
serve_accepted_stream(tls_stream, peer, service, http1_keep_alive).await;
}
Ok(Err(err)) => {
tracing::debug!(
remote_addr = %remote_addr,
error = ?err,
"TLS handshake failed: {err}",
);
}
Err(_) => {
tracing::warn!(
remote_addr = %remote_addr,
"TLS handshake timed out after {tls_handshake_timeout:?}",
);
}
}
return;
}
let peer = PeerInfo {
addr: remote_addr,
#[cfg(feature = "server-tls")]
certs: None,
};
serve_accepted_stream(stream, peer, service, http1_keep_alive).await;
});
}
drop(listener);
tracker.close();
tracker.wait().await;
tracing::info!("All connections drained; shutdown complete");
Ok(())
}
fn panic_handler(err: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
let backtrace = std::backtrace::Backtrace::capture();
let message = if let Some(s) = err.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = err.downcast_ref::<&str>() {
(*s).to_string()
} else {
"handler panicked".to_string()
};
match backtrace.status() {
std::backtrace::BacktraceStatus::Captured => {
tracing::error!(
"Request handler panicked: {}\n\nBacktrace:\n{}",
message,
backtrace
);
}
_ => {
tracing::error!(
"Request handler panicked: {} (set RUST_BACKTRACE=1 for backtrace)",
message
);
}
}
let error = ConnectError::new(ErrorCode::Internal, "internal server error");
let body = error.to_json();
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(header::CONTENT_TYPE, content_type::JSON)
.body(Full::new(body))
.unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Full::new(Bytes::new()))
.unwrap()
})
}
fn is_transient_accept_error(err: &std::io::Error) -> bool {
use std::io::ErrorKind;
matches!(
err.kind(),
ErrorKind::WouldBlock |
ErrorKind::Interrupted |
ErrorKind::ConnectionAborted |
ErrorKind::ConnectionReset
) || {
err.raw_os_error()
.is_some_and(|code| code == libc::EMFILE || code == libc::ENFILE)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
use std::time::Duration;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
const ECHO_REQ: &[u8] = concat!(
"POST /svc/Echo HTTP/1.1\r\n",
"Host: localhost\r\n",
"Content-Type: application/proto\r\n",
"Content-Length: 0\r\n",
"Connection: close\r\n",
"\r\n",
)
.as_bytes();
#[test]
fn test_server_creation() {
let router = Router::new();
let _server = Server::new(router);
}
#[tokio::test]
async fn test_graceful_shutdown_immediate() {
let bound = Server::bind("127.0.0.1:0").await.unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
rx.await.ok();
})
.await
});
tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(5), serve)
.await
.expect("server did not shut down in time")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test]
async fn test_graceful_shutdown_with_inflight_connection() {
let bound = Server::bind("127.0.0.1:0").await.unwrap();
let addr = bound.local_addr().unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
rx.await.ok();
})
.await
});
let conn = tokio::net::TcpStream::connect(addr).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
tx.send(()).unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(
!serve.is_finished(),
"server shut down before connection closed"
);
drop(conn);
let result = tokio::time::timeout(Duration::from_secs(5), serve)
.await
.expect("server did not shut down after connection dropped")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test]
async fn test_graceful_shutdown_rejects_new_connections() {
let bound = Server::bind("127.0.0.1:0").await.unwrap();
let addr = bound.local_addr().unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
rx.await.ok();
})
.await
});
tokio::time::sleep(Duration::from_millis(20)).await;
tx.send(()).unwrap();
tokio::time::timeout(Duration::from_secs(5), serve)
.await
.unwrap()
.unwrap()
.unwrap();
let connect_result = tokio::net::TcpStream::connect(addr).await;
assert!(
connect_result.is_err(),
"expected connection refused after shutdown"
);
}
#[tokio::test]
async fn peer_addr_reaches_handler() {
let captured: Arc<Mutex<Option<PeerAddr>>> = Arc::new(Mutex::new(None));
let handler_captured = Arc::clone(&captured);
let router = Router::new().route(
"svc",
"Echo",
crate::handler_fn(move |ctx: crate::Context, _req: buffa_types::Empty| {
let cap = Arc::clone(&handler_captured);
async move {
*cap.lock().unwrap() = ctx.extensions.get::<PeerAddr>().cloned();
Ok((buffa_types::Empty::default(), ctx))
}
}),
);
let bound = Server::bind("127.0.0.1:0").await.unwrap();
let addr = bound.local_addr().unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(router, async {
rx.await.ok();
})
.await
});
let mut stream = tokio::net::TcpStream::connect(addr).await.unwrap();
let client_local = stream.local_addr().unwrap();
stream.write_all(ECHO_REQ).await.unwrap();
let mut resp = Vec::new();
stream.read_to_end(&mut resp).await.unwrap();
assert!(
resp.starts_with(b"HTTP/1.1 2"),
"expected 2xx, got: {}",
String::from_utf8_lossy(&resp[..resp.len().min(80)])
);
tx.send(()).unwrap();
tokio::time::timeout(Duration::from_secs(5), serve)
.await
.unwrap()
.unwrap()
.unwrap();
let peer = captured
.lock()
.unwrap()
.take()
.expect("handler should have captured PeerAddr");
assert_eq!(peer.0, client_local);
}
#[cfg(feature = "server-tls")]
#[tokio::test]
async fn peer_certs_reach_handler() {
fn pki() -> (
Arc<rustls::ServerConfig>,
Arc<rustls::ClientConfig>,
rustls::pki_types::CertificateDer<'static>,
) {
use rcgen::CertificateParams;
use rcgen::KeyPair;
use rcgen::SanType;
use rustls::pki_types::CertificateDer;
use rustls::pki_types::PrivatePkcs8KeyDer;
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let ca_key = KeyPair::generate().unwrap();
let mut ca_params = CertificateParams::default();
ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
let ca = rcgen::CertifiedIssuer::self_signed(ca_params, ca_key).unwrap();
let issue = |sans: &[SanType]| {
let k = KeyPair::generate().unwrap();
let mut p = CertificateParams::default();
p.subject_alt_names = sans.to_vec();
let c = p.signed_by(&k, &ca).unwrap();
(
CertificateDer::from(c.der().to_vec()),
PrivatePkcs8KeyDer::from(k.serialized_der().to_vec()).into(),
)
};
let (srv_cert, srv_key) = issue(&[SanType::DnsName("localhost".try_into().unwrap())]);
let (cli_cert, cli_key) = issue(&[]);
let mut roots = rustls::RootCertStore::empty();
roots.add(CertificateDer::from(ca.der().to_vec())).unwrap();
let roots = Arc::new(roots);
let cv = rustls::server::WebPkiClientVerifier::builder(Arc::clone(&roots))
.build()
.unwrap();
let server = rustls::ServerConfig::builder()
.with_client_cert_verifier(cv)
.with_single_cert(vec![srv_cert], srv_key)
.unwrap();
let client = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_client_auth_cert(vec![cli_cert.clone()], cli_key)
.unwrap();
(Arc::new(server), Arc::new(client), cli_cert)
}
let (server_cfg, client_cfg, expected_client_der) = pki();
let captured: Arc<Mutex<Option<PeerCerts>>> = Arc::new(Mutex::new(None));
let handler_captured = Arc::clone(&captured);
let router = Router::new().route(
"svc",
"Echo",
crate::handler_fn(move |ctx: crate::Context, _req: buffa_types::Empty| {
let cap = Arc::clone(&handler_captured);
async move {
*cap.lock().unwrap() = ctx.extensions.get::<PeerCerts>().cloned();
Ok((buffa_types::Empty::default(), ctx))
}
}),
);
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_tls(server_cfg);
let addr = bound.local_addr().unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(router, async {
rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let connector = tokio_rustls::TlsConnector::from(client_cfg);
let sni = rustls::pki_types::ServerName::try_from("localhost").unwrap();
let mut tls = connector.connect(sni, tcp).await.unwrap();
tls.write_all(ECHO_REQ).await.unwrap();
let mut resp = Vec::new();
tls.read_to_end(&mut resp).await.unwrap();
assert!(
resp.starts_with(b"HTTP/1.1 2"),
"expected 2xx, got: {}",
String::from_utf8_lossy(&resp[..resp.len().min(80)])
);
tx.send(()).unwrap();
tokio::time::timeout(Duration::from_secs(5), serve)
.await
.unwrap()
.unwrap()
.unwrap();
let certs = captured
.lock()
.unwrap()
.take()
.expect("handler should have captured PeerCerts");
assert_eq!(certs.0.len(), 1);
assert_eq!(certs.0[0].as_ref(), expected_client_der.as_ref());
}
}