use crate::config::extended_hashmap::ExtendedMap;
use crate::config::tree::{parse_conf, ConfigNode};
use crate::logger::LogLevel;
use crate::proxy::{EqMutex, LoadBalancer};
use crate::rand::Lcg;
use std::collections::HashMap;
use std::env::{args, var};
use std::fs::File;
use std::io::Read;
use std::net::IpAddr;
use std::path::Path;
use std::time::Duration;
#[derive(Debug, PartialEq)]
pub struct Config {
pub source: ConfigSource,
pub address: String,
pub port: u16,
pub threads: usize,
#[cfg(feature = "tls")]
pub tls_config: Option<TlsConfig>,
pub default_websocket_proxy: Option<String>,
pub hosts: Vec<HostConfig>,
pub default_host: HostConfig,
#[cfg(feature = "plugins")]
pub plugins: Vec<PluginConfig>,
pub logging: LoggingConfig,
pub cache: CacheConfig,
pub blacklist: BlacklistConfig,
pub connection_timeout: Option<Duration>,
}
#[derive(Debug, PartialEq)]
pub struct HostConfig {
pub matches: String,
pub routes: Vec<RouteConfig>,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum RouteType {
File,
Directory,
Proxy,
Redirect,
ExclusiveWebSocket,
}
#[derive(Debug, PartialEq)]
pub struct RouteConfig {
pub route_type: RouteType,
pub matches: String,
pub path: Option<String>,
pub load_balancer: Option<EqMutex<LoadBalancer>>,
pub websocket_proxy: Option<String>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LoggingConfig {
pub level: LogLevel,
pub console: bool,
pub file: Option<String>,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
pub struct CacheConfig {
pub size_limit: usize,
pub time_limit: usize,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct BlacklistConfig {
pub list: Vec<IpAddr>,
pub mode: BlacklistMode,
}
#[cfg(feature = "tls")]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct TlsConfig {
pub cert_file: String,
pub key_file: String,
pub force: bool,
}
#[cfg(feature = "plugins")]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct PluginConfig {
pub name: String,
pub library: String,
pub config: HashMap<String, String>,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum LoadBalancerMode {
RoundRobin,
Random,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum BlacklistMode {
Block,
Forbidden,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ConfigSource {
Argument,
EnvironmentVariable,
CurrentDirectory,
Default,
}
impl Config {
pub fn load() -> Result<Self, String> {
let (path, source) = if let Some(arg_path) = args().nth(1) {
(arg_path, ConfigSource::Argument)
} else if Path::new("humphrey.conf").exists() {
("humphrey.conf".into(), ConfigSource::CurrentDirectory)
} else if let Ok(env_path) = var("HUMPHREY_CONF") {
(env_path, ConfigSource::EnvironmentVariable)
} else {
("".into(), ConfigSource::Default)
};
if let Ok((filename, config_string)) = load_config_file(path) {
let tree = parse_conf(&config_string, &filename).map_err(|e| e.to_string())?;
let mut config = Self::from_tree(tree)?;
config.source = source;
Ok(config)
} else {
Ok(Config::default())
}
}
pub fn from_tree(tree: ConfigNode) -> Result<Self, &'static str> {
let mut hashmap: HashMap<String, ConfigNode> = HashMap::new();
tree.flatten(&mut hashmap, &Vec::new());
let address = hashmap.get_optional("server.address", "0.0.0.0".into());
let port: u16 = hashmap.get_optional_parsed("server.port", 80, "Invalid port")?;
let threads: usize =
hashmap.get_optional_parsed("server.threads", 32, "Invalid number of threads")?;
let default_websocket_proxy = hashmap.get_owned("server.websocket");
let connection_timeout_seconds: u64 =
hashmap.get_optional_parsed("server.timeout", 0, "Invalid connection timeout")?;
let connection_timeout = if connection_timeout_seconds > 0 {
Some(Duration::from_secs(connection_timeout_seconds))
} else {
None
};
if threads < 1 {
return Err("You cannot specify less than 1 thread");
}
let blacklist = {
let blacklist_strings: Vec<String> =
load_list_file(hashmap.get_owned("server.blacklist.file"))?;
let mut blacklist: Vec<IpAddr> = Vec::with_capacity(blacklist_strings.len());
for ip in blacklist_strings {
blacklist.push(
ip.parse::<IpAddr>()
.map_err(|_| "Could not parse IP address in blacklist file")?,
);
}
let blacklist_mode = hashmap.get_optional("server.blacklist.mode", "block".into());
let blacklist_mode = match blacklist_mode.as_ref() {
"block" => BlacklistMode::Block,
"forbidden" => BlacklistMode::Forbidden,
_ => return Err("Invalid blacklist mode"),
};
BlacklistConfig {
list: blacklist,
mode: blacklist_mode,
}
};
#[cfg(feature = "tls")]
let tls_config = {
let cert_file = hashmap.get_owned("server.tls.cert_file");
let key_file = hashmap.get_owned("server.tls.key_file");
let force = hashmap.get_optional("server.tls.force", "false".into());
if force == "true" && threads < 2 {
return Err("A minimum of two threads are required to force HTTPS");
}
if force == "true" && port != 443 {
return Err("Forcing HTTPS redirects requires the port to be 443");
}
if let Some(cert_file) = cert_file {
if let Some(key_file) = key_file {
Some(TlsConfig {
cert_file,
key_file,
force: force == "true",
})
} else {
return Err("Missing key file for TLS");
}
} else {
None
}
};
let logging = {
let log_level = hashmap.get_optional_parsed(
"server.log.level",
LogLevel::Warn,
"Invalid log level",
)?;
let log_file = hashmap.get_owned("server.log.file");
let log_console = hashmap.get_optional_parsed(
"server.log.console",
true,
"server.log.console must be a boolean",
)?;
LoggingConfig {
level: log_level,
console: log_console,
file: log_file,
}
};
let cache = {
let cache_size =
hashmap.get_optional_parsed("server.cache.size", 0_usize, "Invalid cache size")?;
let cache_time =
hashmap.get_optional_parsed("server.cache.time", 0_usize, "Invalid cache time")?;
CacheConfig {
size_limit: cache_size,
time_limit: cache_time,
}
};
let default_host = parse_host("*", &tree)?;
let hosts = {
let hosts_map = tree.get_hosts();
let mut hosts: Vec<HostConfig> = Vec::with_capacity(hosts_map.len());
for (host, conf) in hosts_map {
hosts.push(parse_host(&host, &conf)?);
}
hosts
};
#[cfg(feature = "plugins")]
let plugins = {
let plugins_map = tree.get_plugins();
let mut plugins: Vec<PluginConfig> = Vec::new();
for (name, conf) in plugins_map {
let library = conf.get_compulsory("library", "Plugin library not specified")?;
let mut additional_config: HashMap<String, String> = conf
.clone()
.iter()
.map(|(k, v)| (k.clone(), v.get_string().unwrap()))
.collect();
additional_config.remove("library").unwrap();
plugins.push(PluginConfig {
name,
library,
config: additional_config,
})
}
plugins
};
Ok(Config {
source: ConfigSource::Default,
address,
port,
threads,
#[cfg(feature = "tls")]
tls_config,
default_websocket_proxy,
default_host,
hosts,
#[cfg(feature = "plugins")]
plugins,
logging,
cache,
blacklist,
connection_timeout,
})
}
pub fn get_route(&self, host: usize, route: usize) -> &RouteConfig {
if host == 0 {
&self.default_host.routes[route]
} else {
&self.hosts[host - 1].routes[route]
}
}
}
fn load_config_file(path: impl AsRef<str>) -> Result<(String, String), ()> {
if let Ok(mut file) = File::open(path.as_ref()) {
let mut string = String::new();
if file.read_to_string(&mut string).is_ok() {
Ok((path.as_ref().to_string(), string))
} else {
Err(())
}
} else {
Err(())
}
}
fn load_list_file(path: Option<String>) -> Result<Vec<String>, &'static str> {
if let Some(path) = path {
let mut file = File::open(path).map_err(|_| "List file could not be opened")?;
let mut buf = String::new();
file.read_to_string(&mut buf)
.map_err(|_| "List file could not be read")?;
let list: Vec<String> = buf.lines().map(|s| s.to_string()).collect();
Ok(list)
} else {
Ok(Vec::new())
}
}
fn parse_host(wild: &str, node: &ConfigNode) -> Result<HostConfig, &'static str> {
let routes_map = node.get_routes();
let mut routes: Vec<RouteConfig> = Vec::with_capacity(routes_map.len());
for (wild, conf) in routes_map {
routes.extend(parse_route(&wild, conf)?);
}
Ok(HostConfig {
matches: wild.to_string(),
routes,
})
}
fn parse_route(
wild: &str,
conf: HashMap<String, ConfigNode>,
) -> Result<Vec<RouteConfig>, &'static str> {
let mut routes: Vec<RouteConfig> = Vec::new();
for wild in wild.split(',').map(|s| s.trim()) {
let websocket_proxy = conf.get_owned("websocket");
if conf.contains_key("file") {
let file = conf.get_compulsory("file", "").unwrap();
routes.push(RouteConfig {
route_type: RouteType::File,
matches: wild.to_string(),
path: Some(file),
load_balancer: None,
websocket_proxy,
});
} else if conf.contains_key("directory") {
let directory = conf.get_compulsory("directory", "").unwrap();
routes.push(RouteConfig {
route_type: RouteType::Directory,
matches: wild.to_string(),
path: Some(directory),
load_balancer: None,
websocket_proxy,
});
} else if conf.contains_key("proxy") {
let targets: Vec<String> = conf
.get_compulsory("proxy", "")
.unwrap()
.split(',')
.map(|s| s.to_owned())
.collect();
let load_balancer_mode = conf.get_optional("load_balancer_mode", "round-robin".into());
let load_balancer_mode =
match load_balancer_mode.as_str() {
"round-robin" => LoadBalancerMode::RoundRobin,
"random" => LoadBalancerMode::Random,
_ => return Err(
"Invalid load balancer mode, valid options are `round-robin` or `random`",
),
};
let load_balancer = EqMutex::new(LoadBalancer {
targets,
mode: load_balancer_mode,
lcg: Lcg::new(),
index: 0,
});
routes.push(RouteConfig {
route_type: RouteType::Proxy,
matches: wild.to_string(),
path: None,
load_balancer: Some(load_balancer),
websocket_proxy,
});
} else if conf.contains_key("redirect") {
let target = conf.get_compulsory("redirect", "").unwrap();
routes.push(RouteConfig {
route_type: RouteType::Redirect,
matches: wild.to_string(),
path: Some(target),
load_balancer: None,
websocket_proxy,
});
} else if !conf.contains_key("websocket") {
return Err("Invalid route configuration, every route must contain either the `file`, `directory`, `proxy` or `redirect` field, unless it defines a WebSocket proxy with the `websocket` field");
} else {
routes.push(RouteConfig {
route_type: RouteType::ExclusiveWebSocket,
matches: wild.to_string(),
path: None,
load_balancer: None,
websocket_proxy,
});
}
}
Ok(routes)
}