Skip to main content

chain_cmp/
lib.rs

1//! `chain_cmp` lets you chain comparison operators like
2//! you would in mathematics using the [`chmp!`] macro.
3
4use proc_macro::TokenStream;
5use quote::ToTokens;
6use syn::{parse_macro_input, spanned::Spanned, Expr, ExprBinary, Token};
7
8/// Use the `chmp` macro to chain comparison operators.
9///
10/// You can use all of these operators: `<`, `<=`, `>`, `>=`, `==`, `!=`.
11///
12/// # Examples
13///
14/// ## Basic usage
15///
16/// ```
17/// use chain_cmp::chmp;
18///
19/// let (a, b, c) = (1, 2, 3);
20///
21/// let verbose = a < b && b <= c;
22/// let concise = chmp!(a < b <= c);
23/// assert_eq!(concise, verbose);
24///
25/// // You can use equality operators as well:
26/// assert!(chmp!(a != b != c));
27///
28/// // And you can even chain more than three operators:
29/// assert!(chmp!(a != b != c != a)); // making sure these values are pairwise distinct
30///
31/// // And of course mix and match operators:
32/// assert!(chmp!(a < b <= c != a == a));
33/// ```
34///
35/// ## Short-circuiting
36///
37/// `chmp` will short-circuit to evaluate the fewest expressions
38/// possible.
39///
40/// ```
41/// # use chain_cmp::chmp;
42/// fn panics() -> i32 {
43///     panic!();
44/// }
45///
46/// assert!(!chmp!(i32::MAX < i32::MIN < panics())); // this **won't** panic
47/// ```
48///
49/// ## Comparing arbitrary expressions
50///
51/// As long as the comparison operators have the lowest precedence,
52/// `chmp` will evaluate any expression, like variables, blocks,
53/// function calls, etc.
54///
55/// ```
56/// # use chain_cmp::chmp;
57/// const ANSWER: u32 = 42;
58///
59/// assert!(chmp!({
60///     println!("Life, the Universe, and Everything");
61///     ANSWER
62/// } != 6 * 9 == 54));
63/// ```
64#[proc_macro]
65pub fn chmp(tokens: TokenStream) -> TokenStream {
66    let ast = parse_macro_input!(tokens as ExprBinary);
67    match cmp_tree_to_conjunction_tree(ast) {
68        Ok(expr) => expr.into_token_stream(),
69        Err(err) => err.to_compile_error(),
70    }
71    .into()
72}
73
74/// `cmp_tree_to_conjunction_tree` turns a tree of chained
75/// comparisons that would normally not be valid rust into
76/// a valid tree of conjunctions.
77///
78/// In other words, it turns something like `a < b < c` into
79/// `a < b && b < c`.
80fn cmp_tree_to_conjunction_tree(cmp_tree: ExprBinary) -> Result<Expr, syn::Error> {
81    let mut exprs = Vec::new();
82    flatten_tree(cmp_tree, &mut exprs).map(|_| build_conjunction_tree(exprs))
83}
84
85/// `is_comparison_op` returns `true` if `op` is one of
86/// `<`, `<=`, `>`, `>=`, `==`, `!=`.
87fn is_comparison_op(op: &syn::BinOp) -> bool {
88    use syn::BinOp::*;
89    matches!(op, Ne(_) | Eq(_) | Le(_) | Ge(_) | Lt(_) | Gt(_))
90}
91
92/// `is_comparison` returns `true` if `expr` is a
93/// comparison, i.e. any operation that is supported
94/// by types that implement `PartialEq` or `PartialOrd`.
95fn is_comparison(expr: &ExprBinary) -> bool {
96    is_comparison_op(&expr.op)
97}
98
99/// `flatten_tree` takes a `tree` of binary expressions and flattens it,
100/// appending each individual expression to `container`.
101///
102/// For example, this tree of comparison expressions
103/// (where `en` is an arbitrary expression)
104///
105/// ```nocompile
106///          <
107///         / \
108///        <=  e4
109///       /  \
110///      <=   e3
111///     /  \
112///    e1  e2
113/// ```
114///
115/// becomes this flattened list:
116///
117/// ```nocompile
118/// [e4, e3, e2, e1]
119/// ```
120fn flatten_tree(mut tree: ExprBinary, container: &mut Vec<Expr>) -> Result<(), syn::Error> {
121    let op = tree.op;
122    if !is_comparison_op(&op) {
123        let err = syn::Error::new_spanned(
124            op,
125            format!(
126                "Expected one of `<`, `<=`, `>`, `>=`, `==`, `!=`, found: `{}`",
127                op.to_token_stream()
128            ),
129        );
130        return Err(err);
131    }
132
133    match &*tree.left {
134        Expr::Binary(rest) if is_comparison(rest) => {
135            let lhs = rest.right.clone();
136            let rest = match *std::mem::replace(&mut tree.left, lhs) {
137                Expr::Binary(expr) => expr,
138                _ => unreachable!(),
139            };
140            container.push(into_expr(tree));
141            flatten_tree(rest, container)
142        }
143        _ => {
144            container.push(into_expr(tree));
145            Ok(())
146        }
147    }
148}
149
150/// `build_conjunction_tree` turns a list of `Expr`s into
151/// a tree of conjunctions where the last element of the list
152/// is the root of the resulting tree.
153///
154/// For example, this list of four expressions
155///
156/// ```nocompile
157/// [e4, e3, e2, e1]
158/// ```
159///
160/// becomes this tree of conjunctions:
161///
162/// ```nocompile
163///     &&
164///    /  \
165///   e1   &&
166///       /  \
167///      e2   &&
168///          /  \
169///         e3   e4
170/// ```
171fn build_conjunction_tree(mut exprs: Vec<Expr>) -> Expr {
172    let expr = exprs
173        .pop()
174        .expect("need at least one expression to build tree");
175
176    if exprs.is_empty() {
177        expr
178    } else {
179        into_expr(new_conjuction(expr, build_conjunction_tree(exprs)))
180    }
181}
182
183fn into_expr(bin_expr: ExprBinary) -> Expr {
184    Expr::Binary(bin_expr)
185}
186
187/// `new_conjunction` returns a new binary expression of the form
188/// `left && right`.
189fn new_conjuction(left: Expr, right: Expr) -> ExprBinary {
190    let (left_span, right_span) = (left.span(), right.span());
191
192    ExprBinary {
193        attrs: vec![],
194        left: Box::new(left),
195        op: syn::BinOp::And(Token![&&]([left_span, right_span])),
196        right: Box::new(right),
197    }
198}