prax_query/connection/
env.rs1use super::{ConnectionError, ConnectionResult};
4use std::collections::HashMap;
5
6pub trait EnvSource: Send + Sync {
8 fn get(&self, name: &str) -> Option<String>;
10
11 fn contains(&self, name: &str) -> bool {
13 self.get(name).is_some()
14 }
15}
16
17#[derive(Debug, Clone, Copy, Default)]
19pub struct StdEnvSource;
20
21impl EnvSource for StdEnvSource {
22 fn get(&self, name: &str) -> Option<String> {
23 std::env::var(name).ok()
24 }
25}
26
27#[derive(Debug, Clone, Default)]
29pub struct MapEnvSource {
30 vars: HashMap<String, String>,
31}
32
33impl MapEnvSource {
34 pub fn new() -> Self {
36 Self::default()
37 }
38
39 pub fn set(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
41 self.vars.insert(name.into(), value.into());
42 self
43 }
44
45 pub fn with_vars(mut self, vars: HashMap<String, String>) -> Self {
47 self.vars.extend(vars);
48 self
49 }
50}
51
52impl EnvSource for MapEnvSource {
53 fn get(&self, name: &str) -> Option<String> {
54 self.vars.get(name).cloned()
55 }
56}
57
58#[derive(Debug, Clone)]
66pub struct EnvExpander<S: EnvSource = StdEnvSource> {
67 source: S,
68}
69
70impl EnvExpander<StdEnvSource> {
71 pub fn new() -> Self {
73 Self {
74 source: StdEnvSource,
75 }
76 }
77}
78
79impl Default for EnvExpander<StdEnvSource> {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl<S: EnvSource> EnvExpander<S> {
86 pub fn with_source(source: S) -> Self {
88 Self { source }
89 }
90
91 pub fn expand(&self, input: &str) -> ConnectionResult<String> {
106 let mut result = String::with_capacity(input.len());
107 let mut chars = input.chars().peekable();
108
109 while let Some(c) = chars.next() {
110 if c == '$' {
111 if chars.peek() == Some(&'{') {
112 chars.next(); let expanded = self.expand_braced(&mut chars)?;
115 result.push_str(&expanded);
116 } else if chars.peek().map_or(false, |c| c.is_alphabetic() || *c == '_') {
117 let expanded = self.expand_simple(&mut chars)?;
119 result.push_str(&expanded);
120 } else {
121 result.push(c);
123 }
124 } else {
125 result.push(c);
126 }
127 }
128
129 Ok(result)
130 }
131
132 fn expand_braced(&self, chars: &mut std::iter::Peekable<std::str::Chars>) -> ConnectionResult<String> {
133 let mut name = String::new();
134 let mut modifier = None;
135 let mut modifier_value = String::new();
136
137 while let Some(&c) = chars.peek() {
138 if c == '}' {
139 chars.next();
140 break;
141 } else if c == ':' && modifier.is_none() {
142 chars.next();
143 if let Some(&next) = chars.peek() {
145 modifier = Some(next);
146 chars.next();
147 }
148 } else if modifier.is_some() {
149 modifier_value.push(c);
150 chars.next();
151 } else {
152 name.push(c);
153 chars.next();
154 }
155 }
156
157 if name.is_empty() {
158 return Err(ConnectionError::InvalidEnvValue {
159 name: "".to_string(),
160 message: "Empty variable name".to_string(),
161 });
162 }
163
164 match self.source.get(&name) {
165 Some(value) if !value.is_empty() => Ok(value),
166 _ => {
167 match modifier {
168 Some('-') => Ok(modifier_value),
169 Some('?') => Err(ConnectionError::InvalidEnvValue {
170 name: name.clone(),
171 message: if modifier_value.is_empty() {
172 format!("Required variable '{}' is not set", name)
173 } else {
174 modifier_value
175 },
176 }),
177 Some('+') => {
178 Ok(String::new())
180 }
181 _ => Err(ConnectionError::EnvNotFound(name)),
182 }
183 }
184 }
185 }
186
187 fn expand_simple(&self, chars: &mut std::iter::Peekable<std::str::Chars>) -> ConnectionResult<String> {
188 let mut name = String::new();
189
190 while let Some(&c) = chars.peek() {
191 if c.is_alphanumeric() || c == '_' {
192 name.push(c);
193 chars.next();
194 } else {
195 break;
196 }
197 }
198
199 self.source.get(&name).ok_or_else(|| ConnectionError::EnvNotFound(name))
200 }
201
202 pub fn expand_url(&self, url: &str) -> ConnectionResult<String> {
204 self.expand(url)
205 }
206
207 pub fn has_variables(input: &str) -> bool {
209 input.contains('$')
210 }
211}
212
213pub fn expand_env(input: &str) -> ConnectionResult<String> {
215 EnvExpander::new().expand(input)
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 fn test_source() -> MapEnvSource {
223 MapEnvSource::new()
224 .set("HOST", "localhost")
225 .set("PORT", "5432")
226 .set("USER", "testuser")
227 .set("PASS", "secret")
228 .set("EMPTY", "")
229 }
230
231 #[test]
232 fn test_expand_simple() {
233 let expander = EnvExpander::with_source(test_source());
234
235 assert_eq!(
236 expander.expand("postgres://$HOST/db").unwrap(),
237 "postgres://localhost/db"
238 );
239 }
240
241 #[test]
242 fn test_expand_braced() {
243 let expander = EnvExpander::with_source(test_source());
244
245 assert_eq!(
246 expander.expand("postgres://${HOST}:${PORT}/db").unwrap(),
247 "postgres://localhost:5432/db"
248 );
249 }
250
251 #[test]
252 fn test_expand_default() {
253 let expander = EnvExpander::with_source(test_source());
254
255 assert_eq!(
257 expander.expand("${HOST:-default}").unwrap(),
258 "localhost"
259 );
260
261 assert_eq!(
263 expander.expand("${MISSING:-default}").unwrap(),
264 "default"
265 );
266
267 assert_eq!(
269 expander.expand("${EMPTY:-default}").unwrap(),
270 "default"
271 );
272 }
273
274 #[test]
275 fn test_expand_required() {
276 let expander = EnvExpander::with_source(test_source());
277
278 assert_eq!(
280 expander.expand("${HOST:?Host is required}").unwrap(),
281 "localhost"
282 );
283
284 let result = expander.expand("${MISSING:?Missing is required}");
286 assert!(result.is_err());
287 assert!(result.unwrap_err().to_string().contains("Missing is required"));
288 }
289
290 #[test]
291 fn test_expand_missing() {
292 let expander = EnvExpander::with_source(test_source());
293
294 let result = expander.expand("${MISSING}");
295 assert!(matches!(result, Err(ConnectionError::EnvNotFound(_))));
296 }
297
298 #[test]
299 fn test_expand_full_url() {
300 let expander = EnvExpander::with_source(test_source());
301
302 let url = "postgres://${USER}:${PASS}@${HOST}:${PORT}/mydb?sslmode=require";
303 let expanded = expander.expand(url).unwrap();
304
305 assert_eq!(
306 expanded,
307 "postgres://testuser:secret@localhost:5432/mydb?sslmode=require"
308 );
309 }
310
311 #[test]
312 fn test_has_variables() {
313 assert!(EnvExpander::<StdEnvSource>::has_variables("${VAR}"));
314 assert!(EnvExpander::<StdEnvSource>::has_variables("$VAR"));
315 assert!(!EnvExpander::<StdEnvSource>::has_variables("no variables"));
316 }
317
318 #[test]
319 fn test_literal_dollar() {
320 let expander = EnvExpander::with_source(test_source());
321
322 assert_eq!(
324 expander.expand("cost: $5").unwrap(),
325 "cost: $5"
326 );
327 }
328}
329