use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use crate::conversation::{BoxedConversationManager, SlidingWindowConversationManager};
use crate::permission::{GrantStore, ToolAuthorizationPolicy, ToolCallAuthorizer};
use crate::provider::ModelProvider;
use crate::tool::{box_tool, DynTool, Tool};
use super::context::{ContextConfig, ContextSource};
use super::types::{DEFAULT_MAX_CONCURRENT_TOOLS, DEFAULT_PERMISSION_TIMEOUT};
use super::Agent;
#[cfg(feature = "session")]
use crate::session::SessionStore;
#[cfg(feature = "bedrock")]
use crate::model::BedrockModel;
#[cfg(feature = "bedrock")]
use crate::provider::BedrockProvider;
#[cfg(feature = "anthropic")]
use crate::model::AnthropicModel;
#[cfg(feature = "anthropic")]
use crate::provider::AnthropicProvider;
type ProviderFactory = Box<
dyn FnOnce()
-> Pin<Box<dyn Future<Output = crate::error::Result<Arc<dyn ModelProvider>>> + Send>>
+ Send,
>;
pub struct AgentBuilder {
provider_factory: Option<ProviderFactory>,
tools: Vec<Box<dyn DynTool>>,
system_prompt: Option<String>,
max_concurrent_tools: usize,
pub(super) grant_store: Option<Box<dyn GrantStore>>,
pub(super) authorization_policy: ToolAuthorizationPolicy,
pub(super) authorization_timeout: Duration,
trusted_tools: Vec<String>,
conversation_manager: Option<BoxedConversationManager>,
#[cfg(feature = "session")]
session_store: Option<Arc<dyn SessionStore>>,
#[cfg(feature = "mcp")]
pub(super) mcp_servers: Vec<crate::mcp::McpServerConfig>,
#[cfg(feature = "mcp")]
pub(super) mcp_config_files: Vec<std::path::PathBuf>,
context_sources: Vec<ContextSource>,
context_config: ContextConfig,
}
impl Default for AgentBuilder {
fn default() -> Self {
Self::new()
}
}
impl AgentBuilder {
pub fn new() -> Self {
Self {
provider_factory: None,
tools: Vec::new(),
system_prompt: None,
max_concurrent_tools: DEFAULT_MAX_CONCURRENT_TOOLS,
grant_store: None,
authorization_policy: ToolAuthorizationPolicy::default(), authorization_timeout: DEFAULT_PERMISSION_TIMEOUT,
trusted_tools: Vec::new(),
conversation_manager: None,
#[cfg(feature = "session")]
session_store: None,
#[cfg(feature = "mcp")]
mcp_servers: Vec::new(),
#[cfg(feature = "mcp")]
mcp_config_files: Vec::new(),
context_sources: Vec::new(),
context_config: ContextConfig::default(),
}
}
#[cfg(feature = "bedrock")]
pub fn bedrock(mut self, model: impl BedrockModel + 'static) -> Self {
self.provider_factory = Some(Box::new(move || {
Box::pin(async move {
let provider = BedrockProvider::new(model).await?;
Ok(Arc::new(provider) as Arc<dyn ModelProvider>)
})
}));
self
}
#[cfg(feature = "anthropic")]
pub fn anthropic(
mut self,
model: impl AnthropicModel + 'static,
api_key: impl Into<String>,
) -> Self {
let api_key = api_key.into();
self.provider_factory = Some(Box::new(move || {
Box::pin(async move {
let provider = AnthropicProvider::new(api_key, model)?;
Ok(Arc::new(provider) as Arc<dyn ModelProvider>)
})
}));
self
}
#[cfg(feature = "anthropic")]
pub fn anthropic_from_env(mut self, model: impl AnthropicModel + 'static) -> Self {
self.provider_factory = Some(Box::new(move || {
Box::pin(async move {
let provider = AnthropicProvider::from_env(model)?;
Ok(Arc::new(provider) as Arc<dyn ModelProvider>)
})
}));
self
}
pub fn provider(mut self, provider: impl ModelProvider + 'static) -> Self {
let provider = Arc::new(provider) as Arc<dyn ModelProvider>;
self.provider_factory = Some(Box::new(move || Box::pin(async move { Ok(provider) })));
self
}
pub fn add_tool(mut self, tool: impl Tool + 'static) -> Self {
self.tools.push(box_tool(tool));
self
}
pub fn add_trusted_tool(mut self, tool: impl Tool + 'static) -> Self {
let tool_name = tool.name().to_string();
self.tools.push(box_tool(tool));
self.trusted_tools.push(tool_name);
self
}
pub fn add_tools(mut self, tools: impl IntoIterator<Item = Box<dyn DynTool>>) -> Self {
self.tools.extend(tools);
self
}
pub fn add_trusted_tools(mut self, tools: impl IntoIterator<Item = Box<dyn DynTool>>) -> Self {
for tool in tools {
let tool_name = tool.name().to_string();
self.tools.push(tool);
self.trusted_tools.push(tool_name);
}
self
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn with_max_concurrent_tools(mut self, max: usize) -> Self {
self.max_concurrent_tools = max;
self
}
pub fn with_conversation_manager(
mut self,
manager: impl crate::conversation::ConversationManager + 'static,
) -> Self {
self.conversation_manager = Some(Box::new(manager));
self
}
#[cfg(feature = "session")]
pub fn with_session_store(mut self, store: impl SessionStore + 'static) -> Self {
self.session_store = Some(Arc::new(store));
self
}
pub fn add_context(mut self, content: impl Into<String>) -> Self {
self.context_sources.push(ContextSource::Content {
content: content.into(),
});
self
}
pub fn add_context_file(mut self, path: impl Into<String>) -> Self {
self.context_sources.push(ContextSource::File {
path: path.into(),
required: true,
});
self
}
pub fn add_optional_context_file(mut self, path: impl Into<String>) -> Self {
self.context_sources.push(ContextSource::File {
path: path.into(),
required: false,
});
self
}
pub fn add_context_files(mut self, paths: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.context_sources.push(ContextSource::Files {
paths: paths.into_iter().map(|p| p.into()).collect(),
required: true,
});
self
}
pub fn add_optional_context_files(
mut self,
paths: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.context_sources.push(ContextSource::Files {
paths: paths.into_iter().map(|p| p.into()).collect(),
required: false,
});
self
}
pub fn add_context_files_glob(mut self, pattern: impl Into<String>) -> Self {
self.context_sources.push(ContextSource::Glob {
pattern: pattern.into(),
});
self
}
pub fn with_context_config(mut self, config: ContextConfig) -> Self {
self.context_config = config;
self
}
pub async fn build(self) -> crate::error::Result<Agent> {
let provider_factory = self
.provider_factory
.ok_or_else(|| crate::error::Error::Config(
"No provider configured. Call .bedrock(), .anthropic(), or .provider() before .build()".to_string()
))?;
let provider = provider_factory().await?;
let conversation_manager = self
.conversation_manager
.unwrap_or_else(|| Box::new(SlidingWindowConversationManager::new()));
let authorizer = match self.grant_store {
Some(store) => ToolCallAuthorizer::with_boxed_store(store),
None => ToolCallAuthorizer::new(),
}
.with_authorization_policy(self.authorization_policy);
for tool_name in &self.trusted_tools {
authorizer.grant_tool(tool_name).await?;
}
#[allow(unused_mut)]
let mut agent = Agent {
provider,
system_prompt: self.system_prompt,
max_concurrent_tools: self.max_concurrent_tools,
tools: self.tools,
hooks: Arc::new(parking_lot::RwLock::new(HashMap::new())),
next_hook_id: AtomicU64::new(0),
authorizer: Arc::new(RwLock::new(authorizer)),
authorization_timeout: self.authorization_timeout,
pending_authorizations: Arc::new(RwLock::new(HashMap::new())),
#[cfg(feature = "mcp")]
mcp_clients: Vec::new(),
conversation_manager: parking_lot::RwLock::new(conversation_manager),
#[cfg(feature = "session")]
session_store: self.session_store,
context_sources: self.context_sources,
context_config: self.context_config,
last_context_result: parking_lot::RwLock::new(None),
};
#[cfg(feature = "mcp")]
{
super::mcp::connect_mcp_servers(&mut agent, self.mcp_servers, self.mcp_config_files)
.await?;
}
Ok(agent)
}
}
impl Agent {
pub fn builder() -> AgentBuilder {
AgentBuilder::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::box_tools;
use crate::conversation::SimpleConversationManager;
use crate::provider::{ModelProvider, ProviderError};
use crate::types::{ContentBlock, Message, Role, StopReason, ToolDefinition};
use crate::ModelResponse;
#[derive(Clone)]
struct MockProvider;
#[async_trait::async_trait]
impl ModelProvider for MockProvider {
fn name(&self) -> &str {
"MockProvider"
}
fn max_context_tokens(&self) -> usize {
200_000
}
fn max_output_tokens(&self) -> usize {
8_192
}
async fn generate(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_system_prompt: Option<String>,
) -> Result<ModelResponse, ProviderError> {
Ok(ModelResponse {
message: Message {
role: Role::Assistant,
content: vec![ContentBlock::Text("ok".to_string())],
},
stop_reason: StopReason::EndTurn,
usage: None,
})
}
}
#[test]
fn test_builder_creation() {
let builder = Agent::builder();
assert!(builder.provider_factory.is_none());
assert!(builder.tools.is_empty());
assert!(builder.system_prompt.is_none());
}
#[test]
fn test_builder_default() {
let builder = AgentBuilder::default();
assert!(builder.provider_factory.is_none());
assert_eq!(builder.max_concurrent_tools, DEFAULT_MAX_CONCURRENT_TOOLS);
assert_eq!(builder.authorization_timeout, DEFAULT_PERMISSION_TIMEOUT);
}
#[test]
fn test_builder_system_prompt() {
let builder = Agent::builder().with_system_prompt("Test prompt");
assert_eq!(builder.system_prompt, Some("Test prompt".to_string()));
}
#[test]
fn test_builder_max_concurrent_tools() {
let builder = Agent::builder().with_max_concurrent_tools(4);
assert_eq!(builder.max_concurrent_tools, 4);
}
#[test]
fn test_builder_conversation_manager() {
let builder =
Agent::builder().with_conversation_manager(SimpleConversationManager::new(100));
assert!(builder.conversation_manager.is_some());
}
#[tokio::test]
async fn test_build_with_provider() {
let agent = Agent::builder()
.provider(MockProvider)
.build()
.await
.unwrap();
assert_eq!(agent.provider.name(), "MockProvider");
}
#[tokio::test]
async fn test_build_with_system_prompt() {
let agent = Agent::builder()
.provider(MockProvider)
.with_system_prompt("Be helpful")
.build()
.await
.unwrap();
assert_eq!(agent.system_prompt, Some("Be helpful".to_string()));
}
#[tokio::test]
async fn test_build_with_conversation_manager() {
let agent = Agent::builder()
.provider(MockProvider)
.with_conversation_manager(SimpleConversationManager::new(100))
.build()
.await
.unwrap();
assert_eq!(agent.provider.name(), "MockProvider");
}
#[tokio::test]
async fn test_build_without_provider_fails() {
let result = Agent::builder().build().await;
match result {
Err(err) => assert!(err.is_config()),
Ok(_) => panic!("Expected error when building without provider"),
}
}
#[tokio::test]
async fn test_builder_chaining() {
let agent = Agent::builder()
.provider(MockProvider)
.with_system_prompt("Test")
.with_max_concurrent_tools(8)
.with_authorization_timeout(Duration::from_secs(60))
.build()
.await
.unwrap();
assert_eq!(agent.system_prompt, Some("Test".to_string()));
assert_eq!(agent.max_concurrent_tools, 8);
assert_eq!(agent.authorization_timeout, Duration::from_secs(60));
}
#[test]
fn test_builder_add_tool_single() {
use crate::tool::{Tool, ToolError, ToolResult};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
#[allow(dead_code)]
struct TestInput {
value: String,
}
struct TestTool;
impl Tool for TestTool {
type Input = TestInput;
fn name(&self) -> &str {
"test_tool"
}
fn description(&self) -> &str {
"A test tool"
}
async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
Ok(ToolResult::text("result"))
}
}
let builder = Agent::builder().add_tool(TestTool);
assert_eq!(builder.tools.len(), 1);
assert_eq!(builder.tools[0].name(), "test_tool");
}
#[test]
fn test_builder_add_tools_multiple() {
use crate::tool::{Tool, ToolError, ToolResult};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
#[allow(dead_code)]
struct TestInput {
value: String,
}
#[derive(Clone)]
struct TestTool {
name: &'static str,
description: &'static str,
}
impl Tool for TestTool {
type Input = TestInput;
fn name(&self) -> &str {
self.name
}
fn description(&self) -> &str {
self.description
}
async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
Ok(ToolResult::text(self.name))
}
}
let builder = Agent::builder().add_tools(box_tools![
TestTool {
name: "tool1",
description: "First tool",
},
TestTool {
name: "tool2",
description: "Second tool",
},
TestTool {
name: "tool3",
description: "Third tool",
},
]);
assert_eq!(builder.tools.len(), 3);
assert_eq!(builder.tools[0].name(), "tool1");
assert_eq!(builder.tools[1].name(), "tool2");
assert_eq!(builder.tools[2].name(), "tool3");
}
#[test]
fn test_builder_add_tools_empty() {
use crate::tool::{Tool, ToolError, ToolResult};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
#[allow(dead_code)]
struct TestInput {
value: String,
}
#[allow(dead_code)]
struct TestTool;
impl Tool for TestTool {
type Input = TestInput;
fn name(&self) -> &str {
"test"
}
fn description(&self) -> &str {
"Test"
}
async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
Ok(ToolResult::text("ok"))
}
}
let builder = Agent::builder().add_tools(box_tools![]);
assert_eq!(builder.tools.len(), 0);
}
#[test]
fn test_builder_add_tool_and_add_tools_chaining() {
use crate::tool::{Tool, ToolError, ToolResult};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
struct TestInput {}
struct Tool1;
impl Tool for Tool1 {
type Input = TestInput;
fn name(&self) -> &str {
"tool1"
}
fn description(&self) -> &str {
"First"
}
async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
Ok(ToolResult::text("1"))
}
}
#[derive(Clone)]
struct Tool2;
impl Tool for Tool2 {
type Input = TestInput;
fn name(&self) -> &str {
"tool2"
}
fn description(&self) -> &str {
"Second"
}
async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
Ok(ToolResult::text("2"))
}
}
let builder = Agent::builder()
.add_tool(Tool1)
.add_tools(box_tools![Tool2, Tool2]);
assert_eq!(builder.tools.len(), 3);
assert_eq!(builder.tools[0].name(), "tool1");
assert_eq!(builder.tools[1].name(), "tool2");
assert_eq!(builder.tools[2].name(), "tool2");
}
#[tokio::test]
async fn test_build_with_add_tools() {
use crate::tool::{Tool, ToolError, ToolResult};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
struct TestInput {}
#[derive(Clone)]
struct NamedTool {
tool_name: &'static str,
tool_desc: &'static str,
}
impl Tool for NamedTool {
type Input = TestInput;
fn name(&self) -> &str {
self.tool_name
}
fn description(&self) -> &str {
self.tool_desc
}
async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
Ok(ToolResult::text(self.tool_name))
}
}
let agent = Agent::builder()
.provider(MockProvider)
.add_tools(box_tools![
NamedTool {
tool_name: "calculator",
tool_desc: "Calculates things",
},
NamedTool {
tool_name: "weather",
tool_desc: "Gets weather",
},
])
.build()
.await
.unwrap();
let tools = agent.list_tools();
assert_eq!(tools.len(), 2);
let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
assert!(names.contains(&"calculator"));
assert!(names.contains(&"weather"));
}
}