numeric_literals/
lib.rs

1//!
2//! **numeric_literals** is a Rust library that provides procedural attribute macros for replacing
3//! numeric literals with arbitrary expressions.
4//!
5//! While Rust's explicitness is generally a boon, it is a major pain when writing numeric
6//! code that is intended to be generic over a scalar type. As an example, consider
7//! writing a function that returns the golden ratio for any type that implements `T: num::Float`.
8//! An implementation might look like the following.
9//!
10//! ```rust
11//! extern crate num;
12//! use num::Float;
13//!
14//! fn golden_ratio<T: Float>() -> T {
15//!     ( T::one() + T::sqrt(T::from(5).unwrap())) / T::from(2).unwrap()
16//! }
17//! ```
18//!
19//! This is arguably very messy for such a simple task. With `numeric_literals`, we may
20//! instead write:
21//!
22//! ```rust
23//! # use num::Float;
24//! use numeric_literals::replace_numeric_literals;
25//!
26//! #[replace_numeric_literals(T::from(literal).unwrap())]
27//! fn golden_ratio<T: Float>() -> T {
28//!    (1 + 5.sqrt()) / 2
29//! }
30//! ```
31//!
32//! The above two code segments do essentially the same thing
33//! (apart from using `T::from(1)` instead of `T::one()`). However, in the latter example,
34//! the `replace_numeric_literals` attribute replaces any numeric literal with the expression
35//! `T::from(literal).unwrap()`, where `literal` is a placeholder for each individual literal.
36//!
37//! There is no magic involved: the code is still explict about what it does to numeric literals.
38//! The difference is that we can declare this behavior once for all numeric literals. Moreover,
39//! we move the conversion behavior away from where the literals are needed, enhancing readability
40//! by reducing the noise imposed by being explicit about the exact types involved.
41//!
42//! Float and integer literal replacement
43//! -------------------------------------
44//!
45//! An issue with the replacement of numeric literals is that there is no way to distinguish
46//! literals that are used for e.g. indexing from those that are part of a numerical computation.
47//! In the example above, if you would additionally need to index into an array with a constant index
48//! such as `array[0]`, the macro will try to convert the index `0` to a float type, which
49//! would clearly fail. Thankfully, in most cases these examples will outright fail to compile
50//! because of type mismatch. One possible resolution to this problem is to use the separate
51//! macros `replace_float_literals` and `replace_int_literals`, which work in the exact same way,
52//! but only trigger on float or integer literals, respectively. Below is an example from
53//! Finite Element code that uses float literal replacement to improve readability of numerical
54//! constants in generic code.
55//!
56//! ```ignore
57//! #[replace_float_literals(T::from_f64(literal).expect("Literal must fit in T"))]
58//! pub fn assemble_element_mass<T>(quad: &Quad2d<T>) -> MatrixN<T, U8>
59//! where
60//!    T: RealField
61//! {
62//!     let phi = |alpha, beta, xi: &Vector2<T>| -(1.0 + alpha * xi[0]) * (1.0 + beta * xi[1]) / 4.0;
63//!     let phi_grad = |alpha, beta, xi: &Vector2<T>| {
64//!         Vector2::new(
65//!             alpha * (1.0 + beta * xi[1]) / 4.0,
66//!             beta * (1.0 + alpha * xi[0]) / 4.0,
67//!         )
68//!     };
69//!     let alphas = [-1.0, 1.0, 1.0, -1.0];
70//!     let betas = [-1.0, -1.0, 1.0, 1.0];
71//!
72//!     // And so on...
73//! }
74//! ```
75//!
76//! In general, **the macros should be used with caution**. It is recommended to keep the macro close to
77//! the region in which the literals are being used, as to avoid confusion for readers of the code.
78//! The Rust code before macro expansion is usually not valid Rust (because of the lack of explicit
79//! type conversion), but without the context of the attribute, it is simply not clear why this
80//! code still compiles.
81//!
82//! An option for the future would be to apply the attribute only to very local blocks of code that
83//! are heavy on numerical constants. However, at present, Rust does not allow attribute macros
84//! to apply to blocks or single expressions.
85//!
86//! Replacement in macro invocations
87//! --------------------------------
88//! By default, the macros of this crate will also replace literals inside of macro invocations.
89//! This allows code such as the following to compile:
90//!
91//! ```rust
92//! use num::Float;
93//! use numeric_literals::replace_numeric_literals;
94//!
95//! #[replace_numeric_literals(T::from(literal).unwrap())]
96//! fn zeros<T: Float>(n: usize) -> Vec<T> {
97//!     vec![0.0; n]
98//! }
99//! ```
100//! If this behavior is unwanted, it is possible to disable replacement inside of macros with a
101//! parameter:
102//! ```ignore
103//! #[replace_numeric_literals(T::from(literal).unwrap()), visit_macros = false]
104//! ```
105//!
106//! Literals with suffixes
107//! ----------------------
108//! In rust, literal suffixes can be used to disambiguate the type of a literal. For example, the suffix `_f64`
109//! in the expression `1_f64.sqrt()` makes it clear that the value `1` is of type `f64`. This is also supported
110//! by the macros of this crate for all floating point and integer suffixes. For example:
111//!
112//! ```rust
113//! use num::Float;
114//! use numeric_literals::replace_numeric_literals;
115//!
116//! #[replace_numeric_literals(T::from(literal).unwrap())]
117//! fn golden_ratio<T: Float>() -> T {
118//!     (1.0_f64 + 5f32.sqrt()) / 2.0
119//! }
120//! ```
121
122extern crate proc_macro;
123use proc_macro::TokenStream;
124
125use syn::parse::Parser;
126use syn::punctuated::Punctuated;
127use syn::visit::Visit;
128use syn::visit_mut::{visit_expr_mut, VisitMut};
129use syn::{
130    parse_macro_input, Expr, ExprAssign, ExprLit, ExprPath, Item, Lit, LitBool, Macro, Token,
131};
132
133use quote::{quote, ToTokens};
134
135/// Visit an expression and replaces any numeric literal
136/// with the replacement expression, in which a placeholder identifier
137/// is replaced with the numeric literal.
138struct NumericLiteralVisitor<'a> {
139    pub parameters: MacroParameters,
140    pub placeholder: &'a str,
141    pub float_replacement: &'a Expr,
142    pub int_replacement: &'a Expr,
143}
144
145struct FloatLiteralVisitor<'a> {
146    pub parameters: MacroParameters,
147    pub placeholder: &'a str,
148    pub replacement: &'a Expr,
149}
150
151struct IntLiteralVisitor<'a> {
152    pub parameters: MacroParameters,
153    pub placeholder: &'a str,
154    pub replacement: &'a Expr,
155}
156
157/// Represents classes of primitive types relevant to the crate
158enum PrimitiveClass {
159    Float,
160    Int,
161    Other,
162}
163
164/// Returns what class of primitive types is represented by this literal expression, e.g. `20f64 -> Float`, `20 -> Int`
165fn determine_primitive_class(lit_expr: &ExprLit) -> PrimitiveClass {
166    match &lit_expr.lit {
167        // Parsed float literals are always floats
168        Lit::Float(_) => PrimitiveClass::Float,
169        // Literals like `20f64` are parsed as `LitInt`s
170        Lit::Int(int_lit) if matches!(int_lit.suffix(), "f32" | "f64") => PrimitiveClass::Float,
171        // All other integer literals should be actual integers
172        Lit::Int(_) => PrimitiveClass::Int,
173        _ => PrimitiveClass::Other,
174    }
175}
176
177fn replace_literal(expr: &mut Expr, placeholder: &str, literal: &ExprLit) {
178    let mut replacer = ReplacementExpressionVisitor {
179        placeholder,
180        literal,
181    };
182    replacer.visit_expr_mut(expr);
183}
184
185fn try_parse_punctuated_macro<P: ToTokens, V: VisitMut, F: Parser<Output = Punctuated<Expr, P>>>(
186    visitor: &mut V,
187    mac: &mut Macro,
188    parser: F,
189) -> bool {
190    if let Ok(mut exprs) = mac.parse_body_with(parser) {
191        exprs
192            .iter_mut()
193            .for_each(|expr| visitor.visit_expr_mut(expr));
194        mac.tokens = exprs.into_token_stream();
195        return true;
196    }
197    return false;
198}
199
200fn visit_macros_mut<V: VisitMut>(visitor: &mut V, mac: &mut Macro) {
201    // Handle expression based macros (e.g. assert)
202    if let Ok(mut expr) = mac.parse_body::<Expr>() {
203        visitor.visit_expr_mut(&mut expr);
204        mac.tokens = expr.into_token_stream();
205        return;
206    }
207
208    // Handle , punctuation based macros (e.g. vec with list, assert_eq)
209    let parser_comma = Punctuated::<Expr, Token![,]>::parse_terminated;
210    if try_parse_punctuated_macro(visitor, mac, parser_comma) {
211        return;
212    }
213
214    // Handle ; punctuation based macros (e.g. vec with repeat)
215    let parser_semicolon = Punctuated::<Expr, Token![;]>::parse_terminated;
216    if try_parse_punctuated_macro(visitor, mac, parser_semicolon) {
217        return;
218    }
219}
220
221impl<'a> VisitMut for FloatLiteralVisitor<'a> {
222    fn visit_expr_mut(&mut self, expr: &mut Expr) {
223        if let Expr::Lit(lit_expr) = expr {
224            if let PrimitiveClass::Float = determine_primitive_class(&lit_expr) {
225                let mut adapted_replacement = self.replacement.clone();
226                replace_literal(&mut adapted_replacement, self.placeholder, lit_expr);
227                *expr = adapted_replacement;
228                return;
229            }
230        }
231        visit_expr_mut(self, expr)
232    }
233
234    fn visit_macro_mut(&mut self, mac: &mut Macro) {
235        if self.parameters.visit_macros {
236            visit_macros_mut(self, mac);
237        }
238    }
239}
240
241impl<'a> VisitMut for IntLiteralVisitor<'a> {
242    fn visit_expr_mut(&mut self, expr: &mut Expr) {
243        if let Expr::Lit(lit_expr) = expr {
244            if let PrimitiveClass::Int = determine_primitive_class(&lit_expr) {
245                let mut adapted_replacement = self.replacement.clone();
246                replace_literal(&mut adapted_replacement, self.placeholder, lit_expr);
247                *expr = adapted_replacement;
248                return;
249            }
250        }
251        visit_expr_mut(self, expr)
252    }
253
254    fn visit_macro_mut(&mut self, mac: &mut Macro) {
255        if self.parameters.visit_macros {
256            visit_macros_mut(self, mac);
257        }
258    }
259}
260
261impl<'a> VisitMut for NumericLiteralVisitor<'a> {
262    fn visit_expr_mut(&mut self, expr: &mut Expr) {
263        if let Expr::Lit(lit_expr) = expr {
264            // TODO: Currently we cannot correctly treat integers that don't fit in 64
265            //  bits. For this we'd have to deal with verbatim literals and manually
266            //  parse the string
267
268            match determine_primitive_class(&lit_expr) {
269                PrimitiveClass::Float => {
270                    let mut visitor = FloatLiteralVisitor {
271                        parameters: self.parameters,
272                        placeholder: self.placeholder,
273                        replacement: self.float_replacement,
274                    };
275                    visitor.visit_expr_mut(expr);
276                    return;
277                }
278                PrimitiveClass::Int => {
279                    let mut visitor = IntLiteralVisitor {
280                        parameters: self.parameters,
281                        placeholder: self.placeholder,
282                        replacement: self.int_replacement,
283                    };
284                    visitor.visit_expr_mut(expr);
285                    return;
286                }
287                _ => {}
288            }
289        }
290        visit_expr_mut(self, expr)
291    }
292
293    fn visit_macro_mut(&mut self, mac: &mut Macro) {
294        if self.parameters.visit_macros {
295            visit_macros_mut(self, mac);
296        }
297    }
298}
299
300/// Visits the "replacement expression", which replaces a placeholder identifier
301/// with the given literal.
302struct ReplacementExpressionVisitor<'a> {
303    pub placeholder: &'a str,
304    pub literal: &'a ExprLit,
305}
306
307impl<'a> VisitMut for ReplacementExpressionVisitor<'a> {
308    fn visit_expr_mut(&mut self, expr: &mut Expr) {
309        if let Expr::Path(path_expr) = expr {
310            if let Some(last_segment) = path_expr.path.segments.last() {
311                if last_segment.ident == self.placeholder {
312                    *expr = Expr::Lit(self.literal.clone());
313                    return;
314                }
315            }
316        }
317        visit_expr_mut(self, expr)
318    }
319}
320
321struct MacroParameterVisitor {
322    pub name: Option<String>,
323    pub value: Option<ParameterValue>,
324}
325
326impl MacroParameterVisitor {
327    fn parse_flag(expr: &Expr) -> Option<(String, ParameterValue)> {
328        let mut visitor = MacroParameterVisitor {
329            name: None,
330            value: None,
331        };
332        visitor.visit_expr(expr);
333        let name = visitor.name.take();
334        let value = visitor.value.take();
335        name.and_then(|n| value.and_then(|v| Some((n, v))))
336    }
337}
338
339impl<'ast> Visit<'ast> for MacroParameterVisitor {
340    fn visit_expr_assign(&mut self, expr: &'ast ExprAssign) {
341        self.visit_expr(&expr.left);
342        self.visit_expr(&expr.right);
343    }
344
345    fn visit_expr_path(&mut self, expr: &'ast ExprPath) {
346        let mut name = Vec::new();
347        expr.path
348            .leading_colon
349            .map(|_| name.push(String::from("::")));
350        for p in expr.path.segments.pairs() {
351            match p {
352                syn::punctuated::Pair::Punctuated(ps, _sep) => {
353                    name.push(ps.ident.to_string());
354                    name.push(String::from("::"));
355                }
356                syn::punctuated::Pair::End(ps) => {
357                    name.push(ps.ident.to_string());
358                }
359            }
360        }
361        self.name = Some(name.concat());
362    }
363
364    fn visit_lit_bool(&mut self, expr: &'ast LitBool) {
365        self.value = Some(ParameterValue::Bool(expr.value));
366    }
367}
368
369enum ParameterValue {
370    Bool(bool),
371}
372
373#[derive(Copy, Clone)]
374struct MacroParameters {
375    pub visit_macros: bool,
376}
377
378impl Default for MacroParameters {
379    fn default() -> Self {
380        Self { visit_macros: true }
381    }
382}
383
384impl MacroParameters {
385    fn set(&mut self, name: &str, value: ParameterValue) {
386        match name {
387            "visit_macros" => match value {
388                ParameterValue::Bool(v) => self.visit_macros = v,
389            },
390            _ => {}
391        }
392    }
393}
394
395/// Obtain the replacement expression and parameters from the macro attr token stream.
396fn parse_macro_attribute(attr: TokenStream) -> Result<(Expr, MacroParameters), syn::Error> {
397    let parser = Punctuated::<Expr, Token![,]>::parse_separated_nonempty;
398    let attributes = parser.parse(attr)?;
399
400    let mut attr_iter = attributes.into_iter();
401    let replacement = attr_iter.next().expect("No replacement provided");
402
403    let user_parameters: Vec<_> = attr_iter
404        .filter_map(|expr| MacroParameterVisitor::parse_flag(&expr))
405        .collect();
406    let mut parameters = MacroParameters::default();
407    for (name, value) in user_parameters {
408        parameters.set(&name, value);
409    }
410
411    Ok((replacement, parameters))
412}
413
414/// Replace any numeric literal with custom transformation code.
415///
416/// Refer to the documentation at the root of the crate for usage instructions.
417#[proc_macro_attribute]
418pub fn replace_numeric_literals(attr: TokenStream, item: TokenStream) -> TokenStream {
419    let mut input = parse_macro_input!(item as Item);
420    let (replacement, parameters) = match parse_macro_attribute(attr) {
421        Ok(res) => res,
422        Err(err) => return TokenStream::from(err.to_compile_error()),
423    };
424
425    let mut replacer = NumericLiteralVisitor {
426        parameters,
427        placeholder: "literal",
428        int_replacement: &replacement,
429        float_replacement: &replacement,
430    };
431    replacer.visit_item_mut(&mut input);
432
433    let expanded = quote! { #input };
434
435    TokenStream::from(expanded)
436}
437
438/// Replace any float literal with custom transformation code.
439///
440/// Refer to the documentation at the root of the crate for usage instructions.
441#[proc_macro_attribute]
442pub fn replace_float_literals(attr: TokenStream, item: TokenStream) -> TokenStream {
443    let mut input = parse_macro_input!(item as Item);
444    let (replacement, parameters) = match parse_macro_attribute(attr) {
445        Ok(res) => res,
446        Err(err) => return TokenStream::from(err.to_compile_error()),
447    };
448
449    let mut replacer = FloatLiteralVisitor {
450        parameters,
451        placeholder: "literal",
452        replacement: &replacement,
453    };
454    replacer.visit_item_mut(&mut input);
455
456    let expanded = quote! { #input };
457
458    TokenStream::from(expanded)
459}
460
461/// Replace any integer literal with custom transformation code.
462///
463/// Refer to the documentation at the root of the crate for usage instructions.
464#[proc_macro_attribute]
465pub fn replace_int_literals(attr: TokenStream, item: TokenStream) -> TokenStream {
466    let mut input = parse_macro_input!(item as Item);
467    let (replacement, parameters) = match parse_macro_attribute(attr) {
468        Ok(res) => res,
469        Err(err) => return TokenStream::from(err.to_compile_error()),
470    };
471
472    let mut replacer = IntLiteralVisitor {
473        parameters,
474        placeholder: "literal",
475        replacement: &replacement,
476    };
477    replacer.visit_item_mut(&mut input);
478
479    let expanded = quote! { #input };
480
481    TokenStream::from(expanded)
482}