Skip to main content

trace_weft_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{FnArg, Ident, ImplItemFn, ItemFn, Pat, ReturnType, Signature, Type};
5
6/// Whether the function's declared return type is a `Result<_, _>` (by the last
7/// path segment, so `Result`, `std::result::Result`, and `anyhow::Result` all
8/// match). Used to decide whether the recorded span can fail.
9fn returns_result(sig: &Signature) -> bool {
10    let ReturnType::Type(_, ty) = &sig.output else {
11        return false;
12    };
13    matches!(&**ty, Type::Path(tp)
14        if tp.path.segments.last().is_some_and(|seg| seg.ident == "Result"))
15}
16
17/// Strip `#[trace(..)]` attributes off the arguments (so the regenerated
18/// signature stays valid) and return the idents of arguments to capture: every
19/// by-name typed argument not marked `#[trace(skip)]`. The receiver (`self`)
20/// and non-ident patterns are never captured.
21fn capture_args(sig: &mut Signature) -> Vec<Ident> {
22    let mut captured = Vec::new();
23    for arg in sig.inputs.iter_mut() {
24        let FnArg::Typed(pat_type) = arg else {
25            continue;
26        };
27        let skip = pat_type.attrs.iter().any(|a| a.path().is_ident("trace"));
28        pat_type.attrs.retain(|a| !a.path().is_ident("trace"));
29        if skip {
30            continue;
31        }
32        if let Pat::Ident(pat_ident) = &*pat_type.pat {
33            captured.push(pat_ident.ident.clone());
34        }
35    }
36    captured
37}
38
39/// Shared expansion for the instrumentation attributes. `kind` is the
40/// `TraceWeftSpanKind` variant ident to stamp on the recorded span.
41///
42/// Accepts both free functions and `impl`/trait-impl methods (including
43/// `&self` receivers). Trait *definitions* carry no body, so the target is the
44/// concrete `impl`. The function must be `async`.
45fn expand(kind: TokenStream2, item: TokenStream) -> TokenStream {
46    let item2 = TokenStream2::from(item);
47
48    // A normal method parses as an `ItemFn`; the `ImplItemFn` fallback covers
49    // `default fn` and other impl-only shapes. We keep `attrs` and any
50    // `defaultness` so doc comments, stacked attributes, and `default` survive.
51    let parsed = syn::parse2::<ItemFn>(item2.clone())
52        .map(|f| (f.attrs, f.vis, quote!(), f.sig, *f.block))
53        .or_else(|_| {
54            syn::parse2::<ImplItemFn>(item2).map(|m| {
55                let defaultness = m.defaultness;
56                (m.attrs, m.vis, quote!(#defaultness), m.sig, m.block)
57            })
58        });
59    let (attrs, vis, defaultness, mut sig, block) = match parsed {
60        Ok(parts) => parts,
61        Err(err) => return err.to_compile_error().into(),
62    };
63    let captured = capture_args(&mut sig);
64    let returns_result = returns_result(&sig);
65    let name = &sig.ident;
66
67    // Serialize captured args into `input_ref` before the body moves them.
68    // Guarded by `capture_enabled()` so a `MetadataOnly` process pays nothing
69    // beyond the (compile-time) `Serialize` bound on captured args.
70    let input_capture = if captured.is_empty() {
71        quote!()
72    } else {
73        let inserts = captured.iter().map(|id| {
74            let key = id.to_string();
75            quote! {
76                __input.insert(
77                    #key.to_string(),
78                    trace_weft::serde_json::to_value(&#id)
79                        .unwrap_or(trace_weft::serde_json::Value::Null),
80                );
81            }
82        });
83        quote! {
84            if trace_weft::capture_enabled() {
85                let mut __input = trace_weft::serde_json::Map::new();
86                #(#inserts)*
87                _span.input_ref = trace_weft::capture_json(
88                    "application/json",
89                    trace_weft::serde_json::Value::Object(__input),
90                ).await;
91            }
92        }
93    };
94
95    // Serialize the successful output into `output_ref`.
96    let output_capture = if returns_result {
97        quote! {
98            if trace_weft::capture_enabled() {
99                if let Ok(__ok) = &result {
100                    _span.output_ref = trace_weft::capture_json(
101                        "application/json",
102                        trace_weft::serde_json::to_value(__ok)
103                            .unwrap_or(trace_weft::serde_json::Value::Null),
104                    ).await;
105                }
106            }
107        }
108    } else {
109        quote! {
110            if trace_weft::capture_enabled() {
111                _span.output_ref = trace_weft::capture_json(
112                    "application/json",
113                    trace_weft::serde_json::to_value(&result)
114                        .unwrap_or(trace_weft::serde_json::Value::Null),
115                ).await;
116            }
117        }
118    };
119
120    // A `Result`-returning body sets Error status on `Err`; everything else
121    // always completes Ok. Only the `Result` arm touches `result` by reference,
122    // so a non-`Result` body never gains a spurious `Debug`/`Display` bound.
123    let status_update = if returns_result {
124        quote! {
125            match &result {
126                Ok(_) => { _span.status = trace_weft::SpanStatus::Ok; }
127                Err(__e) => {
128                    _span.status = trace_weft::SpanStatus::Error;
129                    _span.error_type = Some(std::any::type_name_of_val(__e).to_string());
130                    _span.error_message_redacted = Some(trace_weft::redact_text(&format!("{}", __e)).redacted_text);
131                }
132            }
133        }
134    } else {
135        quote! { _span.status = trace_weft::SpanStatus::Ok; }
136    };
137
138    let expanded = quote! {
139        #(#attrs)*
140        #vis #defaultness #sig {
141            let mut _span = trace_weft::SpanRecord {
142                trace_id: trace_weft::TraceId(trace_weft::uuid::Uuid::now_v7()),
143                span_id: trace_weft::SpanId(trace_weft::uuid::Uuid::now_v7()),
144                parent_span_id: None,
145                run_id: trace_weft::RunId(trace_weft::uuid::Uuid::now_v7()),
146                session_id: None,
147                user_id_hash: None,
148                project_id: None,
149                span_kind: trace_weft::TraceWeftSpanKind::#kind,
150                name: stringify!(#name).to_string(),
151                start_time: std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64,
152                end_time: None,
153                status: trace_weft::SpanStatus::InProgress,
154                status_message: None,
155                error_type: None,
156                error_message_redacted: None,
157                attributes: std::collections::HashMap::new(),
158                otel_attributes: std::collections::HashMap::new(),
159                openinference_attributes: std::collections::HashMap::new(),
160                memory_state: None,
161                input_ref: None,
162                output_ref: None,
163                prompt_template_id: None,
164                prompt_version: None,
165                model_provider: None,
166                model_name: None,
167                tool_name: None,
168                tool_schema_hash: None,
169                retrieval_query_hash: None,
170                retrieved_document_refs: vec![],
171                token_usage: None,
172                cost_estimate: None,
173                latency_ms: None,
174                retry_count: None,
175                cache_hit: None,
176                redaction_policy: trace_weft::capture_policy(),
177                schema_version: "1.0".to_string(),
178            };
179
180            if let Some(__parent) = trace_weft::current_span_context() {
181                _span.trace_id = __parent.trace_id;
182                _span.run_id = __parent.run_id;
183                _span.parent_span_id = Some(__parent.span_id);
184            }
185
186            #input_capture
187
188            let __ctx = trace_weft::SpanContext {
189                trace_id: _span.trace_id,
190                run_id: _span.run_id,
191                span_id: _span.span_id,
192            };
193            let result = trace_weft::scope_current(__ctx, async move #block).await;
194
195            _span.end_time = Some(std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64);
196            _span.latency_ms = Some(_span.end_time.unwrap() - _span.start_time);
197            #status_update
198            #output_capture
199            trace_weft::record_span(_span).await;
200
201            result
202        }
203    };
204
205    TokenStream::from(expanded)
206}
207
208#[proc_macro_attribute]
209pub fn agent(_attr: TokenStream, item: TokenStream) -> TokenStream {
210    expand(quote!(Agent), item)
211}
212
213#[proc_macro_attribute]
214pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
215    expand(quote!(Tool), item)
216}
217
218#[proc_macro_attribute]
219pub fn llm_call(_attr: TokenStream, item: TokenStream) -> TokenStream {
220    expand(quote!(LlmCall), item)
221}