use crate::config::PluginConfig;
use crate::plugin::{PluginHandler, Registry};
#[cfg(feature = "doh")]
use crate::server::DohServer;
#[cfg(feature = "doq")]
use crate::server::DoqServer;
#[cfg(feature = "dot")]
use crate::server::DotServer;
#[cfg(feature = "metrics")]
use crate::server::MonitoringServer;
#[cfg(any(feature = "doh", feature = "dot"))]
use crate::server::TlsConfig;
#[cfg(feature = "admin")]
use crate::server::admin::{AdminServer, AdminState};
use crate::server::{ServerConfig, TcpServer, UdpServer};
use serde_yaml::Value;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::RwLock;
#[cfg(any(feature = "admin", feature = "metrics"))]
use tracing::info;
use tracing::{error, warn};
fn normalize_listen_addr(listen: &str) -> String {
if listen.starts_with(':') {
format!("0.0.0.0{}", listen)
} else {
listen.to_string()
}
}
pub struct ServerLauncher {
registry: Arc<Registry>,
}
impl ServerLauncher {
pub fn new(registry: Arc<Registry>) -> Self {
Self { registry }
}
pub async fn launch_all(
&self,
plugins: &[PluginConfig],
) -> Vec<tokio::sync::oneshot::Receiver<()>> {
let mut receivers = Vec::new();
for plugin_config in plugins {
let receiver = match plugin_config.plugin_type.as_str() {
"udp_server" => self.launch_udp_server(plugin_config).await,
"tcp_server" => self.launch_tcp_server(plugin_config).await,
"doh_server" => self.launch_doh_server(plugin_config).await,
"dot_server" => self.launch_dot_server(plugin_config).await,
"doq_server" => self.launch_doq_server(plugin_config).await,
_ => continue,
};
if let Some(rx) = receiver {
receivers.push(rx);
}
}
receivers
}
fn parse_listen_addr(
&self,
args: &HashMap<String, Value>,
default: &str,
) -> Option<SocketAddr> {
let listen_str = args
.get("listen")
.and_then(|v| v.as_str())
.unwrap_or(default);
let normalized = normalize_listen_addr(listen_str);
match normalized.parse::<SocketAddr>() {
Ok(addr) => Some(addr),
Err(e) => {
error!("Failed to parse listen address '{}': {}", listen_str, e);
None
}
}
}
fn get_entry(&self, args: &HashMap<String, Value>) -> String {
args.get("entry")
.and_then(|v| v.as_str())
.unwrap_or("main_sequence")
.to_string()
}
fn create_handler(&self, entry: String) -> Arc<PluginHandler> {
Arc::new(PluginHandler {
registry: Arc::clone(&self.registry),
entry,
})
}
async fn launch_udp_server(
&self,
plugin_config: &PluginConfig,
) -> Option<tokio::sync::oneshot::Receiver<()>> {
let args = plugin_config.effective_args();
let addr = self.parse_listen_addr(&args, "0.0.0.0:53")?;
let entry = self.get_entry(&args);
let config = ServerConfig {
udp_addr: Some(addr),
..Default::default()
};
let handler = self.create_handler(entry);
match UdpServer::new(config, handler).await {
Ok(server) => {
let (tx, rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
let _ = tx.send(());
if let Err(e) = server.run().await {
error!("UDP server error: {}", e);
}
});
Some(rx)
}
Err(e) => {
error!("Failed to start UDP server on {}: {}", addr, e);
None
}
}
}
async fn launch_tcp_server(
&self,
plugin_config: &PluginConfig,
) -> Option<tokio::sync::oneshot::Receiver<()>> {
let args = plugin_config.effective_args();
let addr = self.parse_listen_addr(&args, "0.0.0.0:53")?;
let entry = self.get_entry(&args);
let config = ServerConfig {
tcp_addr: Some(addr),
..Default::default()
};
let handler = self.create_handler(entry);
match TcpServer::new(config, handler).await {
Ok(server) => {
let (tx, rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
let _ = tx.send(());
if let Err(e) = server.run().await {
error!("TCP server error: {}", e);
}
});
Some(rx)
}
Err(e) => {
error!("Failed to start TCP server on {}: {}", addr, e);
None
}
}
}
#[cfg(feature = "doh")]
async fn launch_doh_server(
&self,
plugin_config: &PluginConfig,
) -> Option<tokio::sync::oneshot::Receiver<()>> {
let args = plugin_config.effective_args();
let addr = self.parse_listen_addr(&args, "0.0.0.0:443")?;
let cert_path = args.get("cert_file").and_then(|v| v.as_str());
let key_path = args.get("key_file").and_then(|v| v.as_str());
let (Some(cert_path), Some(key_path)) = (cert_path, key_path) else {
warn!("doh_server plugin configured without cert_file/key_file");
return None;
};
let tls = match TlsConfig::from_files(cert_path, key_path) {
Ok(t) => t,
Err(e) => {
error!("Failed to load TLS config for DoH: {}", e);
return None;
}
};
let entry = self.get_entry(&args);
let handler = self.create_handler(entry);
let server = DohServer::new(addr.to_string(), tls, handler);
let (tx, rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
let _ = tx.send(());
if let Err(e) = server.run().await {
error!("DoH server error: {}", e);
}
});
Some(rx)
}
#[cfg(not(feature = "doh"))]
async fn launch_doh_server(
&self,
_plugin_config: &PluginConfig,
) -> Option<tokio::sync::oneshot::Receiver<()>> {
warn!("DoH server requested but TLS feature is not enabled");
None
}
#[cfg(feature = "dot")]
async fn launch_dot_server(
&self,
plugin_config: &PluginConfig,
) -> Option<tokio::sync::oneshot::Receiver<()>> {
let args = plugin_config.effective_args();
let addr = self.parse_listen_addr(&args, "0.0.0.0:853")?;
let cert_path = args.get("cert_file").and_then(|v| v.as_str());
let key_path = args.get("key_file").and_then(|v| v.as_str());
let (Some(cert_path), Some(key_path)) = (cert_path, key_path) else {
warn!("dot_server plugin configured without cert_file/key_file");
return None;
};
let tls = match TlsConfig::from_files(cert_path, key_path) {
Ok(t) => t,
Err(e) => {
error!("Failed to load TLS config for DoT: {}", e);
return None;
}
};
let entry = self.get_entry(&args);
let handler = self.create_handler(entry);
let server = DotServer::new(addr.to_string(), tls, handler);
let (tx, rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
let _ = tx.send(());
if let Err(e) = server.run().await {
error!("DoT server error: {}", e);
}
});
Some(rx)
}
#[cfg(not(feature = "dot"))]
async fn launch_dot_server(
&self,
_plugin_config: &PluginConfig,
) -> Option<tokio::sync::oneshot::Receiver<()>> {
warn!("DoT server requested but TLS feature is not enabled");
None
}
#[cfg(feature = "doq")]
async fn launch_doq_server(
&self,
plugin_config: &PluginConfig,
) -> Option<tokio::sync::oneshot::Receiver<()>> {
let args = plugin_config.effective_args();
let addr = self.parse_listen_addr(&args, "0.0.0.0:784")?;
let cert_path = args.get("cert_file").and_then(|v| v.as_str());
let key_path = args.get("key_file").and_then(|v| v.as_str());
let (Some(cert_path), Some(key_path)) = (cert_path, key_path) else {
warn!("doq_server plugin configured without cert_file/key_file");
return None;
};
let entry = self.get_entry(&args);
let handler = self.create_handler(entry);
let server = DoqServer::new(addr.to_string(), cert_path, key_path, handler);
let (tx, rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
let _ = tx.send(());
if let Err(e) = server.run().await {
error!("DoQ server error: {}", e);
}
});
Some(rx)
}
#[cfg(not(feature = "doq"))]
async fn launch_doq_server(
&self,
_plugin_config: &PluginConfig,
) -> Option<tokio::sync::oneshot::Receiver<()>> {
warn!("DoQ server requested but DoQ feature is not enabled");
None
}
#[cfg(feature = "admin")]
pub async fn launch_admin_server(
&self,
config: Arc<RwLock<crate::config::Config>>,
) -> Option<tokio::sync::oneshot::Receiver<()>> {
let cfg = config.read().await;
if !cfg.admin.enabled {
return None;
}
let addr = normalize_listen_addr(&cfg.admin.addr);
drop(cfg);
info!("Starting admin API server on {}", addr);
let state = AdminState::new(Arc::clone(&config), Arc::clone(&self.registry));
let server = AdminServer::new(addr, state);
let (tx, rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
info!("Admin server task started");
if let Err(e) = server.run_with_signal(Some(tx), None).await {
error!("Admin server error: {}", e);
}
info!("Admin server task finished");
});
Some(rx)
}
#[cfg(feature = "metrics")]
pub async fn launch_monitoring_server(
&self,
config: Arc<RwLock<crate::config::Config>>,
) -> Option<tokio::sync::oneshot::Receiver<()>> {
let cfg = config.read().await;
if !cfg.monitoring.enabled {
return None;
}
let addr = normalize_listen_addr(&cfg.monitoring.addr);
drop(cfg);
info!("Starting monitoring server on {}", addr);
let server = MonitoringServer::new(addr);
let (startup_tx, startup_rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
info!("Monitoring server task started");
if let Err(e) = server.run_with_signal(Some(startup_tx), None).await {
error!("Monitoring server error: {}", e);
}
info!("Monitoring server task finished");
});
Some(startup_rx)
}
#[cfg(not(feature = "metrics"))]
pub async fn launch_monitoring_server(
&self,
config: Arc<RwLock<crate::config::Config>>,
) -> Option<tokio::sync::oneshot::Receiver<()>> {
if config.read().await.monitoring.enabled {
warn!("Monitoring server requested but metrics feature is not enabled");
}
None
}
#[cfg(not(feature = "admin"))]
pub async fn launch_admin_server(
&self,
config: Arc<RwLock<crate::config::Config>>,
) -> Option<tokio::sync::oneshot::Receiver<()>> {
if config.read().await.admin.enabled {
warn!("Admin server requested but admin feature is not enabled");
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::plugin::{Plugin, Registry};
use async_trait::async_trait;
use serde_yaml::Value;
use std::collections::HashMap;
#[derive(Debug)]
struct MockPlugin;
#[async_trait]
impl Plugin for MockPlugin {
async fn execute(&self, _ctx: &mut crate::plugin::Context) -> crate::Result<()> {
Ok(())
}
fn name(&self) -> &str {
"mock_plugin"
}
}
#[test]
fn test_normalize_listen_addr() {
assert_eq!(normalize_listen_addr(":5353"), "0.0.0.0:5353");
assert_eq!(normalize_listen_addr("127.0.0.1:8080"), "127.0.0.1:8080");
assert_eq!(normalize_listen_addr("0.0.0.0:53"), "0.0.0.0:53");
assert_eq!(normalize_listen_addr("[::1]:53"), "[::1]:53");
assert_eq!(normalize_listen_addr("localhost:8080"), "localhost:8080");
}
#[test]
fn test_parse_listen_addr_with_ipv6() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let mut args = HashMap::new();
args.insert(
"listen".to_string(),
Value::String("[::1]:5353".to_string()),
);
let addr = launcher.parse_listen_addr(&args, "0.0.0.0:53");
assert_eq!(addr, Some("[::1]:5353".parse().unwrap()));
}
#[test]
fn test_parse_listen_addr_with_hostname() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let mut args = HashMap::new();
args.insert(
"listen".to_string(),
Value::String("127.0.0.1:8080".to_string()),
);
let addr = launcher.parse_listen_addr(&args, "0.0.0.0:53");
assert_eq!(addr, Some("127.0.0.1:8080".parse().unwrap()));
}
#[test]
fn test_get_entry_with_non_string_value() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let mut args = HashMap::new();
args.insert(
"entry".to_string(),
Value::Number(serde_yaml::Number::from(42)),
);
let entry = launcher.get_entry(&args);
assert_eq!(entry, "main_sequence");
}
#[test]
fn test_launch_all_with_multiple_plugins() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let plugin1 = crate::config::PluginConfig::new("udp_server".to_string()).with_arg(
"listen".to_string(),
Value::String("127.0.0.1:0".to_string()),
);
let plugin2 = crate::config::PluginConfig::new("unknown_plugin".to_string());
let plugin3 = crate::config::PluginConfig::new("tcp_server".to_string()).with_arg(
"listen".to_string(),
Value::String("127.0.0.1:0".to_string()),
);
let plugins = vec![plugin1, plugin2, plugin3];
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let _receivers = launcher.launch_all(&plugins).await;
});
}
#[test]
fn test_launch_all_with_tcp_server_config() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let plugin_config = crate::config::PluginConfig::new("tcp_server".to_string()).with_arg(
"listen".to_string(),
Value::String("127.0.0.1:0".to_string()),
);
let plugins = vec![plugin_config];
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let _receivers = launcher.launch_all(&plugins).await;
});
}
#[cfg(feature = "metrics")]
#[tokio::test]
async fn test_launch_monitoring_server() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let mut cfg = crate::config::Config::new();
cfg.monitoring.enabled = true;
cfg.monitoring.addr = "127.0.0.1:0".to_string();
let cfg_arc = Arc::new(RwLock::new(cfg));
if let Some(startup_rx) = launcher
.launch_monitoring_server(Arc::clone(&cfg_arc))
.await
{
let started = tokio::time::timeout(std::time::Duration::from_secs(2), startup_rx).await;
assert!(started.is_ok());
} else {
panic!("Expected monitoring server to be launched");
}
}
#[cfg(feature = "metrics")]
#[tokio::test]
async fn test_launch_monitoring_server_with_shorthand_addr() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let mut cfg = crate::config::Config::new();
cfg.monitoring.enabled = true;
cfg.monitoring.addr = ":0".to_string();
let cfg_arc = Arc::new(RwLock::new(cfg));
if let Some(startup_rx) = launcher
.launch_monitoring_server(Arc::clone(&cfg_arc))
.await
{
let started = tokio::time::timeout(std::time::Duration::from_secs(2), startup_rx).await;
assert!(
started.is_ok(),
"Monitoring server should start with shorthand address"
);
} else {
panic!("Expected monitoring server to be launched with shorthand address");
}
}
#[cfg(feature = "admin")]
#[tokio::test]
async fn test_launch_admin_server_with_shorthand_addr() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let mut cfg = crate::config::Config::new();
cfg.admin.enabled = true;
cfg.admin.addr = ":0".to_string();
let cfg_arc = Arc::new(RwLock::new(cfg));
if let Some(startup_rx) = launcher.launch_admin_server(Arc::clone(&cfg_arc)).await {
let started = tokio::time::timeout(std::time::Duration::from_secs(2), startup_rx).await;
assert!(
started.is_ok(),
"Admin server should start with shorthand address"
);
} else {
panic!("Expected admin server to be launched with shorthand address");
}
}
#[cfg(feature = "doh")]
#[test]
fn test_launch_all_with_doh_server_config() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let plugin_config = crate::config::PluginConfig::new("doh_server".to_string())
.with_arg(
"listen".to_string(),
Value::String("127.0.0.1:0".to_string()),
)
.with_arg(
"cert_file".to_string(),
Value::String("/nonexistent/cert.pem".to_string()),
)
.with_arg(
"key_file".to_string(),
Value::String("/nonexistent/key.pem".to_string()),
);
let plugins = vec![plugin_config];
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let _receivers = launcher.launch_all(&plugins).await;
});
}
#[cfg(feature = "dot")]
#[test]
fn test_launch_all_with_dot_server_config() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let plugin_config = crate::config::PluginConfig::new("dot_server".to_string())
.with_arg(
"listen".to_string(),
Value::String("127.0.0.1:0".to_string()),
)
.with_arg(
"cert_file".to_string(),
Value::String("/nonexistent/cert.pem".to_string()),
)
.with_arg(
"key_file".to_string(),
Value::String("/nonexistent/key.pem".to_string()),
);
let plugins = vec![plugin_config];
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let _receivers = launcher.launch_all(&plugins).await;
});
}
#[cfg(feature = "doq")]
#[test]
fn test_launch_all_with_doq_server_config() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let plugin_config = crate::config::PluginConfig::new("doq_server".to_string())
.with_arg(
"listen".to_string(),
Value::String("127.0.0.1:0".to_string()),
)
.with_arg(
"cert_file".to_string(),
Value::String("/nonexistent/cert.pem".to_string()),
)
.with_arg(
"key_file".to_string(),
Value::String("/nonexistent/key.pem".to_string()),
);
let plugins = vec![plugin_config];
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let _receivers = launcher.launch_all(&plugins).await;
});
}
#[test]
fn test_server_launcher_creation() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let _ = launcher;
}
#[test]
fn test_parse_listen_addr_with_valid_address() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let mut args = HashMap::new();
args.insert(
"listen".to_string(),
Value::String("127.0.0.1:5353".to_string()),
);
let addr = launcher.parse_listen_addr(&args, "0.0.0.0:53");
assert_eq!(addr, Some("127.0.0.1:5353".parse().unwrap()));
}
#[test]
fn test_parse_listen_addr_with_shorthand() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let mut args = HashMap::new();
args.insert("listen".to_string(), Value::String(":8080".to_string()));
let addr = launcher.parse_listen_addr(&args, "0.0.0.0:53");
assert_eq!(addr, Some("0.0.0.0:8080".parse().unwrap()));
}
#[test]
fn test_parse_listen_addr_with_default() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let args = HashMap::new();
let addr = launcher.parse_listen_addr(&args, "0.0.0.0:53");
assert_eq!(addr, Some("0.0.0.0:53".parse().unwrap()));
}
#[test]
fn test_parse_listen_addr_with_invalid_address() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let mut args = HashMap::new();
args.insert(
"listen".to_string(),
Value::String("invalid:address".to_string()),
);
let addr = launcher.parse_listen_addr(&args, "0.0.0.0:53");
assert!(addr.is_none());
}
#[test]
fn test_get_entry_with_custom_entry() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let mut args = HashMap::new();
args.insert(
"entry".to_string(),
Value::String("custom_sequence".to_string()),
);
let entry = launcher.get_entry(&args);
assert_eq!(entry, "custom_sequence");
}
#[test]
fn test_get_entry_with_default() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let args = HashMap::new();
let entry = launcher.get_entry(&args);
assert_eq!(entry, "main_sequence");
}
#[test]
fn test_create_handler() {
let mut registry = Registry::new();
registry.register(Arc::new(MockPlugin)).unwrap();
let registry = Arc::new(registry);
let launcher = ServerLauncher::new(Arc::clone(®istry));
let handler = launcher.create_handler("mock_plugin".to_string());
assert_eq!(handler.entry, "mock_plugin");
assert!(Arc::ptr_eq(&handler.registry, ®istry));
}
#[test]
fn test_launch_all_with_empty_plugins() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let plugins = Vec::new();
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let _receivers = launcher.launch_all(&plugins).await;
assert!(_receivers.is_empty());
});
}
#[test]
fn test_launch_all_with_unknown_plugin_type() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let plugin_config = crate::config::PluginConfig::new("unknown_server".to_string());
let plugins = vec![plugin_config];
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let _receivers = launcher.launch_all(&plugins).await;
assert!(_receivers.is_empty());
});
}
#[test]
fn test_launch_all_with_udp_server_config() {
let registry = Arc::new(Registry::new());
let launcher = ServerLauncher::new(registry);
let plugin_config = crate::config::PluginConfig::new("udp_server".to_string()).with_arg(
"listen".to_string(),
serde_yaml::Value::String("127.0.0.1:0".to_string()),
);
let plugins = vec![plugin_config];
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let _receivers = launcher.launch_all(&plugins).await;
assert_eq!(_receivers.len(), 1);
});
}
}