use std::collections::BTreeMap;
use serde_json::Value;
use super::builder::ConfigBuilder;
use super::keys;
use super::map::ConfigMap;
use super::profile::Profile;
use super::schema::{CacheConfig, DatabaseConfig, MessagingConfig, ServerConfig, StorageConfig};
use crate::error::ConfigError;
#[derive(Debug, Clone)]
pub struct Config {
profile: Profile,
values: ConfigMap,
}
impl Config {
pub(crate) fn new(profile: Profile, values: ConfigMap) -> Self {
Self { profile, values }
}
pub fn builder(profile: Profile) -> ConfigBuilder {
ConfigBuilder::new(profile)
}
pub async fn load() -> Result<Self, ConfigError> {
Self::builder(Profile::detect()).build().await
}
pub async fn auto_load() -> Result<Self, ConfigError> {
Self::load().await
}
pub fn profile(&self) -> &Profile {
&self.profile
}
pub fn values(&self) -> &ConfigMap {
&self.values
}
pub fn as_map(&self) -> &BTreeMap<String, Value> {
self.values.as_map()
}
pub fn keys(&self) -> impl Iterator<Item = &str> {
self.values.keys()
}
pub fn contains_key(&self, key: &str) -> bool {
self.get_raw(key).is_some()
}
pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Result<T, ConfigError> {
let value =
self.get_raw(key).ok_or_else(|| ConfigError::MissingRequired(key.to_owned()))?;
serde_json::from_value(value.clone())
.map_err(|e| ConfigError::Deserialization { key: key.to_owned(), source: e })
}
pub fn get_optional<T: serde::de::DeserializeOwned>(
&self,
key: &str,
) -> Result<Option<T>, ConfigError> {
match self.get_raw(key) {
None => Ok(None),
Some(value) => serde_json::from_value(value.clone())
.map(Some)
.map_err(|e| ConfigError::Deserialization { key: key.to_owned(), source: e }),
}
}
pub fn get_string(&self, key: &str) -> Option<String> {
match self.get_raw(key)? {
Value::String(s) => Some(s.clone()),
other => Some(other.to_string()),
}
}
pub fn get_raw(&self, key: &str) -> Option<&Value> {
if let Some(v) = self.values.get(key) {
return Some(v);
}
let (head, tail) = key.split_once('.')?;
let parent = self.values.get(head)?;
get_nested(parent, tail)
}
pub fn database(&self) -> Result<DatabaseConfig, ConfigError> {
self.get(keys::DATABASE)
}
pub fn cache(&self) -> Result<CacheConfig, ConfigError> {
self.get(keys::CACHE)
}
pub fn messaging(&self) -> Result<MessagingConfig, ConfigError> {
self.get(keys::MESSAGING)
}
pub fn storage(&self) -> Result<StorageConfig, ConfigError> {
self.get(keys::STORAGE)
}
pub fn server(&self) -> Result<ServerConfig, ConfigError> {
self.get(keys::SERVER)
}
}
fn get_nested<'a>(value: &'a Value, path: &str) -> Option<&'a Value> {
let mut parts = path.splitn(2, '.');
let head = parts.next()?;
let child = value.get(head)?;
match parts.next() {
Some(tail) => get_nested(child, tail),
None => Some(child),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::provider::MemoryProvider;
use serde::Deserialize;
use serde_json::json;
#[derive(Deserialize, Debug, PartialEq)]
struct DbFixture {
url: String,
pool_size: u32,
}
async fn fixture() -> Config {
Config::builder(Profile::Test)
.with_provider(
MemoryProvider::new()
.set("app_name", "TestApp")
.set("debug", true)
.set("database", json!({ "url": "postgres://localhost", "pool_size": 10 })),
)
.build()
.await
.unwrap()
}
#[tokio::test]
async fn get_deserializes_nested_section() {
let config = fixture().await;
let db: DbFixture = config.get("database").unwrap();
assert_eq!(db, DbFixture { url: "postgres://localhost".into(), pool_size: 10 });
}
#[tokio::test]
async fn dotted_keys_resolve_into_nested_objects() {
let config = fixture().await;
assert_eq!(config.get_raw("database.url"), Some(&json!("postgres://localhost")));
assert_eq!(config.get_string("database.pool_size").as_deref(), Some("10"));
}
#[tokio::test]
async fn missing_vs_optional() {
let config = fixture().await;
assert!(matches!(config.get::<DbFixture>("nope"), Err(ConfigError::MissingRequired(_))));
assert_eq!(config.get_optional::<DbFixture>("nope").unwrap(), None);
assert!(!config.contains_key("nope"));
assert!(config.contains_key("database.url"));
}
}