dial9_trace_format_derive/
lib.rs1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{Data, DeriveInput, Fields, parse_macro_input};
4
5fn derive_trace_event_impl(input: DeriveInput) -> proc_macro2::TokenStream {
6 let name = &input.ident;
7 let vis = &input.vis;
8 let name_str = name.to_string();
9 let ref_name = format_ident!("{}Ref", name);
10
11 let fields = match &input.data {
12 Data::Struct(data) => match &data.fields {
13 Fields::Named(f) => &f.named,
14 _ => panic!("TraceEvent only supports named fields"),
15 },
16 _ => panic!("TraceEvent can only be derived for structs"),
17 };
18
19 let mut timestamp_field_name = None;
21 for field in fields.iter() {
22 for attr in &field.attrs {
23 if attr.path().is_ident("traceevent") {
24 let _ = attr.parse_nested_meta(|meta| {
25 if meta.path.is_ident("timestamp") {
26 timestamp_field_name = Some(field.ident.as_ref().unwrap().clone());
27 }
28 Ok(())
29 });
30 }
31 }
32 }
33
34 let mut field_def_tokens = Vec::new();
35 let mut encode_tokens = Vec::new();
36 let mut ref_fields = Vec::new();
37 let mut decode_tokens = Vec::new();
38 let mut decode_idx = 0usize;
39
40 for field in fields.iter() {
41 let field_name = field.ident.as_ref().unwrap();
42 let ty = &field.ty;
43
44 if timestamp_field_name.as_ref() == Some(field_name) {
46 continue;
47 }
48
49 let field_name_str = field_name.to_string();
50 field_def_tokens.push(quote! {
51 ::dial9_trace_format::schema::FieldDef {
52 name: #field_name_str.to_string(),
53 field_type: <#ty as ::dial9_trace_format::TraceField>::field_type(),
54 }
55 });
56 encode_tokens.push(quote! {
57 <#ty as ::dial9_trace_format::TraceField>::encode(&self.#field_name, enc)?;
58 });
59
60 ref_fields.push(quote! {
61 pub #field_name: <#ty as ::dial9_trace_format::TraceField>::Ref<'a>
62 });
63 let idx = decode_idx;
64 decode_tokens.push(quote! {
65 #field_name: <#ty as ::dial9_trace_format::TraceField>::decode_ref(fields.get(#idx)?)?
66 });
67 decode_idx += 1;
68 }
69
70 let timestamp_impl = if let Some(ref ts_field) = timestamp_field_name {
71 quote! {
72 fn timestamp(&self) -> u64 { self.#ts_field }
73 }
74 } else {
75 panic!("TraceEvent requires a field marked with #[traceevent(timestamp)]");
76 };
77
78 let has_timestamp_impl = quote! {};
79
80 let ref_timestamp_field = if let Some(ref ts_field) = timestamp_field_name {
82 quote! { pub #ts_field: u64, }
83 } else {
84 quote! {}
85 };
86 let decode_timestamp_init = if let Some(ref ts_field) = timestamp_field_name {
87 quote! { #ts_field: timestamp_ns?, }
88 } else {
89 quote! {}
90 };
91
92 let phantom_field =
93 if fields.is_empty() || (fields.len() == 1 && timestamp_field_name.is_some()) {
94 quote! { _marker: ::std::marker::PhantomData<&'a ()>, }
95 } else {
96 quote! {}
97 };
98 let phantom_init = if fields.is_empty() || (fields.len() == 1 && timestamp_field_name.is_some())
99 {
100 quote! { _marker: ::std::marker::PhantomData, }
101 } else {
102 quote! {}
103 };
104
105 quote! {
106 #[derive(Debug, Clone)]
107 #vis struct #ref_name<'a> {
108 #ref_timestamp_field
109 #(#ref_fields,)*
110 #phantom_field
111 }
112
113 impl ::dial9_trace_format::TraceEvent for #name {
114 type Ref<'a> = #ref_name<'a>;
115
116 fn event_name() -> &'static str { #name_str }
117 fn field_defs() -> Vec<::dial9_trace_format::schema::FieldDef> {
118 vec![#(#field_def_tokens),*]
119 }
120 #timestamp_impl
121 #has_timestamp_impl
122 fn encode_fields<W: ::std::io::Write>(&self, enc: &mut ::dial9_trace_format::EventEncoder<'_, W>) -> ::std::io::Result<()> {
123 #(#encode_tokens)*
124 Ok(())
125 }
126 fn decode<'a>(timestamp_ns: Option<u64>, fields: &[::dial9_trace_format::types::FieldValueRef<'a>]) -> Option<Self::Ref<'a>> {
127 Some(#ref_name {
128 #decode_timestamp_init
129 #(#decode_tokens,)*
130 #phantom_init
131 })
132 }
133 }
134 }
135}
136
137#[proc_macro_derive(TraceEvent, attributes(traceevent))]
138pub fn derive_trace_event(input: TokenStream) -> TokenStream {
139 let input = parse_macro_input!(input as DeriveInput);
140 TokenStream::from(derive_trace_event_impl(input))
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use insta::assert_snapshot;
147
148 fn expand_to_string(input: proc_macro2::TokenStream) -> String {
149 let input: DeriveInput = syn::parse2(input).unwrap();
150 let output = derive_trace_event_impl(input);
151 match syn::parse2::<syn::File>(output.clone()) {
152 Ok(file) => prettyplease::unparse(&file),
153 Err(_) => output.to_string(),
154 }
155 }
156
157 #[test]
158 fn simple_event() {
159 assert_snapshot!(expand_to_string(quote! {
160 struct SimpleEvent {
161 #[traceevent(timestamp)]
162 timestamp_ns: u64,
163 value: u32,
164 }
165 }));
166 }
167
168 #[test]
169 fn empty_event() {
170 assert_snapshot!(expand_to_string(quote! {
171 struct EmptyEvent {
172 #[traceevent(timestamp)]
173 timestamp_ns: u64,
174 }
175 }));
176 }
177
178 #[test]
179 fn all_field_types() {
180 assert_snapshot!(expand_to_string(quote! {
181 struct AllFieldTypes {
182 #[traceevent(timestamp)]
183 timestamp_ns: u64,
184 a_u8: u8,
185 b_u16: u16,
186 c_u32: u32,
187 d_u64: u64,
188 e_i64: i64,
189 f_f64: f64,
190 g_bool: bool,
191 h_string: String,
192 i_bytes: Vec<u8>,
193 j_interned: InternedString,
194 k_frames: StackFrames,
195 l_map: Vec<(String, String)>,
196 }
197 }));
198 }
199
200 #[test]
201 fn timestamp_attribute() {
202 assert_snapshot!(expand_to_string(quote! {
203 struct PollStart {
204 #[traceevent(timestamp)]
205 timestamp_ns: u64,
206 worker_id: u64,
207 task_id: u64,
208 }
209 }));
210 }
211}