prax_query/connection/
env.rs

1//! Environment variable expansion.
2
3use super::{ConnectionError, ConnectionResult};
4use std::collections::HashMap;
5
6/// Source for environment variables.
7pub trait EnvSource: Send + Sync {
8    /// Get an environment variable value.
9    fn get(&self, name: &str) -> Option<String>;
10
11    /// Check if a variable exists.
12    fn contains(&self, name: &str) -> bool {
13        self.get(name).is_some()
14    }
15}
16
17/// Default environment source using std::env.
18#[derive(Debug, Clone, Copy, Default)]
19pub struct StdEnvSource;
20
21impl EnvSource for StdEnvSource {
22    fn get(&self, name: &str) -> Option<String> {
23        std::env::var(name).ok()
24    }
25}
26
27/// Environment source backed by a HashMap.
28#[derive(Debug, Clone, Default)]
29pub struct MapEnvSource {
30    vars: HashMap<String, String>,
31}
32
33impl MapEnvSource {
34    /// Create a new map-based environment source.
35    pub fn new() -> Self {
36        Self::default()
37    }
38
39    /// Add a variable.
40    pub fn set(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
41        self.vars.insert(name.into(), value.into());
42        self
43    }
44
45    /// Add multiple variables.
46    pub fn with_vars(mut self, vars: HashMap<String, String>) -> Self {
47        self.vars.extend(vars);
48        self
49    }
50}
51
52impl EnvSource for MapEnvSource {
53    fn get(&self, name: &str) -> Option<String> {
54        self.vars.get(name).cloned()
55    }
56}
57
58/// Expands environment variables in strings.
59///
60/// Supported syntax:
61/// - `${VAR}` - Required variable
62/// - `${VAR:-default}` - Variable with default value
63/// - `${VAR:?error message}` - Required with custom error
64/// - `$VAR` - Simple variable reference
65#[derive(Debug, Clone)]
66pub struct EnvExpander<S: EnvSource = StdEnvSource> {
67    source: S,
68}
69
70impl EnvExpander<StdEnvSource> {
71    /// Create a new expander using the standard environment.
72    pub fn new() -> Self {
73        Self {
74            source: StdEnvSource,
75        }
76    }
77}
78
79impl Default for EnvExpander<StdEnvSource> {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl<S: EnvSource> EnvExpander<S> {
86    /// Create an expander with a custom environment source.
87    pub fn with_source(source: S) -> Self {
88        Self { source }
89    }
90
91    /// Expand environment variables in a string.
92    ///
93    /// # Examples
94    ///
95    /// ```rust
96    /// use prax_query::connection::EnvExpander;
97    ///
98    /// // SAFETY: This is for documentation purposes only
99    /// unsafe { std::env::set_var("PRAX_TEST_HOST", "localhost") };
100    /// let expander = EnvExpander::new();
101    /// let result = expander.expand("postgres://${PRAX_TEST_HOST}/db").unwrap();
102    /// assert_eq!(result, "postgres://localhost/db");
103    /// unsafe { std::env::remove_var("PRAX_TEST_HOST") };
104    /// ```
105    pub fn expand(&self, input: &str) -> ConnectionResult<String> {
106        let mut result = String::with_capacity(input.len());
107        let mut chars = input.chars().peekable();
108
109        while let Some(c) = chars.next() {
110            if c == '$' {
111                if chars.peek() == Some(&'{') {
112                    // ${VAR} syntax
113                    chars.next(); // consume '{'
114                    let expanded = self.expand_braced(&mut chars)?;
115                    result.push_str(&expanded);
116                } else if chars.peek().map_or(false, |c| c.is_alphabetic() || *c == '_') {
117                    // $VAR syntax
118                    let expanded = self.expand_simple(&mut chars)?;
119                    result.push_str(&expanded);
120                } else {
121                    // Literal $
122                    result.push(c);
123                }
124            } else {
125                result.push(c);
126            }
127        }
128
129        Ok(result)
130    }
131
132    fn expand_braced(&self, chars: &mut std::iter::Peekable<std::str::Chars>) -> ConnectionResult<String> {
133        let mut name = String::new();
134        let mut modifier = None;
135        let mut modifier_value = String::new();
136
137        while let Some(&c) = chars.peek() {
138            if c == '}' {
139                chars.next();
140                break;
141            } else if c == ':' && modifier.is_none() {
142                chars.next();
143                // Check for modifier type
144                if let Some(&next) = chars.peek() {
145                    modifier = Some(next);
146                    chars.next();
147                }
148            } else if modifier.is_some() {
149                modifier_value.push(c);
150                chars.next();
151            } else {
152                name.push(c);
153                chars.next();
154            }
155        }
156
157        if name.is_empty() {
158            return Err(ConnectionError::InvalidEnvValue {
159                name: "".to_string(),
160                message: "Empty variable name".to_string(),
161            });
162        }
163
164        match self.source.get(&name) {
165            Some(value) if !value.is_empty() => Ok(value),
166            _ => {
167                match modifier {
168                    Some('-') => Ok(modifier_value),
169                    Some('?') => Err(ConnectionError::InvalidEnvValue {
170                        name: name.clone(),
171                        message: if modifier_value.is_empty() {
172                            format!("Required variable '{}' is not set", name)
173                        } else {
174                            modifier_value
175                        },
176                    }),
177                    Some('+') => {
178                        // ${VAR:+value} - value if VAR is set, empty otherwise
179                        Ok(String::new())
180                    }
181                    _ => Err(ConnectionError::EnvNotFound(name)),
182                }
183            }
184        }
185    }
186
187    fn expand_simple(&self, chars: &mut std::iter::Peekable<std::str::Chars>) -> ConnectionResult<String> {
188        let mut name = String::new();
189
190        while let Some(&c) = chars.peek() {
191            if c.is_alphanumeric() || c == '_' {
192                name.push(c);
193                chars.next();
194            } else {
195                break;
196            }
197        }
198
199        self.source.get(&name).ok_or_else(|| ConnectionError::EnvNotFound(name))
200    }
201
202    /// Expand a connection URL.
203    pub fn expand_url(&self, url: &str) -> ConnectionResult<String> {
204        self.expand(url)
205    }
206
207    /// Check if a string contains environment variable references.
208    pub fn has_variables(input: &str) -> bool {
209        input.contains('$')
210    }
211}
212
213/// Expand environment variables using the standard environment.
214pub fn expand_env(input: &str) -> ConnectionResult<String> {
215    EnvExpander::new().expand(input)
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    fn test_source() -> MapEnvSource {
223        MapEnvSource::new()
224            .set("HOST", "localhost")
225            .set("PORT", "5432")
226            .set("USER", "testuser")
227            .set("PASS", "secret")
228            .set("EMPTY", "")
229    }
230
231    #[test]
232    fn test_expand_simple() {
233        let expander = EnvExpander::with_source(test_source());
234
235        assert_eq!(
236            expander.expand("postgres://$HOST/db").unwrap(),
237            "postgres://localhost/db"
238        );
239    }
240
241    #[test]
242    fn test_expand_braced() {
243        let expander = EnvExpander::with_source(test_source());
244
245        assert_eq!(
246            expander.expand("postgres://${HOST}:${PORT}/db").unwrap(),
247            "postgres://localhost:5432/db"
248        );
249    }
250
251    #[test]
252    fn test_expand_default() {
253        let expander = EnvExpander::with_source(test_source());
254
255        // Variable exists
256        assert_eq!(
257            expander.expand("${HOST:-default}").unwrap(),
258            "localhost"
259        );
260
261        // Variable doesn't exist
262        assert_eq!(
263            expander.expand("${MISSING:-default}").unwrap(),
264            "default"
265        );
266
267        // Empty variable
268        assert_eq!(
269            expander.expand("${EMPTY:-default}").unwrap(),
270            "default"
271        );
272    }
273
274    #[test]
275    fn test_expand_required() {
276        let expander = EnvExpander::with_source(test_source());
277
278        // Variable exists
279        assert_eq!(
280            expander.expand("${HOST:?Host is required}").unwrap(),
281            "localhost"
282        );
283
284        // Variable doesn't exist
285        let result = expander.expand("${MISSING:?Missing is required}");
286        assert!(result.is_err());
287        assert!(result.unwrap_err().to_string().contains("Missing is required"));
288    }
289
290    #[test]
291    fn test_expand_missing() {
292        let expander = EnvExpander::with_source(test_source());
293
294        let result = expander.expand("${MISSING}");
295        assert!(matches!(result, Err(ConnectionError::EnvNotFound(_))));
296    }
297
298    #[test]
299    fn test_expand_full_url() {
300        let expander = EnvExpander::with_source(test_source());
301
302        let url = "postgres://${USER}:${PASS}@${HOST}:${PORT}/mydb?sslmode=require";
303        let expanded = expander.expand(url).unwrap();
304
305        assert_eq!(
306            expanded,
307            "postgres://testuser:secret@localhost:5432/mydb?sslmode=require"
308        );
309    }
310
311    #[test]
312    fn test_has_variables() {
313        assert!(EnvExpander::<StdEnvSource>::has_variables("${VAR}"));
314        assert!(EnvExpander::<StdEnvSource>::has_variables("$VAR"));
315        assert!(!EnvExpander::<StdEnvSource>::has_variables("no variables"));
316    }
317
318    #[test]
319    fn test_literal_dollar() {
320        let expander = EnvExpander::with_source(test_source());
321
322        // Dollar followed by non-variable character
323        assert_eq!(
324            expander.expand("cost: $5").unwrap(),
325            "cost: $5"
326        );
327    }
328}
329