1use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::atomic::AtomicU64;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::RwLock;
16
17use crate::conversation::{BoxedConversationManager, SlidingWindowConversationManager};
18use crate::permission::{GrantStore, ToolAuthorizationPolicy, ToolCallAuthorizer};
19use crate::provider::ModelProvider;
20use crate::tool::{box_tool, DynTool, Tool};
21
22use super::context::{ContextConfig, ContextSource};
23use super::types::{DEFAULT_MAX_CONCURRENT_TOOLS, DEFAULT_PERMISSION_TIMEOUT};
24use super::Agent;
25
26#[cfg(feature = "session")]
27use crate::session::SessionStore;
28
29#[cfg(feature = "bedrock")]
30use crate::model::BedrockModel;
31#[cfg(feature = "bedrock")]
32use crate::provider::BedrockProvider;
33
34#[cfg(feature = "anthropic")]
35use crate::model::AnthropicModel;
36#[cfg(feature = "anthropic")]
37use crate::provider::AnthropicProvider;
38
39type ProviderFactory = Box<
41 dyn FnOnce()
42 -> Pin<Box<dyn Future<Output = crate::error::Result<Arc<dyn ModelProvider>>> + Send>>
43 + Send,
44>;
45
46pub struct AgentBuilder {
72 provider_factory: Option<ProviderFactory>,
73 tools: Vec<Box<dyn DynTool>>,
74 system_prompt: Option<String>,
75 max_concurrent_tools: usize,
76 pub(super) grant_store: Option<Box<dyn GrantStore>>,
78 pub(super) authorization_policy: ToolAuthorizationPolicy,
80 pub(super) authorization_timeout: Duration,
82 trusted_tools: Vec<String>,
84 conversation_manager: Option<BoxedConversationManager>,
85 #[cfg(feature = "session")]
86 session_store: Option<Arc<dyn SessionStore>>,
87 #[cfg(feature = "mcp")]
89 pub(super) mcp_servers: Vec<crate::mcp::McpServerConfig>,
90 #[cfg(feature = "mcp")]
91 pub(super) mcp_config_files: Vec<std::path::PathBuf>,
92 context_sources: Vec<ContextSource>,
95 context_config: ContextConfig,
97}
98
99impl Default for AgentBuilder {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105impl AgentBuilder {
106 pub fn new() -> Self {
108 Self {
109 provider_factory: None,
110 tools: Vec::new(),
111 system_prompt: None,
112 max_concurrent_tools: DEFAULT_MAX_CONCURRENT_TOOLS,
113 grant_store: None,
114 authorization_policy: ToolAuthorizationPolicy::default(), authorization_timeout: DEFAULT_PERMISSION_TIMEOUT,
116 trusted_tools: Vec::new(),
117 conversation_manager: None,
118 #[cfg(feature = "session")]
119 session_store: None,
120 #[cfg(feature = "mcp")]
121 mcp_servers: Vec::new(),
122 #[cfg(feature = "mcp")]
123 mcp_config_files: Vec::new(),
124 context_sources: Vec::new(),
125 context_config: ContextConfig::default(),
126 }
127 }
128
129 #[cfg(feature = "bedrock")]
143 pub fn bedrock(mut self, model: impl BedrockModel + 'static) -> Self {
144 self.provider_factory = Some(Box::new(move || {
145 Box::pin(async move {
146 let provider = BedrockProvider::new(model).await?;
147 Ok(Arc::new(provider) as Arc<dyn ModelProvider>)
148 })
149 }));
150 self
151 }
152
153 #[cfg(feature = "anthropic")]
164 pub fn anthropic(
165 mut self,
166 model: impl AnthropicModel + 'static,
167 api_key: impl Into<String>,
168 ) -> Self {
169 let api_key = api_key.into();
170 self.provider_factory = Some(Box::new(move || {
171 Box::pin(async move {
172 let provider = AnthropicProvider::new(api_key, model)?;
173 Ok(Arc::new(provider) as Arc<dyn ModelProvider>)
174 })
175 }));
176 self
177 }
178
179 #[cfg(feature = "anthropic")]
192 pub fn anthropic_from_env(mut self, model: impl AnthropicModel + 'static) -> Self {
193 self.provider_factory = Some(Box::new(move || {
194 Box::pin(async move {
195 let provider = AnthropicProvider::from_env(model)?;
196 Ok(Arc::new(provider) as Arc<dyn ModelProvider>)
197 })
198 }));
199 self
200 }
201
202 pub fn provider(mut self, provider: impl ModelProvider + 'static) -> Self {
220 let provider = Arc::new(provider) as Arc<dyn ModelProvider>;
221 self.provider_factory = Some(Box::new(move || Box::pin(async move { Ok(provider) })));
222 self
223 }
224
225 pub fn add_tool(mut self, tool: impl Tool + 'static) -> Self {
238 self.tools.push(box_tool(tool));
239 self
240 }
241
242 pub fn add_trusted_tool(mut self, tool: impl Tool + 'static) -> Self {
258 let tool_name = tool.name().to_string();
259 self.tools.push(box_tool(tool));
260 self.trusted_tools.push(tool_name);
261 self
262 }
263
264 pub fn add_tools(mut self, tools: impl IntoIterator<Item = Box<dyn DynTool>>) -> Self {
288 self.tools.extend(tools);
289 self
290 }
291
292 pub fn add_trusted_tools(mut self, tools: impl IntoIterator<Item = Box<dyn DynTool>>) -> Self {
310 for tool in tools {
311 let tool_name = tool.name().to_string();
312 self.tools.push(tool);
313 self.trusted_tools.push(tool_name);
314 }
315 self
316 }
317
318 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
320 self.system_prompt = Some(prompt.into());
321 self
322 }
323
324 pub fn with_max_concurrent_tools(mut self, max: usize) -> Self {
326 self.max_concurrent_tools = max;
327 self
328 }
329
330 pub fn with_conversation_manager(
336 mut self,
337 manager: impl crate::conversation::ConversationManager + 'static,
338 ) -> Self {
339 self.conversation_manager = Some(Box::new(manager));
340 self
341 }
342
343 #[cfg(feature = "session")]
345 pub fn with_session_store(mut self, store: impl SessionStore + 'static) -> Self {
346 self.session_store = Some(Arc::new(store));
347 self
348 }
349
350 pub fn add_context(mut self, content: impl Into<String>) -> Self {
366 self.context_sources.push(ContextSource::Content {
367 content: content.into(),
368 });
369 self
370 }
371
372 pub fn add_context_file(mut self, path: impl Into<String>) -> Self {
393 self.context_sources.push(ContextSource::File {
394 path: path.into(),
395 required: true,
396 });
397 self
398 }
399
400 pub fn add_optional_context_file(mut self, path: impl Into<String>) -> Self {
414 self.context_sources.push(ContextSource::File {
415 path: path.into(),
416 required: false,
417 });
418 self
419 }
420
421 pub fn add_context_files(mut self, paths: impl IntoIterator<Item = impl Into<String>>) -> Self {
435 self.context_sources.push(ContextSource::Files {
436 paths: paths.into_iter().map(|p| p.into()).collect(),
437 required: true,
438 });
439 self
440 }
441
442 pub fn add_optional_context_files(
456 mut self,
457 paths: impl IntoIterator<Item = impl Into<String>>,
458 ) -> Self {
459 self.context_sources.push(ContextSource::Files {
460 paths: paths.into_iter().map(|p| p.into()).collect(),
461 required: false,
462 });
463 self
464 }
465
466 pub fn add_context_files_glob(mut self, pattern: impl Into<String>) -> Self {
482 self.context_sources.push(ContextSource::Glob {
483 pattern: pattern.into(),
484 });
485 self
486 }
487
488 pub fn with_context_config(mut self, config: ContextConfig) -> Self {
505 self.context_config = config;
506 self
507 }
508
509 pub async fn build(self) -> crate::error::Result<Agent> {
532 let provider_factory = self
533 .provider_factory
534 .ok_or_else(|| crate::error::Error::Config(
535 "No provider configured. Call .bedrock(), .anthropic(), or .provider() before .build()".to_string()
536 ))?;
537
538 let provider = provider_factory().await?;
539
540 let conversation_manager = self
541 .conversation_manager
542 .unwrap_or_else(|| Box::new(SlidingWindowConversationManager::new()));
543
544 let authorizer = match self.grant_store {
547 Some(store) => ToolCallAuthorizer::with_boxed_store(store),
548 None => ToolCallAuthorizer::new(),
549 }
550 .with_authorization_policy(self.authorization_policy);
551
552 for tool_name in &self.trusted_tools {
554 authorizer.grant_tool(tool_name).await?;
555 }
556
557 #[allow(unused_mut)]
558 let mut agent = Agent {
559 provider,
560 system_prompt: self.system_prompt,
561 max_concurrent_tools: self.max_concurrent_tools,
562 tools: self.tools,
563 hooks: Arc::new(parking_lot::RwLock::new(HashMap::new())),
564 next_hook_id: AtomicU64::new(0),
565 authorizer: Arc::new(RwLock::new(authorizer)),
566 authorization_timeout: self.authorization_timeout,
567 pending_authorizations: Arc::new(RwLock::new(HashMap::new())),
568 #[cfg(feature = "mcp")]
569 mcp_clients: Vec::new(),
570 conversation_manager: parking_lot::RwLock::new(conversation_manager),
571 #[cfg(feature = "session")]
572 session_store: self.session_store,
573 context_sources: self.context_sources,
575 context_config: self.context_config,
576 last_context_result: parking_lot::RwLock::new(None),
577 };
578
579 #[cfg(feature = "mcp")]
581 {
582 super::mcp::connect_mcp_servers(&mut agent, self.mcp_servers, self.mcp_config_files)
583 .await?;
584 }
585
586 Ok(agent)
587 }
588}
589
590impl Agent {
591 pub fn builder() -> AgentBuilder {
612 AgentBuilder::new()
613 }
614
615 }
618
619#[cfg(test)]
620mod tests {
621 use super::*;
622 use crate::box_tools;
623 use crate::conversation::SimpleConversationManager;
624 use crate::provider::{ModelProvider, ProviderError};
625 use crate::types::{ContentBlock, Message, Role, StopReason, ToolDefinition};
626 use crate::ModelResponse;
627
628 #[derive(Clone)]
630 struct MockProvider;
631
632 #[async_trait::async_trait]
633 impl ModelProvider for MockProvider {
634 fn name(&self) -> &str {
635 "MockProvider"
636 }
637
638 fn max_context_tokens(&self) -> usize {
639 200_000
640 }
641
642 fn max_output_tokens(&self) -> usize {
643 8_192
644 }
645
646 async fn generate(
647 &self,
648 _messages: Vec<Message>,
649 _tools: Vec<ToolDefinition>,
650 _system_prompt: Option<String>,
651 ) -> Result<ModelResponse, ProviderError> {
652 Ok(ModelResponse {
653 message: Message {
654 role: Role::Assistant,
655 content: vec![ContentBlock::Text("ok".to_string())],
656 },
657 stop_reason: StopReason::EndTurn,
658 usage: None,
659 })
660 }
661 }
662
663 #[test]
664 fn test_builder_creation() {
665 let builder = Agent::builder();
666 assert!(builder.provider_factory.is_none());
667 assert!(builder.tools.is_empty());
668 assert!(builder.system_prompt.is_none());
669 }
670
671 #[test]
672 fn test_builder_default() {
673 let builder = AgentBuilder::default();
674 assert!(builder.provider_factory.is_none());
675 assert_eq!(builder.max_concurrent_tools, DEFAULT_MAX_CONCURRENT_TOOLS);
676 assert_eq!(builder.authorization_timeout, DEFAULT_PERMISSION_TIMEOUT);
677 }
678
679 #[test]
680 fn test_builder_system_prompt() {
681 let builder = Agent::builder().with_system_prompt("Test prompt");
682 assert_eq!(builder.system_prompt, Some("Test prompt".to_string()));
683 }
684
685 #[test]
686 fn test_builder_max_concurrent_tools() {
687 let builder = Agent::builder().with_max_concurrent_tools(4);
688 assert_eq!(builder.max_concurrent_tools, 4);
689 }
690
691 #[test]
692 fn test_builder_conversation_manager() {
693 let builder =
694 Agent::builder().with_conversation_manager(SimpleConversationManager::new(100));
695 assert!(builder.conversation_manager.is_some());
696 }
697
698 #[tokio::test]
699 async fn test_build_with_provider() {
700 let agent = Agent::builder()
701 .provider(MockProvider)
702 .build()
703 .await
704 .unwrap();
705
706 assert_eq!(agent.provider.name(), "MockProvider");
707 }
708
709 #[tokio::test]
710 async fn test_build_with_system_prompt() {
711 let agent = Agent::builder()
712 .provider(MockProvider)
713 .with_system_prompt("Be helpful")
714 .build()
715 .await
716 .unwrap();
717
718 assert_eq!(agent.system_prompt, Some("Be helpful".to_string()));
719 }
720
721 #[tokio::test]
722 async fn test_build_with_conversation_manager() {
723 let agent = Agent::builder()
724 .provider(MockProvider)
725 .with_conversation_manager(SimpleConversationManager::new(100))
726 .build()
727 .await
728 .unwrap();
729
730 assert_eq!(agent.provider.name(), "MockProvider");
732 }
733
734 #[tokio::test]
735 async fn test_build_without_provider_fails() {
736 let result = Agent::builder().build().await;
737 match result {
738 Err(err) => assert!(err.is_config()),
739 Ok(_) => panic!("Expected error when building without provider"),
740 }
741 }
742
743 #[tokio::test]
744 async fn test_builder_chaining() {
745 let agent = Agent::builder()
746 .provider(MockProvider)
747 .with_system_prompt("Test")
748 .with_max_concurrent_tools(8)
749 .with_authorization_timeout(Duration::from_secs(60))
750 .build()
751 .await
752 .unwrap();
753
754 assert_eq!(agent.system_prompt, Some("Test".to_string()));
755 assert_eq!(agent.max_concurrent_tools, 8);
756 assert_eq!(agent.authorization_timeout, Duration::from_secs(60));
757 }
758
759 #[test]
762 fn test_builder_add_tool_single() {
763 use crate::tool::{Tool, ToolError, ToolResult};
764 use schemars::JsonSchema;
765 use serde::{Deserialize, Serialize};
766
767 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
768 #[allow(dead_code)]
769 struct TestInput {
770 value: String,
771 }
772
773 struct TestTool;
774
775 impl Tool for TestTool {
776 type Input = TestInput;
777 fn name(&self) -> &str {
778 "test_tool"
779 }
780 fn description(&self) -> &str {
781 "A test tool"
782 }
783 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
784 Ok(ToolResult::text("result"))
785 }
786 }
787
788 let builder = Agent::builder().add_tool(TestTool);
789 assert_eq!(builder.tools.len(), 1);
790 assert_eq!(builder.tools[0].name(), "test_tool");
791 }
792
793 #[test]
794 fn test_builder_add_tools_multiple() {
795 use crate::tool::{Tool, ToolError, ToolResult};
796 use schemars::JsonSchema;
797 use serde::{Deserialize, Serialize};
798
799 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
800 #[allow(dead_code)]
801 struct TestInput {
802 value: String,
803 }
804
805 #[derive(Clone)]
806 struct TestTool {
807 name: &'static str,
808 description: &'static str,
809 }
810
811 impl Tool for TestTool {
812 type Input = TestInput;
813 fn name(&self) -> &str {
814 self.name
815 }
816 fn description(&self) -> &str {
817 self.description
818 }
819 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
820 Ok(ToolResult::text(self.name))
821 }
822 }
823
824 let builder = Agent::builder().add_tools(box_tools![
825 TestTool {
826 name: "tool1",
827 description: "First tool",
828 },
829 TestTool {
830 name: "tool2",
831 description: "Second tool",
832 },
833 TestTool {
834 name: "tool3",
835 description: "Third tool",
836 },
837 ]);
838
839 assert_eq!(builder.tools.len(), 3);
840 assert_eq!(builder.tools[0].name(), "tool1");
841 assert_eq!(builder.tools[1].name(), "tool2");
842 assert_eq!(builder.tools[2].name(), "tool3");
843 }
844
845 #[test]
846 fn test_builder_add_tools_empty() {
847 use crate::tool::{Tool, ToolError, ToolResult};
848 use schemars::JsonSchema;
849 use serde::{Deserialize, Serialize};
850
851 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
852 #[allow(dead_code)]
853 struct TestInput {
854 value: String,
855 }
856
857 #[allow(dead_code)]
858 struct TestTool;
859 impl Tool for TestTool {
860 type Input = TestInput;
861 fn name(&self) -> &str {
862 "test"
863 }
864 fn description(&self) -> &str {
865 "Test"
866 }
867 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
868 Ok(ToolResult::text("ok"))
869 }
870 }
871
872 let builder = Agent::builder().add_tools(box_tools![]);
873
874 assert_eq!(builder.tools.len(), 0);
875 }
876
877 #[test]
878 fn test_builder_add_tool_and_add_tools_chaining() {
879 use crate::tool::{Tool, ToolError, ToolResult};
880 use schemars::JsonSchema;
881 use serde::{Deserialize, Serialize};
882
883 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
884 struct TestInput {}
885
886 struct Tool1;
887 impl Tool for Tool1 {
888 type Input = TestInput;
889 fn name(&self) -> &str {
890 "tool1"
891 }
892 fn description(&self) -> &str {
893 "First"
894 }
895 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
896 Ok(ToolResult::text("1"))
897 }
898 }
899
900 #[derive(Clone)]
901 struct Tool2;
902 impl Tool for Tool2 {
903 type Input = TestInput;
904 fn name(&self) -> &str {
905 "tool2"
906 }
907 fn description(&self) -> &str {
908 "Second"
909 }
910 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
911 Ok(ToolResult::text("2"))
912 }
913 }
914
915 let builder = Agent::builder()
917 .add_tool(Tool1)
918 .add_tools(box_tools![Tool2, Tool2]);
919
920 assert_eq!(builder.tools.len(), 3);
921 assert_eq!(builder.tools[0].name(), "tool1");
922 assert_eq!(builder.tools[1].name(), "tool2");
923 assert_eq!(builder.tools[2].name(), "tool2");
924 }
925
926 #[tokio::test]
927 async fn test_build_with_add_tools() {
928 use crate::tool::{Tool, ToolError, ToolResult};
929 use schemars::JsonSchema;
930 use serde::{Deserialize, Serialize};
931
932 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
933 struct TestInput {}
934
935 #[derive(Clone)]
936 struct NamedTool {
937 tool_name: &'static str,
938 tool_desc: &'static str,
939 }
940
941 impl Tool for NamedTool {
942 type Input = TestInput;
943 fn name(&self) -> &str {
944 self.tool_name
945 }
946 fn description(&self) -> &str {
947 self.tool_desc
948 }
949 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
950 Ok(ToolResult::text(self.tool_name))
951 }
952 }
953
954 let agent = Agent::builder()
955 .provider(MockProvider)
956 .add_tools(box_tools![
957 NamedTool {
958 tool_name: "calculator",
959 tool_desc: "Calculates things",
960 },
961 NamedTool {
962 tool_name: "weather",
963 tool_desc: "Gets weather",
964 },
965 ])
966 .build()
967 .await
968 .unwrap();
969
970 let tools = agent.list_tools();
971 assert_eq!(tools.len(), 2);
972
973 let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
974 assert!(names.contains(&"calculator"));
975 assert!(names.contains(&"weather"));
976 }
977}