use std::{collections::BTreeMap, net::SocketAddr, time::Duration};
use serde::Deserialize;
use crate::core::{
ConfigFeatureWarning, CoreError, CoreResult, DatabaseSection, LogConfig, LogSection,
RpcClientSection, ServiceConfig, dependency_feature_warnings, load_config,
};
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
#[serde(default, deny_unknown_fields)]
pub struct RestServiceConfig {
pub name: String,
pub mode: String,
pub server: RestServerSection,
pub log: LogSection,
pub auth: Option<RestAuthSection>,
pub middlewares: RestMiddlewaresSection,
pub rpc_clients: BTreeMap<String, RpcClientSection>,
pub database: Option<DatabaseSection>,
}
impl Default for RestServiceConfig {
fn default() -> Self {
let service = ServiceConfig::default();
Self {
name: service.name,
mode: service.mode,
server: RestServerSection::default(),
log: service.log,
auth: None,
middlewares: RestMiddlewaresSection::default(),
rpc_clients: BTreeMap::new(),
database: None,
}
}
}
impl RestServiceConfig {
pub fn load(basename: &str, env_prefix: &str) -> Result<Self, config::ConfigError> {
load_config(basename, env_prefix)
}
pub fn addr(&self) -> CoreResult<SocketAddr> {
format!("{}:{}", self.server.host, self.server.port)
.parse()
.map_err(|error| {
config::ConfigError::Message(format!("invalid REST listen address: {error}")).into()
})
}
pub fn log_config(&self) -> LogConfig {
self.log.to_log_config(&self.name)
}
pub fn rest_config(&self) -> crate::rest::RestConfig {
let mut config = if self.middlewares.resilience || self.middlewares.metrics {
crate::rest::RestConfig::production_defaults(self.name.clone())
} else {
crate::rest::RestConfig {
name: self.name.clone(),
..crate::rest::RestConfig::default()
}
};
config.timeout = Duration::from_millis(self.server.timeout_ms);
config.max_body_bytes = self.server.max_body_bytes;
config.middlewares.metrics.enabled = self.middlewares.metrics;
if !self.middlewares.resilience {
config.middlewares.resilience = crate::rest::RestResilienceConfig::default();
}
config.auth = self.auth.as_ref().and_then(RestAuthSection::auth_config);
config
}
pub fn validate_features(&self) -> Vec<ConfigFeatureWarning> {
let mut warnings = Vec::new();
if self.middlewares.metrics && !cfg!(feature = "observability") {
warnings.push(ConfigFeatureWarning::ignored(
"middlewares.metrics",
"observability",
));
}
if self.middlewares.resilience && !cfg!(feature = "resil") {
warnings.push(ConfigFeatureWarning::ignored(
"middlewares.resilience",
"resil",
));
}
warnings.extend(dependency_feature_warnings(
&self.rpc_clients,
self.database.as_ref(),
));
warnings
}
pub fn jwt_expires(&self) -> Option<u64> {
self.auth.as_ref().map(RestAuthSection::jwt_expires)
}
pub fn rpc_client(&self, name: &str) -> CoreResult<&RpcClientSection> {
self.rpc_clients.get(name).ok_or_else(|| {
CoreError::Config(config::ConfigError::Message(format!(
"missing rpc client config: {name}"
)))
})
}
#[cfg(feature = "rpc")]
pub fn rpc_client_config(&self, name: &str) -> CoreResult<crate::rpc::RpcClientConfig> {
self.rpc_client(name)?.to_rpc_client_config()
}
#[cfg(feature = "db")]
pub fn database_config(&self) -> Option<crate::db::DatabaseConfig> {
self.database
.as_ref()
.map(DatabaseSection::to_database_config)
}
}
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
#[serde(default, deny_unknown_fields)]
pub struct RestServerSection {
pub host: String,
pub port: u16,
pub timeout_ms: u64,
pub max_body_bytes: usize,
}
impl Default for RestServerSection {
fn default() -> Self {
Self {
host: "127.0.0.1".to_string(),
port: 8080,
timeout_ms: 5000,
max_body_bytes: 1024 * 1024,
}
}
}
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
#[serde(default, deny_unknown_fields)]
pub struct RestAuthSection {
pub jwt_secret: String,
pub jwt_expires: u64,
pub public_paths: Vec<String>,
}
impl Default for RestAuthSection {
fn default() -> Self {
Self {
jwt_secret: String::new(),
jwt_expires: 7200,
public_paths: Vec::new(),
}
}
}
impl RestAuthSection {
fn secret(&self) -> String {
std::env::var("JWT_AUTH_SECRET").unwrap_or_else(|_| self.jwt_secret.clone())
}
fn jwt_expires(&self) -> u64 {
std::env::var("JWT_AUTH_EXPIRES")
.ok()
.and_then(|value| value.parse().ok())
.unwrap_or(self.jwt_expires)
}
fn auth_config(&self) -> Option<crate::rest::AuthConfig> {
let secret = self.secret();
(!secret.is_empty()).then(|| crate::rest::AuthConfig {
secret,
public_paths: self.public_paths.clone(),
})
}
}
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
#[serde(default, deny_unknown_fields)]
pub struct RestMiddlewaresSection {
pub metrics: bool,
pub resilience: bool,
}
impl Default for RestMiddlewaresSection {
fn default() -> Self {
Self {
metrics: true,
resilience: true,
}
}
}
#[cfg(test)]
mod tests {
use super::RestServiceConfig;
#[test]
fn maps_runtime_values() {
let config = RestServiceConfig::default();
let runtime = config.rest_config();
assert_eq!(runtime.name, "rs-zero");
assert_eq!(runtime.timeout, std::time::Duration::from_millis(5000));
assert!(runtime.middlewares.metrics.enabled);
}
#[test]
fn validate_features_reflects_compile_time_features() {
let warnings = RestServiceConfig::default().validate_features();
assert_eq!(
warnings
.iter()
.any(|warning| warning.required_feature == "observability"),
!cfg!(feature = "observability")
);
assert_eq!(
warnings
.iter()
.any(|warning| warning.required_feature == "resil"),
!cfg!(feature = "resil")
);
}
}