use std::sync::Arc;
use tokio::sync::RwLock;
use crate::{
completion::{CompletionError, ToolDefinition},
tool::{Tool, ToolDyn, ToolSet, ToolSetError},
vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn, request::Filter},
};
struct ToolServerState {
static_tool_names: Vec<String>,
dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
toolset: ToolSet,
}
pub struct ToolServer {
static_tool_names: Vec<String>,
dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
toolset: ToolSet,
}
impl Default for ToolServer {
fn default() -> Self {
Self::new()
}
}
impl ToolServer {
pub fn new() -> Self {
Self {
static_tool_names: Vec::new(),
dynamic_tools: Vec::new(),
toolset: ToolSet::default(),
}
}
pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
self.static_tool_names = names;
self
}
pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
self.toolset = tools;
self
}
pub(crate) fn add_dynamic_tools(
mut self,
dyn_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
) -> Self {
self.dynamic_tools = dyn_tools;
self
}
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
let toolname = tool.name();
self.toolset.add_tool(tool);
self.static_tool_names.push(toolname);
self
}
#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
#[cfg(feature = "rmcp")]
pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
use crate::tool::rmcp::McpTool;
let toolname = tool.name.clone();
self.toolset
.add_tool(McpTool::from_mcp_server(tool, client));
self.static_tool_names.push(toolname.to_string());
self
}
pub fn dynamic_tools(
mut self,
sample: usize,
dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
toolset: ToolSet,
) -> Self {
self.dynamic_tools.push((sample, Arc::new(dynamic_tools)));
self.toolset.add_tools(toolset);
self
}
pub fn run(self) -> ToolServerHandle {
ToolServerHandle(Arc::new(RwLock::new(ToolServerState {
static_tool_names: self.static_tool_names,
dynamic_tools: self.dynamic_tools,
toolset: self.toolset,
})))
}
}
#[derive(Clone)]
pub struct ToolServerHandle(Arc<RwLock<ToolServerState>>);
impl ToolServerHandle {
pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
let mut state = self.0.write().await;
state.static_tool_names.push(tool.name());
state.toolset.add_tool_boxed(Box::new(tool));
Ok(())
}
pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
let mut state = self.0.write().await;
state.toolset.add_tools(toolset);
Ok(())
}
pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
let mut state = self.0.write().await;
state.static_tool_names.retain(|x| *x != tool_name);
state.toolset.delete_tool(tool_name);
Ok(())
}
pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
let tool = {
let state = self.0.read().await;
state.toolset.get(tool_name).cloned()
};
match tool {
Some(tool) => {
tracing::debug!(target: "rig",
"Calling tool {tool_name} with args:\n{}",
serde_json::to_string_pretty(&args).unwrap_or_default()
);
tool.call(args.to_string())
.await
.map_err(|e| ToolSetError::ToolCallError(e).into())
}
None => Err(ToolServerError::ToolsetError(
ToolSetError::ToolNotFoundError(tool_name.to_string()),
)),
}
}
pub async fn get_tool_defs(
&self,
prompt: Option<String>,
) -> Result<Vec<ToolDefinition>, ToolServerError> {
let (static_tool_names, dynamic_tools) = {
let state = self.0.read().await;
(state.static_tool_names.clone(), state.dynamic_tools.clone())
};
let mut tools = if let Some(ref text) = prompt {
let search_futures = dynamic_tools.iter().map(|(num_sample, index)| {
let text = text.clone();
let num_sample = *num_sample;
let index = index.clone();
async move {
let req = VectorSearchRequest::builder()
.query(text)
.samples(num_sample as u64)
.build();
let ids = index
.as_ref()
.top_n_ids(req.map_filter(Filter::interpret))
.await?
.into_iter()
.map(|(_, id)| id)
.collect::<Vec<String>>();
Ok::<_, VectorStoreError>(ids)
}
});
let dynamic_tool_ids: Vec<String> = futures::future::try_join_all(search_futures)
.await
.map_err(|e| {
ToolServerError::DefinitionError(CompletionError::RequestError(Box::new(e)))
})?
.into_iter()
.flatten()
.collect();
let dynamic_tool_handles: Vec<_> = {
let state = self.0.read().await;
dynamic_tool_ids
.iter()
.filter_map(|doc| {
let handle = state.toolset.get(doc).cloned();
if handle.is_none() {
tracing::warn!("Tool implementation not found in toolset: {}", doc);
}
handle
})
.collect()
};
let mut tools = Vec::new();
for tool in dynamic_tool_handles {
tools.push(tool.definition(text.clone()).await);
}
tools
} else {
Vec::new()
};
let static_tool_handles: Vec<_> = {
let state = self.0.read().await;
static_tool_names
.iter()
.filter_map(|toolname| {
let handle = state.toolset.get(toolname).cloned();
if handle.is_none() {
tracing::warn!("Tool implementation not found in toolset: {}", toolname);
}
handle
})
.collect()
};
for tool in static_tool_handles {
tools.push(tool.definition(String::new()).await);
}
Ok(tools)
}
}
#[derive(Debug, thiserror::Error)]
pub enum ToolServerError {
#[error("Toolset error: {0}")]
ToolsetError(#[from] ToolSetError),
#[error("Failed to retrieve tool definitions: {0}")]
DefinitionError(CompletionError),
}
#[cfg(test)]
mod tests {
use std::{sync::Arc, time::Duration};
use crate::{
test_utils::{
BarrierMockToolIndex, MockAddTool, MockBarrierTool, MockControlledTool,
MockSubtractTool, MockToolIndex,
},
tool::{ToolSet, server::ToolServer},
};
#[tokio::test]
pub async fn test_toolserver() {
let server = ToolServer::new();
let handle = server.run();
handle.add_tool(MockAddTool).await.unwrap();
let res = handle.get_tool_defs(None).await.unwrap();
assert_eq!(res.len(), 1);
let json_args_as_string =
serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
assert_eq!(res, "7");
handle.remove_tool("add").await.unwrap();
let res = handle.get_tool_defs(None).await.unwrap();
assert_eq!(res.len(), 0);
}
#[tokio::test]
pub async fn test_toolserver_dynamic_tools() {
let mut toolset = ToolSet::default();
toolset.add_tool(MockAddTool);
toolset.add_tool(MockSubtractTool);
let mock_index = MockToolIndex::new(["subtract"]);
let server = ToolServer::new().tool(MockAddTool).dynamic_tools(
1,
mock_index,
ToolSet::from_tools(vec![MockSubtractTool]),
);
let handle = server.run();
let res = handle.get_tool_defs(None).await.unwrap();
assert_eq!(res.len(), 1);
assert_eq!(res[0].name, "add");
let res = handle
.get_tool_defs(Some("calculate difference".to_string()))
.await
.unwrap();
assert_eq!(res.len(), 2);
let tool_names: Vec<&str> = res.iter().map(|t| t.name.as_str()).collect();
assert!(tool_names.contains(&"add"));
assert!(tool_names.contains(&"subtract"));
}
#[tokio::test]
pub async fn test_toolserver_dynamic_tools_missing_implementation() {
let mock_index = MockToolIndex::new(["nonexistent_tool"]);
let server =
ToolServer::new()
.tool(MockAddTool)
.dynamic_tools(1, mock_index, ToolSet::default());
let handle = server.run();
let res = handle
.get_tool_defs(Some("some query".to_string()))
.await
.unwrap();
assert_eq!(res.len(), 1);
assert_eq!(res[0].name, "add");
}
#[tokio::test]
pub async fn test_toolserver_concurrent_tool_execution() {
let num_calls = 3;
let barrier = Arc::new(tokio::sync::Barrier::new(num_calls));
let server = ToolServer::new().tool(MockBarrierTool::new(barrier.clone()));
let handle = server.run();
let futures: Vec<_> = (0..num_calls)
.map(|_| handle.call_tool("barrier_tool", "{}"))
.collect();
let result =
tokio::time::timeout(Duration::from_secs(1), futures::future::join_all(futures)).await;
assert!(
result.is_ok(),
"Tool execution deadlocked! Tools are executing sequentially instead of concurrently."
);
for res in result.unwrap() {
assert!(res.is_ok(), "Tool call failed: {:?}", res);
assert_eq!(res.unwrap(), "done");
}
}
#[tokio::test]
pub async fn test_toolserver_write_while_tool_running() {
let started = Arc::new(tokio::sync::Notify::new());
let allow_finish = Arc::new(tokio::sync::Notify::new());
let tool = MockControlledTool::new(started.clone(), allow_finish.clone());
let server = ToolServer::new().tool(tool);
let handle = server.run();
let handle_clone = handle.clone();
let call_task =
tokio::spawn(async move { handle_clone.call_tool("controlled", "{}").await });
started.notified().await;
let add_result =
tokio::time::timeout(Duration::from_secs(1), handle.add_tool(MockAddTool)).await;
assert!(
add_result.is_ok(),
"Writing to ToolServer deadlocked! The read lock is being held across tool execution."
);
assert!(add_result.unwrap().is_ok());
allow_finish.notify_one();
let call_result = call_task.await.unwrap();
assert_eq!(call_result.unwrap(), "42");
}
#[tokio::test]
pub async fn test_toolserver_parallel_dynamic_tool_fetching() {
let barrier = Arc::new(tokio::sync::Barrier::new(2));
let index1 = BarrierMockToolIndex::new(barrier.clone(), "add");
let index2 = BarrierMockToolIndex::new(barrier.clone(), "subtract");
let mut toolset = ToolSet::default();
toolset.add_tool(MockAddTool);
toolset.add_tool(MockSubtractTool);
let server = ToolServer::new()
.dynamic_tools(1, index1, ToolSet::default())
.dynamic_tools(1, index2, toolset);
let handle = server.run();
let get_defs = tokio::time::timeout(
std::time::Duration::from_secs(1),
handle.get_tool_defs(Some("do math".to_string())),
)
.await;
assert!(
get_defs.is_ok(),
"Dynamic tools were fetched sequentially! The first query deadlocked waiting for the second query to start."
);
let defs = get_defs.unwrap().unwrap();
assert_eq!(defs.len(), 2);
let tool_names: Vec<&str> = defs.iter().map(|t| t.name.as_str()).collect();
assert!(tool_names.contains(&"add"));
assert!(tool_names.contains(&"subtract"));
}
}