use std::{
collections::{HashMap, HashSet},
sync::{Arc, OnceLock, RwLock},
};
use futures::future::BoxFuture;
use tokio::sync::Mutex;
use crate::{
Arguments, Error, Result, ServerCtx,
schema::{CallToolResult, Cursor, ListToolsResult, ServerNotification, Tool, ToolSchema},
};
pub struct ToolSetView<'a> {
groups: &'a HashMap<String, GroupSnapshot>,
}
#[doc(hidden)]
pub type ToolFuture<'a> = BoxFuture<'a, Result<CallToolResult>>;
impl ToolSetView<'_> {
pub fn is_group_active(&self, name: &str) -> bool {
is_group_active_snapshot(self.groups, name)
}
}
#[derive(Clone)]
pub enum Visibility {
Always,
Group(String),
When(Arc<dyn Fn(&ToolSetView) -> bool + Send + Sync>),
}
impl Visibility {
fn is_visible(&self, view: &ToolSetView) -> bool {
match self {
Self::Always => true,
Self::Group(name) => view.is_group_active(name),
Self::When(predicate) => predicate(view),
}
}
}
pub struct GroupInfo {
pub name: String,
pub description: String,
pub active: bool,
pub parent: Option<String>,
pub tool_count: usize,
}
pub struct GroupConfig {
pub name: String,
pub description: String,
pub parent: Option<String>,
pub on_activate: Option<ActivationHook>,
pub on_deactivate: Option<ActivationHook>,
pub show_deactivator: bool,
}
pub type ActivationHook = Box<dyn Fn(&ServerCtx) -> BoxFuture<'static, Result<()>> + Send + Sync>;
pub trait Group: GroupDispatch {
fn config(&self) -> GroupConfig;
fn register(&self, toolset: &ToolSet, parent: Option<&str>) -> Result<()>
where
Self: Sized,
{
GroupRegistration::register_with_override(self, toolset, parent, None)
}
}
#[doc(hidden)]
pub trait GroupDispatch {
fn register_tools(&self, toolset: &ToolSet, group_name: &str) -> Result<()>;
fn call_tool<'a>(
&'a self,
ctx: &'a ServerCtx,
name: &'a str,
arguments: Option<Arguments>,
) -> ToolFuture<'a>;
}
#[doc(hidden)]
pub trait GroupRegistration: Group + GroupDispatch {
fn register_with_override(
&self,
toolset: &ToolSet,
parent: Option<&str>,
segment_override: Option<&str>,
) -> Result<()>;
}
impl<T> GroupRegistration for T
where
T: Group + GroupDispatch,
{
fn register_with_override(
&self,
toolset: &ToolSet,
parent: Option<&str>,
segment_override: Option<&str>,
) -> Result<()> {
let mut config = self.config();
let segment = segment_override.unwrap_or(config.name.as_str());
validate_group_segment(segment)?;
let group_name = if let Some(parent) = parent {
format!("{parent}.{segment}")
} else {
segment.to_string()
};
config.parent = parent.map(|parent| parent.to_string());
config.name = group_name.clone();
toolset.register_group(config)?;
self.register_tools(toolset, &group_name)?;
Ok(())
}
}
#[derive(Clone)]
pub struct ToolSet {
tools: Arc<RwLock<HashMap<String, ToolEntry>>>,
groups: ToolGroups,
registration: Arc<OnceLock<()>>,
activation_lock: Arc<Mutex<()>>,
}
impl Default for ToolSet {
fn default() -> Self {
Self {
tools: Arc::new(RwLock::new(HashMap::new())),
groups: ToolGroups::default(),
registration: Arc::new(OnceLock::new()),
activation_lock: Arc::new(Mutex::new(())),
}
}
}
impl ToolSet {
pub fn register_group(&self, config: GroupConfig) -> Result<()> {
validate_group_path(&config.name)?;
let auto_config = AutoGroupConfig::from(&config);
self.groups.register_group(config)?;
if let Err(error) = self.register_auto_tools(&auto_config) {
self.groups.unregister_group(&auto_config.name);
return Err(error);
}
Ok(())
}
pub fn register_exclusion(&self, groups: &[&str]) -> Result<()> {
self.groups.register_exclusion(groups)
}
pub fn qualified_name(group: &str, base: &str) -> String {
if group.is_empty() {
base.to_string()
} else {
format!("{group}.{base}")
}
}
pub async fn activate_group(&self, name: &str, ctx: &ServerCtx) -> Result<bool> {
let _guard = self.activation_lock.lock().await;
let plan = self.groups.plan_activation(name)?;
run_activation_hooks(&plan.hooks, ctx).await?;
let changed = self.groups.apply_activation(plan)?;
if changed {
self.notify_list_changed(ctx)?;
}
Ok(changed)
}
pub async fn deactivate_group(&self, name: &str, ctx: &ServerCtx) -> Result<bool> {
let _guard = self.activation_lock.lock().await;
let plan = self.groups.plan_deactivation(name)?;
run_activation_hooks(&plan.hooks, ctx).await?;
let changed = self.groups.apply_deactivation(plan)?;
if changed {
self.notify_list_changed(ctx)?;
}
Ok(changed)
}
pub fn is_group_active(&self, name: &str) -> bool {
self.groups.is_active(name)
}
pub fn list_groups(&self) -> Vec<GroupInfo> {
let tool_counts = self.group_tool_counts();
let mut groups = self.groups.list_groups(&tool_counts);
groups.sort_by(|a, b| a.name.cmp(&b.name));
groups
}
pub fn register<F>(&self, name: &str, tool: Tool, handler: F) -> Result<()>
where
F: for<'a> Fn(&'a ServerCtx, Option<Arguments>) -> ToolFuture<'a> + Send + Sync + 'static,
{
self.register_with_visibility(name, tool, Visibility::Always, handler)
}
pub fn register_with_visibility<F>(
&self,
name: &str,
tool: Tool,
visibility: Visibility,
handler: F,
) -> Result<()>
where
F: for<'a> Fn(&'a ServerCtx, Option<Arguments>) -> ToolFuture<'a> + Send + Sync + 'static,
{
let handler: ToolHandler = Arc::new(handler);
self.register_entry(name, tool, visibility, Some(handler), ToolOrigin::Explicit)
}
pub fn unregister(&self, name: &str) -> Option<Tool> {
self.tools
.write()
.unwrap_or_else(|err| err.into_inner())
.remove(name)
.map(|entry| entry.tool)
}
pub fn notify_list_changed(&self, ctx: &ServerCtx) -> Result<()> {
ctx.notify(ServerNotification::tool_list_changed())
}
pub fn list_tools(&self, cursor: Option<Cursor>) -> Result<ListToolsResult> {
let snapshot = self.groups.snapshot();
let view = ToolSetView { groups: &snapshot };
let entries = self
.tools
.read()
.unwrap_or_else(|err| err.into_inner())
.values()
.cloned()
.collect::<Vec<_>>();
let mut tools = entries
.into_iter()
.filter(|entry| entry.visibility.is_visible(&view))
.map(|entry| entry.tool)
.collect::<Vec<_>>();
tools.sort_by(|a, b| a.name.cmp(&b.name));
let offset = cursor.map(parse_cursor_offset).transpose()?.unwrap_or(0);
if offset > tools.len() {
return Err(Error::InvalidParams("cursor out of range".to_string()));
}
let tools = tools.into_iter().skip(offset).collect();
Ok(ListToolsResult {
tools,
next_cursor: None,
})
}
pub async fn call_tool(
&self,
ctx: &ServerCtx,
name: &str,
arguments: Option<Arguments>,
) -> Result<CallToolResult> {
if !self.is_tool_visible(name) {
return Err(Error::ToolNotFound(name.to_string()));
}
let handler = self.tool_handler(name);
let handler = handler.ok_or_else(|| Error::ToolNotFound(name.to_string()))?;
handler(ctx, arguments).await
}
pub async fn call_tool_with<H, F>(
&self,
handler: &H,
ctx: &ServerCtx,
name: &str,
arguments: Option<Arguments>,
dispatch: F,
) -> Result<CallToolResult>
where
F: for<'a> Fn(&'a H, &'a ServerCtx, &'a str, Option<Arguments>) -> ToolFuture<'a>,
{
if !self.is_tool_visible(name) {
return Err(Error::ToolNotFound(name.to_string()));
}
if let Some(tool_handler) = self.tool_handler(name) {
return tool_handler(ctx, arguments).await;
}
dispatch(handler, ctx, name, arguments).await
}
#[doc(hidden)]
pub fn register_schema(&self, name: &str, tool: Tool, visibility: Visibility) -> Result<()> {
self.register_entry(name, tool, visibility, None, ToolOrigin::Explicit)
}
#[doc(hidden)]
pub fn ensure_registered<F>(&self, register: F)
where
F: FnOnce(),
{
let _ = self.registration.get_or_init(|| {
register();
});
}
#[doc(hidden)]
pub fn is_tool_visible(&self, name: &str) -> bool {
let Some(entry) = self.tool_entry(name) else {
return false;
};
let snapshot = self.groups.snapshot();
let view = ToolSetView { groups: &snapshot };
entry.visibility.is_visible(&view)
}
#[doc(hidden)]
pub async fn call_dynamic_tool(
&self,
ctx: &ServerCtx,
name: &str,
arguments: Option<Arguments>,
) -> Result<CallToolResult> {
self.call_tool(ctx, name, arguments).await
}
fn tool_group_name(tool_name: &str) -> Option<String> {
tool_name
.rsplit_once('.')
.map(|x| x.0)
.map(|group| group.to_string())
}
fn group_tool_counts(&self) -> HashMap<String, usize> {
let tools = self.tools.read().unwrap_or_else(|err| err.into_inner());
let mut counts = HashMap::new();
for name in tools.keys() {
if let Some(group) = Self::tool_group_name(name) {
*counts.entry(group).or_insert(0) += 1;
}
}
counts
}
fn register_entry(
&self,
name: &str,
mut tool: Tool,
visibility: Visibility,
handler: Option<ToolHandler>,
origin: ToolOrigin,
) -> Result<()> {
self.validate_tool_registration(name, &visibility)?;
tool.name = name.to_string();
let mut tools = self.tools.write().unwrap_or_else(|err| err.into_inner());
if let Some(existing) = tools.get(name)
&& (existing.origin != ToolOrigin::AutoGroup || origin != ToolOrigin::Explicit)
{
return Err(Error::InvalidConfiguration(format!(
"tool already registered: {name}"
)));
}
tools.insert(
name.to_string(),
ToolEntry {
tool,
visibility,
handler,
origin,
},
);
Ok(())
}
fn validate_tool_registration(&self, name: &str, visibility: &Visibility) -> Result<()> {
let (group_path, _base) = split_tool_name(name)?;
match visibility {
Visibility::Group(group_name) => {
if group_path != Some(group_name.as_str()) {
return Err(Error::InvalidConfiguration(format!(
"tool '{name}' must be prefixed with group '{group_name}'"
)));
}
self.groups.ensure_group_exists(group_name)?;
}
_ => {
if let Some(group) = group_path {
self.groups.ensure_group_exists(group)?;
}
}
}
Ok(())
}
fn tool_entry(&self, name: &str) -> Option<ToolEntry> {
self.tools
.read()
.unwrap_or_else(|err| err.into_inner())
.get(name)
.cloned()
}
fn tool_handler(&self, name: &str) -> Option<ToolHandler> {
self.tool_entry(name).and_then(|entry| entry.handler)
}
fn register_auto_tools(&self, config: &AutoGroupConfig) -> Result<()> {
let group = config.name.clone();
let activate_name = Self::qualified_name(&group, "activate");
let activate_tool = activation_tool(&activate_name, &config.description, true);
let toolset = self.clone();
let handler_group = group.clone();
let activate_handler: ToolHandler =
Arc::new(move |ctx: &ServerCtx, _args: Option<Arguments>| {
let toolset = toolset.clone();
let handler_group = handler_group.clone();
let ctx = ctx.clone();
Box::pin(async move {
let _ = toolset.activate_group(&handler_group, &ctx).await?;
Ok(CallToolResult::new())
})
});
let mut entries = vec![(activate_name, activate_tool, Some(activate_handler))];
if config.show_deactivator {
let deactivate_name = Self::qualified_name(&group, "deactivate");
let deactivate_tool = activation_tool(&deactivate_name, &config.description, false);
let toolset = self.clone();
let handler_group = group;
let deactivate_handler: ToolHandler =
Arc::new(move |ctx: &ServerCtx, _args: Option<Arguments>| {
let toolset = toolset.clone();
let handler_group = handler_group.clone();
let ctx = ctx.clone();
Box::pin(async move {
let _ = toolset.deactivate_group(&handler_group, &ctx).await?;
Ok(CallToolResult::new())
})
});
entries.push((deactivate_name, deactivate_tool, Some(deactivate_handler)));
}
let visibility = Visibility::Always;
for (name, tool, _handler) in &mut entries {
self.validate_tool_registration(name, &visibility)?;
tool.name = name.clone();
}
let mut tools = self.tools.write().unwrap_or_else(|err| err.into_inner());
for (name, _, _) in &entries {
if tools.contains_key(name) {
return Err(Error::InvalidConfiguration(format!(
"tool already registered: {name}"
)));
}
}
for (name, tool, handler) in entries {
tools.insert(
name,
ToolEntry {
tool,
visibility: visibility.clone(),
handler,
origin: ToolOrigin::AutoGroup,
},
);
}
Ok(())
}
}
type ToolHandler =
Arc<dyn for<'a> Fn(&'a ServerCtx, Option<Arguments>) -> ToolFuture<'a> + Send + Sync>;
type SharedActivationHook = Arc<dyn Fn(&ServerCtx) -> BoxFuture<'static, Result<()>> + Send + Sync>;
#[derive(Clone)]
struct ToolEntry {
tool: Tool,
visibility: Visibility,
handler: Option<ToolHandler>,
origin: ToolOrigin,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum ToolOrigin {
AutoGroup,
Explicit,
}
struct AutoGroupConfig {
name: String,
description: String,
show_deactivator: bool,
}
impl From<&GroupConfig> for AutoGroupConfig {
fn from(config: &GroupConfig) -> Self {
Self {
name: config.name.clone(),
description: config.description.clone(),
show_deactivator: config.show_deactivator,
}
}
}
#[derive(Clone)]
struct GroupSnapshot {
active: bool,
parent: Option<String>,
}
#[derive(Clone, Default)]
struct ToolGroups {
registry: Arc<RwLock<GroupRegistry>>,
}
impl ToolGroups {
fn register_group(&self, config: GroupConfig) -> Result<()> {
let mut registry = self.registry.write().unwrap_or_else(|err| err.into_inner());
if registry.groups.contains_key(&config.name) {
return Err(Error::InvalidConfiguration(format!(
"group already registered: {}",
config.name
)));
}
if let Some(parent) = &config.parent {
if !registry.groups.contains_key(parent) {
return Err(Error::GroupNotFound(parent.clone()));
}
if !config.name.starts_with(&format!("{parent}.")) {
return Err(Error::InvalidConfiguration(format!(
"group '{}' must be nested under parent '{}'",
config.name, parent
)));
}
}
let state = GroupState {
description: config.description.clone(),
active: false,
parent: config.parent.clone(),
on_activate: config.on_activate.map(Arc::from),
on_deactivate: config.on_deactivate.map(Arc::from),
};
registry.groups.insert(config.name, state);
Ok(())
}
fn unregister_group(&self, name: &str) {
let mut registry = self.registry.write().unwrap_or_else(|err| err.into_inner());
registry.groups.remove(name);
registry
.exclusions
.retain(|exclusion| !exclusion.iter().any(|group| group == name));
}
fn register_exclusion(&self, groups: &[&str]) -> Result<()> {
if groups.len() < 2 {
return Err(Error::InvalidConfiguration(
"exclusion sets require at least two groups".to_string(),
));
}
let mut unique = HashSet::new();
let mut entries = Vec::new();
let registry = self.registry.read().unwrap_or_else(|err| err.into_inner());
for group in groups {
if !registry.groups.contains_key(*group) {
return Err(Error::GroupNotFound((*group).to_string()));
}
if unique.insert(*group) {
entries.push((*group).to_string());
}
}
drop(registry);
let mut registry = self.registry.write().unwrap_or_else(|err| err.into_inner());
registry.exclusions.push(entries);
Ok(())
}
fn ensure_group_exists(&self, name: &str) -> Result<()> {
let registry = self.registry.read().unwrap_or_else(|err| err.into_inner());
if registry.groups.contains_key(name) {
Ok(())
} else {
Err(Error::GroupNotFound(name.to_string()))
}
}
fn is_active(&self, name: &str) -> bool {
let snapshot = self.snapshot();
is_group_active_snapshot(&snapshot, name)
}
fn snapshot(&self) -> HashMap<String, GroupSnapshot> {
let registry = self.registry.read().unwrap_or_else(|err| err.into_inner());
registry
.groups
.iter()
.map(|(name, state)| {
(
name.clone(),
GroupSnapshot {
active: state.active,
parent: state.parent.clone(),
},
)
})
.collect()
}
fn list_groups(&self, tool_counts: &HashMap<String, usize>) -> Vec<GroupInfo> {
let registry = self.registry.read().unwrap_or_else(|err| err.into_inner());
registry
.groups
.iter()
.map(|(name, state)| GroupInfo {
name: name.clone(),
description: state.description.clone(),
active: state.active,
parent: state.parent.clone(),
tool_count: *tool_counts.get(name).unwrap_or(&0),
})
.collect()
}
fn plan_activation(&self, name: &str) -> Result<GroupChangePlan> {
let registry = self.registry.read().unwrap_or_else(|err| err.into_inner());
let snapshot = snapshot_from_registry(®istry);
let target = registry
.groups
.get(name)
.ok_or_else(|| Error::GroupNotFound(name.to_string()))?;
if let Some(parent) = &target.parent
&& !is_group_active_snapshot(&snapshot, parent)
{
return Err(Error::GroupInactive {
group: name.to_string(),
parent: parent.clone(),
});
}
let mut to_deactivate = Vec::new();
let mut deactivation_set = HashSet::new();
for exclusion in ®istry.exclusions {
if exclusion.iter().any(|group| group == name) {
for other in exclusion {
if other != name && is_group_active_snapshot(&snapshot, other) {
for group in collect_descendants_including_self(®istry.groups, other) {
if deactivation_set.insert(group.clone()) {
to_deactivate.push(group);
}
}
}
}
}
}
let mut hooks = Vec::new();
for group in &to_deactivate {
if let Some(state) = registry.groups.get(group)
&& state.active
&& let Some(hook) = &state.on_deactivate
{
hooks.push(hook.clone());
}
}
if !target.active
&& let Some(hook) = &target.on_activate
{
hooks.push(hook.clone());
}
let changed = !to_deactivate.is_empty() || !target.active;
Ok(GroupChangePlan {
target: name.to_string(),
deactivate: to_deactivate,
activate: !target.active,
hooks,
changed,
})
}
fn apply_activation(&self, plan: GroupChangePlan) -> Result<bool> {
if !plan.changed {
return Ok(false);
}
let mut registry = self.registry.write().unwrap_or_else(|err| err.into_inner());
if !registry.groups.contains_key(&plan.target) {
return Err(Error::GroupNotFound(plan.target));
}
for group in plan.deactivate {
if let Some(state) = registry.groups.get_mut(&group) {
state.active = false;
}
}
if plan.activate
&& let Some(state) = registry.groups.get_mut(&plan.target)
{
state.active = true;
}
Ok(true)
}
fn plan_deactivation(&self, name: &str) -> Result<GroupChangePlan> {
let registry = self.registry.read().unwrap_or_else(|err| err.into_inner());
let target = registry
.groups
.get(name)
.ok_or_else(|| Error::GroupNotFound(name.to_string()))?;
let mut to_deactivate = Vec::new();
for group in collect_descendants_including_self(®istry.groups, name) {
to_deactivate.push(group);
}
let mut hooks = Vec::new();
for group in &to_deactivate {
if let Some(state) = registry.groups.get(group)
&& state.active
&& let Some(hook) = &state.on_deactivate
{
hooks.push(hook.clone());
}
}
let changed = target.active
|| to_deactivate.iter().any(|group| {
registry
.groups
.get(group)
.map(|state| state.active)
.unwrap_or(false)
});
Ok(GroupChangePlan {
target: name.to_string(),
deactivate: to_deactivate,
activate: false,
hooks,
changed,
})
}
fn apply_deactivation(&self, plan: GroupChangePlan) -> Result<bool> {
if !plan.changed {
return Ok(false);
}
let mut registry = self.registry.write().unwrap_or_else(|err| err.into_inner());
if !registry.groups.contains_key(&plan.target) {
return Err(Error::GroupNotFound(plan.target));
}
for group in plan.deactivate {
if let Some(state) = registry.groups.get_mut(&group) {
state.active = false;
}
}
Ok(true)
}
}
#[derive(Default)]
struct GroupRegistry {
groups: HashMap<String, GroupState>,
exclusions: Vec<Vec<String>>,
}
struct GroupState {
description: String,
active: bool,
parent: Option<String>,
on_activate: Option<SharedActivationHook>,
on_deactivate: Option<SharedActivationHook>,
}
struct GroupChangePlan {
target: String,
deactivate: Vec<String>,
activate: bool,
hooks: Vec<SharedActivationHook>,
changed: bool,
}
async fn run_activation_hooks(hooks: &[SharedActivationHook], ctx: &ServerCtx) -> Result<()> {
for hook in hooks {
hook(ctx).await?;
}
Ok(())
}
fn validate_group_segment(segment: &str) -> Result<()> {
if segment.is_empty() || segment.contains('.') {
return Err(Error::InvalidConfiguration(format!(
"invalid group segment: {segment}"
)));
}
Ok(())
}
fn validate_group_path(group: &str) -> Result<()> {
if group.is_empty() {
return Err(Error::InvalidConfiguration(
"group name is empty".to_string(),
));
}
for segment in group.split('.') {
validate_group_segment(segment)?;
}
Ok(())
}
fn split_tool_name(name: &str) -> Result<(Option<&str>, &str)> {
let mut parts = name.rsplitn(2, '.');
let base = parts
.next()
.ok_or_else(|| Error::InvalidConfiguration("tool name is empty".to_string()))?;
if base.is_empty() {
return Err(Error::InvalidConfiguration(
"tool name is empty".to_string(),
));
}
let group = parts.next();
if let Some(group) = group {
validate_group_path(group)?;
Ok((Some(group), base))
} else {
Ok((None, base))
}
}
fn parse_cursor_offset(cursor: Cursor) -> Result<usize> {
let Cursor(value) = cursor;
value
.parse::<usize>()
.map_err(|_| Error::InvalidParams("invalid cursor".to_string()))
}
fn is_group_active_snapshot(groups: &HashMap<String, GroupSnapshot>, name: &str) -> bool {
let mut current = match groups.get(name) {
Some(state) => state,
None => return false,
};
if !current.active {
return false;
}
let mut guard = HashSet::new();
while let Some(parent) = current.parent.as_deref() {
if !guard.insert(parent) {
return false;
}
current = match groups.get(parent) {
Some(state) => state,
None => return false,
};
if !current.active {
return false;
}
}
true
}
fn snapshot_from_registry(registry: &GroupRegistry) -> HashMap<String, GroupSnapshot> {
registry
.groups
.iter()
.map(|(name, state)| {
(
name.clone(),
GroupSnapshot {
active: state.active,
parent: state.parent.clone(),
},
)
})
.collect()
}
fn collect_descendants_including_self(
groups: &HashMap<String, GroupState>,
root: &str,
) -> Vec<String> {
let mut collected = Vec::new();
collect_descendants(groups, root, &mut collected);
collected.push(root.to_string());
collected
}
fn collect_descendants(
groups: &HashMap<String, GroupState>,
root: &str,
collected: &mut Vec<String>,
) {
for (name, _) in groups
.iter()
.filter(|(_, state)| state.parent.as_deref() == Some(root))
{
collect_descendants(groups, name, collected);
collected.push(name.clone());
}
}
fn activation_tool(name: &str, description: &str, activate: bool) -> Tool {
let mut tool = Tool::new(name.to_string(), ToolSchema::empty());
if !description.is_empty() {
let label = if activate { "Activate" } else { "Deactivate" };
tool = tool.with_description(format!("{label} {description}"));
}
tool
}