prax_query/connection/
env.rs

1//! Environment variable expansion.
2
3#![allow(dead_code)]
4
5use super::{ConnectionError, ConnectionResult};
6use std::collections::HashMap;
7
8/// Source for environment variables.
9pub trait EnvSource: Send + Sync {
10    /// Get an environment variable value.
11    fn get(&self, name: &str) -> Option<String>;
12
13    /// Check if a variable exists.
14    fn contains(&self, name: &str) -> bool {
15        self.get(name).is_some()
16    }
17}
18
19/// Default environment source using std::env.
20#[derive(Debug, Clone, Copy, Default)]
21pub struct StdEnvSource;
22
23impl EnvSource for StdEnvSource {
24    fn get(&self, name: &str) -> Option<String> {
25        std::env::var(name).ok()
26    }
27}
28
29/// Environment source backed by a HashMap.
30#[derive(Debug, Clone, Default)]
31pub struct MapEnvSource {
32    vars: HashMap<String, String>,
33}
34
35impl MapEnvSource {
36    /// Create a new map-based environment source.
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    /// Add a variable.
42    pub fn set(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
43        self.vars.insert(name.into(), value.into());
44        self
45    }
46
47    /// Add multiple variables.
48    pub fn with_vars(mut self, vars: HashMap<String, String>) -> Self {
49        self.vars.extend(vars);
50        self
51    }
52}
53
54impl EnvSource for MapEnvSource {
55    fn get(&self, name: &str) -> Option<String> {
56        self.vars.get(name).cloned()
57    }
58}
59
60/// Expands environment variables in strings.
61///
62/// Supported syntax:
63/// - `${VAR}` - Required variable
64/// - `${VAR:-default}` - Variable with default value
65/// - `${VAR:?error message}` - Required with custom error
66/// - `$VAR` - Simple variable reference
67#[derive(Debug, Clone)]
68pub struct EnvExpander<S: EnvSource = StdEnvSource> {
69    source: S,
70}
71
72impl EnvExpander<StdEnvSource> {
73    /// Create a new expander using the standard environment.
74    pub fn new() -> Self {
75        Self {
76            source: StdEnvSource,
77        }
78    }
79}
80
81impl Default for EnvExpander<StdEnvSource> {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87impl<S: EnvSource> EnvExpander<S> {
88    /// Create an expander with a custom environment source.
89    pub fn with_source(source: S) -> Self {
90        Self { source }
91    }
92
93    /// Expand environment variables in a string.
94    ///
95    /// # Examples
96    ///
97    /// ```rust
98    /// use prax_query::connection::EnvExpander;
99    ///
100    /// // SAFETY: This is for documentation purposes only
101    /// unsafe { std::env::set_var("PRAX_TEST_HOST", "localhost") };
102    /// let expander = EnvExpander::new();
103    /// let result = expander.expand("postgres://${PRAX_TEST_HOST}/db").unwrap();
104    /// assert_eq!(result, "postgres://localhost/db");
105    /// unsafe { std::env::remove_var("PRAX_TEST_HOST") };
106    /// ```
107    pub fn expand(&self, input: &str) -> ConnectionResult<String> {
108        let mut result = String::with_capacity(input.len());
109        let mut chars = input.chars().peekable();
110
111        while let Some(c) = chars.next() {
112            if c == '$' {
113                if chars.peek() == Some(&'{') {
114                    // ${VAR} syntax
115                    chars.next(); // consume '{'
116                    let expanded = self.expand_braced(&mut chars)?;
117                    result.push_str(&expanded);
118                } else if chars.peek().is_some_and(|c| c.is_alphabetic() || *c == '_') {
119                    // $VAR syntax
120                    let expanded = self.expand_simple(&mut chars)?;
121                    result.push_str(&expanded);
122                } else {
123                    // Literal $
124                    result.push(c);
125                }
126            } else {
127                result.push(c);
128            }
129        }
130
131        Ok(result)
132    }
133
134    fn expand_braced(
135        &self,
136        chars: &mut std::iter::Peekable<std::str::Chars>,
137    ) -> ConnectionResult<String> {
138        let mut name = String::new();
139        let mut modifier = None;
140        let mut modifier_value = String::new();
141
142        while let Some(&c) = chars.peek() {
143            if c == '}' {
144                chars.next();
145                break;
146            } else if c == ':' && modifier.is_none() {
147                chars.next();
148                // Check for modifier type
149                if let Some(&next) = chars.peek() {
150                    modifier = Some(next);
151                    chars.next();
152                }
153            } else if modifier.is_some() {
154                modifier_value.push(c);
155                chars.next();
156            } else {
157                name.push(c);
158                chars.next();
159            }
160        }
161
162        if name.is_empty() {
163            return Err(ConnectionError::InvalidEnvValue {
164                name: "".to_string(),
165                message: "Empty variable name".to_string(),
166            });
167        }
168
169        match self.source.get(&name) {
170            Some(value) if !value.is_empty() => Ok(value),
171            _ => {
172                match modifier {
173                    Some('-') => Ok(modifier_value),
174                    Some('?') => Err(ConnectionError::InvalidEnvValue {
175                        name: name.clone(),
176                        message: if modifier_value.is_empty() {
177                            format!("Required variable '{}' is not set", name)
178                        } else {
179                            modifier_value
180                        },
181                    }),
182                    Some('+') => {
183                        // ${VAR:+value} - value if VAR is set, empty otherwise
184                        Ok(String::new())
185                    }
186                    _ => Err(ConnectionError::EnvNotFound(name)),
187                }
188            }
189        }
190    }
191
192    fn expand_simple(
193        &self,
194        chars: &mut std::iter::Peekable<std::str::Chars>,
195    ) -> ConnectionResult<String> {
196        let mut name = String::new();
197
198        while let Some(&c) = chars.peek() {
199            if c.is_alphanumeric() || c == '_' {
200                name.push(c);
201                chars.next();
202            } else {
203                break;
204            }
205        }
206
207        self.source
208            .get(&name)
209            .ok_or(ConnectionError::EnvNotFound(name))
210    }
211
212    /// Expand a connection URL.
213    pub fn expand_url(&self, url: &str) -> ConnectionResult<String> {
214        self.expand(url)
215    }
216
217    /// Check if a string contains environment variable references.
218    pub fn has_variables(input: &str) -> bool {
219        input.contains('$')
220    }
221}
222
223/// Expand environment variables using the standard environment.
224pub fn expand_env(input: &str) -> ConnectionResult<String> {
225    EnvExpander::new().expand(input)
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    fn test_source() -> MapEnvSource {
233        MapEnvSource::new()
234            .set("HOST", "localhost")
235            .set("PORT", "5432")
236            .set("USER", "testuser")
237            .set("PASS", "secret")
238            .set("EMPTY", "")
239    }
240
241    #[test]
242    fn test_expand_simple() {
243        let expander = EnvExpander::with_source(test_source());
244
245        assert_eq!(
246            expander.expand("postgres://$HOST/db").unwrap(),
247            "postgres://localhost/db"
248        );
249    }
250
251    #[test]
252    fn test_expand_braced() {
253        let expander = EnvExpander::with_source(test_source());
254
255        assert_eq!(
256            expander.expand("postgres://${HOST}:${PORT}/db").unwrap(),
257            "postgres://localhost:5432/db"
258        );
259    }
260
261    #[test]
262    fn test_expand_default() {
263        let expander = EnvExpander::with_source(test_source());
264
265        // Variable exists
266        assert_eq!(expander.expand("${HOST:-default}").unwrap(), "localhost");
267
268        // Variable doesn't exist
269        assert_eq!(expander.expand("${MISSING:-default}").unwrap(), "default");
270
271        // Empty variable
272        assert_eq!(expander.expand("${EMPTY:-default}").unwrap(), "default");
273    }
274
275    #[test]
276    fn test_expand_required() {
277        let expander = EnvExpander::with_source(test_source());
278
279        // Variable exists
280        assert_eq!(
281            expander.expand("${HOST:?Host is required}").unwrap(),
282            "localhost"
283        );
284
285        // Variable doesn't exist
286        let result = expander.expand("${MISSING:?Missing is required}");
287        assert!(result.is_err());
288        assert!(
289            result
290                .unwrap_err()
291                .to_string()
292                .contains("Missing is required")
293        );
294    }
295
296    #[test]
297    fn test_expand_missing() {
298        let expander = EnvExpander::with_source(test_source());
299
300        let result = expander.expand("${MISSING}");
301        assert!(matches!(result, Err(ConnectionError::EnvNotFound(_))));
302    }
303
304    #[test]
305    fn test_expand_full_url() {
306        let expander = EnvExpander::with_source(test_source());
307
308        let url = "postgres://${USER}:${PASS}@${HOST}:${PORT}/mydb?sslmode=require";
309        let expanded = expander.expand(url).unwrap();
310
311        assert_eq!(
312            expanded,
313            "postgres://testuser:secret@localhost:5432/mydb?sslmode=require"
314        );
315    }
316
317    #[test]
318    fn test_has_variables() {
319        assert!(EnvExpander::<StdEnvSource>::has_variables("${VAR}"));
320        assert!(EnvExpander::<StdEnvSource>::has_variables("$VAR"));
321        assert!(!EnvExpander::<StdEnvSource>::has_variables("no variables"));
322    }
323
324    #[test]
325    fn test_literal_dollar() {
326        let expander = EnvExpander::with_source(test_source());
327
328        // Dollar followed by non-variable character
329        assert_eq!(expander.expand("cost: $5").unwrap(), "cost: $5");
330    }
331}