prax_query/connection/
env.rs

1//! Environment variable expansion.
2
3#![allow(dead_code)]
4
5use super::{ConnectionError, ConnectionResult};
6use std::collections::HashMap;
7
8/// Source for environment variables.
9pub trait EnvSource: Send + Sync {
10    /// Get an environment variable value.
11    fn get(&self, name: &str) -> Option<String>;
12
13    /// Check if a variable exists.
14    fn contains(&self, name: &str) -> bool {
15        self.get(name).is_some()
16    }
17}
18
19/// Default environment source using std::env.
20#[derive(Debug, Clone, Copy, Default)]
21pub struct StdEnvSource;
22
23impl EnvSource for StdEnvSource {
24    fn get(&self, name: &str) -> Option<String> {
25        std::env::var(name).ok()
26    }
27}
28
29/// Environment source backed by a HashMap.
30#[derive(Debug, Clone, Default)]
31pub struct MapEnvSource {
32    vars: HashMap<String, String>,
33}
34
35impl MapEnvSource {
36    /// Create a new map-based environment source.
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    /// Add a variable.
42    pub fn set(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
43        self.vars.insert(name.into(), value.into());
44        self
45    }
46
47    /// Add multiple variables.
48    pub fn with_vars(mut self, vars: HashMap<String, String>) -> Self {
49        self.vars.extend(vars);
50        self
51    }
52}
53
54impl EnvSource for MapEnvSource {
55    fn get(&self, name: &str) -> Option<String> {
56        self.vars.get(name).cloned()
57    }
58}
59
60/// Expands environment variables in strings.
61///
62/// Supported syntax:
63/// - `${VAR}` - Required variable
64/// - `${VAR:-default}` - Variable with default value
65/// - `${VAR:?error message}` - Required with custom error
66/// - `$VAR` - Simple variable reference
67#[derive(Debug, Clone)]
68pub struct EnvExpander<S: EnvSource = StdEnvSource> {
69    source: S,
70}
71
72impl EnvExpander<StdEnvSource> {
73    /// Create a new expander using the standard environment.
74    pub fn new() -> Self {
75        Self {
76            source: StdEnvSource,
77        }
78    }
79}
80
81impl Default for EnvExpander<StdEnvSource> {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87impl<S: EnvSource> EnvExpander<S> {
88    /// Create an expander with a custom environment source.
89    pub fn with_source(source: S) -> Self {
90        Self { source }
91    }
92
93    /// Expand environment variables in a string.
94    ///
95    /// # Examples
96    ///
97    /// ```rust
98    /// use prax_query::connection::EnvExpander;
99    ///
100    /// // SAFETY: This is for documentation purposes only
101    /// unsafe { std::env::set_var("PRAX_TEST_HOST", "localhost") };
102    /// let expander = EnvExpander::new();
103    /// let result = expander.expand("postgres://${PRAX_TEST_HOST}/db").unwrap();
104    /// assert_eq!(result, "postgres://localhost/db");
105    /// unsafe { std::env::remove_var("PRAX_TEST_HOST") };
106    /// ```
107    pub fn expand(&self, input: &str) -> ConnectionResult<String> {
108        let mut result = String::with_capacity(input.len());
109        let mut chars = input.chars().peekable();
110
111        while let Some(c) = chars.next() {
112            if c == '$' {
113                if chars.peek() == Some(&'{') {
114                    // ${VAR} syntax
115                    chars.next(); // consume '{'
116                    let expanded = self.expand_braced(&mut chars)?;
117                    result.push_str(&expanded);
118                } else if chars
119                    .peek()
120                    .is_some_and(|c| c.is_alphabetic() || *c == '_')
121                {
122                    // $VAR syntax
123                    let expanded = self.expand_simple(&mut chars)?;
124                    result.push_str(&expanded);
125                } else {
126                    // Literal $
127                    result.push(c);
128                }
129            } else {
130                result.push(c);
131            }
132        }
133
134        Ok(result)
135    }
136
137    fn expand_braced(
138        &self,
139        chars: &mut std::iter::Peekable<std::str::Chars>,
140    ) -> ConnectionResult<String> {
141        let mut name = String::new();
142        let mut modifier = None;
143        let mut modifier_value = String::new();
144
145        while let Some(&c) = chars.peek() {
146            if c == '}' {
147                chars.next();
148                break;
149            } else if c == ':' && modifier.is_none() {
150                chars.next();
151                // Check for modifier type
152                if let Some(&next) = chars.peek() {
153                    modifier = Some(next);
154                    chars.next();
155                }
156            } else if modifier.is_some() {
157                modifier_value.push(c);
158                chars.next();
159            } else {
160                name.push(c);
161                chars.next();
162            }
163        }
164
165        if name.is_empty() {
166            return Err(ConnectionError::InvalidEnvValue {
167                name: "".to_string(),
168                message: "Empty variable name".to_string(),
169            });
170        }
171
172        match self.source.get(&name) {
173            Some(value) if !value.is_empty() => Ok(value),
174            _ => {
175                match modifier {
176                    Some('-') => Ok(modifier_value),
177                    Some('?') => Err(ConnectionError::InvalidEnvValue {
178                        name: name.clone(),
179                        message: if modifier_value.is_empty() {
180                            format!("Required variable '{}' is not set", name)
181                        } else {
182                            modifier_value
183                        },
184                    }),
185                    Some('+') => {
186                        // ${VAR:+value} - value if VAR is set, empty otherwise
187                        Ok(String::new())
188                    }
189                    _ => Err(ConnectionError::EnvNotFound(name)),
190                }
191            }
192        }
193    }
194
195    fn expand_simple(
196        &self,
197        chars: &mut std::iter::Peekable<std::str::Chars>,
198    ) -> ConnectionResult<String> {
199        let mut name = String::new();
200
201        while let Some(&c) = chars.peek() {
202            if c.is_alphanumeric() || c == '_' {
203                name.push(c);
204                chars.next();
205            } else {
206                break;
207            }
208        }
209
210        self.source
211            .get(&name)
212            .ok_or(ConnectionError::EnvNotFound(name))
213    }
214
215    /// Expand a connection URL.
216    pub fn expand_url(&self, url: &str) -> ConnectionResult<String> {
217        self.expand(url)
218    }
219
220    /// Check if a string contains environment variable references.
221    pub fn has_variables(input: &str) -> bool {
222        input.contains('$')
223    }
224}
225
226/// Expand environment variables using the standard environment.
227pub fn expand_env(input: &str) -> ConnectionResult<String> {
228    EnvExpander::new().expand(input)
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    fn test_source() -> MapEnvSource {
236        MapEnvSource::new()
237            .set("HOST", "localhost")
238            .set("PORT", "5432")
239            .set("USER", "testuser")
240            .set("PASS", "secret")
241            .set("EMPTY", "")
242    }
243
244    #[test]
245    fn test_expand_simple() {
246        let expander = EnvExpander::with_source(test_source());
247
248        assert_eq!(
249            expander.expand("postgres://$HOST/db").unwrap(),
250            "postgres://localhost/db"
251        );
252    }
253
254    #[test]
255    fn test_expand_braced() {
256        let expander = EnvExpander::with_source(test_source());
257
258        assert_eq!(
259            expander.expand("postgres://${HOST}:${PORT}/db").unwrap(),
260            "postgres://localhost:5432/db"
261        );
262    }
263
264    #[test]
265    fn test_expand_default() {
266        let expander = EnvExpander::with_source(test_source());
267
268        // Variable exists
269        assert_eq!(expander.expand("${HOST:-default}").unwrap(), "localhost");
270
271        // Variable doesn't exist
272        assert_eq!(expander.expand("${MISSING:-default}").unwrap(), "default");
273
274        // Empty variable
275        assert_eq!(expander.expand("${EMPTY:-default}").unwrap(), "default");
276    }
277
278    #[test]
279    fn test_expand_required() {
280        let expander = EnvExpander::with_source(test_source());
281
282        // Variable exists
283        assert_eq!(
284            expander.expand("${HOST:?Host is required}").unwrap(),
285            "localhost"
286        );
287
288        // Variable doesn't exist
289        let result = expander.expand("${MISSING:?Missing is required}");
290        assert!(result.is_err());
291        assert!(
292            result
293                .unwrap_err()
294                .to_string()
295                .contains("Missing is required")
296        );
297    }
298
299    #[test]
300    fn test_expand_missing() {
301        let expander = EnvExpander::with_source(test_source());
302
303        let result = expander.expand("${MISSING}");
304        assert!(matches!(result, Err(ConnectionError::EnvNotFound(_))));
305    }
306
307    #[test]
308    fn test_expand_full_url() {
309        let expander = EnvExpander::with_source(test_source());
310
311        let url = "postgres://${USER}:${PASS}@${HOST}:${PORT}/mydb?sslmode=require";
312        let expanded = expander.expand(url).unwrap();
313
314        assert_eq!(
315            expanded,
316            "postgres://testuser:secret@localhost:5432/mydb?sslmode=require"
317        );
318    }
319
320    #[test]
321    fn test_has_variables() {
322        assert!(EnvExpander::<StdEnvSource>::has_variables("${VAR}"));
323        assert!(EnvExpander::<StdEnvSource>::has_variables("$VAR"));
324        assert!(!EnvExpander::<StdEnvSource>::has_variables("no variables"));
325    }
326
327    #[test]
328    fn test_literal_dollar() {
329        let expander = EnvExpander::with_source(test_source());
330
331        // Dollar followed by non-variable character
332        assert_eq!(expander.expand("cost: $5").unwrap(), "cost: $5");
333    }
334}