use std::collections::HashMap;
use std::net::{SocketAddr, ToSocketAddrs};
use std::path::Path;
use std::sync::Arc;
use notify::{Event, RecommendedWatcher, RecursiveMode, Watcher};
use serde::Deserialize;
use tokio::sync::{RwLock, mpsc};
use crate::error::{Error, Result};
#[derive(Debug, Clone, Deserialize, Default, PartialEq, Eq)]
#[serde(default)]
pub struct TlsFileConfig {
pub cert: Option<String>,
pub key: Option<String>,
pub self_signed: bool,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ServerFileConfig {
pub listen: String,
pub default_target: Option<String>,
#[serde(default)]
pub routes: HashMap<String, String>,
#[serde(default)]
pub tls: TlsFileConfig,
}
impl ServerFileConfig {
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let content = std::fs::read_to_string(path.as_ref()).map_err(|e| {
Error::config(format!(
"failed to read config file '{}': {}",
path.as_ref().display(),
e
))
})?;
Self::parse(&content)
}
pub fn parse(content: &str) -> Result<Self> {
let config: ServerFileConfig =
toml::from_str(content).map_err(|e| Error::config(format!("invalid config: {}", e)))?;
config.validate()?;
Ok(config)
}
fn validate(&self) -> Result<()> {
resolve_addr(&self.listen).map_err(|e| {
Error::config(format!("invalid listen address '{}': {}", self.listen, e))
})?;
if self.tls.self_signed && (self.tls.cert.is_some() || self.tls.key.is_some()) {
return Err(Error::config("cannot use self_signed with cert/key files"));
}
if self.tls.cert.is_some() != self.tls.key.is_some() {
return Err(Error::config("both cert and key must be provided"));
}
if self.routes.is_empty() && self.default_target.is_none() {
return Err(Error::config(
"at least one route or default_target is required",
));
}
Ok(())
}
pub fn has_tls(&self) -> bool {
self.tls.self_signed || self.tls.cert.is_some()
}
pub fn only_routing_changed(&self, other: &Self) -> bool {
self.listen == other.listen && self.tls == other.tls
}
}
fn resolve_addr(addr: impl ToSocketAddrs) -> std::io::Result<SocketAddr> {
addr.to_socket_addrs()?.next().ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::NotFound, "could not resolve address")
})
}
#[derive(Debug, Clone)]
pub struct ResolvedConfig {
pub listen_addr: SocketAddr,
pub default_target: Option<String>,
pub routes: HashMap<String, String>,
}
impl ResolvedConfig {
pub fn from_file_config(config: &ServerFileConfig) -> Result<Self> {
let listen_addr = resolve_addr(&config.listen)
.map_err(|e| Error::config(format!("invalid listen address: {}", e)))?;
Ok(Self {
listen_addr,
default_target: config.default_target.clone(),
routes: config.routes.clone(),
})
}
}
pub type SharedRoutingConfig = Arc<RwLock<ResolvedConfig>>;
pub fn create_shared_config(config: &ServerFileConfig) -> Result<SharedRoutingConfig> {
let resolved = ResolvedConfig::from_file_config(config)?;
Ok(Arc::new(RwLock::new(resolved)))
}
#[derive(Debug, Clone)]
pub enum ConfigChange {
RoutingOnly(ServerFileConfig),
FullRestart(ServerFileConfig),
Error(String),
}
pub struct ConfigWatcher {
_watcher: RecommendedWatcher,
receiver: mpsc::Receiver<ConfigChange>,
}
impl ConfigWatcher {
pub fn new(config_path: impl AsRef<Path>, current_config: ServerFileConfig) -> Result<Self> {
let path = config_path.as_ref().to_path_buf();
let (tx, rx) = mpsc::channel(16);
let current = Arc::new(std::sync::Mutex::new(current_config));
let tx_clone = tx.clone();
let path_clone = path.clone();
let current_clone = Arc::clone(¤t);
let mut watcher = notify::recommended_watcher(move |res: std::result::Result<Event, _>| {
if let Ok(event) = res {
if !matches!(
event.kind,
notify::EventKind::Modify(_) | notify::EventKind::Create(_)
) {
return;
}
if !event.paths.iter().any(|p| p.ends_with(&path_clone)) {
return;
}
std::thread::sleep(std::time::Duration::from_millis(50));
match ServerFileConfig::load(&path_clone) {
Ok(new_config) => {
let old_config = current_clone.lock().unwrap();
let change = if old_config.only_routing_changed(&new_config) {
ConfigChange::RoutingOnly(new_config.clone())
} else {
ConfigChange::FullRestart(new_config.clone())
};
drop(old_config);
*current_clone.lock().unwrap() = new_config;
let _ = tx_clone.blocking_send(change);
}
Err(e) => {
let _ = tx_clone.blocking_send(ConfigChange::Error(e.to_string()));
}
}
}
})
.map_err(|e| Error::config(format!("failed to create file watcher: {}", e)))?;
let watch_path = path.parent().unwrap_or(&path);
watcher
.watch(watch_path, RecursiveMode::NonRecursive)
.map_err(|e| Error::config(format!("failed to watch config file: {}", e)))?;
Ok(Self {
_watcher: watcher,
receiver: rx,
})
}
pub async fn recv(&mut self) -> Option<ConfigChange> {
self.receiver.recv().await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_minimal_config() {
let config = ServerFileConfig::parse(
r#"
listen = "0.0.0.0:8080"
default_target = "127.0.0.1:22"
"#,
)
.unwrap();
assert_eq!(config.listen, "0.0.0.0:8080");
assert_eq!(config.default_target, Some("127.0.0.1:22".to_string()));
assert!(config.routes.is_empty());
}
#[test]
fn test_parse_full_config() {
let config = ServerFileConfig::parse(
r#"
listen = "0.0.0.0:8443"
default_target = "127.0.0.1:22"
[routes]
"/ssh" = "127.0.0.1:22"
"/db" = "127.0.0.1:5432"
[tls]
cert = "cert.pem"
key = "key.pem"
"#,
)
.unwrap();
assert_eq!(config.listen, "0.0.0.0:8443");
assert_eq!(config.routes.get("/ssh"), Some(&"127.0.0.1:22".to_string()));
assert_eq!(
config.routes.get("/db"),
Some(&"127.0.0.1:5432".to_string())
);
assert!(config.has_tls());
}
#[test]
fn test_parse_self_signed_tls() {
let config = ServerFileConfig::parse(
r#"
listen = "0.0.0.0:8443"
default_target = "127.0.0.1:22"
[tls]
self_signed = true
"#,
)
.unwrap();
assert!(config.tls.self_signed);
assert!(config.has_tls());
}
#[test]
fn test_invalid_listen_address() {
let result = ServerFileConfig::parse(
r#"
listen = "invalid"
default_target = "127.0.0.1:22"
"#,
);
assert!(result.is_err());
}
#[test]
fn test_no_routes_or_default() {
let result = ServerFileConfig::parse(
r#"
listen = "0.0.0.0:8080"
"#,
);
assert!(result.is_err());
}
#[test]
fn test_conflicting_tls_config() {
let result = ServerFileConfig::parse(
r#"
listen = "0.0.0.0:8080"
default_target = "127.0.0.1:22"
[tls]
cert = "cert.pem"
key = "key.pem"
self_signed = true
"#,
);
assert!(result.is_err());
}
#[test]
fn test_only_routing_changed() {
let config1 = ServerFileConfig::parse(
r#"
listen = "0.0.0.0:8080"
default_target = "127.0.0.1:22"
"#,
)
.unwrap();
let config2 = ServerFileConfig::parse(
r#"
listen = "0.0.0.0:8080"
default_target = "127.0.0.1:23"
"#,
)
.unwrap();
let config3 = ServerFileConfig::parse(
r#"
listen = "0.0.0.0:9090"
default_target = "127.0.0.1:22"
"#,
)
.unwrap();
assert!(config1.only_routing_changed(&config2));
assert!(!config1.only_routing_changed(&config3));
}
}