use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use crate::Error;
#[async_trait]
pub trait Config: Send + Sync + 'static {
async fn get(&self, path: &str) -> Result<Option<Vec<u8>>, Error>;
fn watch(
&self,
path: &str,
interval: Duration,
) -> tokio::sync::watch::Receiver<Option<Vec<u8>>>;
fn source(&self) -> &'static str;
}
pub async fn get_typed<T: DeserializeOwned>(
cfg: &(dyn Config + '_),
path: &str,
) -> Result<Option<T>, Error> {
match cfg.get(path).await? {
Some(bytes) => {
let v = serde_json::from_slice(&bytes).map_err(|e| {
Error::CapabilityPermanent(format!("config '{path}' deserialize: {e}"))
})?;
Ok(Some(v))
}
None => Ok(None),
}
}
pub struct ChainedConfig {
sources: Vec<Arc<dyn Config>>,
}
impl ChainedConfig {
pub fn new(sources: Vec<Arc<dyn Config>>) -> Self {
Self { sources }
}
}
#[async_trait]
impl Config for ChainedConfig {
async fn get(&self, path: &str) -> Result<Option<Vec<u8>>, Error> {
for src in &self.sources {
if let Some(v) = src.get(path).await? {
return Ok(Some(v));
}
}
Ok(None)
}
fn watch(
&self,
path: &str,
interval: Duration,
) -> tokio::sync::watch::Receiver<Option<Vec<u8>>> {
if let Some(first) = self.sources.first() {
first.watch(path, interval)
} else {
let (_tx, rx) = tokio::sync::watch::channel(None);
rx
}
}
fn source(&self) -> &'static str {
"chained"
}
}
pub struct EnvConfig {
pub prefix: String,
}
impl EnvConfig {
pub fn new(prefix: impl Into<String>) -> Self {
Self {
prefix: prefix.into(),
}
}
fn env_key(&self, path: &str) -> String {
let mut s = String::with_capacity(self.prefix.len() + path.len());
s.push_str(&self.prefix);
for ch in path.chars() {
if ch == '.' {
s.push('_');
} else {
s.push(ch.to_ascii_uppercase());
}
}
s
}
}
impl Default for EnvConfig {
fn default() -> Self {
Self::new("")
}
}
#[async_trait]
impl Config for EnvConfig {
async fn get(&self, path: &str) -> Result<Option<Vec<u8>>, Error> {
Ok(std::env::var(self.env_key(path))
.ok()
.map(String::into_bytes))
}
fn watch(
&self,
path: &str,
_interval: Duration,
) -> tokio::sync::watch::Receiver<Option<Vec<u8>>> {
let current = std::env::var(self.env_key(path))
.ok()
.map(String::into_bytes);
let (tx, rx) = tokio::sync::watch::channel(current);
std::mem::forget(tx);
rx
}
fn source(&self) -> &'static str {
"env"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn env_key_translation_is_correct() {
let c = EnvConfig::new("APP_");
assert_eq!(c.env_key("db.pool.max"), "APP_DB_POOL_MAX");
assert_eq!(c.env_key("feature.flag"), "APP_FEATURE_FLAG");
assert_eq!(c.env_key(""), "APP_");
}
#[tokio::test]
async fn env_get_returns_value_or_none() {
let key = "TONIN_TEST_CONFIG_VALUE_42";
unsafe { std::env::set_var(key, "hello") };
let c = EnvConfig::new("TONIN_TEST_CONFIG_");
assert_eq!(c.get("value.42").await.unwrap(), Some(b"hello".to_vec()));
assert!(c.get("missing.key").await.unwrap().is_none());
unsafe { std::env::remove_var(key) };
}
#[tokio::test]
async fn chained_falls_through_to_next_source() {
let key = "TONIN_TEST_CHAINED_FALLBACK";
unsafe { std::env::set_var(key, "from-env") };
let first = Arc::new(EnvConfig::new("TONIN_TEST_EMPTY_"));
let second = Arc::new(EnvConfig::new("TONIN_TEST_CHAINED_"));
let chain = ChainedConfig::new(vec![first, second]);
assert_eq!(
chain.get("fallback").await.unwrap(),
Some(b"from-env".to_vec())
);
assert_eq!(chain.source(), "chained");
unsafe { std::env::remove_var(key) };
}
}