const_env_impl/
lib.rs

1use std::collections::HashMap;
2
3use proc_macro2::TokenStream;
4use quote::{ToTokens, quote_spanned};
5use syn::{Expr, ExprLit, Lit};
6use syn::spanned::Spanned;
7
8pub trait ReadEnv {
9    fn read_env(&self, var_name: &String) -> Option<String>;
10}
11
12pub struct TestEnv {
13    env_vars: HashMap<String, String>
14}
15
16impl TestEnv {
17    pub fn builder() -> TestEnvBuilder {
18        TestEnvBuilder {
19            env_vars: HashMap::new()
20        }
21    }
22}
23
24impl ReadEnv for TestEnv {
25    fn read_env(&self, var_name: &String) -> Option<String> {
26        self.env_vars.get(var_name).cloned()
27    }
28}
29
30pub struct TestEnvBuilder {
31    env_vars: HashMap<String, String>
32}
33
34impl TestEnvBuilder {
35    pub fn set(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
36        self.env_vars.insert(name.into(), value.into());
37        self
38    }
39
40    pub fn build(self) -> TestEnv {
41        TestEnv {
42            env_vars: self.env_vars
43        }
44    }
45}
46
47struct MacroInput {
48    env_var_name: syn::LitStr,
49    default_value: syn::Expr,
50}
51
52impl syn::parse::Parse for MacroInput {
53    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
54        let args = syn::punctuated::Punctuated::<syn::Expr, syn::token::Comma>::parse_terminated(input)?;
55        if args.len() != 2 {
56            return Err(syn::Error::new(input.span(), "Exactly 2 arguments expected"));
57        }
58        let env_var_name = match args.first().unwrap() {
59            Expr::Lit(ExprLit { lit: syn::Lit::Str(lit_str), .. }) => {
60                lit_str.clone()
61            },
62            otherwise => return Err(syn::Error::new(otherwise.span(), "Expected first argument to be a string literal"))
63        };
64        let default_value = args.last().unwrap().clone();
65        Ok(Self {
66            env_var_name,
67            default_value
68        })
69    }
70}
71
72/// Include environment variable contents as a Rust literal.
73pub fn env_lit(tokens: TokenStream, read_env: impl ReadEnv) -> TokenStream {
74    let input: MacroInput = match syn::parse2(tokens) {
75        Ok(input) => input,
76        Err(err) => return err.to_compile_error()
77    };
78    let env_var_value = match read_env.read_env(&input.env_var_name.value()) {
79        Some(env_var_value) => env_var_value,
80        None => return input.default_value.into_token_stream()
81    };
82    let env_var_value_tokens = match env_var_value.parse::<TokenStream>() {
83        Ok(tokens) => tokens,
84        Err(err) => return syn::Error::new(input.env_var_name.span(), format!("{}", err)).to_compile_error()
85    };
86    // Special case logic for quoted literals such as strings. We want to allow users not
87    // to need to quote their environment variable values for strings, even though this is
88    // technically inconsistent with other literals. So here we handle string-like literals
89    // specially by auto-adding quotes. Note that we only do this for top-level string literals -
90    // other instances of literal strings, such as elements of an array, require the user to
91    // manually quote them.
92    match input.default_value {
93        syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(_), ..}) => {
94            let quoted = format!("\"{}\"", env_var_value);
95            match syn::parse_str::<syn::LitStr>(&quoted) {
96                Ok(literal) => literal.to_token_stream(),
97                Err(err) => syn::Error::new(input.env_var_name.span(), format!("Invalid string literal contents: {}", err)).to_compile_error()
98            }
99        }
100        syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::ByteStr(_), ..}) => {
101            let quoted = format!("b\"{}\"", env_var_value);
102            match syn::parse_str::<syn::LitByteStr>(&quoted) {
103                Ok(literal) => literal.to_token_stream(),
104                Err(err) => syn::Error::new(input.env_var_name.span(), format!("Invalid byte string literal contents: {}", err)).to_compile_error()
105            }
106        }
107        syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Char(_), ..}) => {
108            let quoted = format!("'{}'", env_var_value);
109            match syn::parse_str::<syn::LitChar>(&quoted) {
110                Ok(literal) => literal.to_token_stream(),
111                Err(err) => syn::Error::new(input.env_var_name.span(), format!("Invalid char literal contents: {}", err)).to_compile_error()
112            }
113        }
114        syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Byte(_), ..}) => {
115            let quoted = format!("b'{}'", env_var_value);
116            match syn::parse_str::<syn::LitByte>(&quoted) {
117                Ok(literal) => literal.to_token_stream(),
118                Err(err) => syn::Error::new(input.env_var_name.span(), format!("Invalid byte literal contents: {}", err)).to_compile_error()
119            }
120        }
121        _ => env_var_value_tokens
122    }
123}
124
125/// Inner implementation details of `const_env::env_item`.
126pub fn env_item(attr: TokenStream, item: TokenStream, read_env: impl ReadEnv) -> TokenStream {
127    match try_env_item(attr, item, read_env) {
128        Ok(tokens) => tokens,
129        Err(err) => err.into_compile_error()
130    }
131}
132
133fn try_env_item(attr: TokenStream, item: TokenStream, read_env: impl ReadEnv) -> Result<TokenStream, syn::Error> {
134    if let Ok(mut item_const) = syn::parse2::<syn::ItemConst>(item.clone()) {
135        let default_var_name = format!("{}", item_const.ident);
136        let var_name = extract_var_name(attr, default_var_name)?;
137        let var_value = match read_env.read_env(&var_name) {
138            Some(val) => val,
139            None => return Ok(item)
140        };
141        let new_expr = value_to_literal(&var_value, &item_const.expr)?;
142        let span = item_const.span();
143        item_const.expr = Box::new(new_expr);
144        Ok(quote_spanned!(span => #item_const))
145    } else if let Ok(mut item_static) = syn::parse2::<syn::ItemStatic>(item.clone()) {
146        let default_var_name = format!("{}", item_static.ident);
147        let var_name = extract_var_name(attr, default_var_name)?;
148        let var_value = match read_env.read_env(&var_name) {
149            Some(val) => val,
150            None => return Ok(item)
151        };
152        let new_expr = value_to_literal(&var_value, &item_static.expr)?;
153        let span = item_static.span();
154        item_static.expr = Box::new(new_expr);
155        Ok(quote_spanned!(span => #item_static))
156    } else {
157        Err(syn::Error::new(attr.span(), "Macro is only valid on const or static items"))
158    }
159}
160
161fn extract_var_name(attr: TokenStream, default: String) -> Result<String, syn::Error> {
162    if attr.is_empty() {
163        return Ok(default);
164    }
165    let span = attr.span();
166    let expr: Expr = syn::parse2(attr)
167        .map_err(|_| syn::Error::new(span,"Unable to parse attribute args as expression"))?;
168    extract_var_name_from_expr(&expr)
169}
170
171fn extract_var_name_from_expr(expr: &Expr) -> Result<String, syn::Error> {
172    match expr {
173        Expr::Lit(literal) => {
174            match &literal.lit {
175                Lit::Str(lit_str) => {
176                    Ok(lit_str.value())
177                },
178                _ => Err(syn::Error::new_spanned(expr, "Attribute arguments are not a valid string literal"))
179            }
180        },
181        Expr::Paren(paren) => {
182            extract_var_name_from_expr(&paren.expr)
183        },
184        _ => {
185            Err(syn::Error::new_spanned(expr, "Attribute arguments are not a valid string literal expression"))
186        }
187    }
188}
189
190fn value_to_literal(value: &str, original_expr: &Expr) -> Result<Expr, syn::Error> {
191    Ok(match original_expr {
192        Expr::Array(array) => {
193            syn::Expr::Array(syn::parse_str::<syn::ExprArray>(value)
194                .map_err(|_| syn::Error::new_spanned(array, "Failed to parse environment variable contents as valid array"))?)
195        },
196        Expr::Unary(unary) => {
197            // A unary sign indicates this is a numeric literal which doesn't need any
198            // escaping, so we can parse it directly.
199            let new: Expr = syn::parse_str(value)
200                .map_err(|_| syn::Error::new_spanned(unary, "Failed to parse environment variable contents as valid expression"))?;
201            return Ok(new);
202        },
203        Expr::Lit(literal) => {
204            let new_lit = match &literal.lit {
205                Lit::Str(original) => {
206                    let mut new: syn::LitStr = syn::parse_str(&format!("\"{}\"", value))
207                        .map_err(|_| syn::Error::new_spanned(original, "Failed to parse environment variable contents as literal string"))?;
208                    new.set_span(original.span());
209                    Lit::Str(new)
210                },
211                Lit::ByteStr(original) => {
212                    let mut new: syn::LitByteStr = syn::parse_str(&format!("b\"{}\"", value))
213                        .map_err(|_| syn::Error::new_spanned(original, "Failed to parse environment variable contents as literal byte string"))?;
214                    new.set_span(original.span());
215                    Lit::ByteStr(new)
216                },
217                Lit::Byte(original) => {
218                    let mut new: syn::LitByte = syn::parse_str(&format!("b'{}'", value))
219                        .map_err(|_| syn::Error::new_spanned(original, "Failed to parse environment variable contents as literal byte"))?;
220                    new.set_span(original.span());
221                    Lit::Byte(new)
222                },
223                Lit::Char(original) => {
224                    let mut new: syn::LitChar = syn::parse_str(&format!("'{}'", value))
225                        .map_err(|_| syn::Error::new_spanned(original, "Failed to parse environment variable contents as literal character"))?;
226                    new.set_span(original.span());
227                    Lit::Char(new)
228                },
229                // These variants do not need any escaping and can be parsed as an expression
230                // directly.
231                Lit::Bool(_) | Lit::Int(_) | Lit::Float(_) | Lit::Verbatim(_) => {
232                    let new: Expr = syn::parse_str(value)
233                        .map_err(|_| syn::Error::new_spanned(original_expr, "Failed to parse environment variable contents as valid expression"))?;
234                    return Ok(new);
235                },
236                unhandled => {
237                    return Err(syn::Error::new_spanned(unhandled, "Unsupported literal type"));
238                }
239            };
240            ExprLit {
241                attrs: literal.attrs.clone(),
242                lit: new_lit
243            }.into()
244        },
245        Expr::Struct(_) => {
246            return Ok(syn::parse_str(value)?)
247        }
248        expr => {
249            return Err(syn::Error::new_spanned(expr, "Original const expression was not a recognized literal expression"));
250        }
251    })
252}