miden-node-tracing-macro 0.15.1

Procedural macros for Miden node tracing
Documentation
use std::collections::BTreeSet;

use proc_macro::TokenStream;
use proc_macro2::{Delimiter, Group, TokenStream as TokenStream2, TokenTree};
use quote::{ToTokens, quote};
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::token::Dot;
use syn::visit::Visit;
use syn::{Block, Expr, Ident, ItemFn, Macro, Result, Token, parse_macro_input, parse_quote};

const ALLOWED_FIELD_NAMES: &[&str] = &[
    "account.id",
    "account.id.network_prefix",
    "account.ids",
    "account.ids.count",
    "account.updated",
    "batch.id",
    "batch.account_updates.count",
    "batch.expires_at",
    "batch.expiration_height",
    "batch.input_notes.count",
    "batch.output_notes.count",
    "batch.reference_block.commitment",
    "batch.reference_block.number",
    "block.batch.ids",
    "block.batches.count",
    "block.batches.output_notes.count",
    "block.commitment",
    "block.commitments.account",
    "block.commitments.chain",
    "block.commitments.kernel",
    "block.commitments.note",
    "block.commitments.nullifier",
    "block.commitments.transaction",
    "block.erased_note_proofs.count",
    "block.erased_notes.count",
    "block.from",
    "block.nullifiers.count",
    "block.number",
    "block.output_notes.count",
    "block.prev_block_commitment",
    "block.protocol.version",
    "block.sub_commitment",
    "block.timestamp",
    "block.transactions.ids",
    "block.transactions.count",
    "block.updated_accounts.count",
    "block_range.from",
    "block_range.to",
    "db.account_state_forest.size",
    "db.account_tree.size",
    "db.block_store.size",
    "db.nullifier_tree.size",
    "db.sqlite.size",
    "db.sqlite.wal.size",
    "dice_roll",
    "failure_rate",
    "mempool.accounts",
    "mempool.batches.proposed",
    "mempool.batches.proven",
    "mempool.nullifiers",
    "mempool.output_notes",
    "mempool.transactions.unbatched",
    "mempool.transactions.uncommitted",
    "note.id",
    "notes.count",
    "prover.kind",
    "reference_block.number",
    "request.kind",
    "script.root",
    "transaction.id",
    "transaction.expires_at",
    "transaction.input_notes.count",
    "transaction.output_notes.count",
    "transaction.reference_block.commitment",
    "transaction.reference_block.number",
    "tip.number",
    "transactions.count",
    "transactions.ids",
    "transactions.input_notes.count",
    "transactions.output_notes.count",
    "transactions.unauthenticated_notes.count",
    "workers.active",
    "workers.capacity",
    "workers.count",
];

#[proc_macro_attribute]
pub fn miden_instrument(attr: TokenStream, item: TokenStream) -> TokenStream {
    let attr = TokenStream2::from(attr);
    let mut function = parse_macro_input!(item as ItemFn);
    let fields = collect_recorded_fields(&function);
    let args = merge_inferred_fields(attr, &fields);
    let statements = &function.block.stmts;
    let block: Block = parse_quote! {{
        #[allow(unused_macros)]
        macro_rules! __miden_span_record_must_be_used_within_miden_instrument {
            () => {};
        }

        #(#statements)*
    }};
    *function.block = block;

    let expanded = quote! {
        #[::tracing::instrument(#args)]
        #function
    };

    expanded.into()
}

