use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use crate::BoxError;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[non_exhaustive]
pub enum Environment {
#[default]
Local,
#[serde(rename = "dev-private", alias = "devprivate")]
DevPrivate,
#[serde(alias = "dev")]
Development,
Test,
Staging,
#[serde(alias = "prod")]
Production,
}
impl std::fmt::Display for Environment {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Local => write!(f, "local"),
Self::DevPrivate => write!(f, "dev-private"),
Self::Development => write!(f, "development"),
Self::Test => write!(f, "test"),
Self::Staging => write!(f, "staging"),
Self::Production => write!(f, "production"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ServerConfig {
#[serde(default = "default_host")]
pub host: String,
#[serde(default = "default_port")]
pub port: u16,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: default_host(),
port: default_port(),
}
}
}
impl ServerConfig {
#[must_use]
pub fn new(host: impl Into<String>, port: u16) -> Self {
Self {
host: host.into(),
port,
}
}
}
fn default_host() -> String {
std::env::var("HOST").unwrap_or_else(|_| "0.0.0.0".to_string())
}
fn default_port() -> u16 {
std::env::var("PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(8443)
}
#[derive(Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct AppConfigDefinition {
#[serde(default = "default_app_name")]
pub name: String,
#[serde(default)]
pub env: Option<String>,
#[serde(default)]
pub server: ServerConfig,
#[serde(default)]
pub environments: HashMap<String, serde_json::Value>,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
impl Default for AppConfigDefinition {
fn default() -> Self {
Self {
name: default_app_name(),
env: None,
server: ServerConfig::default(),
environments: HashMap::new(),
extra: HashMap::new(),
}
}
}
impl std::fmt::Debug for AppConfigDefinition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut env_names: Vec<&str> = self.environments.keys().map(String::as_str).collect();
env_names.sort_unstable();
let mut extra_names: Vec<&str> = self.extra.keys().map(String::as_str).collect();
extra_names.sort_unstable();
f.debug_struct("AppConfigDefinition")
.field("name", &self.name)
.field("env", &self.env)
.field("server", &self.server)
.field("environments", &env_names)
.field("extra", &extra_names)
.finish()
}
}
fn default_app_name() -> String {
"rusty-gasket-app".to_string()
}
impl AppConfigDefinition {
#[must_use]
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
..Self::default()
}
}
#[must_use]
pub fn server(mut self, server: ServerConfig) -> Self {
self.server = server;
self
}
#[must_use]
pub fn env(mut self, env: impl Into<String>) -> Self {
self.env = Some(env.into());
self
}
pub fn from_toml(contents: &str) -> Result<Self, BoxError> {
Ok(toml::from_str(contents)?)
}
pub fn from_yaml(contents: &str) -> Result<Self, BoxError> {
Ok(serde_yaml_ng::from_str(contents)?)
}
pub fn from_file(path: impl AsRef<Path>) -> Result<Self, BoxError> {
let path = path.as_ref();
let contents = std::fs::read_to_string(path)
.map_err(|e| format!("Failed to read config file '{}': {e}", path.display()))?;
Self::parse_with_extension(path, &contents)
}
pub fn from_file_optional(path: impl AsRef<Path>) -> Result<Option<Self>, BoxError> {
let path = path.as_ref();
let contents = match std::fs::read_to_string(path) {
Ok(s) => s,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(e) => {
return Err(format!("Failed to read config file '{}': {e}", path.display()).into());
}
};
Self::parse_with_extension(path, &contents).map(Some)
}
fn parse_with_extension(path: &Path, contents: &str) -> Result<Self, BoxError> {
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
match ext {
"toml" => Self::from_toml(contents),
"yaml" | "yml" => Self::from_yaml(contents),
_ => Self::from_toml(contents).or_else(|_| Self::from_yaml(contents)),
}
}
pub fn resolve(self) -> Result<AppConfig, BoxError> {
let env_str = std::env::var("GASKET_ENV")
.ok()
.or(self.env)
.unwrap_or_else(|| "local".to_string());
let env: Environment = serde_json::from_value(serde_json::Value::String(env_str.clone()))
.map_err(|_| {
format!(
"Unknown environment '{env_str}'. \
Valid values: local, dev-private, development, dev, test, staging, production, prod"
)
})?;
let mut sections = self.extra;
let canonical_env = env.to_string();
let env_overrides = self
.environments
.get(&env_str)
.or_else(|| self.environments.get(&canonical_env));
if let Some(env_overrides) = env_overrides
&& let Some(obj) = env_overrides.as_object()
{
for (key, value) in obj {
if let Some(existing) = sections.get(key) {
sections.insert(key.clone(), deep_merge_json(existing, value));
} else {
sections.insert(key.clone(), value.clone());
}
}
}
Ok(AppConfig {
name: self.name,
env,
server: self.server,
sections,
})
}
}
#[derive(Clone)]
pub struct AppConfig {
pub name: String,
pub env: Environment,
pub server: ServerConfig,
sections: HashMap<String, serde_json::Value>,
}
impl std::fmt::Debug for AppConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut section_names: Vec<&str> = self.sections.keys().map(String::as_str).collect();
section_names.sort_unstable();
f.debug_struct("AppConfig")
.field("name", &self.name)
.field("env", &self.env)
.field("server", &self.server)
.field("sections", §ion_names)
.finish()
}
}
impl AppConfig {
pub fn section<T: serde::de::DeserializeOwned>(&self, key: &str) -> Result<T, BoxError> {
let value = self
.sections
.get(key)
.ok_or_else(|| format!("Config section '{key}' not found"))?;
Ok(serde_json::from_value(value.clone())?)
}
pub fn section_or_default<T: serde::de::DeserializeOwned + Default>(
&self,
key: &str,
) -> Result<T, BoxError> {
match self.sections.get(key) {
Some(v) => Ok(serde_json::from_value(v.clone())
.map_err(|e| format!("Config section '{key}' is invalid: {e}"))?),
None => Ok(T::default()),
}
}
#[must_use]
pub fn has_section(&self, key: &str) -> bool {
self.sections.contains_key(key)
}
pub fn set_section(&mut self, key: &str, value: serde_json::Value) {
self.sections.insert(key.to_string(), value);
}
}
fn deep_merge_json(base: &serde_json::Value, overlay: &serde_json::Value) -> serde_json::Value {
match (base, overlay) {
(serde_json::Value::Object(base_map), serde_json::Value::Object(overlay_map)) => {
let mut merged = base_map.clone();
for (key, overlay_val) in overlay_map {
let merged_val = if let Some(base_val) = base_map.get(key) {
deep_merge_json(base_val, overlay_val)
} else {
overlay_val.clone()
};
merged.insert(key.clone(), merged_val);
}
serde_json::Value::Object(merged)
}
(_, overlay) => overlay.clone(),
}
}
pub trait SecretsProvider: Send + Sync + 'static {
fn get_secret<'ctx>(
&'ctx self,
key: &'ctx str,
) -> impl Future<Output = Result<Option<SecretValue>, BoxError>> + Send + 'ctx;
}
pub struct SecretValue {
inner: secrecy::SecretString,
}
impl SecretValue {
#[must_use]
pub fn new(value: String) -> Self {
Self {
inner: secrecy::SecretString::from(value),
}
}
#[must_use]
pub fn expose(&self) -> &str {
use secrecy::ExposeSecret;
self.inner.expose_secret()
}
}
impl std::fmt::Debug for SecretValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("SecretValue(***)")
}
}
#[derive(Debug, Default)]
pub struct EnvSecretsProvider;
impl SecretsProvider for EnvSecretsProvider {
async fn get_secret(&self, key: &str) -> Result<Option<SecretValue>, BoxError> {
let env_key = key.to_uppercase().replace('-', "_");
Ok(std::env::var(&env_key).ok().map(SecretValue::new))
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct StringSet {
values: HashSet<String>,
}
impl StringSet {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn from_values(values: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self {
values: values.into_iter().map(Into::into).collect(),
}
}
pub fn load_yaml_field_optional(
path: impl AsRef<Path>,
field: &'static str,
) -> Result<Self, StringSetError> {
let path = path.as_ref();
let contents = match std::fs::read_to_string(path) {
Ok(contents) => contents,
Err(error) if error.kind() == std::io::ErrorKind::NotFound => return Ok(Self::new()),
Err(error) => {
return Err(StringSetError::Io {
path: path.to_path_buf(),
error,
});
}
};
let document: serde_json::Value =
serde_yaml_ng::from_str(&contents).map_err(|error| StringSetError::Parse {
path: path.to_path_buf(),
reason: error.to_string(),
})?;
let Some(value) = document.get(field) else {
return Ok(Self::new());
};
let Some(items) = value.as_array() else {
return Err(StringSetError::InvalidField {
path: path.to_path_buf(),
field,
reason: "expected a YAML sequence of strings",
});
};
let mut values = HashSet::with_capacity(items.len());
for item in items {
let Some(value) = item.as_str() else {
return Err(StringSetError::InvalidField {
path: path.to_path_buf(),
field,
reason: "all sequence entries must be strings",
});
};
values.insert(value.to_string());
}
Ok(Self { values })
}
#[must_use]
pub fn with_env_csv(mut self, env_var: &'static str) -> Self {
if let Ok(raw_values) = std::env::var(env_var) {
self.extend_csv(&raw_values);
}
self
}
pub fn extend_csv(&mut self, csv: &str) {
for value in csv
.split(',')
.map(str::trim)
.filter(|value| !value.is_empty())
{
self.values.insert(value.to_string());
}
}
#[must_use]
pub fn contains(&self, value: &str) -> bool {
self.values.contains(value)
}
#[must_use]
pub fn len(&self) -> usize {
self.values.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &String> {
self.values.iter()
}
#[must_use]
pub const fn as_hash_set(&self) -> &HashSet<String> {
&self.values
}
#[must_use]
pub fn into_hash_set(self) -> HashSet<String> {
self.values
}
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum StringSetError {
#[error("Failed to read string-set file '{}': {error}", path.display())]
Io {
path: PathBuf,
#[source]
error: std::io::Error,
},
#[error("Failed to parse string-set file '{}': {reason}", path.display())]
Parse {
path: PathBuf,
reason: String,
},
#[error("Invalid string-set field '{field}' in '{}': {reason}", path.display())]
InvalidField {
path: PathBuf,
field: &'static str,
reason: &'static str,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn environment_parsing() {
let env: Environment =
serde_json::from_value(serde_json::Value::String("production".to_string()))
.expect("parse");
assert_eq!(env, Environment::Production);
let env: Environment =
serde_json::from_value(serde_json::Value::String("prod".to_string()))
.expect("parse alias");
assert_eq!(env, Environment::Production);
let env: Environment =
serde_json::from_value(serde_json::Value::String("dev-private".to_string()))
.expect("parse alias");
assert_eq!(env, Environment::DevPrivate);
}
#[test]
fn environment_serde_round_trip_matches_display() {
for env in [
Environment::Local,
Environment::DevPrivate,
Environment::Development,
Environment::Test,
Environment::Staging,
Environment::Production,
] {
let value = serde_json::to_value(env).expect("serialize");
let s = value.as_str().expect("string");
assert_eq!(s, env.to_string(), "serde and Display disagree for {env:?}");
let back: Environment = serde_json::from_value(value).expect("round-trip");
assert_eq!(back, env);
}
}
#[test]
fn config_from_toml() {
let toml_str = r#"
name = "my-api"
env = "test"
[server]
host = "127.0.0.1"
port = 3000
[database]
url = "postgres://localhost/mydb"
"#;
let def = AppConfigDefinition::from_toml(toml_str).expect("parse toml");
assert_eq!(def.name, "my-api");
let config = def.resolve().expect("resolve");
assert_eq!(config.env, Environment::Test);
assert_eq!(config.server.port, 3000);
assert!(config.has_section("database"));
}
#[test]
fn config_env_overrides_replace_base_when_resolved_env_matches() {
let toml_str = r#"
name = "my-api"
env = "production"
[database]
url = "postgres://localhost/dev"
max_connections = 10
[environments.production.database]
url = "postgres://prod-host/prod"
"#;
let def = AppConfigDefinition::from_toml(toml_str).expect("parse");
let config = def.resolve().expect("resolve");
assert_eq!(config.env, Environment::Production);
let db: serde_json::Value = config.section("database").expect("database section");
assert_eq!(db["url"], "postgres://prod-host/prod");
assert_eq!(db["max_connections"], 10);
}
#[test]
fn config_env_overrides_skipped_when_env_does_not_match() {
let toml_str = r#"
name = "my-api"
env = "local"
[database]
url = "postgres://localhost/dev"
[environments.production.database]
url = "postgres://prod-host/prod"
"#;
let def = AppConfigDefinition::from_toml(toml_str).expect("parse");
let config = def.resolve().expect("resolve");
assert_eq!(config.env, Environment::Local);
let db: serde_json::Value = config.section("database").expect("database section");
assert_eq!(db["url"], "postgres://localhost/dev");
}
#[test]
fn app_config_definition_debug_omits_section_values() {
let toml_str = r#"
name = "my-api"
env = "local"
[database]
url = "postgres://admin:another_secret_456@db.internal/app"
[environments.production.database]
url = "postgres://prod-admin:prod_secret_789@prod.internal/app"
"#;
let def = AppConfigDefinition::from_toml(toml_str).expect("parse");
let debug = format!("{def:?}");
assert!(
!debug.contains("another_secret_456"),
"AppConfigDefinition Debug must not print extra values: {debug}"
);
assert!(
!debug.contains("prod_secret_789"),
"AppConfigDefinition Debug must not print environments values: {debug}"
);
assert!(
debug.contains("database"),
"section names should still be listed: {debug}"
);
assert!(
debug.contains("production"),
"environment names should still be listed: {debug}"
);
}
#[test]
fn app_config_debug_omits_section_values() {
let toml_str = r#"
name = "my-api"
env = "local"
[database]
url = "postgres://admin:super_secret_password_123@db.internal/app"
"#;
let def = AppConfigDefinition::from_toml(toml_str).expect("parse");
let config = def.resolve().expect("resolve");
let debug = format!("{config:?}");
assert!(
!debug.contains("super_secret_password_123"),
"AppConfig Debug must not print section values: {debug}"
);
assert!(
!debug.contains("postgres://"),
"AppConfig Debug must not print section values: {debug}"
);
assert!(
debug.contains("database"),
"section names should still be listed: {debug}"
);
}
#[test]
fn string_set_loads_named_yaml_field() {
let dir = tempfile::tempdir().expect("temp dir");
let path = dir.path().join("policy.yaml");
std::fs::write(&path, "clients:\n - svc-a\n - svc-b\n").expect("write yaml");
let set = StringSet::load_yaml_field_optional(&path, "clients").expect("load set");
assert_eq!(set.len(), 2);
assert!(set.contains("svc-a"));
assert!(set.contains("svc-b"));
assert!(!set.contains("svc-c"));
}
#[test]
fn string_set_missing_file_is_empty() {
let set = StringSet::load_yaml_field_optional("/does/not/exist.yaml", "clients")
.expect("missing file");
assert!(set.is_empty());
}
#[test]
fn string_set_rejects_non_string_entries() {
let dir = tempfile::tempdir().expect("temp dir");
let path = dir.path().join("policy.yaml");
std::fs::write(&path, "clients:\n - svc-a\n - 42\n").expect("write yaml");
let result = StringSet::load_yaml_field_optional(&path, "clients");
assert!(matches!(result, Err(StringSetError::InvalidField { .. })));
}
}