Skip to main content

langgraph_prebuilt/
chat_agent.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use serde_json::Value as JsonValue;
5use langgraph_checkpoint::config::RunnableConfig;
6use langgraph::graph::GraphError;
7use langgraph::runnable::{Runnable, RunnableError};
8use langgraph::channels::{BinaryOperatorAggregate, Channel};
9use langgraph::constants::{START, END};
10use langgraph::graph::StateGraph;
11
12use crate::traits::{BaseChatModel, BaseTool, ToolDef};
13use crate::types::{Message, add_messages};
14use crate::tool_node::ToolNode;
15use crate::tools_condition::tools_condition;
16
17/// Configuration for creating a ReAct agent.
18pub struct ReActAgentConfig {
19    /// The system prompt to use.
20    pub system_prompt: Option<String>,
21    /// Maximum number of steps before the agent stops.
22    pub max_steps: Option<usize>,
23    /// Whether to handle tool errors gracefully.
24    pub handle_tool_errors: bool,
25}
26
27impl Default for ReActAgentConfig {
28    fn default() -> Self {
29        Self {
30            system_prompt: None,
31            max_steps: Some(25),
32            handle_tool_errors: true,
33        }
34    }
35}
36
37/// A compiled ReAct agent graph.
38///
39/// This is a prebuilt graph that implements the ReAct (Reasoning + Acting) pattern:
40/// 1. The model receives the conversation history and decides what to do
41/// 2. If the model calls tools, they are executed and the results are added to the history
42/// 3. The model sees the tool results and decides what to do next
43/// 4. This continues until the model responds without tool calls (or max steps is reached)
44pub struct ReActAgent {
45    graph: Box<dyn Runnable>,
46}
47
48impl ReActAgent {
49    /// Invoke the agent synchronously.
50    pub fn invoke(&self, input: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, RunnableError> {
51        self.graph.invoke(input, config)
52    }
53
54    /// Invoke the agent asynchronously.
55    pub async fn ainvoke(&self, input: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, RunnableError> {
56        self.graph.ainvoke(input, config).await
57    }
58}
59
60/// Reducer for messages channel: appends new messages to existing ones.
61fn messages_reducer(current: &JsonValue, update: &JsonValue) -> JsonValue {
62    add_messages(current.clone(), update.clone())
63}
64
65/// Create a ReAct agent with the given model and tools.
66///
67/// This builds a graph with the following structure:
68/// ```text
69/// START → agent → [tools_condition] → tools → agent (loop)
70///                                → END
71/// ```
72///
73/// # Arguments
74/// * `model` - The chat model to use for generating responses
75/// * `tools` - The tools available to the agent
76/// * `config` - Optional configuration for the agent
77///
78/// # Returns
79/// A compiled agent graph that can be invoked.
80pub fn create_react_agent(
81    model: Arc<dyn BaseChatModel>,
82    tools: Vec<Arc<dyn BaseTool>>,
83    config: Option<ReActAgentConfig>,
84) -> Result<ReActAgent, GraphError> {
85    let config = config.unwrap_or_default();
86
87    // Get tool definitions for the model
88    let tool_defs: Vec<ToolDef> = tools.iter().map(|t| t.to_tool_def()).collect();
89
90    // Bind tools to the model (wrap in Arc for sharing across closures)
91    let bound_model: Arc<dyn BaseChatModel> = Arc::from(model.bind_tools(tool_defs));
92
93    // Create the ToolNode (wrapped in Arc for sharing across closures)
94    let tool_node = Arc::new(
95        ToolNode::new(tools).with_error_handling(config.handle_tool_errors)
96    );
97
98    // -------------------------------------------------------
99    // Build graph: START → agent → [should_continue] → tools → agent (loop) → END
100    // (same structure as Python create_react_agent)
101    // -------------------------------------------------------
102
103    // Create channels with reducers
104    let mut channels: HashMap<String, Box<dyn Channel>> = HashMap::new();
105    channels.insert(
106        "messages".to_string(),
107        Box::new(BinaryOperatorAggregate::new("messages", messages_reducer)),
108    );
109
110    let mut graph = StateGraph::new(channels);
111
112    // --- Agent node: calls the LLM ---
113    let agent_model = bound_model;
114    let system_prompt = config.system_prompt.clone();
115
116    graph.add_node("agent", move |input: JsonValue, _config: RunnableConfig| {
117        let model = agent_model.clone();
118        let prompt = system_prompt.clone();
119        async move {
120            let messages = match input.get("messages") {
121                Some(JsonValue::Array(arr)) => arr.clone(),
122                _ => vec![],
123            };
124
125            let mut typed_messages: Vec<Message> = Vec::new();
126
127            if let Some(ref p) = prompt {
128                typed_messages.push(Message::system(p.clone()));
129            }
130
131            for msg in &messages {
132                if let Ok(m) = serde_json::from_value::<Message>(msg.clone()) {
133                    typed_messages.push(m);
134                }
135            }
136
137            let response = model.invoke(&typed_messages, &RunnableConfig::new())
138                .map_err(|e| RunnableError::Node(e.to_string()))?;
139            let response_json = serde_json::to_value(response)
140                .map_err(|e: serde_json::Error| RunnableError::Node(e.to_string()))?;
141
142            Ok(serde_json::json!({
143                "messages": [response_json]
144            }))
145        }
146    })?;
147
148    // --- Tools node: executes tool calls ---
149    let tools_arc = tool_node.clone();
150    graph.add_node("tools", move |input: JsonValue, config: RunnableConfig| {
151        let tn = tools_arc.clone();
152        async move {
153            tn.ainvoke(&input, &config).await
154        }
155    })?;
156
157    // --- Conditional edge: agent → tools or END ---
158    graph.add_conditional_edges(
159        "agent",
160        |input: JsonValue, _config: RunnableConfig| async move {
161            let route = tools_condition(&input);
162            Ok(JsonValue::String(route))
163        },
164        Some({
165            let mut map = HashMap::new();
166            map.insert("tools".to_string(), "tools".to_string());
167            map.insert(END.to_string(), END.to_string());
168            map
169        }),
170    )?;
171
172    // --- Edge: tools → agent (loop back) ---
173    graph.add_edge("tools", "agent")?;
174
175    // --- Entry point ---
176    graph.add_edge(START, "agent")?;
177
178    // --- Compile ---
179    let mut builder = graph.compile_builder();
180    if let Some(steps) = config.max_steps {
181        builder = builder.recursion_limit(steps as u64);
182    }
183    let compiled = builder.build()?;
184
185    Ok(ReActAgent {
186        graph: Box::new(compiled),
187    })
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    
194
195    #[test]
196    fn test_merge_state() {
197        // The reducer receives raw channel values (arrays), not objects with "messages" key
198        let current = serde_json::json!([
199            {"type": "human", "content": "Hi"}
200        ]);
201        let update = serde_json::json!([
202            {"type": "ai", "content": "Hello"}
203        ]);
204
205        let merged = messages_reducer(&current, &update);
206        let messages = merged.as_array().unwrap();
207        assert_eq!(messages.len(), 2);
208    }
209
210    #[test]
211    fn test_merge_state_new_key() {
212        let current = serde_json::json!({
213            "messages": []
214        });
215        let update = serde_json::json!({
216            "result": "done"
217        });
218
219        let _merged = messages_reducer(&current, &update);
220        // add_messages merges the messages arrays
221        // "result" is not messages, so it gets appended as a message
222    }
223}