fn merge_inferred_fields(attr: TokenStream2, fields: &[FieldPath]) -> TokenStream2 {
    if fields.is_empty() {
        return attr;
    }

    let inferred_fields = quote! { #(#fields = ::tracing::field::Empty),* };
    if attr.is_empty() {
        return quote! { fields(#inferred_fields) };
    }

    let mut merged_existing_fields = false;
    let args = split_top_level_args(attr)
        .into_iter()
        .map(|arg| {
            if let Some(group) = fields_group(&arg) {
                merged_existing_fields = true;
                let existing_fields = group.stream();
                let merged_fields = if existing_fields.is_empty() {
                    inferred_fields.clone()
                } else if ends_with_comma(&existing_fields) {
                    quote! { #existing_fields #inferred_fields }
                } else {
                    quote! { #existing_fields, #inferred_fields }
                };
                let mut merged_group = Group::new(Delimiter::Parenthesis, merged_fields);
                merged_group.set_span(group.span());
                quote! { fields #merged_group }
            } else {
                arg
            }
        })
        .collect::<Vec<_>>();

    if merged_existing_fields {
        quote! { #(#args),* }
    } else {
        quote! { #(#args,)* fields(#inferred_fields) }
    }
}

fn split_top_level_args(tokens: TokenStream2) -> Vec<TokenStream2> {
    let mut args = Vec::new();
    let mut current = TokenStream2::new();

    for token in tokens {
        match &token {
            TokenTree::Punct(punct) if punct.as_char() == ',' => {
                args.push(current);
                current = TokenStream2::new();
            },
            _ => current.extend([token]),
        }
    }

    if !current.is_empty() {
        args.push(current);
    }

    args
}

fn fields_group(arg: &TokenStream2) -> Option<Group> {
    let mut tokens = arg.clone().into_iter();
    let Some(TokenTree::Ident(ident)) = tokens.next() else {
        return None;
    };
    if ident != "fields" {
        return None;
    }

    let Some(TokenTree::Group(group)) = tokens.next() else {
        return None;
    };
    if group.delimiter() != Delimiter::Parenthesis || tokens.next().is_some() {
        return None;
    }

    Some(group)
}

fn ends_with_comma(tokens: &TokenStream2) -> bool {
    matches!(
        tokens.clone().into_iter().last(),
        Some(TokenTree::Punct(punct)) if punct.as_char() == ','
    )
}

#[proc_macro]
pub fn miden_span_record(input: TokenStream) -> TokenStream {
    let records = parse_macro_input!(input as RecordFields);
    let records = records.fields.into_iter().map(|field| {
        let name = field.path.name();
        let value = field.value.value_tokens();

        quote! {
            ::tracing::Span::current().record(#name, #value);
        }
    });

    quote! {
        __miden_span_record_must_be_used_within_miden_instrument!();
        #(#records)*
    }
    .into()
}

fn validate_field_name(path: &FieldPath) -> Result<()> {
    let name = path.name();

    if ALLOWED_FIELD_NAMES.contains(&name.as_str()) {
        Ok(())
    } else {
        Err(syn::Error::new_spanned(
            path,
            format!(
                "unsupported tracing field `{name}`; use one of: {}",
                ALLOWED_FIELD_NAMES.join(", "),
            ),
        ))
    }
}

fn collect_recorded_fields(function: &ItemFn) -> Vec<FieldPath> {
    let mut visitor = MacroVisitor::default();
    visitor.visit_block(&function.block);

    let mut names = BTreeSet::new();
    visitor.fields.into_iter().filter(|field| names.insert(field.name())).collect()
}

#[derive(Default)]
struct MacroVisitor {
    fields: Vec<FieldPath>,
}

impl<'ast> Visit<'ast> for MacroVisitor {
    fn visit_macro(&mut self, mac: &'ast Macro) {
        if mac
            .path
            .segments
            .last()
            .is_some_and(|segment| segment.ident == "miden_span_record")
        {
            if let Ok(records) = syn::parse2::<RecordFields>(mac.tokens.clone()) {
                self.fields.extend(records.fields.into_iter().map(|field| field.path));
            }
        }

        syn::visit::visit_macro(self, mac);
    }
}

struct RecordFields {
    fields: Punctuated<RecordField, Token![,]>,
}

impl Parse for RecordFields {
    fn parse(input: ParseStream<'_>) -> Result<Self> {
        Ok(Self {
            fields: Punctuated::parse_terminated(input)?,
        })
    }
}

struct RecordField {
    path: FieldPath,
    value: RecordValue,
}

impl Parse for RecordField {
    fn parse(input: ParseStream<'_>) -> Result<Self> {
        let path = input.parse()?;
        validate_field_name(&path)?;
        input.parse::<Token![=]>()?;
        let value = input.parse()?;

        Ok(Self { path, value })
    }
}

struct FieldPath {
    first: Ident,
    rest: Vec<(Dot, Ident)>,
}

impl FieldPath {
    fn name(&self) -> String {
        std::iter::once(&self.first)
            .chain(self.rest.iter().map(|(_, ident)| ident))
            .map(ToString::to_string)
            .collect::<Vec<_>>()
            .join(".")
    }
}

impl Parse for FieldPath {
    fn parse(input: ParseStream<'_>) -> Result<Self> {
        let first = input.parse()?;
        let mut rest = Vec::new();

        while input.peek(Token![.]) {
            rest.push((input.parse()?, input.parse()?));
        }

        Ok(Self { first, rest })
    }
}

impl ToTokens for FieldPath {
    fn to_tokens(&self, tokens: &mut TokenStream2) {
        self.first.to_tokens(tokens);
        for (dot, ident) in &self.rest {
            dot.to_tokens(tokens);
            ident.to_tokens(tokens);
        }
    }
}

struct RecordValue {
    formatter: Formatter,
    expr: Expr,
}

impl RecordValue {
    fn value_tokens(&self) -> TokenStream2 {
        let expr = &self.expr;

        match self.formatter {
            Formatter::Display => quote! { &::tracing::field::display(#expr) },
            Formatter::Debug => quote! { &::tracing::field::debug(#expr) },
            Formatter::Plain => quote! { &#expr },
        }
    }
}

impl Parse for RecordValue {
    fn parse(input: ParseStream<'_>) -> Result<Self> {
        let formatter = if input.peek(Token![%]) {
            input.parse::<Token![%]>()?;
            Formatter::Display
        } else if input.peek(Token![?]) {
            input.parse::<Token![?]>()?;
            Formatter::Debug
        } else {
            Formatter::Plain
        };
        let expr = input.parse()?;

        Ok(Self { formatter, expr })
    }
}

enum Formatter {
    Display,
    Debug,
    Plain,
}