1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
//!
//! **numeric_literals** is a Rust library that provides procedural attribute macros for replacing
//! numeric literals with arbitrary expressions.
//!
//! While Rust's explicitness is generally a boon, it is a major pain when writing numeric
//! code that is intended to be generic over a scalar type. As an example, consider
//! writing a function that returns the golden ratio for any type that implements `T: num::Float`.
//! An implementation might look like the following.
//!
//! ```rust
//! extern crate num;
//! use num::Float;
//!
//! fn golden_ratio<T: Float>() -> T {
//!     ( T::one() + T::sqrt(T::from(5).unwrap())) / T::from(2).unwrap()
//! }
//! ```
//!
//! This is arguably very messy for such a simple task. With `numeric_literals`, we may
//! instead write:
//!
//! ```rust
//! # use num::Float;
//! use numeric_literals::replace_numeric_literals;
//!
//! #[replace_numeric_literals(T::from(literal).unwrap())]
//! fn golden_ratio<T: Float>() -> T {
//!    (1 + 5.sqrt()) / 2
//! }
//! ```
//!
//! The above two code segments do essentially the same thing
//! (apart from using `T::from(1)` instead of `T::one()`). However, in the latter example,
//! the `replace_numeric_literals` attribute replaces any numeric literal with the expression
//! `T::from(literal).unwrap()`, where `literal` is a placeholder for each individual literal.
//!
//! There is no magic involved: the code is still explict about what it does to numeric literals.
//! The difference is that we can declare this behavior once for all numeric literals. Moreover,
//! we move the conversion behavior away from where the literals are needed, enhancing readability
//! by reducing the noise imposed by being explicit about the exact types involved.
//!
//! Float and integer literal replacement
//! -------------------------------------
//!
//! An issue with the replacement of numeric literals is that there is no way to distinguish
//! literals that are used for e.g. indexing from those that are part of a numerical computation.
//! In the example above, if you would additionally need to index into an array with a constant index
//! such as `array[0]`, the macro will try to convert the index `0` to a float type, which
//! would clearly fail. Thankfully, in most cases these examples will outright fail to compile
//! because of type mismatch. One possible resolution to this problem is to use the separate
//! macros `replace_float_literals` and `replace_int_literals`, which work in the exact same way,
//! but only trigger on float or integer literals, respectively. Below is an example from
//! Finite Element code that uses float literal replacement to improve readability of numerical
//! constants in generic code.
//!
//! ```ignore
//! #[replace_float_literals(T::from_f64(literal).expect("Literal must fit in T"))]
//! pub fn assemble_element_mass<T>(quad: &Quad2d<T>) -> MatrixN<T, U8>
//! where
//!    T: RealField
//! {
//!     let phi = |alpha, beta, xi: &Vector2<T>| -(1.0 + alpha * xi[0]) * (1.0 + beta * xi[1]) / 4.0;
//!     let phi_grad = |alpha, beta, xi: &Vector2<T>| {
//!         Vector2::new(
//!             alpha * (1.0 + beta * xi[1]) / 4.0,
//!             beta * (1.0 + alpha * xi[0]) / 4.0,
//!         )
//!     };
//!     let alphas = [-1.0, 1.0, 1.0, -1.0];
//!     let betas = [-1.0, -1.0, 1.0, 1.0];
//!
//!     // And so on...
//! }
//! ```
//!
//! In general, **the macros should be used with caution**. It is recommended to keep the macro close to
//! the region in which the literals are being used, as to avoid confusion for readers of the code.
//! The Rust code before macro expansion is usually not valid Rust (because of the lack of explicit
//! type conversion), but without the context of the attribute, it is simply not clear why this
//! code still compiles.
//!
//! An option for the future would be to apply the attribute only to very local blocks of code that
//! are heavy on numerical constants. However, at present, Rust does not allow attribute macros
//! to apply to blocks or single expressions.
//!
//! Replacement in macro invocations
//! --------------------------------
//! By default, the macros of this crate will also replace literals inside of macro invocations.
//! This allows code such as the following to compile:
//!
//! ```rust
//! use num::Float;
//! use numeric_literals::replace_numeric_literals;
//!
//! #[replace_numeric_literals(T::from(literal).unwrap())]
//! fn zeros<T: Float>(n: usize) -> Vec<T> {
//!     vec![0.0; n]
//! }
//! ```
//! If this behavior is unwanted, it is possible to disable replacement inside of macros with a
//! parameter:
//! ```ignore
//! #[replace_numeric_literals(T::from(literal).unwrap()), visit_macros = false]
//! ```
//!
//! Literals with suffixes
//! ----------------------
//! In rust, literal suffixes can be used to disambiguate the type of a literal. For example, the suffix `_f64`
//! in the expression `1_f64.sqrt()` makes it clear that the value `1` is of type `f64`. This is also supported
//! by the macros of this crate for all floating point and integer suffixes. For example:
//!
//! ```rust
//! use num::Float;
//! use numeric_literals::replace_numeric_literals;
//!
//! #[replace_numeric_literals(T::from(literal).unwrap())]
//! fn golden_ratio<T: Float>() -> T {
//!     (1.0_f64 + 5f32.sqrt()) / 2.0
//! }
//! ```

