use std::{
collections::BTreeMap,
fs,
path::{Path, PathBuf},
sync::OnceLock,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use tracing::info;
use crate::{controller::middleware, environment::Environment, logger, scheduler, Error, Result};
static DEFAULT_FOLDER: OnceLock<PathBuf> = OnceLock::new();
fn get_default_folder() -> &'static PathBuf {
DEFAULT_FOLDER.get_or_init(|| PathBuf::from("config"))
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Config {
pub logger: Logger,
pub server: Server,
#[cfg(feature = "with-db")]
pub database: Database,
#[serde(default)]
pub cache: CacheConfig,
pub queue: Option<QueueConfig>,
pub auth: Option<Auth>,
#[serde(default)]
pub workers: Workers,
pub mailer: Option<Mailer>,
pub initializers: Option<Initializers>,
#[serde(default)]
pub settings: Option<serde_json::Value>,
pub scheduler: Option<scheduler::Config>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct Logger {
pub enable: bool,
#[serde(default)]
pub pretty_backtrace: bool,
pub level: logger::LogLevel,
pub format: logger::Format,
pub override_filter: Option<String>,
pub file_appender: Option<LoggerFileAppender>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct LoggerFileAppender {
pub enable: bool,
#[serde(default)]
pub non_blocking: bool,
pub level: logger::LogLevel,
pub format: logger::Format,
pub rotation: logger::Rotation,
pub dir: Option<String>,
pub filename_prefix: Option<String>,
pub filename_suffix: Option<String>,
pub max_log_files: usize,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[allow(clippy::struct_excessive_bools)]
pub struct Database {
pub uri: String,
pub enable_logging: bool,
pub min_connections: u32,
pub max_connections: u32,
pub connect_timeout: u64,
pub idle_timeout: u64,
pub acquire_timeout: Option<u64>,
#[serde(default)]
pub auto_migrate: bool,
#[serde(default)]
pub dangerously_truncate: bool,
#[serde(default)]
pub dangerously_recreate: bool,
pub run_on_start: Option<String>,
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
#[serde(tag = "kind")]
pub enum CacheConfig {
#[cfg(feature = "cache_inmem")]
InMem(InMemCacheConfig),
#[cfg(feature = "cache_redis")]
Redis(RedisCacheConfig),
#[default]
Null,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct InMemCacheConfig {
#[serde(default = "cache_in_mem_max_capacity")]
pub max_capacity: u64,
}
fn cache_in_mem_max_capacity() -> u64 {
32 * 1024 * 1024
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RedisCacheConfig {
pub uri: String,
pub max_size: u32,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "kind")]
pub enum QueueConfig {
Redis(RedisQueueConfig),
Postgres(PostgresQueueConfig),
Sqlite(SqliteQueueConfig),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RedisQueueConfig {
pub uri: String,
#[serde(default)]
pub dangerously_flush: bool,
pub queues: Option<Vec<String>>,
#[serde(default = "num_workers")]
pub num_workers: u32,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct PostgresQueueConfig {
pub uri: String,
#[serde(default)]
pub dangerously_flush: bool,
#[serde(default)]
pub enable_logging: bool,
#[serde(default = "db_max_conn")]
pub max_connections: u32,
#[serde(default = "db_min_conn")]
pub min_connections: u32,
#[serde(default = "db_connect_timeout")]
pub connect_timeout: u64,
#[serde(default = "db_idle_timeout")]
pub idle_timeout: u64,
#[serde(default = "pgq_poll_interval")]
pub poll_interval_sec: u32,
#[serde(default = "num_workers")]
pub num_workers: u32,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SqliteQueueConfig {
pub uri: String,
#[serde(default)]
pub dangerously_flush: bool,
#[serde(default)]
pub enable_logging: bool,
#[serde(default = "db_max_conn")]
pub max_connections: u32,
#[serde(default = "db_min_conn")]
pub min_connections: u32,
#[serde(default = "db_connect_timeout")]
pub connect_timeout: u64,
#[serde(default = "db_idle_timeout")]
pub idle_timeout: u64,
#[serde(default = "sqlt_poll_interval")]
pub poll_interval_sec: u32,
#[serde(default = "num_workers")]
pub num_workers: u32,
}
fn db_min_conn() -> u32 {
1
}
fn db_max_conn() -> u32 {
20
}
fn db_connect_timeout() -> u64 {
500
}
fn db_idle_timeout() -> u64 {
500
}
fn pgq_poll_interval() -> u32 {
1
}
fn sqlt_poll_interval() -> u32 {
1
}
fn num_workers() -> u32 {
2
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Auth {
pub jwt: Option<JWT>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct JWT {
pub location: Option<JWTLocationConfig>,
pub secret: String,
pub expiration: u64,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "from")]
pub enum JWTLocation {
Bearer,
Query { name: String },
Cookie { name: String },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum JWTLocationConfig {
Single(JWTLocation),
Multiple(Vec<JWTLocation>),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Server {
#[serde(default = "default_binding")]
pub binding: String,
pub port: i32,
pub host: String,
pub ident: Option<String>,
#[serde(default)]
pub middlewares: middleware::Config,
}
fn default_binding() -> String {
"localhost".to_string()
}
impl Server {
#[must_use]
pub fn full_url(&self) -> String {
format!("{}:{}", self.host, self.port)
}
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct Workers {
pub mode: WorkerMode,
}
#[derive(Clone, Default, Serialize, Deserialize, Debug, PartialEq, Eq)]
pub enum WorkerMode {
#[default]
BackgroundQueue,
ForegroundBlocking,
BackgroundAsync,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Mailer {
pub smtp: Option<SmtpMailer>,
#[serde(default)]
pub stub: bool,
}
pub type Initializers = BTreeMap<String, serde_json::Value>;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SmtpMailer {
pub enable: bool,
pub host: String,
pub port: u16,
pub secure: bool,
pub auth: Option<MailerAuth>,
pub hello_name: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MailerAuth {
pub user: String,
pub password: String,
}
impl Config {
pub fn new(env: &Environment) -> Result<Self> {
let config = Self::from_folder(env, get_default_folder().as_path())?;
Ok(config)
}
pub fn from_folder(env: &Environment, path: &Path) -> Result<Self> {
let files = [
path.join(format!("{env}.local.yaml")),
path.join(format!("{env}.yaml")),
];
let selected_path = files.iter().find(|p| p.exists()).ok_or_else(|| {
Error::Message(format!(
"no configuration file found in folder: {}",
path.display()
))
})?;
info!(selected_path =? selected_path, "loading environment from");
let content = fs::read_to_string(selected_path)?;
let rendered = crate::tera::render_string(&content, &json!({}))?;
serde_yaml::from_str(&rendered)
.map_err(|err| Error::YAMLFile(err, selected_path.to_string_lossy().to_string()))
}
pub fn get_jwt_config(&self) -> Result<&JWT> {
self.auth
.as_ref()
.and_then(|auth| auth.jwt.as_ref())
.map_or_else(
|| Err(Error::Any("no JWT config found".to_string().into())),
Ok,
)
}
}
impl std::fmt::Display for Config {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let content = serde_yaml::to_string(self).unwrap_or_default();
write!(f, "{content}")
}
}