use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use crate::error::BoltError;
use crate::server::auth::AuthValidator;
use crate::server::backend::BoltBackend;
use crate::server::connection::Connection;
use crate::server::handshake::server_handshake;
use crate::server::session_manager::SessionManager;
#[cfg(feature = "tls")]
use std::io::BufReader;
#[cfg(feature = "tls")]
use tokio_rustls::TlsAcceptor;
#[cfg(feature = "tls")]
pub struct TlsConfig {
acceptor: TlsAcceptor,
}
#[cfg(feature = "tls")]
impl TlsConfig {
pub fn from_pem(cert_pem: &[u8], key_pem: &[u8]) -> Result<Self, BoltError> {
let certs: Vec<_> = rustls_pemfile::certs(&mut BufReader::new(cert_pem))
.collect::<Result<_, _>>()
.map_err(|e| BoltError::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))?;
let key = rustls_pemfile::private_key(&mut BufReader::new(key_pem))
.map_err(|e| BoltError::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))?
.ok_or_else(|| {
BoltError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"no private key found in PEM data",
))
})?;
let config = tokio_rustls::rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| BoltError::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))?;
Ok(Self {
acceptor: TlsAcceptor::from(Arc::new(config)),
})
}
}
pub struct BoltServer<B: BoltBackend> {
backend: B,
auth_validator: Option<Arc<dyn AuthValidator>>,
idle_timeout: Option<Duration>,
max_sessions: Option<usize>,
shutdown: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
#[cfg(feature = "tls")]
tls_config: Option<TlsConfig>,
}
impl<B: BoltBackend> BoltServer<B> {
pub fn builder(backend: B) -> Self {
Self {
backend,
auth_validator: None,
idle_timeout: None,
max_sessions: None,
shutdown: None,
#[cfg(feature = "tls")]
tls_config: None,
}
}
pub fn auth(mut self, validator: impl AuthValidator) -> Self {
self.auth_validator = Some(Arc::new(validator));
self
}
#[cfg(feature = "tls")]
pub fn tls(mut self, config: TlsConfig) -> Self {
self.tls_config = Some(config);
self
}
pub fn idle_timeout(mut self, timeout: Duration) -> Self {
self.idle_timeout = Some(timeout);
self
}
pub fn max_sessions(mut self, limit: usize) -> Self {
self.max_sessions = Some(limit);
self
}
pub fn shutdown(mut self, signal: impl Future<Output = ()> + Send + 'static) -> Self {
self.shutdown = Some(Box::pin(signal));
self
}
pub async fn serve(self, addr: SocketAddr) -> Result<(), BoltError> {
let listener = TcpListener::bind(addr).await?;
let backend = Arc::new(self.backend);
let session_manager = Arc::new(SessionManager::new(self.max_sessions));
let auth_validator = self.auth_validator;
#[cfg(feature = "tls")]
let tls_acceptor = self.tls_config.map(|c| Arc::new(c.acceptor));
#[cfg(not(feature = "tls"))]
let tls_acceptor: Option<()> = None;
let reaper_handle = if let Some(timeout) = self.idle_timeout {
let sm = session_manager.clone();
let be = backend.clone();
let handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(timeout / 2);
loop {
interval.tick().await;
let expired = sm.reap_idle(timeout);
for id in &expired {
let handle = crate::server::SessionHandle(id.clone());
let _ = be.close_session(&handle).await;
tracing::debug!(session_id = %id, "reaped idle Bolt session");
}
}
});
Some(handle)
} else {
None
};
let tls_label = if tls_acceptor.is_some() { " (TLS)" } else { "" };
tracing::info!(%addr, "Bolt server listening{}", tls_label);
let shutdown = self.shutdown;
let accept_result = if let Some(shutdown_signal) = shutdown {
tokio::pin!(shutdown_signal);
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((stream, peer_addr)) => {
spawn_connection(
stream,
peer_addr,
backend.clone(),
session_manager.clone(),
auth_validator.clone(),
tls_acceptor.clone(),
);
}
Err(e) => {
tracing::warn!(error = %e, "accept error");
}
}
}
() = &mut shutdown_signal => {
tracing::info!("Bolt server shutting down");
break;
}
}
}
Ok(())
} else {
loop {
match listener.accept().await {
Ok((stream, peer_addr)) => {
spawn_connection(
stream,
peer_addr,
backend.clone(),
session_manager.clone(),
auth_validator.clone(),
tls_acceptor.clone(),
);
}
Err(e) => {
tracing::warn!(error = %e, "accept error");
}
}
}
};
if let Some(handle) = reaper_handle {
handle.abort();
}
tracing::info!("Bolt server stopped");
accept_result
}
}
fn spawn_connection<B: BoltBackend>(
stream: tokio::net::TcpStream,
peer_addr: SocketAddr,
backend: Arc<B>,
session_manager: Arc<SessionManager>,
auth_validator: Option<Arc<dyn AuthValidator>>,
#[cfg(feature = "tls")] tls_acceptor: Option<Arc<TlsAcceptor>>,
#[cfg(not(feature = "tls"))] _tls_acceptor: Option<()>,
) {
tokio::spawn(async move {
#[cfg(feature = "tls")]
if let Some(acceptor) = tls_acceptor {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
run_handshake_and_connection(
tls_stream,
peer_addr,
backend,
session_manager,
auth_validator,
)
.await;
}
Err(e) => {
tracing::debug!(%peer_addr, error = %e, "TLS handshake failed");
}
}
return;
}
run_handshake_and_connection(stream, peer_addr, backend, session_manager, auth_validator)
.await;
});
}
async fn run_handshake_and_connection<S, B>(
stream: S,
peer_addr: SocketAddr,
backend: Arc<B>,
session_manager: Arc<SessionManager>,
auth_validator: Option<Arc<dyn AuthValidator>>,
) where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
B: BoltBackend,
{
let (read_half, write_half) = tokio::io::split(stream);
let mut combined = read_half.unsplit(write_half);
match server_handshake(&mut combined).await {
Ok(version) => {
tracing::debug!(%peer_addr, ?version, "Bolt handshake complete");
let (rh, wh) = tokio::io::split(combined);
let mut conn =
Connection::new(rh, wh, backend, session_manager, auth_validator, peer_addr);
if let Err(e) = conn.run().await {
tracing::debug!(%peer_addr, error = %e, "Bolt connection closed");
}
}
Err(e) => {
tracing::debug!(%peer_addr, error = %e, "Bolt handshake failed");
}
}
}