use crate::*;
use std::collections::BTreeMap;
use std::sync::{Arc, Mutex};
fn tc<S: serde::Serialize, D: serde::de::DeserializeOwned>(
s: &S,
) -> K2Result<D> {
serde_json::from_str(
&serde_json::to_string(s)
.map_err(|e| K2Error::other_src("encode config", e))?,
)
.map_err(|e| K2Error::other_src("decode config", e))
}
pub type ConfigUpdateCb =
Arc<dyn Fn(serde_json::Value) + 'static + Send + Sync>;
#[derive(Clone, serde::Serialize, serde::Deserialize)]
#[serde(transparent, rename_all = "camelCase")]
struct ConfigEntry {
pub value: serde_json::Value,
#[serde(skip, default)]
pub update_cb: Option<ConfigUpdateCb>,
}
impl std::fmt::Debug for ConfigEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConfigEntry")
.field("value", &self.value)
.field("has_update_cb", &self.update_cb.is_some())
.finish()
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(untagged, rename_all = "camelCase")]
enum ConfigMap {
ConfigMap(BTreeMap<String, Box<Self>>),
ConfigEntry(ConfigEntry),
}
impl Default for ConfigMap {
fn default() -> Self {
Self::ConfigMap(BTreeMap::new())
}
}
#[derive(Clone)]
struct Inner {
map: ConfigMap,
are_defaults_set: bool,
did_validate: bool,
is_runtime: bool,
}
pub struct Config(Mutex<Inner>);
impl serde::Serialize for Config {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.0.lock().unwrap().map.serialize(serializer)
}
}
impl std::fmt::Debug for Config {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.lock().unwrap().map.fmt(f)
}
}
impl Default for Config {
fn default() -> Self {
Self(Mutex::new(Inner {
map: ConfigMap::default(),
are_defaults_set: false,
did_validate: false,
is_runtime: false,
}))
}
}
impl Clone for Config {
fn clone(&self) -> Self {
let lock = self.0.lock().expect("config mutex poisoned");
Self(Mutex::new(Inner {
map: lock.map.clone(),
are_defaults_set: lock.are_defaults_set,
did_validate: lock.did_validate,
is_runtime: lock.is_runtime,
}))
}
}
impl Config {
pub fn mark_defaults_set(&self) {
self.0.lock().unwrap().are_defaults_set = true;
}
pub fn mark_validated(&self) -> bool {
let mut lock = self.0.lock().unwrap();
let out = lock.did_validate;
lock.did_validate = true;
out
}
pub fn mark_runtime(&self) {
self.0.lock().unwrap().is_runtime = true;
}
pub fn get_module_config<D: serde::de::DeserializeOwned>(
&self,
) -> K2Result<D> {
let lock = self.0.lock().unwrap();
tc(&lock.map)
}
pub fn set_module_config<S: serde::Serialize>(
&self,
config: &S,
) -> K2Result<()> {
let in_map: ConfigMap = tc(config)?;
let debug_path = format!("{in_map:?}");
let mut updates = Vec::new();
{
let mut lock = self.0.lock().unwrap();
let are_defaults_set = lock.are_defaults_set;
let is_runtime = lock.is_runtime;
let old_map: &mut ConfigMap = &mut lock.map;
let new_map: &ConfigMap = &in_map;
fn apply_map(
debug_path: &str,
are_defaults_set: bool,
is_runtime: bool,
updates: &mut Vec<(ConfigUpdateCb, serde_json::Value)>,
old_map: &mut ConfigMap,
new_map: &ConfigMap,
) -> K2Result<()> {
match new_map {
ConfigMap::ConfigMap(new_map) => match old_map {
ConfigMap::ConfigMap(old_map) => {
for (key, new_map) in new_map.iter() {
if are_defaults_set
&& !old_map.contains_key(key)
{
tracing::warn!(
debug_path,
"this config parameter may be unused"
);
}
let old_map =
old_map.entry(key.clone()).or_default();
apply_map(
debug_path,
are_defaults_set,
is_runtime,
updates,
old_map,
new_map,
)?;
}
}
ConfigMap::ConfigEntry(_) => {
return Err(K2Error::other(format!(
"{debug_path} attempted to insert a map where an entry exists",
)));
}
},
ConfigMap::ConfigEntry(new_entry) => match old_map {
ConfigMap::ConfigMap(m) => {
if !m.is_empty() {
return Err(K2Error::other(format!(
"{debug_path} attempted to insert an entry where a map exists",
)));
}
*old_map =
ConfigMap::ConfigEntry(new_entry.clone());
if is_runtime {
tracing::warn!(
debug_path,
"no update callback for runtime config alteration"
);
}
}
ConfigMap::ConfigEntry(old_entry) => {
old_entry.value = new_entry.value.clone();
if let Some(update_cb) = &old_entry.update_cb {
updates.push((
update_cb.clone(),
new_entry.value.clone(),
));
} else if is_runtime {
tracing::warn!(
debug_path,
"no update callback for runtime config alteration"
);
}
}
},
}
Ok(())
}
apply_map(
&debug_path,
are_defaults_set,
is_runtime,
&mut updates,
old_map,
new_map,
)?;
}
for (update_cb, value) in updates {
update_cb(value);
}
Ok(())
}
pub fn register_entry_update_cb<D: std::fmt::Display>(
&self,
path: &[D],
update_cb: ConfigUpdateCb,
) -> K2Result<()> {
let value = {
let mut lock = self.0.lock().unwrap();
let mut cur: &mut ConfigMap = &mut lock.map;
for path in path {
let key = path.to_string();
match cur {
ConfigMap::ConfigMap(m) => cur = m.entry(key).or_default(),
ConfigMap::ConfigEntry(_) => {
return Err(K2Error::other(
"attempted to insert a map where an entry exists",
));
}
}
}
match cur {
ConfigMap::ConfigMap(m) => {
if !m.is_empty() {
return Err(K2Error::other(
"attempted to insert an entry where a map exists",
));
}
*cur = ConfigMap::ConfigEntry(ConfigEntry {
value: serde_json::Value::Null,
update_cb: Some(update_cb.clone()),
});
serde_json::Value::Null
}
ConfigMap::ConfigEntry(e) => {
e.update_cb = Some(update_cb.clone());
e.value.clone()
}
}
};
update_cb(value);
Ok(())
}
pub fn merge_config_overrides(
self,
config_overrides: &Self,
) -> K2Result<Self> {
{
let lock_overrides =
config_overrides.0.lock().expect("config mutex poisoned");
let mut lock = self.0.lock().expect("config mutex poisoned");
lock.map.merge_overrides(&lock_overrides.map);
}
Ok(self)
}
}
impl ConfigMap {
fn merge_overrides(&mut self, overrides: &Self) {
match (self, overrides) {
(Self::ConfigMap(self_map), Self::ConfigMap(overrides)) => {
for (key, value) in overrides.iter() {
match self_map.get_mut(key) {
Some(current_value) => {
current_value.merge_overrides(value);
}
None => {
self_map.insert(key.clone(), value.clone());
}
}
}
}
(current, overrides) => {
*current = overrides.clone();
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn warns_unused() {
enable_tracing();
let c = Config::default();
c.set_module_config(&serde_json::json!({"apples": "red"}))
.unwrap();
c.mark_defaults_set();
c.set_module_config(&serde_json::json!({"apples": "green"}))
.unwrap();
c.set_module_config(&serde_json::json!({"bananas": 42}))
.unwrap();
}
#[test]
fn warns_no_runtime_cb() {
enable_tracing();
let c = Config::default();
c.set_module_config(&serde_json::json!({"apples": "red"}))
.unwrap();
c.mark_runtime();
c.set_module_config(&serde_json::json!({"apples": "green"}))
.unwrap();
c.set_module_config(&serde_json::json!({"bananas": 42}))
.unwrap();
}
#[test]
fn config_usage_example() {
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
struct SubConfig {
pub apples: String,
pub bananas: u32,
}
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
struct ModConfig {
pub my_module: SubConfig,
}
let c = Config::default();
let expect = ModConfig {
my_module: SubConfig {
apples: "red".to_string(),
bananas: 42,
},
};
c.set_module_config(&expect).unwrap();
println!("{}", serde_json::to_string_pretty(&c).unwrap());
let resp: ModConfig = c.get_module_config().unwrap();
assert_eq!(expect, resp);
use std::sync::atomic::*;
let update = Arc::new(AtomicU32::new(0));
let update2 = update.clone();
c.register_entry_update_cb(
&["myModule", "bananas"],
Arc::new(move |v| {
let v: u32 =
serde_json::from_str(&serde_json::to_string(&v).unwrap())
.unwrap();
update2.store(v, Ordering::SeqCst);
}),
)
.unwrap();
c.set_module_config(&serde_json::json!({
"myModule": {
"bananas": 99,
}
}))
.unwrap();
assert_eq!(99, update.load(Ordering::SeqCst));
}
fn enable_tracing() {
let _ = tracing_subscriber::fmt()
.with_test_writer()
.with_env_filter(
tracing_subscriber::EnvFilter::builder()
.with_default_directive(tracing::Level::DEBUG.into())
.from_env_lossy(),
)
.try_init();
}
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct Tx5TransportConfig {
pub signal_allow_plain_text: bool,
pub server_url: String,
pub timeout_s: u32,
pub webrtc_connect_timeout_s: u32,
pub tracing_enabled: bool,
pub ephemeral_udp_port_min: Option<u16>,
pub ephemeral_udp_port_max: Option<u16>,
pub danger_force_signal_relay: bool,
}
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct Tx5TransportModConfig {
pub tx5_transport: Tx5TransportConfig,
}
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
struct CoreBootstrapConfig {
pub server_url: Option<String>,
pub backoff_min_ms: u32,
pub backoff_max_ms: u32,
}
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
struct CoreBootstrapModConfig {
pub core_bootstrap: CoreBootstrapConfig,
}
#[test]
fn test_should_clone_config() {
let mod_config = CoreBootstrapModConfig {
core_bootstrap: CoreBootstrapConfig {
server_url: Some("https://example.com".to_string()),
backoff_min_ms: 1000,
backoff_max_ms: 60000,
},
};
let c = Config::default();
c.set_module_config(&mod_config).unwrap();
let clone = c.clone();
let v: CoreBootstrapModConfig =
clone.get_module_config().expect("failed to get");
assert_eq!(v, mod_config);
}
#[test]
fn test_should_merge_config_override() {
let config = Config::default();
let src_bootstrap = CoreBootstrapModConfig {
core_bootstrap: CoreBootstrapConfig {
server_url: None,
backoff_min_ms: 1000,
backoff_max_ms: 60000,
},
};
let overrides = Config::default();
let override_bootstrap = CoreBootstrapModConfig {
core_bootstrap: CoreBootstrapConfig {
server_url: Some("https://override.com".to_string()),
backoff_min_ms: 2000,
backoff_max_ms: 120000,
},
};
config
.set_module_config(&src_bootstrap)
.expect("failed to set source config");
let tx5 = Tx5TransportModConfig {
tx5_transport: Tx5TransportConfig {
signal_allow_plain_text: true,
server_url: "wss://signal.example.com".to_string(),
timeout_s: 30,
webrtc_connect_timeout_s: 10,
tracing_enabled: false,
ephemeral_udp_port_min: Some(10000),
ephemeral_udp_port_max: Some(20000),
danger_force_signal_relay: false,
},
};
overrides
.set_module_config(&override_bootstrap)
.expect("failed to set override config");
overrides
.set_module_config(&tx5)
.expect("failed to set override config");
let merged = config
.merge_config_overrides(&overrides)
.expect("failed to merge config overrides");
let new_bootstrap: CoreBootstrapModConfig = merged
.get_module_config()
.expect("failed to get merged config");
assert_eq!(
new_bootstrap.core_bootstrap.server_url,
Some("https://override.com".to_string())
);
assert_eq!(new_bootstrap.core_bootstrap.backoff_min_ms, 2000);
assert_eq!(new_bootstrap.core_bootstrap.backoff_max_ms, 120000);
let new_tx5: Tx5TransportModConfig = merged
.get_module_config()
.expect("failed to get merged config");
assert_eq!(new_tx5, tx5);
}
}