use std::{borrow::Cow, path::PathBuf};
use clap::{Args, Parser, ValueEnum};
use figment::{
providers::{Env, Format, Serialized, Toml, Yaml},
Figment, Provider,
};
use serde::{Deserialize, Deserializer, Serialize};
use tracing::level_filters::LevelFilter;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ConfigLetsEncrypt {
pub email: Cow<'static, str>,
pub enabled: Option<bool>,
pub staging: Option<bool>,
}
impl Default for ConfigLetsEncrypt {
fn default() -> Self {
Self {
email: Cow::Borrowed("contact@example.com"),
enabled: Some(true),
staging: Some(true),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ConfigPath {
pub lets_encrypt: PathBuf,
}
impl Default for ConfigPath {
fn default() -> Self {
Self {
lets_encrypt: PathBuf::from("/etc/proksi/letsencrypt"),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ConfigRouteHeaderAdd {
pub name: Cow<'static, str>,
pub value: Cow<'static, str>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ConfigRouteHeaderRemove {
pub name: Cow<'static, str>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ConfigRouteHeader {
pub add: Vec<ConfigRouteHeaderAdd>,
pub remove: Vec<ConfigRouteHeaderRemove>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ConfigRouteUpstream {
pub ip: Cow<'static, str>,
pub port: i16,
pub network: Option<String>,
pub weight: Option<i8>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ConfigRoute {
pub host: Cow<'static, str>,
pub headers: Option<ConfigRouteHeader>,
pub path_suffix: Option<Cow<'static, str>>,
pub path_prefix: Option<Cow<'static, str>>,
pub upstreams: Vec<ConfigRouteUpstream>,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Copy, ValueEnum)]
pub enum LogLevel {
Debug,
Info,
Warn,
Error,
}
impl From<&LogLevel> for tracing::level_filters::LevelFilter {
fn from(val: &LogLevel) -> Self {
match val {
LogLevel::Debug => LevelFilter::DEBUG,
LogLevel::Info => LevelFilter::INFO,
LogLevel::Warn => LevelFilter::WARN,
LogLevel::Error => LevelFilter::ERROR,
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone, Args)]
#[group(id = "logging", requires = "level")]
pub struct ConfigLogging {
#[serde(deserialize_with = "log_level_deser")]
#[arg(long, required = false, value_enum, default_value = "info")]
pub level: LogLevel,
#[arg(long, required = false, value_parser, default_value = "true")]
pub access_logs_enabled: bool,
#[arg(long, required = false, value_parser, default_value = "false")]
pub error_logs_enabled: bool,
}
#[derive(Debug, Serialize, Deserialize, Parser)]
#[command(name = "Proksi")]
#[command(version, about, long_about = None)]
pub(crate) struct Config {
#[serde(default)]
#[clap(short, long, default_value = "proksi")]
pub service_name: Cow<'static, str>,
#[clap(short, long, default_value = "1")]
pub worker_threads: Option<usize>,
#[serde(skip)]
#[clap(short, long, default_value = "./")]
pub config_path: Cow<'static, str>,
#[command(flatten)]
pub logging: ConfigLogging,
#[clap(skip)]
pub lets_encrypt: ConfigLetsEncrypt,
#[clap(skip)]
pub paths: ConfigPath,
#[clap(skip)]
pub routes: Vec<ConfigRoute>,
}
impl Default for Config {
fn default() -> Self {
Config {
config_path: Cow::Borrowed("/etc/proksi/config"),
service_name: Cow::Borrowed("proksi"),
worker_threads: Some(1),
lets_encrypt: ConfigLetsEncrypt::default(),
routes: vec![],
logging: ConfigLogging {
level: LogLevel::Info,
access_logs_enabled: true,
error_logs_enabled: false,
},
paths: ConfigPath::default(),
}
}
}
impl Provider for Config {
fn metadata(&self) -> figment::Metadata {
figment::Metadata::named("proksi")
}
fn data(
&self,
) -> Result<figment::value::Map<figment::Profile, figment::value::Dict>, figment::Error> {
Serialized::defaults(Config::default()).data()
}
}
pub fn load_proxy_config(fallback: &str) -> Result<Config, figment::Error> {
let parsed_commands = Config::parse();
let path_with_fallback = if parsed_commands.config_path.is_empty() {
fallback
} else {
&parsed_commands.config_path
};
let config: Config = Figment::new()
.merge(Config::default())
.merge(Serialized::defaults(&parsed_commands))
.merge(Yaml::file(format!("{}/proksi.yaml", path_with_fallback)))
.merge(Toml::file(format!("{}/proksi.toml", path_with_fallback)))
.merge(Env::prefixed("PROKSI_").split("__"))
.extract()?;
Ok(config)
}
fn log_level_deser<'de, D>(deserializer: D) -> Result<LogLevel, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
match s.to_lowercase().as_str() {
"debug" => Ok(LogLevel::Debug),
"info" => Ok(LogLevel::Info),
"warn" => Ok(LogLevel::Warn),
"error" => Ok(LogLevel::Error),
_ => Err(serde::de::Error::custom(
"expected one of DEBUG, INFO, WARN, ERROR",
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn helper_config_file() -> &'static str {
r#"
service_name: "proksi"
logging:
level: "INFO"
access_logs_enabled: true
error_logs_enabled: false
paths:
lets_encrypt: "/test/letsencrypt"
routes:
- host: "example.com"
path_prefix: "/api"
headers:
add:
- name: "X-Forwarded-For"
value: "<value>"
- name: "X-Api-Version"
value: "1.0"
remove:
- name: "Server"
upstreams:
- ip: "10.0.1.3/25"
port: 3000
network: "public"
"#
}
#[test]
fn test_load_config_from_yaml() {
figment::Jail::expect_with(|jail| {
let tmp_dir = jail.directory().to_string_lossy();
jail.create_file(format!("{}/proksi.yaml", tmp_dir), helper_config_file())?;
let config = load_proxy_config(&tmp_dir);
let proxy_config = config.unwrap();
assert_eq!(proxy_config.service_name, "proksi");
Ok(())
});
}
#[test]
fn test_load_config_from_yaml_and_env_vars() {
figment::Jail::expect_with(|jail| {
jail.create_file(
format!("{}/proksi.yaml", jail.directory().to_str().unwrap()),
helper_config_file(),
)?;
jail.set_env("PROKSI_SERVICE_NAME", "new_name");
jail.set_env("PROKSI_LOGGING__LEVEL", "warn");
jail.set_env("PROKSI_LETS_ENCRYPT__STAGING", "false");
jail.set_env("PROKSI_LETS_ENCRYPT__EMAIL", "my-real-email@domain.com");
jail.set_env(
"PROKSI_ROUTES",
r#"[{
host="changed.example.com",
upstreams=[{ ip="10.0.1.2/24", port=3000, weight=1 }] }]
"#,
);
let config = load_proxy_config(jail.directory().to_str().unwrap());
let proxy_config = config.unwrap();
assert_eq!(proxy_config.service_name, "new_name");
assert_eq!(proxy_config.logging.level, LogLevel::Warn);
assert_eq!(proxy_config.lets_encrypt.staging, Some(false));
assert_eq!(proxy_config.lets_encrypt.email, "my-real-email@domain.com");
assert_eq!(proxy_config.routes[0].host, "changed.example.com");
assert_eq!(proxy_config.routes[0].upstreams[0].ip, "10.0.1.2/24");
assert_eq!(
proxy_config.paths.lets_encrypt,
PathBuf::from("/test/letsencrypt")
);
Ok(())
});
}
#[test]
fn test_load_config_with_defaults_only() {
figment::Jail::expect_with(|_| {
let config = load_proxy_config("/non-existent");
let proxy_config = config.unwrap();
let logging = proxy_config.logging;
assert_eq!(proxy_config.service_name, "proksi");
assert_eq!(logging.level, LogLevel::Info);
assert_eq!(logging.access_logs_enabled, true);
assert_eq!(logging.error_logs_enabled, false);
print!("{:?}", proxy_config.routes);
assert_eq!(proxy_config.routes.len(), 0);
Ok(())
})
}
#[test]
fn test_load_config_with_defaults_and_yaml() {
figment::Jail::expect_with(|jail| {
let tmp_dir = jail.directory().to_string_lossy();
jail.create_file(
format!("{}/proksi.yaml", tmp_dir),
r#"
routes:
- host: "example.com"
upstreams:
- ip: "10.1.2.24/24"
port: 3000
"#,
)?;
let config = load_proxy_config(&tmp_dir);
let proxy_config = config.unwrap();
let logging = proxy_config.logging;
let paths = proxy_config.paths;
let letsencrypt = proxy_config.lets_encrypt;
assert_eq!(proxy_config.service_name, "proksi");
assert_eq!(logging.level, LogLevel::Info);
assert_eq!(logging.access_logs_enabled, true);
assert_eq!(logging.error_logs_enabled, false);
assert_eq!(proxy_config.routes.len(), 1);
assert_eq!(letsencrypt.email, "contact@example.com");
assert_eq!(letsencrypt.enabled, Some(true));
assert_eq!(letsencrypt.staging, Some(true));
assert_eq!(paths.lets_encrypt.as_os_str(), "/etc/proksi/letsencrypt");
Ok(())
});
}
}