use std::vec::Vec;
use indexmap::IndexMap;
use crate::config_value::{ConfigValue, Sourced};
use crate::provenance::{Override, Provenance};
#[derive(Debug)]
pub struct MergeResult {
pub value: ConfigValue,
pub overrides: Vec<Override>,
}
impl MergeResult {
pub fn new(value: ConfigValue) -> Self {
Self {
value,
overrides: Vec::new(),
}
}
}
pub fn merge(lower: ConfigValue, upper: ConfigValue, path: &str) -> MergeResult {
let mut overrides = Vec::new();
let value = merge_inner(lower, upper, path, &mut overrides);
MergeResult { value, overrides }
}
pub fn merge_layers<I>(layers: I) -> MergeResult
where
I: IntoIterator<Item = ConfigValue>,
{
let mut iter = layers.into_iter();
let Some(first) = iter.next() else {
return MergeResult::new(ConfigValue::Object(Sourced::new(IndexMap::default())));
};
let mut result = MergeResult::new(first);
for upper in iter {
let merged = merge(result.value, upper, "");
result.value = merged.value;
result.overrides.extend(merged.overrides);
}
result
}
fn merge_inner(
lower: ConfigValue,
upper: ConfigValue,
path: &str,
overrides: &mut Vec<Override>,
) -> ConfigValue {
match (lower, upper) {
(ConfigValue::Object(mut lower_obj), ConfigValue::Object(upper_obj)) => {
for (key, upper_value) in upper_obj.value {
let key_path = if path.is_empty() {
key.clone()
} else {
format!("{path}.{key}")
};
if let Some(lower_value) = lower_obj.value.shift_remove(&key) {
let merged = merge_inner(lower_value, upper_value, &key_path, overrides);
lower_obj.value.insert(key, merged);
} else {
lower_obj.value.insert(key, upper_value);
}
}
if upper_obj.provenance.is_some() {
lower_obj.provenance = upper_obj.provenance;
}
ConfigValue::Object(lower_obj)
}
(lower, upper) => {
if let (Some(lower_prov), Some(upper_prov)) =
(get_provenance(&lower), get_provenance(&upper))
{
overrides.push(Override::new(path, upper_prov.clone(), lower_prov.clone()));
}
upper
}
}
}
fn get_provenance(value: &ConfigValue) -> Option<&Provenance> {
match value {
ConfigValue::Null(s) => s.provenance.as_ref(),
ConfigValue::Bool(s) => s.provenance.as_ref(),
ConfigValue::Integer(s) => s.provenance.as_ref(),
ConfigValue::Float(s) => s.provenance.as_ref(),
ConfigValue::String(s) => s.provenance.as_ref(),
ConfigValue::Array(s) => s.provenance.as_ref(),
ConfigValue::Object(s) => s.provenance.as_ref(),
ConfigValue::Enum(s) => s.provenance.as_ref(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use crate::provenance::ConfigFile;
fn string_with_prov(value: &str, prov: Provenance) -> ConfigValue {
ConfigValue::String(Sourced {
value: value.to_string(),
span: None,
provenance: Some(prov),
})
}
fn int_with_prov(value: i64, prov: Provenance) -> ConfigValue {
ConfigValue::Integer(Sourced {
value,
span: None,
provenance: Some(prov),
})
}
fn object(entries: Vec<(&str, ConfigValue)>) -> ConfigValue {
let map: IndexMap<String, ConfigValue, std::hash::RandomState> = entries
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect();
ConfigValue::Object(Sourced::new(map))
}
#[test]
fn test_merge_disjoint_objects() {
let lower = object(vec![("a", int_with_prov(1, Provenance::Default))]);
let upper = object(vec![("b", int_with_prov(2, Provenance::Default))]);
let result = merge(lower, upper, "");
if let ConfigValue::Object(obj) = result.value {
assert_eq!(obj.value.len(), 2);
assert!(obj.value.contains_key("a"));
assert!(obj.value.contains_key("b"));
} else {
panic!("expected object");
}
assert!(result.overrides.is_empty());
}
#[test]
fn test_merge_overlapping_objects() {
let file = Arc::new(ConfigFile::new("config.json", "{}"));
let file_prov = Provenance::file(file, "port", 0, 4);
let env_prov = Provenance::env("REEF__PORT", "9000");
let lower = object(vec![("port", int_with_prov(8080, file_prov))]);
let upper = object(vec![("port", int_with_prov(9000, env_prov))]);
let result = merge(lower, upper, "");
if let ConfigValue::Object(obj) = result.value {
assert_eq!(obj.value.len(), 1);
if let Some(ConfigValue::Integer(port)) = obj.value.get("port") {
assert_eq!(port.value, 9000); } else {
panic!("expected integer");
}
} else {
panic!("expected object");
}
assert_eq!(result.overrides.len(), 1);
assert_eq!(result.overrides[0].path, "port");
}
#[test]
fn test_merge_nested_objects() {
let file_prov = || Provenance::Default;
let lower = object(vec![(
"smtp",
object(vec![
("host", string_with_prov("mail.example.com", file_prov())),
("port", int_with_prov(587, file_prov())),
]),
)]);
let upper = object(vec![(
"smtp",
object(vec![(
"host",
string_with_prov("override.com", file_prov()),
)]),
)]);
let result = merge(lower, upper, "");
if let ConfigValue::Object(obj) = result.value {
if let Some(ConfigValue::Object(smtp)) = obj.value.get("smtp") {
assert_eq!(smtp.value.len(), 2);
if let Some(ConfigValue::String(host)) = smtp.value.get("host") {
assert_eq!(host.value, "override.com"); }
if let Some(ConfigValue::Integer(port)) = smtp.value.get("port") {
assert_eq!(port.value, 587); }
} else {
panic!("expected smtp object");
}
} else {
panic!("expected object");
}
}
#[test]
fn test_merge_scalar_replaces() {
let lower = int_with_prov(1, Provenance::Default);
let upper = int_with_prov(2, Provenance::env("VAR", "2"));
let result = merge(lower, upper, "value");
if let ConfigValue::Integer(i) = result.value {
assert_eq!(i.value, 2);
} else {
panic!("expected integer");
}
assert_eq!(result.overrides.len(), 1);
}
#[test]
fn test_merge_layers_empty() {
let result = merge_layers(Vec::<ConfigValue>::new());
if let ConfigValue::Object(obj) = result.value {
assert!(obj.value.is_empty());
} else {
panic!("expected empty object");
}
}
#[test]
fn test_merge_layers_single() {
let layer = object(vec![("port", int_with_prov(8080, Provenance::Default))]);
let result = merge_layers(vec![layer]);
if let ConfigValue::Object(obj) = result.value {
assert_eq!(obj.value.len(), 1);
} else {
panic!("expected object");
}
}
#[test]
fn test_merge_layers_multiple() {
let file_prov = Provenance::Default;
let env_prov = Provenance::env("REEF__PORT", "9000");
let cli_prov = Provenance::cli("--config.port", "8080");
let file_layer = object(vec![
("port", int_with_prov(80, file_prov.clone())),
("host", string_with_prov("file.com", file_prov)),
]);
let env_layer = object(vec![("port", int_with_prov(9000, env_prov))]);
let cli_layer = object(vec![("port", int_with_prov(8080, cli_prov))]);
let result = merge_layers(vec![file_layer, env_layer, cli_layer]);
if let ConfigValue::Object(obj) = result.value {
if let Some(ConfigValue::Integer(port)) = obj.value.get("port") {
assert_eq!(port.value, 8080);
}
if let Some(ConfigValue::String(host)) = obj.value.get("host") {
assert_eq!(host.value, "file.com");
}
} else {
panic!("expected object");
}
assert_eq!(result.overrides.len(), 2);
}
#[test]
fn test_merge_object_over_scalar() {
let lower = object(vec![(
"smtp",
string_with_prov("legacy", Provenance::Default),
)]);
let upper = object(vec![(
"smtp",
object(vec![(
"host",
string_with_prov("mail.com", Provenance::Default),
)]),
)]);
let result = merge(lower, upper, "");
if let ConfigValue::Object(obj) = result.value {
if let Some(ConfigValue::Object(smtp)) = obj.value.get("smtp") {
assert_eq!(smtp.value.len(), 1);
assert!(smtp.value.contains_key("host"));
} else {
panic!("expected smtp to be object");
}
} else {
panic!("expected object");
}
}
#[test]
fn test_merge_scalar_over_object() {
let lower = object(vec![(
"smtp",
object(vec![(
"host",
string_with_prov("mail.com", Provenance::Default),
)]),
)]);
let upper = object(vec![(
"smtp",
string_with_prov("disabled", Provenance::Default),
)]);
let result = merge(lower, upper, "");
if let ConfigValue::Object(obj) = result.value {
if let Some(ConfigValue::String(smtp)) = obj.value.get("smtp") {
assert_eq!(smtp.value, "disabled");
} else {
panic!("expected smtp to be string");
}
} else {
panic!("expected object");
}
}
}