extern crate proc_macro;
use proc_macro::TokenStream;

use syn::parse::Parser;
use syn::punctuated::Punctuated;
use syn::visit::Visit;
use syn::visit_mut::{visit_expr_mut, VisitMut};
use syn::{
    parse_macro_input, Expr, ExprAssign, ExprLit, ExprPath, Item, Lit, LitBool, Macro, Token,
};

use quote::{quote, ToTokens};

/// Visit an expression and replaces any numeric literal
/// with the replacement expression, in which a placeholder identifier
/// is replaced with the numeric literal.
struct NumericLiteralVisitor<'a> {
    pub parameters: MacroParameters,
    pub placeholder: &'a str,
    pub float_replacement: &'a Expr,
    pub int_replacement: &'a Expr,
}

struct FloatLiteralVisitor<'a> {
    pub parameters: MacroParameters,
    pub placeholder: &'a str,
    pub replacement: &'a Expr,
}

struct IntLiteralVisitor<'a> {
    pub parameters: MacroParameters,
    pub placeholder: &'a str,
    pub replacement: &'a Expr,
}

/// Represents classes of primitive types relevant to the crate
enum PrimitiveClass {
    Float,
    Int,
    Other,
}

/// Returns what class of primitive types is represented by this literal expression, e.g. `20f64 -> Float`, `20 -> Int`
fn determine_primitive_class(lit_expr: &ExprLit) -> PrimitiveClass {
    match &lit_expr.lit {
        // Parsed float literals are always floats
        Lit::Float(_) => PrimitiveClass::Float,
        // Literals like `20f64` are parsed as `LitInt`s
        Lit::Int(int_lit) if matches!(int_lit.suffix(), "f32" | "f64") => PrimitiveClass::Float,
        // All other integer literals should be actual integers
        Lit::Int(_) => PrimitiveClass::Int,
        _ => PrimitiveClass::Other,
    }
}

fn replace_literal(expr: &mut Expr, placeholder: &str, literal: &ExprLit) {
    let mut replacer = ReplacementExpressionVisitor {
        placeholder,
        literal,
    };
    replacer.visit_expr_mut(expr);
}

fn try_parse_punctuated_macro<P: ToTokens, V: VisitMut, F: Parser<Output = Punctuated<Expr, P>>>(
    visitor: &mut V,
    mac: &mut Macro,
    parser: F,
) -> bool {
    if let Ok(mut exprs) = mac.parse_body_with(parser) {
        exprs
            .iter_mut()
            .for_each(|expr| visitor.visit_expr_mut(expr));
        mac.tokens = exprs.into_token_stream();
        return true;
    }
    return false;
}

fn visit_macros_mut<V: VisitMut>(visitor: &mut V, mac: &mut Macro) {
    // Handle expression based macros (e.g. assert)
    if let Ok(mut expr) = mac.parse_body::<Expr>() {
        visitor.visit_expr_mut(&mut expr);
        mac.tokens = expr.into_token_stream();
        return;
    }

    // Handle , punctuation based macros (e.g. vec with list, assert_eq)
    let parser_comma = Punctuated::<Expr, Token![,]>::parse_terminated;
    if try_parse_punctuated_macro(visitor, mac, parser_comma) {
        return;
    }

    // Handle ; punctuation based macros (e.g. vec with repeat)
    let parser_semicolon = Punctuated::<Expr, Token![;]>::parse_terminated;
    if try_parse_punctuated_macro(visitor, mac, parser_semicolon) {
        return;
    }
}

