Skip to main content

agent_base/engine/
builder.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::sync::atomic::AtomicU64;
4
5use tokio::sync::broadcast;
6
7use crate::llm::LlmClient;
8use crate::skill::{Skill, SkillDetailTool, SkillPrompter, LazySkillPrompter};
9use crate::tool::{Tool, ToolPolicy, ToolRegistry};
10use crate::types::{AgentConfig, ResponseFormat, RetryConfig};
11
12use super::approval::ApprovalHandler;
13use super::context::ContextWindowManager;
14use super::middleware::{Middleware, MiddlewareRef};
15use super::recovery::{StopOnError, ToolErrorRecovery};
16use super::session_store::{InMemorySessionStore, SessionStore};
17use super::AgentRuntime;
18
19pub struct AgentBuilder {
20    client: Arc<dyn LlmClient>,
21    config: AgentConfig,
22    tools: ToolRegistry,
23    approval_handler: Option<Arc<dyn ApprovalHandler>>,
24    tool_policy: Option<Arc<dyn ToolPolicy>>,
25    middlewares: Vec<MiddlewareRef>,
26    context_manager: Option<ContextWindowManager>,
27    session_store: Option<Arc<dyn SessionStore>>,
28    skills: Vec<Arc<dyn Skill>>,
29    skill_prompter: Option<Arc<dyn SkillPrompter>>,
30    skill_detail_tool_name: String,
31    disable_skill_prompt_injection: bool,
32    error_recovery: Option<Arc<dyn ToolErrorRecovery>>,
33}
34
35impl AgentBuilder {
36    pub fn new(client: Arc<dyn LlmClient>) -> Self {
37        Self {
38            client,
39            config: AgentConfig::default(),
40            tools: ToolRegistry::default(),
41            approval_handler: None,
42            tool_policy: None,
43            middlewares: Vec::new(),
44            context_manager: None,
45            session_store: None,
46            skills: Vec::new(),
47            skill_prompter: None,
48            skill_detail_tool_name: "get_skill_detail".to_string(),
49            disable_skill_prompt_injection: false,
50            error_recovery: None,
51        }
52    }
53
54    pub fn system_prompt(mut self, system_prompt: impl Into<String>) -> Self {
55        self.config.system_prompt = Some(system_prompt.into());
56        self
57    }
58
59    pub fn enable_thought(mut self, enable: bool) -> Self {
60        self.config.enable_thought = enable;
61        self
62    }
63
64    pub fn enable_thinking(mut self, enable: bool) -> Self {
65        self.config.enable_thinking = Some(enable);
66        self
67    }
68
69    pub fn tool_timeout(mut self, timeout_ms: u64) -> Self {
70        self.config.tool_timeout_ms = Some(timeout_ms);
71        self
72    }
73
74    pub fn max_tool_output_chars(mut self, max_chars: usize) -> Self {
75        self.config.max_tool_output_chars = Some(max_chars);
76        self
77    }
78
79    pub fn register_tool(mut self, tool: impl Tool + 'static) -> Self {
80        self.tools.register(tool);
81        self
82    }
83
84    pub fn register_tool_arc(mut self, tool: Arc<dyn Tool>) -> Self {
85        self.tools.register_arc(tool);
86        self
87    }
88
89    pub fn approval_handler(mut self, handler: Arc<dyn ApprovalHandler>) -> Self {
90        self.approval_handler = Some(handler);
91        self
92    }
93
94    pub fn tool_policy(mut self, policy: Arc<dyn ToolPolicy>) -> Self {
95        self.tool_policy = Some(policy);
96        self
97    }
98
99    pub fn middleware(mut self, mw: impl Middleware + 'static) -> Self {
100        self.middlewares.push(Arc::new(mw));
101        self
102    }
103
104    pub fn context_window(mut self, max_tokens: usize) -> Self {
105        self.context_manager = Some(ContextWindowManager::new(max_tokens));
106        self
107    }
108
109    pub fn context_window_manager(mut self, manager: ContextWindowManager) -> Self {
110        self.context_manager = Some(manager);
111        self
112    }
113
114    pub fn response_format(mut self, format: ResponseFormat) -> Self {
115        self.config.response_format = Some(format);
116        self
117    }
118
119    pub fn llm_retry(mut self, retry: RetryConfig) -> Self {
120        self.config.llm_retry = Some(retry);
121        self
122    }
123
124    pub fn session_store(mut self, store: Arc<dyn SessionStore>) -> Self {
125        self.session_store = Some(store);
126        self
127    }
128
129    pub fn register_skill(mut self, skill: impl Skill + 'static) -> Self {
130        self.skills.push(Arc::new(skill));
131        self
132    }
133
134    pub fn skill_prompter(mut self, prompter: Arc<dyn SkillPrompter>) -> Self {
135        self.skill_prompter = Some(prompter);
136        self
137    }
138
139    pub fn disable_skill_prompt_injection(mut self) -> Self {
140        self.disable_skill_prompt_injection = true;
141        self
142    }
143
144    pub fn skill_detail_tool_name(mut self, name: impl Into<String>) -> Self {
145        self.skill_detail_tool_name = name.into();
146        self
147    }
148
149    pub fn error_recovery(mut self, recovery: Arc<dyn ToolErrorRecovery>) -> Self {
150        self.error_recovery = Some(recovery);
151        self
152    }
153
154    pub fn build(mut self) -> AgentRuntime {
155        let prompter: Arc<dyn SkillPrompter> = self
156            .skill_prompter
157            .unwrap_or_else(|| Arc::new(LazySkillPrompter::new()));
158
159        let mut skill_refs: Vec<Arc<dyn Skill>> = Vec::new();
160        let mut known_tool_names: HashSet<String> = self
161            .tools
162            .definitions()
163            .iter()
164            .filter_map(|d| {
165                d.get("function")
166                    .and_then(|f| f.get("name"))
167                    .and_then(|n| n.as_str())
168                    .map(|s| s.to_string())
169            })
170            .collect();
171
172        for skill in self.skills {
173            for tool in skill.tools() {
174                let tool_name = tool.name().to_string();
175                if known_tool_names.contains(&tool_name) {
176                    panic!(
177                        "Tool name conflict: `{}` (Skill `{}`)",
178                        tool_name,
179                        skill.name()
180                    );
181                }
182                known_tool_names.insert(tool_name.clone());
183                self.tools.register_arc(tool);
184            }
185            skill_refs.push(skill);
186        }
187
188        if !skill_refs.is_empty() && !self.disable_skill_prompt_injection {
189            let skill_prompt = prompter.build_prompt(&skill_refs);
190            if !skill_prompt.is_empty() {
191                let new_prompt = match self.config.system_prompt.take() {
192                    Some(existing) => format!("{}\n\n---\n\n{}", existing, skill_prompt),
193                    None => skill_prompt,
194                };
195                self.config.system_prompt = Some(new_prompt);
196            }
197        }
198
199        if !skill_refs.is_empty() {
200            let detail_tool = SkillDetailTool::new(
201                skill_refs.clone(),
202                std::mem::take(&mut self.skill_detail_tool_name),
203            );
204            self.tools.register(detail_tool);
205        }
206
207        let (event_bus, _) = broadcast::channel(2048);
208        let session_store = self
209            .session_store
210            .unwrap_or_else(|| Arc::new(InMemorySessionStore::new()));
211        let error_recovery = self
212            .error_recovery
213            .unwrap_or_else(|| Arc::new(StopOnError));
214
215        AgentRuntime {
216            client: self.client,
217            config: self.config,
218            tools: self.tools,
219            approval_handler: self.approval_handler,
220            tool_policy: self.tool_policy,
221            middlewares: self.middlewares,
222            event_bus,
223            next_session_id: AtomicU64::new(1),
224            sessions: HashMap::new(),
225            context_manager: self.context_manager,
226            session_store,
227            skills: skill_refs,
228            skill_prompter: prompter,
229            error_recovery,
230        }
231    }
232}