Skip to main content

dial9_trace_format_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{Data, DeriveInput, Fields, parse_macro_input};
4
5fn derive_trace_event_impl(input: DeriveInput) -> proc_macro2::TokenStream {
6    let name = &input.ident;
7    let vis = &input.vis;
8    let name_str = name.to_string();
9    let ref_name = format_ident!("{}Ref", name);
10
11    let fields = match &input.data {
12        Data::Struct(data) => match &data.fields {
13            Fields::Named(f) => &f.named,
14            _ => panic!("TraceEvent only supports named fields"),
15        },
16        _ => panic!("TraceEvent can only be derived for structs"),
17    };
18
19    // Find the field marked with #[traceevent(timestamp)]
20    let mut timestamp_field_name = None;
21    for field in fields.iter() {
22        for attr in &field.attrs {
23            if attr.path().is_ident("traceevent") {
24                let _ = attr.parse_nested_meta(|meta| {
25                    if meta.path.is_ident("timestamp") {
26                        timestamp_field_name = Some(field.ident.as_ref().unwrap().clone());
27                    }
28                    Ok(())
29                });
30            }
31        }
32    }
33
34    let mut field_def_tokens = Vec::new();
35    let mut encode_tokens = Vec::new();
36    let mut ref_fields = Vec::new();
37    let mut decode_tokens = Vec::new();
38    let mut decode_idx = 0usize;
39
40    for field in fields.iter() {
41        let field_name = field.ident.as_ref().unwrap();
42        let ty = &field.ty;
43
44        // Skip the timestamp field in schema/encode — it's in the event header
45        if timestamp_field_name.as_ref() == Some(field_name) {
46            continue;
47        }
48
49        let field_name_str = field_name.to_string();
50        field_def_tokens.push(quote! {
51            ::dial9_trace_format::schema::FieldDef {
52                name: #field_name_str.to_string(),
53                field_type: <#ty as ::dial9_trace_format::TraceField>::field_type(),
54            }
55        });
56        encode_tokens.push(quote! {
57            <#ty as ::dial9_trace_format::TraceField>::encode(&self.#field_name, enc)?;
58        });
59
60        ref_fields.push(quote! {
61            pub #field_name: <#ty as ::dial9_trace_format::TraceField>::Ref<'a>
62        });
63        let idx = decode_idx;
64        decode_tokens.push(quote! {
65            #field_name: <#ty as ::dial9_trace_format::TraceField>::decode_ref(fields.get(#idx)?)?
66        });
67        decode_idx += 1;
68    }
69
70    let timestamp_impl = if let Some(ref ts_field) = timestamp_field_name {
71        quote! {
72            fn timestamp(&self) -> u64 { self.#ts_field }
73        }
74    } else {
75        panic!("TraceEvent requires a field marked with #[traceevent(timestamp)]");
76    };
77
78    let has_timestamp_impl = quote! {};
79
80    // For the Ref struct, include the timestamp field if present — populated from the decode parameter
81    let ref_timestamp_field = if let Some(ref ts_field) = timestamp_field_name {
82        quote! { pub #ts_field: u64, }
83    } else {
84        quote! {}
85    };
86    let decode_timestamp_init = if let Some(ref ts_field) = timestamp_field_name {
87        quote! { #ts_field: timestamp_ns?, }
88    } else {
89        quote! {}
90    };
91
92    let phantom_field =
93        if fields.is_empty() || (fields.len() == 1 && timestamp_field_name.is_some()) {
94            quote! { _marker: ::std::marker::PhantomData<&'a ()>, }
95        } else {
96            quote! {}
97        };
98    let phantom_init = if fields.is_empty() || (fields.len() == 1 && timestamp_field_name.is_some())
99    {
100        quote! { _marker: ::std::marker::PhantomData, }
101    } else {
102        quote! {}
103    };
104
105    quote! {
106        #[derive(Debug, Clone)]
107        #vis struct #ref_name<'a> {
108            #ref_timestamp_field
109            #(#ref_fields,)*
110            #phantom_field
111        }
112
113        impl ::dial9_trace_format::TraceEvent for #name {
114            type Ref<'a> = #ref_name<'a>;
115
116            fn event_name() -> &'static str { #name_str }
117            fn field_defs() -> Vec<::dial9_trace_format::schema::FieldDef> {
118                vec![#(#field_def_tokens),*]
119            }
120            #timestamp_impl
121            #has_timestamp_impl
122            fn encode_fields<W: ::std::io::Write>(&self, enc: &mut ::dial9_trace_format::EventEncoder<'_, W>) -> ::std::io::Result<()> {
123                #(#encode_tokens)*
124                Ok(())
125            }
126            fn decode<'a>(timestamp_ns: Option<u64>, fields: &[::dial9_trace_format::types::FieldValueRef<'a>]) -> Option<Self::Ref<'a>> {
127                Some(#ref_name {
128                    #decode_timestamp_init
129                    #(#decode_tokens,)*
130                    #phantom_init
131                })
132            }
133        }
134    }
135}
136
137#[proc_macro_derive(TraceEvent, attributes(traceevent))]
138pub fn derive_trace_event(input: TokenStream) -> TokenStream {
139    let input = parse_macro_input!(input as DeriveInput);
140    TokenStream::from(derive_trace_event_impl(input))
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use insta::assert_snapshot;
147
148    fn expand_to_string(input: proc_macro2::TokenStream) -> String {
149        let input: DeriveInput = syn::parse2(input).unwrap();
150        let output = derive_trace_event_impl(input);
151        match syn::parse2::<syn::File>(output.clone()) {
152            Ok(file) => prettyplease::unparse(&file),
153            Err(_) => output.to_string(),
154        }
155    }
156
157    #[test]
158    fn simple_event() {
159        assert_snapshot!(expand_to_string(quote! {
160            struct SimpleEvent {
161                #[traceevent(timestamp)]
162                timestamp_ns: u64,
163                value: u32,
164            }
165        }));
166    }
167
168    #[test]
169    fn empty_event() {
170        assert_snapshot!(expand_to_string(quote! {
171            struct EmptyEvent {
172                #[traceevent(timestamp)]
173                timestamp_ns: u64,
174            }
175        }));
176    }
177
178    #[test]
179    fn all_field_types() {
180        assert_snapshot!(expand_to_string(quote! {
181            struct AllFieldTypes {
182                #[traceevent(timestamp)]
183                timestamp_ns: u64,
184                a_u8: u8,
185                b_u16: u16,
186                c_u32: u32,
187                d_u64: u64,
188                e_i64: i64,
189                f_f64: f64,
190                g_bool: bool,
191                h_string: String,
192                i_bytes: Vec<u8>,
193                j_interned: InternedString,
194                k_frames: StackFrames,
195                l_map: Vec<(String, String)>,
196            }
197        }));
198    }
199
200    #[test]
201    fn timestamp_attribute() {
202        assert_snapshot!(expand_to_string(quote! {
203            struct PollStart {
204                #[traceevent(timestamp)]
205                timestamp_ns: u64,
206                worker_id: u64,
207                task_id: u64,
208            }
209        }));
210    }
211}