use crate::agent::Agent;
use crate::backend_trait::LlmBackend;
use crate::observer::{NoOpObserver, Observer};
use crate::store_trait::MessageStore;
use crate::tool::Tool;
use std::sync::Arc;
pub struct AgentBuilder<B: LlmBackend> {
backend: B,
system: String,
tools: Vec<Box<dyn Tool>>,
max_steps: Option<usize>,
max_window: Option<usize>,
max_tool_result_bytes: Option<usize>,
store: Option<Arc<dyn MessageStore>>,
session: Option<String>,
observer: Option<Arc<dyn Observer>>,
on_token: Option<Box<dyn FnMut(&str) + Send>>,
max_step_duration: Option<std::time::Duration>,
}
impl<B: LlmBackend> AgentBuilder<B> {
pub fn new(backend: B) -> Self {
Self {
backend,
system: String::new(),
tools: Vec::new(),
max_steps: None,
max_window: None,
max_tool_result_bytes: None,
store: None,
session: None,
observer: None,
on_token: None,
max_step_duration: None,
}
}
pub fn system(mut self, system: impl Into<String>) -> Self {
self.system = system.into();
self
}
pub fn tool(mut self, tool: Box<dyn Tool>) -> Self {
self.tools.push(tool);
self
}
pub fn tools(mut self, tools: Vec<Box<dyn Tool>>) -> Self {
self.tools.extend(tools);
self
}
pub fn max_steps(mut self, n: usize) -> Self {
self.max_steps = Some(n);
self
}
pub fn max_window(mut self, n: usize) -> Self {
self.max_window = Some(n);
self
}
pub fn max_tool_result_bytes(mut self, n: usize) -> Self {
self.max_tool_result_bytes = Some(n);
self
}
pub fn store(mut self, store: Arc<dyn MessageStore>, session: impl Into<String>) -> Self {
self.store = Some(store);
self.session = Some(session.into());
self
}
pub fn observer(mut self, observer: Arc<dyn Observer>) -> Self {
self.observer = Some(observer);
self
}
pub fn on_token(mut self, sink: Box<dyn FnMut(&str) + Send>) -> Self {
self.on_token = Some(sink);
self
}
pub fn max_step_duration(mut self, d: std::time::Duration) -> Self {
self.max_step_duration = Some(d);
self
}
pub fn build(self) -> Result<Agent<B>, String> {
let mut agent = Agent::new(self.backend, &self.system);
for tool in self.tools {
agent.tools.register(tool);
}
if let Some(n) = self.max_steps {
agent.max_steps = n;
}
if let Some(n) = self.max_window {
agent.max_window = n;
}
if let Some(n) = self.max_tool_result_bytes {
agent.max_tool_result_bytes = n;
}
if let Some(obs) = self.observer {
agent.observer = obs;
} else {
agent.observer = Arc::new(NoOpObserver);
}
agent.on_token = self.on_token;
agent.max_step_duration = self.max_step_duration;
if let Some(store) = self.store {
let session = self.session.unwrap_or_else(|| "default".into());
agent.attach_store(store, &session)?;
}
Ok(agent)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend_trait::BackendError;
use crate::message::Message;
use serde_json::Value;
struct MockBackend;
impl LlmBackend for MockBackend {
fn model(&self) -> &str {
"mock"
}
fn chat(
&self,
_messages: &[Message],
_tools: &Value,
_on_token: Option<&mut dyn FnMut(&str)>,
) -> Result<Message, BackendError> {
Ok(Message {
role: "assistant".into(),
content: Some("ok".into()),
tool_calls: None,
tool_call_id: None,
name: None,
})
}
}
#[test]
fn builder_sets_defaults() {
let agent = AgentBuilder::new(MockBackend)
.system("sys")
.build()
.unwrap();
assert_eq!(agent.max_steps, 10);
assert_eq!(agent.max_window, 40);
assert_eq!(agent.messages[0].role, "system");
assert_eq!(agent.messages[0].content.as_deref(), Some("sys"));
}
#[test]
fn builder_overrides() {
let agent = AgentBuilder::new(MockBackend)
.system("sys")
.max_steps(3)
.max_window(5)
.max_tool_result_bytes(1024)
.build()
.unwrap();
assert_eq!(agent.max_steps, 3);
assert_eq!(agent.max_window, 5);
assert_eq!(agent.max_tool_result_bytes, 1024);
}
#[test]
fn builder_accepts_multiple_tools() {
use crate::tool::Tool;
struct Dummy(&'static str);
impl Tool for Dummy {
fn name(&self) -> &str {
self.0
}
fn description(&self) -> &str {
"dummy"
}
fn schema(&self) -> Value {
serde_json::json!({ "type": "object" })
}
fn call(&self, _args: Value) -> Result<String, String> {
Ok("".into())
}
}
let agent = AgentBuilder::new(MockBackend)
.tool(Box::new(Dummy("a")))
.tool(Box::new(Dummy("b")))
.tools(vec![Box::new(Dummy("c")), Box::new(Dummy("d"))])
.build()
.unwrap();
assert_eq!(agent.tools.names(), vec!["a", "b", "c", "d"]);
}
}