use std::path::PathBuf;
use anyhow::{Context, Result};
use bincode_next::{Decode, Encode};
#[cfg(test)]
use bon::Builder;
use config::{Config, Environment, File, FileFormat, Source};
use dirs2::config_dir;
use getset::{CopyGetters, Getters, Setters};
use serde::{Deserialize, Serialize};
use tracing::Level;
use tracing_subscriber_init::{TracingConfig, get_effective_level};
#[cfg(test)]
use crate::utils::Mock;
use crate::{TlsConfig, TracingConfigExt, error::Error, utils::to_path_buf};
pub trait PathDefaults {
fn env_prefix(&self) -> String;
fn config_absolute_path(&self) -> Option<String>;
fn default_file_path(&self) -> String;
fn default_file_name(&self) -> String;
fn tracing_absolute_path(&self) -> Option<String>;
fn default_tracing_path(&self) -> String;
fn default_tracing_file_name(&self) -> String;
}
#[derive(Clone, Debug, Default, Deserialize, Eq, Getters, PartialEq, Serialize)]
pub struct Tracing {
#[getset(get = "pub")]
stdout: Layer,
#[getset(get = "pub")]
file: FileLayer,
}
#[derive(Clone, CopyGetters, Debug, Default, Deserialize, Eq, PartialEq, Serialize, Setters)]
pub struct FileLayer {
quiet: u8,
verbose: u8,
layer: Layer,
}
impl TracingConfig for FileLayer {
fn quiet(&self) -> u8 {
self.quiet
}
fn verbose(&self) -> u8 {
self.verbose
}
fn with_ansi(&self) -> bool {
false
}
fn with_target(&self) -> bool {
self.layer.with_target
}
fn with_thread_ids(&self) -> bool {
self.layer.with_thread_ids
}
fn with_thread_names(&self) -> bool {
self.layer.with_thread_names
}
fn with_line_number(&self) -> bool {
self.layer.with_line_number
}
fn with_level(&self) -> bool {
self.layer.with_level
}
}
impl TracingConfigExt for FileLayer {
fn enable_stdout(&self) -> bool {
false
}
fn directives(&self) -> Option<&String> {
self.layer.directives.as_ref()
}
fn level(&self) -> Level {
get_effective_level(self.quiet(), self.verbose())
}
}
#[allow(clippy::struct_excessive_bools)]
#[derive(Clone, CopyGetters, Debug, Default, Deserialize, Eq, Getters, PartialEq, Serialize)]
pub struct Layer {
#[getset(get_copy = "pub")]
with_target: bool,
#[getset(get_copy = "pub")]
with_thread_ids: bool,
#[getset(get_copy = "pub")]
with_thread_names: bool,
#[getset(get_copy = "pub")]
with_line_number: bool,
#[getset(get_copy = "pub")]
with_level: bool,
#[getset(get = "pub")]
directives: Option<String>,
}
#[derive(Clone, Debug, Default, Deserialize, Eq, Getters, PartialEq, Serialize)]
#[getset(get = "pub")]
pub struct Command {
cmd: String,
}
#[derive(Clone, Debug, Default, Deserialize, Eq, Getters, PartialEq, Serialize)]
#[getset(get = "pub")]
pub struct Actix {
workers: u8,
ip: String,
port: u16,
tls: Option<Tls>,
}
#[derive(Clone, CopyGetters, Debug, Default, Deserialize, Eq, Getters, PartialEq, Serialize)]
pub struct Tls {
#[getset(get = "pub")]
ip: String,
#[getset(get_copy = "pub")]
port: u16,
#[getset(get = "pub")]
cert_file_path: String,
#[getset(get = "pub")]
key_file_path: String,
#[getset(get = "pub")]
#[serde(default)]
client_ca_cert: Option<PathBuf>,
}
impl TlsConfig for Tls {
fn cert_file_path(&self) -> &str {
&self.cert_file_path
}
fn key_file_path(&self) -> &str {
&self.key_file_path
}
fn client_ca_cert_path(&self) -> Option<&std::path::Path> {
self.client_ca_cert.as_deref()
}
}
#[derive(Clone, CopyGetters, Debug, Default, Deserialize, Eq, Getters, PartialEq, Serialize)]
pub struct Bartos {
#[getset(get = "pub")]
prefix: String,
#[getset(get = "pub")]
host: String,
#[getset(get_copy = "pub")]
port: u16,
#[getset(get = "pub")]
#[serde(default)]
ca_cert: Option<PathBuf>,
#[getset(get = "pub")]
#[serde(default)]
client_cert: Option<PathBuf>,
#[getset(get = "pub")]
#[serde(default)]
client_key: Option<PathBuf>,
#[getset(get = "pub")]
#[serde(default)]
api_key: Option<String>,
}
#[derive(Clone, CopyGetters, Debug, Default, Deserialize, Eq, Getters, PartialEq, Serialize)]
pub struct Mariadb {
host: String,
port: Option<u16>,
username: String,
password: String,
database: String,
options: Option<String>,
#[doc(hidden)]
#[getset(get_copy = "pub")]
#[serde(default = "OutputTableName::default")]
output_table: OutputTableName,
#[doc(hidden)]
#[getset(get_copy = "pub")]
#[serde(default = "StatusTableName::default")]
status_table: StatusTableName,
}
impl Mariadb {
#[must_use]
pub fn connection_string(&self) -> String {
let mut url = format!(
"mariadb://{}:{}@{}:{}/{}",
self.username,
self.password,
self.host,
self.port.unwrap_or(3306),
self.database
);
if let Some(options) = self.options.as_ref() {
url.push('?');
url.push_str(options);
}
url
}
#[must_use]
pub fn disp_connection_string(&self) -> String {
let mut url = format!(
"mariadb://{}:****@{}:{}/{}",
self.username,
self.host,
self.port.unwrap_or(3306),
self.database
);
if let Some(options) = self.options.as_ref() {
url.push('?');
url.push_str(options);
}
url
}
}
#[doc(hidden)]
#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub enum OutputTableName {
#[default]
Output,
OutputTest,
}
impl From<OutputTableName> for &'static str {
fn from(value: OutputTableName) -> Self {
match value {
OutputTableName::Output => "output",
OutputTableName::OutputTest => "output_test",
}
}
}
#[doc(hidden)]
#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub enum StatusTableName {
#[default]
Status,
StatusTest,
}
impl From<StatusTableName> for &'static str {
fn from(value: StatusTableName) -> Self {
match value {
StatusTableName::Status => "status",
StatusTableName::StatusTest => "status_test",
}
}
}
#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub enum MissedTick {
#[default]
Burst,
Delay,
Skip,
}
#[derive(Clone, Debug, Decode, Deserialize, Encode, Eq, Getters, PartialEq, Serialize)]
#[cfg_attr(test, derive(Builder))]
#[getset(get = "pub")]
pub struct Schedules {
schedules: Vec<Schedule>,
}
#[cfg(test)]
impl Mock for Schedules {
fn mock() -> Self {
Self::builder()
.schedules(vec![Schedule::mock(), Schedule::mock()])
.build()
}
}
#[derive(Clone, Debug, Decode, Default, Deserialize, Encode, Eq, Getters, PartialEq, Serialize)]
#[cfg_attr(test, derive(Builder))]
#[getset(get = "pub")]
pub struct Schedule {
name: String,
on_calendar: String,
cmds: Vec<String>,
}
#[cfg(test)]
impl Mock for Schedule {
fn mock() -> Self {
Self::builder()
.name("mock_schedule".to_string())
.on_calendar("* * * * *".to_string())
.cmds(vec!["echo 'Hello, World!'".to_string()])
.build()
}
}
pub fn load<'a, S, T, D>(cli: &S, defaults: &D) -> Result<T>
where
T: Deserialize<'a>,
S: Source + Clone + Send + Sync + 'static,
D: PathDefaults,
{
let config_file_path = config_file_path(defaults)?;
let config = Config::builder()
.add_source(
Environment::with_prefix(&defaults.env_prefix())
.prefix_separator("_")
.separator("__")
.try_parsing(true),
)
.add_source(cli.clone())
.add_source(File::from(config_file_path).format(FileFormat::Toml))
.build()
.with_context(|| Error::ConfigBuild)?;
config
.try_deserialize::<T>()
.with_context(|| Error::ConfigDeserialize)
}
pub fn resolve_config_path<D>(defaults: &D) -> Result<PathBuf>
where
D: PathDefaults,
{
config_file_path(defaults)
}
fn config_file_path<D>(defaults: &D) -> Result<PathBuf>
where
D: PathDefaults,
{
let default_fn = || -> Result<PathBuf> { default_config_file_path(defaults) };
defaults
.config_absolute_path()
.as_ref()
.map_or_else(default_fn, to_path_buf)
}
fn default_config_file_path<D>(defaults: &D) -> Result<PathBuf>
where
D: PathDefaults,
{
let mut config_file_path = config_dir().ok_or(Error::ConfigDir)?;
config_file_path.push(defaults.default_file_path());
config_file_path.push(defaults.default_file_name());
let _ = config_file_path.set_extension("toml");
Ok(config_file_path)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use crate::TlsConfig;
use super::{Bartos, PathDefaults, Tls, load};
#[test]
fn test_bartos_client_cert_key_default_none() {
let bartos = Bartos::default();
assert!(bartos.client_cert().is_none());
assert!(bartos.client_key().is_none());
assert!(bartos.ca_cert().is_none());
assert!(bartos.api_key().is_none());
}
#[test]
fn test_tls_client_ca_cert_default_none() {
let tls = Tls::default();
assert!(tls.client_ca_cert().is_none());
assert!(tls.client_ca_cert_path().is_none());
}
#[test]
fn test_tls_client_ca_cert_path_returns_some() {
let tls = Tls {
ip: "0.0.0.0".to_string(),
port: 8443,
cert_file_path: "cert.pem".to_string(),
key_file_path: "key.pem".to_string(),
client_ca_cert: Some(PathBuf::from("/etc/bartos/client-ca.pem")),
};
assert_eq!(
tls.client_ca_cert_path(),
Some(std::path::Path::new("/etc/bartos/client-ca.pem"))
);
}
#[test]
fn test_bartos_client_cert_key_returns_some() {
let bartos = Bartos {
prefix: "wss".to_string(),
host: "localhost".to_string(),
port: 8443,
ca_cert: None,
client_cert: Some(PathBuf::from("/etc/bartoc/client.pem")),
client_key: Some(PathBuf::from("/etc/bartoc/client.key")),
api_key: None,
};
assert_eq!(
bartos.client_cert().as_deref(),
Some(std::path::Path::new("/etc/bartoc/client.pem"))
);
assert_eq!(
bartos.client_key().as_deref(),
Some(std::path::Path::new("/etc/bartoc/client.key"))
);
}
#[test]
#[cfg_attr(nightly, allow(unsafe_code))]
fn test_load_flat_env_var_with_underscore() {
use std::env;
use config::{ConfigError, Map, Source, Value};
use serde::Deserialize;
use tempfile::NamedTempFile;
#[derive(Debug, Deserialize, PartialEq)]
struct TestCfg {
my_field: Option<String>,
}
#[derive(Debug, Clone)]
struct NoOpSource;
impl Source for NoOpSource {
fn clone_into_box(&self) -> Box<dyn Source + Send + Sync> {
Box::new(self.clone())
}
fn collect(&self) -> Result<Map<String, Value>, ConfigError> {
Ok(Map::new())
}
}
struct TestDefaults {
path: String,
}
impl PathDefaults for TestDefaults {
fn env_prefix(&self) -> String {
"LBCFGTEST".to_string()
}
fn config_absolute_path(&self) -> Option<String> {
Some(self.path.clone())
}
fn default_file_path(&self) -> String {
String::new()
}
fn default_file_name(&self) -> String {
String::new()
}
fn tracing_absolute_path(&self) -> Option<String> {
None
}
fn default_tracing_path(&self) -> String {
String::new()
}
fn default_tracing_file_name(&self) -> String {
String::new()
}
}
let toml = NamedTempFile::new().unwrap();
std::fs::write(toml.path(), b"").unwrap();
let defaults = TestDefaults {
path: toml.path().to_str().unwrap().to_string(),
};
unsafe { env::set_var("LBCFGTEST_MY_FIELD", "flatval") };
let cfg: TestCfg = load(&NoOpSource, &defaults).unwrap();
unsafe { env::remove_var("LBCFGTEST_MY_FIELD") };
assert_eq!(cfg.my_field.as_deref(), Some("flatval"));
}
}