otel_instrument/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use quote::quote;
5use std::collections::HashSet;
6use syn::{
7    Expr, Ident, ItemFn, Token,
8    ext::IdentExt,
9    parse::{Parse, ParseStream},
10    parse_macro_input,
11    spanned::Spanned,
12};
13
14#[derive(Default)]
15struct InstrumentArgs {
16    skip: HashSet<String>,
17    skip_all: bool,
18    fields: Vec<(String, Expr)>,
19    ret: bool,
20    err: Option<Expr>,
21    name: Option<String>,
22    parent: Option<Expr>,
23}
24
25impl Parse for InstrumentArgs {
26    fn parse(input: ParseStream) -> syn::Result<Self> {
27        let mut args = InstrumentArgs::default();
28
29        while !input.is_empty() {
30            let ident: Ident = input.parse()?;
31            match ident.to_string().as_str() {
32                "skip_all" => {
33                    args.skip_all = true;
34                }
35                "skip" => {
36                    let content;
37                    syn::parenthesized!(content in input);
38                    let names = content.parse_terminated(Ident::parse_any, Token![,])?;
39                    args.skip = names.into_iter().map(|i| i.to_string()).collect();
40                }
41                "fields" => {
42                    let content;
43                    syn::parenthesized!(content in input);
44                    while !content.is_empty() {
45                        let field_name: Ident = content.parse()?;
46                        let field_expr = if content.peek(Token![=]) {
47                            content.parse::<Token![=]>()?;
48                            content.parse::<Expr>()?
49                        } else {
50                            // Fallback to name = name shorthand
51                            syn::parse_quote!(#field_name)
52                        };
53                        args.fields.push((field_name.to_string(), field_expr));
54                        if !content.is_empty() {
55                            content.parse::<Token![,]>()?;
56                        }
57                    }
58                }
59                "ret" => {
60                    args.ret = true;
61                }
62                "err" => {
63                    if input.peek(Token![=]) {
64                        input.parse::<Token![=]>()?;
65                        let err_expr: Expr = input.parse()?;
66                        args.err = Some(err_expr);
67                    } else {
68                        args.err = Some(syn::parse_quote!(e));
69                    }
70                }
71                "name" => {
72                    input.parse::<Token![=]>()?;
73                    let name_str: syn::LitStr = input.parse()?;
74                    args.name = Some(name_str.value());
75                }
76                "parent" => {
77                    input.parse::<Token![=]>()?;
78                    let parent_expr: Expr = input.parse()?;
79                    args.parent = Some(parent_expr);
80                }
81                _ => {
82                    return Err(syn::Error::new_spanned(ident, "Unknown attribute"));
83                }
84            }
85
86            if !input.is_empty() {
87                input.parse::<Token![,]>()?;
88            }
89        }
90
91        Ok(args)
92    }
93}
94
95/// Define the global tracer name for instrumentation.
96/// If not called, defaults to "otel-instrument".
97///
98/// # Example
99/// ```rust
100/// use otel_instrument::tracer_name;
101///
102/// tracer_name!("my-service");
103/// ```
104#[proc_macro]
105pub fn tracer_name(input: TokenStream) -> TokenStream {
106    let tracer_name = if input.is_empty() {
107        "otel-instrument".to_string()
108    } else {
109        let literal: syn::LitStr = parse_macro_input!(input as syn::LitStr);
110        literal.value()
111    };
112
113    let expanded = quote! {
114        pub(crate) const _OTEL_TRACER_NAME: &str = #tracer_name;
115    };
116
117    expanded.into()
118}
119
120/// See crate level documentation for usage.
121#[proc_macro_attribute]
122pub fn instrument(args: TokenStream, input: TokenStream) -> TokenStream {
123    let input_fn = parse_macro_input!(input as ItemFn);
124    let args = if args.is_empty() {
125        InstrumentArgs::default()
126    } else {
127        parse_macro_input!(args as InstrumentArgs)
128    };
129
130    match instrument_impl(args, input_fn) {
131        Ok(tokens) => tokens.into(),
132        Err(err) => err.to_compile_error().into(),
133    }
134}
135
136fn extract_ident_from_pattern(pat: &syn::Pat) -> Option<Ident> {
137    match pat {
138        syn::Pat::Ident(ident) => Some(ident.ident.clone()),
139        syn::Pat::TupleStruct(tuple_struct) => {
140            // Handle patterns like State(state): State<AppState>
141            // Extract the first inner pattern if it's an identifier
142            if let Some(first_pattern) = tuple_struct.elems.first() {
143                extract_ident_from_pattern(first_pattern)
144            } else {
145                None
146            }
147        }
148        syn::Pat::Tuple(tuple) => {
149            // Handle tuple destructuring like (a, b): (i32, i32)
150            // For now, we'll take the first element
151            if let Some(first_pattern) = tuple.elems.first() {
152                extract_ident_from_pattern(first_pattern)
153            } else {
154                None
155            }
156        }
157        syn::Pat::Struct(struct_pat) => {
158            // Handle struct destructuring like User { name, age }: User
159            // For now, we'll take the first field
160            if let Some(first_field) = struct_pat.fields.first() {
161                if let syn::Member::Named(ident) = &first_field.member {
162                    Some(ident.clone())
163                } else {
164                    None
165                }
166            } else {
167                None
168            }
169        }
170        _ => None,
171    }
172}
173
174fn instrument_impl(
175    args: InstrumentArgs,
176    mut input_fn: ItemFn,
177) -> Result<proc_macro2::TokenStream, syn::Error> {
178    let fn_name = &input_fn.sig.ident;
179    let fn_name_str = fn_name.to_string();
180    let span_name = args.name.unwrap_or(fn_name_str.clone());
181
182    // Check if function is async
183    let is_async = input_fn.sig.asyncness.is_some();
184
185    // Extract function parameters for span attributes and function calls
186    let mut self_ident = None;
187    let mut param_names = Vec::new();
188    let mut param_patterns = Vec::new();
189    
190    for arg in &input_fn.sig.inputs {
191        match arg {
192            syn::FnArg::Typed(pat_type) => {
193                param_patterns.push(pat_type.pat.clone());
194                if let Some(ident) = extract_ident_from_pattern(pat_type.pat.as_ref()) {
195                    param_names.push(ident);
196                }
197            }
198            syn::FnArg::Receiver(recv) => {
199                self_ident = Some(Ident::new("self", recv.span()));
200            }
201        }
202    }
203
204    // Generate span attributes from parameters (respecting skip and skip_all)
205    let span_attrs: Vec<_> = if args.skip_all {
206        Vec::new()
207    } else {
208        param_names.iter()
209            .filter(|name| !args.skip.contains(&name.to_string()))
210            .map(|name| {
211                let name_str = name.to_string();
212                quote! {
213                    span.set_attribute(::opentelemetry::KeyValue::new(#name_str, format!("{:?}", #name)));
214                }
215            })
216            .collect()
217    };
218
219    // Generate custom field attributes
220    let field_attrs = args.fields.iter().map(|(name, expr)| {
221        quote! {
222            span.set_attribute(::opentelemetry::KeyValue::new(#name, format!("{:?}", #expr)));
223        }
224    });
225
226    // Generate return value capture if requested
227    let ret_capture = args
228        .ret
229        .then_some(quote! {
230            if let Ok(ref ret_val) = result {
231                ::opentelemetry::trace::get_active_span(|span| {
232                    span.set_attribute(
233                        ::opentelemetry::KeyValue::new("return", format!("{:?}", ret_val))
234                    );
235                });
236            }
237        })
238        .unwrap_or_default();
239
240    // Generate error capture if requested (enhanced version)
241    let err_capture = if let Some(err_expr) = &args.err {
242        quote! {
243            match &result {
244                Ok(_) => {
245                    ::opentelemetry::trace::get_active_span(|span| {
246                        span.set_status(::opentelemetry::trace::Status::Ok);
247                    });
248                }
249                Err(e) => {
250                    ::opentelemetry::trace::get_active_span(|span| {
251                        span.set_attribute(::opentelemetry::KeyValue::new("error", format!("{:?}", e)));
252                        span.set_status(::opentelemetry::trace::Status::error(format!("{:?}", e)));
253                        let err = #err_expr;
254                        span.record_error(err);
255                    });
256                }
257            }
258        }
259    } else {
260        quote! {
261            if let Ok(_) = result {
262               ::opentelemetry::trace::get_active_span(|span| {
263                   span.set_status(::opentelemetry::trace::Status::Ok);
264               });
265            }
266        }
267    };
268
269    // Generate span creation code based on whether parent is specified
270    let span_creation = if let Some(parent_expr) = &args.parent {
271        quote! {
272            use ::opentelemetry::Context;
273            // The parent_value should implement Into<Context> or be a Context
274            // This allows for flexibility in what users can pass:
275            // - Context directly
276            // - Span (which can be converted to Context)
277            // - SpanContext (which can be used to create Context)
278            let parent_ctx = #parent_expr.clone().into();
279            let mut span = tracer.start_with_context(#span_name, &parent_ctx);
280        }
281    } else {
282        quote! { let mut span = tracer.start(#span_name); }
283    };
284
285    let mut original_fn = input_fn.clone();
286    original_fn.sig.ident = syn::Ident::new(
287        &(input_fn.sig.ident.to_string() + "original"),
288        input_fn.sig.span(),
289    );
290    let original_ident = original_fn.sig.ident.clone();
291    let call = if let Some(ident) = self_ident {
292        quote! {
293            #ident.#original_ident(#(#param_patterns),*)
294        }
295    } else {
296        quote! {
297            #original_ident(#(#param_patterns),*)
298        }
299    };
300
301    // Generate the result execution block based on whether function is async or sync
302    let result_block = if is_async {
303        quote! {
304            use ::opentelemetry::{context::FutureExt, trace::TraceContextExt};
305            let result = async move {
306                let result = #call.await;
307                #ret_capture
308                #err_capture
309                result
310            }
311            .with_context(::opentelemetry::Context::current_with_span(span))
312            .await;
313        }
314    } else {
315        quote! {
316            let _guard = ::opentelemetry::trace::mark_span_as_active(span);
317            let result = #call;
318            #ret_capture
319            #err_capture
320        }
321    };
322
323    // Create the instrumented function body
324    let instrumented_body = quote! {
325        {
326            use ::opentelemetry::{trace::{Tracer, Span}, global};
327
328            let tracer = global::tracer(_OTEL_TRACER_NAME);
329            #span_creation
330            #(#span_attrs)*
331            #(#field_attrs)*
332            #result_block
333            result
334        }
335    };
336
337    // Replace the function body
338    input_fn.block = syn::parse2(instrumented_body)?;
339
340    Ok(quote! {
341        #[doc(hidden)]
342        #original_fn
343        #input_fn
344    })
345}