use std::{collections::HashMap, sync::Arc};
use schemars::{JsonSchema, Schema, schema_for};
use tokio::sync::RwLock;
use crate::{
agent::prompt_request::hooks::PromptHook,
completion::{CompletionModel, Document},
message::ToolChoice,
tool::{
Tool, ToolDyn, ToolSet,
server::{ToolServer, ToolServerHandle},
},
vector_store::VectorStoreIndexDyn,
};
#[cfg(feature = "rmcp")]
#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
use crate::tool::rmcp::McpTool as RmcpTool;
use super::Agent;
#[derive(Default)]
pub struct NoToolConfig;
pub struct WithToolServerHandle {
handle: ToolServerHandle,
}
pub struct WithBuilderTools {
static_tools: Vec<String>,
tools: ToolSet,
dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
}
pub struct AgentBuilder<M, P = (), ToolState = NoToolConfig>
where
M: CompletionModel,
P: PromptHook<M>,
{
name: Option<String>,
description: Option<String>,
model: M,
preamble: Option<String>,
static_context: Vec<Document>,
additional_params: Option<serde_json::Value>,
max_tokens: Option<u64>,
dynamic_context: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
temperature: Option<f64>,
tool_choice: Option<ToolChoice>,
default_max_turns: Option<usize>,
tool_state: ToolState,
hook: Option<P>,
output_schema: Option<schemars::Schema>,
}
impl<M, P, ToolState> AgentBuilder<M, P, ToolState>
where
M: CompletionModel,
P: PromptHook<M>,
{
pub fn name(mut self, name: &str) -> Self {
self.name = Some(name.into());
self
}
pub fn description(mut self, description: &str) -> Self {
self.description = Some(description.into());
self
}
pub fn preamble(mut self, preamble: &str) -> Self {
self.preamble = Some(preamble.into());
self
}
pub fn without_preamble(mut self) -> Self {
self.preamble = None;
self
}
pub fn append_preamble(mut self, doc: &str) -> Self {
self.preamble = Some(format!("{}\n{}", self.preamble.unwrap_or_default(), doc));
self
}
pub fn context(mut self, doc: &str) -> Self {
self.static_context.push(Document {
id: format!("static_doc_{}", self.static_context.len()),
text: doc.into(),
additional_props: HashMap::new(),
});
self
}
pub fn dynamic_context(
mut self,
sample: usize,
dynamic_context: impl VectorStoreIndexDyn + Send + Sync + 'static,
) -> Self {
self.dynamic_context
.push((sample, Arc::new(dynamic_context)));
self
}
pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
self.tool_choice = Some(tool_choice);
self
}
pub fn default_max_turns(mut self, default_max_turns: usize) -> Self {
self.default_max_turns = Some(default_max_turns);
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
self.additional_params = Some(params);
self
}
pub fn output_schema<T>(mut self) -> Self
where
T: JsonSchema,
{
self.output_schema = Some(schema_for!(T));
self
}
pub fn output_schema_raw(mut self, schema: Schema) -> Self {
self.output_schema = Some(schema);
self
}
}
impl<M> AgentBuilder<M, (), NoToolConfig>
where
M: CompletionModel,
{
pub fn new(model: M) -> Self {
Self {
name: None,
description: None,
model,
preamble: None,
static_context: vec![],
temperature: None,
max_tokens: None,
additional_params: None,
dynamic_context: vec![],
tool_choice: None,
default_max_turns: None,
tool_state: NoToolConfig,
hook: None,
output_schema: None,
}
}
}
impl<M, P> AgentBuilder<M, P, NoToolConfig>
where
M: CompletionModel,
P: PromptHook<M>,
{
pub fn tool_server_handle(
self,
handle: ToolServerHandle,
) -> AgentBuilder<M, P, WithToolServerHandle> {
AgentBuilder {
name: self.name,
description: self.description,
model: self.model,
preamble: self.preamble,
static_context: self.static_context,
additional_params: self.additional_params,
max_tokens: self.max_tokens,
dynamic_context: self.dynamic_context,
temperature: self.temperature,
tool_choice: self.tool_choice,
default_max_turns: self.default_max_turns,
tool_state: WithToolServerHandle { handle },
hook: self.hook,
output_schema: self.output_schema,
}
}
pub fn tool(self, tool: impl Tool + 'static) -> AgentBuilder<M, P, WithBuilderTools> {
let toolname = tool.name();
AgentBuilder {
name: self.name,
description: self.description,
model: self.model,
preamble: self.preamble,
static_context: self.static_context,
additional_params: self.additional_params,
max_tokens: self.max_tokens,
dynamic_context: self.dynamic_context,
temperature: self.temperature,
tool_choice: self.tool_choice,
default_max_turns: self.default_max_turns,
tool_state: WithBuilderTools {
static_tools: vec![toolname],
tools: ToolSet::from_tools(vec![tool]),
dynamic_tools: vec![],
},
hook: self.hook,
output_schema: self.output_schema,
}
}
pub fn tools(self, tools: Vec<Box<dyn ToolDyn>>) -> AgentBuilder<M, P, WithBuilderTools> {
let static_tools = tools.iter().map(|tool| tool.name()).collect();
let tools = ToolSet::from_tools_boxed(tools);
AgentBuilder {
name: self.name,
description: self.description,
model: self.model,
preamble: self.preamble,
static_context: self.static_context,
additional_params: self.additional_params,
max_tokens: self.max_tokens,
dynamic_context: self.dynamic_context,
temperature: self.temperature,
tool_choice: self.tool_choice,
default_max_turns: self.default_max_turns,
hook: self.hook,
output_schema: self.output_schema,
tool_state: WithBuilderTools {
static_tools,
tools,
dynamic_tools: vec![],
},
}
}
#[cfg(feature = "rmcp")]
#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
pub fn rmcp_tool(
self,
tool: rmcp::model::Tool,
client: rmcp::service::ServerSink,
) -> AgentBuilder<M, P, WithBuilderTools> {
let toolname = tool.name.clone().to_string();
let tools = ToolSet::from_tools(vec![RmcpTool::from_mcp_server(tool, client)]);
AgentBuilder {
name: self.name,
description: self.description,
model: self.model,
preamble: self.preamble,
static_context: self.static_context,
additional_params: self.additional_params,
max_tokens: self.max_tokens,
dynamic_context: self.dynamic_context,
temperature: self.temperature,
tool_choice: self.tool_choice,
default_max_turns: self.default_max_turns,
hook: self.hook,
output_schema: self.output_schema,
tool_state: WithBuilderTools {
static_tools: vec![toolname],
tools,
dynamic_tools: vec![],
},
}
}
#[cfg(feature = "rmcp")]
#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
pub fn rmcp_tools(
self,
tools: Vec<rmcp::model::Tool>,
client: rmcp::service::ServerSink,
) -> AgentBuilder<M, P, WithBuilderTools> {
let (static_tools, tools) = tools.into_iter().fold(
(Vec::new(), Vec::new()),
|(mut toolnames, mut toolset), tool| {
let tool_name = tool.name.to_string();
let tool = RmcpTool::from_mcp_server(tool, client.clone());
toolnames.push(tool_name);
toolset.push(tool);
(toolnames, toolset)
},
);
let tools = ToolSet::from_tools(tools);
AgentBuilder {
name: self.name,
description: self.description,
model: self.model,
preamble: self.preamble,
static_context: self.static_context,
additional_params: self.additional_params,
max_tokens: self.max_tokens,
dynamic_context: self.dynamic_context,
temperature: self.temperature,
tool_choice: self.tool_choice,
default_max_turns: self.default_max_turns,
hook: self.hook,
output_schema: self.output_schema,
tool_state: WithBuilderTools {
static_tools,
tools,
dynamic_tools: vec![],
},
}
}
pub fn dynamic_tools(
self,
sample: usize,
dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
toolset: ToolSet,
) -> AgentBuilder<M, P, WithBuilderTools> {
AgentBuilder {
name: self.name,
description: self.description,
model: self.model,
preamble: self.preamble,
static_context: self.static_context,
additional_params: self.additional_params,
max_tokens: self.max_tokens,
dynamic_context: self.dynamic_context,
temperature: self.temperature,
tool_choice: self.tool_choice,
default_max_turns: self.default_max_turns,
hook: self.hook,
output_schema: self.output_schema,
tool_state: WithBuilderTools {
static_tools: vec![],
tools: toolset,
dynamic_tools: vec![(sample, Arc::new(dynamic_tools))],
},
}
}
pub fn hook<P2>(self, hook: P2) -> AgentBuilder<M, P2, NoToolConfig>
where
P2: PromptHook<M>,
{
AgentBuilder {
name: self.name,
description: self.description,
model: self.model,
preamble: self.preamble,
static_context: self.static_context,
additional_params: self.additional_params,
max_tokens: self.max_tokens,
dynamic_context: self.dynamic_context,
temperature: self.temperature,
tool_choice: self.tool_choice,
default_max_turns: self.default_max_turns,
tool_state: self.tool_state,
hook: Some(hook),
output_schema: self.output_schema,
}
}
pub fn build(self) -> Agent<M, P> {
let tool_server_handle = ToolServer::new().run();
Agent {
name: self.name,
description: self.description,
model: Arc::new(self.model),
preamble: self.preamble,
static_context: self.static_context,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
tool_choice: self.tool_choice,
dynamic_context: Arc::new(RwLock::new(self.dynamic_context)),
tool_server_handle,
default_max_turns: self.default_max_turns,
hook: self.hook,
output_schema: self.output_schema,
}
}
}
impl<M, P> AgentBuilder<M, P, WithToolServerHandle>
where
M: CompletionModel,
P: PromptHook<M>,
{
pub fn build(self) -> Agent<M, P> {
Agent {
name: self.name,
description: self.description,
model: Arc::new(self.model),
preamble: self.preamble,
static_context: self.static_context,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
tool_choice: self.tool_choice,
dynamic_context: Arc::new(RwLock::new(self.dynamic_context)),
tool_server_handle: self.tool_state.handle,
default_max_turns: self.default_max_turns,
hook: self.hook,
output_schema: self.output_schema,
}
}
}
impl<M, P> AgentBuilder<M, P, WithBuilderTools>
where
M: CompletionModel,
P: PromptHook<M>,
{
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
let toolname = tool.name();
self.tool_state.tools.add_tool(tool);
self.tool_state.static_tools.push(toolname);
self
}
pub fn tools(mut self, tools: Vec<Box<dyn ToolDyn>>) -> Self {
let toolnames: Vec<String> = tools.iter().map(|tool| tool.name()).collect();
let tools = ToolSet::from_tools_boxed(tools);
self.tool_state.tools.add_tools(tools);
self.tool_state.static_tools.extend(toolnames);
self
}
#[cfg(feature = "rmcp")]
#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
pub fn rmcp_tools(
mut self,
tools: Vec<rmcp::model::Tool>,
client: rmcp::service::ServerSink,
) -> Self {
for tool in tools {
let tool_name = tool.name.to_string();
let tool = RmcpTool::from_mcp_server(tool, client.clone());
self.tool_state.static_tools.push(tool_name);
self.tool_state.tools.add_tool(tool);
}
self
}
pub fn dynamic_tools(
mut self,
sample: usize,
dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
toolset: ToolSet,
) -> Self {
self.tool_state
.dynamic_tools
.push((sample, Arc::new(dynamic_tools)));
self.tool_state.tools.add_tools(toolset);
self
}
pub fn build(self) -> Agent<M, P> {
let tool_server_handle = ToolServer::new()
.static_tool_names(self.tool_state.static_tools)
.add_tools(self.tool_state.tools)
.add_dynamic_tools(self.tool_state.dynamic_tools)
.run();
Agent {
name: self.name,
description: self.description,
model: Arc::new(self.model),
preamble: self.preamble,
static_context: self.static_context,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
tool_choice: self.tool_choice,
dynamic_context: Arc::new(RwLock::new(self.dynamic_context)),
tool_server_handle,
default_max_turns: self.default_max_turns,
hook: self.hook,
output_schema: self.output_schema,
}
}
}