#![allow(dead_code)]
use super::{ConnectionError, ConnectionResult};
use std::collections::HashMap;
pub trait EnvSource: Send + Sync {
fn get(&self, name: &str) -> Option<String>;
fn contains(&self, name: &str) -> bool {
self.get(name).is_some()
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct StdEnvSource;
impl EnvSource for StdEnvSource {
fn get(&self, name: &str) -> Option<String> {
std::env::var(name).ok()
}
}
#[derive(Debug, Clone, Default)]
pub struct MapEnvSource {
vars: HashMap<String, String>,
}
impl MapEnvSource {
pub fn new() -> Self {
Self::default()
}
pub fn set(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.vars.insert(name.into(), value.into());
self
}
pub fn with_vars(mut self, vars: HashMap<String, String>) -> Self {
self.vars.extend(vars);
self
}
}
impl EnvSource for MapEnvSource {
fn get(&self, name: &str) -> Option<String> {
self.vars.get(name).cloned()
}
}
#[derive(Debug, Clone)]
pub struct EnvExpander<S: EnvSource = StdEnvSource> {
source: S,
}
impl EnvExpander<StdEnvSource> {
pub fn new() -> Self {
Self {
source: StdEnvSource,
}
}
}
impl Default for EnvExpander<StdEnvSource> {
fn default() -> Self {
Self::new()
}
}
impl<S: EnvSource> EnvExpander<S> {
pub fn with_source(source: S) -> Self {
Self { source }
}
pub fn expand(&self, input: &str) -> ConnectionResult<String> {
let mut result = String::with_capacity(input.len());
let mut chars = input.chars().peekable();
while let Some(c) = chars.next() {
if c == '$' {
if chars.peek() == Some(&'{') {
chars.next(); let expanded = self.expand_braced(&mut chars)?;
result.push_str(&expanded);
} else if chars.peek().is_some_and(|c| c.is_alphabetic() || *c == '_') {
let expanded = self.expand_simple(&mut chars)?;
result.push_str(&expanded);
} else {
result.push(c);
}
} else {
result.push(c);
}
}
Ok(result)
}
fn expand_braced(
&self,
chars: &mut std::iter::Peekable<std::str::Chars>,
) -> ConnectionResult<String> {
let mut name = String::new();
let mut modifier = None;
let mut modifier_value = String::new();
while let Some(&c) = chars.peek() {
if c == '}' {
chars.next();
break;
} else if c == ':' && modifier.is_none() {
chars.next();
if let Some(&next) = chars.peek() {
modifier = Some(next);
chars.next();
}
} else if modifier.is_some() {
modifier_value.push(c);
chars.next();
} else {
name.push(c);
chars.next();
}
}
if name.is_empty() {
return Err(ConnectionError::InvalidEnvValue {
name: "".to_string(),
message: "Empty variable name".to_string(),
});
}
match self.source.get(&name) {
Some(value) if !value.is_empty() => Ok(value),
_ => {
match modifier {
Some('-') => Ok(modifier_value),
Some('?') => Err(ConnectionError::InvalidEnvValue {
name: name.clone(),
message: if modifier_value.is_empty() {
format!("Required variable '{}' is not set", name)
} else {
modifier_value
},
}),
Some('+') => {
Ok(String::new())
}
_ => Err(ConnectionError::EnvNotFound(name)),
}
}
}
}
fn expand_simple(
&self,
chars: &mut std::iter::Peekable<std::str::Chars>,
) -> ConnectionResult<String> {
let mut name = String::new();
while let Some(&c) = chars.peek() {
if c.is_alphanumeric() || c == '_' {
name.push(c);
chars.next();
} else {
break;
}
}
self.source
.get(&name)
.ok_or(ConnectionError::EnvNotFound(name))
}
pub fn expand_url(&self, url: &str) -> ConnectionResult<String> {
self.expand(url)
}
pub fn has_variables(input: &str) -> bool {
input.contains('$')
}
}
pub fn expand_env(input: &str) -> ConnectionResult<String> {
EnvExpander::new().expand(input)
}
#[cfg(test)]
mod tests {
use super::*;
fn test_source() -> MapEnvSource {
MapEnvSource::new()
.set("HOST", "localhost")
.set("PORT", "5432")
.set("USER", "testuser")
.set("PASS", "secret")
.set("EMPTY", "")
}
#[test]
fn test_expand_simple() {
let expander = EnvExpander::with_source(test_source());
assert_eq!(
expander.expand("postgres://$HOST/db").unwrap(),
"postgres://localhost/db"
);
}
#[test]
fn test_expand_braced() {
let expander = EnvExpander::with_source(test_source());
assert_eq!(
expander.expand("postgres://${HOST}:${PORT}/db").unwrap(),
"postgres://localhost:5432/db"
);
}
#[test]
fn test_expand_default() {
let expander = EnvExpander::with_source(test_source());
assert_eq!(expander.expand("${HOST:-default}").unwrap(), "localhost");
assert_eq!(expander.expand("${MISSING:-default}").unwrap(), "default");
assert_eq!(expander.expand("${EMPTY:-default}").unwrap(), "default");
}
#[test]
fn test_expand_required() {
let expander = EnvExpander::with_source(test_source());
assert_eq!(
expander.expand("${HOST:?Host is required}").unwrap(),
"localhost"
);
let result = expander.expand("${MISSING:?Missing is required}");
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Missing is required")
);
}
#[test]
fn test_expand_missing() {
let expander = EnvExpander::with_source(test_source());
let result = expander.expand("${MISSING}");
assert!(matches!(result, Err(ConnectionError::EnvNotFound(_))));
}
#[test]
fn test_expand_full_url() {
let expander = EnvExpander::with_source(test_source());
let url = "postgres://${USER}:${PASS}@${HOST}:${PORT}/mydb?sslmode=require";
let expanded = expander.expand(url).unwrap();
assert_eq!(
expanded,
"postgres://testuser:secret@localhost:5432/mydb?sslmode=require"
);
}
#[test]
fn test_has_variables() {
assert!(EnvExpander::<StdEnvSource>::has_variables("${VAR}"));
assert!(EnvExpander::<StdEnvSource>::has_variables("$VAR"));
assert!(!EnvExpander::<StdEnvSource>::has_variables("no variables"));
}
#[test]
fn test_literal_dollar() {
let expander = EnvExpander::with_source(test_source());
assert_eq!(expander.expand("cost: $5").unwrap(), "cost: $5");
}
}