use std::sync::Arc;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use crate::conversation::ConversationManager;
use crate::hooks::HookRegistry;
use crate::models::Model;
use crate::tools::structured_output::StructuredOutputContext;
use crate::tools::{AgentTool, ToolRegistry};
use crate::types::content::Messages;
use crate::types::errors::{Result, StrandsError};
use super::{Agent, AgentState};
pub struct AgentBuilder {
model: Option<Arc<dyn Model>>,
messages: Messages,
system_prompt: Option<String>,
tool_registry: ToolRegistry,
agent_name: Option<String>,
agent_id: String,
description: Option<String>,
state: AgentState,
hooks: HookRegistry,
conversation_manager: Option<Box<dyn ConversationManager>>,
record_direct_tool_call: bool,
trace_attributes: std::collections::HashMap<String, String>,
max_tool_calls: Option<usize>,
structured_output_context: Option<StructuredOutputContext>,
}
impl Default for AgentBuilder {
fn default() -> Self { Self::new() }
}
impl AgentBuilder {
pub fn new() -> Self {
Self {
model: None,
messages: Vec::new(),
system_prompt: None,
tool_registry: ToolRegistry::new(),
agent_name: None,
agent_id: "default".to_string(),
description: None,
state: AgentState::new(),
hooks: HookRegistry::new(),
conversation_manager: None,
record_direct_tool_call: false,
trace_attributes: std::collections::HashMap::new(),
max_tool_calls: None,
structured_output_context: None,
}
}
pub fn model(mut self, model: impl Model + 'static) -> Self {
self.model = Some(Arc::new(model));
self
}
pub fn model_arc(mut self, model: Arc<dyn Model>) -> Self {
self.model = Some(model);
self
}
pub fn messages(mut self, messages: Messages) -> Self {
self.messages = messages;
self
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn tool(mut self, tool: impl AgentTool + 'static) -> Result<Self> {
self.tool_registry.register_typed(tool)?;
Ok(self)
}
pub fn tools(mut self, tools: impl IntoIterator<Item = impl AgentTool + 'static>) -> Result<Self> {
for tool in tools {
self.tool_registry.register_typed(tool)?;
}
Ok(self)
}
pub fn tool_registry(mut self, registry: ToolRegistry) -> Self {
self.tool_registry = registry;
self
}
pub fn name(mut self, name: impl Into<String>) -> Self {
self.agent_name = Some(name.into());
self
}
pub fn agent_id(mut self, id: impl Into<String>) -> Self {
self.agent_id = id.into();
self
}
pub fn description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn state(mut self, state: AgentState) -> Self {
self.state = state;
self
}
pub fn hooks(mut self, hooks: HookRegistry) -> Self {
self.hooks = hooks;
self
}
pub fn conversation_manager(mut self, manager: impl ConversationManager + 'static) -> Self {
self.conversation_manager = Some(Box::new(manager));
self
}
pub fn record_direct_tool_call(mut self, record: bool) -> Self {
self.record_direct_tool_call = record;
self
}
pub fn trace_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.trace_attributes.insert(key.into(), value.into());
self
}
pub fn trace_attributes(mut self, attrs: std::collections::HashMap<String, String>) -> Self {
self.trace_attributes = attrs;
self
}
pub fn max_tool_calls(mut self, max: usize) -> Self {
self.max_tool_calls = Some(max);
self
}
pub fn structured_output_model<T: JsonSchema + DeserializeOwned + 'static>(mut self) -> Self {
let context = StructuredOutputContext::with_type::<T>();
self.structured_output_context = Some(context);
self
}
pub fn structured_output_context(mut self, context: StructuredOutputContext) -> Self {
self.structured_output_context = Some(context);
self
}
pub fn build(self) -> Result<Agent> {
let model = self.model.ok_or_else(|| StrandsError::ConfigurationError {
message: "Model is required".to_string(),
})?;
Ok(Agent {
model,
messages: self.messages,
system_prompt: self.system_prompt,
tool_registry: self.tool_registry,
agent_name: self.agent_name,
agent_id: self.agent_id,
description: self.description,
state: self.state,
hooks: self.hooks,
conversation_manager: self.conversation_manager.unwrap_or_else(|| {
Box::new(crate::conversation::SlidingWindowConversationManager::default())
}),
interrupt_state: crate::types::interrupt::InterruptState::new(),
record_direct_tool_call: self.record_direct_tool_call,
trace_attributes: self.trace_attributes,
max_tool_calls: self.max_tool_calls,
structured_output_context: self.structured_output_context,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::BedrockModel;
#[test]
fn test_builder_basic() {
let agent = Agent::builder()
.model(BedrockModel::default())
.system_prompt("Test prompt")
.name("TestAgent")
.build()
.unwrap();
assert_eq!(agent.name(), Some(&"TestAgent".to_string()));
assert_eq!(agent.system_prompt(), Some("Test prompt"));
}
#[test]
fn test_builder_no_model() {
let result = Agent::builder().build();
assert!(result.is_err());
}
}