Skip to main content

langgraph_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Data, Fields, Attribute, Lit, ItemFn, ReturnType};
4
5/// Derive macro for StateGraph state types.
6///
7/// Annotate fields with `#[channel(reducer = "fn_name")]` to specify
8/// a reducer function for that channel. Fields without the attribute
9/// use LastValue (default).
10///
11/// **Robustness Check**: This macro enforces that every field must have 
12/// `#[serde(default)]` (or be an `Option` which handles missing keys gracefully).
13/// This prevents silent state loss during graph resume operations.
14#[proc_macro_derive(StateGraph, attributes(channel))]
15pub fn derive_state_graph(input: TokenStream) -> TokenStream {
16    let input = parse_macro_input!(input as DeriveInput);
17    impl_state_graph(&input)
18}
19
20
21/// This attribute macro:
22/// 1. Automatically adds `#[derive(serde::Serialize, serde::Deserialize, Clone, Default, StateGraph)]`.
23/// 2. Automatically injects `#[serde(default)]` on every field to ensure robustness.
24/// 
25/// Usage:
26/// ```rust,ignore
27/// #[langgraph_state]
28/// struct MyState {
29///     #[channel(messages)]
30///     messages: Vec<Message>,
31///     other_field: String,
32/// }
33/// ```
34#[proc_macro_attribute]
35pub fn langgraph_state(_attr: TokenStream, item: TokenStream) -> TokenStream {
36    let mut input = parse_macro_input!(item as syn::ItemStruct);
37    
38    // 1. Add the "big bunch" of derives
39    input.attrs.push(syn::parse_quote! {
40        #[derive(serde::Serialize, serde::Deserialize, Clone, Default, langgraph::StateGraph)]
41    });
42
43    // 2. Walk fields and ensure #[serde(default)] exists
44    if let syn::Fields::Named(fields) = &mut input.fields {
45        for field in &mut fields.named {
46            let mut has_default = false;
47            for attr in &field.attrs {
48                if attr.path().is_ident("serde") {
49                    let _ = attr.parse_nested_meta(|meta| {
50                        if meta.path.is_ident("default") {
51                            has_default = true;
52                        }
53                        Ok(())
54                    });
55                }
56            }
57
58            if !has_default {
59                field.attrs.push(syn::parse_quote! {
60                    #[serde(default)]
61                });
62            }
63        }
64    }
65
66    let expanded = quote! {
67        #input
68    };
69
70    TokenStream::from(expanded)
71}
72
73fn impl_state_graph(input: &DeriveInput) -> TokenStream {
74    let name = &input.ident;
75    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
76
77    let fields = match &input.data {
78        Data::Struct(data) => match &data.fields {
79            Fields::Named(fields) => &fields.named,
80            _ => panic!("StateGraph can only be derived for structs with named fields"),
81        },
82        _ => panic!("StateGraph can only be derived for structs"),
83    };
84
85    // ── ROBUSTNESS CHECK ─────────────────────────────────────────────────────
86    // For every field, ensure it has #[serde(default)]
87    for field in fields {
88        let field_name = field.ident.as_ref().unwrap();
89        
90        let mut has_serde_default = false;
91        for attr in &field.attrs {
92            if attr.path().is_ident("serde") {
93                let _ = attr.parse_nested_meta(|meta| {
94                    if meta.path.is_ident("default") {
95                        has_serde_default = true;
96                    }
97                    Ok(())
98                });
99            }
100        }
101
102        if !has_serde_default {
103            let error_msg = format!(
104                "Field `{}` in `{}` is missing `#[serde(default)]`. \
105                 LangGraph states require this attribute on all fields to prevent \
106                 state loss during resume operations. Please add `#[serde(default)]` \
107                 to this field.",
108                field_name, name
109            );
110            return syn::Error::new_spanned(field, error_msg).to_compile_error().into();
111        }
112    }
113
114    let channel_registrations: Vec<proc_macro2::TokenStream> = fields
115        .iter()
116        .map(|field| {
117            let field_name = field.ident.as_ref().unwrap();
118            let field_name_str = field_name.to_string();
119
120            // Check for channel attribute
121            let reducer = get_channel_reducer(&field.attrs);
122
123            match reducer {
124                Some(ReducerSpec::Named(fn_name)) => {
125                    let fn_ident = syn::Ident::new(&fn_name, proc_macro2::Span::call_site());
126                    quote! {
127                        channels.insert(
128                            #field_name_str.to_string(),
129                            Box::new(langgraph::channels::BinaryOperatorAggregate::new(
130                                #field_name_str,
131                                #fn_ident,
132                            )) as Box<dyn langgraph::channels::Channel>
133                        );
134                    }
135                }
136                Some(ReducerSpec::Messages) => {
137                    quote! {
138                        channels.insert(
139                            #field_name_str.to_string(),
140                            Box::new(langgraph::channels::BinaryOperatorAggregate::new(
141                                #field_name_str,
142                                langgraph::prebuilt::add_messages_ref,
143                            )) as Box<dyn langgraph::channels::Channel>
144                        );
145                    }
146                }
147                None => {
148                    quote! {
149                        channels.insert(
150                            #field_name_str.to_string(),
151                            Box::new(langgraph::channels::LastValue::new(#field_name_str)) as Box<dyn langgraph::channels::Channel>
152                        );
153                    }
154                }
155            }
156        })
157        .collect();
158
159    let expanded = quote! {
160        impl #impl_generics #name #ty_generics #where_clause {
161            pub fn create_channels() -> std::collections::HashMap<String, Box<dyn langgraph::channels::Channel>> {
162                let mut channels = std::collections::HashMap::new();
163                #(#channel_registrations)*
164                channels
165            }
166        }
167    };
168
169    TokenStream::from(expanded)
170}
171
172/// The type of reducer for a channel.
173enum ReducerSpec {
174    /// A named reducer function: `#[channel(reducer = "fn_name")]`
175    Named(String),
176    /// The built-in messages reducer: `#[channel(messages)]`
177    Messages,
178}
179
180fn get_channel_reducer(attrs: &[Attribute]) -> Option<ReducerSpec> {
181    for attr in attrs {
182        if !attr.path().is_ident("channel") {
183            continue;
184        }
185
186        let mut result = None;
187
188        attr.parse_nested_meta(|meta| {
189            if meta.path.is_ident("reducer") {
190                let value = meta.value()?;
191                let lit: Lit = value.parse()?;
192                if let Lit::Str(s) = lit {
193                    result = Some(ReducerSpec::Named(s.value()));
194                }
195                Ok(())
196            } else if meta.path.is_ident("messages") {
197                result = Some(ReducerSpec::Messages);
198                Ok(())
199            } else {
200                Err(meta.error("unknown channel attribute"))
201            }
202        })
203        .ok();
204
205        return result;
206    }
207    None
208}
209
210// ============================================================================
211// #[tool] attribute macro
212// ============================================================================
213
214/// Attribute macro to define a tool from a function.
215#[proc_macro_attribute]
216pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
217    let func = parse_macro_input!(item as ItemFn);
218    let args = parse_macro_input!(attr as ToolMacroArgs);
219    impl_tool_macro(&args.name, &args.description, &func)
220}
221
222struct ToolMacroArgs {
223    name: Option<Lit>,
224    description: Option<Lit>,
225}
226
227impl syn::parse::Parse for ToolMacroArgs {
228    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
229        if input.is_empty() {
230            return Ok(Self { name: None, description: None });
231        }
232        let name: Lit = input.parse()?;
233        let description = if input.peek(syn::Token![,]) {
234            input.parse::<syn::Token![,]>()?;
235            Some(input.parse()?)
236        } else {
237            None
238        };
239        Ok(Self { name: Some(name), description })
240    }
241}
242
243fn impl_tool_macro(name_lit: &Option<Lit>, desc_lit: &Option<Lit>, func: &ItemFn) -> TokenStream {
244    let fn_name = &func.sig.ident;
245    let fn_name_str = fn_name.to_string();
246
247    let tool_name = if let Some(Lit::Str(s)) = name_lit {
248        s.value()
249    } else {
250        fn_name_str.clone()
251    };
252
253    // Extract parameter descriptions from @param lines in doc comments.
254    let param_descs = extract_param_descs(func);
255
256    let description = if let Some(desc) = desc_lit {
257        match desc {
258            Lit::Str(s) => s.value(),
259            _ => panic!("description must be a string literal"),
260        }
261    } else {
262        let mut extracted_desc = String::new();
263        for attr in &func.attrs {
264            if attr.path().is_ident("doc") {
265                if let syn::Meta::NameValue(nv) = &attr.meta {
266                    if let syn::Expr::Lit(expr_lit) = &nv.value {
267                        if let syn::Lit::Str(lit_str) = &expr_lit.lit {
268                            let doc_str = lit_str.value();
269                            let trimmed = doc_str.trim();
270                            // Skip @param lines — they are for schema, not description.
271                            if trimmed.starts_with("@param ") {
272                                continue;
273                            }
274                            if !extracted_desc.is_empty() {
275                                extracted_desc.push_str(" ");
276                            }
277                            extracted_desc.push_str(trimmed);
278                        }
279                    }
280                }
281            }
282        }
283        extracted_desc
284    };
285
286    let struct_name_str = to_camel_case(&fn_name_str);
287    let struct_name = syn::Ident::new(&struct_name_str, fn_name.span());
288
289    let params: Vec<_> = func.sig.inputs.iter().filter_map(|arg| {
290        if let syn::FnArg::Typed(pat_type) = arg {
291            if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
292                return Some((pat_ident.ident.clone(), (*pat_type.ty).clone()));
293            }
294        }
295        None
296    }).collect();
297
298    let properties: Vec<proc_macro2::TokenStream> = params.iter().map(|(name, ty)| {
299        let name_str = name.to_string();
300        let actual_ty = if is_option(ty) { extract_type_from_option(ty) } else { ty };
301        let json_type = rust_type_to_json_type(actual_ty);
302        if let Some(d) = param_descs.get(&name_str) {
303            quote! {
304                (#name_str, serde_json::json!({"type": #json_type, "description": #d}))
305            }
306        } else {
307            quote! {
308                (#name_str, serde_json::json!({"type": #json_type}))
309            }
310        }
311    }).collect();
312
313    let required: Vec<String> = params.iter()
314        .filter(|(_, ty)| !is_option(ty))
315        .map(|(name, _)| name.to_string())
316        .collect();
317
318    let extractions: Vec<proc_macro2::TokenStream> = params.iter().map(|(name, ty)| {
319        let name_str = name.to_string();
320        let err_invalid = format!("invalid parameter '{}': {{}}", name_str);
321        
322        if is_option(ty) {
323            quote! {
324                let #name: #ty = match args.get(#name_str) {
325                    Some(v) => serde_json::from_value(v.clone())
326                        .map_err(|e| langgraph::prebuilt::ToolError::InvalidArgs(format!(#err_invalid, e)))?,
327                    None => None,
328                };
329            }
330        } else {
331            let err_missing = format!("missing required parameter '{}'", name_str);
332            quote! {
333                let #name: #ty = serde_json::from_value(
334                    args.get(#name_str)
335                        .cloned()
336                        .ok_or_else(|| langgraph::prebuilt::ToolError::InvalidArgs(#err_missing.to_string()))?
337                ).map_err(|e| langgraph::prebuilt::ToolError::InvalidArgs(
338                    format!(#err_invalid, e)
339                ))?;
340            }
341        }
342    }).collect();
343
344    let param_names: Vec<_> = params.iter().map(|(name, _)| name.clone()).collect();
345
346    let is_result_return = match &func.sig.output {
347        ReturnType::Type(_, ty) => {
348            if let syn::Type::Path(type_path) = ty.as_ref() {
349                type_path.path.segments.last()
350                    .map(|s| s.ident == "Result")
351                    .unwrap_or(false)
352            } else {
353                false
354            }
355        }
356        _ => false,
357    };
358
359    let is_async = func.sig.asyncness.is_some();
360
361    let await_tokens = if is_async {
362        quote! { .await }
363    } else {
364        quote! {}
365    };
366
367    let invoke_body = if is_result_return {
368        quote! {
369            #(#extractions)*
370            let result = #fn_name(#(#param_names),*)#await_tokens;
371            result
372                .map_err(|e| {
373                    let tool_err: langgraph::prebuilt::ToolError = e.into();
374                    tool_err
375                })
376                .and_then(|r| serde_json::to_value(r).map_err(|e| langgraph::prebuilt::ToolError::Execution(
377                    format!("failed to serialize result: {}", e)
378                )))
379        }
380    } else {
381        quote! {
382            #(#extractions)*
383            let result = #fn_name(#(#param_names),*)#await_tokens;
384            serde_json::to_value(result).map_err(|e| langgraph::prebuilt::ToolError::Execution(
385                format!("failed to serialize result: {}", e)
386            ))
387        }
388    };
389
390    let trait_methods = if is_async {
391        quote! {
392            fn invoke(
393                &self,
394                _args: &serde_json::Value,
395                _config: &langgraph::checkpoint::config::RunnableConfig,
396            ) -> Result<serde_json::Value, langgraph::prebuilt::ToolError> {
397                Err(langgraph::prebuilt::ToolError::Execution(
398                    "This tool is asynchronous and must be invoked with ainvoke".to_string()
399                ))
400            }
401
402            async fn ainvoke(
403                &self,
404                args: &serde_json::Value,
405                _config: &langgraph::checkpoint::config::RunnableConfig,
406            ) -> Result<serde_json::Value, langgraph::prebuilt::ToolError> {
407                #invoke_body
408            }
409        }
410    } else {
411        quote! {
412            fn invoke(
413                &self,
414                args: &serde_json::Value,
415                _config: &langgraph::checkpoint::config::RunnableConfig,
416            ) -> Result<serde_json::Value, langgraph::prebuilt::ToolError> {
417                #invoke_body
418            }
419        }
420    };
421
422    let expanded = quote! {
423        #func
424        pub struct #struct_name;
425        impl #struct_name {
426            pub fn new() -> Self { Self }
427        }
428        impl Default for #struct_name {
429            fn default() -> Self { Self }
430        }
431        #[async_trait::async_trait]
432        impl langgraph::prebuilt::BaseTool for #struct_name {
433            fn name(&self) -> &str { #tool_name }
434            fn description(&self) -> &str { #description }
435            fn parameters(&self) -> Option<&serde_json::Value> {
436                use std::sync::OnceLock;
437                static SCHEMA: OnceLock<serde_json::Value> = OnceLock::new();
438                Some(SCHEMA.get_or_init(|| {
439                    let mut properties = serde_json::Map::new();
440                    #(
441                        {
442                            let (k, v) = #properties;
443                            properties.insert(k.to_string(), v);
444                        }
445                    )*
446                    serde_json::json!({
447                        "type": "object",
448                        "properties": properties,
449                        "required": [#(#required),*]
450                    })
451                }))
452            }
453            #trait_methods
454        }
455    };
456
457    TokenStream::from(expanded)
458}
459
460fn to_camel_case(s: &str) -> String {
461    s.split('_')
462        .map(|word| {
463            let mut chars = word.chars();
464            match chars.next() {
465                Some(c) => c.to_uppercase().to_string() + &chars.as_str().to_lowercase(),
466                None => String::new(),
467            }
468        })
469        .collect()
470}
471
472fn rust_type_to_json_type(ty: &syn::Type) -> &'static str {
473    if let syn::Type::Path(type_path) = ty {
474        let type_name = type_path.path.segments.last()
475            .map(|s| s.ident.to_string())
476            .unwrap_or_default();
477
478        match type_name.as_str() {
479            "String" | "str" => "string",
480            "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "isize" | "usize" => "integer",
481            "f32" | "f64" => "number",
482            "bool" => "boolean",
483            _ => "string", // fallback
484        }
485    } else {
486        "string"
487    }
488}
489
490// ============================================================================
491// #[derive(Traceable)]
492// ============================================================================
493#[proc_macro_derive(Traceable)]
494pub fn derive_traceable(input: TokenStream) -> TokenStream {
495    let input = parse_macro_input!(input as DeriveInput);
496    impl_traceable(&input)
497}
498
499fn impl_traceable(input: &DeriveInput) -> TokenStream {
500    let name = &input.ident;
501    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
502    let expanded = quote! {
503        impl #impl_generics #name #ty_generics #where_clause {
504            pub fn tracing_context() -> langgraph_tracing::TracingContext {
505                langgraph_tracing::TracingContext::new()
506            }
507        }
508    };
509    TokenStream::from(expanded)
510}
511
512fn is_option(ty: &syn::Type) -> bool {
513    if let syn::Type::Path(type_path) = ty {
514        if let Some(segment) = type_path.path.segments.last() {
515            return segment.ident == "Option";
516        }
517    }
518    false
519}
520
521fn extract_type_from_option(ty: &syn::Type) -> &syn::Type {
522    if let syn::Type::Path(type_path) = ty {
523        if let Some(segment) = type_path.path.segments.last() {
524            if segment.ident == "Option" {
525                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
526                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
527                        return inner_ty;
528                    }
529                }
530            }
531        }
532    }
533    ty
534}
535
536fn extract_param_descs(func: &ItemFn) -> std::collections::HashMap<String, String> {
537    let mut descs = std::collections::HashMap::new();
538    for attr in &func.attrs {
539        if !attr.path().is_ident("doc") {
540            continue;
541        }
542        if let syn::Meta::NameValue(nv) = &attr.meta {
543            if let syn::Expr::Lit(expr_lit) = &nv.value {
544                if let syn::Lit::Str(lit_str) = &expr_lit.lit {
545                    let line = lit_str.value();
546                    let trimmed = line.trim();
547                    // Parse "@param name description"
548                    if let Some(rest) = trimmed.strip_prefix("@param ") {
549                        let rest = rest.trim_start();
550                        if let Some(space_idx) = rest.find(char::is_whitespace) {
551                            let name = rest[..space_idx].to_string();
552                            let desc = rest[space_idx..].trim().to_string();
553                            if !desc.is_empty() {
554                                descs.insert(name, desc);
555                            }
556                        }
557                    }
558                }
559            }
560        }
561    }
562    descs
563}