prax_query/connection/
env.rs1#![allow(dead_code)]
4
5use super::{ConnectionError, ConnectionResult};
6use std::collections::HashMap;
7
8pub trait EnvSource: Send + Sync {
10 fn get(&self, name: &str) -> Option<String>;
12
13 fn contains(&self, name: &str) -> bool {
15 self.get(name).is_some()
16 }
17}
18
19#[derive(Debug, Clone, Copy, Default)]
21pub struct StdEnvSource;
22
23impl EnvSource for StdEnvSource {
24 fn get(&self, name: &str) -> Option<String> {
25 std::env::var(name).ok()
26 }
27}
28
29#[derive(Debug, Clone, Default)]
31pub struct MapEnvSource {
32 vars: HashMap<String, String>,
33}
34
35impl MapEnvSource {
36 pub fn new() -> Self {
38 Self::default()
39 }
40
41 pub fn set(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
43 self.vars.insert(name.into(), value.into());
44 self
45 }
46
47 pub fn with_vars(mut self, vars: HashMap<String, String>) -> Self {
49 self.vars.extend(vars);
50 self
51 }
52}
53
54impl EnvSource for MapEnvSource {
55 fn get(&self, name: &str) -> Option<String> {
56 self.vars.get(name).cloned()
57 }
58}
59
60#[derive(Debug, Clone)]
68pub struct EnvExpander<S: EnvSource = StdEnvSource> {
69 source: S,
70}
71
72impl EnvExpander<StdEnvSource> {
73 pub fn new() -> Self {
75 Self {
76 source: StdEnvSource,
77 }
78 }
79}
80
81impl Default for EnvExpander<StdEnvSource> {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl<S: EnvSource> EnvExpander<S> {
88 pub fn with_source(source: S) -> Self {
90 Self { source }
91 }
92
93 pub fn expand(&self, input: &str) -> ConnectionResult<String> {
108 let mut result = String::with_capacity(input.len());
109 let mut chars = input.chars().peekable();
110
111 while let Some(c) = chars.next() {
112 if c == '$' {
113 if chars.peek() == Some(&'{') {
114 chars.next(); let expanded = self.expand_braced(&mut chars)?;
117 result.push_str(&expanded);
118 } else if chars.peek().is_some_and(|c| c.is_alphabetic() || *c == '_') {
119 let expanded = self.expand_simple(&mut chars)?;
121 result.push_str(&expanded);
122 } else {
123 result.push(c);
125 }
126 } else {
127 result.push(c);
128 }
129 }
130
131 Ok(result)
132 }
133
134 fn expand_braced(
135 &self,
136 chars: &mut std::iter::Peekable<std::str::Chars>,
137 ) -> ConnectionResult<String> {
138 let mut name = String::new();
139 let mut modifier = None;
140 let mut modifier_value = String::new();
141
142 while let Some(&c) = chars.peek() {
143 if c == '}' {
144 chars.next();
145 break;
146 } else if c == ':' && modifier.is_none() {
147 chars.next();
148 if let Some(&next) = chars.peek() {
150 modifier = Some(next);
151 chars.next();
152 }
153 } else if modifier.is_some() {
154 modifier_value.push(c);
155 chars.next();
156 } else {
157 name.push(c);
158 chars.next();
159 }
160 }
161
162 if name.is_empty() {
163 return Err(ConnectionError::InvalidEnvValue {
164 name: "".to_string(),
165 message: "Empty variable name".to_string(),
166 });
167 }
168
169 match self.source.get(&name) {
170 Some(value) if !value.is_empty() => Ok(value),
171 _ => {
172 match modifier {
173 Some('-') => Ok(modifier_value),
174 Some('?') => Err(ConnectionError::InvalidEnvValue {
175 name: name.clone(),
176 message: if modifier_value.is_empty() {
177 format!("Required variable '{}' is not set", name)
178 } else {
179 modifier_value
180 },
181 }),
182 Some('+') => {
183 Ok(String::new())
185 }
186 _ => Err(ConnectionError::EnvNotFound(name)),
187 }
188 }
189 }
190 }
191
192 fn expand_simple(
193 &self,
194 chars: &mut std::iter::Peekable<std::str::Chars>,
195 ) -> ConnectionResult<String> {
196 let mut name = String::new();
197
198 while let Some(&c) = chars.peek() {
199 if c.is_alphanumeric() || c == '_' {
200 name.push(c);
201 chars.next();
202 } else {
203 break;
204 }
205 }
206
207 self.source
208 .get(&name)
209 .ok_or(ConnectionError::EnvNotFound(name))
210 }
211
212 pub fn expand_url(&self, url: &str) -> ConnectionResult<String> {
214 self.expand(url)
215 }
216
217 pub fn has_variables(input: &str) -> bool {
219 input.contains('$')
220 }
221}
222
223pub fn expand_env(input: &str) -> ConnectionResult<String> {
225 EnvExpander::new().expand(input)
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 fn test_source() -> MapEnvSource {
233 MapEnvSource::new()
234 .set("HOST", "localhost")
235 .set("PORT", "5432")
236 .set("USER", "testuser")
237 .set("PASS", "secret")
238 .set("EMPTY", "")
239 }
240
241 #[test]
242 fn test_expand_simple() {
243 let expander = EnvExpander::with_source(test_source());
244
245 assert_eq!(
246 expander.expand("postgres://$HOST/db").unwrap(),
247 "postgres://localhost/db"
248 );
249 }
250
251 #[test]
252 fn test_expand_braced() {
253 let expander = EnvExpander::with_source(test_source());
254
255 assert_eq!(
256 expander.expand("postgres://${HOST}:${PORT}/db").unwrap(),
257 "postgres://localhost:5432/db"
258 );
259 }
260
261 #[test]
262 fn test_expand_default() {
263 let expander = EnvExpander::with_source(test_source());
264
265 assert_eq!(expander.expand("${HOST:-default}").unwrap(), "localhost");
267
268 assert_eq!(expander.expand("${MISSING:-default}").unwrap(), "default");
270
271 assert_eq!(expander.expand("${EMPTY:-default}").unwrap(), "default");
273 }
274
275 #[test]
276 fn test_expand_required() {
277 let expander = EnvExpander::with_source(test_source());
278
279 assert_eq!(
281 expander.expand("${HOST:?Host is required}").unwrap(),
282 "localhost"
283 );
284
285 let result = expander.expand("${MISSING:?Missing is required}");
287 assert!(result.is_err());
288 assert!(
289 result
290 .unwrap_err()
291 .to_string()
292 .contains("Missing is required")
293 );
294 }
295
296 #[test]
297 fn test_expand_missing() {
298 let expander = EnvExpander::with_source(test_source());
299
300 let result = expander.expand("${MISSING}");
301 assert!(matches!(result, Err(ConnectionError::EnvNotFound(_))));
302 }
303
304 #[test]
305 fn test_expand_full_url() {
306 let expander = EnvExpander::with_source(test_source());
307
308 let url = "postgres://${USER}:${PASS}@${HOST}:${PORT}/mydb?sslmode=require";
309 let expanded = expander.expand(url).unwrap();
310
311 assert_eq!(
312 expanded,
313 "postgres://testuser:secret@localhost:5432/mydb?sslmode=require"
314 );
315 }
316
317 #[test]
318 fn test_has_variables() {
319 assert!(EnvExpander::<StdEnvSource>::has_variables("${VAR}"));
320 assert!(EnvExpander::<StdEnvSource>::has_variables("$VAR"));
321 assert!(!EnvExpander::<StdEnvSource>::has_variables("no variables"));
322 }
323
324 #[test]
325 fn test_literal_dollar() {
326 let expander = EnvExpander::with_source(test_source());
327
328 assert_eq!(expander.expand("cost: $5").unwrap(), "cost: $5");
330 }
331}