use std::{collections::HashMap, sync::Arc};
use schemars::{JsonSchema, Schema, schema_for};
use crate::{
agent::prompt_request::hooks::PromptHook,
completion::{CompletionModel, Document},
memory::ConversationMemory,
message::ToolChoice,
tool::{
Tool, ToolDyn, ToolSet,
server::{ToolServer, ToolServerHandle},
},
vector_store::VectorStoreIndexDyn,
};
#[cfg(feature = "rmcp")]
#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
use crate::tool::rmcp::McpTool as RmcpTool;
use super::Agent;
#[cfg(feature = "rmcp")]
fn build_rmcp_tools(
tools: Vec<rmcp::model::Tool>,
client: rmcp::service::ServerSink,
timeout: Option<std::time::Duration>,
) -> Vec<(String, RmcpTool)> {
tools
.into_iter()
.map(|tool| {
let name = tool.name.to_string();
let rmcp_tool = RmcpTool::from_mcp_server(tool, client.clone()).with_timeout(timeout);
(name, rmcp_tool)
})
.collect()
}
#[derive(Default)]
pub struct NoToolConfig;
pub struct WithToolServerHandle {
handle: ToolServerHandle,
}
pub struct WithBuilderTools {
static_tools: Vec<String>,
tools: ToolSet,
dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
}
pub struct AgentBuilder<M, P = (), ToolState = NoToolConfig>
where
M: CompletionModel,
P: PromptHook<M>,
{
name: Option<String>,
description: Option<String>,
model: M,
preamble: Option<String>,
static_context: Vec<Document>,
additional_params: Option<serde_json::Value>,
max_tokens: Option<u64>,
dynamic_context: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
temperature: Option<f64>,
tool_choice: Option<ToolChoice>,
default_max_turns: Option<usize>,
tool_state: ToolState,
hook: Option<P>,
output_schema: Option<schemars::Schema>,
memory: Option<Arc<dyn ConversationMemory>>,
default_conversation_id: Option<String>,
}
impl<M, P, ToolState> AgentBuilder<M, P, ToolState>
where
M: CompletionModel,
P: PromptHook<M>,
{
pub fn name(mut self, name: &str) -> Self {
self.name = Some(name.into());
self
}
pub fn description(mut self, description: &str) -> Self {
self.description = Some(description.into());
self
}
pub fn preamble(mut self, preamble: &str) -> Self {
self.preamble = Some(preamble.into());
self
}
pub fn without_preamble(mut self) -> Self {
self.preamble = None;
self
}
pub fn append_preamble(mut self, doc: &str) -> Self {
self.preamble = Some(format!("{}\n{}", self.preamble.unwrap_or_default(), doc));
self
}
pub fn context(mut self, doc: &str) -> Self {
self.static_context.push(Document {
id: format!("static_doc_{}", self.static_context.len()),
text: doc.into(),
additional_props: HashMap::new(),
});
self
}
pub fn dynamic_context(
mut self,
sample: usize,
dynamic_context: impl VectorStoreIndexDyn + Send + Sync + 'static,
) -> Self {
self.dynamic_context
.push((sample, Arc::new(dynamic_context)));
self
}
pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
self.tool_choice = Some(tool_choice);
self
}
pub fn default_max_turns(mut self, default_max_turns: usize) -> Self {
self.default_max_turns = Some(default_max_turns);
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
self.additional_params = Some(params);
self
}
pub fn output_schema<T>(mut self) -> Self
where
T: JsonSchema,
{
self.output_schema = Some(schema_for!(T));
self
}
pub fn output_schema_raw(mut self, schema: Schema) -> Self {
self.output_schema = Some(schema);
self
}
pub fn memory<B>(mut self, memory: B) -> Self
where
B: ConversationMemory + 'static,
{
self.memory = Some(Arc::new(memory));
self
}
pub fn conversation_id(mut self, id: impl Into<String>) -> Self {
self.default_conversation_id = Some(id.into());
self
}
pub fn hook<P2>(self, hook: P2) -> AgentBuilder<M, P2, ToolState>
where
P2: PromptHook<M>,
{
AgentBuilder {
name: self.name,
description: self.description,
model: self.model,
preamble: self.preamble,
static_context: self.static_context,
additional_params: self.additional_params,
max_tokens: self.max_tokens,
dynamic_context: self.dynamic_context,
temperature: self.temperature,
tool_choice: self.tool_choice,
default_max_turns: self.default_max_turns,
tool_state: self.tool_state,
hook: Some(hook),
output_schema: self.output_schema,
memory: self.memory,
default_conversation_id: self.default_conversation_id,
}
}
}
impl<M> AgentBuilder<M, (), NoToolConfig>
where
M: CompletionModel,
{
pub fn new(model: M) -> Self {
Self {
name: None,
description: None,
model,
preamble: None,
static_context: vec![],
temperature: None,
max_tokens: None,
additional_params: None,
dynamic_context: vec![],
tool_choice: None,
default_max_turns: None,
tool_state: NoToolConfig,
hook: None,
output_schema: None,
memory: None,
default_conversation_id: None,
}
}
}
impl<M, P> AgentBuilder<M, P, NoToolConfig>
where
M: CompletionModel,
P: PromptHook<M>,
{
pub fn tool_server_handle(
self,
handle: ToolServerHandle,
) -> AgentBuilder<M, P, WithToolServerHandle> {
AgentBuilder {
name: self.name,
description: self.description,
model: self.model,
preamble: self.preamble,
static_context: self.static_context,
additional_params: self.additional_params,
max_tokens: self.max_tokens,
dynamic_context: self.dynamic_context,
temperature: self.temperature,
tool_choice: self.tool_choice,
default_max_turns: self.default_max_turns,
tool_state: WithToolServerHandle { handle },
hook: self.hook,
output_schema: self.output_schema,
memory: self.memory,
default_conversation_id: self.default_conversation_id,
}
}
pub fn tool(self, tool: impl Tool + 'static) -> AgentBuilder<M, P, WithBuilderTools> {
let toolname = tool.name();
AgentBuilder {
name: self.name,
description: self.description,
model: self.model,
preamble: self.preamble,
static_context: self.static_context,
additional_params: self.additional_params,
max_tokens: self.max_tokens,
dynamic_context: self.dynamic_context,
temperature: self.temperature,
tool_choice: self.tool_choice,
default_max_turns: self.default_max_turns,
tool_state: WithBuilderTools {
static_tools: vec![toolname],
tools: ToolSet::from_tools(vec![tool]),
dynamic_tools: vec![],
},
hook: self.hook,
output_schema: self.output_schema,
memory: self.memory,
default_conversation_id: self.default_conversation_id,
}
}
pub fn tools(self, tools: Vec<Box<dyn ToolDyn>>) -> AgentBuilder<M, P, WithBuilderTools> {
let static_tools = tools.iter().map(|tool| tool.name()).collect();
let tools = ToolSet::from_tools_boxed(tools);
AgentBuilder {
name: self.name,
description: self.description,
model: self.model,
preamble: self.preamble,
static_context: self.static_context,
additional_params: self.additional_params,
max_tokens: self.max_tokens,
dynamic_context: self.dynamic_context,
temperature: self.temperature,
tool_choice: self.tool_choice,
default_max_turns: self.default_max_turns,
hook: self.hook,
output_schema: self.output_schema,
memory: self.memory,
default_conversation_id: self.default_conversation_id,
tool_state: WithBuilderTools {
static_tools,
tools,
dynamic_tools: vec![],
},
}
}
#[cfg(feature = "rmcp")]
#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
pub fn rmcp_tool(
self,
tool: rmcp::model::Tool,
client: rmcp::service::ServerSink,
) -> AgentBuilder<M, P, WithBuilderTools> {
self.rmcp_tool_with_timeout(tool, client, crate::tool::rmcp::DEFAULT_MCP_TOOL_TIMEOUT)
}
#[cfg(feature = "rmcp")]
#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
pub fn rmcp_tool_with_timeout(
self,
tool: rmcp::model::Tool,
client: rmcp::service::ServerSink,
timeout: impl Into<Option<std::time::Duration>>,
) -> AgentBuilder<M, P, WithBuilderTools> {
self.with_rmcp_toolset(build_rmcp_tools(vec![tool], client, timeout.into()))
}
#[cfg(feature = "rmcp")]
#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
pub fn rmcp_tools(
self,
tools: Vec<rmcp::model::Tool>,
client: rmcp::service::ServerSink,
) -> AgentBuilder<M, P, WithBuilderTools> {
self.rmcp_tools_with_timeout(tools, client, crate::tool::rmcp::DEFAULT_MCP_TOOL_TIMEOUT)
}
#[cfg(feature = "rmcp")]
#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
pub fn rmcp_tools_with_timeout(
self,
tools: Vec<rmcp::model::Tool>,
client: rmcp::service::ServerSink,
timeout: impl Into<Option<std::time::Duration>>,
) -> AgentBuilder<M, P, WithBuilderTools> {
self.with_rmcp_toolset(build_rmcp_tools(tools, client, timeout.into()))
}
#[cfg(feature = "rmcp")]
fn with_rmcp_toolset(
self,
built: Vec<(String, RmcpTool)>,
) -> AgentBuilder<M, P, WithBuilderTools> {
let (static_tools, toolset): (Vec<String>, Vec<RmcpTool>) = built.into_iter().unzip();
AgentBuilder {
name: self.name,
description: self.description,
model: self.model,
preamble: self.preamble,
static_context: self.static_context,
additional_params: self.additional_params,
max_tokens: self.max_tokens,
dynamic_context: self.dynamic_context,
temperature: self.temperature,
tool_choice: self.tool_choice,
default_max_turns: self.default_max_turns,
hook: self.hook,
output_schema: self.output_schema,
memory: self.memory,
default_conversation_id: self.default_conversation_id,
tool_state: WithBuilderTools {
static_tools,
tools: ToolSet::from_tools(toolset),
dynamic_tools: vec![],
},
}
}
pub fn dynamic_tools(
self,
sample: usize,
dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
toolset: ToolSet,
) -> AgentBuilder<M, P, WithBuilderTools> {
AgentBuilder {
name: self.name,
description: self.description,
model: self.model,
preamble: self.preamble,
static_context: self.static_context,
additional_params: self.additional_params,
max_tokens: self.max_tokens,
dynamic_context: self.dynamic_context,
temperature: self.temperature,
tool_choice: self.tool_choice,
default_max_turns: self.default_max_turns,
hook: self.hook,
output_schema: self.output_schema,
memory: self.memory,
default_conversation_id: self.default_conversation_id,
tool_state: WithBuilderTools {
static_tools: vec![],
tools: toolset,
dynamic_tools: vec![(sample, Arc::new(dynamic_tools))],
},
}
}
pub fn build(self) -> Agent<M, P> {
let tool_server_handle = ToolServer::new().run();
Agent {
name: self.name,
description: self.description,
model: Arc::new(self.model),
preamble: self.preamble,
static_context: self.static_context,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
tool_choice: self.tool_choice,
dynamic_context: Arc::new(self.dynamic_context),
tool_server_handle,
default_max_turns: self.default_max_turns,
hook: self.hook,
output_schema: self.output_schema,
memory: self.memory,
default_conversation_id: self.default_conversation_id,
}
}
}
impl<M, P> AgentBuilder<M, P, WithToolServerHandle>
where
M: CompletionModel,
P: PromptHook<M>,
{
pub fn build(self) -> Agent<M, P> {
Agent {
name: self.name,
description: self.description,
model: Arc::new(self.model),
preamble: self.preamble,
static_context: self.static_context,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
tool_choice: self.tool_choice,
dynamic_context: Arc::new(self.dynamic_context),
tool_server_handle: self.tool_state.handle,
default_max_turns: self.default_max_turns,
hook: self.hook,
output_schema: self.output_schema,
memory: self.memory,
default_conversation_id: self.default_conversation_id,
}
}
}
impl<M, P> AgentBuilder<M, P, WithBuilderTools>
where
M: CompletionModel,
P: PromptHook<M>,
{
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
let toolname = tool.name();
self.tool_state.tools.add_tool(tool);
self.tool_state.static_tools.push(toolname);
self
}
pub fn tools(mut self, tools: Vec<Box<dyn ToolDyn>>) -> Self {
let toolnames: Vec<String> = tools.iter().map(|tool| tool.name()).collect();
let tools = ToolSet::from_tools_boxed(tools);
self.tool_state.tools.add_tools(tools);
self.tool_state.static_tools.extend(toolnames);
self
}
#[cfg(feature = "rmcp")]
#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
pub fn rmcp_tools(
self,
tools: Vec<rmcp::model::Tool>,
client: rmcp::service::ServerSink,
) -> Self {
self.rmcp_tools_with_timeout(tools, client, crate::tool::rmcp::DEFAULT_MCP_TOOL_TIMEOUT)
}
#[cfg(feature = "rmcp")]
#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
pub fn rmcp_tools_with_timeout(
self,
tools: Vec<rmcp::model::Tool>,
client: rmcp::service::ServerSink,
timeout: impl Into<Option<std::time::Duration>>,
) -> Self {
self.add_rmcp_tools(build_rmcp_tools(tools, client, timeout.into()))
}
#[cfg(feature = "rmcp")]
fn add_rmcp_tools(mut self, built: Vec<(String, RmcpTool)>) -> Self {
for (name, tool) in built {
self.tool_state.static_tools.push(name);
self.tool_state.tools.add_tool(tool);
}
self
}
pub fn dynamic_tools(
mut self,
sample: usize,
dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
toolset: ToolSet,
) -> Self {
self.tool_state
.dynamic_tools
.push((sample, Arc::new(dynamic_tools)));
self.tool_state.tools.add_tools(toolset);
self
}
pub fn build(self) -> Agent<M, P> {
let tool_server_handle = ToolServer::new()
.static_tool_names(self.tool_state.static_tools)
.add_tools(self.tool_state.tools)
.add_dynamic_tools(self.tool_state.dynamic_tools)
.run();
Agent {
name: self.name,
description: self.description,
model: Arc::new(self.model),
preamble: self.preamble,
static_context: self.static_context,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
tool_choice: self.tool_choice,
dynamic_context: Arc::new(self.dynamic_context),
tool_server_handle,
default_max_turns: self.default_max_turns,
hook: self.hook,
output_schema: self.output_schema,
memory: self.memory,
default_conversation_id: self.default_conversation_id,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::{MockAddTool, MockCompletionModel};
#[derive(Clone)]
struct BuilderHook;
impl PromptHook<MockCompletionModel> for BuilderHook {}
#[test]
fn hook_can_be_set_after_tool_configuration() {
let _agent = AgentBuilder::new(MockCompletionModel::text("ok"))
.tool(MockAddTool)
.hook(BuilderHook)
.build();
}
#[cfg(feature = "rmcp")]
#[tokio::test]
async fn build_rmcp_tools_threads_timeout_into_built_tools() {
use crate::tool::ToolDyn;
use crate::tool::rmcp::DEFAULT_MCP_TOOL_TIMEOUT;
use rmcp::model::{
CallToolRequestParams, CallToolResult, ClientInfo, ErrorData, Implementation,
ProtocolVersion, ServerCapabilities, ServerInfo, Tool,
};
use rmcp::service::RequestContext;
use rmcp::{RoleServer, ServerHandler, ServiceExt};
use std::sync::Arc;
use std::time::Duration;
#[derive(Clone)]
struct HangingServer;
impl ServerHandler for HangingServer {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
.with_protocol_version(ProtocolVersion::LATEST)
.with_server_info(Implementation::new("builder-timeout-test", "0.1.0"))
}
async fn call_tool(
&self,
_request: CallToolRequestParams,
_context: RequestContext<RoleServer>,
) -> Result<CallToolResult, ErrorData> {
std::future::pending::<Result<CallToolResult, ErrorData>>().await
}
}
fn tool(name: &str) -> Tool {
Tool::new(
name.to_string(),
String::new(),
Arc::new(serde_json::Map::new()),
)
}
let (c2s, sfc) = tokio::io::duplex(8192);
let (s2c, cfs) = tokio::io::duplex(8192);
let server_task = tokio::spawn(async move {
let running = HangingServer.serve((sfc, s2c)).await.expect("server start");
running.waiting().await.expect("server error");
});
let client = ClientInfo::default()
.serve((cfs, c2s))
.await
.expect("client connect");
let peer = client.peer().clone();
let built_default = build_rmcp_tools(
vec![tool("a")],
peer.clone(),
Some(DEFAULT_MCP_TOOL_TIMEOUT),
);
assert_eq!(built_default[0].1.timeout(), Some(DEFAULT_MCP_TOOL_TIMEOUT));
let built_none = build_rmcp_tools(vec![tool("b")], peer.clone(), None);
assert_eq!(built_none[0].1.timeout(), None);
let built = build_rmcp_tools(
vec![tool("hang_forever")],
peer,
Some(Duration::from_millis(200)),
);
assert_eq!(built.len(), 1);
assert_eq!(built[0].0, "hang_forever");
let timed =
tokio::time::timeout(Duration::from_secs(5), built[0].1.call("{}".to_string())).await;
let err = timed
.expect("built tool hung past the safety timeout")
.expect_err("call should time out");
assert!(err.to_string().contains("timed out"), "got: {err}");
drop(client);
server_task.abort();
}
}