prax_query/connection/
env.rs1use super::{ConnectionError, ConnectionResult};
4use std::collections::HashMap;
5
6pub trait EnvSource: Send + Sync {
8 fn get(&self, name: &str) -> Option<String>;
10
11 fn contains(&self, name: &str) -> bool {
13 self.get(name).is_some()
14 }
15}
16
17#[derive(Debug, Clone, Copy, Default)]
19pub struct StdEnvSource;
20
21impl EnvSource for StdEnvSource {
22 fn get(&self, name: &str) -> Option<String> {
23 std::env::var(name).ok()
24 }
25}
26
27#[derive(Debug, Clone, Default)]
29pub struct MapEnvSource {
30 vars: HashMap<String, String>,
31}
32
33impl MapEnvSource {
34 pub fn new() -> Self {
36 Self::default()
37 }
38
39 pub fn set(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
41 self.vars.insert(name.into(), value.into());
42 self
43 }
44
45 pub fn with_vars(mut self, vars: HashMap<String, String>) -> Self {
47 self.vars.extend(vars);
48 self
49 }
50}
51
52impl EnvSource for MapEnvSource {
53 fn get(&self, name: &str) -> Option<String> {
54 self.vars.get(name).cloned()
55 }
56}
57
58#[derive(Debug, Clone)]
66pub struct EnvExpander<S: EnvSource = StdEnvSource> {
67 source: S,
68}
69
70impl EnvExpander<StdEnvSource> {
71 pub fn new() -> Self {
73 Self {
74 source: StdEnvSource,
75 }
76 }
77}
78
79impl Default for EnvExpander<StdEnvSource> {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl<S: EnvSource> EnvExpander<S> {
86 pub fn with_source(source: S) -> Self {
88 Self { source }
89 }
90
91 pub fn expand(&self, input: &str) -> ConnectionResult<String> {
106 let mut result = String::with_capacity(input.len());
107 let mut chars = input.chars().peekable();
108
109 while let Some(c) = chars.next() {
110 if c == '$' {
111 if chars.peek() == Some(&'{') {
112 chars.next(); let expanded = self.expand_braced(&mut chars)?;
115 result.push_str(&expanded);
116 } else if chars
117 .peek()
118 .map_or(false, |c| c.is_alphabetic() || *c == '_')
119 {
120 let expanded = self.expand_simple(&mut chars)?;
122 result.push_str(&expanded);
123 } else {
124 result.push(c);
126 }
127 } else {
128 result.push(c);
129 }
130 }
131
132 Ok(result)
133 }
134
135 fn expand_braced(
136 &self,
137 chars: &mut std::iter::Peekable<std::str::Chars>,
138 ) -> ConnectionResult<String> {
139 let mut name = String::new();
140 let mut modifier = None;
141 let mut modifier_value = String::new();
142
143 while let Some(&c) = chars.peek() {
144 if c == '}' {
145 chars.next();
146 break;
147 } else if c == ':' && modifier.is_none() {
148 chars.next();
149 if let Some(&next) = chars.peek() {
151 modifier = Some(next);
152 chars.next();
153 }
154 } else if modifier.is_some() {
155 modifier_value.push(c);
156 chars.next();
157 } else {
158 name.push(c);
159 chars.next();
160 }
161 }
162
163 if name.is_empty() {
164 return Err(ConnectionError::InvalidEnvValue {
165 name: "".to_string(),
166 message: "Empty variable name".to_string(),
167 });
168 }
169
170 match self.source.get(&name) {
171 Some(value) if !value.is_empty() => Ok(value),
172 _ => {
173 match modifier {
174 Some('-') => Ok(modifier_value),
175 Some('?') => Err(ConnectionError::InvalidEnvValue {
176 name: name.clone(),
177 message: if modifier_value.is_empty() {
178 format!("Required variable '{}' is not set", name)
179 } else {
180 modifier_value
181 },
182 }),
183 Some('+') => {
184 Ok(String::new())
186 }
187 _ => Err(ConnectionError::EnvNotFound(name)),
188 }
189 }
190 }
191 }
192
193 fn expand_simple(
194 &self,
195 chars: &mut std::iter::Peekable<std::str::Chars>,
196 ) -> ConnectionResult<String> {
197 let mut name = String::new();
198
199 while let Some(&c) = chars.peek() {
200 if c.is_alphanumeric() || c == '_' {
201 name.push(c);
202 chars.next();
203 } else {
204 break;
205 }
206 }
207
208 self.source
209 .get(&name)
210 .ok_or_else(|| ConnectionError::EnvNotFound(name))
211 }
212
213 pub fn expand_url(&self, url: &str) -> ConnectionResult<String> {
215 self.expand(url)
216 }
217
218 pub fn has_variables(input: &str) -> bool {
220 input.contains('$')
221 }
222}
223
224pub fn expand_env(input: &str) -> ConnectionResult<String> {
226 EnvExpander::new().expand(input)
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 fn test_source() -> MapEnvSource {
234 MapEnvSource::new()
235 .set("HOST", "localhost")
236 .set("PORT", "5432")
237 .set("USER", "testuser")
238 .set("PASS", "secret")
239 .set("EMPTY", "")
240 }
241
242 #[test]
243 fn test_expand_simple() {
244 let expander = EnvExpander::with_source(test_source());
245
246 assert_eq!(
247 expander.expand("postgres://$HOST/db").unwrap(),
248 "postgres://localhost/db"
249 );
250 }
251
252 #[test]
253 fn test_expand_braced() {
254 let expander = EnvExpander::with_source(test_source());
255
256 assert_eq!(
257 expander.expand("postgres://${HOST}:${PORT}/db").unwrap(),
258 "postgres://localhost:5432/db"
259 );
260 }
261
262 #[test]
263 fn test_expand_default() {
264 let expander = EnvExpander::with_source(test_source());
265
266 assert_eq!(expander.expand("${HOST:-default}").unwrap(), "localhost");
268
269 assert_eq!(expander.expand("${MISSING:-default}").unwrap(), "default");
271
272 assert_eq!(expander.expand("${EMPTY:-default}").unwrap(), "default");
274 }
275
276 #[test]
277 fn test_expand_required() {
278 let expander = EnvExpander::with_source(test_source());
279
280 assert_eq!(
282 expander.expand("${HOST:?Host is required}").unwrap(),
283 "localhost"
284 );
285
286 let result = expander.expand("${MISSING:?Missing is required}");
288 assert!(result.is_err());
289 assert!(
290 result
291 .unwrap_err()
292 .to_string()
293 .contains("Missing is required")
294 );
295 }
296
297 #[test]
298 fn test_expand_missing() {
299 let expander = EnvExpander::with_source(test_source());
300
301 let result = expander.expand("${MISSING}");
302 assert!(matches!(result, Err(ConnectionError::EnvNotFound(_))));
303 }
304
305 #[test]
306 fn test_expand_full_url() {
307 let expander = EnvExpander::with_source(test_source());
308
309 let url = "postgres://${USER}:${PASS}@${HOST}:${PORT}/mydb?sslmode=require";
310 let expanded = expander.expand(url).unwrap();
311
312 assert_eq!(
313 expanded,
314 "postgres://testuser:secret@localhost:5432/mydb?sslmode=require"
315 );
316 }
317
318 #[test]
319 fn test_has_variables() {
320 assert!(EnvExpander::<StdEnvSource>::has_variables("${VAR}"));
321 assert!(EnvExpander::<StdEnvSource>::has_variables("$VAR"));
322 assert!(!EnvExpander::<StdEnvSource>::has_variables("no variables"));
323 }
324
325 #[test]
326 fn test_literal_dollar() {
327 let expander = EnvExpander::with_source(test_source());
328
329 assert_eq!(expander.expand("cost: $5").unwrap(), "cost: $5");
331 }
332}