use std::{collections::HashMap, fs::read_to_string, net::SocketAddr, str::FromStr};
use anyhow::Result;
use clap::Parser;
use serde::{Deserialize, Serialize};
use crate::service::{InterfaceAddr, Transport, session::ports::PortRange};
#[derive(Deserialize, Serialize, Debug, Clone)]
#[serde(rename_all = "kebab-case")]
pub struct Ssl {
pub private_key: String,
pub certificate_chain: String,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
#[serde(tag = "transport", rename_all = "kebab-case")]
pub enum Interface {
Tcp {
listen: SocketAddr,
external: SocketAddr,
#[serde(default = "Interface::idle_timeout")]
idle_timeout: u32,
#[serde(default)]
ssl: Option<Ssl>,
},
Udp {
listen: SocketAddr,
external: SocketAddr,
#[serde(default = "Interface::idle_timeout")]
idle_timeout: u32,
#[serde(default = "Interface::mtu")]
mtu: usize,
},
}
impl Interface {
fn mtu() -> usize {
1500
}
fn idle_timeout() -> u32 {
20
}
}
#[derive(Deserialize, Debug, Clone)]
#[serde(rename_all = "kebab-case")]
pub struct Server {
#[serde(default = "Server::port_range")]
pub port_range: PortRange,
#[serde(default = "Server::max_threads")]
pub max_threads: usize,
#[serde(default = "Server::realm")]
pub realm: String,
#[serde(default)]
pub interfaces: Vec<Interface>,
}
impl Server {
pub fn get_interface_addrs(&self) -> Vec<InterfaceAddr> {
self.interfaces
.iter()
.map(|item| match item {
Interface::Tcp {
listen, external, ..
} => InterfaceAddr {
addr: *listen,
external: *external,
transport: Transport::Tcp,
},
Interface::Udp {
listen, external, ..
} => InterfaceAddr {
addr: *listen,
external: *external,
transport: Transport::Udp,
},
})
.collect()
}
}
impl Server {
fn realm() -> String {
"localhost".to_string()
}
fn port_range() -> PortRange {
PortRange::default()
}
fn max_threads() -> usize {
num_cpus::get()
}
}
impl Default for Server {
fn default() -> Self {
Self {
realm: Self::realm(),
interfaces: Default::default(),
port_range: Self::port_range(),
max_threads: Self::max_threads(),
}
}
}
#[derive(Deserialize, Debug, Clone)]
#[serde(rename_all = "kebab-case")]
pub struct Hooks {
#[serde(default = "Hooks::max_channel_size")]
pub max_channel_size: usize,
pub endpoint: String,
#[serde(default)]
pub ssl: Option<Ssl>,
#[serde(default = "Hooks::timeout")]
pub timeout: u32,
}
impl Hooks {
fn max_channel_size() -> usize {
1024
}
fn timeout() -> u32 {
5
}
}
#[derive(Deserialize, Debug, Clone)]
#[serde(rename_all = "kebab-case")]
pub struct Api {
#[serde(default = "Api::bind")]
pub listen: SocketAddr,
#[serde(default)]
pub ssl: Option<Ssl>,
#[serde(default = "Api::timeout")]
pub timeout: u32,
}
impl Api {
fn bind() -> SocketAddr {
"127.0.0.1:3000"
.parse()
.expect("Invalid default API bind address")
}
fn timeout() -> u32 {
5
}
}
impl Default for Api {
fn default() -> Self {
Self {
timeout: Self::timeout(),
listen: Self::bind(),
ssl: None,
}
}
}
#[derive(Deserialize, Debug, Clone)]
#[serde(rename_all = "kebab-case")]
pub struct Prometheus {
#[serde(default = "Prometheus::bind")]
pub listen: SocketAddr,
#[serde(default)]
pub ssl: Option<Ssl>,
}
impl Prometheus {
fn bind() -> SocketAddr {
"127.0.0.1:9184"
.parse()
.expect("Invalid default Prometheus bind address")
}
}
impl Default for Prometheus {
fn default() -> Self {
Self {
listen: Self::bind(),
ssl: None,
}
}
}
#[derive(Deserialize, Debug, Clone, Copy)]
#[serde(rename_all = "lowercase")]
pub enum LogLevel {
Error,
Warn,
Info,
Debug,
Trace,
}
impl FromStr for LogLevel {
type Err = String;
fn from_str(value: &str) -> Result<Self, Self::Err> {
Ok(match value {
"trace" => Self::Trace,
"debug" => Self::Debug,
"info" => Self::Info,
"warn" => Self::Warn,
"error" => Self::Error,
_ => return Err(format!("unknown log level: {value}")),
})
}
}
impl Default for LogLevel {
fn default() -> Self {
Self::Info
}
}
impl From<LogLevel> for log::LevelFilter {
fn from(val: LogLevel) -> Self {
match val {
LogLevel::Error => log::LevelFilter::Error,
LogLevel::Debug => log::LevelFilter::Debug,
LogLevel::Trace => log::LevelFilter::Trace,
LogLevel::Warn => log::LevelFilter::Warn,
LogLevel::Info => log::LevelFilter::Info,
}
}
}
#[derive(Deserialize, Debug, Default, Clone)]
#[serde(rename_all = "kebab-case")]
pub struct Log {
#[serde(default)]
pub level: LogLevel,
#[serde(default = "Log::stdout")]
pub stdout: bool,
#[serde(default)]
pub file_directory: Option<String>,
}
impl Log {
fn stdout() -> bool {
true
}
}
#[derive(Deserialize, Debug, Default, Clone)]
#[serde(rename_all = "kebab-case")]
pub struct Auth {
#[serde(default)]
pub static_credentials: HashMap<String, String>,
pub static_auth_secret: Option<String>,
#[serde(default)]
pub enable_hooks_auth: bool,
}
#[derive(Deserialize, Debug, Default, Clone)]
#[serde(rename_all = "kebab-case")]
pub struct Config {
#[serde(default)]
pub server: Server,
#[serde(default)]
pub api: Option<Api>,
#[serde(default)]
pub prometheus: Option<Prometheus>,
#[serde(default)]
pub hooks: Option<Hooks>,
#[serde(default)]
pub log: Log,
#[serde(default)]
pub auth: Auth,
}
#[derive(Parser, Debug)]
#[command(
about = env!("CARGO_PKG_DESCRIPTION"),
version = env!("CARGO_PKG_VERSION"),
author = env!("CARGO_PKG_AUTHORS"),
)]
struct Cli {
#[arg(long, short)]
config: String,
}
impl Config {
pub fn load() -> Result<Self> {
Ok(toml::from_str::<Self>(&read_to_string(
&Cli::parse().config,
)?)?)
}
}