Skip to main content

synaptic_graph/
prebuilt.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use synaptic_core::{ChatModel, ChatRequest, Message, SynapseError, Tool, ToolDefinition};
6use synaptic_tools::{SerialToolExecutor, ToolRegistry};
7
8use crate::builder::StateGraph;
9use crate::checkpoint::Checkpointer;
10use crate::compiled::CompiledGraph;
11use crate::node::Node;
12use crate::state::MessageState;
13use crate::tool_node::ToolNode;
14use crate::END;
15
16/// Prebuilt node that calls a ChatModel with the current messages.
17struct ChatModelNode {
18    model: Arc<dyn ChatModel>,
19    tool_defs: Vec<ToolDefinition>,
20    system_prompt: Option<String>,
21}
22
23#[async_trait]
24impl Node<MessageState> for ChatModelNode {
25    async fn process(&self, mut state: MessageState) -> Result<MessageState, SynapseError> {
26        let mut messages = Vec::new();
27
28        // If a system prompt is configured, prepend it
29        if let Some(ref prompt) = self.system_prompt {
30            messages.push(Message::system(prompt));
31        }
32
33        messages.extend(state.messages.clone());
34
35        let request = ChatRequest::new(messages).with_tools(self.tool_defs.clone());
36        let response = self.model.chat(request).await?;
37        state.messages.push(response.message);
38        Ok(state)
39    }
40}
41
42/// Options for creating a ReAct agent with `create_react_agent_with_options`.
43///
44/// All fields are optional and have sensible defaults. Use `Default::default()`
45/// for the simplest configuration, which behaves identically to `create_react_agent`.
46#[derive(Default)]
47pub struct ReactAgentOptions {
48    /// Optional checkpointer for state persistence across invocations.
49    pub checkpointer: Option<Arc<dyn Checkpointer>>,
50    /// Node names that should interrupt BEFORE execution (human-in-the-loop).
51    pub interrupt_before: Vec<String>,
52    /// Node names that should interrupt AFTER execution (human-in-the-loop).
53    pub interrupt_after: Vec<String>,
54    /// Optional system prompt to prepend to messages before calling the model.
55    pub system_prompt: Option<String>,
56}
57
58/// Create a prebuilt ReAct agent graph.
59///
60/// The graph has two nodes:
61/// - "agent": calls the ChatModel with messages and tool definitions
62/// - "tools": executes any tool calls from the agent's response
63///
64/// Routing: if the agent returns tool calls, route to "tools"; otherwise route to END.
65pub fn create_react_agent(
66    model: Arc<dyn ChatModel>,
67    tools: Vec<Arc<dyn Tool>>,
68) -> Result<CompiledGraph<MessageState>, SynapseError> {
69    create_react_agent_with_options(model, tools, ReactAgentOptions::default())
70}
71
72/// Create a prebuilt ReAct agent graph with additional configuration options.
73///
74/// This is the extended version of [`create_react_agent`] that accepts
75/// a [`ReactAgentOptions`] struct for configuring checkpointing,
76/// interrupts, and system prompts.
77///
78/// The graph has two nodes:
79/// - "agent": calls the ChatModel with messages and tool definitions
80/// - "tools": executes any tool calls from the agent's response
81///
82/// Routing: if the agent returns tool calls, route to "tools"; otherwise route to END.
83///
84/// # Example
85///
86/// ```ignore
87/// use std::sync::Arc;
88/// use synaptic_graph::{create_react_agent_with_options, ReactAgentOptions, MemorySaver};
89///
90/// let options = ReactAgentOptions {
91///     checkpointer: Some(Arc::new(MemorySaver::new())),
92///     system_prompt: Some("You are a helpful assistant.".to_string()),
93///     interrupt_before: vec!["tools".to_string()],
94///     ..Default::default()
95/// };
96///
97/// let graph = create_react_agent_with_options(model, tools, options)?;
98/// ```
99pub fn create_react_agent_with_options(
100    model: Arc<dyn ChatModel>,
101    tools: Vec<Arc<dyn Tool>>,
102    options: ReactAgentOptions,
103) -> Result<CompiledGraph<MessageState>, SynapseError> {
104    let tool_defs: Vec<ToolDefinition> = tools
105        .iter()
106        .map(|t| ToolDefinition {
107            name: t.name().to_string(),
108            description: t.description().to_string(),
109            parameters: serde_json::json!({}),
110        })
111        .collect();
112
113    let registry = ToolRegistry::new();
114    for tool in tools {
115        registry.register(tool)?;
116    }
117    let executor = SerialToolExecutor::new(registry);
118
119    let agent_node = ChatModelNode {
120        model,
121        tool_defs,
122        system_prompt: options.system_prompt,
123    };
124    let tool_node = ToolNode::new(executor);
125
126    let mut builder = StateGraph::new()
127        .add_node("agent", agent_node)
128        .add_node("tools", tool_node)
129        .set_entry_point("agent")
130        .add_conditional_edges_with_path_map(
131            "agent",
132            |state: &MessageState| {
133                if let Some(last) = state.last_message() {
134                    if !last.tool_calls().is_empty() {
135                        return "tools".to_string();
136                    }
137                }
138                END.to_string()
139            },
140            HashMap::from([
141                ("tools".to_string(), "tools".to_string()),
142                (END.to_string(), END.to_string()),
143            ]),
144        )
145        .add_edge("tools", "agent");
146
147    if !options.interrupt_before.is_empty() {
148        builder = builder.interrupt_before(options.interrupt_before);
149    }
150    if !options.interrupt_after.is_empty() {
151        builder = builder.interrupt_after(options.interrupt_after);
152    }
153
154    let mut graph = builder.compile()?;
155
156    if let Some(checkpointer) = options.checkpointer {
157        graph = graph.with_checkpointer(checkpointer);
158    }
159
160    Ok(graph)
161}