synaptic_graph/
prebuilt.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use synaptic_core::{ChatModel, ChatRequest, Message, SynapseError, Tool, ToolDefinition};
6use synaptic_tools::{SerialToolExecutor, ToolRegistry};
7
8use crate::builder::StateGraph;
9use crate::checkpoint::Checkpointer;
10use crate::compiled::CompiledGraph;
11use crate::node::Node;
12use crate::state::MessageState;
13use crate::tool_node::ToolNode;
14use crate::END;
15
16struct ChatModelNode {
18 model: Arc<dyn ChatModel>,
19 tool_defs: Vec<ToolDefinition>,
20 system_prompt: Option<String>,
21}
22
23#[async_trait]
24impl Node<MessageState> for ChatModelNode {
25 async fn process(&self, mut state: MessageState) -> Result<MessageState, SynapseError> {
26 let mut messages = Vec::new();
27
28 if let Some(ref prompt) = self.system_prompt {
30 messages.push(Message::system(prompt));
31 }
32
33 messages.extend(state.messages.clone());
34
35 let request = ChatRequest::new(messages).with_tools(self.tool_defs.clone());
36 let response = self.model.chat(request).await?;
37 state.messages.push(response.message);
38 Ok(state)
39 }
40}
41
42#[derive(Default)]
47pub struct ReactAgentOptions {
48 pub checkpointer: Option<Arc<dyn Checkpointer>>,
50 pub interrupt_before: Vec<String>,
52 pub interrupt_after: Vec<String>,
54 pub system_prompt: Option<String>,
56}
57
58pub fn create_react_agent(
66 model: Arc<dyn ChatModel>,
67 tools: Vec<Arc<dyn Tool>>,
68) -> Result<CompiledGraph<MessageState>, SynapseError> {
69 create_react_agent_with_options(model, tools, ReactAgentOptions::default())
70}
71
72pub fn create_react_agent_with_options(
100 model: Arc<dyn ChatModel>,
101 tools: Vec<Arc<dyn Tool>>,
102 options: ReactAgentOptions,
103) -> Result<CompiledGraph<MessageState>, SynapseError> {
104 let tool_defs: Vec<ToolDefinition> = tools
105 .iter()
106 .map(|t| ToolDefinition {
107 name: t.name().to_string(),
108 description: t.description().to_string(),
109 parameters: serde_json::json!({}),
110 })
111 .collect();
112
113 let registry = ToolRegistry::new();
114 for tool in tools {
115 registry.register(tool)?;
116 }
117 let executor = SerialToolExecutor::new(registry);
118
119 let agent_node = ChatModelNode {
120 model,
121 tool_defs,
122 system_prompt: options.system_prompt,
123 };
124 let tool_node = ToolNode::new(executor);
125
126 let mut builder = StateGraph::new()
127 .add_node("agent", agent_node)
128 .add_node("tools", tool_node)
129 .set_entry_point("agent")
130 .add_conditional_edges_with_path_map(
131 "agent",
132 |state: &MessageState| {
133 if let Some(last) = state.last_message() {
134 if !last.tool_calls().is_empty() {
135 return "tools".to_string();
136 }
137 }
138 END.to_string()
139 },
140 HashMap::from([
141 ("tools".to_string(), "tools".to_string()),
142 (END.to_string(), END.to_string()),
143 ]),
144 )
145 .add_edge("tools", "agent");
146
147 if !options.interrupt_before.is_empty() {
148 builder = builder.interrupt_before(options.interrupt_before);
149 }
150 if !options.interrupt_after.is_empty() {
151 builder = builder.interrupt_after(options.interrupt_after);
152 }
153
154 let mut graph = builder.compile()?;
155
156 if let Some(checkpointer) = options.checkpointer {
157 graph = graph.with_checkpointer(checkpointer);
158 }
159
160 Ok(graph)
161}