strands_agents/agent/
builder.rs1use std::sync::Arc;
4
5use schemars::JsonSchema;
6use serde::de::DeserializeOwned;
7
8use crate::conversation::ConversationManager;
9use crate::hooks::HookRegistry;
10use crate::models::Model;
11use crate::tools::structured_output::StructuredOutputContext;
12use crate::tools::{AgentTool, ToolRegistry};
13use crate::types::content::Messages;
14use crate::types::errors::{Result, StrandsError};
15
16use super::{Agent, AgentState};
17
18pub struct AgentBuilder {
20 model: Option<Arc<dyn Model>>,
21 messages: Messages,
22 system_prompt: Option<String>,
23 tool_registry: ToolRegistry,
24 agent_name: Option<String>,
25 agent_id: String,
26 description: Option<String>,
27 state: AgentState,
28 hooks: HookRegistry,
29 conversation_manager: Option<Box<dyn ConversationManager>>,
30 record_direct_tool_call: bool,
31 trace_attributes: std::collections::HashMap<String, String>,
32 max_tool_calls: Option<usize>,
33 structured_output_context: Option<StructuredOutputContext>,
34}
35
36impl Default for AgentBuilder {
37 fn default() -> Self { Self::new() }
38}
39
40impl AgentBuilder {
41 pub fn new() -> Self {
42 Self {
43 model: None,
44 messages: Vec::new(),
45 system_prompt: None,
46 tool_registry: ToolRegistry::new(),
47 agent_name: None,
48 agent_id: "default".to_string(),
49 description: None,
50 state: AgentState::new(),
51 hooks: HookRegistry::new(),
52 conversation_manager: None,
53 record_direct_tool_call: false,
54 trace_attributes: std::collections::HashMap::new(),
55 max_tool_calls: None,
56 structured_output_context: None,
57 }
58 }
59
60 pub fn model(mut self, model: impl Model + 'static) -> Self {
62 self.model = Some(Arc::new(model));
63 self
64 }
65
66 pub fn model_arc(mut self, model: Arc<dyn Model>) -> Self {
68 self.model = Some(model);
69 self
70 }
71
72 pub fn messages(mut self, messages: Messages) -> Self {
74 self.messages = messages;
75 self
76 }
77
78 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
80 self.system_prompt = Some(prompt.into());
81 self
82 }
83
84 pub fn tool(mut self, tool: impl AgentTool + 'static) -> Result<Self> {
86 self.tool_registry.register_typed(tool)?;
87 Ok(self)
88 }
89
90 pub fn tools(mut self, tools: impl IntoIterator<Item = impl AgentTool + 'static>) -> Result<Self> {
92 for tool in tools {
93 self.tool_registry.register_typed(tool)?;
94 }
95 Ok(self)
96 }
97
98 pub fn tool_registry(mut self, registry: ToolRegistry) -> Self {
100 self.tool_registry = registry;
101 self
102 }
103
104 pub fn name(mut self, name: impl Into<String>) -> Self {
106 self.agent_name = Some(name.into());
107 self
108 }
109
110 pub fn agent_id(mut self, id: impl Into<String>) -> Self {
112 self.agent_id = id.into();
113 self
114 }
115
116 pub fn description(mut self, description: impl Into<String>) -> Self {
118 self.description = Some(description.into());
119 self
120 }
121
122 pub fn state(mut self, state: AgentState) -> Self {
124 self.state = state;
125 self
126 }
127
128 pub fn hooks(mut self, hooks: HookRegistry) -> Self {
130 self.hooks = hooks;
131 self
132 }
133
134 pub fn conversation_manager(mut self, manager: impl ConversationManager + 'static) -> Self {
136 self.conversation_manager = Some(Box::new(manager));
137 self
138 }
139
140 pub fn record_direct_tool_call(mut self, record: bool) -> Self {
142 self.record_direct_tool_call = record;
143 self
144 }
145
146 pub fn trace_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
148 self.trace_attributes.insert(key.into(), value.into());
149 self
150 }
151
152 pub fn trace_attributes(mut self, attrs: std::collections::HashMap<String, String>) -> Self {
154 self.trace_attributes = attrs;
155 self
156 }
157
158 pub fn max_tool_calls(mut self, max: usize) -> Self {
160 self.max_tool_calls = Some(max);
161 self
162 }
163
164 pub fn structured_output_model<T: JsonSchema + DeserializeOwned + 'static>(mut self) -> Self {
188 let context = StructuredOutputContext::with_type::<T>();
189 self.structured_output_context = Some(context);
190 self
191 }
192
193 pub fn structured_output_context(mut self, context: StructuredOutputContext) -> Self {
195 self.structured_output_context = Some(context);
196 self
197 }
198
199 pub fn build(self) -> Result<Agent> {
201 let model = self.model.ok_or_else(|| StrandsError::ConfigurationError {
202 message: "Model is required".to_string(),
203 })?;
204
205 Ok(Agent {
206 model,
207 messages: self.messages,
208 system_prompt: self.system_prompt,
209 tool_registry: self.tool_registry,
210 agent_name: self.agent_name,
211 agent_id: self.agent_id,
212 description: self.description,
213 state: self.state,
214 hooks: self.hooks,
215 conversation_manager: self.conversation_manager.unwrap_or_else(|| {
216 Box::new(crate::conversation::SlidingWindowConversationManager::default())
217 }),
218 interrupt_state: crate::types::interrupt::InterruptState::new(),
219 record_direct_tool_call: self.record_direct_tool_call,
220 trace_attributes: self.trace_attributes,
221 max_tool_calls: self.max_tool_calls,
222 structured_output_context: self.structured_output_context,
223 })
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use crate::models::BedrockModel;
231
232 #[test]
233 fn test_builder_basic() {
234 let agent = Agent::builder()
235 .model(BedrockModel::default())
236 .system_prompt("Test prompt")
237 .name("TestAgent")
238 .build()
239 .unwrap();
240
241 assert_eq!(agent.name(), Some(&"TestAgent".to_string()));
242 assert_eq!(agent.system_prompt(), Some("Test prompt"));
243 }
244
245 #[test]
246 fn test_builder_no_model() {
247 let result = Agent::builder().build();
248 assert!(result.is_err());
249 }
250}