claude_agent/agent/options/
builder.rs1use std::path::PathBuf;
27use std::sync::Arc;
28use std::time::Duration;
29
30use crate::auth::{Credential, OAuthConfig};
31use crate::budget::TenantBudgetManager;
32use crate::client::{CloudProvider, FallbackConfig, ModelConfig, ProviderConfig};
33use crate::common::IndexRegistry;
34use crate::context::{LeveledMemoryProvider, RuleIndex};
35use crate::hooks::{Hook, HookManager};
36use crate::output_style::OutputStyle;
37use crate::permissions::{PermissionMode, PermissionPolicy, PermissionRule};
38use crate::skills::SkillIndex;
39use crate::subagents::{SubagentIndex, builtin_subagents};
40use crate::tools::{Tool, ToolAccess};
41
42use crate::agent::config::{AgentConfig, SystemPromptMode};
43
44pub const DEFAULT_COMPACT_KEEP_MESSAGES: usize = 4;
46
47#[derive(Default)]
51pub struct AgentBuilder {
52 pub(super) config: AgentConfig,
53 pub(super) credential: Option<Credential>,
54 pub(super) auth_type: Option<crate::auth::Auth>,
55 pub(super) oauth_config: Option<OAuthConfig>,
56 pub(super) cloud_provider: Option<CloudProvider>,
57 pub(super) model_config: Option<ModelConfig>,
58 pub(super) provider_config: Option<ProviderConfig>,
59 pub(super) skill_registry: Option<IndexRegistry<SkillIndex>>,
60 pub(super) subagent_registry: Option<IndexRegistry<SubagentIndex>>,
61 pub(super) rule_indices: Vec<RuleIndex>,
62 pub(super) hooks: HookManager,
63 pub(super) custom_tools: Vec<Arc<dyn Tool>>,
64 pub(super) memory_provider: Option<LeveledMemoryProvider>,
65 pub(super) sandbox_settings: Option<crate::config::SandboxSettings>,
66 pub(super) initial_messages: Option<Vec<crate::types::Message>>,
67 pub(super) resume_session_id: Option<String>,
68 pub(super) resumed_session: Option<crate::session::Session>,
69 pub(super) tenant_budget_manager: Option<TenantBudgetManager>,
70 pub(super) fallback_config: Option<FallbackConfig>,
71 pub(super) output_style_name: Option<String>,
72 pub(super) mcp_configs: std::collections::HashMap<String, crate::mcp::McpServerConfig>,
73 pub(super) mcp_manager: Option<std::sync::Arc<crate::mcp::McpManager>>,
74 pub(super) mcp_toolset_registry: Option<crate::mcp::McpToolsetRegistry>,
75 pub(super) tool_search_config: Option<crate::tools::ToolSearchConfig>,
76 pub(super) tool_search_manager: Option<std::sync::Arc<crate::tools::ToolSearchManager>>,
77 pub(super) session_manager: Option<crate::session::SessionManager>,
78
79 pub(super) load_enterprise: bool,
82 pub(super) load_user: bool,
83 pub(super) load_project: bool,
84 pub(super) load_local: bool,
85
86 #[cfg(feature = "aws")]
87 pub(super) aws_region: Option<String>,
88 #[cfg(feature = "gcp")]
89 pub(super) gcp_project: Option<String>,
90 #[cfg(feature = "gcp")]
91 pub(super) gcp_region: Option<String>,
92 #[cfg(feature = "azure")]
93 pub(super) azure_resource: Option<String>,
94}
95
96impl AgentBuilder {
97 pub fn new() -> Self {
99 Self::default()
100 }
101
102 pub fn agent_config(mut self, config: AgentConfig) -> Self {
108 self.config = config;
109 self
110 }
111
112 pub fn provider_config(mut self, config: ProviderConfig) -> Self {
114 self.provider_config = Some(config);
115 self
116 }
117
118 pub async fn auth(mut self, auth: impl Into<crate::auth::Auth>) -> crate::Result<Self> {
142 let auth = auth.into();
143
144 #[allow(unreachable_patterns)]
145 match &auth {
146 #[cfg(feature = "aws")]
147 crate::auth::Auth::Bedrock { region } => {
148 self.cloud_provider = Some(CloudProvider::Bedrock);
149 self.aws_region = Some(region.clone());
150 self.model_config = Some(ModelConfig::bedrock());
151 self = self.apply_provider_models();
152 }
153 #[cfg(feature = "gcp")]
154 crate::auth::Auth::Vertex { project, region } => {
155 self.cloud_provider = Some(CloudProvider::Vertex);
156 self.gcp_project = Some(project.clone());
157 self.gcp_region = Some(region.clone());
158 self.model_config = Some(ModelConfig::vertex());
159 self = self.apply_provider_models();
160 }
161 #[cfg(feature = "azure")]
162 crate::auth::Auth::Foundry { resource } => {
163 self.cloud_provider = Some(CloudProvider::Foundry);
164 self.azure_resource = Some(resource.clone());
165 self.model_config = Some(ModelConfig::foundry());
166 self = self.apply_provider_models();
167 }
168 _ => {}
169 }
170
171 let credential = auth.resolve().await?;
172 if !credential.is_default() {
173 self.credential = Some(credential);
174 }
175
176 self.auth_type = Some(auth);
177
178 if self.supports_server_tools() {
179 self.config.server_tools = crate::agent::config::ServerToolsConfig::all();
180 }
181
182 Ok(self)
183 }
184
185 pub fn oauth_config(mut self, config: OAuthConfig) -> Self {
187 self.oauth_config = Some(config);
188 self
189 }
190
191 pub fn supports_server_tools(&self) -> bool {
193 self.auth_type
194 .as_ref()
195 .map(|a| a.supports_server_tools())
196 .unwrap_or(true)
197 }
198
199 pub fn models(mut self, config: ModelConfig) -> Self {
205 self.model_config = Some(config.clone());
206 self.config.model.primary = config.primary;
207 self.config.model.small = config.small;
208 self
209 }
210
211 #[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
212 fn apply_provider_models(mut self) -> Self {
213 if let Some(ref config) = self.model_config {
214 if self.config.model.primary
215 == crate::agent::config::AgentModelConfig::default().primary
216 {
217 self.config.model.primary = config.primary.clone();
218 }
219 if self.config.model.small == crate::agent::config::AgentModelConfig::default().small {
220 self.config.model.small = config.small.clone();
221 }
222 }
223 self
224 }
225
226 pub fn model(mut self, model: impl Into<String>) -> Self {
230 self.config.model.primary = model.into();
231 self
232 }
233
234 pub fn small_model(mut self, model: impl Into<String>) -> Self {
238 self.config.model.small = model.into();
239 self
240 }
241
242 pub fn max_tokens(mut self, tokens: u32) -> Self {
247 self.config.model.max_tokens = tokens;
248 self
249 }
250
251 pub fn extended_context(mut self, enabled: bool) -> Self {
256 self.config.model.extended_context = enabled;
257 self
258 }
259
260 pub fn tools(mut self, access: ToolAccess) -> Self {
272 self.config.security.tool_access = access;
273 self
274 }
275
276 pub fn tool<T: Tool + 'static>(mut self, tool: T) -> Self {
278 self.custom_tools.push(Arc::new(tool));
279 self
280 }
281
282 pub fn working_dir(mut self, path: impl Into<PathBuf>) -> Self {
288 self.config.working_dir = Some(path.into());
289 self
290 }
291
292 pub fn max_iterations(mut self, max: usize) -> Self {
296 self.config.execution.max_iterations = max;
297 self
298 }
299
300 pub fn timeout(mut self, timeout: Duration) -> Self {
304 self.config.execution.timeout = Some(timeout);
305 self
306 }
307
308 pub fn chunk_timeout(mut self, timeout: Duration) -> Self {
318 self.config.execution.chunk_timeout = timeout;
319 self
320 }
321
322 pub fn auto_compact(mut self, enabled: bool) -> Self {
326 self.config.execution.auto_compact = enabled;
327 self
328 }
329
330 pub fn compact_keep_messages(mut self, count: usize) -> Self {
334 self.config.execution.compact_keep_messages = count;
335 self
336 }
337
338 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
344 self.config.prompt.system_prompt = Some(prompt.into());
345 self
346 }
347
348 pub fn system_prompt_mode(mut self, mode: SystemPromptMode) -> Self {
350 self.config.prompt.system_prompt_mode = mode;
351 self
352 }
353
354 pub fn append_system_prompt(mut self, prompt: impl Into<String>) -> Self {
356 self.config.prompt.system_prompt_mode = SystemPromptMode::Append;
357 self.config.prompt.system_prompt = Some(prompt.into());
358 self
359 }
360
361 pub fn output_style(mut self, style: OutputStyle) -> Self {
363 self.config.prompt.output_style = Some(style);
364 self
365 }
366
367 pub fn output_style_name(mut self, name: impl Into<String>) -> Self {
369 self.output_style_name = Some(name.into());
370 self
371 }
372
373 pub fn output_schema(mut self, schema: serde_json::Value) -> Self {
375 self.config.prompt.output_schema = Some(schema);
376 self
377 }
378
379 pub fn structured_output<T: schemars::JsonSchema>(mut self) -> Self {
381 let schema = schemars::schema_for!(T);
382 self.config.prompt.output_schema = serde_json::to_value(schema).ok();
383 self
384 }
385
386 pub fn permission_policy(mut self, policy: PermissionPolicy) -> Self {
392 self.config.security.permission_policy = policy;
393 self
394 }
395
396 pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
398 self.config.security.permission_policy.mode = mode;
399 self
400 }
401
402 pub fn allow_tool(mut self, pattern: impl Into<String>) -> Self {
404 let pattern = pattern.into();
405 let rule = if pattern.contains('(') && pattern.contains(')') {
406 PermissionRule::allow_scoped(&pattern)
407 } else {
408 PermissionRule::allow(&pattern)
409 };
410 self.config.security.permission_policy.rules.push(rule);
411 self
412 }
413
414 pub fn deny_tool(mut self, pattern: impl Into<String>) -> Self {
416 let pattern = pattern.into();
417 let rule = if pattern.contains('(') && pattern.contains(')') {
418 PermissionRule::deny_scoped(&pattern)
419 } else {
420 PermissionRule::deny(&pattern)
421 };
422 self.config.security.permission_policy.rules.push(rule);
423 self
424 }
425
426 pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
432 self.config.security.env.insert(key.into(), value.into());
433 self
434 }
435
436 pub fn envs(
438 mut self,
439 vars: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
440 ) -> Self {
441 for (k, v) in vars {
442 self.config.security.env.insert(k.into(), v.into());
443 }
444 self
445 }
446
447 pub fn allow_domain(mut self, domain: impl Into<String>) -> Self {
453 self.sandbox_settings
454 .get_or_insert_with(crate::config::SandboxSettings::default)
455 .network
456 .allowed_domains
457 .insert(domain.into());
458 self
459 }
460
461 pub fn deny_domain(mut self, domain: impl Into<String>) -> Self {
463 self.sandbox_settings
464 .get_or_insert_with(crate::config::SandboxSettings::default)
465 .network
466 .blocked_domains
467 .insert(domain.into());
468 self
469 }
470
471 pub fn sandbox_enabled(mut self, enabled: bool) -> Self {
473 self.sandbox_settings
474 .get_or_insert_with(crate::config::SandboxSettings::default)
475 .enabled = enabled;
476 self
477 }
478
479 pub fn exclude_command(mut self, command: impl Into<String>) -> Self {
481 self.sandbox_settings
482 .get_or_insert_with(crate::config::SandboxSettings::default)
483 .excluded_commands
484 .push(command.into());
485 self
486 }
487
488 pub fn max_budget_usd(mut self, amount: f64) -> Self {
494 self.config.budget.max_cost_usd = Some(amount);
495 self
496 }
497
498 pub fn tenant_id(mut self, id: impl Into<String>) -> Self {
500 self.config.budget.tenant_id = Some(id.into());
501 self
502 }
503
504 pub fn tenant_budget_manager(mut self, manager: TenantBudgetManager) -> Self {
506 self.tenant_budget_manager = Some(manager);
507 self
508 }
509
510 pub fn fallback_model(mut self, model: impl Into<String>) -> Self {
512 self.config.budget.fallback_model = Some(model.into());
513 self
514 }
515
516 pub fn fallback(mut self, config: FallbackConfig) -> Self {
518 self.fallback_config = Some(config);
519 self
520 }
521
522 pub fn session_manager(mut self, manager: crate::session::SessionManager) -> Self {
528 self.session_manager = Some(manager);
529 self
530 }
531
532 pub async fn fork_session(mut self, session_id: impl Into<String>) -> crate::Result<Self> {
534 let manager = self.session_manager.take().unwrap_or_default();
535 let session_id_str: String = session_id.into();
536 let original_id = crate::session::SessionId::from(session_id_str);
537 let forked = manager
538 .fork(&original_id)
539 .await
540 .map_err(|e| crate::Error::Session(e.to_string()))?;
541
542 self.initial_messages = Some(forked.to_api_messages());
543 self.resume_session_id = Some(forked.id.to_string());
544 self.session_manager = Some(manager);
545 Ok(self)
546 }
547
548 pub async fn resume_session(mut self, session_id: impl Into<String>) -> crate::Result<Self> {
550 let session_id_str: String = session_id.into();
551 let id = crate::session::SessionId::from(session_id_str);
552 let manager = self.session_manager.take().unwrap_or_default();
553 let session = manager.get(&id).await?;
554
555 let messages: Vec<crate::types::Message> = session
556 .messages
557 .iter()
558 .map(|m| crate::types::Message {
559 role: m.role,
560 content: m.content.clone(),
561 })
562 .collect();
563
564 self.initial_messages = Some(messages);
565 self.resume_session_id = Some(id.to_string());
566 self.resumed_session = Some(session);
567 self.session_manager = Some(manager);
568 Ok(self)
569 }
570
571 pub fn messages(mut self, messages: Vec<crate::types::Message>) -> Self {
573 self.initial_messages = Some(messages);
574 self
575 }
576
577 pub fn mcp_server(
583 mut self,
584 name: impl Into<String>,
585 config: crate::mcp::McpServerConfig,
586 ) -> Self {
587 self.mcp_configs.insert(name.into(), config);
588 self
589 }
590
591 pub fn mcp_stdio(
593 mut self,
594 name: impl Into<String>,
595 command: impl Into<String>,
596 args: Vec<String>,
597 ) -> Self {
598 self.mcp_configs.insert(
599 name.into(),
600 crate::mcp::McpServerConfig::Stdio {
601 command: command.into(),
602 args,
603 env: std::collections::HashMap::new(),
604 },
605 );
606 self
607 }
608
609 pub fn mcp_manager(mut self, manager: crate::mcp::McpManager) -> Self {
611 self.mcp_manager = Some(std::sync::Arc::new(manager));
612 self
613 }
614
615 pub fn shared_mcp_manager(mut self, manager: std::sync::Arc<crate::mcp::McpManager>) -> Self {
617 self.mcp_manager = Some(manager);
618 self
619 }
620
621 pub fn mcp_toolset(mut self, toolset: crate::mcp::McpToolset) -> Self {
623 self.mcp_toolset_registry
624 .get_or_insert_with(crate::mcp::McpToolsetRegistry::new)
625 .register(toolset);
626 self
627 }
628
629 pub fn with_tool_search(mut self) -> Self {
635 self.tool_search_config = Some(crate::tools::ToolSearchConfig::default());
636 self
637 }
638
639 pub fn tool_search_config(mut self, config: crate::tools::ToolSearchConfig) -> Self {
641 self.tool_search_config = Some(config);
642 self
643 }
644
645 pub fn tool_search_threshold(mut self, threshold: f64) -> Self {
647 let config = self
648 .tool_search_config
649 .get_or_insert_with(crate::tools::ToolSearchConfig::default);
650 config.threshold = threshold.clamp(0.0, 1.0);
651 self
652 }
653
654 pub fn tool_search_mode(mut self, mode: crate::tools::SearchMode) -> Self {
656 let config = self
657 .tool_search_config
658 .get_or_insert_with(crate::tools::ToolSearchConfig::default);
659 config.search_mode = mode;
660 self
661 }
662
663 pub fn always_load_tools(mut self, tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
665 let config = self
666 .tool_search_config
667 .get_or_insert_with(crate::tools::ToolSearchConfig::default);
668 config.always_load = tools.into_iter().map(Into::into).collect();
669 self
670 }
671
672 pub fn shared_tool_search_manager(
674 mut self,
675 manager: std::sync::Arc<crate::tools::ToolSearchManager>,
676 ) -> Self {
677 self.tool_search_manager = Some(manager);
678 self
679 }
680
681 pub fn skill_registry(mut self, registry: IndexRegistry<SkillIndex>) -> Self {
687 self.skill_registry = Some(registry);
688 self
689 }
690
691 pub fn skill(mut self, skill: SkillIndex) -> Self {
693 self.skill_registry
694 .get_or_insert_with(IndexRegistry::new)
695 .register(skill);
696 self
697 }
698
699 pub fn rule_index(mut self, index: RuleIndex) -> Self {
701 self.rule_indices.push(index);
702 self
703 }
704
705 pub fn memory_content(mut self, content: impl Into<String>) -> Self {
707 self.memory_provider
708 .get_or_insert_with(LeveledMemoryProvider::new)
709 .add_content(content);
710 self
711 }
712
713 pub fn local_memory_content(mut self, content: impl Into<String>) -> Self {
715 self.memory_provider
716 .get_or_insert_with(LeveledMemoryProvider::new)
717 .add_local_content(content);
718 self
719 }
720
721 pub fn subagent_registry(mut self, registry: IndexRegistry<SubagentIndex>) -> Self {
727 self.subagent_registry = Some(registry);
728 self
729 }
730
731 pub fn subagent(mut self, subagent: SubagentIndex) -> Self {
733 self.subagent_registry
734 .get_or_insert_with(|| {
735 let mut registry = IndexRegistry::new();
736 registry.register_all(builtin_subagents());
737 registry
738 })
739 .register(subagent);
740 self
741 }
742
743 pub fn hook<H: Hook + 'static>(mut self, hook: H) -> Self {
749 self.hooks.register(hook);
750 self
751 }
752}
753
754#[cfg(test)]
755mod tests {
756 use super::*;
757 use crate::client::DEFAULT_MAX_TOKENS;
758
759 #[test]
760 fn test_tool_access() {
761 assert!(ToolAccess::all().is_allowed("Read"));
762 assert!(!ToolAccess::none().is_allowed("Read"));
763 assert!(ToolAccess::only(["Read", "Write"]).is_allowed("Read"));
764 assert!(!ToolAccess::only(["Read", "Write"]).is_allowed("Bash"));
765 assert!(!ToolAccess::except(["Bash"]).is_allowed("Bash"));
766 assert!(ToolAccess::except(["Bash"]).is_allowed("Read"));
767 }
768
769 #[test]
770 fn test_max_tokens_default() {
771 let builder = AgentBuilder::new();
772 assert_eq!(builder.config.model.max_tokens, DEFAULT_MAX_TOKENS);
773 }
774
775 #[test]
776 fn test_max_tokens_custom() {
777 let builder = AgentBuilder::new().max_tokens(16384);
778 assert_eq!(builder.config.model.max_tokens, 16384);
779 }
780}