extern crate proc_macro;
use proc_macro::TokenStream;
use syn::parse::Parser;
use syn::punctuated::Punctuated;
use syn::visit::Visit;
use syn::visit_mut::{visit_expr_mut, VisitMut};
use syn::{
parse_macro_input, Expr, ExprAssign, ExprLit, ExprPath, Item, Lit, LitBool, Macro, Token,
};
use quote::{quote, ToTokens};
struct NumericLiteralVisitor<'a> {
pub parameters: MacroParameters,
pub placeholder: &'a str,
pub float_replacement: &'a Expr,
pub int_replacement: &'a Expr,
}
struct FloatLiteralVisitor<'a> {
pub parameters: MacroParameters,
pub placeholder: &'a str,
pub replacement: &'a Expr,
}
struct IntLiteralVisitor<'a> {
pub parameters: MacroParameters,
pub placeholder: &'a str,
pub replacement: &'a Expr,
}
enum PrimitiveClass {
Float,
Int,
Other,
}
fn determine_primitive_class(lit_expr: &ExprLit) -> PrimitiveClass {
match &lit_expr.lit {
Lit::Float(_) => PrimitiveClass::Float,
Lit::Int(int_lit) if matches!(int_lit.suffix(), "f32" | "f64") => PrimitiveClass::Float,
Lit::Int(_) => PrimitiveClass::Int,
_ => PrimitiveClass::Other,
}
}
fn replace_literal(expr: &mut Expr, placeholder: &str, literal: &ExprLit) {
let mut replacer = ReplacementExpressionVisitor {
placeholder,
literal,
};
replacer.visit_expr_mut(expr);
}
fn try_parse_punctuated_macro<P: ToTokens, V: VisitMut, F: Parser<Output = Punctuated<Expr, P>>>(
visitor: &mut V,
mac: &mut Macro,
parser: F,
) -> bool {
if let Ok(mut exprs) = mac.parse_body_with(parser) {
exprs
.iter_mut()
.for_each(|expr| visitor.visit_expr_mut(expr));
mac.tokens = exprs.into_token_stream();
return true;
}
return false;
}
fn visit_macros_mut<V: VisitMut>(visitor: &mut V, mac: &mut Macro) {
if let Ok(mut expr) = mac.parse_body::<Expr>() {
visitor.visit_expr_mut(&mut expr);
mac.tokens = expr.into_token_stream();
return;
}
let parser_comma = Punctuated::<Expr, Token![,]>::parse_terminated;
if try_parse_punctuated_macro(visitor, mac, parser_comma) {
return;
}
let parser_semicolon = Punctuated::<Expr, Token![;]>::parse_terminated;
if try_parse_punctuated_macro(visitor, mac, parser_semicolon) {
return;
}
}
impl<'a> VisitMut for FloatLiteralVisitor<'a> {
fn visit_expr_mut(&mut self, expr: &mut Expr) {
if let Expr::Lit(lit_expr) = expr {
if let PrimitiveClass::Float = determine_primitive_class(&lit_expr) {
let mut adapted_replacement = self.replacement.clone();
replace_literal(&mut adapted_replacement, self.placeholder, lit_expr);
*expr = adapted_replacement;
return;
}
}
visit_expr_mut(self, expr)
}
fn visit_macro_mut(&mut self, mac: &mut Macro) {
if self.parameters.visit_macros {
visit_macros_mut(self, mac);
}
}
}
impl<'a> VisitMut for IntLiteralVisitor<'a> {
fn visit_expr_mut(&mut self, expr: &mut Expr) {
if let Expr::Lit(lit_expr) = expr {
if let PrimitiveClass::Int = determine_primitive_class(&lit_expr) {
let mut adapted_replacement = self.replacement.clone();
replace_literal(&mut adapted_replacement, self.placeholder, lit_expr);
*expr = adapted_replacement;
return;
}
}
visit_expr_mut(self, expr)
}
fn visit_macro_mut(&mut self, mac: &mut Macro) {
if self.parameters.visit_macros {
visit_macros_mut(self, mac);
}
}
}
impl<'a> VisitMut for NumericLiteralVisitor<'a> {
fn visit_expr_mut(&mut self, expr: &mut Expr) {
if let Expr::Lit(lit_expr) = expr {
match determine_primitive_class(&lit_expr) {
PrimitiveClass::Float => {
let mut visitor = FloatLiteralVisitor {
parameters: self.parameters,
placeholder: self.placeholder,
replacement: self.float_replacement,
};
visitor.visit_expr_mut(expr);
return;
}
PrimitiveClass::Int => {
let mut visitor = IntLiteralVisitor {
parameters: self.parameters,
placeholder: self.placeholder,
replacement: self.int_replacement,
};
visitor.visit_expr_mut(expr);
return;
}
_ => {}
}
}
visit_expr_mut(self, expr)
}
fn visit_macro_mut(&mut self, mac: &mut Macro) {
if self.parameters.visit_macros {
visit_macros_mut(self, mac);
}
}
}
struct ReplacementExpressionVisitor<'a> {
pub placeholder: &'a str,
pub literal: &'a ExprLit,
}
impl<'a> VisitMut for ReplacementExpressionVisitor<'a> {
fn visit_expr_mut(&mut self, expr: &mut Expr) {
if let Expr::Path(path_expr) = expr {
if let Some(last_segment) = path_expr.path.segments.last() {
if last_segment.ident == self.placeholder {
*expr = Expr::Lit(self.literal.clone());
return;
}
}
}
visit_expr_mut(self, expr)
}
}
struct MacroParameterVisitor {
pub name: Option<String>,
pub value: Option<ParameterValue>,
}
impl MacroParameterVisitor {
fn parse_flag(expr: &Expr) -> Option<(String, ParameterValue)> {
let mut visitor = MacroParameterVisitor {
name: None,
value: None,
};
visitor.visit_expr(expr);
let name = visitor.name.take();
let value = visitor.value.take();
name.and_then(|n| value.and_then(|v| Some((n, v))))
}
}
impl<'ast> Visit<'ast> for MacroParameterVisitor {
fn visit_expr_assign(&mut self, expr: &'ast ExprAssign) {
self.visit_expr(&expr.left);
self.visit_expr(&expr.right);
}
fn visit_expr_path(&mut self, expr: &'ast ExprPath) {
let mut name = Vec::new();
expr.path
.leading_colon
.map(|_| name.push(String::from("::")));
for p in expr.path.segments.pairs() {
match p {
syn::punctuated::Pair::Punctuated(ps, _sep) => {
name.push(ps.ident.to_string());
name.push(String::from("::"));
}
syn::punctuated::Pair::End(ps) => {
name.push(ps.ident.to_string());
}
}
}
self.name = Some(name.concat());
}
fn visit_lit_bool(&mut self, expr: &'ast LitBool) {
self.value = Some(ParameterValue::Bool(expr.value));
}
}
enum ParameterValue {
Bool(bool),
}
#[derive(Copy, Clone)]
struct MacroParameters {
pub visit_macros: bool,
}
impl Default for MacroParameters {
fn default() -> Self {
Self { visit_macros: true }
}
}
impl MacroParameters {
fn set(&mut self, name: &str, value: ParameterValue) {
match name {
"visit_macros" => match value {
ParameterValue::Bool(v) => self.visit_macros = v,
},
_ => {}
}
}
}
fn parse_macro_attribute(attr: TokenStream) -> Result<(Expr, MacroParameters), syn::Error> {
let parser = Punctuated::<Expr, Token![,]>::parse_separated_nonempty;
let attributes = parser.parse(attr)?;
let mut attr_iter = attributes.into_iter();
let replacement = attr_iter.next().expect("No replacement provided");
let user_parameters: Vec<_> = attr_iter
.filter_map(|expr| MacroParameterVisitor::parse_flag(&expr))
.collect();
let mut parameters = MacroParameters::default();
for (name, value) in user_parameters {
parameters.set(&name, value);
}
Ok((replacement, parameters))
}
#[proc_macro_attribute]
pub fn replace_numeric_literals(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut input = parse_macro_input!(item as Item);
let (replacement, parameters) = match parse_macro_attribute(attr) {
Ok(res) => res,
Err(err) => return TokenStream::from(err.to_compile_error()),
};
let mut replacer = NumericLiteralVisitor {
parameters,
placeholder: "literal",
int_replacement: &replacement,
float_replacement: &replacement,
};
replacer.visit_item_mut(&mut input);
let expanded = quote! { #input };
TokenStream::from(expanded)
}
#[proc_macro_attribute]
pub fn replace_float_literals(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut input = parse_macro_input!(item as Item);
let (replacement, parameters) = match parse_macro_attribute(attr) {
Ok(res) => res,
Err(err) => return TokenStream::from(err.to_compile_error()),
};
let mut replacer = FloatLiteralVisitor {
parameters,
placeholder: "literal",
replacement: &replacement,
};
replacer.visit_item_mut(&mut input);
let expanded = quote! { #input };
TokenStream::from(expanded)
}
#[proc_macro_attribute]
pub fn replace_int_literals(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut input = parse_macro_input!(item as Item);
let (replacement, parameters) = match parse_macro_attribute(attr) {
Ok(res) => res,
Err(err) => return TokenStream::from(err.to_compile_error()),
};
let mut replacer = IntLiteralVisitor {
parameters,
placeholder: "literal",
replacement: &replacement,
};
replacer.visit_item_mut(&mut input);
let expanded = quote! { #input };
TokenStream::from(expanded)
}