pub mod auth;
pub mod config;
pub mod connection;
pub mod crypto;
pub mod event;
#[cfg(feature = "grpc")]
pub mod rpc;
pub mod server;
pub mod status;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, error, info, info_span};
pub use config::Config;
pub use event::{EventManager, Phase, PlayerInfo, Plugin, PostOrder};
pub use server::{Server, ServerId, ServerRegistry};
use config::ConfigError;
#[derive(Debug, thiserror::Error)]
pub enum ProxyError {
#[error("configuration error: {0}")]
Config(#[from] ConfigError),
#[error("RSA key generation failed")]
KeyGeneration(#[source] Box<dyn std::error::Error + Send + Sync>),
#[error("HTTP client initialization failed")]
HttpClient(#[from] reqwest::Error),
#[error("duplicate server ID: {0}")]
DuplicateServer(String),
#[error("failed to bind listener")]
Bind(#[source] std::io::Error),
}
use connection::ConnectionError;
use connection::MinecraftConnection;
use crypto::ServerKeyPair;
#[cfg(feature = "grpc")]
use rpc::DeepslateService;
#[cfg(feature = "grpc")]
use rpc::proto::deepslate_server::DeepslateServer;
pub struct Proxy {
config: Arc<Config>,
registry: Arc<ServerRegistry>,
key_pair: Arc<ServerKeyPair>,
event_manager: Arc<EventManager>,
http_client: reqwest::Client,
}
impl Proxy {
#[must_use]
pub fn builder() -> ProxyBuilder {
ProxyBuilder {
config: None,
config_overrides: ConfigOverrides::default(),
bootstrap_servers: Vec::new(),
try_servers: None,
forced_hosts: Vec::new(),
plugins: Vec::new(),
}
}
#[must_use]
pub const fn config(&self) -> &Arc<Config> {
&self.config
}
#[must_use]
pub const fn registry(&self) -> &Arc<ServerRegistry> {
&self.registry
}
#[must_use]
pub const fn event_manager(&self) -> &Arc<EventManager> {
&self.event_manager
}
#[allow(clippy::too_many_lines)]
pub async fn run(&self) -> Result<(), ProxyError> {
let listener = TcpListener::bind(self.config.listen_addr)
.await
.map_err(ProxyError::Bind)?;
info!(addr = %self.config.listen_addr, "proxy listening");
#[cfg(feature = "grpc")]
let grpc_addr = self.config.grpc_addr;
#[cfg(feature = "grpc")]
let grpc_registry = Arc::clone(&self.registry);
#[cfg(feature = "grpc")]
let mut grpc_handle = tokio::spawn(async move {
let grpc_service = DeepslateService::new(grpc_registry);
#[cfg(feature = "grpc-reflection")]
let reflection_service = tonic_reflection::server::Builder::configure()
.register_encoded_file_descriptor_set(rpc::proto::FILE_DESCRIPTOR_SET)
.build_v1()
.expect("failed to build reflection service");
info!(addr = %grpc_addr, "gRPC control plane listening");
let builder =
tonic::transport::Server::builder().add_service(DeepslateServer::new(grpc_service));
#[cfg(feature = "grpc-reflection")]
let builder = builder.add_service(reflection_service);
builder.serve(grpc_addr).await
});
let shutdown = CancellationToken::new();
let mut connections = JoinSet::new();
let session_counter = AtomicUsize::new(0);
loop {
#[cfg(feature = "grpc")]
let grpc_fut = &mut grpc_handle;
#[cfg(not(feature = "grpc"))]
let grpc_fut = std::future::pending::<Result<(), String>>();
tokio::select! {
result = listener.accept() => {
match result {
Ok((stream, addr)) => {
if let Err(e) = stream.set_nodelay(true) {
tracing::warn!(error = %e, "failed to set TCP_NODELAY on client stream");
}
let session_id = session_counter.fetch_add(1, Ordering::Relaxed);
let config = Arc::clone(&self.config);
let key_pair = Arc::clone(&self.key_pair);
let registry = Arc::clone(&self.registry);
let http_client = self.http_client.clone();
let event_manager = Arc::clone(&self.event_manager);
let conn_shutdown = shutdown.clone();
connections.spawn(
async move {
let conn = MinecraftConnection::new(
stream,
libdeflater::CompressionLvl::new(config.compression_level)
.expect("validated in Config::validate"),
Duration::from_millis(config.read_timeout_ms),
);
if let Err(e) = Box::pin(connection::client::handle_client(
conn,
addr,
config,
key_pair,
registry,
http_client,
event_manager,
conn_shutdown,
))
.await
{
match &e {
ConnectionError::Io(io_err) => {
tracing::debug!(error = %io_err, "client I/O error");
}
ConnectionError::Protocol(proto_err) => {
tracing::debug!(error = %proto_err, "protocol error");
}
ConnectionError::Auth(auth_err) => {
tracing::debug!(error = %auth_err, "authentication error");
}
ConnectionError::BackendFailed { reason } => {
tracing::debug!(reason, "backend connection failed");
}
ConnectionError::Timeout => {
tracing::debug!("connection timed out");
}
}
}
}
.instrument(info_span!(
"conn",
sid = session_id,
ip = %addr.ip(),
port = addr.port()
)),
);
}
Err(e) => {
error!(error = %e, "failed to accept connection");
}
}
}
Some(result) = connections.join_next(), if !connections.is_empty() => {
if let Err(e) = result {
tracing::debug!(error = %e, "connection task panicked");
}
}
_ = tokio::signal::ctrl_c() => {
info!("received shutdown signal");
break;
}
result = grpc_fut => {
if let Err(e) = result {
error!(error = %e, "gRPC server error");
}
break;
}
}
}
info!("shutting down");
shutdown.cancel();
let drain = self.config.shutdown_drain;
if !connections.is_empty() {
info!(
connections = connections.len(),
drain_ms = u64::try_from(drain.as_millis()).unwrap_or(u64::MAX),
"draining active connections"
);
let _ = tokio::time::timeout(drain, async {
while connections.join_next().await.is_some() {}
})
.await;
let remaining = connections.len();
if remaining > 0 {
info!(
remaining,
"drain timeout reached, dropping remaining connections"
);
connections.abort_all();
}
}
Ok(())
}
}
pub struct ProxyBuilder {
config: Option<Config>,
config_overrides: ConfigOverrides,
bootstrap_servers: Vec<ServerId>,
try_servers: Option<Vec<&'static str>>,
forced_hosts: Vec<(String, Vec<&'static str>)>,
plugins: Vec<Box<dyn Plugin>>,
}
impl ProxyBuilder {
#[must_use]
pub fn config(mut self, config: Config) -> Self {
self.config = Some(config);
self
}
#[must_use]
pub const fn listen_addr(mut self, listen_addr: std::net::SocketAddr) -> Self {
self.config_overrides.listen_addr = Some(listen_addr);
self
}
#[cfg(feature = "grpc")]
#[must_use]
pub const fn grpc_addr(mut self, grpc_addr: std::net::SocketAddr) -> Self {
self.config_overrides.grpc_addr = Some(grpc_addr);
self
}
#[must_use]
pub const fn online_mode(mut self, online_mode: bool) -> Self {
self.config_overrides.online_mode = Some(online_mode);
self
}
#[must_use]
pub fn forwarding_secret(mut self, forwarding_secret: impl AsRef<[u8]>) -> Self {
self.config_overrides.forwarding_secret = Some(forwarding_secret.as_ref().to_vec());
self
}
#[must_use]
pub const fn compression_threshold(mut self, compression_threshold: i32) -> Self {
self.config_overrides.compression_threshold = Some(compression_threshold);
self
}
#[must_use]
pub const fn compression_level(mut self, compression_level: i32) -> Self {
self.config_overrides.compression_level = Some(compression_level);
self
}
#[must_use]
pub fn motd(mut self, motd: impl Into<String>) -> Self {
self.config_overrides.motd = Some(motd.into());
self
}
#[must_use]
pub const fn max_players(mut self, max_players: i32) -> Self {
self.config_overrides.max_players = Some(max_players);
self
}
#[must_use]
pub const fn read_timeout_ms(mut self, read_timeout_ms: u64) -> Self {
self.config_overrides.read_timeout_ms = Some(read_timeout_ms);
self
}
#[must_use]
pub fn try_servers<I>(mut self, try_servers: I) -> Self
where
I: IntoIterator<Item = &'static ServerId>,
{
self.try_servers = Some(try_servers.into_iter().map(|s| s.id).collect());
self
}
#[must_use]
pub fn forced_host<I>(mut self, hostname: impl Into<String>, servers: I) -> Self
where
I: IntoIterator<Item = &'static ServerId>,
{
let ids: Vec<&'static str> = servers.into_iter().map(|s| s.id).collect();
self.forced_hosts
.push((hostname.into().to_lowercase(), ids));
self
}
#[must_use]
pub fn log_level(mut self, log_level: impl Into<String>) -> Self {
self.config_overrides.log_level = Some(log_level.into());
self
}
#[must_use]
pub const fn log_json(mut self, log_json: bool) -> Self {
self.config_overrides.log_json = Some(log_json);
self
}
#[must_use]
pub const fn shutdown_drain(mut self, shutdown_drain: std::time::Duration) -> Self {
self.config_overrides.shutdown_drain = Some(shutdown_drain);
self
}
#[must_use]
pub fn server(mut self, server: &'static ServerId) -> Self {
self.bootstrap_servers.push(*server);
self
}
#[must_use]
pub fn servers<I>(mut self, servers: I) -> Self
where
I: IntoIterator<Item = &'static ServerId>,
{
self.bootstrap_servers.extend(servers.into_iter().copied());
self
}
#[must_use]
pub fn plugin(mut self, plugin: impl Plugin) -> Self {
self.plugins.push(Box::new(plugin));
self
}
pub fn build(self) -> Result<Proxy, ProxyError> {
let mut config = match self.config {
Some(config) => config,
None if self.config_overrides.has_any() || !self.bootstrap_servers.is_empty() => {
Config::default()
}
None => Config::from_env()?,
};
self.config_overrides.apply(&mut config);
let config = config.validate()?;
let key_pair = ServerKeyPair::generate().map_err(ProxyError::KeyGeneration)?;
info!("generated RSA key pair");
let mut event_manager = EventManager::new();
for plugin in &self.plugins {
plugin.register(&mut event_manager);
}
let http_client = reqwest::Client::builder()
.user_agent(concat!("Deepslate/", env!("CARGO_PKG_VERSION")))
.connect_timeout(Duration::from_secs(5))
.timeout(Duration::from_secs(10))
.build()?;
let registry = ServerRegistry::new();
for server_id in &self.bootstrap_servers {
let server = Server::from(server_id);
if !registry.register(&server) {
return Err(ProxyError::DuplicateServer(server_id.id.to_string()));
}
}
if let Some(try_servers) = self.try_servers {
let order: Vec<String> = try_servers.into_iter().map(String::from).collect();
registry.set_try_order(order);
} else if !config.try_servers.is_empty() {
registry.set_try_order(config.try_servers.clone());
}
if !self.forced_hosts.is_empty() {
let map: std::collections::HashMap<String, Vec<String>> = self
.forced_hosts
.into_iter()
.map(|(host, ids)| (host, ids.into_iter().map(String::from).collect()))
.collect();
registry.set_forced_hosts(map);
} else if !config.forced_hosts.is_empty() {
registry.set_forced_hosts(config.forced_hosts.clone());
}
Ok(Proxy {
config: Arc::new(config),
registry: Arc::new(registry),
key_pair: Arc::new(key_pair),
event_manager: Arc::new(event_manager),
http_client,
})
}
}
#[derive(Default)]
struct ConfigOverrides {
listen_addr: Option<std::net::SocketAddr>,
#[cfg(feature = "grpc")]
grpc_addr: Option<std::net::SocketAddr>,
online_mode: Option<bool>,
forwarding_secret: Option<Vec<u8>>,
compression_threshold: Option<i32>,
compression_level: Option<i32>,
motd: Option<String>,
max_players: Option<i32>,
read_timeout_ms: Option<u64>,
log_level: Option<String>,
log_json: Option<bool>,
shutdown_drain: Option<std::time::Duration>,
}
impl ConfigOverrides {
const fn has_any(&self) -> bool {
let base = self.listen_addr.is_some()
|| self.online_mode.is_some()
|| self.forwarding_secret.is_some()
|| self.compression_threshold.is_some()
|| self.compression_level.is_some()
|| self.motd.is_some()
|| self.max_players.is_some()
|| self.read_timeout_ms.is_some()
|| self.log_level.is_some()
|| self.log_json.is_some()
|| self.shutdown_drain.is_some();
#[cfg(feature = "grpc")]
let base = base || self.grpc_addr.is_some();
base
}
fn apply(self, config: &mut Config) {
if let Some(listen_addr) = self.listen_addr {
config.listen_addr = listen_addr;
}
#[cfg(feature = "grpc")]
if let Some(grpc_addr) = self.grpc_addr {
config.grpc_addr = grpc_addr;
}
if let Some(online_mode) = self.online_mode {
config.online_mode = online_mode;
}
if let Some(forwarding_secret) = self.forwarding_secret {
config.forwarding_secret = forwarding_secret;
}
if let Some(compression_threshold) = self.compression_threshold {
config.compression_threshold = compression_threshold;
}
if let Some(compression_level) = self.compression_level {
config.compression_level = compression_level;
}
if let Some(motd) = self.motd {
config.motd = motd;
}
if let Some(max_players) = self.max_players {
config.max_players = max_players;
}
if let Some(read_timeout_ms) = self.read_timeout_ms {
config.read_timeout_ms = read_timeout_ms;
}
if let Some(log_level) = self.log_level {
config.log_level = log_level;
}
if let Some(log_json) = self.log_json {
config.log_json = log_json;
}
if let Some(shutdown_drain) = self.shutdown_drain {
config.shutdown_drain = shutdown_drain;
}
}
}
pub fn init_tracing(config: &Config) {
use tracing_subscriber::EnvFilter;
let filter = EnvFilter::try_new(&config.log_level).unwrap_or_else(|_| EnvFilter::new("info"));
if config.log_json {
tracing_subscriber::fmt()
.with_env_filter(filter)
.json()
.init();
} else {
tracing_subscriber::fmt().with_env_filter(filter).init();
}
}
#[cfg(test)]
mod tests {
use std::net::{Ipv4Addr, SocketAddr};
use super::*;
#[test]
fn builder_uses_code_config_without_env() {
let proxy = Proxy::builder()
.forwarding_secret("secret")
.listen_addr(SocketAddr::from((Ipv4Addr::LOCALHOST, 25_565)))
.motd("Configured in code")
.build()
.unwrap();
assert_eq!(
proxy.config.listen_addr,
SocketAddr::from((Ipv4Addr::LOCALHOST, 25_565))
);
assert_eq!(proxy.config.motd, "Configured in code");
}
#[cfg(feature = "grpc")]
#[test]
fn builder_sets_grpc_addr() {
let proxy = Proxy::builder()
.forwarding_secret("secret")
.grpc_addr(SocketAddr::from((Ipv4Addr::LOCALHOST, 25_577)))
.build()
.unwrap();
assert_eq!(
proxy.config.grpc_addr,
SocketAddr::from((Ipv4Addr::LOCALHOST, 25_577))
);
}
#[test]
fn builder_rejects_invalid_manual_config() {
let result = Proxy::builder()
.forwarding_secret("secret")
.compression_level(13)
.build();
assert!(result.is_err());
}
#[test]
fn builder_bootstraps_servers_and_try_order() {
static LOBBY: ServerId = ServerId::new("lobby", "127.0.0.1:25566");
static SURVIVAL: ServerId = ServerId::new("survival", "127.0.0.1:25567");
let proxy = Proxy::builder()
.forwarding_secret("secret")
.server(&LOBBY)
.server(&SURVIVAL)
.try_servers([&SURVIVAL, &LOBBY])
.build()
.unwrap();
assert_eq!(proxy.registry.list().len(), 2);
assert_eq!(proxy.registry.try_order(), vec!["survival", "lobby"]);
assert_eq!(proxy.registry.select_initial().unwrap().id, "survival");
}
#[test]
fn builder_overrides_explicit_config_fields() {
let proxy = Proxy::builder()
.config(Config {
forwarding_secret: b"base-secret".to_vec(),
motd: "Base MOTD".to_string(),
..Config::default()
})
.forwarding_secret("override-secret")
.motd("Override MOTD")
.build()
.unwrap();
assert_eq!(proxy.config.forwarding_secret, b"override-secret".to_vec());
assert_eq!(proxy.config.motd, "Override MOTD");
}
}