use std::{borrow::Cow, collections::HashMap, path::PathBuf};
use clap::{Args, Parser, ValueEnum};
use figment::{
providers::{Env, Format, Serialized, Yaml},
Figment, Provider,
};
use hcl::Hcl;
use serde::{Deserialize, Deserializer, Serialize};
use tracing::level_filters::LevelFilter;
mod hcl;
mod validate;
fn bool_true() -> bool {
true
}
fn default_proto_version() -> ProtoVersion {
ProtoVersion::V1_3
}
fn default_proto_version_min() -> ProtoVersion {
ProtoVersion::V1_2
}
fn default_stale_secs() -> u32 {
60
}
fn default_cache_expire_secs() -> u64 {
3600
}
fn default_cache_type() -> RouteCacheType {
RouteCacheType::MemCache
}
#[derive(Debug, Serialize, Deserialize, Clone, ValueEnum)]
pub(crate) enum DockerServiceMode {
Swarm,
Container,
}
#[derive(Debug, Serialize, Deserialize, Clone, Args)]
#[group(id = "docker", requires = "level")]
pub struct Docker {
#[arg(
long = "docker.interval_secs",
required = false,
value_parser,
default_value = "15"
)]
pub interval_secs: Option<u64>,
#[arg(
long = "docker.endpoint",
required = false,
value_parser,
default_value = "unix:///var/run/docker.sock"
)]
pub endpoint: Option<Cow<'static, str>>,
#[arg(
long = "docker.enabled",
required = false,
value_parser,
default_value = "false",
id = "docker.enabled"
)]
pub enabled: Option<bool>,
#[serde(deserialize_with = "docker_mode_deser")]
#[arg(
long = "docker.mode",
required = false,
value_enum,
default_value = "container"
)]
pub mode: DockerServiceMode,
}
impl Default for Docker {
fn default() -> Self {
Self {
interval_secs: Some(15),
endpoint: Some(Cow::Borrowed("unix:///var/run/docker.sock")),
enabled: Some(false),
mode: DockerServiceMode::Container,
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LetsEncrypt {
pub email: Cow<'static, str>,
pub enabled: Option<bool>,
pub staging: Option<bool>,
}
impl Default for LetsEncrypt {
fn default() -> Self {
Self {
email: Cow::Borrowed("contact@example.com"),
enabled: Some(true),
staging: Some(true),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Path {
pub lets_encrypt: PathBuf,
}
impl Default for Path {
fn default() -> Self {
Self {
lets_encrypt: PathBuf::from("/etc/proksi/letsencrypt"),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RouteHeaderAdd {
pub name: Cow<'static, str>,
pub value: Cow<'static, str>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RouteHeaderRemove {
pub name: Cow<'static, str>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RouteHeader {
pub add: Option<Vec<RouteHeaderAdd>>,
pub remove: Option<Vec<RouteHeaderRemove>>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RouteUpstream {
pub ip: Cow<'static, str>,
pub port: u16,
pub network: Option<String>,
pub weight: Option<i8>,
pub sni: Option<String>,
pub headers: Option<RouteHeader>,
}
impl Default for RouteUpstream {
fn default() -> Self {
RouteUpstream {
ip: Cow::Borrowed("127.0.0.1"),
port: 80,
network: None,
weight: None,
sni: None,
headers: None,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RouteSslCertificate {
pub self_signed_on_failure: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RoutePathMatcher {
pub patterns: Vec<Cow<'static, str>>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RouteMatcher {
pub path: Option<RoutePathMatcher>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RoutePlugin {
pub name: Cow<'static, str>,
pub config: Option<HashMap<Cow<'static, str>, serde_json::Value>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RouteSslPath {
pub key: PathBuf,
pub pem: PathBuf,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ProtoVersion {
V1_1,
V1_2,
V1_3,
}
impl From<pingora::tls::ssl::SslVersion> for ProtoVersion {
fn from(v: pingora::tls::ssl::SslVersion) -> Self {
match v {
pingora::tls::ssl::SslVersion::TLS1_1 => ProtoVersion::V1_1,
pingora::tls::ssl::SslVersion::TLS1_2 => ProtoVersion::V1_2,
_ => ProtoVersion::V1_3,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RouteSsl {
pub path: Option<RouteSslPath>,
#[serde(
default = "default_proto_version_min",
deserialize_with = "proto_version_deser"
)]
pub min_proto: ProtoVersion,
#[serde(
default = "default_proto_version",
deserialize_with = "proto_version_deser"
)]
pub max_proto: ProtoVersion,
#[serde(default = "bool_true")]
pub self_signed_fallback: bool,
}
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
pub enum RouteCacheType {
Disk,
MemCache,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RouteCache {
pub enabled: Option<bool>,
#[serde(
default = "default_cache_type",
deserialize_with = "deserialize_cache_type"
)]
pub cache_type: RouteCacheType,
#[serde(default = "default_cache_expire_secs")]
pub expires_in_secs: u64,
#[serde(default = "default_stale_secs")]
pub stale_if_error_secs: u32,
#[serde(default = "default_stale_secs")]
pub stale_while_revalidate_secs: u32,
pub path: PathBuf,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Route {
pub host: Cow<'static, str>,
pub cache: Option<RouteCache>,
pub plugins: Option<Vec<RoutePlugin>>,
pub ssl_certificate: Option<RouteSslCertificate>,
pub ssl: Option<RouteSsl>,
pub headers: Option<RouteHeader>,
pub upstreams: Vec<RouteUpstream>,
pub match_with: Option<RouteMatcher>,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, ValueEnum)]
pub enum LogLevel {
Debug,
Info,
Warn,
Error,
Trace,
}
impl From<&LogLevel> for tracing::level_filters::LevelFilter {
fn from(val: &LogLevel) -> Self {
match val {
LogLevel::Debug => LevelFilter::DEBUG,
LogLevel::Info => LevelFilter::INFO,
LogLevel::Warn => LevelFilter::WARN,
LogLevel::Error => LevelFilter::ERROR,
LogLevel::Trace => LevelFilter::TRACE,
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, ValueEnum)]
pub enum LogFormat {
Json,
Pretty,
}
#[derive(Debug, Serialize, Deserialize, Clone, Args)]
#[group(id = "logging", requires = "level")]
pub struct Logging {
#[arg(
long = "log.enabled",
required = false,
value_parser,
default_value = "true",
id = "log.enabled"
)]
pub enabled: bool,
#[serde(deserialize_with = "log_level_deser")]
#[arg(
long = "log.level",
required = false,
value_enum,
default_value = "info"
)]
pub level: LogLevel,
#[arg(
long = "log.access_logs_enabled",
required = false,
value_parser,
default_value = "true"
)]
pub access_logs_enabled: bool,
#[arg(
long = "log.error_logs_enabled",
required = false,
value_parser,
default_value = "true"
)]
pub error_logs_enabled: bool,
#[serde(deserialize_with = "log_format_deser")]
#[arg(
long = "log.format",
required = false,
value_enum,
default_value = "json"
)]
pub format: LogFormat,
}
#[derive(Debug, Serialize, Deserialize, Parser)]
#[command(name = "Proksi")]
#[command(version, about, long_about = None)]
pub(crate) struct Config {
#[serde(default)]
#[clap(short, long, default_value = "proksi")]
pub service_name: Cow<'static, str>,
#[clap(short, long, default_value = "false")]
pub daemon: bool,
#[clap(short, long, default_value = "1")]
pub worker_threads: Option<usize>,
#[serde(skip)]
#[clap(short, long, default_value = "./")]
#[allow(clippy::struct_field_names)]
pub config_path: Cow<'static, str>,
#[command(flatten)]
pub logging: Logging,
#[command(flatten)]
pub docker: Docker,
#[clap(skip)]
pub lets_encrypt: LetsEncrypt,
#[clap(skip)]
pub paths: Path,
#[clap(skip)]
pub routes: Vec<Route>,
}
impl Default for Config {
fn default() -> Self {
Config {
config_path: Cow::Borrowed("/etc/proksi/config"),
service_name: Cow::Borrowed("proksi"),
worker_threads: Some(1),
daemon: false,
docker: Docker::default(),
lets_encrypt: LetsEncrypt::default(),
routes: vec![],
logging: Logging {
enabled: true,
level: LogLevel::Info,
access_logs_enabled: true,
error_logs_enabled: false,
format: LogFormat::Json,
},
paths: Path::default(),
}
}
}
impl Provider for Config {
fn metadata(&self) -> figment::Metadata {
figment::Metadata::named("proksi")
}
fn data(
&self,
) -> Result<figment::value::Map<figment::Profile, figment::value::Dict>, figment::Error> {
Serialized::defaults(Config::default()).data()
}
}
pub fn load(fallback: &str) -> Result<Config, figment::Error> {
let parsed_commands = Config::parse();
let path_with_fallback = if parsed_commands.config_path.is_empty() {
fallback
} else {
&parsed_commands.config_path
};
let config: Config = Figment::new()
.merge(Config::default())
.merge(Serialized::defaults(&parsed_commands))
.merge(Yaml::file(format!("{path_with_fallback}/proksi.yml")))
.merge(Yaml::file(format!("{path_with_fallback}/proksi.yaml")))
.merge(Hcl::file(format!("{path_with_fallback}/proksi.hcl")))
.merge(Env::prefixed("PROKSI_").split("__"))
.extract()?;
validate::check_config(&config).map_err(|err| figment::Error::from(err.to_string()))?;
Ok(config)
}
fn log_level_deser<'de, D>(deserializer: D) -> Result<LogLevel, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
match s.to_lowercase().as_str() {
"debug" => Ok(LogLevel::Debug),
"info" => Ok(LogLevel::Info),
"warn" => Ok(LogLevel::Warn),
"error" => Ok(LogLevel::Error),
"trace" => Ok(LogLevel::Trace),
_ => Err(serde::de::Error::custom(
"expected one of DEBUG, INFO, WARN, ERROR, TRACE",
)),
}
}
fn docker_mode_deser<'de, D>(deserializer: D) -> Result<DockerServiceMode, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
match s.to_lowercase().as_str() {
"swarm" => Ok(DockerServiceMode::Swarm),
"container" => Ok(DockerServiceMode::Container),
_ => Err(serde::de::Error::custom(
"expected one of: Swarm, Container",
)),
}
}
fn log_format_deser<'de, D>(deserializer: D) -> Result<LogFormat, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
match s.to_lowercase().as_str() {
"json" => Ok(LogFormat::Json),
"pretty" => Ok(LogFormat::Pretty),
_ => Err(serde::de::Error::custom("expected one of: json, pretty")),
}
}
fn proto_version_deser<'de, D>(deserializer: D) -> Result<ProtoVersion, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
match s.to_lowercase().as_str() {
"v1.1" => Ok(ProtoVersion::V1_1),
"v1.2" => Ok(ProtoVersion::V1_2),
"v1.3" => Ok(ProtoVersion::V1_3),
_ => Err(serde::de::Error::custom(
"expected one of: v1.1, v1.2, v1.3",
)),
}
}
fn deserialize_cache_type<'de, D>(deserializer: D) -> Result<RouteCacheType, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
match s.to_lowercase().as_str() {
"disk" => Ok(RouteCacheType::Disk),
"memcache" => Ok(RouteCacheType::MemCache),
_ => Err(serde::de::Error::custom("expected one of: disk, memcache")),
}
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
fn helper_config_file() -> &'static str {
r#"
service_name: "proksi"
lets_encrypt:
email: "user@domain.net"
logging:
level: "INFO"
access_logs_enabled: true
error_logs_enabled: false
paths:
lets_encrypt: "/test/letsencrypt"
routes:
- host: "example.com"
plugins:
- name: "cors"
config:
allowed_origins: ["*"]
headers:
add:
- name: "X-Forwarded-For"
value: "<value>"
- name: "X-Api-Version"
value: "1.0"
remove:
- name: "Server"
upstreams:
- ip: "10.0.1.3/25"
port: 3000
network: "public"
"#
}
#[test]
fn test_load_config_from_yaml() {
figment::Jail::expect_with(|jail| {
let tmp_dir = jail.directory().to_string_lossy();
jail.create_file(format!("{}/proksi.yaml", tmp_dir), helper_config_file())?;
let config = load(&tmp_dir);
let proxy_config = config.unwrap();
assert_eq!(proxy_config.service_name, "proksi");
Ok(())
});
}
#[test]
fn test_load_config_from_yaml_and_env_vars() {
figment::Jail::expect_with(|jail| {
jail.create_file(
format!("{}/proksi.yaml", jail.directory().to_str().unwrap()),
helper_config_file(),
)?;
jail.set_env("PROKSI_SERVICE_NAME", "new_name");
jail.set_env("PROKSI_LOGGING__ENABLED", "false");
jail.set_env("PROKSI_LOGGING__LEVEL", "warn");
jail.set_env("PROKSI_DOCKER__ENABLED", "true");
jail.set_env("PROKSI_DOCKER__INTERVAL_SECS", "30");
jail.set_env("PROKSI_DOCKER__ENDPOINT", "http://localhost:2375");
jail.set_env("PROKSI_LETS_ENCRYPT__STAGING", "false");
jail.set_env("PROKSI_LETS_ENCRYPT__EMAIL", "my-real-email@domain.com");
jail.set_env(
"PROKSI_ROUTES",
r#"[{
host="changed.example.com",
match_with={ path={ patterns=["/api/v1/:entity/:action*"] } },
plugins=[{ name="cors", config={ allowed_origins=["*"] } }],
upstreams=[{ ip="10.0.1.2/24", port=3000, weight=1 }] }]
"#,
);
let config = load(jail.directory().to_str().unwrap());
let proxy_config = config.unwrap();
assert_eq!(proxy_config.service_name, "new_name");
assert!(!proxy_config.logging.enabled);
assert_eq!(proxy_config.logging.level, LogLevel::Warn);
assert_eq!(proxy_config.docker.enabled, Some(true));
assert_eq!(proxy_config.docker.interval_secs, Some(30));
assert_eq!(
proxy_config.docker.endpoint,
Some(Cow::Borrowed("http://localhost:2375"))
);
assert_eq!(proxy_config.lets_encrypt.staging, Some(false));
assert_eq!(proxy_config.lets_encrypt.email, "my-real-email@domain.com");
assert_eq!(proxy_config.routes[0].host, "changed.example.com");
assert_eq!(proxy_config.routes[0].upstreams[0].ip, "10.0.1.2/24");
let matcher = proxy_config.routes[0].match_with.as_ref().unwrap();
assert_eq!(
matcher.path.as_ref().unwrap().patterns,
vec![Cow::Borrowed("/api/v1/:entity/:action*")]
);
assert_eq!(
proxy_config.paths.lets_encrypt,
PathBuf::from("/test/letsencrypt")
);
Ok(())
});
}
#[test]
fn test_load_config_with_defaults_only() {
figment::Jail::expect_with(|jail| {
jail.set_env("PROKSI_LETS_ENCRYPT__EMAIL", "my-real-email@domain.com");
let config = load("/non-existent");
let proxy_config = config.unwrap();
let logging = proxy_config.logging;
assert_eq!(proxy_config.service_name, "proksi");
assert_eq!(logging.level, LogLevel::Info);
assert!(logging.access_logs_enabled);
assert!(logging.error_logs_enabled);
assert_eq!(proxy_config.routes.len(), 0);
Ok(())
})
}
#[test]
fn test_load_config_with_defaults_and_yaml() {
figment::Jail::expect_with(|jail| {
let tmp_dir = jail.directory().to_string_lossy();
jail.create_file(
format!("{}/proksi.yaml", tmp_dir),
r#"
lets_encrypt:
email: "domain@valid.com"
routes:
- host: "example.com"
upstreams:
- ip: "10.1.2.24/24"
port: 3000
plugins:
- name: "cors"
config:
allowed_origins: ["*"]
ssl:
path:
key: "/etc/proksi/certs/my-host.key"
pem: "/etc/proksi/certs/my-host.pem"
"#,
)?;
let config = load(&tmp_dir);
let proxy_config = config.unwrap();
let logging = proxy_config.logging;
let paths = proxy_config.paths;
let letsencrypt = proxy_config.lets_encrypt;
assert_eq!(proxy_config.service_name, "proksi");
assert_eq!(logging.level, LogLevel::Info);
assert!(logging.access_logs_enabled);
assert!(logging.error_logs_enabled);
assert_eq!(proxy_config.routes.len(), 1);
assert_eq!(proxy_config.docker.enabled, Some(false));
assert_eq!(proxy_config.docker.interval_secs, Some(15));
assert_eq!(
proxy_config.docker.endpoint,
Some(Cow::Borrowed("unix:///var/run/docker.sock"))
);
assert_eq!(letsencrypt.email, "domain@valid.com");
assert_eq!(letsencrypt.enabled, Some(true));
assert_eq!(letsencrypt.staging, Some(true));
assert_eq!(paths.lets_encrypt.as_os_str(), "/etc/proksi/letsencrypt");
let route = &proxy_config.routes[0];
let plugins = route.plugins.as_ref().unwrap();
let plugin_config = plugins[0].config.as_ref().unwrap();
assert_eq!(plugins[0].name, "cors");
assert_eq!(plugin_config.get("allowed_origins"), Some(&json!(["*"])));
let ssl = route.ssl.as_ref().unwrap();
let path = ssl.path.as_ref().unwrap();
assert_eq!(ssl.self_signed_fallback, true);
assert_eq!(path.key.as_os_str(), "/etc/proksi/certs/my-host.key");
assert_eq!(path.pem.as_os_str(), "/etc/proksi/certs/my-host.pem");
Ok(())
});
}
#[test]
fn test_load_config_from_hcl() {
figment::Jail::expect_with(|jail| {
let tmp_dir = jail.directory().to_string_lossy();
jail.create_file(
format!("{}/proksi.hcl", tmp_dir),
r#"
service_name = "hcl-service"
worker_threads = 8
docker {
enabled = true
interval_secs = 30
endpoint = "unix:///var/run/docker.sock"
}
lets_encrypt {
email = "domain@valid.com"
enabled = true
staging = false
}
paths {
lets_encrypt = "/etc/proksi/letsencrypt"
}
"#,
)?;
let config = load(&tmp_dir);
let proxy_config = config.unwrap();
assert_eq!(proxy_config.service_name, "hcl-service");
assert_eq!(proxy_config.worker_threads, Some(8));
assert_eq!(proxy_config.docker.enabled, Some(true));
assert_eq!(proxy_config.docker.interval_secs, Some(30));
assert_eq!(
proxy_config.docker.endpoint,
Some(Cow::Borrowed("unix:///var/run/docker.sock"))
);
assert_eq!(proxy_config.lets_encrypt.email, "domain@valid.com");
assert_eq!(proxy_config.lets_encrypt.enabled, Some(true));
assert_eq!(proxy_config.lets_encrypt.staging, Some(false));
Ok(())
});
}
}