miden_node_grpc_error_macro/
lib.rs

1//! Procedural macro for deriving the `GrpcError` trait on error enums.
2//!
3//! This macro simplifies the creation of gRPC-compatible error enums by automatically:
4//! - Generating a companion error enum for gRPC serialization
5//! - Implementing the `GrpcError` trait
6//! - Providing proper error code mappings
7//! - Generating `From<Error> for tonic::Status` conversion
8//!
9//! # Example
10//!
11//! ```rust,ignore
12//! use miden_node_grpc_error_macro::GrpcError;
13//! use thiserror::Error;
14//!
15//! #[derive(Debug, Error, GrpcError)]
16//! pub enum GetNoteScriptByRootError {
17//!     #[error("database error")]
18//!     #[grpc(internal)]
19//!     DatabaseError(#[from] DatabaseError),
20//!     
21//!     #[error("malformed script root")]
22//!     DeserializationFailed,
23//!     
24//!     #[error("script with given root doesn't exist")]
25//!     ScriptNotFound,
26//! }
27//! ```
28
29use proc_macro::TokenStream;
30use quote::quote;
31use syn::{Data, DeriveInput, Fields, Ident, parse_macro_input};
32
33/// Derives the `GrpcError` trait for an error enum.
34///
35/// # Attributes
36///
37/// - `#[grpc(internal)]` - Marks a variant as an internal error (will map to
38///   `tonic::Code::Internal`)
39///
40/// # Generated Code
41///
42/// This macro generates:
43/// 1. A companion `*GrpcError` enum with `#[repr(u8)]` for wire serialization
44/// 2. An implementation of the `GrpcError` trait for the companion enum
45/// 3. A method `api_error()` on the original enum that maps to the companion enum
46/// 4. An implementation of `From<Error> for tonic::Status` for automatic error conversion
47#[proc_macro_derive(GrpcError, attributes(grpc))]
48pub fn derive_grpc_error(input: TokenStream) -> TokenStream {
49    let input = parse_macro_input!(input as DeriveInput);
50
51    let name = &input.ident;
52    let vis = &input.vis;
53    let grpc_name = Ident::new(&format!("{name}GrpcError"), name.span());
54
55    let variants = match &input.data {
56        Data::Enum(data) => &data.variants,
57        _ => {
58            return syn::Error::new_spanned(name, "GrpcError can only be derived for enums")
59                .to_compile_error()
60                .into();
61        },
62    };
63
64    // Build the GrpcError enum variants
65    let mut grpc_variants = Vec::new();
66    let mut api_error_arms = Vec::new();
67
68    // Always add Internal variant (standard practice for gRPC errors)
69    grpc_variants.push(quote! {
70        /// Internal server error
71        Internal = 0
72    });
73    let mut discriminant = 1u8;
74
75    for variant in variants {
76        let variant_name = &variant.ident;
77
78        // Check if this variant is marked as internal
79        let is_internal = variant.attrs.iter().any(|attr| {
80            attr.path().is_ident("grpc")
81                && attr.parse_args::<Ident>().map(|i| i == "internal").unwrap_or(false)
82        });
83
84        // Extract doc comments
85        let docs: Vec<_> =
86            variant.attrs.iter().filter(|attr| attr.path().is_ident("doc")).collect();
87
88        if is_internal {
89            // Map to Internal variant
90            let pattern = match &variant.fields {
91                Fields::Unit => quote! { #name::#variant_name },
92                Fields::Unnamed(_) => quote! { #name::#variant_name(..) },
93                Fields::Named(_) => quote! { #name::#variant_name { .. } },
94            };
95
96            api_error_arms.push(quote! {
97                #pattern => #grpc_name::Internal
98            });
99        } else {
100            // Create a corresponding variant in GrpcError enum
101            grpc_variants.push(quote! {
102                #(#docs)*
103                #variant_name = #discriminant
104            });
105
106            let pattern = match &variant.fields {
107                Fields::Unit => quote! { #name::#variant_name },
108                Fields::Unnamed(_) => quote! { #name::#variant_name(..) },
109                Fields::Named(_) => quote! { #name::#variant_name { .. } },
110            };
111
112            api_error_arms.push(quote! {
113                #pattern => #grpc_name::#variant_name
114            });
115
116            discriminant += 1;
117        }
118    }
119
120    let expanded = quote! {
121        #[derive(Debug, Copy, Clone, PartialEq, Eq)]
122        #[repr(u8)]
123        #vis enum #grpc_name {
124            #(#grpc_variants,)*
125        }
126
127        impl #grpc_name {
128            /// Returns the error code for this gRPC error.
129            pub fn api_code(self) -> u8 {
130                self as u8
131            }
132
133            /// Returns true if this is an internal server error.
134            pub fn is_internal(&self) -> bool {
135                matches!(self, Self::Internal)
136            }
137
138            /// Returns the appropriate tonic code for this error.
139            pub fn tonic_code(&self) -> tonic::Code {
140                if self.is_internal() {
141                    tonic::Code::Internal
142                } else {
143                    tonic::Code::InvalidArgument
144                }
145            }
146        }
147
148        impl #name {
149            /// Maps this error to its gRPC error code representation.
150            pub fn api_error(&self) -> #grpc_name {
151                match self {
152                    #(#api_error_arms,)*
153                }
154            }
155        }
156
157        // Automatically implement From<Error> for tonic::Status
158        impl From<#name> for tonic::Status {
159            fn from(value: #name) -> Self {
160                let api_error = value.api_error();
161
162                let message = if api_error.is_internal() {
163                    "Internal error".to_owned()
164                } else {
165                    // Use ErrorReport trait to get detailed error message
166                    use miden_node_utils::ErrorReport as _;
167                    value.as_report()
168                };
169
170                tonic::Status::with_details(
171                    api_error.tonic_code(),
172                    message,
173                    vec![api_error.api_code()].into(),
174                )
175            }
176        }
177    };
178
179    TokenStream::from(expanded)
180}