1use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::RwLock;
15
16use crate::conversation::{BoxedConversationManager, SlidingWindowConversationManager};
17use crate::permission::{GrantStore, ToolAuthorizationPolicy, ToolCallAuthorizer};
18use crate::provider::ModelProvider;
19use crate::tool::{box_tool, DynTool, Tool};
20
21use super::context::{ContextConfig, ContextSource};
22use super::types::{DEFAULT_MAX_CONCURRENT_TOOLS, DEFAULT_PERMISSION_TIMEOUT};
23use super::Agent;
24
25#[cfg(feature = "session")]
26use crate::session::SessionStore;
27
28#[cfg(feature = "bedrock")]
29use crate::model::BedrockModel;
30#[cfg(feature = "bedrock")]
31use crate::provider::BedrockProvider;
32
33#[cfg(feature = "anthropic")]
34use crate::model::AnthropicModel;
35#[cfg(feature = "anthropic")]
36use crate::provider::AnthropicProvider;
37
38type ProviderFactory = Box<
40 dyn FnOnce()
41 -> Pin<Box<dyn Future<Output = crate::error::Result<Arc<dyn ModelProvider>>> + Send>>
42 + Send,
43>;
44
45pub struct AgentBuilder {
71 provider_factory: Option<ProviderFactory>,
72 tools: Vec<Box<dyn DynTool>>,
73 system_prompt: Option<String>,
74 max_concurrent_tools: usize,
75 pub(super) grant_store: Option<Box<dyn GrantStore>>,
77 pub(super) authorization_policy: ToolAuthorizationPolicy,
79 pub(super) authorization_timeout: Duration,
81 trusted_tools: Vec<String>,
83 conversation_manager: Option<BoxedConversationManager>,
84 #[cfg(feature = "session")]
85 session_store: Option<Arc<dyn SessionStore>>,
86 #[cfg(feature = "mcp")]
88 pub(super) mcp_servers: Vec<crate::mcp::McpServerConfig>,
89 #[cfg(feature = "mcp")]
90 pub(super) mcp_config_files: Vec<std::path::PathBuf>,
91 context_sources: Vec<ContextSource>,
94 context_config: ContextConfig,
96}
97
98impl Default for AgentBuilder {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104impl AgentBuilder {
105 pub fn new() -> Self {
107 Self {
108 provider_factory: None,
109 tools: Vec::new(),
110 system_prompt: None,
111 max_concurrent_tools: DEFAULT_MAX_CONCURRENT_TOOLS,
112 grant_store: None,
113 authorization_policy: ToolAuthorizationPolicy::default(), authorization_timeout: DEFAULT_PERMISSION_TIMEOUT,
115 trusted_tools: Vec::new(),
116 conversation_manager: None,
117 #[cfg(feature = "session")]
118 session_store: None,
119 #[cfg(feature = "mcp")]
120 mcp_servers: Vec::new(),
121 #[cfg(feature = "mcp")]
122 mcp_config_files: Vec::new(),
123 context_sources: Vec::new(),
124 context_config: ContextConfig::default(),
125 }
126 }
127
128 #[cfg(feature = "bedrock")]
142 pub fn bedrock(mut self, model: impl BedrockModel + 'static) -> Self {
143 self.provider_factory = Some(Box::new(move || {
144 Box::pin(async move {
145 let provider = BedrockProvider::new(model).await?;
146 Ok(Arc::new(provider) as Arc<dyn ModelProvider>)
147 })
148 }));
149 self
150 }
151
152 #[cfg(feature = "anthropic")]
163 pub fn anthropic(
164 mut self,
165 model: impl AnthropicModel + 'static,
166 api_key: impl Into<String>,
167 ) -> Self {
168 let api_key = api_key.into();
169 self.provider_factory = Some(Box::new(move || {
170 Box::pin(async move {
171 let provider = AnthropicProvider::new(api_key, model)?;
172 Ok(Arc::new(provider) as Arc<dyn ModelProvider>)
173 })
174 }));
175 self
176 }
177
178 #[cfg(feature = "anthropic")]
191 pub fn anthropic_from_env(mut self, model: impl AnthropicModel + 'static) -> Self {
192 self.provider_factory = Some(Box::new(move || {
193 Box::pin(async move {
194 let provider = AnthropicProvider::from_env(model)?;
195 Ok(Arc::new(provider) as Arc<dyn ModelProvider>)
196 })
197 }));
198 self
199 }
200
201 pub fn provider(mut self, provider: impl ModelProvider + 'static) -> Self {
219 let provider = Arc::new(provider) as Arc<dyn ModelProvider>;
220 self.provider_factory = Some(Box::new(move || Box::pin(async move { Ok(provider) })));
221 self
222 }
223
224 pub fn add_tool(mut self, tool: impl Tool + 'static) -> Self {
237 self.tools.push(box_tool(tool));
238 self
239 }
240
241 pub fn add_trusted_tool(mut self, tool: impl Tool + 'static) -> Self {
257 let tool_name = tool.name().to_string();
258 self.tools.push(box_tool(tool));
259 self.trusted_tools.push(tool_name);
260 self
261 }
262
263 pub fn add_tools(mut self, tools: impl IntoIterator<Item = Box<dyn DynTool>>) -> Self {
287 self.tools.extend(tools);
288 self
289 }
290
291 pub fn add_trusted_tools(mut self, tools: impl IntoIterator<Item = Box<dyn DynTool>>) -> Self {
309 for tool in tools {
310 let tool_name = tool.name().to_string();
311 self.tools.push(tool);
312 self.trusted_tools.push(tool_name);
313 }
314 self
315 }
316
317 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
319 self.system_prompt = Some(prompt.into());
320 self
321 }
322
323 pub fn with_max_concurrent_tools(mut self, max: usize) -> Self {
325 self.max_concurrent_tools = max;
326 self
327 }
328
329 pub fn with_conversation_manager(
335 mut self,
336 manager: impl crate::conversation::ConversationManager + 'static,
337 ) -> Self {
338 self.conversation_manager = Some(Box::new(manager));
339 self
340 }
341
342 #[cfg(feature = "session")]
344 pub fn with_session_store(mut self, store: impl SessionStore + 'static) -> Self {
345 self.session_store = Some(Arc::new(store));
346 self
347 }
348
349 pub fn add_context(mut self, content: impl Into<String>) -> Self {
365 self.context_sources.push(ContextSource::Content {
366 content: content.into(),
367 });
368 self
369 }
370
371 pub fn add_context_file(mut self, path: impl Into<String>) -> Self {
392 self.context_sources.push(ContextSource::File {
393 path: path.into(),
394 required: true,
395 });
396 self
397 }
398
399 pub fn add_optional_context_file(mut self, path: impl Into<String>) -> Self {
413 self.context_sources.push(ContextSource::File {
414 path: path.into(),
415 required: false,
416 });
417 self
418 }
419
420 pub fn add_context_files(mut self, paths: impl IntoIterator<Item = impl Into<String>>) -> Self {
434 self.context_sources.push(ContextSource::Files {
435 paths: paths.into_iter().map(|p| p.into()).collect(),
436 required: true,
437 });
438 self
439 }
440
441 pub fn add_optional_context_files(
455 mut self,
456 paths: impl IntoIterator<Item = impl Into<String>>,
457 ) -> Self {
458 self.context_sources.push(ContextSource::Files {
459 paths: paths.into_iter().map(|p| p.into()).collect(),
460 required: false,
461 });
462 self
463 }
464
465 pub fn add_context_files_glob(mut self, pattern: impl Into<String>) -> Self {
481 self.context_sources.push(ContextSource::Glob {
482 pattern: pattern.into(),
483 });
484 self
485 }
486
487 pub fn with_context_config(mut self, config: ContextConfig) -> Self {
504 self.context_config = config;
505 self
506 }
507
508 pub async fn build(self) -> crate::error::Result<Agent> {
531 let provider_factory = self
532 .provider_factory
533 .ok_or_else(|| crate::error::Error::Config(
534 "No provider configured. Call .bedrock(), .anthropic(), or .provider() before .build()".to_string()
535 ))?;
536
537 let provider = provider_factory().await?;
538
539 let conversation_manager = self
540 .conversation_manager
541 .unwrap_or_else(|| Box::new(SlidingWindowConversationManager::new()));
542
543 let authorizer = match self.grant_store {
546 Some(store) => ToolCallAuthorizer::with_boxed_store(store),
547 None => ToolCallAuthorizer::new(),
548 }
549 .with_authorization_policy(self.authorization_policy);
550
551 for tool_name in &self.trusted_tools {
553 authorizer.grant_tool(tool_name).await?;
554 }
555
556 #[allow(unused_mut)]
557 let mut agent = Agent {
558 provider,
559 system_prompt: self.system_prompt,
560 max_concurrent_tools: self.max_concurrent_tools,
561 tools: self.tools,
562 hooks: Arc::new(parking_lot::RwLock::new(Vec::new())),
563 authorizer: Arc::new(RwLock::new(authorizer)),
564 authorization_timeout: self.authorization_timeout,
565 pending_authorizations: Arc::new(RwLock::new(HashMap::new())),
566 #[cfg(feature = "mcp")]
567 mcp_clients: Vec::new(),
568 conversation_manager: parking_lot::RwLock::new(conversation_manager),
569 #[cfg(feature = "session")]
570 session_store: self.session_store,
571 context_sources: self.context_sources,
573 context_config: self.context_config,
574 last_context_result: parking_lot::RwLock::new(None),
575 };
576
577 #[cfg(feature = "mcp")]
579 {
580 super::mcp::connect_mcp_servers(&mut agent, self.mcp_servers, self.mcp_config_files)
581 .await?;
582 }
583
584 Ok(agent)
585 }
586}
587
588impl Agent {
589 pub fn builder() -> AgentBuilder {
610 AgentBuilder::new()
611 }
612
613 }
616
617#[cfg(test)]
618mod tests {
619 use super::*;
620 use crate::box_tools;
621 use crate::conversation::SimpleConversationManager;
622 use crate::provider::{ModelProvider, ProviderError};
623 use crate::types::{ContentBlock, Message, Role, StopReason, ToolDefinition};
624 use crate::ModelResponse;
625
626 #[derive(Clone)]
628 struct MockProvider;
629
630 #[async_trait::async_trait]
631 impl ModelProvider for MockProvider {
632 fn name(&self) -> &str {
633 "MockProvider"
634 }
635
636 fn max_context_tokens(&self) -> usize {
637 200_000
638 }
639
640 fn max_output_tokens(&self) -> usize {
641 8_192
642 }
643
644 async fn generate(
645 &self,
646 _messages: Vec<Message>,
647 _tools: Vec<ToolDefinition>,
648 _system_prompt: Option<String>,
649 ) -> Result<ModelResponse, ProviderError> {
650 Ok(ModelResponse {
651 message: Message {
652 role: Role::Assistant,
653 content: vec![ContentBlock::Text("ok".to_string())],
654 },
655 stop_reason: StopReason::EndTurn,
656 usage: None,
657 })
658 }
659 }
660
661 #[test]
662 fn test_builder_creation() {
663 let builder = Agent::builder();
664 assert!(builder.provider_factory.is_none());
665 assert!(builder.tools.is_empty());
666 assert!(builder.system_prompt.is_none());
667 }
668
669 #[test]
670 fn test_builder_default() {
671 let builder = AgentBuilder::default();
672 assert!(builder.provider_factory.is_none());
673 assert_eq!(builder.max_concurrent_tools, DEFAULT_MAX_CONCURRENT_TOOLS);
674 assert_eq!(builder.authorization_timeout, DEFAULT_PERMISSION_TIMEOUT);
675 }
676
677 #[test]
678 fn test_builder_system_prompt() {
679 let builder = Agent::builder().with_system_prompt("Test prompt");
680 assert_eq!(builder.system_prompt, Some("Test prompt".to_string()));
681 }
682
683 #[test]
684 fn test_builder_max_concurrent_tools() {
685 let builder = Agent::builder().with_max_concurrent_tools(4);
686 assert_eq!(builder.max_concurrent_tools, 4);
687 }
688
689 #[test]
690 fn test_builder_conversation_manager() {
691 let builder =
692 Agent::builder().with_conversation_manager(SimpleConversationManager::new(100));
693 assert!(builder.conversation_manager.is_some());
694 }
695
696 #[tokio::test]
697 async fn test_build_with_provider() {
698 let agent = Agent::builder()
699 .provider(MockProvider)
700 .build()
701 .await
702 .unwrap();
703
704 assert_eq!(agent.provider.name(), "MockProvider");
705 }
706
707 #[tokio::test]
708 async fn test_build_with_system_prompt() {
709 let agent = Agent::builder()
710 .provider(MockProvider)
711 .with_system_prompt("Be helpful")
712 .build()
713 .await
714 .unwrap();
715
716 assert_eq!(agent.system_prompt, Some("Be helpful".to_string()));
717 }
718
719 #[tokio::test]
720 async fn test_build_with_conversation_manager() {
721 let agent = Agent::builder()
722 .provider(MockProvider)
723 .with_conversation_manager(SimpleConversationManager::new(100))
724 .build()
725 .await
726 .unwrap();
727
728 assert_eq!(agent.provider.name(), "MockProvider");
730 }
731
732 #[tokio::test]
733 async fn test_build_without_provider_fails() {
734 let result = Agent::builder().build().await;
735 match result {
736 Err(err) => assert!(err.is_config()),
737 Ok(_) => panic!("Expected error when building without provider"),
738 }
739 }
740
741 #[tokio::test]
742 async fn test_builder_chaining() {
743 let agent = Agent::builder()
744 .provider(MockProvider)
745 .with_system_prompt("Test")
746 .with_max_concurrent_tools(8)
747 .with_authorization_timeout(Duration::from_secs(60))
748 .build()
749 .await
750 .unwrap();
751
752 assert_eq!(agent.system_prompt, Some("Test".to_string()));
753 assert_eq!(agent.max_concurrent_tools, 8);
754 assert_eq!(agent.authorization_timeout, Duration::from_secs(60));
755 }
756
757 #[test]
760 fn test_builder_add_tool_single() {
761 use crate::tool::{Tool, ToolError, ToolResult};
762 use schemars::JsonSchema;
763 use serde::{Deserialize, Serialize};
764
765 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
766 #[allow(dead_code)]
767 struct TestInput {
768 value: String,
769 }
770
771 struct TestTool;
772
773 impl Tool for TestTool {
774 type Input = TestInput;
775 fn name(&self) -> &str {
776 "test_tool"
777 }
778 fn description(&self) -> &str {
779 "A test tool"
780 }
781 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
782 Ok(ToolResult::text("result"))
783 }
784 }
785
786 let builder = Agent::builder().add_tool(TestTool);
787 assert_eq!(builder.tools.len(), 1);
788 assert_eq!(builder.tools[0].name(), "test_tool");
789 }
790
791 #[test]
792 fn test_builder_add_tools_multiple() {
793 use crate::tool::{Tool, ToolError, ToolResult};
794 use schemars::JsonSchema;
795 use serde::{Deserialize, Serialize};
796
797 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
798 #[allow(dead_code)]
799 struct TestInput {
800 value: String,
801 }
802
803 #[derive(Clone)]
804 struct TestTool {
805 name: &'static str,
806 description: &'static str,
807 }
808
809 impl Tool for TestTool {
810 type Input = TestInput;
811 fn name(&self) -> &str {
812 self.name
813 }
814 fn description(&self) -> &str {
815 self.description
816 }
817 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
818 Ok(ToolResult::text(self.name))
819 }
820 }
821
822 let builder = Agent::builder().add_tools(box_tools![
823 TestTool {
824 name: "tool1",
825 description: "First tool",
826 },
827 TestTool {
828 name: "tool2",
829 description: "Second tool",
830 },
831 TestTool {
832 name: "tool3",
833 description: "Third tool",
834 },
835 ]);
836
837 assert_eq!(builder.tools.len(), 3);
838 assert_eq!(builder.tools[0].name(), "tool1");
839 assert_eq!(builder.tools[1].name(), "tool2");
840 assert_eq!(builder.tools[2].name(), "tool3");
841 }
842
843 #[test]
844 fn test_builder_add_tools_empty() {
845 use crate::tool::{Tool, ToolError, ToolResult};
846 use schemars::JsonSchema;
847 use serde::{Deserialize, Serialize};
848
849 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
850 #[allow(dead_code)]
851 struct TestInput {
852 value: String,
853 }
854
855 #[allow(dead_code)]
856 struct TestTool;
857 impl Tool for TestTool {
858 type Input = TestInput;
859 fn name(&self) -> &str {
860 "test"
861 }
862 fn description(&self) -> &str {
863 "Test"
864 }
865 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
866 Ok(ToolResult::text("ok"))
867 }
868 }
869
870 let builder = Agent::builder().add_tools(box_tools![]);
871
872 assert_eq!(builder.tools.len(), 0);
873 }
874
875 #[test]
876 fn test_builder_add_tool_and_add_tools_chaining() {
877 use crate::tool::{Tool, ToolError, ToolResult};
878 use schemars::JsonSchema;
879 use serde::{Deserialize, Serialize};
880
881 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
882 struct TestInput {}
883
884 struct Tool1;
885 impl Tool for Tool1 {
886 type Input = TestInput;
887 fn name(&self) -> &str {
888 "tool1"
889 }
890 fn description(&self) -> &str {
891 "First"
892 }
893 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
894 Ok(ToolResult::text("1"))
895 }
896 }
897
898 #[derive(Clone)]
899 struct Tool2;
900 impl Tool for Tool2 {
901 type Input = TestInput;
902 fn name(&self) -> &str {
903 "tool2"
904 }
905 fn description(&self) -> &str {
906 "Second"
907 }
908 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
909 Ok(ToolResult::text("2"))
910 }
911 }
912
913 let builder = Agent::builder()
915 .add_tool(Tool1)
916 .add_tools(box_tools![Tool2, Tool2]);
917
918 assert_eq!(builder.tools.len(), 3);
919 assert_eq!(builder.tools[0].name(), "tool1");
920 assert_eq!(builder.tools[1].name(), "tool2");
921 assert_eq!(builder.tools[2].name(), "tool2");
922 }
923
924 #[tokio::test]
925 async fn test_build_with_add_tools() {
926 use crate::tool::{Tool, ToolError, ToolResult};
927 use schemars::JsonSchema;
928 use serde::{Deserialize, Serialize};
929
930 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
931 struct TestInput {}
932
933 #[derive(Clone)]
934 struct NamedTool {
935 tool_name: &'static str,
936 tool_desc: &'static str,
937 }
938
939 impl Tool for NamedTool {
940 type Input = TestInput;
941 fn name(&self) -> &str {
942 self.tool_name
943 }
944 fn description(&self) -> &str {
945 self.tool_desc
946 }
947 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
948 Ok(ToolResult::text(self.tool_name))
949 }
950 }
951
952 let agent = Agent::builder()
953 .provider(MockProvider)
954 .add_tools(box_tools![
955 NamedTool {
956 tool_name: "calculator",
957 tool_desc: "Calculates things",
958 },
959 NamedTool {
960 tool_name: "weather",
961 tool_desc: "Gets weather",
962 },
963 ])
964 .build()
965 .await
966 .unwrap();
967
968 let tools = agent.list_tools();
969 assert_eq!(tools.len(), 2);
970
971 let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
972 assert!(names.contains(&"calculator"));
973 assert!(names.contains(&"weather"));
974 }
975}