1use std::collections::HashMap;
2
3use proc_macro2::TokenStream;
4use quote::{ToTokens, quote_spanned};
5use syn::{Expr, ExprLit, Lit};
6use syn::spanned::Spanned;
7
8pub trait ReadEnv {
9 fn read_env(&self, var_name: &String) -> Option<String>;
10}
11
12pub struct TestEnv {
13 env_vars: HashMap<String, String>
14}
15
16impl TestEnv {
17 pub fn builder() -> TestEnvBuilder {
18 TestEnvBuilder {
19 env_vars: HashMap::new()
20 }
21 }
22}
23
24impl ReadEnv for TestEnv {
25 fn read_env(&self, var_name: &String) -> Option<String> {
26 self.env_vars.get(var_name).cloned()
27 }
28}
29
30pub struct TestEnvBuilder {
31 env_vars: HashMap<String, String>
32}
33
34impl TestEnvBuilder {
35 pub fn set(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
36 self.env_vars.insert(name.into(), value.into());
37 self
38 }
39
40 pub fn build(self) -> TestEnv {
41 TestEnv {
42 env_vars: self.env_vars
43 }
44 }
45}
46
47struct MacroInput {
48 env_var_name: syn::LitStr,
49 default_value: syn::Expr,
50}
51
52impl syn::parse::Parse for MacroInput {
53 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
54 let args = syn::punctuated::Punctuated::<syn::Expr, syn::token::Comma>::parse_terminated(input)?;
55 if args.len() != 2 {
56 return Err(syn::Error::new(input.span(), "Exactly 2 arguments expected"));
57 }
58 let env_var_name = match args.first().unwrap() {
59 Expr::Lit(ExprLit { lit: syn::Lit::Str(lit_str), .. }) => {
60 lit_str.clone()
61 },
62 otherwise => return Err(syn::Error::new(otherwise.span(), "Expected first argument to be a string literal"))
63 };
64 let default_value = args.last().unwrap().clone();
65 Ok(Self {
66 env_var_name,
67 default_value
68 })
69 }
70}
71
72pub fn env_lit(tokens: TokenStream, read_env: impl ReadEnv) -> TokenStream {
74 let input: MacroInput = match syn::parse2(tokens) {
75 Ok(input) => input,
76 Err(err) => return err.to_compile_error()
77 };
78 let env_var_value = match read_env.read_env(&input.env_var_name.value()) {
79 Some(env_var_value) => env_var_value,
80 None => return input.default_value.into_token_stream()
81 };
82 let env_var_value_tokens = match env_var_value.parse::<TokenStream>() {
83 Ok(tokens) => tokens,
84 Err(err) => return syn::Error::new(input.env_var_name.span(), format!("{}", err)).to_compile_error()
85 };
86 match input.default_value {
93 syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(_), ..}) => {
94 let quoted = format!("\"{}\"", env_var_value);
95 match syn::parse_str::<syn::LitStr>("ed) {
96 Ok(literal) => literal.to_token_stream(),
97 Err(err) => syn::Error::new(input.env_var_name.span(), format!("Invalid string literal contents: {}", err)).to_compile_error()
98 }
99 }
100 syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::ByteStr(_), ..}) => {
101 let quoted = format!("b\"{}\"", env_var_value);
102 match syn::parse_str::<syn::LitByteStr>("ed) {
103 Ok(literal) => literal.to_token_stream(),
104 Err(err) => syn::Error::new(input.env_var_name.span(), format!("Invalid byte string literal contents: {}", err)).to_compile_error()
105 }
106 }
107 syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Char(_), ..}) => {
108 let quoted = format!("'{}'", env_var_value);
109 match syn::parse_str::<syn::LitChar>("ed) {
110 Ok(literal) => literal.to_token_stream(),
111 Err(err) => syn::Error::new(input.env_var_name.span(), format!("Invalid char literal contents: {}", err)).to_compile_error()
112 }
113 }
114 syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Byte(_), ..}) => {
115 let quoted = format!("b'{}'", env_var_value);
116 match syn::parse_str::<syn::LitByte>("ed) {
117 Ok(literal) => literal.to_token_stream(),
118 Err(err) => syn::Error::new(input.env_var_name.span(), format!("Invalid byte literal contents: {}", err)).to_compile_error()
119 }
120 }
121 _ => env_var_value_tokens
122 }
123}
124
125pub fn env_item(attr: TokenStream, item: TokenStream, read_env: impl ReadEnv) -> TokenStream {
127 match try_env_item(attr, item, read_env) {
128 Ok(tokens) => tokens,
129 Err(err) => err.into_compile_error()
130 }
131}
132
133fn try_env_item(attr: TokenStream, item: TokenStream, read_env: impl ReadEnv) -> Result<TokenStream, syn::Error> {
134 if let Ok(mut item_const) = syn::parse2::<syn::ItemConst>(item.clone()) {
135 let default_var_name = format!("{}", item_const.ident);
136 let var_name = extract_var_name(attr, default_var_name)?;
137 let var_value = match read_env.read_env(&var_name) {
138 Some(val) => val,
139 None => return Ok(item)
140 };
141 let new_expr = value_to_literal(&var_value, &item_const.expr)?;
142 let span = item_const.span();
143 item_const.expr = Box::new(new_expr);
144 Ok(quote_spanned!(span => #item_const))
145 } else if let Ok(mut item_static) = syn::parse2::<syn::ItemStatic>(item.clone()) {
146 let default_var_name = format!("{}", item_static.ident);
147 let var_name = extract_var_name(attr, default_var_name)?;
148 let var_value = match read_env.read_env(&var_name) {
149 Some(val) => val,
150 None => return Ok(item)
151 };
152 let new_expr = value_to_literal(&var_value, &item_static.expr)?;
153 let span = item_static.span();
154 item_static.expr = Box::new(new_expr);
155 Ok(quote_spanned!(span => #item_static))
156 } else {
157 Err(syn::Error::new(attr.span(), "Macro is only valid on const or static items"))
158 }
159}
160
161fn extract_var_name(attr: TokenStream, default: String) -> Result<String, syn::Error> {
162 if attr.is_empty() {
163 return Ok(default);
164 }
165 let span = attr.span();
166 let expr: Expr = syn::parse2(attr)
167 .map_err(|_| syn::Error::new(span,"Unable to parse attribute args as expression"))?;
168 extract_var_name_from_expr(&expr)
169}
170
171fn extract_var_name_from_expr(expr: &Expr) -> Result<String, syn::Error> {
172 match expr {
173 Expr::Lit(literal) => {
174 match &literal.lit {
175 Lit::Str(lit_str) => {
176 Ok(lit_str.value())
177 },
178 _ => Err(syn::Error::new_spanned(expr, "Attribute arguments are not a valid string literal"))
179 }
180 },
181 Expr::Paren(paren) => {
182 extract_var_name_from_expr(&paren.expr)
183 },
184 _ => {
185 Err(syn::Error::new_spanned(expr, "Attribute arguments are not a valid string literal expression"))
186 }
187 }
188}
189
190fn value_to_literal(value: &str, original_expr: &Expr) -> Result<Expr, syn::Error> {
191 Ok(match original_expr {
192 Expr::Array(array) => {
193 syn::Expr::Array(syn::parse_str::<syn::ExprArray>(value)
194 .map_err(|_| syn::Error::new_spanned(array, "Failed to parse environment variable contents as valid array"))?)
195 },
196 Expr::Unary(unary) => {
197 let new: Expr = syn::parse_str(value)
200 .map_err(|_| syn::Error::new_spanned(unary, "Failed to parse environment variable contents as valid expression"))?;
201 return Ok(new);
202 },
203 Expr::Lit(literal) => {
204 let new_lit = match &literal.lit {
205 Lit::Str(original) => {
206 let mut new: syn::LitStr = syn::parse_str(&format!("\"{}\"", value))
207 .map_err(|_| syn::Error::new_spanned(original, "Failed to parse environment variable contents as literal string"))?;
208 new.set_span(original.span());
209 Lit::Str(new)
210 },
211 Lit::ByteStr(original) => {
212 let mut new: syn::LitByteStr = syn::parse_str(&format!("b\"{}\"", value))
213 .map_err(|_| syn::Error::new_spanned(original, "Failed to parse environment variable contents as literal byte string"))?;
214 new.set_span(original.span());
215 Lit::ByteStr(new)
216 },
217 Lit::Byte(original) => {
218 let mut new: syn::LitByte = syn::parse_str(&format!("b'{}'", value))
219 .map_err(|_| syn::Error::new_spanned(original, "Failed to parse environment variable contents as literal byte"))?;
220 new.set_span(original.span());
221 Lit::Byte(new)
222 },
223 Lit::Char(original) => {
224 let mut new: syn::LitChar = syn::parse_str(&format!("'{}'", value))
225 .map_err(|_| syn::Error::new_spanned(original, "Failed to parse environment variable contents as literal character"))?;
226 new.set_span(original.span());
227 Lit::Char(new)
228 },
229 Lit::Bool(_) | Lit::Int(_) | Lit::Float(_) | Lit::Verbatim(_) => {
232 let new: Expr = syn::parse_str(value)
233 .map_err(|_| syn::Error::new_spanned(original_expr, "Failed to parse environment variable contents as valid expression"))?;
234 return Ok(new);
235 },
236 unhandled => {
237 return Err(syn::Error::new_spanned(unhandled, "Unsupported literal type"));
238 }
239 };
240 ExprLit {
241 attrs: literal.attrs.clone(),
242 lit: new_lit
243 }.into()
244 },
245 Expr::Struct(_) => {
246 return Ok(syn::parse_str(value)?)
247 }
248 expr => {
249 return Err(syn::Error::new_spanned(expr, "Original const expression was not a recognized literal expression"));
250 }
251 })
252}