1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
//! `chain_cmp` lets you chain comparison operators like

//! you would in mathematics using the [`chmp!`] macro.


use proc_macro::TokenStream;
use quote::ToTokens;
use syn::{parse_macro_input, spanned::Spanned, Expr, ExprBinary, Token};

/// Use the `chmp` macro to chain comparison operators.

///

/// You can use all of these operators: `<`, `<=`, `>`, `>=`, `==`, `!=`.

///

/// # Examples

///

/// ## Basic usage

///

/// ```

/// use chain_cmp::chmp;

///

/// let (a, b, c) = (1, 2, 3);

///

/// let verbose = a < b && b <= c;

/// let concise = chmp!(a < b <= c);

/// assert_eq!(concise, verbose);

///

/// // You can use equality operators as well:

/// assert!(chmp!(a != b != c));

///

/// // And you can even chain more than three operators:

/// assert!(chmp!(a != b != c != a)); // making sure these values are pairwise distinct

///

/// // And of course mix and match operators:

/// assert!(chmp!(a < b <= c != a == a));

/// ```

///

/// ## Short-circuiting

///

/// `chmp` will short-circuit to evaluate the fewest expressions

/// possible.

///

/// ```

/// # use chain_cmp::chmp;

/// fn panics() -> i32 {

///     panic!();

/// }

///

/// assert!(!chmp!(i32::MAX < i32::MIN < panics())); // this **won't** panic

/// ```

///

/// ## Comparing arbitrary expressions

///

/// As long as the comparison operators have the lowest precedence,

/// `chmp` will evaluate any expression, like variables, blocks,

/// function calls, etc.

///

/// ```

/// # use chain_cmp::chmp;

/// const ANSWER: u32 = 42;

///

/// assert!(chmp!({

///     println!("Life, the Universe, and Everything");

///     ANSWER

/// } != 6 * 9 == 54));

/// ```

#[proc_macro]
pub fn chmp(tokens: TokenStream) -> TokenStream {
    let ast = parse_macro_input!(tokens as ExprBinary);
    match cmp_tree_to_conjunction_tree(ast) {
        Ok(expr) => expr.into_token_stream(),
        Err(err) => err.to_compile_error(),
    }
    .into()
}

/// `cmp_tree_to_conjunction_tree` turns a tree of chained

/// comparisons that would normally not be valid rust into

/// a valid tree of conjunctions.

///

/// In other words, it turns something like `a < b < c` into

/// `a < b && b < c`.

fn cmp_tree_to_conjunction_tree(cmp_tree: ExprBinary) -> Result<Expr, syn::Error> {
    let mut exprs = Vec::new();
    flatten_tree(cmp_tree, &mut exprs).map(|_| build_conjunction_tree(exprs))
}

/// `is_comparison_op` returns `true` if `op` is one of

/// `<`, `<=`, `>`, `>=`, `==`, `!=`.

fn is_comparison_op(op: &syn::BinOp) -> bool {
    use syn::BinOp::*;
    matches!(op, Ne(_) | Eq(_) | Le(_) | Ge(_) | Lt(_) | Gt(_))
}

/// `is_comparison` returns `true` if `expr` is a

/// comparison, i.e. any operation that is supported

/// by types that implement `PartialEq` or `PartialOrd`.

fn is_comparison(expr: &ExprBinary) -> bool {
    is_comparison_op(&expr.op)
}

/// `flatten_tree` takes a `tree` of binary expressions and flattens it,

/// appending each individual expression to `container`.

///

/// For example, this tree of comparison expressions

/// (where `en` is an arbitrary expression)

///

/// ```nocompile

///          <

///         / \

///        <=  e4

///       /  \

///      <=   e3

///     /  \

///    e1  e2

/// ```

///

/// becomes this flattened list:

///

/// ```nocompile

/// [e4, e3, e2, e1]

/// ```

fn flatten_tree(mut tree: ExprBinary, container: &mut Vec<Expr>) -> Result<(), syn::Error> {
    let op = tree.op;
    if !is_comparison_op(&op) {
        let err = syn::Error::new_spanned(
            op,
            format!(
                "Expected one of `<`, `<=`, `>`, `>=`, `==`, `!=`, found: `{}`",
                op.to_token_stream()
            ),
        );
        return Err(err);
    }

    match &*tree.left {
        Expr::Binary(rest) if is_comparison(rest) => {
            let lhs = rest.right.clone();
            let rest = match *std::mem::replace(&mut tree.left, lhs) {
                Expr::Binary(expr) => expr,
                _ => unreachable!(),
            };
            container.push(into_expr(tree));
            flatten_tree(rest, container)
        }
        _ => {
            container.push(into_expr(tree));
            Ok(())
        }
    }
}

/// `build_conjunction_tree` turns a list of `Expr`s into

/// a tree of conjunctions where the last element of the list

/// is the root of the resulting tree.

///

/// For example, this list of four expressions

///

/// ```nocompile

/// [e4, e3, e2, e1]

/// ```

///

/// becomes this tree of conjunctions:

///

/// ```nocompile

///     &&

///    /  \

///   e1   &&

///       /  \

///      e2   &&

///          /  \

///         e3   e4

/// ```

fn build_conjunction_tree(mut exprs: Vec<Expr>) -> Expr {
    let expr = exprs
        .pop()
        .expect("need at least one expression to build tree");

    if exprs.is_empty() {
        expr
    } else {
        into_expr(new_conjuction(expr, build_conjunction_tree(exprs)))
    }
}

fn into_expr(bin_expr: ExprBinary) -> Expr {
    Expr::Binary(bin_expr)
}

/// `new_conjunction` returns a new binary expression of the form

/// `left && right`.

fn new_conjuction(left: Expr, right: Expr) -> ExprBinary {
    let (left_span, right_span) = (left.span(), right.span());

    ExprBinary {
        attrs: vec![],
        left: Box::new(left),
        op: syn::BinOp::And(Token![&&]([left_span, right_span])),
        right: Box::new(right),
    }
}