Skip to main content

miden_node_tracing_macro/
lib.rs

1use std::collections::BTreeSet;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Delimiter, Group, TokenStream as TokenStream2, TokenTree};
5use quote::{ToTokens, quote};
6use syn::parse::{Parse, ParseStream};
7use syn::punctuated::Punctuated;
8use syn::token::Dot;
9use syn::visit::Visit;
10use syn::{Block, Expr, Ident, ItemFn, Macro, Result, Token, parse_macro_input, parse_quote};
11
12const ALLOWED_FIELD_NAMES: &[&str] = &[
13    "account.id",
14    "account.id.network_prefix",
15    "account.ids",
16    "account.ids.count",
17    "account.updated",
18    "batch.id",
19    "batch.account_updates.count",
20    "batch.expires_at",
21    "batch.expiration_height",
22    "batch.input_notes.count",
23    "batch.output_notes.count",
24    "batch.reference_block.commitment",
25    "batch.reference_block.number",
26    "block.batch.ids",
27    "block.batches.count",
28    "block.batches.output_notes.count",
29    "block.commitment",
30    "block.commitments.account",
31    "block.commitments.chain",
32    "block.commitments.kernel",
33    "block.commitments.note",
34    "block.commitments.nullifier",
35    "block.commitments.transaction",
36    "block.erased_note_proofs.count",
37    "block.erased_notes.count",
38    "block.from",
39    "block.nullifiers.count",
40    "block.number",
41    "block.output_notes.count",
42    "block.prev_block_commitment",
43    "block.protocol.version",
44    "block.sub_commitment",
45    "block.timestamp",
46    "block.transactions.ids",
47    "block.transactions.count",
48    "block.updated_accounts.count",
49    "block_range.from",
50    "block_range.to",
51    "db.account_state_forest.size",
52    "db.account_tree.size",
53    "db.block_store.size",
54    "db.nullifier_tree.size",
55    "db.sqlite.size",
56    "db.sqlite.wal.size",
57    "dice_roll",
58    "failure_rate",
59    "mempool.accounts",
60    "mempool.batches.proposed",
61    "mempool.batches.proven",
62    "mempool.nullifiers",
63    "mempool.output_notes",
64    "mempool.transactions.unbatched",
65    "mempool.transactions.uncommitted",
66    "note.id",
67    "notes.count",
68    "prover.kind",
69    "reference_block.number",
70    "request.kind",
71    "script.root",
72    "transaction.id",
73    "transaction.expires_at",
74    "transaction.input_notes.count",
75    "transaction.output_notes.count",
76    "transaction.reference_block.commitment",
77    "transaction.reference_block.number",
78    "tip.number",
79    "transactions.count",
80    "transactions.ids",
81    "transactions.input_notes.count",
82    "transactions.output_notes.count",
83    "transactions.unauthenticated_notes.count",
84    "workers.active",
85    "workers.capacity",
86    "workers.count",
87];
88
89#[proc_macro_attribute]
90pub fn miden_instrument(attr: TokenStream, item: TokenStream) -> TokenStream {
91    let attr = TokenStream2::from(attr);
92    let mut function = parse_macro_input!(item as ItemFn);
93    let fields = collect_recorded_fields(&function);
94    let args = merge_inferred_fields(attr, &fields);
95    let statements = &function.block.stmts;
96    let block: Block = parse_quote! {{
97        #[allow(unused_macros)]
98        macro_rules! __miden_span_record_must_be_used_within_miden_instrument {
99            () => {};
100        }
101
102        #(#statements)*
103    }};
104    *function.block = block;
105
106    let expanded = quote! {
107        #[::tracing::instrument(#args)]
108        #function
109    };
110
111    expanded.into()
112}
113
114fn merge_inferred_fields(attr: TokenStream2, fields: &[FieldPath]) -> TokenStream2 {
115    if fields.is_empty() {
116        return attr;
117    }
118
119    let inferred_fields = quote! { #(#fields = ::tracing::field::Empty),* };
120    if attr.is_empty() {
121        return quote! { fields(#inferred_fields) };
122    }
123
124    let mut merged_existing_fields = false;
125    let args = split_top_level_args(attr)
126        .into_iter()
127        .map(|arg| {
128            if let Some(group) = fields_group(&arg) {
129                merged_existing_fields = true;
130                let existing_fields = group.stream();
131                let merged_fields = if existing_fields.is_empty() {
132                    inferred_fields.clone()
133                } else if ends_with_comma(&existing_fields) {
134                    quote! { #existing_fields #inferred_fields }
135                } else {
136                    quote! { #existing_fields, #inferred_fields }
137                };
138                let mut merged_group = Group::new(Delimiter::Parenthesis, merged_fields);
139                merged_group.set_span(group.span());
140                quote! { fields #merged_group }
141            } else {
142                arg
143            }
144        })
145        .collect::<Vec<_>>();
146
147    if merged_existing_fields {
148        quote! { #(#args),* }
149    } else {
150        quote! { #(#args,)* fields(#inferred_fields) }
151    }
152}
153
154fn split_top_level_args(tokens: TokenStream2) -> Vec<TokenStream2> {
155    let mut args = Vec::new();
156    let mut current = TokenStream2::new();
157
158    for token in tokens {
159        match &token {
160            TokenTree::Punct(punct) if punct.as_char() == ',' => {
161                args.push(current);
162                current = TokenStream2::new();
163            },
164            _ => current.extend([token]),
165        }
166    }
167
168    if !current.is_empty() {
169        args.push(current);
170    }
171
172    args
173}
174
175fn fields_group(arg: &TokenStream2) -> Option<Group> {
176    let mut tokens = arg.clone().into_iter();
177    let Some(TokenTree::Ident(ident)) = tokens.next() else {
178        return None;
179    };
180    if ident != "fields" {
181        return None;
182    }
183
184    let Some(TokenTree::Group(group)) = tokens.next() else {
185        return None;
186    };
187    if group.delimiter() != Delimiter::Parenthesis || tokens.next().is_some() {
188        return None;
189    }
190
191    Some(group)
192}
193
194fn ends_with_comma(tokens: &TokenStream2) -> bool {
195    matches!(
196        tokens.clone().into_iter().last(),
197        Some(TokenTree::Punct(punct)) if punct.as_char() == ','
198    )
199}
200
201#[proc_macro]
202pub fn miden_span_record(input: TokenStream) -> TokenStream {
203    let records = parse_macro_input!(input as RecordFields);
204    let records = records.fields.into_iter().map(|field| {
205        let name = field.path.name();
206        let value = field.value.value_tokens();
207
208        quote! {
209            ::tracing::Span::current().record(#name, #value);
210        }
211    });
212
213    quote! {
214        __miden_span_record_must_be_used_within_miden_instrument!();
215        #(#records)*
216    }
217    .into()
218}
219
220fn validate_field_name(path: &FieldPath) -> Result<()> {
221    let name = path.name();
222
223    if ALLOWED_FIELD_NAMES.contains(&name.as_str()) {
224        Ok(())
225    } else {
226        Err(syn::Error::new_spanned(
227            path,
228            format!(
229                "unsupported tracing field `{name}`; use one of: {}",
230                ALLOWED_FIELD_NAMES.join(", "),
231            ),
232        ))
233    }
234}
235
236fn collect_recorded_fields(function: &ItemFn) -> Vec<FieldPath> {
237    let mut visitor = MacroVisitor::default();
238    visitor.visit_block(&function.block);
239
240    let mut names = BTreeSet::new();
241    visitor.fields.into_iter().filter(|field| names.insert(field.name())).collect()
242}
243
244#[derive(Default)]
245struct MacroVisitor {
246    fields: Vec<FieldPath>,
247}
248
249impl<'ast> Visit<'ast> for MacroVisitor {
250    fn visit_macro(&mut self, mac: &'ast Macro) {
251        if mac
252            .path
253            .segments
254            .last()
255            .is_some_and(|segment| segment.ident == "miden_span_record")
256        {
257            if let Ok(records) = syn::parse2::<RecordFields>(mac.tokens.clone()) {
258                self.fields.extend(records.fields.into_iter().map(|field| field.path));
259            }
260        }
261
262        syn::visit::visit_macro(self, mac);
263    }
264}
265
266struct RecordFields {
267    fields: Punctuated<RecordField, Token![,]>,
268}
269
270impl Parse for RecordFields {
271    fn parse(input: ParseStream<'_>) -> Result<Self> {
272        Ok(Self {
273            fields: Punctuated::parse_terminated(input)?,
274        })
275    }
276}
277
278struct RecordField {
279    path: FieldPath,
280    value: RecordValue,
281}
282
283impl Parse for RecordField {
284    fn parse(input: ParseStream<'_>) -> Result<Self> {
285        let path = input.parse()?;
286        validate_field_name(&path)?;
287        input.parse::<Token![=]>()?;
288        let value = input.parse()?;
289
290        Ok(Self { path, value })
291    }
292}
293
294struct FieldPath {
295    first: Ident,
296    rest: Vec<(Dot, Ident)>,
297}
298
299impl FieldPath {
300    fn name(&self) -> String {
301        std::iter::once(&self.first)
302            .chain(self.rest.iter().map(|(_, ident)| ident))
303            .map(ToString::to_string)
304            .collect::<Vec<_>>()
305            .join(".")
306    }
307}
308
309impl Parse for FieldPath {
310    fn parse(input: ParseStream<'_>) -> Result<Self> {
311        let first = input.parse()?;
312        let mut rest = Vec::new();
313
314        while input.peek(Token![.]) {
315            rest.push((input.parse()?, input.parse()?));
316        }
317
318        Ok(Self { first, rest })
319    }
320}
321
322impl ToTokens for FieldPath {
323    fn to_tokens(&self, tokens: &mut TokenStream2) {
324        self.first.to_tokens(tokens);
325        for (dot, ident) in &self.rest {
326            dot.to_tokens(tokens);
327            ident.to_tokens(tokens);
328        }
329    }
330}
331
332struct RecordValue {
333    formatter: Formatter,
334    expr: Expr,
335}
336
337impl RecordValue {
338    fn value_tokens(&self) -> TokenStream2 {
339        let expr = &self.expr;
340
341        match self.formatter {
342            Formatter::Display => quote! { &::tracing::field::display(#expr) },
343            Formatter::Debug => quote! { &::tracing::field::debug(#expr) },
344            Formatter::Plain => quote! { &#expr },
345        }
346    }
347}
348
349impl Parse for RecordValue {
350    fn parse(input: ParseStream<'_>) -> Result<Self> {
351        let formatter = if input.peek(Token![%]) {
352            input.parse::<Token![%]>()?;
353            Formatter::Display
354        } else if input.peek(Token![?]) {
355            input.parse::<Token![?]>()?;
356            Formatter::Debug
357        } else {
358            Formatter::Plain
359        };
360        let expr = input.parse()?;
361
362        Ok(Self { formatter, expr })
363    }
364}
365
366enum Formatter {
367    Display,
368    Debug,
369    Plain,
370}