use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use vti_common::error::AppError;
use vti_common::store::KeyspaceHandle;
use crate::config::AppConfig;
const STORAGE_PREFIX: &[u8] = b"config:override:";
fn storage_key(key: &str) -> Vec<u8> {
let mut out = STORAGE_PREFIX.to_vec();
out.extend_from_slice(key.as_bytes());
out
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ConfigSource {
Env,
Db,
Toml,
Default,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConfigKeyKind {
String,
U64,
StringEnum(&'static [&'static str]),
PathAllowlist(&'static [&'static str]),
}
#[derive(Debug, Clone, Copy)]
pub struct ConfigKeyDef {
pub key: &'static str,
pub kind: ConfigKeyKind,
pub requires_restart: bool,
pub sensitive: bool,
}
pub const REGISTRY: &[ConfigKeyDef] = &[
ConfigKeyDef {
key: "server.host",
kind: ConfigKeyKind::String,
requires_restart: true,
sensitive: false,
},
ConfigKeyDef {
key: "server.port",
kind: ConfigKeyKind::U64,
requires_restart: true,
sensitive: false,
},
ConfigKeyDef {
key: "log.level",
kind: ConfigKeyKind::StringEnum(&["trace", "debug", "info", "warn", "error"]),
requires_restart: false,
sensitive: false,
},
];
pub fn lookup(key: &str) -> Option<&'static ConfigKeyDef> {
REGISTRY.iter().find(|d| d.key == key)
}
#[derive(Clone)]
pub struct ConfigStore {
ks: KeyspaceHandle,
}
impl ConfigStore {
pub fn new(ks: KeyspaceHandle) -> Self {
Self { ks }
}
pub async fn get(&self, key: &str) -> Result<Option<Value>, AppError> {
self.ks.get(storage_key(key)).await
}
pub async fn put(&self, key: &str, value: &Value) -> Result<(), AppError> {
self.ks.insert(storage_key(key), value).await
}
pub async fn delete(&self, key: &str) -> Result<(), AppError> {
self.ks.remove(storage_key(key)).await
}
pub async fn snapshot(&self) -> Result<HashMap<String, Value>, AppError> {
let pairs = self.ks.prefix_iter_raw(STORAGE_PREFIX.to_vec()).await?;
let prefix_len = STORAGE_PREFIX.len();
let mut out = HashMap::with_capacity(pairs.len());
for (key_bytes, value_bytes) in pairs {
let Some(rest) = key_bytes.get(prefix_len..) else {
continue;
};
let Ok(name) = String::from_utf8(rest.to_vec()) else {
tracing::warn!("skipping non-UTF-8 config-override key");
continue;
};
let Ok(value) = serde_json::from_slice::<Value>(&value_bytes) else {
tracing::warn!(key = %name, "skipping unparseable config-override value");
continue;
};
out.insert(name, value);
}
Ok(out)
}
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct EffectiveField {
pub key: String,
pub value: Value,
pub source: ConfigSource,
pub requires_restart: bool,
}
#[derive(Debug, Clone, Serialize)]
pub struct EffectiveConfig {
pub fields: Vec<EffectiveField>,
}
fn toml_layer_value(key: &str, cfg: &AppConfig) -> Option<Value> {
use crate::config::default_host_value;
use crate::config::default_port_value;
match key {
"server.host" => {
if cfg.server.host == default_host_value() {
None
} else {
Some(Value::String(cfg.server.host.clone()))
}
}
"server.port" => {
if cfg.server.port == default_port_value() {
None
} else {
Some(Value::Number(serde_json::Number::from(cfg.server.port)))
}
}
"log.level" => {
if cfg.log.level == "info" {
None
} else {
Some(Value::String(cfg.log.level.clone()))
}
}
_ => None,
}
}
fn default_layer_value(key: &str) -> Value {
use crate::config::default_host_value;
use crate::config::default_port_value;
match key {
"server.host" => Value::String(default_host_value()),
"server.port" => Value::Number(serde_json::Number::from(default_port_value())),
"log.level" => Value::String("info".into()),
_ => Value::Null, }
}
fn env_layer_value(key: &str) -> Option<Value> {
match key {
"server.host" => std::env::var("VTC_SERVER_HOST").ok().map(Value::String),
"server.port" => std::env::var("VTC_SERVER_PORT")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.map(|n| Value::Number(serde_json::Number::from(n))),
"log.level" => std::env::var("VTC_LOG_LEVEL").ok().map(Value::String),
_ => None,
}
}
pub async fn compute_effective_config(
cfg: &AppConfig,
db: &ConfigStore,
) -> Result<EffectiveConfig, AppError> {
let db_snapshot = db.snapshot().await?;
let mut fields = Vec::with_capacity(REGISTRY.len());
for def in REGISTRY {
let (value, source) = if let Some(v) = env_layer_value(def.key) {
(v, ConfigSource::Env)
} else if let Some(v) = db_snapshot.get(def.key) {
(v.clone(), ConfigSource::Db)
} else if let Some(v) = toml_layer_value(def.key, cfg) {
(v, ConfigSource::Toml)
} else {
(default_layer_value(def.key), ConfigSource::Default)
};
fields.push(EffectiveField {
key: def.key.into(),
value,
source,
requires_restart: def.requires_restart,
});
}
Ok(EffectiveConfig { fields })
}
pub fn validate_value(def: &ConfigKeyDef, value: &Value) -> Result<(), AppError> {
match def.kind {
ConfigKeyKind::String => {
if !value.is_string() {
return Err(AppError::Validation(format!(
"{} must be a string",
def.key
)));
}
Ok(())
}
ConfigKeyKind::U64 => match value.as_u64() {
Some(_) => Ok(()),
None => Err(AppError::Validation(format!(
"{} must be an unsigned integer",
def.key
))),
},
ConfigKeyKind::StringEnum(allowed) => {
let s = value
.as_str()
.ok_or_else(|| AppError::Validation(format!("{} must be a string", def.key)))?;
if allowed.contains(&s) {
Ok(())
} else {
Err(AppError::Validation(format!(
"{} must be one of {:?}, got {:?}",
def.key, allowed, s
)))
}
}
ConfigKeyKind::PathAllowlist(prefixes) => {
let s = value
.as_str()
.ok_or_else(|| AppError::Validation(format!("{} must be a string", def.key)))?;
if prefixes.iter().any(|p| s.starts_with(p)) {
Ok(())
} else {
Err(AppError::Validation(format!(
"{} must start with one of {:?}",
def.key, prefixes
)))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use vti_common::config::StoreConfig;
use vti_common::store::Store;
fn temp_store() -> (ConfigStore, tempfile::TempDir) {
let dir = tempfile::tempdir().expect("tempdir");
let cfg = StoreConfig {
data_dir: dir.path().to_path_buf(),
};
let store = Store::open(&cfg).expect("store");
let ks = store.keyspace("config-test").expect("ks");
(ConfigStore::new(ks), dir)
}
fn default_app_config() -> AppConfig {
toml::from_str("").expect("empty TOML parses")
}
#[tokio::test]
async fn put_then_get_returns_value() {
let (store, _dir) = temp_store();
store.put("log.level", &json!("debug")).await.unwrap();
let got = store.get("log.level").await.unwrap();
assert_eq!(got, Some(json!("debug")));
}
#[tokio::test]
async fn delete_removes_value() {
let (store, _dir) = temp_store();
store.put("log.level", &json!("debug")).await.unwrap();
store.delete("log.level").await.unwrap();
let got = store.get("log.level").await.unwrap();
assert!(got.is_none());
}
#[tokio::test]
async fn snapshot_returns_all_overrides() {
let (store, _dir) = temp_store();
store.put("log.level", &json!("debug")).await.unwrap();
store.put("server.host", &json!("10.0.0.1")).await.unwrap();
let snap = store.snapshot().await.unwrap();
assert_eq!(snap.len(), 2);
assert_eq!(snap.get("log.level"), Some(&json!("debug")));
assert_eq!(snap.get("server.host"), Some(&json!("10.0.0.1")));
}
#[test]
fn validate_string_kind() {
let def = lookup("server.host").unwrap();
assert!(validate_value(def, &json!("0.0.0.0")).is_ok());
assert!(validate_value(def, &json!(42)).is_err());
}
#[test]
fn validate_u64_kind() {
let def = lookup("server.port").unwrap();
assert!(validate_value(def, &json!(8200)).is_ok());
assert!(validate_value(def, &json!(-1)).is_err());
assert!(validate_value(def, &json!("8200")).is_err());
}
#[test]
fn validate_string_enum_kind() {
let def = lookup("log.level").unwrap();
for lvl in ["trace", "debug", "info", "warn", "error"] {
assert!(
validate_value(def, &json!(lvl)).is_ok(),
"{lvl} should pass"
);
}
assert!(validate_value(def, &json!("verbose")).is_err());
assert!(validate_value(def, &json!(42)).is_err());
}
#[tokio::test]
async fn effective_returns_defaults_when_no_overrides() {
let (store, _dir) = temp_store();
let cfg = default_app_config();
let eff = compute_effective_config(&cfg, &store).await.unwrap();
let by_key: HashMap<_, _> = eff.fields.iter().map(|f| (&*f.key, f)).collect();
assert_eq!(by_key["server.host"].source, ConfigSource::Default);
assert_eq!(by_key["server.host"].value, json!("0.0.0.0"));
assert_eq!(by_key["server.port"].source, ConfigSource::Default);
assert_eq!(by_key["server.port"].value, json!(8200));
assert_eq!(by_key["log.level"].source, ConfigSource::Default);
assert_eq!(by_key["log.level"].value, json!("info"));
}
#[tokio::test]
async fn db_layer_beats_toml() {
let (store, _dir) = temp_store();
let mut cfg = default_app_config();
cfg.log.level = "debug".into();
store.put("log.level", &json!("warn")).await.unwrap();
let eff = compute_effective_config(&cfg, &store).await.unwrap();
let f = eff.fields.iter().find(|f| f.key == "log.level").unwrap();
assert_eq!(f.source, ConfigSource::Db);
assert_eq!(f.value, json!("warn"));
}
#[tokio::test]
async fn toml_layer_used_when_no_db_override() {
let (store, _dir) = temp_store();
let mut cfg = default_app_config();
cfg.log.level = "debug".into();
let eff = compute_effective_config(&cfg, &store).await.unwrap();
let f = eff.fields.iter().find(|f| f.key == "log.level").unwrap();
assert_eq!(f.source, ConfigSource::Toml);
assert_eq!(f.value, json!("debug"));
}
#[tokio::test]
async fn requires_restart_flag_propagates_from_registry() {
let (store, _dir) = temp_store();
let cfg = default_app_config();
let eff = compute_effective_config(&cfg, &store).await.unwrap();
let host = eff.fields.iter().find(|f| f.key == "server.host").unwrap();
assert!(host.requires_restart);
let level = eff.fields.iter().find(|f| f.key == "log.level").unwrap();
assert!(!level.requires_restart);
}
}