Skip to main content

dial9_trace_format_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{Data, DeriveInput, Fields, parse_macro_input};
4
5fn derive_trace_event_impl(input: DeriveInput) -> proc_macro2::TokenStream {
6    let name = &input.ident;
7    let vis = &input.vis;
8    let name_str = name.to_string();
9    let ref_name = format_ident!("{}Ref", name);
10    let struct_doc_attrs: Vec<_> = input
11        .attrs
12        .iter()
13        .filter(|a| a.path().is_ident("doc"))
14        .collect();
15
16    let fields = match &input.data {
17        Data::Struct(data) => match &data.fields {
18            Fields::Named(f) => &f.named,
19            _ => panic!("TraceEvent only supports named fields"),
20        },
21        _ => panic!("TraceEvent can only be derived for structs"),
22    };
23
24    // Find the field marked with #[traceevent(timestamp)]
25    let mut timestamp_field_name = None;
26    let mut timestamp_doc_attrs = Vec::new();
27    for field in fields.iter() {
28        for attr in &field.attrs {
29            if attr.path().is_ident("traceevent") {
30                let _ = attr.parse_nested_meta(|meta| {
31                    if meta.path.is_ident("timestamp") {
32                        timestamp_field_name = Some(field.ident.as_ref().unwrap().clone());
33                        timestamp_doc_attrs = field
34                            .attrs
35                            .iter()
36                            .filter(|a| a.path().is_ident("doc"))
37                            .cloned()
38                            .collect();
39                    }
40                    Ok(())
41                });
42            }
43        }
44    }
45
46    let mut field_def_tokens = Vec::new();
47    let mut encode_tokens = Vec::new();
48    let mut ref_fields = Vec::new();
49    let mut decode_tokens = Vec::new();
50
51    for field in fields.iter() {
52        let field_name = field.ident.as_ref().unwrap();
53        let ty = &field.ty;
54
55        // Skip the timestamp field in schema/encode — it's in the event header
56        if timestamp_field_name.as_ref() == Some(field_name) {
57            continue;
58        }
59
60        let field_name_str = field_name.to_string();
61        field_def_tokens.push(quote! {
62            ::dial9_trace_format::schema::FieldDef::new(
63                #field_name_str,
64                <#ty as ::dial9_trace_format::TraceField>::field_type(),
65            )
66        });
67        encode_tokens.push(quote! {
68            <#ty as ::dial9_trace_format::TraceField>::encode(&self.#field_name, enc)?;
69        });
70
71        let doc_attrs: Vec<_> = field
72            .attrs
73            .iter()
74            .filter(|a| a.path().is_ident("doc"))
75            .collect();
76        ref_fields.push(quote! {
77            #(#doc_attrs)*
78            pub #field_name: <#ty as ::dial9_trace_format::TraceField>::Ref<'a>
79        });
80        decode_tokens.push(quote! {
81            #field_name: {
82                let val = field_defs.iter().zip(fields.iter())
83                    .find(|(f, _)| f.name() == #field_name_str)
84                    .map(|(_, v)| v);
85                match val {
86                    Some(v) => <#ty as ::dial9_trace_format::TraceField>::decode_ref(v)?,
87                    None => <#ty as ::dial9_trace_format::TraceField>::decode_missing()?,
88                }
89            }
90        });
91    }
92
93    let timestamp_impl = if let Some(ref ts_field) = timestamp_field_name {
94        quote! {
95            fn timestamp(&self) -> u64 { self.#ts_field }
96        }
97    } else {
98        panic!("TraceEvent requires a field marked with #[traceevent(timestamp)]");
99    };
100
101    let has_timestamp_impl = quote! {};
102
103    // For the Ref struct, include the timestamp field if present — populated from the decode parameter
104    let ref_timestamp_field = if let Some(ref ts_field) = timestamp_field_name {
105        let ts_docs = &timestamp_doc_attrs;
106        quote! { #(#ts_docs)* pub #ts_field: u64, }
107    } else {
108        quote! {}
109    };
110    let decode_timestamp_init = if let Some(ref ts_field) = timestamp_field_name {
111        quote! { #ts_field: timestamp_ns?, }
112    } else {
113        quote! {}
114    };
115
116    let phantom_field =
117        if fields.is_empty() || (fields.len() == 1 && timestamp_field_name.is_some()) {
118            quote! { _marker: ::std::marker::PhantomData<&'a ()>, }
119        } else {
120            quote! {}
121        };
122    let phantom_init = if fields.is_empty() || (fields.len() == 1 && timestamp_field_name.is_some())
123    {
124        quote! { _marker: ::std::marker::PhantomData, }
125    } else {
126        quote! {}
127    };
128
129    quote! {
130        #(#struct_doc_attrs)*
131        #[derive(Debug, Clone)]
132        #vis struct #ref_name<'a> {
133            #ref_timestamp_field
134            #(#ref_fields,)*
135            #phantom_field
136        }
137
138        impl ::dial9_trace_format::TraceEvent for #name {
139            type Ref<'a> = #ref_name<'a>;
140
141            fn event_name() -> &'static str { #name_str }
142            fn field_defs() -> Vec<::dial9_trace_format::schema::FieldDef> {
143                vec![#(#field_def_tokens),*]
144            }
145            #timestamp_impl
146            #has_timestamp_impl
147            fn encode_fields<W: ::std::io::Write>(&self, enc: &mut ::dial9_trace_format::EventEncoder<'_, W>) -> ::std::io::Result<()> {
148                #(#encode_tokens)*
149                Ok(())
150            }
151            fn decode<'a>(timestamp_ns: Option<u64>, fields: &[::dial9_trace_format::types::FieldValueRef<'a>], field_defs: &[::dial9_trace_format::schema::FieldDef]) -> Option<Self::Ref<'a>> {
152                Some(#ref_name {
153                    #decode_timestamp_init
154                    #(#decode_tokens,)*
155                    #phantom_init
156                })
157            }
158        }
159    }
160}
161
162#[proc_macro_derive(TraceEvent, attributes(traceevent))]
163pub fn derive_trace_event(input: TokenStream) -> TokenStream {
164    let input = parse_macro_input!(input as DeriveInput);
165    TokenStream::from(derive_trace_event_impl(input))
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use insta::assert_snapshot;
172
173    fn expand_to_string(input: proc_macro2::TokenStream) -> String {
174        let input: DeriveInput = syn::parse2(input).unwrap();
175        let output = derive_trace_event_impl(input);
176        match syn::parse2::<syn::File>(output.clone()) {
177            Ok(file) => prettyplease::unparse(&file),
178            Err(_) => output.to_string(),
179        }
180    }
181
182    #[test]
183    fn simple_event() {
184        assert_snapshot!(expand_to_string(quote! {
185            struct SimpleEvent {
186                #[traceevent(timestamp)]
187                timestamp_ns: u64,
188                value: u32,
189            }
190        }));
191    }
192
193    #[test]
194    fn empty_event() {
195        assert_snapshot!(expand_to_string(quote! {
196            struct EmptyEvent {
197                #[traceevent(timestamp)]
198                timestamp_ns: u64,
199            }
200        }));
201    }
202
203    #[test]
204    fn all_field_types() {
205        assert_snapshot!(expand_to_string(quote! {
206            struct AllFieldTypes {
207                #[traceevent(timestamp)]
208                timestamp_ns: u64,
209                a_u8: u8,
210                b_u16: u16,
211                c_u32: u32,
212                d_u64: u64,
213                e_i64: i64,
214                f_f64: f64,
215                g_bool: bool,
216                h_string: String,
217                i_bytes: Vec<u8>,
218                j_interned: InternedString,
219                k_frames: StackFrames,
220                l_map: Vec<(String, String)>,
221            }
222        }));
223    }
224
225    #[test]
226    fn doc_comments_copied_to_ref_fields() {
227        assert_snapshot!(expand_to_string(quote! {
228            /// Root documentation
229            struct DocEvent {
230                #[traceevent(timestamp)]
231                /// Event timestamp in nanoseconds.
232                timestamp_ns: u64,
233                /// The worker thread ID.
234                worker_id: u64,
235                /// Number of items in the local queue.
236                local_queue: u8,
237            }
238        }));
239    }
240
241    #[test]
242    fn timestamp_attribute() {
243        assert_snapshot!(expand_to_string(quote! {
244            struct PollStart {
245                #[traceevent(timestamp)]
246                timestamp_ns: u64,
247                worker_id: u64,
248                task_id: u64,
249            }
250        }));
251    }
252}