impl<'a> VisitMut for FloatLiteralVisitor<'a> {
    fn visit_expr_mut(&mut self, expr: &mut Expr) {
        if let Expr::Lit(lit_expr) = expr {
            if let PrimitiveClass::Float = determine_primitive_class(&lit_expr) {
                let mut adapted_replacement = self.replacement.clone();
                replace_literal(&mut adapted_replacement, self.placeholder, lit_expr);
                *expr = adapted_replacement;
                return;
            }
        }
        visit_expr_mut(self, expr)
    }

    fn visit_macro_mut(&mut self, mac: &mut Macro) {
        if self.parameters.visit_macros {
            visit_macros_mut(self, mac);
        }
    }
}

impl<'a> VisitMut for IntLiteralVisitor<'a> {
    fn visit_expr_mut(&mut self, expr: &mut Expr) {
        if let Expr::Lit(lit_expr) = expr {
            if let PrimitiveClass::Int = determine_primitive_class(&lit_expr) {
                let mut adapted_replacement = self.replacement.clone();
                replace_literal(&mut adapted_replacement, self.placeholder, lit_expr);
                *expr = adapted_replacement;
                return;
            }
        }
        visit_expr_mut(self, expr)
    }

    fn visit_macro_mut(&mut self, mac: &mut Macro) {
        if self.parameters.visit_macros {
            visit_macros_mut(self, mac);
        }
    }
}

impl<'a> VisitMut for NumericLiteralVisitor<'a> {
    fn visit_expr_mut(&mut self, expr: &mut Expr) {
        if let Expr::Lit(lit_expr) = expr {
            // TODO: Currently we cannot correctly treat integers that don't fit in 64
            //  bits. For this we'd have to deal with verbatim literals and manually
            //  parse the string

            match determine_primitive_class(&lit_expr) {
                PrimitiveClass::Float => {
                    let mut visitor = FloatLiteralVisitor {
                        parameters: self.parameters,
                        placeholder: self.placeholder,
                        replacement: self.float_replacement,
                    };
                    visitor.visit_expr_mut(expr);
                    return;
                }
                PrimitiveClass::Int => {
                    let mut visitor = IntLiteralVisitor {
                        parameters: self.parameters,
                        placeholder: self.placeholder,
                        replacement: self.int_replacement,
                    };
                    visitor.visit_expr_mut(expr);
                    return;
                }
                _ => {}
            }
        }
        visit_expr_mut(self, expr)
    }

    fn visit_macro_mut(&mut self, mac: &mut Macro) {
        if self.parameters.visit_macros {
            visit_macros_mut(self, mac);
        }
    }
}

/// Visits the "replacement expression", which replaces a placeholder identifier
/// with the given literal.
struct ReplacementExpressionVisitor<'a> {
    pub placeholder: &'a str,
    pub literal: &'a ExprLit,
}

impl<'a> VisitMut for ReplacementExpressionVisitor<'a> {
    fn visit_expr_mut(&mut self, expr: &mut Expr) {
        if let Expr::Path(path_expr) = expr {
            if let Some(last_segment) = path_expr.path.segments.last() {
                if last_segment.ident == self.placeholder {
                    *expr = Expr::Lit(self.literal.clone());
                    return;
                }
            }
        }
        visit_expr_mut(self, expr)
    }
}

struct MacroParameterVisitor {
    pub name: Option<String>,
    pub value: Option<ParameterValue>,
}

impl MacroParameterVisitor {
    fn parse_flag(expr: &Expr) -> Option<(String, ParameterValue)> {
        let mut visitor = MacroParameterVisitor {
            name: None,
            value: None,
        };
        visitor.visit_expr(expr);
        let name = visitor.name.take();
        let value = visitor.value.take();
        name.and_then(|n| value.and_then(|v| Some((n, v))))
    }
}

impl<'ast> Visit<'ast> for MacroParameterVisitor {
    fn visit_expr_assign(&mut self, expr: &'ast ExprAssign) {
        self.visit_expr(&expr.left);
        self.visit_expr(&expr.right);
    }

