Skip to main content

financial_ops_macros/
lib.rs

1//! Procedural macros for [`financial-ops`](https://docs.rs/financial-ops).
2//!
3//! This crate provides the [`checked!`] macro, which rewrites an ordinary
4//! arithmetic expression into a chain of the standard library's checked
5//! arithmetic methods (`checked_add`, `checked_sub`, `checked_mul`,
6//! `checked_div`, `checked_rem`), recursively, while preserving the operator
7//! precedence and grouping of the original expression.
8//!
9//! You normally do not depend on this crate directly; the macro is re-exported
10//! from `financial-ops` as `financial_ops::checked`.
11
12use proc_macro::TokenStream;
13use proc_macro2::{Spacing, Span, TokenStream as TokenStream2, TokenTree};
14use quote::{format_ident, quote};
15use syn::{BinOp, Expr};
16
17/// Recursively rewrites an arithmetic expression into checked arithmetic.
18///
19/// The macro walks the parsed expression tree, so it honors normal Rust
20/// operator precedence and parentheses. Every binary operator is replaced by
21/// its checked counterpart and the operands are threaded together so each
22/// sub-expression is evaluated exactly once, in left-to-right order.
23///
24/// # Result vs. Option
25///
26/// * Without an error, the macro evaluates to an `Option<T>` (it is `None` as
27///   soon as any step overflows or divides by zero):
28///
29/// ```ignore
30/// let total: Option<u64> = checked! { a + b * c };
31/// ```
32///
33/// * With a trailing `@ <error expression>`, the macro evaluates to a
34///   `Result<T, E>`, mapping the first failing step to the given error:
35///
36/// ```ignore
37/// let total = checked! { a + b * c @ MyError::Overflow }?;
38/// ```
39///
40/// # Supported operators
41///
42/// `+` → `checked_add`, `-` → `checked_sub`, `*` → `checked_mul`,
43/// `/` → `checked_div`, `%` → `checked_rem`.
44///
45/// # Examples
46///
47/// ```ignore
48/// use financial_ops::checked;
49///
50/// // Respects precedence: this is `a + (b * c)`, fully checked.
51/// let value = checked! { 2u64 + 3 * 4 };
52/// assert_eq!(value, Some(14));
53///
54/// // Overflow short-circuits to `None`.
55/// assert_eq!(checked! { u8::MAX + 1u8 }, None);
56/// ```
57#[proc_macro]
58pub fn checked(input: TokenStream) -> TokenStream {
59    let input2: TokenStream2 = input.into();
60    match parse_input(input2) {
61        Ok((expr, error)) => match expand(&expr, error.as_ref(), 0) {
62            Ok(tokens) => tokens.into(),
63            Err(err) => err.to_compile_error().into(),
64        },
65        Err(err) => err.to_compile_error().into(),
66    }
67}
68
69/// Splits the macro input into the arithmetic expression and an optional error
70/// expression separated by a top-level `@`.
71fn parse_input(tokens: TokenStream2) -> syn::Result<(Expr, Option<Expr>)> {
72    let mut expr_tokens = TokenStream2::new();
73    let mut error_tokens = TokenStream2::new();
74    let mut found_at = false;
75
76    for tt in tokens {
77        if !found_at {
78            if let TokenTree::Punct(ref punct) = tt {
79                if punct.as_char() == '@' && punct.spacing() == Spacing::Alone {
80                    found_at = true;
81                    continue;
82                }
83            }
84            expr_tokens.extend(std::iter::once(tt));
85        } else {
86            error_tokens.extend(std::iter::once(tt));
87        }
88    }
89
90    if expr_tokens.is_empty() {
91        return Err(syn::Error::new(
92            Span::call_site(),
93            "expected an arithmetic expression",
94        ));
95    }
96
97    let expr: Expr = syn::parse2(expr_tokens)?;
98
99    let error = if found_at {
100        if error_tokens.is_empty() {
101            return Err(syn::Error::new(
102                Span::call_site(),
103                "expected an error expression after `@`",
104            ));
105        }
106        Some(syn::parse2::<Expr>(error_tokens)?)
107    } else {
108        None
109    };
110
111    Ok((expr, error))
112}
113
114/// Recursively expands `expr` into a checked-arithmetic token stream.
115///
116/// `depth` is used to generate non-colliding binding identifiers: a node at
117/// depth `d` binds `__l{d}` / `__r{d}`, and its children recurse at `d + 1`, so
118/// a child's bindings can never shadow an ancestor's.
119fn expand(expr: &Expr, error: Option<&Expr>, depth: usize) -> syn::Result<TokenStream2> {
120    match expr {
121        Expr::Binary(binary) => {
122            let method = method_name(&binary.op)?;
123            let lhs = expand(&binary.left, error, depth + 1)?;
124            let rhs = expand(&binary.right, error, depth + 1)?;
125
126            let l = format_ident!("__l{}", depth);
127            let r = format_ident!("__r{}", depth);
128
129            let combine = match error {
130                Some(err) => quote! { #l.#method(#r).ok_or(#err) },
131                None => quote! { #l.#method(#r) },
132            };
133
134            Ok(quote! {
135                (#lhs).and_then(|#l| (#rhs).and_then(|#r| #combine))
136            })
137        }
138        // Transparent wrappers: keep recursing without consuming a depth level.
139        Expr::Paren(paren) => expand(&paren.expr, error, depth),
140        Expr::Group(group) => expand(&group.expr, error, depth),
141        // Any other expression is a leaf value to be lifted into the chain.
142        leaf => Ok(match error {
143            Some(_) => quote! { ::core::result::Result::<_, _>::Ok(#leaf) },
144            None => quote! { ::core::option::Option::Some(#leaf) },
145        }),
146    }
147}
148
149/// Maps a binary operator to its checked-method identifier.
150fn method_name(op: &BinOp) -> syn::Result<proc_macro2::Ident> {
151    let name = match op {
152        BinOp::Add(_) => "checked_add",
153        BinOp::Sub(_) => "checked_sub",
154        BinOp::Mul(_) => "checked_mul",
155        BinOp::Div(_) => "checked_div",
156        BinOp::Rem(_) => "checked_rem",
157        other => {
158            return Err(syn::Error::new_spanned(
159                other,
160                "`checked!` only supports the `+`, `-`, `*`, `/`, and `%` operators",
161            ));
162        }
163    };
164    Ok(format_ident!("{}", name))
165}