use regex::Regex;
use serde_json::map::Entry;
use serde_json::{Map, Number, Value};
use std::sync::OnceLock;
use qubit_value::{MultiValues, ValueError};
use super::{ConfigError, ConfigReader, ConfigResult, Property};
static VARIABLE_PATTERN: OnceLock<Regex> = OnceLock::new();
#[inline]
fn get_variable_pattern() -> &'static Regex {
VARIABLE_PATTERN.get_or_init(|| {
Regex::new(r"\$\{([^}]+)\}").expect("Failed to compile variable pattern regex")
})
}
pub(crate) fn map_value_error(key: &str, err: ValueError) -> ConfigError {
match err {
ValueError::NoValue => ConfigError::PropertyHasNoValue(key.to_string()),
ValueError::TypeMismatch { expected, actual } => {
ConfigError::type_mismatch_at(key, expected, actual)
}
ValueError::ConversionFailed { from, to } => {
ConfigError::conversion_error_at(key, format!("From {from} to {to}"))
}
ValueError::ConversionError(msg) => ConfigError::conversion_error_at(key, msg),
ValueError::IndexOutOfBounds { index, len } => ConfigError::IndexOutOfBounds { index, len },
ValueError::JsonSerializationError(msg) => {
ConfigError::conversion_error_at(key, format!("JSON serialization error: {msg}"))
}
ValueError::JsonDeserializationError(msg) => {
ConfigError::conversion_error_at(key, format!("JSON deserialization error: {msg}"))
}
}
}
pub(crate) fn substitute_variables<R: ConfigReader + ?Sized>(
value: &str,
config: &R,
max_depth: usize,
) -> ConfigResult<String> {
if value.is_empty() {
return Ok(value.to_string());
}
let pattern = get_variable_pattern();
let mut result = value.to_string();
let mut depth = 0;
loop {
if !pattern.is_match(&result) {
break;
}
if depth >= max_depth {
return Err(ConfigError::SubstitutionDepthExceeded(max_depth));
}
let mut first_error: Option<ConfigError> = None;
let replaced = pattern.replace_all(&result, |caps: ®ex::Captures| {
let var_name = caps.get(1).map(|m| m.as_str()).unwrap_or_default();
match find_variable_value(var_name, config) {
Ok(v) => v,
Err(err) => {
if first_error.is_none() {
first_error = Some(err);
}
caps.get(0)
.map(|m| m.as_str().to_string())
.unwrap_or_default()
}
}
});
if let Some(err) = first_error {
return Err(err);
}
result = replaced.into_owned();
depth += 1;
}
Ok(result)
}
fn find_variable_value<R: ConfigReader + ?Sized>(
var_name: &str,
config: &R,
) -> ConfigResult<String> {
match config.get::<String>(var_name) {
Ok(value) => Ok(value),
Err(ConfigError::PropertyNotFound(_)) | Err(ConfigError::PropertyHasNoValue(_)) => {
std::env::var(var_name).map_err(|_| {
ConfigError::SubstitutionError(format!("Cannot resolve variable: {}", var_name))
})
}
Err(err) => Err(err),
}
}
pub(crate) fn insert_deserialize_value(root: &mut Map<String, Value>, key: &str, value: Value) {
if !key.contains('.') || key.is_empty() {
root.insert(key.to_string(), value);
return;
}
let fallback_value = value.clone();
if try_insert_nested_json_value(root, key, value).is_err() {
root.insert(key.to_string(), fallback_value);
}
}
fn try_insert_nested_json_value(
root: &mut Map<String, Value>,
key: &str,
value: Value,
) -> Result<(), ()> {
let mut current = root;
let mut parts = key.split('.').peekable();
let mut leaf_value = Some(value);
while let Some(part) = parts.next() {
if part.is_empty() {
return Err(());
}
if parts.peek().is_none() {
match current.entry(part.to_string()) {
Entry::Vacant(entry) => {
let Some(value) = leaf_value.take() else {
return Err(());
};
entry.insert(value);
return Ok(());
}
Entry::Occupied(_) => return Err(()),
}
}
let next = match current.entry(part.to_string()) {
Entry::Vacant(entry) => entry.insert(Value::Object(Map::new())),
Entry::Occupied(entry) => entry.into_mut(),
};
match next {
Value::Object(obj) => {
current = obj;
}
_ => return Err(()),
}
}
Err(())
}
pub(crate) fn property_to_json_value(prop: &Property) -> Value {
let mv = prop.value();
match mv {
MultiValues::Empty(_) => Value::Null,
MultiValues::Bool(v) => {
if v.len() == 1 {
Value::Bool(v[0])
} else {
Value::Array(v.iter().map(|b| Value::Bool(*b)).collect())
}
}
MultiValues::Int8(v) => scalar_or_array(v, |x| Value::Number((*x).into())),
MultiValues::Int16(v) => scalar_or_array(v, |x| Value::Number((*x).into())),
MultiValues::Int32(v) => scalar_or_array(v, |x| Value::Number((*x).into())),
MultiValues::Int64(v) => scalar_or_array(v, |x| Value::Number((*x).into())),
MultiValues::IntSize(v) => scalar_or_array(v, |x| Value::Number(Number::from(*x as i64))),
MultiValues::UInt8(v) => scalar_or_array(v, |x| Value::Number((*x).into())),
MultiValues::UInt16(v) => scalar_or_array(v, |x| Value::Number((*x).into())),
MultiValues::UInt32(v) => scalar_or_array(v, |x| Value::Number((*x).into())),
MultiValues::UInt64(v) => scalar_or_array(v, |x| Value::Number((*x).into())),
MultiValues::UIntSize(v) => scalar_or_array(v, |x| Value::Number(Number::from(*x as u64))),
MultiValues::Float32(v) => scalar_or_array(v, |x| {
Number::from_f64(*x as f64)
.map(Value::Number)
.unwrap_or(Value::Null)
}),
MultiValues::Float64(v) => scalar_or_array(v, |x| {
Number::from_f64(*x)
.map(Value::Number)
.unwrap_or(Value::Null)
}),
MultiValues::String(v) => scalar_or_array(v, |x| Value::String(x.clone())),
MultiValues::Duration(v) => {
scalar_or_array(v, |x| Value::String(format!("{}ms", x.as_millis())))
}
MultiValues::Url(v) => scalar_or_array(v, |x| Value::String(x.to_string())),
MultiValues::StringMap(v) => {
if v.len() == 1 {
let obj: Map<String, Value> = v[0]
.iter()
.map(|(k, val)| (k.clone(), Value::String(val.clone())))
.collect();
Value::Object(obj)
} else {
Value::Array(
v.iter()
.map(|m| {
let obj: Map<String, Value> = m
.iter()
.map(|(k, val)| (k.clone(), Value::String(val.clone())))
.collect();
Value::Object(obj)
})
.collect(),
)
}
}
MultiValues::Json(v) => {
if v.len() == 1 {
v[0].clone()
} else {
Value::Array(v.clone())
}
}
MultiValues::Char(v) => scalar_or_array(v, |x| Value::String(x.to_string())),
MultiValues::BigInteger(v) => scalar_or_array(v, |x| Value::String(x.to_string())),
MultiValues::BigDecimal(v) => scalar_or_array(v, |x| Value::String(x.to_string())),
MultiValues::DateTime(v) => scalar_or_array(v, |x| Value::String(x.to_string())),
MultiValues::Date(v) => scalar_or_array(v, |x| Value::String(x.to_string())),
MultiValues::Time(v) => scalar_or_array(v, |x| Value::String(x.to_string())),
MultiValues::Instant(v) => scalar_or_array(v, |x| Value::String(x.to_string())),
MultiValues::Int128(v) => scalar_or_array(v, |x| Value::String(x.to_string())),
MultiValues::UInt128(v) => scalar_or_array(v, |x| Value::String(x.to_string())),
}
}
fn scalar_or_array<T, F>(v: &[T], f: F) -> Value
where
F: Fn(&T) -> Value,
{
if v.len() == 1 {
f(&v[0])
} else {
Value::Array(v.iter().map(f).collect())
}
}
#[cfg(test)]
mod substitute_variable_tests {
use super::substitute_variables;
use crate::{Config, ConfigError};
#[test]
fn test_substitute_simple() {
let mut config = Config::new();
config.set("name", "world").unwrap();
let result = substitute_variables("Hello, ${name}!", &config, 10).unwrap();
assert_eq!(result, "Hello, world!");
}
#[test]
fn test_substitute_multiple() {
let mut config = Config::new();
config.set("host", "localhost").unwrap();
config.set("port", "8080").unwrap();
let result = substitute_variables("http://${host}:${port}/api", &config, 10).unwrap();
assert_eq!(result, "http://localhost:8080/api");
}
#[test]
fn test_substitute_repeated_placeholder() {
let mut config = Config::new();
config.set("name", "world").unwrap();
let result = substitute_variables("${name}-${name}-${name}", &config, 10).unwrap();
assert_eq!(result, "world-world-world");
}
#[test]
fn test_substitute_recursive() {
let mut config = Config::new();
config.set("a", "value_a").unwrap();
config.set("b", "${a}_b").unwrap();
config.set("c", "${b}_c").unwrap();
let result = substitute_variables("${c}", &config, 10).unwrap();
assert_eq!(result, "value_a_b_c");
}
#[test]
fn test_substitute_depth_exceeded() {
let mut config = Config::new();
config.set("a", "${b}").unwrap();
config.set("b", "${a}").unwrap();
let result = substitute_variables("${a}", &config, 5);
assert!(matches!(
result,
Err(ConfigError::SubstitutionDepthExceeded(5))
));
}
#[test]
fn test_substitute_env_var() {
unsafe {
std::env::set_var("TEST_VAR", "test_value");
}
let config = Config::new();
let result = substitute_variables("Value: ${TEST_VAR}", &config, 10).unwrap();
assert_eq!(result, "Value: test_value");
unsafe {
std::env::remove_var("TEST_VAR");
}
}
#[test]
fn test_substitute_empty_string() {
let config = Config::new();
let result = substitute_variables("", &config, 10).unwrap();
assert_eq!(result, "");
}
#[test]
fn test_substitute_zero_depth_without_placeholders_should_succeed() {
let config = Config::new();
let result = substitute_variables("plain text", &config, 0).unwrap();
assert_eq!(result, "plain text");
}
#[test]
fn test_substitute_variable_not_found() {
let config = Config::new();
let result = substitute_variables("${NONEXISTENT_VAR}", &config, 10);
assert!(matches!(result, Err(ConfigError::SubstitutionError(_))));
if let Err(ConfigError::SubstitutionError(msg)) = result {
assert!(msg.contains("Cannot resolve variable: NONEXISTENT_VAR"));
}
}
#[test]
fn test_substitute_no_variables() {
let config = Config::new();
let result = substitute_variables("Plain text with no variables", &config, 10).unwrap();
assert_eq!(result, "Plain text with no variables");
}
#[test]
fn test_substitute_mixed_sources() {
unsafe {
std::env::set_var("ENV_VAR", "from_env");
}
let mut config = Config::new();
config.set("CONFIG_VAR", "from_config").unwrap();
let result = substitute_variables("${CONFIG_VAR} and ${ENV_VAR}", &config, 10).unwrap();
assert_eq!(result, "from_config and from_env");
unsafe {
std::env::remove_var("ENV_VAR");
}
}
#[test]
fn test_substitute_config_priority_over_env() {
unsafe {
std::env::set_var("SHARED_VAR", "from_env");
}
let mut config = Config::new();
config.set("SHARED_VAR", "from_config").unwrap();
let result = substitute_variables("${SHARED_VAR}", &config, 10).unwrap();
assert_eq!(result, "from_config");
unsafe {
std::env::remove_var("SHARED_VAR");
}
}
#[test]
fn test_substitute_does_not_fallback_to_env_on_config_type_error() {
unsafe {
std::env::set_var("STRICT_VAR", "from_env");
}
let mut config = Config::new();
config.set("STRICT_VAR", 8080i32).unwrap();
let result = substitute_variables("${STRICT_VAR}", &config, 10);
assert!(matches!(
result,
Err(ConfigError::TypeMismatch { .. }) | Err(ConfigError::ConversionError { .. })
));
unsafe {
std::env::remove_var("STRICT_VAR");
}
}
}