dial9_trace_format_derive/
lib.rs1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{Data, DeriveInput, Fields, parse_macro_input};
4
5fn derive_trace_event_impl(input: DeriveInput) -> proc_macro2::TokenStream {
6 let name = &input.ident;
7 let vis = &input.vis;
8 let name_str = name.to_string();
9 let ref_name = format_ident!("{}Ref", name);
10 let struct_doc_attrs: Vec<_> = input
11 .attrs
12 .iter()
13 .filter(|a| a.path().is_ident("doc"))
14 .collect();
15
16 let fields = match &input.data {
17 Data::Struct(data) => match &data.fields {
18 Fields::Named(f) => &f.named,
19 _ => panic!("TraceEvent only supports named fields"),
20 },
21 _ => panic!("TraceEvent can only be derived for structs"),
22 };
23
24 let mut timestamp_field_name = None;
26 let mut timestamp_doc_attrs = Vec::new();
27 for field in fields.iter() {
28 for attr in &field.attrs {
29 if attr.path().is_ident("traceevent") {
30 let _ = attr.parse_nested_meta(|meta| {
31 if meta.path.is_ident("timestamp") {
32 timestamp_field_name = Some(field.ident.as_ref().unwrap().clone());
33 timestamp_doc_attrs = field
34 .attrs
35 .iter()
36 .filter(|a| a.path().is_ident("doc"))
37 .cloned()
38 .collect();
39 }
40 Ok(())
41 });
42 }
43 }
44 }
45
46 let mut field_def_tokens = Vec::new();
47 let mut encode_tokens = Vec::new();
48 let mut ref_fields = Vec::new();
49 let mut decode_tokens = Vec::new();
50
51 for field in fields.iter() {
52 let field_name = field.ident.as_ref().unwrap();
53 let ty = &field.ty;
54
55 if timestamp_field_name.as_ref() == Some(field_name) {
57 continue;
58 }
59
60 let field_name_str = field_name.to_string();
61 field_def_tokens.push(quote! {
62 ::dial9_trace_format::schema::FieldDef::new(
63 #field_name_str,
64 <#ty as ::dial9_trace_format::TraceField>::field_type(),
65 )
66 });
67 encode_tokens.push(quote! {
68 <#ty as ::dial9_trace_format::TraceField>::encode(&self.#field_name, enc)?;
69 });
70
71 let doc_attrs: Vec<_> = field
72 .attrs
73 .iter()
74 .filter(|a| a.path().is_ident("doc"))
75 .collect();
76 ref_fields.push(quote! {
77 #(#doc_attrs)*
78 pub #field_name: <#ty as ::dial9_trace_format::TraceField>::Ref<'a>
79 });
80 decode_tokens.push(quote! {
81 #field_name: {
82 let val = field_defs.iter().zip(fields.iter())
83 .find(|(f, _)| f.name() == #field_name_str)
84 .map(|(_, v)| v);
85 match val {
86 Some(v) => <#ty as ::dial9_trace_format::TraceField>::decode_ref(v)?,
87 None => <#ty as ::dial9_trace_format::TraceField>::decode_missing()?,
88 }
89 }
90 });
91 }
92
93 let timestamp_impl = if let Some(ref ts_field) = timestamp_field_name {
94 quote! {
95 fn timestamp(&self) -> u64 { self.#ts_field }
96 }
97 } else {
98 panic!("TraceEvent requires a field marked with #[traceevent(timestamp)]");
99 };
100
101 let has_timestamp_impl = quote! {};
102
103 let ref_timestamp_field = if let Some(ref ts_field) = timestamp_field_name {
105 let ts_docs = ×tamp_doc_attrs;
106 quote! { #(#ts_docs)* pub #ts_field: u64, }
107 } else {
108 quote! {}
109 };
110 let decode_timestamp_init = if let Some(ref ts_field) = timestamp_field_name {
111 quote! { #ts_field: timestamp_ns?, }
112 } else {
113 quote! {}
114 };
115
116 let phantom_field =
117 if fields.is_empty() || (fields.len() == 1 && timestamp_field_name.is_some()) {
118 quote! { _marker: ::std::marker::PhantomData<&'a ()>, }
119 } else {
120 quote! {}
121 };
122 let phantom_init = if fields.is_empty() || (fields.len() == 1 && timestamp_field_name.is_some())
123 {
124 quote! { _marker: ::std::marker::PhantomData, }
125 } else {
126 quote! {}
127 };
128
129 quote! {
130 #(#struct_doc_attrs)*
131 #[derive(Debug, Clone)]
132 #vis struct #ref_name<'a> {
133 #ref_timestamp_field
134 #(#ref_fields,)*
135 #phantom_field
136 }
137
138 impl ::dial9_trace_format::TraceEvent for #name {
139 type Ref<'a> = #ref_name<'a>;
140
141 fn event_name() -> &'static str { #name_str }
142 fn field_defs() -> Vec<::dial9_trace_format::schema::FieldDef> {
143 vec![#(#field_def_tokens),*]
144 }
145 #timestamp_impl
146 #has_timestamp_impl
147 fn encode_fields<W: ::std::io::Write>(&self, enc: &mut ::dial9_trace_format::EventEncoder<'_, W>) -> ::std::io::Result<()> {
148 #(#encode_tokens)*
149 Ok(())
150 }
151 fn decode<'a>(timestamp_ns: Option<u64>, fields: &[::dial9_trace_format::types::FieldValueRef<'a>], field_defs: &[::dial9_trace_format::schema::FieldDef]) -> Option<Self::Ref<'a>> {
152 Some(#ref_name {
153 #decode_timestamp_init
154 #(#decode_tokens,)*
155 #phantom_init
156 })
157 }
158 }
159 }
160}
161
162#[proc_macro_derive(TraceEvent, attributes(traceevent))]
163pub fn derive_trace_event(input: TokenStream) -> TokenStream {
164 let input = parse_macro_input!(input as DeriveInput);
165 TokenStream::from(derive_trace_event_impl(input))
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use insta::assert_snapshot;
172
173 fn expand_to_string(input: proc_macro2::TokenStream) -> String {
174 let input: DeriveInput = syn::parse2(input).unwrap();
175 let output = derive_trace_event_impl(input);
176 match syn::parse2::<syn::File>(output.clone()) {
177 Ok(file) => prettyplease::unparse(&file),
178 Err(_) => output.to_string(),
179 }
180 }
181
182 #[test]
183 fn simple_event() {
184 assert_snapshot!(expand_to_string(quote! {
185 struct SimpleEvent {
186 #[traceevent(timestamp)]
187 timestamp_ns: u64,
188 value: u32,
189 }
190 }));
191 }
192
193 #[test]
194 fn empty_event() {
195 assert_snapshot!(expand_to_string(quote! {
196 struct EmptyEvent {
197 #[traceevent(timestamp)]
198 timestamp_ns: u64,
199 }
200 }));
201 }
202
203 #[test]
204 fn all_field_types() {
205 assert_snapshot!(expand_to_string(quote! {
206 struct AllFieldTypes {
207 #[traceevent(timestamp)]
208 timestamp_ns: u64,
209 a_u8: u8,
210 b_u16: u16,
211 c_u32: u32,
212 d_u64: u64,
213 e_i64: i64,
214 f_f64: f64,
215 g_bool: bool,
216 h_string: String,
217 i_bytes: Vec<u8>,
218 j_interned: InternedString,
219 k_frames: StackFrames,
220 l_map: Vec<(String, String)>,
221 }
222 }));
223 }
224
225 #[test]
226 fn doc_comments_copied_to_ref_fields() {
227 assert_snapshot!(expand_to_string(quote! {
228 struct DocEvent {
230 #[traceevent(timestamp)]
231 timestamp_ns: u64,
233 worker_id: u64,
235 local_queue: u8,
237 }
238 }));
239 }
240
241 #[test]
242 fn timestamp_attribute() {
243 assert_snapshot!(expand_to_string(quote! {
244 struct PollStart {
245 #[traceevent(timestamp)]
246 timestamp_ns: u64,
247 worker_id: u64,
248 task_id: u64,
249 }
250 }));
251 }
252}