prax_query/connection/
env.rs

1//! Environment variable expansion.
2
3use super::{ConnectionError, ConnectionResult};
4use std::collections::HashMap;
5
6/// Source for environment variables.
7pub trait EnvSource: Send + Sync {
8    /// Get an environment variable value.
9    fn get(&self, name: &str) -> Option<String>;
10
11    /// Check if a variable exists.
12    fn contains(&self, name: &str) -> bool {
13        self.get(name).is_some()
14    }
15}
16
17/// Default environment source using std::env.
18#[derive(Debug, Clone, Copy, Default)]
19pub struct StdEnvSource;
20
21impl EnvSource for StdEnvSource {
22    fn get(&self, name: &str) -> Option<String> {
23        std::env::var(name).ok()
24    }
25}
26
27/// Environment source backed by a HashMap.
28#[derive(Debug, Clone, Default)]
29pub struct MapEnvSource {
30    vars: HashMap<String, String>,
31}
32
33impl MapEnvSource {
34    /// Create a new map-based environment source.
35    pub fn new() -> Self {
36        Self::default()
37    }
38
39    /// Add a variable.
40    pub fn set(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
41        self.vars.insert(name.into(), value.into());
42        self
43    }
44
45    /// Add multiple variables.
46    pub fn with_vars(mut self, vars: HashMap<String, String>) -> Self {
47        self.vars.extend(vars);
48        self
49    }
50}
51
52impl EnvSource for MapEnvSource {
53    fn get(&self, name: &str) -> Option<String> {
54        self.vars.get(name).cloned()
55    }
56}
57
58/// Expands environment variables in strings.
59///
60/// Supported syntax:
61/// - `${VAR}` - Required variable
62/// - `${VAR:-default}` - Variable with default value
63/// - `${VAR:?error message}` - Required with custom error
64/// - `$VAR` - Simple variable reference
65#[derive(Debug, Clone)]
66pub struct EnvExpander<S: EnvSource = StdEnvSource> {
67    source: S,
68}
69
70impl EnvExpander<StdEnvSource> {
71    /// Create a new expander using the standard environment.
72    pub fn new() -> Self {
73        Self {
74            source: StdEnvSource,
75        }
76    }
77}
78
79impl Default for EnvExpander<StdEnvSource> {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl<S: EnvSource> EnvExpander<S> {
86    /// Create an expander with a custom environment source.
87    pub fn with_source(source: S) -> Self {
88        Self { source }
89    }
90
91    /// Expand environment variables in a string.
92    ///
93    /// # Examples
94    ///
95    /// ```rust
96    /// use prax_query::connection::EnvExpander;
97    ///
98    /// // SAFETY: This is for documentation purposes only
99    /// unsafe { std::env::set_var("PRAX_TEST_HOST", "localhost") };
100    /// let expander = EnvExpander::new();
101    /// let result = expander.expand("postgres://${PRAX_TEST_HOST}/db").unwrap();
102    /// assert_eq!(result, "postgres://localhost/db");
103    /// unsafe { std::env::remove_var("PRAX_TEST_HOST") };
104    /// ```
105    pub fn expand(&self, input: &str) -> ConnectionResult<String> {
106        let mut result = String::with_capacity(input.len());
107        let mut chars = input.chars().peekable();
108
109        while let Some(c) = chars.next() {
110            if c == '$' {
111                if chars.peek() == Some(&'{') {
112                    // ${VAR} syntax
113                    chars.next(); // consume '{'
114                    let expanded = self.expand_braced(&mut chars)?;
115                    result.push_str(&expanded);
116                } else if chars
117                    .peek()
118                    .map_or(false, |c| c.is_alphabetic() || *c == '_')
119                {
120                    // $VAR syntax
121                    let expanded = self.expand_simple(&mut chars)?;
122                    result.push_str(&expanded);
123                } else {
124                    // Literal $
125                    result.push(c);
126                }
127            } else {
128                result.push(c);
129            }
130        }
131
132        Ok(result)
133    }
134
135    fn expand_braced(
136        &self,
137        chars: &mut std::iter::Peekable<std::str::Chars>,
138    ) -> ConnectionResult<String> {
139        let mut name = String::new();
140        let mut modifier = None;
141        let mut modifier_value = String::new();
142
143        while let Some(&c) = chars.peek() {
144            if c == '}' {
145                chars.next();
146                break;
147            } else if c == ':' && modifier.is_none() {
148                chars.next();
149                // Check for modifier type
150                if let Some(&next) = chars.peek() {
151                    modifier = Some(next);
152                    chars.next();
153                }
154            } else if modifier.is_some() {
155                modifier_value.push(c);
156                chars.next();
157            } else {
158                name.push(c);
159                chars.next();
160            }
161        }
162
163        if name.is_empty() {
164            return Err(ConnectionError::InvalidEnvValue {
165                name: "".to_string(),
166                message: "Empty variable name".to_string(),
167            });
168        }
169
170        match self.source.get(&name) {
171            Some(value) if !value.is_empty() => Ok(value),
172            _ => {
173                match modifier {
174                    Some('-') => Ok(modifier_value),
175                    Some('?') => Err(ConnectionError::InvalidEnvValue {
176                        name: name.clone(),
177                        message: if modifier_value.is_empty() {
178                            format!("Required variable '{}' is not set", name)
179                        } else {
180                            modifier_value
181                        },
182                    }),
183                    Some('+') => {
184                        // ${VAR:+value} - value if VAR is set, empty otherwise
185                        Ok(String::new())
186                    }
187                    _ => Err(ConnectionError::EnvNotFound(name)),
188                }
189            }
190        }
191    }
192
193    fn expand_simple(
194        &self,
195        chars: &mut std::iter::Peekable<std::str::Chars>,
196    ) -> ConnectionResult<String> {
197        let mut name = String::new();
198
199        while let Some(&c) = chars.peek() {
200            if c.is_alphanumeric() || c == '_' {
201                name.push(c);
202                chars.next();
203            } else {
204                break;
205            }
206        }
207
208        self.source
209            .get(&name)
210            .ok_or_else(|| ConnectionError::EnvNotFound(name))
211    }
212
213    /// Expand a connection URL.
214    pub fn expand_url(&self, url: &str) -> ConnectionResult<String> {
215        self.expand(url)
216    }
217
218    /// Check if a string contains environment variable references.
219    pub fn has_variables(input: &str) -> bool {
220        input.contains('$')
221    }
222}
223
224/// Expand environment variables using the standard environment.
225pub fn expand_env(input: &str) -> ConnectionResult<String> {
226    EnvExpander::new().expand(input)
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    fn test_source() -> MapEnvSource {
234        MapEnvSource::new()
235            .set("HOST", "localhost")
236            .set("PORT", "5432")
237            .set("USER", "testuser")
238            .set("PASS", "secret")
239            .set("EMPTY", "")
240    }
241
242    #[test]
243    fn test_expand_simple() {
244        let expander = EnvExpander::with_source(test_source());
245
246        assert_eq!(
247            expander.expand("postgres://$HOST/db").unwrap(),
248            "postgres://localhost/db"
249        );
250    }
251
252    #[test]
253    fn test_expand_braced() {
254        let expander = EnvExpander::with_source(test_source());
255
256        assert_eq!(
257            expander.expand("postgres://${HOST}:${PORT}/db").unwrap(),
258            "postgres://localhost:5432/db"
259        );
260    }
261
262    #[test]
263    fn test_expand_default() {
264        let expander = EnvExpander::with_source(test_source());
265
266        // Variable exists
267        assert_eq!(expander.expand("${HOST:-default}").unwrap(), "localhost");
268
269        // Variable doesn't exist
270        assert_eq!(expander.expand("${MISSING:-default}").unwrap(), "default");
271
272        // Empty variable
273        assert_eq!(expander.expand("${EMPTY:-default}").unwrap(), "default");
274    }
275
276    #[test]
277    fn test_expand_required() {
278        let expander = EnvExpander::with_source(test_source());
279
280        // Variable exists
281        assert_eq!(
282            expander.expand("${HOST:?Host is required}").unwrap(),
283            "localhost"
284        );
285
286        // Variable doesn't exist
287        let result = expander.expand("${MISSING:?Missing is required}");
288        assert!(result.is_err());
289        assert!(
290            result
291                .unwrap_err()
292                .to_string()
293                .contains("Missing is required")
294        );
295    }
296
297    #[test]
298    fn test_expand_missing() {
299        let expander = EnvExpander::with_source(test_source());
300
301        let result = expander.expand("${MISSING}");
302        assert!(matches!(result, Err(ConnectionError::EnvNotFound(_))));
303    }
304
305    #[test]
306    fn test_expand_full_url() {
307        let expander = EnvExpander::with_source(test_source());
308
309        let url = "postgres://${USER}:${PASS}@${HOST}:${PORT}/mydb?sslmode=require";
310        let expanded = expander.expand(url).unwrap();
311
312        assert_eq!(
313            expanded,
314            "postgres://testuser:secret@localhost:5432/mydb?sslmode=require"
315        );
316    }
317
318    #[test]
319    fn test_has_variables() {
320        assert!(EnvExpander::<StdEnvSource>::has_variables("${VAR}"));
321        assert!(EnvExpander::<StdEnvSource>::has_variables("$VAR"));
322        assert!(!EnvExpander::<StdEnvSource>::has_variables("no variables"));
323    }
324
325    #[test]
326    fn test_literal_dollar() {
327        let expander = EnvExpander::with_source(test_source());
328
329        // Dollar followed by non-variable character
330        assert_eq!(expander.expand("cost: $5").unwrap(), "cost: $5");
331    }
332}