agent_chain_core/utils/
env.rs

1//! Utilities for environment variables.
2//!
3//! Adapted from langchain_core/utils/env.py
4
5use std::collections::HashMap;
6use std::env;
7
8/// Check if an environment variable is set.
9///
10/// # Arguments
11///
12/// * `env_var` - The name of the environment variable.
13///
14/// # Returns
15///
16/// `true` if the environment variable is set and not falsy, `false` otherwise.
17pub fn env_var_is_set(env_var: &str) -> bool {
18    match env::var(env_var) {
19        Ok(value) => !value.is_empty() && value != "0" && value != "false" && value != "False",
20        Err(_) => false,
21    }
22}
23
24/// Get a value from a dictionary or an environment variable.
25///
26/// # Arguments
27///
28/// * `data` - The dictionary to look up the key in.
29/// * `keys` - The keys to look up in the dictionary. This can be multiple keys to try in order.
30/// * `env_key` - The environment variable to look up if the key is not in the dictionary.
31/// * `default` - The default value to return if the key is not in the dictionary or the environment.
32///
33/// # Returns
34///
35/// The dict value or the environment variable value, or an error if not found.
36///
37/// # Example
38///
39/// ```
40/// use std::collections::HashMap;
41/// use agent_chain_core::utils::env::get_from_dict_or_env;
42///
43/// let mut data = HashMap::new();
44/// data.insert("api_key".to_string(), "my_key".to_string());
45///
46/// let result = get_from_dict_or_env(&data, &["api_key"], "API_KEY", None);
47/// assert_eq!(result.unwrap(), "my_key");
48/// ```
49pub fn get_from_dict_or_env(
50    data: &HashMap<String, String>,
51    keys: &[&str],
52    env_key: &str,
53    default: Option<&str>,
54) -> Result<String, EnvError> {
55    for key in keys {
56        if let Some(value) = data.get(*key)
57            && !value.is_empty()
58        {
59            return Ok(value.clone());
60        }
61    }
62
63    let key_for_err = keys.first().copied().unwrap_or(env_key);
64    get_from_env(key_for_err, env_key, default)
65}
66
67/// Get a value from an environment variable.
68///
69/// # Arguments
70///
71/// * `key` - The key name (used in error messages).
72/// * `env_key` - The environment variable to look up.
73/// * `default` - The default value to return if the environment variable is not set.
74///
75/// # Returns
76///
77/// The value of the environment variable, or an error if not found and no default provided.
78///
79/// # Errors
80///
81/// Returns `EnvError::NotFound` if the environment variable is not set and no default value is provided.
82///
83/// # Example
84///
85/// ```
86/// use agent_chain_core::utils::env::get_from_env;
87/// use std::env;
88///
89/// // SAFETY: This is a single-threaded doc test
90/// unsafe { env::set_var("MY_TEST_VAR", "test_value"); }
91/// let result = get_from_env("my_test", "MY_TEST_VAR", None);
92/// assert_eq!(result.unwrap(), "test_value");
93/// // SAFETY: This is a single-threaded doc test
94/// unsafe { env::remove_var("MY_TEST_VAR"); }
95/// ```
96pub fn get_from_env(key: &str, env_key: &str, default: Option<&str>) -> Result<String, EnvError> {
97    if let Ok(value) = env::var(env_key)
98        && !value.is_empty()
99    {
100        return Ok(value);
101    }
102
103    if let Some(default_val) = default {
104        return Ok(default_val.to_string());
105    }
106
107    Err(EnvError::NotFound {
108        key: key.to_string(),
109        env_key: env_key.to_string(),
110    })
111}
112
113/// Create a factory function that gets a value from an environment variable.
114///
115/// # Arguments
116///
117/// * `key` - The environment variable to look up. If multiple keys are provided,
118///   the first key found in the environment will be used.
119/// * `default` - The default value to return if the environment variable is not set.
120/// * `error_message` - The error message to raise if the key is not found and no default is provided.
121///
122/// # Returns
123///
124/// A closure that will look up the value from the environment.
125pub fn from_env<'a>(
126    keys: &'a [&'a str],
127    default: Option<&'a str>,
128    error_message: Option<&'a str>,
129) -> impl Fn() -> Result<String, EnvError> + 'a {
130    move || {
131        for key in keys {
132            if let Ok(value) = env::var(key)
133                && !value.is_empty()
134            {
135                return Ok(value);
136            }
137        }
138
139        if let Some(default_val) = default {
140            return Ok(default_val.to_string());
141        }
142
143        if let Some(msg) = error_message {
144            return Err(EnvError::Custom(msg.to_string()));
145        }
146
147        let keys_str = keys.join(", ");
148        Err(EnvError::NotFound {
149            key: keys_str.clone(),
150            env_key: keys_str,
151        })
152    }
153}
154
155/// Create a factory function that gets a secret value from an environment variable.
156///
157/// This is similar to `from_env` but is intended for sensitive values like API keys.
158///
159/// # Arguments
160///
161/// * `key` - The environment variable to look up.
162/// * `default` - The default value to return if the environment variable is not set.
163/// * `error_message` - The error message to raise if the key is not found and no default is provided.
164///
165/// # Returns
166///
167/// A closure that will look up the secret from the environment.
168pub fn secret_from_env<'a>(
169    keys: &'a [&'a str],
170    default: Option<&'a str>,
171    error_message: Option<&'a str>,
172) -> impl Fn() -> Result<SecretString, EnvError> + 'a {
173    let get_value = from_env(keys, default, error_message);
174    move || get_value().map(SecretString::new)
175}
176
177/// A wrapper around a string that prevents it from being printed.
178///
179/// This is useful for sensitive values like API keys.
180#[derive(Clone)]
181pub struct SecretString {
182    value: String,
183}
184
185impl SecretString {
186    /// Create a new secret string.
187    pub fn new(value: String) -> Self {
188        Self { value }
189    }
190
191    /// Get the secret value.
192    ///
193    /// Use this sparingly to avoid leaking secrets.
194    pub fn expose_secret(&self) -> &str {
195        &self.value
196    }
197}
198
199impl std::fmt::Debug for SecretString {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        write!(f, "SecretString(***)")
202    }
203}
204
205impl std::fmt::Display for SecretString {
206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207        write!(f, "***")
208    }
209}
210
211impl From<String> for SecretString {
212    fn from(value: String) -> Self {
213        Self::new(value)
214    }
215}
216
217impl From<&str> for SecretString {
218    fn from(value: &str) -> Self {
219        Self::new(value.to_string())
220    }
221}
222
223/// Error types for environment operations.
224#[derive(Debug, Clone, PartialEq)]
225pub enum EnvError {
226    /// The environment variable was not found.
227    NotFound { key: String, env_key: String },
228    /// A custom error message.
229    Custom(String),
230}
231
232impl std::fmt::Display for EnvError {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        match self {
235            EnvError::NotFound { key, env_key } => {
236                write!(
237                    f,
238                    "Did not find {}, please add an environment variable `{}` which contains it, or pass `{}` as a named parameter.",
239                    key, env_key, key
240                )
241            }
242            EnvError::Custom(msg) => write!(f, "{}", msg),
243        }
244    }
245}
246
247impl std::error::Error for EnvError {}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_env_var_is_set() {
255        unsafe {
256            env::set_var("TEST_VAR_SET", "value");
257        }
258        assert!(env_var_is_set("TEST_VAR_SET"));
259        unsafe {
260            env::remove_var("TEST_VAR_SET");
261        }
262
263        unsafe {
264            env::set_var("TEST_VAR_EMPTY", "");
265        }
266        assert!(!env_var_is_set("TEST_VAR_EMPTY"));
267        unsafe {
268            env::remove_var("TEST_VAR_EMPTY");
269        }
270
271        unsafe {
272            env::set_var("TEST_VAR_FALSE", "false");
273        }
274        assert!(!env_var_is_set("TEST_VAR_FALSE"));
275        unsafe {
276            env::remove_var("TEST_VAR_FALSE");
277        }
278
279        unsafe {
280            env::set_var("TEST_VAR_ZERO", "0");
281        }
282        assert!(!env_var_is_set("TEST_VAR_ZERO"));
283        unsafe {
284            env::remove_var("TEST_VAR_ZERO");
285        }
286
287        assert!(!env_var_is_set("NONEXISTENT_VAR_12345"));
288    }
289
290    #[test]
291    fn test_get_from_dict_or_env() {
292        let mut data = HashMap::new();
293        data.insert("key1".to_string(), "value1".to_string());
294
295        let result = get_from_dict_or_env(&data, &["key1"], "ENV_KEY", None);
296        assert_eq!(result.unwrap(), "value1");
297
298        unsafe {
299            env::set_var("TEST_ENV_KEY", "env_value");
300        }
301        let result = get_from_dict_or_env(&data, &["key2"], "TEST_ENV_KEY", None);
302        assert_eq!(result.unwrap(), "env_value");
303        unsafe {
304            env::remove_var("TEST_ENV_KEY");
305        }
306
307        let result = get_from_dict_or_env(&data, &["key3"], "NONEXISTENT", Some("default"));
308        assert_eq!(result.unwrap(), "default");
309    }
310
311    #[test]
312    fn test_get_from_env() {
313        unsafe {
314            env::set_var("TEST_GET_FROM_ENV", "test_value");
315        }
316        let result = get_from_env("test", "TEST_GET_FROM_ENV", None);
317        assert_eq!(result.unwrap(), "test_value");
318        unsafe {
319            env::remove_var("TEST_GET_FROM_ENV");
320        }
321
322        let result = get_from_env("test", "NONEXISTENT_VAR", Some("default"));
323        assert_eq!(result.unwrap(), "default");
324
325        let result = get_from_env("test", "NONEXISTENT_VAR", None);
326        assert!(result.is_err());
327    }
328
329    #[test]
330    fn test_from_env() {
331        unsafe {
332            env::set_var("TEST_FROM_ENV", "test_value");
333        }
334        let get_value = from_env(&["TEST_FROM_ENV"], None, None);
335        assert_eq!(get_value().unwrap(), "test_value");
336        unsafe {
337            env::remove_var("TEST_FROM_ENV");
338        }
339
340        let get_value = from_env(&["NONEXISTENT"], Some("default"), None);
341        assert_eq!(get_value().unwrap(), "default");
342    }
343
344    #[test]
345    fn test_secret_string() {
346        let secret = SecretString::new("my_secret".to_string());
347        assert_eq!(secret.expose_secret(), "my_secret");
348        assert_eq!(format!("{}", secret), "***");
349        assert_eq!(format!("{:?}", secret), "SecretString(***)");
350    }
351}