1use crate::agent::Agent;
10use crate::backend_trait::LlmBackend;
11use crate::observer::{NoOpObserver, Observer};
12use crate::store_trait::MessageStore;
13use crate::tool::Tool;
14use std::sync::Arc;
15
16pub struct AgentBuilder<B: LlmBackend> {
33 backend: B,
34 system: String,
35 tools: Vec<Box<dyn Tool>>,
36 max_steps: Option<usize>,
37 max_window: Option<usize>,
38 max_tool_result_bytes: Option<usize>,
39 store: Option<Arc<dyn MessageStore>>,
40 session: Option<String>,
41 observer: Option<Arc<dyn Observer>>,
42 on_token: Option<Box<dyn FnMut(&str) + Send>>,
43 max_step_duration: Option<std::time::Duration>,
44}
45
46impl<B: LlmBackend> AgentBuilder<B> {
47 pub fn new(backend: B) -> Self {
50 Self {
51 backend,
52 system: String::new(),
53 tools: Vec::new(),
54 max_steps: None,
55 max_window: None,
56 max_tool_result_bytes: None,
57 store: None,
58 session: None,
59 observer: None,
60 on_token: None,
61 max_step_duration: None,
62 }
63 }
64
65 pub fn system(mut self, system: impl Into<String>) -> Self {
67 self.system = system.into();
68 self
69 }
70
71 pub fn tool(mut self, tool: Box<dyn Tool>) -> Self {
73 self.tools.push(tool);
74 self
75 }
76
77 pub fn tools(mut self, tools: Vec<Box<dyn Tool>>) -> Self {
79 self.tools.extend(tools);
80 self
81 }
82
83 pub fn max_steps(mut self, n: usize) -> Self {
85 self.max_steps = Some(n);
86 self
87 }
88
89 pub fn max_window(mut self, n: usize) -> Self {
91 self.max_window = Some(n);
92 self
93 }
94
95 pub fn max_tool_result_bytes(mut self, n: usize) -> Self {
97 self.max_tool_result_bytes = Some(n);
98 self
99 }
100
101 pub fn store(mut self, store: Arc<dyn MessageStore>, session: impl Into<String>) -> Self {
103 self.store = Some(store);
104 self.session = Some(session.into());
105 self
106 }
107
108 pub fn observer(mut self, observer: Arc<dyn Observer>) -> Self {
110 self.observer = Some(observer);
111 self
112 }
113
114 pub fn on_token(mut self, sink: Box<dyn FnMut(&str) + Send>) -> Self {
117 self.on_token = Some(sink);
118 self
119 }
120
121 pub fn max_step_duration(mut self, d: std::time::Duration) -> Self {
126 self.max_step_duration = Some(d);
127 self
128 }
129
130 pub fn build(self) -> Result<Agent<B>, String> {
135 let mut agent = Agent::new(self.backend, &self.system);
136 for tool in self.tools {
137 agent.tools.register(tool);
138 }
139 if let Some(n) = self.max_steps {
140 agent.max_steps = n;
141 }
142 if let Some(n) = self.max_window {
143 agent.max_window = n;
144 }
145 if let Some(n) = self.max_tool_result_bytes {
146 agent.max_tool_result_bytes = n;
147 }
148 if let Some(obs) = self.observer {
149 agent.observer = obs;
150 } else {
151 agent.observer = Arc::new(NoOpObserver);
152 }
153 agent.on_token = self.on_token;
154 agent.max_step_duration = self.max_step_duration;
155 if let Some(store) = self.store {
156 let session = self.session.unwrap_or_else(|| "default".into());
157 agent.attach_store(store, &session)?;
158 }
159 Ok(agent)
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::backend_trait::BackendError;
167 use crate::message::Message;
168 use serde_json::Value;
169
170 struct MockBackend;
171 impl LlmBackend for MockBackend {
172 fn model(&self) -> &str {
173 "mock"
174 }
175 fn chat(
176 &self,
177 _messages: &[Message],
178 _tools: &Value,
179 _on_token: Option<&mut dyn FnMut(&str)>,
180 ) -> Result<Message, BackendError> {
181 Ok(Message {
182 role: "assistant".into(),
183 content: Some("ok".into()),
184 tool_calls: None,
185 tool_call_id: None,
186 name: None,
187 })
188 }
189 }
190
191 #[test]
192 fn builder_sets_defaults() {
193 let agent = AgentBuilder::new(MockBackend)
194 .system("sys")
195 .build()
196 .unwrap();
197 assert_eq!(agent.max_steps, 10);
198 assert_eq!(agent.max_window, 40);
199 assert_eq!(agent.messages[0].role, "system");
200 assert_eq!(agent.messages[0].content.as_deref(), Some("sys"));
201 }
202
203 #[test]
204 fn builder_overrides() {
205 let agent = AgentBuilder::new(MockBackend)
206 .system("sys")
207 .max_steps(3)
208 .max_window(5)
209 .max_tool_result_bytes(1024)
210 .build()
211 .unwrap();
212 assert_eq!(agent.max_steps, 3);
213 assert_eq!(agent.max_window, 5);
214 assert_eq!(agent.max_tool_result_bytes, 1024);
215 }
216
217 #[test]
218 fn builder_accepts_multiple_tools() {
219 use crate::tool::Tool;
220 struct Dummy(&'static str);
221 impl Tool for Dummy {
222 fn name(&self) -> &str {
223 self.0
224 }
225 fn description(&self) -> &str {
226 "dummy"
227 }
228 fn schema(&self) -> Value {
229 serde_json::json!({ "type": "object" })
230 }
231 fn call(&self, _args: Value) -> Result<String, String> {
232 Ok("".into())
233 }
234 }
235 let agent = AgentBuilder::new(MockBackend)
236 .tool(Box::new(Dummy("a")))
237 .tool(Box::new(Dummy("b")))
238 .tools(vec![Box::new(Dummy("c")), Box::new(Dummy("d"))])
239 .build()
240 .unwrap();
241 assert_eq!(agent.tools.names(), vec!["a", "b", "c", "d"]);
242 }
243}