financial_ops_macros/lib.rs
1//! Procedural macros for [`financial-ops`](https://docs.rs/financial-ops).
2//!
3//! This crate provides the [`checked!`] macro, which rewrites an ordinary
4//! arithmetic expression into a chain of the standard library's checked
5//! arithmetic methods (`checked_add`, `checked_sub`, `checked_mul`,
6//! `checked_div`, `checked_rem`), recursively, while preserving the operator
7//! precedence and grouping of the original expression.
8//!
9//! You normally do not depend on this crate directly; the macro is re-exported
10//! from `financial-ops` as `financial_ops::checked`.
11
12use proc_macro::TokenStream;
13use proc_macro2::{Spacing, Span, TokenStream as TokenStream2, TokenTree};
14use quote::{format_ident, quote};
15use syn::{BinOp, Expr};
16
17/// Recursively rewrites an arithmetic expression into checked arithmetic.
18///
19/// The macro walks the parsed expression tree, so it honors normal Rust
20/// operator precedence and parentheses. Every binary operator is replaced by
21/// its checked counterpart and the operands are threaded together so each
22/// sub-expression is evaluated exactly once, in left-to-right order.
23///
24/// # Result vs. Option
25///
26/// * Without an error, the macro evaluates to an `Option<T>` (it is `None` as
27/// soon as any step overflows or divides by zero):
28///
29/// ```ignore
30/// let total: Option<u64> = checked! { a + b * c };
31/// ```
32///
33/// * With a trailing `@ <error expression>`, the macro evaluates to a
34/// `Result<T, E>`, mapping the first failing step to the given error:
35///
36/// ```ignore
37/// let total = checked! { a + b * c @ MyError::Overflow }?;
38/// ```
39///
40/// # Supported operators
41///
42/// `+` → `checked_add`, `-` → `checked_sub`, `*` → `checked_mul`,
43/// `/` → `checked_div`, `%` → `checked_rem`.
44///
45/// # Examples
46///
47/// ```ignore
48/// use financial_ops::checked;
49///
50/// // Respects precedence: this is `a + (b * c)`, fully checked.
51/// let value = checked! { 2u64 + 3 * 4 };
52/// assert_eq!(value, Some(14));
53///
54/// // Overflow short-circuits to `None`.
55/// assert_eq!(checked! { u8::MAX + 1u8 }, None);
56/// ```
57#[proc_macro]
58pub fn checked(input: TokenStream) -> TokenStream {
59 let input2: TokenStream2 = input.into();
60 match parse_input(input2) {
61 Ok((expr, error)) => match expand(&expr, error.as_ref(), 0) {
62 Ok(tokens) => tokens.into(),
63 Err(err) => err.to_compile_error().into(),
64 },
65 Err(err) => err.to_compile_error().into(),
66 }
67}
68
69/// Splits the macro input into the arithmetic expression and an optional error
70/// expression separated by a top-level `@`.
71fn parse_input(tokens: TokenStream2) -> syn::Result<(Expr, Option<Expr>)> {
72 let mut expr_tokens = TokenStream2::new();
73 let mut error_tokens = TokenStream2::new();
74 let mut found_at = false;
75
76 for tt in tokens {
77 if !found_at {
78 if let TokenTree::Punct(ref punct) = tt {
79 if punct.as_char() == '@' && punct.spacing() == Spacing::Alone {
80 found_at = true;
81 continue;
82 }
83 }
84 expr_tokens.extend(std::iter::once(tt));
85 } else {
86 error_tokens.extend(std::iter::once(tt));
87 }
88 }
89
90 if expr_tokens.is_empty() {
91 return Err(syn::Error::new(
92 Span::call_site(),
93 "expected an arithmetic expression",
94 ));
95 }
96
97 let expr: Expr = syn::parse2(expr_tokens)?;
98
99 let error = if found_at {
100 if error_tokens.is_empty() {
101 return Err(syn::Error::new(
102 Span::call_site(),
103 "expected an error expression after `@`",
104 ));
105 }
106 Some(syn::parse2::<Expr>(error_tokens)?)
107 } else {
108 None
109 };
110
111 Ok((expr, error))
112}
113
114/// Recursively expands `expr` into a checked-arithmetic token stream.
115///
116/// `depth` is used to generate non-colliding binding identifiers: a node at
117/// depth `d` binds `__l{d}` / `__r{d}`, and its children recurse at `d + 1`, so
118/// a child's bindings can never shadow an ancestor's.
119fn expand(expr: &Expr, error: Option<&Expr>, depth: usize) -> syn::Result<TokenStream2> {
120 match expr {
121 Expr::Binary(binary) => {
122 let method = method_name(&binary.op)?;
123 let lhs = expand(&binary.left, error, depth + 1)?;
124 let rhs = expand(&binary.right, error, depth + 1)?;
125
126 let l = format_ident!("__l{}", depth);
127 let r = format_ident!("__r{}", depth);
128
129 let combine = match error {
130 Some(err) => quote! { #l.#method(#r).ok_or(#err) },
131 None => quote! { #l.#method(#r) },
132 };
133
134 Ok(quote! {
135 (#lhs).and_then(|#l| (#rhs).and_then(|#r| #combine))
136 })
137 }
138 // Transparent wrappers: keep recursing without consuming a depth level.
139 Expr::Paren(paren) => expand(&paren.expr, error, depth),
140 Expr::Group(group) => expand(&group.expr, error, depth),
141 // Any other expression is a leaf value to be lifted into the chain.
142 leaf => Ok(match error {
143 Some(_) => quote! { ::core::result::Result::<_, _>::Ok(#leaf) },
144 None => quote! { ::core::option::Option::Some(#leaf) },
145 }),
146 }
147}
148
149/// Maps a binary operator to its checked-method identifier.
150fn method_name(op: &BinOp) -> syn::Result<proc_macro2::Ident> {
151 let name = match op {
152 BinOp::Add(_) => "checked_add",
153 BinOp::Sub(_) => "checked_sub",
154 BinOp::Mul(_) => "checked_mul",
155 BinOp::Div(_) => "checked_div",
156 BinOp::Rem(_) => "checked_rem",
157 other => {
158 return Err(syn::Error::new_spanned(
159 other,
160 "`checked!` only supports the `+`, `-`, `*`, `/`, and `%` operators",
161 ));
162 }
163 };
164 Ok(format_ident!("{}", name))
165}