1use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::RwLock;
15
16use crate::conversation::{BoxedConversationManager, SlidingWindowConversationManager};
17use crate::permission::{GrantStore, ToolAuthorizationPolicy, ToolCallAuthorizer};
18use crate::provider::ModelProvider;
19use crate::tool::{box_tool, DynTool, Tool};
20
21use super::context::{ContextConfig, ContextSource};
22use super::types::{DEFAULT_MAX_CONCURRENT_TOOLS, DEFAULT_PERMISSION_TIMEOUT};
23use super::Agent;
24
25#[cfg(feature = "session")]
26use crate::session::SessionStore;
27
28#[cfg(feature = "bedrock")]
29use crate::model::BedrockModel;
30#[cfg(feature = "bedrock")]
31use crate::provider::BedrockProvider;
32
33#[cfg(feature = "anthropic")]
34use crate::model::AnthropicModel;
35#[cfg(feature = "anthropic")]
36use crate::provider::AnthropicProvider;
37
38type ProviderFactory = Box<
40 dyn FnOnce()
41 -> Pin<Box<dyn Future<Output = crate::error::Result<Arc<dyn ModelProvider>>> + Send>>
42 + Send,
43>;
44
45pub struct AgentBuilder {
71 provider_factory: Option<ProviderFactory>,
72 tools: Vec<Box<dyn DynTool>>,
73 system_prompt: Option<String>,
74 max_concurrent_tools: usize,
75 pub(super) grant_store: Option<Box<dyn GrantStore>>,
77 pub(super) authorization_policy: ToolAuthorizationPolicy,
79 pub(super) authorization_timeout: Duration,
81 conversation_manager: Option<BoxedConversationManager>,
82 #[cfg(feature = "session")]
83 session_store: Option<Arc<dyn SessionStore>>,
84 #[cfg(feature = "mcp")]
86 pub(super) mcp_servers: Vec<crate::mcp::McpServerConfig>,
87 #[cfg(feature = "mcp")]
88 pub(super) mcp_config_files: Vec<std::path::PathBuf>,
89 context_sources: Vec<ContextSource>,
92 context_config: ContextConfig,
94}
95
96impl Default for AgentBuilder {
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102impl AgentBuilder {
103 pub fn new() -> Self {
105 Self {
106 provider_factory: None,
107 tools: Vec::new(),
108 system_prompt: None,
109 max_concurrent_tools: DEFAULT_MAX_CONCURRENT_TOOLS,
110 grant_store: None,
111 authorization_policy: ToolAuthorizationPolicy::default(), authorization_timeout: DEFAULT_PERMISSION_TIMEOUT,
113 conversation_manager: None,
114 #[cfg(feature = "session")]
115 session_store: None,
116 #[cfg(feature = "mcp")]
117 mcp_servers: Vec::new(),
118 #[cfg(feature = "mcp")]
119 mcp_config_files: Vec::new(),
120 context_sources: Vec::new(),
121 context_config: ContextConfig::default(),
122 }
123 }
124
125 #[cfg(feature = "bedrock")]
139 pub fn bedrock(mut self, model: impl BedrockModel + 'static) -> Self {
140 self.provider_factory = Some(Box::new(move || {
141 Box::pin(async move {
142 let provider = BedrockProvider::new(model).await?;
143 Ok(Arc::new(provider) as Arc<dyn ModelProvider>)
144 })
145 }));
146 self
147 }
148
149 #[cfg(feature = "anthropic")]
160 pub fn anthropic(
161 mut self,
162 model: impl AnthropicModel + 'static,
163 api_key: impl Into<String>,
164 ) -> Self {
165 let api_key = api_key.into();
166 self.provider_factory = Some(Box::new(move || {
167 Box::pin(async move {
168 let provider = AnthropicProvider::new(api_key, model)?;
169 Ok(Arc::new(provider) as Arc<dyn ModelProvider>)
170 })
171 }));
172 self
173 }
174
175 #[cfg(feature = "anthropic")]
188 pub fn anthropic_from_env(mut self, model: impl AnthropicModel + 'static) -> Self {
189 self.provider_factory = Some(Box::new(move || {
190 Box::pin(async move {
191 let provider = AnthropicProvider::from_env(model)?;
192 Ok(Arc::new(provider) as Arc<dyn ModelProvider>)
193 })
194 }));
195 self
196 }
197
198 pub fn provider(mut self, provider: impl ModelProvider + 'static) -> Self {
216 let provider = Arc::new(provider) as Arc<dyn ModelProvider>;
217 self.provider_factory = Some(Box::new(move || Box::pin(async move { Ok(provider) })));
218 self
219 }
220
221 pub fn add_tool(mut self, tool: impl Tool + 'static) -> Self {
234 self.tools.push(box_tool(tool));
235 self
236 }
237
238 pub fn add_tools(mut self, tools: impl IntoIterator<Item = Box<dyn DynTool>>) -> Self {
262 self.tools.extend(tools);
263 self
264 }
265
266 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
268 self.system_prompt = Some(prompt.into());
269 self
270 }
271
272 pub fn with_max_concurrent_tools(mut self, max: usize) -> Self {
274 self.max_concurrent_tools = max;
275 self
276 }
277
278 pub fn with_conversation_manager(
284 mut self,
285 manager: impl crate::conversation::ConversationManager + 'static,
286 ) -> Self {
287 self.conversation_manager = Some(Box::new(manager));
288 self
289 }
290
291 #[cfg(feature = "session")]
293 pub fn with_session_store(mut self, store: impl SessionStore + 'static) -> Self {
294 self.session_store = Some(Arc::new(store));
295 self
296 }
297
298 pub fn add_context(mut self, content: impl Into<String>) -> Self {
314 self.context_sources.push(ContextSource::Content {
315 content: content.into(),
316 });
317 self
318 }
319
320 pub fn add_context_file(mut self, path: impl Into<String>) -> Self {
341 self.context_sources.push(ContextSource::File {
342 path: path.into(),
343 required: true,
344 });
345 self
346 }
347
348 pub fn add_optional_context_file(mut self, path: impl Into<String>) -> Self {
362 self.context_sources.push(ContextSource::File {
363 path: path.into(),
364 required: false,
365 });
366 self
367 }
368
369 pub fn add_context_files(mut self, paths: impl IntoIterator<Item = impl Into<String>>) -> Self {
383 self.context_sources.push(ContextSource::Files {
384 paths: paths.into_iter().map(|p| p.into()).collect(),
385 required: true,
386 });
387 self
388 }
389
390 pub fn add_optional_context_files(
404 mut self,
405 paths: impl IntoIterator<Item = impl Into<String>>,
406 ) -> Self {
407 self.context_sources.push(ContextSource::Files {
408 paths: paths.into_iter().map(|p| p.into()).collect(),
409 required: false,
410 });
411 self
412 }
413
414 pub fn add_context_files_glob(mut self, pattern: impl Into<String>) -> Self {
430 self.context_sources.push(ContextSource::Glob {
431 pattern: pattern.into(),
432 });
433 self
434 }
435
436 pub fn with_context_config(mut self, config: ContextConfig) -> Self {
453 self.context_config = config;
454 self
455 }
456
457 pub async fn build(self) -> crate::error::Result<Agent> {
480 let provider_factory = self
481 .provider_factory
482 .ok_or_else(|| crate::error::Error::Config(
483 "No provider configured. Call .bedrock(), .anthropic(), or .provider() before .build()".to_string()
484 ))?;
485
486 let provider = provider_factory().await?;
487
488 let conversation_manager = self
489 .conversation_manager
490 .unwrap_or_else(|| Box::new(SlidingWindowConversationManager::new()));
491
492 let authorizer = match self.grant_store {
495 Some(store) => ToolCallAuthorizer::with_boxed_store(store),
496 None => ToolCallAuthorizer::new(),
497 }
498 .with_authorization_policy(self.authorization_policy);
499
500 #[allow(unused_mut)]
501 let mut agent = Agent {
502 provider,
503 system_prompt: self.system_prompt,
504 max_concurrent_tools: self.max_concurrent_tools,
505 tools: self.tools,
506 hooks: Arc::new(parking_lot::RwLock::new(Vec::new())),
507 authorizer: Arc::new(RwLock::new(authorizer)),
508 authorization_timeout: self.authorization_timeout,
509 pending_authorizations: Arc::new(RwLock::new(HashMap::new())),
510 #[cfg(feature = "mcp")]
511 mcp_clients: Vec::new(),
512 conversation_manager: parking_lot::RwLock::new(conversation_manager),
513 #[cfg(feature = "session")]
514 session_store: self.session_store,
515 context_sources: self.context_sources,
517 context_config: self.context_config,
518 last_context_result: parking_lot::RwLock::new(None),
519 };
520
521 #[cfg(feature = "mcp")]
523 {
524 super::mcp::connect_mcp_servers(&mut agent, self.mcp_servers, self.mcp_config_files)
525 .await?;
526 }
527
528 Ok(agent)
529 }
530}
531
532impl Agent {
533 pub fn builder() -> AgentBuilder {
554 AgentBuilder::new()
555 }
556
557 }
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564 use crate::box_tools;
565 use crate::conversation::SimpleConversationManager;
566 use crate::provider::{ModelProvider, ProviderError};
567 use crate::types::{ContentBlock, Message, Role, StopReason, ToolDefinition};
568 use crate::ModelResponse;
569
570 #[derive(Clone)]
572 struct MockProvider;
573
574 #[async_trait::async_trait]
575 impl ModelProvider for MockProvider {
576 fn name(&self) -> &str {
577 "MockProvider"
578 }
579
580 fn max_context_tokens(&self) -> usize {
581 200_000
582 }
583
584 fn max_output_tokens(&self) -> usize {
585 8_192
586 }
587
588 async fn generate(
589 &self,
590 _messages: Vec<Message>,
591 _tools: Vec<ToolDefinition>,
592 _system_prompt: Option<String>,
593 ) -> Result<ModelResponse, ProviderError> {
594 Ok(ModelResponse {
595 message: Message {
596 role: Role::Assistant,
597 content: vec![ContentBlock::Text("ok".to_string())],
598 },
599 stop_reason: StopReason::EndTurn,
600 usage: None,
601 })
602 }
603 }
604
605 #[test]
606 fn test_builder_creation() {
607 let builder = Agent::builder();
608 assert!(builder.provider_factory.is_none());
609 assert!(builder.tools.is_empty());
610 assert!(builder.system_prompt.is_none());
611 }
612
613 #[test]
614 fn test_builder_default() {
615 let builder = AgentBuilder::default();
616 assert!(builder.provider_factory.is_none());
617 assert_eq!(builder.max_concurrent_tools, DEFAULT_MAX_CONCURRENT_TOOLS);
618 assert_eq!(builder.authorization_timeout, DEFAULT_PERMISSION_TIMEOUT);
619 }
620
621 #[test]
622 fn test_builder_system_prompt() {
623 let builder = Agent::builder().with_system_prompt("Test prompt");
624 assert_eq!(builder.system_prompt, Some("Test prompt".to_string()));
625 }
626
627 #[test]
628 fn test_builder_max_concurrent_tools() {
629 let builder = Agent::builder().with_max_concurrent_tools(4);
630 assert_eq!(builder.max_concurrent_tools, 4);
631 }
632
633 #[test]
634 fn test_builder_conversation_manager() {
635 let builder =
636 Agent::builder().with_conversation_manager(SimpleConversationManager::new(100));
637 assert!(builder.conversation_manager.is_some());
638 }
639
640 #[tokio::test]
641 async fn test_build_with_provider() {
642 let agent = Agent::builder()
643 .provider(MockProvider)
644 .build()
645 .await
646 .unwrap();
647
648 assert_eq!(agent.provider.name(), "MockProvider");
649 }
650
651 #[tokio::test]
652 async fn test_build_with_system_prompt() {
653 let agent = Agent::builder()
654 .provider(MockProvider)
655 .with_system_prompt("Be helpful")
656 .build()
657 .await
658 .unwrap();
659
660 assert_eq!(agent.system_prompt, Some("Be helpful".to_string()));
661 }
662
663 #[tokio::test]
664 async fn test_build_with_conversation_manager() {
665 let agent = Agent::builder()
666 .provider(MockProvider)
667 .with_conversation_manager(SimpleConversationManager::new(100))
668 .build()
669 .await
670 .unwrap();
671
672 assert_eq!(agent.provider.name(), "MockProvider");
674 }
675
676 #[tokio::test]
677 async fn test_build_without_provider_fails() {
678 let result = Agent::builder().build().await;
679 match result {
680 Err(err) => assert!(err.is_config()),
681 Ok(_) => panic!("Expected error when building without provider"),
682 }
683 }
684
685 #[tokio::test]
686 async fn test_builder_chaining() {
687 let agent = Agent::builder()
688 .provider(MockProvider)
689 .with_system_prompt("Test")
690 .with_max_concurrent_tools(8)
691 .with_authorization_timeout(Duration::from_secs(60))
692 .build()
693 .await
694 .unwrap();
695
696 assert_eq!(agent.system_prompt, Some("Test".to_string()));
697 assert_eq!(agent.max_concurrent_tools, 8);
698 assert_eq!(agent.authorization_timeout, Duration::from_secs(60));
699 }
700
701 #[test]
704 fn test_builder_add_tool_single() {
705 use crate::tool::{Tool, ToolError, ToolResult};
706 use schemars::JsonSchema;
707 use serde::{Deserialize, Serialize};
708
709 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
710 #[allow(dead_code)]
711 struct TestInput {
712 value: String,
713 }
714
715 struct TestTool;
716
717 impl Tool for TestTool {
718 type Input = TestInput;
719 fn name(&self) -> &str {
720 "test_tool"
721 }
722 fn description(&self) -> &str {
723 "A test tool"
724 }
725 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
726 Ok(ToolResult::text("result"))
727 }
728 }
729
730 let builder = Agent::builder().add_tool(TestTool);
731 assert_eq!(builder.tools.len(), 1);
732 assert_eq!(builder.tools[0].name(), "test_tool");
733 }
734
735 #[test]
736 fn test_builder_add_tools_multiple() {
737 use crate::tool::{Tool, ToolError, ToolResult};
738 use schemars::JsonSchema;
739 use serde::{Deserialize, Serialize};
740
741 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
742 #[allow(dead_code)]
743 struct TestInput {
744 value: String,
745 }
746
747 #[derive(Clone)]
748 struct TestTool {
749 name: &'static str,
750 description: &'static str,
751 }
752
753 impl Tool for TestTool {
754 type Input = TestInput;
755 fn name(&self) -> &str {
756 self.name
757 }
758 fn description(&self) -> &str {
759 self.description
760 }
761 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
762 Ok(ToolResult::text(self.name))
763 }
764 }
765
766 let builder = Agent::builder().add_tools(box_tools![
767 TestTool {
768 name: "tool1",
769 description: "First tool",
770 },
771 TestTool {
772 name: "tool2",
773 description: "Second tool",
774 },
775 TestTool {
776 name: "tool3",
777 description: "Third tool",
778 },
779 ]);
780
781 assert_eq!(builder.tools.len(), 3);
782 assert_eq!(builder.tools[0].name(), "tool1");
783 assert_eq!(builder.tools[1].name(), "tool2");
784 assert_eq!(builder.tools[2].name(), "tool3");
785 }
786
787 #[test]
788 fn test_builder_add_tools_empty() {
789 use crate::tool::{Tool, ToolError, ToolResult};
790 use schemars::JsonSchema;
791 use serde::{Deserialize, Serialize};
792
793 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
794 #[allow(dead_code)]
795 struct TestInput {
796 value: String,
797 }
798
799 #[allow(dead_code)]
800 struct TestTool;
801 impl Tool for TestTool {
802 type Input = TestInput;
803 fn name(&self) -> &str {
804 "test"
805 }
806 fn description(&self) -> &str {
807 "Test"
808 }
809 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
810 Ok(ToolResult::text("ok"))
811 }
812 }
813
814 let builder = Agent::builder().add_tools(box_tools![]);
815
816 assert_eq!(builder.tools.len(), 0);
817 }
818
819 #[test]
820 fn test_builder_add_tool_and_add_tools_chaining() {
821 use crate::tool::{Tool, ToolError, ToolResult};
822 use schemars::JsonSchema;
823 use serde::{Deserialize, Serialize};
824
825 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
826 struct TestInput {}
827
828 struct Tool1;
829 impl Tool for Tool1 {
830 type Input = TestInput;
831 fn name(&self) -> &str {
832 "tool1"
833 }
834 fn description(&self) -> &str {
835 "First"
836 }
837 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
838 Ok(ToolResult::text("1"))
839 }
840 }
841
842 #[derive(Clone)]
843 struct Tool2;
844 impl Tool for Tool2 {
845 type Input = TestInput;
846 fn name(&self) -> &str {
847 "tool2"
848 }
849 fn description(&self) -> &str {
850 "Second"
851 }
852 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
853 Ok(ToolResult::text("2"))
854 }
855 }
856
857 let builder = Agent::builder()
859 .add_tool(Tool1)
860 .add_tools(box_tools![Tool2, Tool2]);
861
862 assert_eq!(builder.tools.len(), 3);
863 assert_eq!(builder.tools[0].name(), "tool1");
864 assert_eq!(builder.tools[1].name(), "tool2");
865 assert_eq!(builder.tools[2].name(), "tool2");
866 }
867
868 #[tokio::test]
869 async fn test_build_with_add_tools() {
870 use crate::tool::{Tool, ToolError, ToolResult};
871 use schemars::JsonSchema;
872 use serde::{Deserialize, Serialize};
873
874 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
875 struct TestInput {}
876
877 #[derive(Clone)]
878 struct NamedTool {
879 tool_name: &'static str,
880 tool_desc: &'static str,
881 }
882
883 impl Tool for NamedTool {
884 type Input = TestInput;
885 fn name(&self) -> &str {
886 self.tool_name
887 }
888 fn description(&self) -> &str {
889 self.tool_desc
890 }
891 async fn execute(&self, _input: Self::Input) -> Result<ToolResult, ToolError> {
892 Ok(ToolResult::text(self.tool_name))
893 }
894 }
895
896 let agent = Agent::builder()
897 .provider(MockProvider)
898 .add_tools(box_tools![
899 NamedTool {
900 tool_name: "calculator",
901 tool_desc: "Calculates things",
902 },
903 NamedTool {
904 tool_name: "weather",
905 tool_desc: "Gets weather",
906 },
907 ])
908 .build()
909 .await
910 .unwrap();
911
912 let tools = agent.list_tools();
913 assert_eq!(tools.len(), 2);
914
915 let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
916 assert!(names.contains(&"calculator"));
917 assert!(names.contains(&"weather"));
918 }
919}