langgraph_prebuilt/
chat_agent.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use serde_json::Value as JsonValue;
5use langgraph_checkpoint::config::RunnableConfig;
6use langgraph::graph::GraphError;
7use langgraph::runnable::{Runnable, RunnableError};
8use langgraph::channels::{BinaryOperatorAggregate, Channel};
9use langgraph::constants::{START, END};
10use langgraph::graph::StateGraph;
11
12use crate::traits::{BaseChatModel, BaseTool, ToolDef};
13use crate::types::{Message, add_messages};
14use crate::tool_node::ToolNode;
15use crate::tools_condition::tools_condition;
16
17pub struct ReActAgentConfig {
19 pub system_prompt: Option<String>,
21 pub max_steps: Option<usize>,
23 pub handle_tool_errors: bool,
25}
26
27impl Default for ReActAgentConfig {
28 fn default() -> Self {
29 Self {
30 system_prompt: None,
31 max_steps: Some(25),
32 handle_tool_errors: true,
33 }
34 }
35}
36
37pub struct ReActAgent {
45 graph: Box<dyn Runnable>,
46}
47
48impl ReActAgent {
49 pub fn invoke(&self, input: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, RunnableError> {
51 self.graph.invoke(input, config)
52 }
53
54 pub async fn ainvoke(&self, input: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, RunnableError> {
56 self.graph.ainvoke(input, config).await
57 }
58}
59
60fn messages_reducer(current: &JsonValue, update: &JsonValue) -> JsonValue {
62 add_messages(current.clone(), update.clone())
63}
64
65pub fn create_react_agent(
81 model: Arc<dyn BaseChatModel>,
82 tools: Vec<Arc<dyn BaseTool>>,
83 config: Option<ReActAgentConfig>,
84) -> Result<ReActAgent, GraphError> {
85 let config = config.unwrap_or_default();
86
87 let tool_defs: Vec<ToolDef> = tools.iter().map(|t| t.to_tool_def()).collect();
89
90 let bound_model: Arc<dyn BaseChatModel> = Arc::from(model.bind_tools(tool_defs));
92
93 let tool_node = Arc::new(
95 ToolNode::new(tools).with_error_handling(config.handle_tool_errors)
96 );
97
98 let mut channels: HashMap<String, Box<dyn Channel>> = HashMap::new();
105 channels.insert(
106 "messages".to_string(),
107 Box::new(BinaryOperatorAggregate::new("messages", messages_reducer)),
108 );
109
110 let mut graph = StateGraph::new(channels);
111
112 let agent_model = bound_model;
114 let system_prompt = config.system_prompt.clone();
115
116 graph.add_node("agent", move |input: JsonValue, _config: RunnableConfig| {
117 let model = agent_model.clone();
118 let prompt = system_prompt.clone();
119 async move {
120 let messages = match input.get("messages") {
121 Some(JsonValue::Array(arr)) => arr.clone(),
122 _ => vec![],
123 };
124
125 let mut typed_messages: Vec<Message> = Vec::new();
126
127 if let Some(ref p) = prompt {
128 typed_messages.push(Message::system(p.clone()));
129 }
130
131 for msg in &messages {
132 if let Ok(m) = serde_json::from_value::<Message>(msg.clone()) {
133 typed_messages.push(m);
134 }
135 }
136
137 let response = model.invoke(&typed_messages, &RunnableConfig::new())
138 .map_err(|e| RunnableError::Node(e.to_string()))?;
139 let response_json = serde_json::to_value(response)
140 .map_err(|e: serde_json::Error| RunnableError::Node(e.to_string()))?;
141
142 Ok(serde_json::json!({
143 "messages": [response_json]
144 }))
145 }
146 })?;
147
148 let tools_arc = tool_node.clone();
150 graph.add_node("tools", move |input: JsonValue, config: RunnableConfig| {
151 let tn = tools_arc.clone();
152 async move {
153 tn.ainvoke(&input, &config).await
154 }
155 })?;
156
157 graph.add_conditional_edges(
159 "agent",
160 |input: JsonValue, _config: RunnableConfig| async move {
161 let route = tools_condition(&input);
162 Ok(JsonValue::String(route))
163 },
164 Some({
165 let mut map = HashMap::new();
166 map.insert("tools".to_string(), "tools".to_string());
167 map.insert(END.to_string(), END.to_string());
168 map
169 }),
170 )?;
171
172 graph.add_edge("tools", "agent")?;
174
175 graph.add_edge(START, "agent")?;
177
178 let mut builder = graph.compile_builder();
180 if let Some(steps) = config.max_steps {
181 builder = builder.recursion_limit(steps as u64);
182 }
183 let compiled = builder.build()?;
184
185 Ok(ReActAgent {
186 graph: Box::new(compiled),
187 })
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194
195 #[test]
196 fn test_merge_state() {
197 let current = serde_json::json!([
199 {"type": "human", "content": "Hi"}
200 ]);
201 let update = serde_json::json!([
202 {"type": "ai", "content": "Hello"}
203 ]);
204
205 let merged = messages_reducer(¤t, &update);
206 let messages = merged.as_array().unwrap();
207 assert_eq!(messages.len(), 2);
208 }
209
210 #[test]
211 fn test_merge_state_new_key() {
212 let current = serde_json::json!({
213 "messages": []
214 });
215 let update = serde_json::json!({
216 "result": "done"
217 });
218
219 let _merged = messages_reducer(¤t, &update);
220 }
223}