1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Data, Fields, Attribute, Lit, ItemFn, ReturnType};
4
5#[proc_macro_derive(StateGraph, attributes(channel))]
15pub fn derive_state_graph(input: TokenStream) -> TokenStream {
16 let input = parse_macro_input!(input as DeriveInput);
17 impl_state_graph(&input)
18}
19
20
21#[proc_macro_attribute]
35pub fn langgraph_state(_attr: TokenStream, item: TokenStream) -> TokenStream {
36 let mut input = parse_macro_input!(item as syn::ItemStruct);
37
38 input.attrs.push(syn::parse_quote! {
40 #[derive(serde::Serialize, serde::Deserialize, Clone, Default, langgraph_derive::StateGraph)]
41 });
42
43 if let syn::Fields::Named(fields) = &mut input.fields {
45 for field in &mut fields.named {
46 let mut has_default = false;
47 for attr in &field.attrs {
48 if attr.path().is_ident("serde") {
49 let _ = attr.parse_nested_meta(|meta| {
50 if meta.path.is_ident("default") {
51 has_default = true;
52 }
53 Ok(())
54 });
55 }
56 }
57
58 if !has_default {
59 field.attrs.push(syn::parse_quote! {
60 #[serde(default)]
61 });
62 }
63 }
64 }
65
66 let expanded = quote! {
67 #input
68 };
69
70 TokenStream::from(expanded)
71}
72
73fn impl_state_graph(input: &DeriveInput) -> TokenStream {
74 let name = &input.ident;
75 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
76
77 let fields = match &input.data {
78 Data::Struct(data) => match &data.fields {
79 Fields::Named(fields) => &fields.named,
80 _ => panic!("StateGraph can only be derived for structs with named fields"),
81 },
82 _ => panic!("StateGraph can only be derived for structs"),
83 };
84
85 for field in fields {
88 let field_name = field.ident.as_ref().unwrap();
89
90 let mut has_serde_default = false;
91 for attr in &field.attrs {
92 if attr.path().is_ident("serde") {
93 let _ = attr.parse_nested_meta(|meta| {
94 if meta.path.is_ident("default") {
95 has_serde_default = true;
96 }
97 Ok(())
98 });
99 }
100 }
101
102 if !has_serde_default {
103 let error_msg = format!(
104 "Field `{}` in `{}` is missing `#[serde(default)]`. \
105 LangGraph states require this attribute on all fields to prevent \
106 state loss during resume operations. Please add `#[serde(default)]` \
107 to this field.",
108 field_name, name
109 );
110 return syn::Error::new_spanned(field, error_msg).to_compile_error().into();
111 }
112 }
113
114 let channel_registrations: Vec<proc_macro2::TokenStream> = fields
115 .iter()
116 .map(|field| {
117 let field_name = field.ident.as_ref().unwrap();
118 let field_name_str = field_name.to_string();
119
120 let reducer = get_channel_reducer(&field.attrs);
122
123 match reducer {
124 Some(ReducerSpec::Named(fn_name)) => {
125 let fn_ident = syn::Ident::new(&fn_name, proc_macro2::Span::call_site());
126 quote! {
127 channels.insert(
128 #field_name_str.to_string(),
129 Box::new(langgraph::channels::BinaryOperatorAggregate::new(
130 #field_name_str,
131 #fn_ident,
132 )) as Box<dyn langgraph::channels::Channel>
133 );
134 }
135 }
136 Some(ReducerSpec::Messages) => {
137 quote! {
138 channels.insert(
139 #field_name_str.to_string(),
140 Box::new(langgraph::channels::BinaryOperatorAggregate::new(
141 #field_name_str,
142 langgraph_prebuilt::add_messages_ref,
143 )) as Box<dyn langgraph::channels::Channel>
144 );
145 }
146 }
147 None => {
148 quote! {
149 channels.insert(
150 #field_name_str.to_string(),
151 Box::new(langgraph::channels::LastValue::new(#field_name_str)) as Box<dyn langgraph::channels::Channel>
152 );
153 }
154 }
155 }
156 })
157 .collect();
158
159 let expanded = quote! {
160 impl #impl_generics #name #ty_generics #where_clause {
161 pub fn create_channels() -> std::collections::HashMap<String, Box<dyn langgraph::channels::Channel>> {
162 let mut channels = std::collections::HashMap::new();
163 #(#channel_registrations)*
164 channels
165 }
166 }
167 };
168
169 TokenStream::from(expanded)
170}
171
172enum ReducerSpec {
174 Named(String),
176 Messages,
178}
179
180fn get_channel_reducer(attrs: &[Attribute]) -> Option<ReducerSpec> {
181 for attr in attrs {
182 if !attr.path().is_ident("channel") {
183 continue;
184 }
185
186 let mut result = None;
187
188 attr.parse_nested_meta(|meta| {
189 if meta.path.is_ident("reducer") {
190 let value = meta.value()?;
191 let lit: Lit = value.parse()?;
192 if let Lit::Str(s) = lit {
193 result = Some(ReducerSpec::Named(s.value()));
194 }
195 Ok(())
196 } else if meta.path.is_ident("messages") {
197 result = Some(ReducerSpec::Messages);
198 Ok(())
199 } else {
200 Err(meta.error("unknown channel attribute"))
201 }
202 })
203 .ok();
204
205 return result;
206 }
207 None
208}
209
210#[proc_macro_attribute]
216pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
217 let func = parse_macro_input!(item as ItemFn);
218 let args = parse_macro_input!(attr as ToolMacroArgs);
219 impl_tool_macro(&args.name, &args.description, &func)
220}
221
222struct ToolMacroArgs {
223 name: Option<Lit>,
224 description: Option<Lit>,
225}
226
227impl syn::parse::Parse for ToolMacroArgs {
228 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
229 if input.is_empty() {
230 return Ok(Self { name: None, description: None });
231 }
232 let name: Lit = input.parse()?;
233 let description = if input.peek(syn::Token![,]) {
234 input.parse::<syn::Token![,]>()?;
235 Some(input.parse()?)
236 } else {
237 None
238 };
239 Ok(Self { name: Some(name), description })
240 }
241}
242
243fn impl_tool_macro(name_lit: &Option<Lit>, desc_lit: &Option<Lit>, func: &ItemFn) -> TokenStream {
244 let fn_name = &func.sig.ident;
245 let fn_name_str = fn_name.to_string();
246
247 let tool_name = if let Some(Lit::Str(s)) = name_lit {
248 s.value()
249 } else {
250 fn_name_str.clone()
251 };
252
253 let param_descs = extract_param_descs(func);
255
256 let description = if let Some(desc) = desc_lit {
257 match desc {
258 Lit::Str(s) => s.value(),
259 _ => panic!("description must be a string literal"),
260 }
261 } else {
262 let mut extracted_desc = String::new();
263 for attr in &func.attrs {
264 if attr.path().is_ident("doc") {
265 if let syn::Meta::NameValue(nv) = &attr.meta {
266 if let syn::Expr::Lit(expr_lit) = &nv.value {
267 if let syn::Lit::Str(lit_str) = &expr_lit.lit {
268 let doc_str = lit_str.value();
269 let trimmed = doc_str.trim();
270 if trimmed.starts_with("@param ") {
272 continue;
273 }
274 if !extracted_desc.is_empty() {
275 extracted_desc.push_str(" ");
276 }
277 extracted_desc.push_str(trimmed);
278 }
279 }
280 }
281 }
282 }
283 extracted_desc
284 };
285
286 let struct_name_str = to_camel_case(&fn_name_str);
287 let struct_name = syn::Ident::new(&struct_name_str, fn_name.span());
288
289 let params: Vec<_> = func.sig.inputs.iter().filter_map(|arg| {
290 if let syn::FnArg::Typed(pat_type) = arg {
291 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
292 return Some((pat_ident.ident.clone(), (*pat_type.ty).clone()));
293 }
294 }
295 None
296 }).collect();
297
298 let properties: Vec<proc_macro2::TokenStream> = params.iter().map(|(name, ty)| {
299 let name_str = name.to_string();
300 let actual_ty = if is_option(ty) { extract_type_from_option(ty) } else { ty };
301 let json_type = rust_type_to_json_type(actual_ty);
302 if let Some(d) = param_descs.get(&name_str) {
303 quote! {
304 (#name_str, serde_json::json!({"type": #json_type, "description": #d}))
305 }
306 } else {
307 quote! {
308 (#name_str, serde_json::json!({"type": #json_type}))
309 }
310 }
311 }).collect();
312
313 let required: Vec<String> = params.iter()
314 .filter(|(_, ty)| !is_option(ty))
315 .map(|(name, _)| name.to_string())
316 .collect();
317
318 let extractions: Vec<proc_macro2::TokenStream> = params.iter().map(|(name, ty)| {
319 let name_str = name.to_string();
320 let err_invalid = format!("invalid parameter '{}': {{}}", name_str);
321
322 if is_option(ty) {
323 quote! {
324 let #name: #ty = match args.get(#name_str) {
325 Some(v) => serde_json::from_value(v.clone())
326 .map_err(|e| langgraph_prebuilt::ToolError::InvalidArgs(format!(#err_invalid, e)))?,
327 None => None,
328 };
329 }
330 } else {
331 let err_missing = format!("missing required parameter '{}'", name_str);
332 quote! {
333 let #name: #ty = serde_json::from_value(
334 args.get(#name_str)
335 .cloned()
336 .ok_or_else(|| langgraph_prebuilt::ToolError::InvalidArgs(#err_missing.to_string()))?
337 ).map_err(|e| langgraph_prebuilt::ToolError::InvalidArgs(
338 format!(#err_invalid, e)
339 ))?;
340 }
341 }
342 }).collect();
343
344 let param_names: Vec<_> = params.iter().map(|(name, _)| name.clone()).collect();
345
346 let is_result_return = match &func.sig.output {
347 ReturnType::Type(_, ty) => {
348 if let syn::Type::Path(type_path) = ty.as_ref() {
349 type_path.path.segments.last()
350 .map(|s| s.ident == "Result")
351 .unwrap_or(false)
352 } else {
353 false
354 }
355 }
356 _ => false,
357 };
358
359 let is_async = func.sig.asyncness.is_some();
360
361 let await_tokens = if is_async {
362 quote! { .await }
363 } else {
364 quote! {}
365 };
366
367 let invoke_body = if is_result_return {
368 quote! {
369 #(#extractions)*
370 let result = #fn_name(#(#param_names),*)#await_tokens;
371 result
372 .map_err(|e| {
373 let tool_err: langgraph_prebuilt::ToolError = e.into();
374 tool_err
375 })
376 .and_then(|r| serde_json::to_value(r).map_err(|e| langgraph_prebuilt::ToolError::Execution(
377 format!("failed to serialize result: {}", e)
378 )))
379 }
380 } else {
381 quote! {
382 #(#extractions)*
383 let result = #fn_name(#(#param_names),*)#await_tokens;
384 serde_json::to_value(result).map_err(|e| langgraph_prebuilt::ToolError::Execution(
385 format!("failed to serialize result: {}", e)
386 ))
387 }
388 };
389
390 let trait_methods = if is_async {
391 quote! {
392 fn invoke(
393 &self,
394 _args: &serde_json::Value,
395 _config: &langgraph_checkpoint::config::RunnableConfig,
396 ) -> Result<serde_json::Value, langgraph_prebuilt::ToolError> {
397 Err(langgraph_prebuilt::ToolError::Execution(
398 "This tool is asynchronous and must be invoked with ainvoke".to_string()
399 ))
400 }
401
402 async fn ainvoke(
403 &self,
404 args: &serde_json::Value,
405 _config: &langgraph_checkpoint::config::RunnableConfig,
406 ) -> Result<serde_json::Value, langgraph_prebuilt::ToolError> {
407 #invoke_body
408 }
409 }
410 } else {
411 quote! {
412 fn invoke(
413 &self,
414 args: &serde_json::Value,
415 _config: &langgraph_checkpoint::config::RunnableConfig,
416 ) -> Result<serde_json::Value, langgraph_prebuilt::ToolError> {
417 #invoke_body
418 }
419 }
420 };
421
422 let expanded = quote! {
423 #func
424 pub struct #struct_name;
425 impl #struct_name {
426 pub fn new() -> Self { Self }
427 }
428 impl Default for #struct_name {
429 fn default() -> Self { Self }
430 }
431 #[async_trait::async_trait]
432 impl langgraph_prebuilt::BaseTool for #struct_name {
433 fn name(&self) -> &str { #tool_name }
434 fn description(&self) -> &str { #description }
435 fn parameters(&self) -> Option<&serde_json::Value> {
436 use std::sync::OnceLock;
437 static SCHEMA: OnceLock<serde_json::Value> = OnceLock::new();
438 Some(SCHEMA.get_or_init(|| {
439 let mut properties = serde_json::Map::new();
440 #(
441 {
442 let (k, v) = #properties;
443 properties.insert(k.to_string(), v);
444 }
445 )*
446 serde_json::json!({
447 "type": "object",
448 "properties": properties,
449 "required": [#(#required),*]
450 })
451 }))
452 }
453 #trait_methods
454 }
455 };
456
457 TokenStream::from(expanded)
458}
459
460fn to_camel_case(s: &str) -> String {
461 s.split('_')
462 .map(|word| {
463 let mut chars = word.chars();
464 match chars.next() {
465 Some(c) => c.to_uppercase().to_string() + &chars.as_str().to_lowercase(),
466 None => String::new(),
467 }
468 })
469 .collect()
470}
471
472fn rust_type_to_json_type(ty: &syn::Type) -> &'static str {
473 if let syn::Type::Path(type_path) = ty {
474 let type_name = type_path.path.segments.last()
475 .map(|s| s.ident.to_string())
476 .unwrap_or_default();
477
478 match type_name.as_str() {
479 "String" | "str" => "string",
480 "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "isize" | "usize" => "integer",
481 "f32" | "f64" => "number",
482 "bool" => "boolean",
483 _ => "string", }
485 } else {
486 "string"
487 }
488}
489
490#[proc_macro_derive(Traceable)]
494pub fn derive_traceable(input: TokenStream) -> TokenStream {
495 let input = parse_macro_input!(input as DeriveInput);
496 impl_traceable(&input)
497}
498
499fn impl_traceable(input: &DeriveInput) -> TokenStream {
500 let name = &input.ident;
501 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
502 let expanded = quote! {
503 impl #impl_generics #name #ty_generics #where_clause {
504 pub fn tracing_context() -> langgraph_tracing::TracingContext {
505 langgraph_tracing::TracingContext::new()
506 }
507 }
508 };
509 TokenStream::from(expanded)
510}
511
512fn is_option(ty: &syn::Type) -> bool {
513 if let syn::Type::Path(type_path) = ty {
514 if let Some(segment) = type_path.path.segments.last() {
515 return segment.ident == "Option";
516 }
517 }
518 false
519}
520
521fn extract_type_from_option(ty: &syn::Type) -> &syn::Type {
522 if let syn::Type::Path(type_path) = ty {
523 if let Some(segment) = type_path.path.segments.last() {
524 if segment.ident == "Option" {
525 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
526 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
527 return inner_ty;
528 }
529 }
530 }
531 }
532 }
533 ty
534}
535
536fn extract_param_descs(func: &ItemFn) -> std::collections::HashMap<String, String> {
537 let mut descs = std::collections::HashMap::new();
538 for attr in &func.attrs {
539 if !attr.path().is_ident("doc") {
540 continue;
541 }
542 if let syn::Meta::NameValue(nv) = &attr.meta {
543 if let syn::Expr::Lit(expr_lit) = &nv.value {
544 if let syn::Lit::Str(lit_str) = &expr_lit.lit {
545 let line = lit_str.value();
546 let trimmed = line.trim();
547 if let Some(rest) = trimmed.strip_prefix("@param ") {
549 let rest = rest.trim_start();
550 if let Some(space_idx) = rest.find(char::is_whitespace) {
551 let name = rest[..space_idx].to_string();
552 let desc = rest[space_idx..].trim().to_string();
553 if !desc.is_empty() {
554 descs.insert(name, desc);
555 }
556 }
557 }
558 }
559 }
560 }
561 }
562 descs
563}