use std::collections::HashMap;
use std::sync::Arc;
use serde_json::Value as JsonValue;
use langgraph_checkpoint::config::RunnableConfig;
use langgraph::graph::GraphError;
use langgraph::runnable::{Runnable, RunnableError};
use langgraph::channels::{BinaryOperatorAggregate, Channel};
use langgraph::constants::{START, END};
use langgraph::graph::StateGraph;
use crate::traits::{BaseChatModel, BaseTool, ToolDef};
use crate::types::{Message, add_messages};
use crate::tool_node::ToolNode;
use crate::tools_condition::tools_condition;
pub struct ReActAgentConfig {
pub system_prompt: Option<String>,
pub max_steps: Option<usize>,
pub handle_tool_errors: bool,
}
impl Default for ReActAgentConfig {
fn default() -> Self {
Self {
system_prompt: None,
max_steps: Some(25),
handle_tool_errors: true,
}
}
}
pub struct ReActAgent {
graph: Box<dyn Runnable>,
}
impl ReActAgent {
pub fn invoke(&self, input: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, RunnableError> {
self.graph.invoke(input, config)
}
pub async fn ainvoke(&self, input: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, RunnableError> {
self.graph.ainvoke(input, config).await
}
}
fn messages_reducer(current: &JsonValue, update: &JsonValue) -> JsonValue {
add_messages(current.clone(), update.clone())
}
pub fn create_react_agent(
model: Arc<dyn BaseChatModel>,
tools: Vec<Arc<dyn BaseTool>>,
config: Option<ReActAgentConfig>,
) -> Result<ReActAgent, GraphError> {
let config = config.unwrap_or_default();
let tool_defs: Vec<ToolDef> = tools.iter().map(|t| t.to_tool_def()).collect();
let bound_model: Arc<dyn BaseChatModel> = Arc::from(model.bind_tools(tool_defs));
let tool_node = Arc::new(
ToolNode::new(tools).with_error_handling(config.handle_tool_errors)
);
let mut channels: HashMap<String, Box<dyn Channel>> = HashMap::new();
channels.insert(
"messages".to_string(),
Box::new(BinaryOperatorAggregate::new("messages", messages_reducer)),
);
let mut graph = StateGraph::new(channels);
let agent_model = bound_model;
let system_prompt = config.system_prompt.clone();
graph.add_node("agent", move |input: JsonValue, _config: RunnableConfig| {
let model = agent_model.clone();
let prompt = system_prompt.clone();
async move {
let messages = match input.get("messages") {
Some(JsonValue::Array(arr)) => arr.clone(),
_ => vec![],
};
let mut typed_messages: Vec<Message> = Vec::new();
if let Some(ref p) = prompt {
typed_messages.push(Message::system(p.clone()));
}
for msg in &messages {
if let Ok(m) = serde_json::from_value::<Message>(msg.clone()) {
typed_messages.push(m);
}
}
let response = model.invoke(&typed_messages, &RunnableConfig::new())
.map_err(|e| RunnableError::Node(e.to_string()))?;
let response_json = serde_json::to_value(response)
.map_err(|e: serde_json::Error| RunnableError::Node(e.to_string()))?;
Ok(serde_json::json!({
"messages": [response_json]
}))
}
})?;
let tools_arc = tool_node.clone();
graph.add_node("tools", move |input: JsonValue, config: RunnableConfig| {
let tn = tools_arc.clone();
async move {
tn.ainvoke(&input, &config).await
}
})?;
graph.add_conditional_edges(
"agent",
|input: JsonValue, _config: RunnableConfig| async move {
let route = tools_condition(&input);
Ok(JsonValue::String(route))
},
Some({
let mut map = HashMap::new();
map.insert("tools".to_string(), "tools".to_string());
map.insert(END.to_string(), END.to_string());
map
}),
)?;
graph.add_edge("tools", "agent")?;
graph.add_edge(START, "agent")?;
let mut builder = graph.compile_builder();
if let Some(steps) = config.max_steps {
builder = builder.recursion_limit(steps as u64);
}
let compiled = builder.build()?;
Ok(ReActAgent {
graph: Box::new(compiled),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_merge_state() {
let current = serde_json::json!([
{"type": "human", "content": "Hi"}
]);
let update = serde_json::json!([
{"type": "ai", "content": "Hello"}
]);
let merged = messages_reducer(¤t, &update);
let messages = merged.as_array().unwrap();
assert_eq!(messages.len(), 2);
}
#[test]
fn test_merge_state_new_key() {
let current = serde_json::json!({
"messages": []
});
let update = serde_json::json!({
"result": "done"
});
let _merged = messages_reducer(¤t, &update);
}
}