1use crate::agent::Agent;
10use crate::backend_trait::LlmBackend;
11use crate::observer::{NoOpObserver, Observer};
12use crate::store_trait::MessageStore;
13use crate::tool::Tool;
14use std::sync::Arc;
15
16pub struct AgentBuilder<B: LlmBackend> {
33 backend: B,
34 system: String,
35 tools: Vec<Box<dyn Tool>>,
36 max_steps: Option<usize>,
37 max_window: Option<usize>,
38 max_tool_result_bytes: Option<usize>,
39 store: Option<Arc<dyn MessageStore>>,
40 session: Option<String>,
41 observer: Option<Arc<dyn Observer>>,
42 on_token: Option<Box<dyn FnMut(&str) + Send>>,
43}
44
45impl<B: LlmBackend> AgentBuilder<B> {
46 pub fn new(backend: B) -> Self {
49 Self {
50 backend,
51 system: String::new(),
52 tools: Vec::new(),
53 max_steps: None,
54 max_window: None,
55 max_tool_result_bytes: None,
56 store: None,
57 session: None,
58 observer: None,
59 on_token: None,
60 }
61 }
62
63 pub fn system(mut self, system: impl Into<String>) -> Self {
65 self.system = system.into();
66 self
67 }
68
69 pub fn tool(mut self, tool: Box<dyn Tool>) -> Self {
71 self.tools.push(tool);
72 self
73 }
74
75 pub fn tools(mut self, tools: Vec<Box<dyn Tool>>) -> Self {
77 self.tools.extend(tools);
78 self
79 }
80
81 pub fn max_steps(mut self, n: usize) -> Self {
83 self.max_steps = Some(n);
84 self
85 }
86
87 pub fn max_window(mut self, n: usize) -> Self {
89 self.max_window = Some(n);
90 self
91 }
92
93 pub fn max_tool_result_bytes(mut self, n: usize) -> Self {
95 self.max_tool_result_bytes = Some(n);
96 self
97 }
98
99 pub fn store(mut self, store: Arc<dyn MessageStore>, session: impl Into<String>) -> Self {
101 self.store = Some(store);
102 self.session = Some(session.into());
103 self
104 }
105
106 pub fn observer(mut self, observer: Arc<dyn Observer>) -> Self {
108 self.observer = Some(observer);
109 self
110 }
111
112 pub fn on_token(mut self, sink: Box<dyn FnMut(&str) + Send>) -> Self {
115 self.on_token = Some(sink);
116 self
117 }
118
119 pub fn build(self) -> Result<Agent<B>, String> {
124 let mut agent = Agent::new(self.backend, &self.system);
125 for tool in self.tools {
126 agent.tools.register(tool);
127 }
128 if let Some(n) = self.max_steps {
129 agent.max_steps = n;
130 }
131 if let Some(n) = self.max_window {
132 agent.max_window = n;
133 }
134 if let Some(n) = self.max_tool_result_bytes {
135 agent.max_tool_result_bytes = n;
136 }
137 if let Some(obs) = self.observer {
138 agent.observer = obs;
139 } else {
140 agent.observer = Arc::new(NoOpObserver);
141 }
142 agent.on_token = self.on_token;
143 if let Some(store) = self.store {
144 let session = self.session.unwrap_or_else(|| "default".into());
145 agent.attach_store(store, &session)?;
146 }
147 Ok(agent)
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::backend_trait::BackendError;
155 use crate::message::Message;
156 use serde_json::Value;
157
158 struct MockBackend;
159 impl LlmBackend for MockBackend {
160 fn model(&self) -> &str {
161 "mock"
162 }
163 fn chat(
164 &self,
165 _messages: &[Message],
166 _tools: &Value,
167 _on_token: Option<&mut dyn FnMut(&str)>,
168 ) -> Result<Message, BackendError> {
169 Ok(Message {
170 role: "assistant".into(),
171 content: Some("ok".into()),
172 tool_calls: None,
173 tool_call_id: None,
174 name: None,
175 })
176 }
177 }
178
179 #[test]
180 fn builder_sets_defaults() {
181 let agent = AgentBuilder::new(MockBackend)
182 .system("sys")
183 .build()
184 .unwrap();
185 assert_eq!(agent.max_steps, 10);
186 assert_eq!(agent.max_window, 40);
187 assert_eq!(agent.messages[0].role, "system");
188 assert_eq!(agent.messages[0].content.as_deref(), Some("sys"));
189 }
190
191 #[test]
192 fn builder_overrides() {
193 let agent = AgentBuilder::new(MockBackend)
194 .system("sys")
195 .max_steps(3)
196 .max_window(5)
197 .max_tool_result_bytes(1024)
198 .build()
199 .unwrap();
200 assert_eq!(agent.max_steps, 3);
201 assert_eq!(agent.max_window, 5);
202 assert_eq!(agent.max_tool_result_bytes, 1024);
203 }
204
205 #[test]
206 fn builder_accepts_multiple_tools() {
207 use crate::tool::Tool;
208 struct Dummy(&'static str);
209 impl Tool for Dummy {
210 fn name(&self) -> &str {
211 self.0
212 }
213 fn description(&self) -> &str {
214 "dummy"
215 }
216 fn schema(&self) -> Value {
217 serde_json::json!({ "type": "object" })
218 }
219 fn call(&self, _args: Value) -> Result<String, String> {
220 Ok("".into())
221 }
222 }
223 let agent = AgentBuilder::new(MockBackend)
224 .tool(Box::new(Dummy("a")))
225 .tool(Box::new(Dummy("b")))
226 .tools(vec![Box::new(Dummy("c")), Box::new(Dummy("d"))])
227 .build()
228 .unwrap();
229 assert_eq!(agent.tools.names(), vec!["a", "b", "c", "d"]);
230 }
231}