pub mod error;
use std::collections::HashMap;
pub use self::error::VarExpandError;
#[derive(Debug, Default, Clone)]
pub struct VarEnv {
inner: HashMap<String, String>,
#[cfg(windows)]
lookup_index: HashMap<String, String>,
}
impl VarEnv {
#[must_use]
pub fn new() -> Self {
Self {
inner: HashMap::new(),
#[cfg(windows)]
lookup_index: HashMap::new(),
}
}
#[must_use]
pub fn from_os() -> Self {
let map: HashMap<String, String> = std::env::vars().collect();
Self::from_map(map)
}
#[must_use]
pub fn from_map(map: HashMap<String, String>) -> Self {
let mut env = Self::new();
for (k, v) in map {
env.insert(k, v);
}
#[cfg(windows)]
{
if env.get("HOME").is_none() {
if let Some(userprofile) = env.get("USERPROFILE").map(str::to_owned) {
env.insert("HOME", userprofile);
}
}
}
env
}
pub fn insert(&mut self, name: impl Into<String>, value: impl Into<String>) {
let name = name.into();
let value = value.into();
#[cfg(windows)]
{
let lower = name.to_ascii_lowercase();
if let Some(prior) = self.lookup_index.get(&lower) {
if prior != &name {
self.inner.remove(prior);
}
}
self.lookup_index.insert(lower, name.clone());
}
self.inner.insert(name, value);
}
#[must_use]
pub fn get(&self, name: &str) -> Option<&str> {
if let Some(v) = self.inner.get(name) {
return Some(v.as_str());
}
#[cfg(windows)]
{
let lower = name.to_ascii_lowercase();
if let Some(original) = self.lookup_index.get(&lower) {
return self.inner.get(original).map(String::as_str);
}
}
None
}
}
pub fn expand(input: &str, env: &VarEnv) -> Result<String, VarExpandError> {
let bytes = input.as_bytes();
let mut out = String::with_capacity(input.len());
let mut i = 0usize;
while i < bytes.len() {
match bytes[i] {
b'$' => i = scan_dollar(bytes, i, env, &mut out)?,
b'%' => i = scan_percent(bytes, i, env, &mut out)?,
b => {
out.push(b as char);
i += 1;
}
}
}
Ok(out)
}
fn scan_dollar(
bytes: &[u8],
start: usize,
env: &VarEnv,
out: &mut String,
) -> Result<usize, VarExpandError> {
debug_assert_eq!(bytes[start], b'$');
let next = bytes.get(start + 1).copied();
match next {
Some(b'$') => {
out.push('$');
Ok(start + 2)
}
Some(b'{') => scan_braced(bytes, start, env, out),
Some(b) if is_name_start(b) => {
let name_start = start + 1;
let mut end = name_start;
while end < bytes.len() && is_name_cont(bytes[end]) {
end += 1;
}
let name = &bytes[name_start..end];
resolve(name, start, env, out)?;
Ok(end)
}
_ => {
let (name_end, found_non_name) = scan_trailing_name(bytes, start + 1);
let got = String::from_utf8_lossy(&bytes[start + 1..name_end]).into_owned();
let got = if got.is_empty() && !found_non_name { String::new() } else { got };
Err(VarExpandError::InvalidVariableName { got, offset: start })
}
}
}
fn scan_trailing_name(bytes: &[u8], from: usize) -> (usize, bool) {
let mut end = from;
while end < bytes.len() && is_name_cont(bytes[end]) {
end += 1;
}
let stopped_on_byte = end < bytes.len();
(end, stopped_on_byte)
}
fn scan_braced(
bytes: &[u8],
start: usize,
env: &VarEnv,
out: &mut String,
) -> Result<usize, VarExpandError> {
debug_assert!(bytes[start] == b'$' && bytes[start + 1] == b'{');
let name_start = start + 2;
let mut end = name_start;
while end < bytes.len() && bytes[end] != b'}' {
end += 1;
}
if end >= bytes.len() {
return Err(VarExpandError::UnclosedBraceExpansion { offset: start });
}
let name = &bytes[name_start..end];
if name.is_empty() {
return Err(VarExpandError::EmptyBraceExpansion { offset: start });
}
resolve(name, start, env, out)?;
Ok(end + 1)
}
fn scan_percent(
bytes: &[u8],
start: usize,
env: &VarEnv,
out: &mut String,
) -> Result<usize, VarExpandError> {
debug_assert_eq!(bytes[start], b'%');
if bytes.get(start + 1).copied() == Some(b'%') {
out.push('%');
return Ok(start + 2);
}
let name_start = start + 1;
let mut end = name_start;
while end < bytes.len() && bytes[end] != b'%' {
end += 1;
}
if end >= bytes.len() {
return Err(VarExpandError::UnclosedPercentExpansion { offset: start });
}
let name = &bytes[name_start..end];
if name.is_empty() {
out.push('%');
return Ok(end + 1);
}
resolve(name, start, env, out)?;
Ok(end + 1)
}
fn resolve(
name: &[u8],
offset: usize,
env: &VarEnv,
out: &mut String,
) -> Result<(), VarExpandError> {
if !is_valid_name(name) {
return Err(VarExpandError::InvalidVariableName {
got: String::from_utf8_lossy(name).into_owned(),
offset,
});
}
let name_str = std::str::from_utf8(name).expect("validated ASCII");
match env.get(name_str) {
Some(value) => {
out.push_str(value);
Ok(())
}
None => Err(VarExpandError::MissingVariable { name: name_str.to_owned(), offset }),
}
}
#[inline]
fn is_name_start(b: u8) -> bool {
b.is_ascii_alphabetic() || b == b'_'
}
#[inline]
fn is_name_cont(b: u8) -> bool {
b.is_ascii_alphanumeric() || b == b'_'
}
fn is_valid_name(name: &[u8]) -> bool {
match name.first() {
Some(&b) if is_name_start(b) => name[1..].iter().all(|&c| is_name_cont(c)),
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn env(pairs: &[(&str, &str)]) -> VarEnv {
let mut e = VarEnv::new();
for (k, v) in pairs {
e.insert(*k, *v);
}
e
}
#[test]
fn expand_noop_no_vars() {
let e = VarEnv::new();
assert_eq!(expand("plain text / no sigils", &e).unwrap(), "plain text / no sigils");
assert_eq!(expand("", &e).unwrap(), "");
}
#[test]
fn expand_posix_bare() {
let e = env(&[("HOME", "/h")]);
assert_eq!(expand("$HOME/foo", &e).unwrap(), "/h/foo");
}
#[test]
fn expand_posix_braced() {
let e = env(&[("USER", "yueyang")]);
assert_eq!(expand("${USER}-log", &e).unwrap(), "yueyang-log");
}
#[test]
fn expand_windows_percent() {
let e = env(&[("USERPROFILE", "C:\\Users\\y")]);
assert_eq!(expand("%USERPROFILE%\\x", &e).unwrap(), "C:\\Users\\y\\x");
}
#[test]
fn expand_escape_dollar() {
let e = VarEnv::new();
assert_eq!(expand("$$HOME", &e).unwrap(), "$HOME");
}
#[test]
fn expand_escape_percent() {
let e = VarEnv::new();
assert_eq!(expand("%%PATH%%", &e).unwrap(), "%PATH%");
}
#[test]
fn expand_missing_var_errors() {
let e = VarEnv::new();
assert_eq!(
expand("$UNDEFINED", &e).unwrap_err(),
VarExpandError::MissingVariable { name: "UNDEFINED".into(), offset: 0 }
);
}
#[test]
fn expand_unclosed_brace() {
let e = VarEnv::new();
assert_eq!(
expand("${FOO", &e).unwrap_err(),
VarExpandError::UnclosedBraceExpansion { offset: 0 }
);
}
#[test]
fn expand_unclosed_percent() {
let e = VarEnv::new();
assert_eq!(
expand("%FOO", &e).unwrap_err(),
VarExpandError::UnclosedPercentExpansion { offset: 0 }
);
}
#[test]
fn expand_empty_brace() {
let e = VarEnv::new();
assert_eq!(
expand("${}", &e).unwrap_err(),
VarExpandError::EmptyBraceExpansion { offset: 0 }
);
}
#[test]
fn expand_invalid_name_digit_led() {
let e = VarEnv::new();
let err = expand("$0FOO", &e).unwrap_err();
match err {
VarExpandError::InvalidVariableName { got, offset } => {
assert_eq!(got, "0FOO");
assert_eq!(offset, 0);
}
other => panic!("expected InvalidVariableName, got {other:?}"),
}
}
#[test]
fn expand_invalid_name_hyphen() {
let e = VarEnv::new();
let err = expand("${BAD-NAME}", &e).unwrap_err();
match err {
VarExpandError::InvalidVariableName { got, offset } => {
assert_eq!(got, "BAD-NAME");
assert_eq!(offset, 0);
}
other => panic!("expected InvalidVariableName, got {other:?}"),
}
}
#[test]
fn expand_no_recursive() {
let e = env(&[("A", "$B"), ("B", "boom")]);
assert_eq!(expand("$A", &e).unwrap(), "$B");
}
#[test]
fn expand_boundary_adjacent() {
let e = env(&[("HOME", "/h"), ("USER", "y")]);
assert_eq!(expand("$HOME/path_$USER", &e).unwrap(), "/h/path_y");
}
#[test]
fn expand_dollar_at_end() {
let e = VarEnv::new();
let err = expand("trailing$", &e).unwrap_err();
match err {
VarExpandError::InvalidVariableName { got, offset } => {
assert_eq!(got, "");
assert_eq!(offset, 8);
}
other => panic!("expected InvalidVariableName, got {other:?}"),
}
}
#[test]
fn expand_percent_isolated_mid() {
let e = VarEnv::new();
assert_eq!(
expand("50% off", &e).unwrap_err(),
VarExpandError::UnclosedPercentExpansion { offset: 2 }
);
}
#[test]
fn expand_offset_is_sigil_position() {
let e = VarEnv::new();
let err = expand("prefix-${MISSING}", &e).unwrap_err();
match err {
VarExpandError::MissingVariable { name, offset } => {
assert_eq!(name, "MISSING");
assert_eq!(offset, 7);
}
other => panic!("expected MissingVariable, got {other:?}"),
}
}
#[test]
fn var_env_from_os() {
let e = VarEnv::from_os();
assert!(e.get("PATH").is_some() || e.get("Path").is_some());
}
#[test]
fn var_env_get_and_insert() {
let mut e = VarEnv::new();
assert_eq!(e.get("X"), None);
e.insert("X", "1");
assert_eq!(e.get("X"), Some("1"));
e.insert("X", "2");
assert_eq!(e.get("X"), Some("2"));
}
#[cfg(windows)]
#[test]
fn var_env_windows_case_insensitive_get() {
let mut e = VarEnv::new();
e.insert("PATH", "c:/bin");
assert_eq!(e.get("PATH"), Some("c:/bin"));
assert_eq!(e.get("Path"), Some("c:/bin"));
assert_eq!(e.get("path"), Some("c:/bin"));
}
#[cfg(windows)]
#[test]
fn var_env_windows_home_fallback_from_userprofile() {
let mut seed = HashMap::new();
seed.insert("USERPROFILE".to_string(), r"C:\Users\y".to_string());
let env = VarEnv::from_map(seed);
assert_eq!(env.get("HOME"), Some(r"C:\Users\y"));
assert_eq!(env.get("home"), Some(r"C:\Users\y"));
}
#[cfg(windows)]
#[test]
fn var_env_windows_home_fallback_not_applied_by_insert() {
let mut e = VarEnv::new();
e.insert("USERPROFILE", r"C:\Users\y");
assert_eq!(e.get("HOME"), None);
}
#[cfg(unix)]
#[test]
fn var_env_unix_case_sensitive_still() {
let mut e = VarEnv::new();
e.insert("PATH", "/usr/bin");
assert_eq!(e.get("PATH"), Some("/usr/bin"));
assert_eq!(e.get("Path"), None);
assert_eq!(e.get("path"), None);
}
}