use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{FnArg, Ident, ImplItemFn, ItemFn, Pat, ReturnType, Signature, Type};
fn returns_result(sig: &Signature) -> bool {
let ReturnType::Type(_, ty) = &sig.output else {
return false;
};
matches!(&**ty, Type::Path(tp)
if tp.path.segments.last().is_some_and(|seg| seg.ident == "Result"))
}
fn capture_args(sig: &mut Signature) -> Vec<Ident> {
let mut captured = Vec::new();
for arg in sig.inputs.iter_mut() {
let FnArg::Typed(pat_type) = arg else {
continue;
};
let skip = pat_type.attrs.iter().any(|a| a.path().is_ident("trace"));
pat_type.attrs.retain(|a| !a.path().is_ident("trace"));
if skip {
continue;
}
if let Pat::Ident(pat_ident) = &*pat_type.pat {
captured.push(pat_ident.ident.clone());
}
}
captured
}
fn expand(kind: TokenStream2, item: TokenStream) -> TokenStream {
let item2 = TokenStream2::from(item);
let parsed = syn::parse2::<ItemFn>(item2.clone())
.map(|f| (f.attrs, f.vis, quote!(), f.sig, *f.block))
.or_else(|_| {
syn::parse2::<ImplItemFn>(item2).map(|m| {
let defaultness = m.defaultness;
(m.attrs, m.vis, quote!(#defaultness), m.sig, m.block)
})
});
let (attrs, vis, defaultness, mut sig, block) = match parsed {
Ok(parts) => parts,
Err(err) => return err.to_compile_error().into(),
};
let captured = capture_args(&mut sig);
let returns_result = returns_result(&sig);
let name = &sig.ident;
let input_capture = if captured.is_empty() {
quote!()
} else {
let inserts = captured.iter().map(|id| {
let key = id.to_string();
quote! {
__input.insert(
#key.to_string(),
trace_weft::serde_json::to_value(&#id)
.unwrap_or(trace_weft::serde_json::Value::Null),
);
}
});
quote! {
if trace_weft::capture_enabled() {
let mut __input = trace_weft::serde_json::Map::new();
#(#inserts)*
_span.input_ref = trace_weft::capture_json(
"application/json",
trace_weft::serde_json::Value::Object(__input),
).await;
}
}
};
let output_capture = if returns_result {
quote! {
if trace_weft::capture_enabled() {
if let Ok(__ok) = &result {
_span.output_ref = trace_weft::capture_json(
"application/json",
trace_weft::serde_json::to_value(__ok)
.unwrap_or(trace_weft::serde_json::Value::Null),
).await;
}
}
}
} else {
quote! {
if trace_weft::capture_enabled() {
_span.output_ref = trace_weft::capture_json(
"application/json",
trace_weft::serde_json::to_value(&result)
.unwrap_or(trace_weft::serde_json::Value::Null),
).await;
}
}
};
let status_update = if returns_result {
quote! {
match &result {
Ok(_) => { _span.status = trace_weft::SpanStatus::Ok; }
Err(__e) => {
_span.status = trace_weft::SpanStatus::Error;
_span.error_type = Some(format!("{:?}", __e));
_span.error_message_redacted = Some(format!("{}", __e));
}
}
}
} else {
quote! { _span.status = trace_weft::SpanStatus::Ok; }
};
let expanded = quote! {
#(#attrs)*
#vis #defaultness #sig {
let mut _span = trace_weft::SpanRecord {
trace_id: trace_weft::TraceId(trace_weft::uuid::Uuid::now_v7()),
span_id: trace_weft::SpanId(trace_weft::uuid::Uuid::now_v7()),
parent_span_id: None,
run_id: trace_weft::RunId(trace_weft::uuid::Uuid::now_v7()),
session_id: None,
user_id_hash: None,
project_id: None,
span_kind: trace_weft::TraceWeftSpanKind::#kind,
name: stringify!(#name).to_string(),
start_time: std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64,
end_time: None,
status: trace_weft::SpanStatus::InProgress,
status_message: None,
error_type: None,
error_message_redacted: None,
attributes: std::collections::HashMap::new(),
otel_attributes: std::collections::HashMap::new(),
openinference_attributes: std::collections::HashMap::new(),
memory_state: None,
input_ref: None,
output_ref: None,
prompt_template_id: None,
prompt_version: None,
model_provider: None,
model_name: None,
tool_name: None,
tool_schema_hash: None,
retrieval_query_hash: None,
retrieved_document_refs: vec![],
token_usage: None,
cost_estimate: None,
latency_ms: None,
retry_count: None,
cache_hit: None,
redaction_policy: trace_weft::CapturePolicy::MetadataOnly,
schema_version: "1.0".to_string(),
};
if let Some(__parent) = trace_weft::current_span_context() {
_span.trace_id = __parent.trace_id;
_span.run_id = __parent.run_id;
_span.parent_span_id = Some(__parent.span_id);
}
#input_capture
let __ctx = trace_weft::SpanContext {
trace_id: _span.trace_id,
run_id: _span.run_id,
span_id: _span.span_id,
};
let result = trace_weft::scope_current(__ctx, async move #block).await;
_span.end_time = Some(std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64);
_span.latency_ms = Some(_span.end_time.unwrap() - _span.start_time);
#status_update
#output_capture
trace_weft::record_span(_span).await;
result
}
};
TokenStream::from(expanded)
}
#[proc_macro_attribute]
pub fn agent(_attr: TokenStream, item: TokenStream) -> TokenStream {
expand(quote!(Agent), item)
}
#[proc_macro_attribute]
pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
expand(quote!(Tool), item)
}
#[proc_macro_attribute]
pub fn llm_call(_attr: TokenStream, item: TokenStream) -> TokenStream {
expand(quote!(LlmCall), item)
}