1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use serde_json::Value;
9use synaptic_core::{ChatModel, ChatRequest, Message, SynapticError, Tool, ToolDefinition};
10use synaptic_macros::traceable;
11use synaptic_middleware::{AgentMiddleware, BaseChatModelCaller, MiddlewareChain, ModelRequest};
12use synaptic_store::Store;
13use synaptic_tools::SerialToolExecutor;
14
15use crate::builder::StateGraph;
16use crate::checkpoint::Checkpointer;
17use crate::command::NodeOutput;
18use crate::compiled::CompiledGraph;
19use crate::node::Node;
20use crate::state::MessageState;
21use crate::tool_node::ToolNode;
22use crate::END;
23
24pub type PreModelHook = Arc<
30 dyn Fn(
31 &mut MessageState,
32 ) -> Pin<Box<dyn Future<Output = Result<(), SynapticError>> + Send + '_>>
33 + Send
34 + Sync,
35>;
36
37pub type PostModelHook = Arc<
39 dyn Fn(
40 &mut MessageState,
41 ) -> Pin<Box<dyn Future<Output = Result<(), SynapticError>> + Send + '_>>
42 + Send
43 + Sync,
44>;
45
46struct ChatModelNode {
51 model: Arc<dyn ChatModel>,
52 tool_defs: Vec<ToolDefinition>,
53 system_prompt: Option<String>,
54 middleware: Arc<MiddlewareChain>,
55 is_first_call: AtomicBool,
56 pre_model_hook: Option<PreModelHook>,
57 post_model_hook: Option<PostModelHook>,
58 response_format: Option<Value>,
61}
62
63#[async_trait]
64impl Node<MessageState> for ChatModelNode {
65 async fn process(
66 &self,
67 mut state: MessageState,
68 ) -> Result<NodeOutput<MessageState>, SynapticError> {
69 if self.is_first_call.swap(false, Ordering::SeqCst) {
71 self.middleware
72 .run_before_agent(&mut state.messages)
73 .await?;
74 }
75
76 if let Some(ref hook) = self.pre_model_hook {
78 hook(&mut state).await?;
79 }
80
81 let request = ModelRequest {
82 messages: state.messages.clone(),
83 tools: self.tool_defs.clone(),
84 tool_choice: None,
85 system_prompt: self.system_prompt.clone(),
86 };
87
88 let base_caller = BaseChatModelCaller::new(self.model.clone());
89 let response = self.middleware.call_model(request, &base_caller).await?;
90
91 state.messages.push(response.message.clone());
92
93 if let Some(ref hook) = self.post_model_hook {
95 hook(&mut state).await?;
96 }
97
98 if response.message.tool_calls().is_empty() {
100 if let Some(ref schema) = self.response_format {
102 let instruction = format!(
103 "You MUST respond with valid JSON matching this schema:\n{}\n\n\
104 Do not include any text outside the JSON object. \
105 Do not use markdown code blocks.",
106 schema
107 );
108 let mut structured_messages = vec![Message::system(instruction)];
109 structured_messages.extend(state.messages.clone());
110
111 let structured_request = ChatRequest::new(structured_messages);
112 let structured_response = self.model.chat(structured_request).await?;
113 state.messages.pop();
115 state.messages.push(structured_response.message);
116 }
117
118 self.middleware.run_after_agent(&mut state.messages).await?;
119 }
120
121 Ok(state.into())
122 }
123}
124
125#[derive(Default)]
131pub struct ReactAgentOptions {
132 pub checkpointer: Option<Arc<dyn Checkpointer>>,
134 pub interrupt_before: Vec<String>,
136 pub interrupt_after: Vec<String>,
138 pub system_prompt: Option<String>,
140}
141
142pub fn create_react_agent(
144 model: Arc<dyn ChatModel>,
145 tools: Vec<Arc<dyn Tool>>,
146) -> Result<CompiledGraph<MessageState>, SynapticError> {
147 create_react_agent_with_options(model, tools, ReactAgentOptions::default())
148}
149
150pub fn create_react_agent_with_options(
152 model: Arc<dyn ChatModel>,
153 tools: Vec<Arc<dyn Tool>>,
154 options: ReactAgentOptions,
155) -> Result<CompiledGraph<MessageState>, SynapticError> {
156 create_agent(
157 model,
158 tools,
159 AgentOptions {
160 checkpointer: options.checkpointer,
161 interrupt_before: options.interrupt_before,
162 interrupt_after: options.interrupt_after,
163 system_prompt: options.system_prompt,
164 ..Default::default()
165 },
166 )
167}
168
169#[derive(Default)]
175pub struct AgentOptions {
176 pub checkpointer: Option<Arc<dyn Checkpointer>>,
177 pub interrupt_before: Vec<String>,
178 pub interrupt_after: Vec<String>,
179 pub system_prompt: Option<String>,
180 pub middleware: Vec<Arc<dyn AgentMiddleware>>,
181 pub store: Option<Arc<dyn Store>>,
182 pub name: Option<String>,
183 pub pre_model_hook: Option<PreModelHook>,
184 pub post_model_hook: Option<PostModelHook>,
185 pub response_format: Option<Value>,
187}
188
189#[traceable(skip = "model,tools,options")]
191pub fn create_agent(
192 model: Arc<dyn ChatModel>,
193 tools: Vec<Arc<dyn Tool>>,
194 options: AgentOptions,
195) -> Result<CompiledGraph<MessageState>, SynapticError> {
196 let tool_defs: Vec<ToolDefinition> = tools.iter().map(|t| t.as_tool_definition()).collect();
197
198 let registry = synaptic_tools::ToolRegistry::new();
199 for tool in tools {
200 registry.register(tool)?;
201 }
202 let executor = SerialToolExecutor::new(registry);
203
204 let middleware_chain = Arc::new(MiddlewareChain::new(options.middleware));
205
206 let agent_node = ChatModelNode {
207 model,
208 tool_defs,
209 system_prompt: options.system_prompt,
210 middleware: middleware_chain.clone(),
211 is_first_call: AtomicBool::new(true),
212 pre_model_hook: options.pre_model_hook,
213 post_model_hook: options.post_model_hook,
214 response_format: options.response_format,
215 };
216
217 let mut tool_node = ToolNode::with_middleware(executor, middleware_chain);
218 if let Some(ref store) = options.store {
219 tool_node = tool_node.with_store(store.clone());
220 }
221
222 let mut builder = StateGraph::new()
223 .add_node("agent", agent_node)
224 .add_node("tools", tool_node)
225 .set_entry_point("agent")
226 .add_conditional_edges_with_path_map(
227 "agent",
228 |state: &MessageState| {
229 if let Some(last) = state.last_message() {
230 if !last.tool_calls().is_empty() {
231 return "tools".to_string();
232 }
233 }
234 END.to_string()
235 },
236 HashMap::from([
237 ("tools".to_string(), "tools".to_string()),
238 (END.to_string(), END.to_string()),
239 ]),
240 )
241 .add_edge("tools", "agent");
242
243 if !options.interrupt_before.is_empty() {
244 builder = builder.interrupt_before(options.interrupt_before);
245 }
246 if !options.interrupt_after.is_empty() {
247 builder = builder.interrupt_after(options.interrupt_after);
248 }
249
250 let mut graph = builder.compile()?;
251
252 let checkpointer: Option<Arc<dyn Checkpointer>> = match (&options.store, options.checkpointer) {
254 (_, Some(ckpt)) => Some(ckpt),
255 (Some(store), None) => Some(Arc::new(crate::StoreCheckpointer::new(store.clone()))),
256 (None, None) => None,
257 };
258
259 if let Some(checkpointer) = checkpointer {
260 graph = graph.with_checkpointer(checkpointer);
261 }
262
263 Ok(graph)
264}
265
266struct HandoffTool {
271 target_agent: String,
272 tool_description: String,
273}
274
275#[async_trait]
276impl Tool for HandoffTool {
277 fn name(&self) -> &'static str {
278 Box::leak(format!("transfer_to_{}", self.target_agent).into_boxed_str())
279 }
280
281 fn description(&self) -> &'static str {
282 Box::leak(self.tool_description.clone().into_boxed_str())
283 }
284
285 async fn call(&self, _args: Value) -> Result<Value, SynapticError> {
286 Ok(Value::String(format!(
287 "Transferring to agent '{}'.",
288 self.target_agent
289 )))
290 }
291}
292
293pub fn create_handoff_tool(agent_name: &str, description: &str) -> Arc<dyn Tool> {
295 Arc::new(HandoffTool {
296 target_agent: agent_name.to_string(),
297 tool_description: description.to_string(),
298 })
299}
300
301#[derive(Default)]
307pub struct SupervisorOptions {
308 pub checkpointer: Option<Arc<dyn Checkpointer>>,
309 pub store: Option<Arc<dyn Store>>,
310 pub system_prompt: Option<String>,
311}
312
313struct SubAgentNode {
315 graph: CompiledGraph<MessageState>,
316}
317
318#[async_trait]
319impl Node<MessageState> for SubAgentNode {
320 async fn process(
321 &self,
322 state: MessageState,
323 ) -> Result<NodeOutput<MessageState>, SynapticError> {
324 let result = self.graph.invoke(state).await?;
325 Ok(result.into_state().into())
326 }
327}
328
329#[traceable(skip = "model,agents,options")]
331pub fn create_supervisor(
332 model: Arc<dyn ChatModel>,
333 agents: Vec<(String, CompiledGraph<MessageState>)>,
334 options: SupervisorOptions,
335) -> Result<CompiledGraph<MessageState>, SynapticError> {
336 let agent_names: Vec<String> = agents.iter().map(|(name, _)| name.clone()).collect();
337
338 let handoff_tools: Vec<Arc<dyn Tool>> = agent_names
340 .iter()
341 .map(|name| {
342 create_handoff_tool(
343 name,
344 &format!("Transfer the conversation to the '{name}' agent."),
345 )
346 })
347 .collect();
348
349 let handoff_tool_defs: Vec<ToolDefinition> = handoff_tools
350 .iter()
351 .map(|t| ToolDefinition {
352 name: t.name().to_string(),
353 description: t.description().to_string(),
354 parameters: serde_json::json!({}),
355 extras: None,
356 })
357 .collect();
358
359 let default_prompt = format!(
360 "You are a supervisor managing these agents: {}. \
361 Use the transfer tools to delegate tasks to the appropriate agent. \
362 When the task is complete, respond directly to the user.",
363 agent_names.join(", ")
364 );
365 let system_prompt = options.system_prompt.unwrap_or(default_prompt);
366
367 let supervisor_node = ChatModelNode {
368 model,
369 tool_defs: handoff_tool_defs.clone(),
370 system_prompt: Some(system_prompt),
371 middleware: Arc::new(MiddlewareChain::new(vec![])),
372 is_first_call: AtomicBool::new(false),
373 pre_model_hook: None,
374 post_model_hook: None,
375 response_format: None,
376 };
377
378 let mut builder = StateGraph::new()
379 .add_node("supervisor", supervisor_node)
380 .set_entry_point("supervisor");
381
382 for (name, graph) in agents {
383 builder = builder
384 .add_node(&name, SubAgentNode { graph })
385 .add_edge(&name, "supervisor");
386 }
387
388 let agent_names_for_router = agent_names.clone();
389 builder = builder.add_conditional_edges("supervisor", move |state: &MessageState| {
390 if let Some(last) = state.last_message() {
391 for tc in last.tool_calls() {
392 for agent_name in &agent_names_for_router {
393 if tc.name == format!("transfer_to_{agent_name}") {
394 return agent_name.clone();
395 }
396 }
397 }
398 }
399 END.to_string()
400 });
401
402 let mut graph = builder.compile()?;
403
404 if let Some(checkpointer) = options.checkpointer {
405 graph = graph.with_checkpointer(checkpointer);
406 }
407
408 Ok(graph)
409}
410
411#[derive(Default)]
417pub struct SwarmOptions {
418 pub checkpointer: Option<Arc<dyn Checkpointer>>,
419 pub store: Option<Arc<dyn Store>>,
420}
421
422struct SwarmAgentNode {
424 model: Arc<dyn ChatModel>,
425 tool_defs: Vec<ToolDefinition>,
426 system_prompt: Option<String>,
427}
428
429#[async_trait]
430impl Node<MessageState> for SwarmAgentNode {
431 async fn process(
432 &self,
433 mut state: MessageState,
434 ) -> Result<NodeOutput<MessageState>, SynapticError> {
435 let mut messages = Vec::new();
436 if let Some(ref prompt) = self.system_prompt {
437 messages.push(Message::system(prompt));
438 }
439 messages.extend(state.messages.clone());
440
441 let request = ChatRequest::new(messages).with_tools(self.tool_defs.clone());
442 let response = self.model.chat(request).await?;
443 state.messages.push(response.message);
444 Ok(state.into())
445 }
446}
447
448struct SwarmToolNode {
450 executor: SerialToolExecutor,
451 handoff_tool_names: Vec<String>,
452}
453
454#[async_trait]
455impl Node<MessageState> for SwarmToolNode {
456 async fn process(
457 &self,
458 mut state: MessageState,
459 ) -> Result<NodeOutput<MessageState>, SynapticError> {
460 let last = state
461 .last_message()
462 .ok_or_else(|| SynapticError::Graph("no messages in state".to_string()))?;
463
464 let tool_calls = last.tool_calls().to_vec();
465 for call in &tool_calls {
466 if self.handoff_tool_names.contains(&call.name) {
467 state.messages.push(Message::tool(
468 "Transferring to agent.".to_string(),
469 &call.id,
470 ));
471 } else {
472 let result = self
473 .executor
474 .execute(&call.name, call.arguments.clone())
475 .await?;
476 state
477 .messages
478 .push(Message::tool(result.to_string(), &call.id));
479 }
480 }
481
482 Ok(state.into())
483 }
484}
485
486pub struct SwarmAgent {
488 pub name: String,
489 pub model: Arc<dyn ChatModel>,
490 pub tools: Vec<Arc<dyn Tool>>,
491 pub system_prompt: Option<String>,
492}
493
494#[traceable(skip = "agents,options")]
496pub fn create_swarm(
497 agents: Vec<SwarmAgent>,
498 options: SwarmOptions,
499) -> Result<CompiledGraph<MessageState>, SynapticError> {
500 if agents.is_empty() {
501 return Err(SynapticError::Graph(
502 "swarm requires at least one agent".to_string(),
503 ));
504 }
505
506 let agent_names: Vec<String> = agents.iter().map(|a| a.name.clone()).collect();
507 let entry_agent = agent_names[0].clone();
508
509 let all_handoff_tools: HashMap<String, Arc<dyn Tool>> = agent_names
510 .iter()
511 .map(|name| {
512 (
513 name.clone(),
514 create_handoff_tool(
515 name,
516 &format!("Transfer the conversation to the '{name}' agent."),
517 ),
518 )
519 })
520 .collect();
521
522 let handoff_tool_names: Vec<String> = all_handoff_tools
523 .values()
524 .map(|t| t.name().to_string())
525 .collect();
526
527 let mut builder = StateGraph::new();
528
529 let global_registry = synaptic_tools::ToolRegistry::new();
530
531 for agent in agents {
532 let SwarmAgent {
533 name,
534 model,
535 tools,
536 system_prompt,
537 } = agent;
538
539 let mut tool_defs: Vec<ToolDefinition> = tools
540 .iter()
541 .map(|t| ToolDefinition {
542 name: t.name().to_string(),
543 description: t.description().to_string(),
544 parameters: serde_json::json!({}),
545 extras: None,
546 })
547 .collect();
548
549 for tool in &tools {
550 let _ = global_registry.register(tool.clone());
551 }
552
553 for other_name in &agent_names {
554 if other_name != &name {
555 if let Some(ht) = all_handoff_tools.get(other_name) {
556 tool_defs.push(ToolDefinition {
557 name: ht.name().to_string(),
558 description: ht.description().to_string(),
559 parameters: serde_json::json!({}),
560 extras: None,
561 });
562 }
563 }
564 }
565
566 let agent_node = SwarmAgentNode {
567 model,
568 tool_defs,
569 system_prompt,
570 };
571
572 builder = builder.add_node(&name, agent_node);
573 }
574
575 let executor = SerialToolExecutor::new(global_registry);
576 let swarm_tool_node = SwarmToolNode {
577 executor,
578 handoff_tool_names: handoff_tool_names.clone(),
579 };
580 builder = builder.add_node("tools", swarm_tool_node);
581
582 builder = builder.set_entry_point(&entry_agent);
583
584 for agent_name in &agent_names {
585 builder = builder.add_conditional_edges(agent_name, |state: &MessageState| {
586 if let Some(last) = state.last_message() {
587 if !last.tool_calls().is_empty() {
588 return "tools".to_string();
589 }
590 }
591 END.to_string()
592 });
593 }
594
595 let all_agent_names = agent_names.clone();
596 builder = builder.add_conditional_edges("tools", move |state: &MessageState| {
597 for msg in state.messages.iter().rev() {
598 if msg.is_ai() && !msg.tool_calls().is_empty() {
599 for tc in msg.tool_calls() {
600 for agent_name in &all_agent_names {
601 if tc.name == format!("transfer_to_{agent_name}") {
602 return agent_name.clone();
603 }
604 }
605 }
606 return all_agent_names[0].clone();
607 }
608 }
609 all_agent_names[0].clone()
610 });
611
612 let mut graph = builder.compile()?;
613
614 if let Some(checkpointer) = options.checkpointer {
615 graph = graph.with_checkpointer(checkpointer);
616 }
617
618 Ok(graph)
619}