prax_query/connection/
env.rs1#![allow(dead_code)]
4
5use super::{ConnectionError, ConnectionResult};
6use std::collections::HashMap;
7
8pub trait EnvSource: Send + Sync {
10 fn get(&self, name: &str) -> Option<String>;
12
13 fn contains(&self, name: &str) -> bool {
15 self.get(name).is_some()
16 }
17}
18
19#[derive(Debug, Clone, Copy, Default)]
21pub struct StdEnvSource;
22
23impl EnvSource for StdEnvSource {
24 fn get(&self, name: &str) -> Option<String> {
25 std::env::var(name).ok()
26 }
27}
28
29#[derive(Debug, Clone, Default)]
31pub struct MapEnvSource {
32 vars: HashMap<String, String>,
33}
34
35impl MapEnvSource {
36 pub fn new() -> Self {
38 Self::default()
39 }
40
41 pub fn set(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
43 self.vars.insert(name.into(), value.into());
44 self
45 }
46
47 pub fn with_vars(mut self, vars: HashMap<String, String>) -> Self {
49 self.vars.extend(vars);
50 self
51 }
52}
53
54impl EnvSource for MapEnvSource {
55 fn get(&self, name: &str) -> Option<String> {
56 self.vars.get(name).cloned()
57 }
58}
59
60#[derive(Debug, Clone)]
68pub struct EnvExpander<S: EnvSource = StdEnvSource> {
69 source: S,
70}
71
72impl EnvExpander<StdEnvSource> {
73 pub fn new() -> Self {
75 Self {
76 source: StdEnvSource,
77 }
78 }
79}
80
81impl Default for EnvExpander<StdEnvSource> {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl<S: EnvSource> EnvExpander<S> {
88 pub fn with_source(source: S) -> Self {
90 Self { source }
91 }
92
93 pub fn expand(&self, input: &str) -> ConnectionResult<String> {
108 let mut result = String::with_capacity(input.len());
109 let mut chars = input.chars().peekable();
110
111 while let Some(c) = chars.next() {
112 if c == '$' {
113 if chars.peek() == Some(&'{') {
114 chars.next(); let expanded = self.expand_braced(&mut chars)?;
117 result.push_str(&expanded);
118 } else if chars
119 .peek()
120 .is_some_and(|c| c.is_alphabetic() || *c == '_')
121 {
122 let expanded = self.expand_simple(&mut chars)?;
124 result.push_str(&expanded);
125 } else {
126 result.push(c);
128 }
129 } else {
130 result.push(c);
131 }
132 }
133
134 Ok(result)
135 }
136
137 fn expand_braced(
138 &self,
139 chars: &mut std::iter::Peekable<std::str::Chars>,
140 ) -> ConnectionResult<String> {
141 let mut name = String::new();
142 let mut modifier = None;
143 let mut modifier_value = String::new();
144
145 while let Some(&c) = chars.peek() {
146 if c == '}' {
147 chars.next();
148 break;
149 } else if c == ':' && modifier.is_none() {
150 chars.next();
151 if let Some(&next) = chars.peek() {
153 modifier = Some(next);
154 chars.next();
155 }
156 } else if modifier.is_some() {
157 modifier_value.push(c);
158 chars.next();
159 } else {
160 name.push(c);
161 chars.next();
162 }
163 }
164
165 if name.is_empty() {
166 return Err(ConnectionError::InvalidEnvValue {
167 name: "".to_string(),
168 message: "Empty variable name".to_string(),
169 });
170 }
171
172 match self.source.get(&name) {
173 Some(value) if !value.is_empty() => Ok(value),
174 _ => {
175 match modifier {
176 Some('-') => Ok(modifier_value),
177 Some('?') => Err(ConnectionError::InvalidEnvValue {
178 name: name.clone(),
179 message: if modifier_value.is_empty() {
180 format!("Required variable '{}' is not set", name)
181 } else {
182 modifier_value
183 },
184 }),
185 Some('+') => {
186 Ok(String::new())
188 }
189 _ => Err(ConnectionError::EnvNotFound(name)),
190 }
191 }
192 }
193 }
194
195 fn expand_simple(
196 &self,
197 chars: &mut std::iter::Peekable<std::str::Chars>,
198 ) -> ConnectionResult<String> {
199 let mut name = String::new();
200
201 while let Some(&c) = chars.peek() {
202 if c.is_alphanumeric() || c == '_' {
203 name.push(c);
204 chars.next();
205 } else {
206 break;
207 }
208 }
209
210 self.source
211 .get(&name)
212 .ok_or(ConnectionError::EnvNotFound(name))
213 }
214
215 pub fn expand_url(&self, url: &str) -> ConnectionResult<String> {
217 self.expand(url)
218 }
219
220 pub fn has_variables(input: &str) -> bool {
222 input.contains('$')
223 }
224}
225
226pub fn expand_env(input: &str) -> ConnectionResult<String> {
228 EnvExpander::new().expand(input)
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 fn test_source() -> MapEnvSource {
236 MapEnvSource::new()
237 .set("HOST", "localhost")
238 .set("PORT", "5432")
239 .set("USER", "testuser")
240 .set("PASS", "secret")
241 .set("EMPTY", "")
242 }
243
244 #[test]
245 fn test_expand_simple() {
246 let expander = EnvExpander::with_source(test_source());
247
248 assert_eq!(
249 expander.expand("postgres://$HOST/db").unwrap(),
250 "postgres://localhost/db"
251 );
252 }
253
254 #[test]
255 fn test_expand_braced() {
256 let expander = EnvExpander::with_source(test_source());
257
258 assert_eq!(
259 expander.expand("postgres://${HOST}:${PORT}/db").unwrap(),
260 "postgres://localhost:5432/db"
261 );
262 }
263
264 #[test]
265 fn test_expand_default() {
266 let expander = EnvExpander::with_source(test_source());
267
268 assert_eq!(expander.expand("${HOST:-default}").unwrap(), "localhost");
270
271 assert_eq!(expander.expand("${MISSING:-default}").unwrap(), "default");
273
274 assert_eq!(expander.expand("${EMPTY:-default}").unwrap(), "default");
276 }
277
278 #[test]
279 fn test_expand_required() {
280 let expander = EnvExpander::with_source(test_source());
281
282 assert_eq!(
284 expander.expand("${HOST:?Host is required}").unwrap(),
285 "localhost"
286 );
287
288 let result = expander.expand("${MISSING:?Missing is required}");
290 assert!(result.is_err());
291 assert!(
292 result
293 .unwrap_err()
294 .to_string()
295 .contains("Missing is required")
296 );
297 }
298
299 #[test]
300 fn test_expand_missing() {
301 let expander = EnvExpander::with_source(test_source());
302
303 let result = expander.expand("${MISSING}");
304 assert!(matches!(result, Err(ConnectionError::EnvNotFound(_))));
305 }
306
307 #[test]
308 fn test_expand_full_url() {
309 let expander = EnvExpander::with_source(test_source());
310
311 let url = "postgres://${USER}:${PASS}@${HOST}:${PORT}/mydb?sslmode=require";
312 let expanded = expander.expand(url).unwrap();
313
314 assert_eq!(
315 expanded,
316 "postgres://testuser:secret@localhost:5432/mydb?sslmode=require"
317 );
318 }
319
320 #[test]
321 fn test_has_variables() {
322 assert!(EnvExpander::<StdEnvSource>::has_variables("${VAR}"));
323 assert!(EnvExpander::<StdEnvSource>::has_variables("$VAR"));
324 assert!(!EnvExpander::<StdEnvSource>::has_variables("no variables"));
325 }
326
327 #[test]
328 fn test_literal_dollar() {
329 let expander = EnvExpander::with_source(test_source());
330
331 assert_eq!(expander.expand("cost: $5").unwrap(), "cost: $5");
333 }
334}