use crate::config::Config;
use crate::errors::Result;
use crate::graph::{StateGraph, CompiledGraph};
use crate::llm::ChatModel;
use crate::nodes::Node;
use crate::prebuilt::{Tool, ToolNode, tools_condition};
use crate::pregel::BranchResult;
use crate::state::{MessagesState, Message};
use async_trait::async_trait;
use std::sync::Arc;
pub fn create_react_agent<M>(
model: M,
tools: Vec<Tool>,
) -> Result<CompiledGraph<MessagesState>>
where
M: ChatModel + 'static,
{
let mut graph = StateGraph::new();
let model = Arc::new(model);
let model_clone = model.clone();
graph.add_node("agent", move |state: MessagesState, _config: &Config| {
let model = model_clone.clone();
async move {
let response = model.invoke(&state.messages).await?;
Ok(MessagesState {
messages: vec![response],
})
}
});
let tool_node = ToolNode::new(tools);
graph.add_node("tools", tool_node);
graph.set_entry_point("agent");
graph.add_conditional_edges("agent", |state: &MessagesState| {
let messages = state.messages.clone();
async move {
let route = tools_condition(&messages);
if route == "tools" {
Ok(BranchResult::single("tools"))
} else {
Ok(BranchResult::end())
}
}
});
graph.add_edge("tools", "agent");
graph.compile(None)
}
pub struct AgentNode<M: ChatModel> {
model: Arc<M>,
}
impl<M: ChatModel> AgentNode<M> {
pub fn new(model: M) -> Self {
Self {
model: Arc::new(model),
}
}
}
#[async_trait]
impl<M: ChatModel + 'static> Node<MessagesState> for AgentNode<M> {
async fn invoke(&self, state: MessagesState, _config: &Config) -> Result<MessagesState> {
let response = self.model.invoke(&state.messages).await?;
Ok(MessagesState {
messages: vec![response],
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::ChatModel;
#[derive(Clone)]
struct MockModel;
#[async_trait]
impl ChatModel for MockModel {
async fn invoke(&self, _messages: &[Message]) -> Result<Message> {
Ok(Message::assistant("Mock response"))
}
fn clone_box(&self) -> Box<dyn ChatModel> {
Box::new(self.clone())
}
}
#[tokio::test]
async fn test_create_react_agent() {
let model = MockModel;
let tool = Tool::new(
"test",
"Test tool",
|_| async move { Ok(serde_json::json!({"result": "ok"})) },
);
let agent = create_react_agent(model, vec![tool]);
assert!(agent.is_ok());
}
#[tokio::test]
async fn test_agent_node() {
let model = MockModel;
let node = AgentNode::new(model);
let state = MessagesState {
messages: vec![Message::user("Hello")],
};
let result = node.invoke(state, &Config::default()).await.unwrap();
assert_eq!(result.messages.len(), 1);
assert_eq!(result.messages[0].content, "Mock response");
}
}