    fn visit_expr_path(&mut self, expr: &'ast ExprPath) {
        let mut name = Vec::new();
        expr.path
            .leading_colon
            .map(|_| name.push(String::from("::")));
        for p in expr.path.segments.pairs() {
            match p {
                syn::punctuated::Pair::Punctuated(ps, _sep) => {
                    name.push(ps.ident.to_string());
                    name.push(String::from("::"));
                }
                syn::punctuated::Pair::End(ps) => {
                    name.push(ps.ident.to_string());
                }
            }
        }
        self.name = Some(name.concat());
    }

    fn visit_lit_bool(&mut self, expr: &'ast LitBool) {
        self.value = Some(ParameterValue::Bool(expr.value));
    }
}

enum ParameterValue {
    Bool(bool),
}

#[derive(Copy, Clone)]
struct MacroParameters {
    pub visit_macros: bool,
}

impl Default for MacroParameters {
    fn default() -> Self {
        Self { visit_macros: true }
    }
}

impl MacroParameters {
    fn set(&mut self, name: &str, value: ParameterValue) {
        match name {
            "visit_macros" => match value {
                ParameterValue::Bool(v) => self.visit_macros = v,
            },
            _ => {}
        }
    }
}

/// Obtain the replacement expression and parameters from the macro attr token stream.
fn parse_macro_attribute(attr: TokenStream) -> Result<(Expr, MacroParameters), syn::Error> {
    let parser = Punctuated::<Expr, Token![,]>::parse_separated_nonempty;
    let attributes = parser.parse(attr)?;

    let mut attr_iter = attributes.into_iter();
    let replacement = attr_iter.next().expect("No replacement provided");

    let user_parameters: Vec<_> = attr_iter
        .filter_map(|expr| MacroParameterVisitor::parse_flag(&expr))
        .collect();
    let mut parameters = MacroParameters::default();
    for (name, value) in user_parameters {
        parameters.set(&name, value);
    }

    Ok((replacement, parameters))
}

/// Replace any numeric literal with custom transformation code.
///
/// Refer to the documentation at the root of the crate for usage instructions.
#[proc_macro_attribute]
pub fn replace_numeric_literals(attr: TokenStream, item: TokenStream) -> TokenStream {
    let mut input = parse_macro_input!(item as Item);
    let (replacement, parameters) = match parse_macro_attribute(attr) {
        Ok(res) => res,
        Err(err) => return TokenStream::from(err.to_compile_error()),
    };

    let mut replacer = NumericLiteralVisitor {
        parameters,
        placeholder: "literal",
        int_replacement: &replacement,
        float_replacement: &replacement,
    };
    replacer.visit_item_mut(&mut input);

    let expanded = quote! { #input };

    TokenStream::from(expanded)
}

/// Replace any float literal with custom transformation code.
///
/// Refer to the documentation at the root of the crate for usage instructions.
#[proc_macro_attribute]
pub fn replace_float_literals(attr: TokenStream, item: TokenStream) -> TokenStream {
    let mut input = parse_macro_input!(item as Item);
    let (replacement, parameters) = match parse_macro_attribute(attr) {
        Ok(res) => res,
        Err(err) => return TokenStream::from(err.to_compile_error()),
    };

    let mut replacer = FloatLiteralVisitor {
        parameters,
        placeholder: "literal",
        replacement: &replacement,
    };
    replacer.visit_item_mut(&mut input);

    let expanded = quote! { #input };

    TokenStream::from(expanded)
}

/// Replace any integer literal with custom transformation code.
///
/// Refer to the documentation at the root of the crate for usage instructions.
#[proc_macro_attribute]
pub fn replace_int_literals(attr: TokenStream, item: TokenStream) -> TokenStream {
    let mut input = parse_macro_input!(item as Item);
    let (replacement, parameters) = match parse_macro_attribute(attr) {
        Ok(res) => res,
        Err(err) => return TokenStream::from(err.to_compile_error()),
    };

    let mut replacer = IntLiteralVisitor {
        parameters,
        placeholder: "literal",
        replacement: &replacement,
    };
    replacer.visit_item_mut(&mut input);

    let expanded = quote! { #input };

    TokenStream::from(expanded)
}