Skip to main content

commonware_runtime_macros/
lib.rs

1//! Augment the development of [`commonware-runtime`](https://docs.rs/commonware-runtime) with procedural macros.
2
3#![doc(
4    html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
5    html_favicon_url = "https://commonware.xyz/favicon.ico"
6)]
7
8use proc_macro::TokenStream;
9use proc_macro2::Span;
10use proc_macro_crate::{crate_name, FoundCrate};
11use quote::quote;
12use syn::{parse_quote, DeriveInput, Ident};
13
14fn found_crate_path(found: FoundCrate) -> proc_macro2::TokenStream {
15    match found {
16        FoundCrate::Itself => quote!(crate),
17        FoundCrate::Name(name) => {
18            let ident = Ident::new(&name, Span::call_site());
19            quote!(::#ident)
20        }
21    }
22}
23
24// The upstream derive crate hardcodes `::prometheus_client::encoding` in the
25// generated impls.
26//
27// Source: https://github.com/prometheus/client_rust/blob/7844d8617926a6f29b772d195860cf118051d019/derive-encode/src/lib.rs#L14-L133
28//
29// Commonware resolves through `commonware-runtime::telemetry::metrics::encoding`
30// first so downstream crates can derive metric labels without a direct
31// `prometheus-client` dependency.
32fn metrics_encoding_path() -> proc_macro2::TokenStream {
33    if let Ok(found) = crate_name("commonware-runtime") {
34        let runtime = found_crate_path(found);
35        return quote!(#runtime::telemetry::metrics::encoding);
36    }
37    if let Ok(found) = crate_name("prometheus-client") {
38        let prometheus = found_crate_path(found);
39        return quote!(#prometheus::encoding);
40    }
41    quote!(::prometheus_client::encoding)
42}
43
44// Adapted from client_rust's `EncodeLabelSet` derive and extended to support
45// Commonware's `EncodeStruct` variant.
46//
47// Source: https://github.com/prometheus/client_rust/blob/7844d8617926a6f29b772d195860cf118051d019/derive-encode/src/lib.rs#L14-L87
48#[proc_macro_derive(EncodeLabelSet, attributes(prometheus))]
49pub fn derive_encode_label_set(input: TokenStream) -> TokenStream {
50    derive_label_set_impl(input, false)
51}
52
53#[proc_macro_derive(EncodeStruct)]
54pub fn derive_encode_struct(input: TokenStream) -> TokenStream {
55    derive_label_set_impl(input, true)
56}
57
58fn derive_label_set_impl(input: TokenStream, display: bool) -> TokenStream {
59    let ast: DeriveInput = syn::parse(input).unwrap();
60    let name = &ast.ident;
61    let encoding = metrics_encoding_path();
62
63    let fields = match ast.clone().data {
64        syn::Data::Struct(s) => match s.fields {
65            syn::Fields::Named(syn::FieldsNamed { named, .. }) => named,
66            syn::Fields::Unnamed(_) => {
67                panic!("Can not derive Encode for struct with unnamed fields.")
68            }
69            syn::Fields::Unit => panic!("Can not derive Encode for struct with unit field."),
70        },
71        syn::Data::Enum(syn::DataEnum { .. }) => panic!("Can not derive Encode for enum."),
72        syn::Data::Union(_) => panic!("Can not derive Encode for union."),
73    };
74
75    let fields_vec: Vec<_> = fields.into_iter().collect();
76    let body: proc_macro2::TokenStream = fields_vec
77        .iter()
78        .cloned()
79        .map(|f| {
80            let attribute = f
81                .attrs
82                .iter()
83                .find(|a| a.path().is_ident("prometheus"))
84                .map(|a| a.parse_args::<Ident>().unwrap().to_string());
85            let flatten = match attribute.as_deref() {
86                Some("flatten") => true,
87                Some(other) => {
88                    panic!("Provided field attribute '{other}', but only 'flatten' is supported")
89                }
90                None => false,
91            };
92            let ident = f.ident.unwrap();
93            if flatten {
94                quote! {
95                    #encoding::EncodeLabelSet::encode(&self.#ident, encoder)?;
96                }
97            } else {
98                let ident_string = KEYWORD_IDENTIFIERS
99                    .iter()
100                    .find(|pair| ident == pair.1)
101                    .map(|pair| pair.0.to_string())
102                    .unwrap_or_else(|| ident.to_string());
103
104                let encode_value = if display {
105                    quote! {
106                        ::core::write!(&mut label_value_encoder, "{}", self.#ident)?;
107                    }
108                } else {
109                    quote! {
110                        EncodeLabelValue::encode(&self.#ident, &mut label_value_encoder)?;
111                    }
112                };
113
114                quote! {
115                    let mut label_encoder = encoder.encode_label();
116                    let mut label_key_encoder = label_encoder.encode_label_key()?;
117                    EncodeLabelKey::encode(&#ident_string, &mut label_key_encoder)?;
118
119                    let mut label_value_encoder = label_key_encoder.encode_label_value()?;
120                    #encode_value
121
122                    label_value_encoder.finish()?;
123                }
124            }
125        })
126        .collect();
127
128    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
129
130    let single_field_impls = if display && fields_vec.len() == 1 {
131        let field = &fields_vec[0];
132        let field_ident = field.ident.as_ref().unwrap();
133        let field_ty = &field.ty;
134        // Preserve the wrapper's own predicates and add `Clone` on the field type,
135        // so wrappers that write their bounds in a `where` clause (not inline)
136        // still get a well-formed `From<&T>` impl.
137        let mut from_generics = ast.generics.clone();
138        from_generics
139            .make_where_clause()
140            .predicates
141            .push(parse_quote!(#field_ty: ::core::clone::Clone));
142        let (from_impl_generics, from_ty_generics, from_where_clause) =
143            from_generics.split_for_impl();
144        quote! {
145            impl #impl_generics ::core::borrow::Borrow<#field_ty> for #name #ty_generics #where_clause {
146                fn borrow(&self) -> &#field_ty {
147                    &self.#field_ident
148                }
149            }
150
151            impl #from_impl_generics ::core::convert::From<&#field_ty> for #name #from_ty_generics #from_where_clause {
152                fn from(value: &#field_ty) -> Self {
153                    Self { #field_ident: value.clone() }
154                }
155            }
156        }
157    } else {
158        quote!()
159    };
160
161    quote! {
162        impl #impl_generics #encoding::EncodeLabelSet for #name #ty_generics #where_clause {
163            fn encode(&self, encoder: &mut #encoding::LabelSetEncoder) -> ::core::result::Result<(), ::core::fmt::Error> {
164                use #encoding::EncodeLabel;
165                use #encoding::EncodeLabelKey;
166                use #encoding::EncodeLabelValue;
167                use ::core::fmt::Write as _;
168
169                #body
170
171                ::core::result::Result::Ok(())
172            }
173        }
174
175        #single_field_impls
176    }
177    .into()
178}
179
180// Adapted from client_rust's `EncodeLabelValue` derive so the generated impls
181// resolve through `metrics_encoding_path()` instead of a hardcoded crate path.
182//
183// Source: https://github.com/prometheus/client_rust/blob/7844d8617926a6f29b772d195860cf118051d019/derive-encode/src/lib.rs#L90-L133
184#[proc_macro_derive(EncodeLabelValue)]
185pub fn derive_encode_label_value(input: TokenStream) -> TokenStream {
186    let ast: DeriveInput = syn::parse(input).unwrap();
187    let name = &ast.ident;
188    let encoding = metrics_encoding_path();
189
190    let body = match ast.clone().data {
191        syn::Data::Struct(_) => panic!("Can not derive EncodeLabel for struct."),
192        syn::Data::Enum(syn::DataEnum { variants, .. }) => {
193            let match_arms: proc_macro2::TokenStream = variants
194                .into_iter()
195                .map(|v| {
196                    let ident = v.ident;
197                    quote! {
198                        #name::#ident => encoder.write_str(stringify!(#ident))?,
199                    }
200                })
201                .collect();
202
203            quote! {
204                match self {
205                    #match_arms
206                }
207            }
208        }
209        syn::Data::Union(_) => panic!("Can not derive Encode for union."),
210    };
211
212    quote! {
213        impl #encoding::EncodeLabelValue for #name {
214            fn encode(&self, encoder: &mut #encoding::LabelValueEncoder) -> ::core::result::Result<(), ::core::fmt::Error> {
215                use ::core::fmt::Write;
216
217                #body
218
219                ::core::result::Result::Ok(())
220            }
221        }
222    }
223    .into()
224}
225
226// Copied from client_rust's keyword table, which in turn cites Askama.
227//
228// Source: https://github.com/prometheus/client_rust/blob/7844d8617926a6f29b772d195860cf118051d019/derive-encode/src/lib.rs#L135-L184
229static KEYWORD_IDENTIFIERS: [(&str, &str); 49] = [
230    ("as", "r#as"),
231    ("break", "r#break"),
232    ("const", "r#const"),
233    ("continue", "r#continue"),
234    ("crate", "r#crate"),
235    ("else", "r#else"),
236    ("enum", "r#enum"),
237    ("extern", "r#extern"),
238    ("false", "r#false"),
239    ("fn", "r#fn"),
240    ("for", "r#for"),
241    ("if", "r#if"),
242    ("impl", "r#impl"),
243    ("in", "r#in"),
244    ("let", "r#let"),
245    ("loop", "r#loop"),
246    ("match", "r#match"),
247    ("mod", "r#mod"),
248    ("move", "r#move"),
249    ("mut", "r#mut"),
250    ("pub", "r#pub"),
251    ("ref", "r#ref"),
252    ("return", "r#return"),
253    ("self", "r#self"),
254    ("Self", "r#Self"),
255    ("static", "r#static"),
256    ("struct", "r#struct"),
257    ("super", "r#super"),
258    ("trait", "r#trait"),
259    ("true", "r#true"),
260    ("type", "r#type"),
261    ("unsafe", "r#unsafe"),
262    ("use", "r#use"),
263    ("where", "r#where"),
264    ("while", "r#while"),
265    ("async", "r#async"),
266    ("await", "r#await"),
267    ("dyn", "r#dyn"),
268    ("abstract", "r#abstract"),
269    ("become", "r#become"),
270    ("box", "r#box"),
271    ("do", "r#do"),
272    ("final", "r#final"),
273    ("macro", "r#macro"),
274    ("override", "r#override"),
275    ("priv", "r#priv"),
276    ("typeof", "r#typeof"),
277    ("unsized", "r#unsized"),
278    ("virtual", "r#virtual"),
279];