use crate::AgentToolDispatcher;
use crate::agent::{ExternalToolNotice, ExternalToolUpdate};
use crate::error::ToolError;
#[cfg(target_arch = "wasm32")]
use crate::tokio;
use crate::types::{ToolCallView, ToolDef, ToolResult};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
pub type AvailabilityCheck = Arc<dyn Fn() -> bool + Send + Sync>;
#[derive(Clone, Default)]
pub enum Availability {
#[default]
Always,
When {
check: AvailabilityCheck,
reason: String,
},
}
impl std::fmt::Debug for Availability {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Availability::Always => write!(f, "Availability::Always"),
Availability::When { reason, .. } => {
write!(f, "Availability::When {{ reason: {reason:?} }}")
}
}
}
}
impl Availability {
pub fn when(reason: impl Into<String>, check: AvailabilityCheck) -> Self {
Availability::When {
check,
reason: reason.into(),
}
}
pub fn is_available(&self) -> bool {
match self {
Availability::Always => true,
Availability::When { check, .. } => check(),
}
}
pub fn unavailable_reason(&self) -> Option<&str> {
match self {
Availability::Always => None,
Availability::When { check, reason } => {
if check() {
None
} else {
Some(reason)
}
}
}
}
}
struct DispatcherEntry {
dispatcher: Arc<dyn AgentToolDispatcher>,
availability: Availability,
}
pub struct ToolGateway {
all_tools: Vec<Arc<ToolDef>>,
tool_entry: Vec<usize>,
route: HashMap<String, usize>,
entries: Vec<DispatcherEntry>,
cache: RwLock<ToolGatewayCache>,
}
#[derive(Debug)]
struct ToolGatewayCache {
entry_available: Vec<bool>,
visible_tools: Arc<[Arc<ToolDef>]>,
}
impl std::fmt::Debug for ToolGateway {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolGateway")
.field(
"all_tools",
&self
.all_tools
.iter()
.map(|t| t.name.as_str())
.collect::<Vec<_>>(),
)
.field("routes", &self.route.keys().collect::<Vec<_>>())
.finish_non_exhaustive()
}
}
impl ToolGateway {
pub fn new(
base: Arc<dyn AgentToolDispatcher>,
overlay: Option<Arc<dyn AgentToolDispatcher>>,
) -> Result<Self, ToolError> {
let mut builder = ToolGatewayBuilder::new().add_dispatcher(base);
if let Some(o) = overlay {
builder = builder.add_dispatcher(o);
}
builder.build()
}
}
pub struct ToolGatewayBuilder {
dispatchers: Vec<(Arc<dyn AgentToolDispatcher>, Availability)>,
}
impl Default for ToolGatewayBuilder {
fn default() -> Self {
Self::new()
}
}
impl ToolGatewayBuilder {
pub fn new() -> Self {
Self {
dispatchers: Vec::new(),
}
}
pub fn add_dispatcher(self, dispatcher: Arc<dyn AgentToolDispatcher>) -> Self {
self.add_dispatcher_with_availability(dispatcher, Availability::Always)
}
pub fn add_dispatcher_with_availability(
mut self,
dispatcher: Arc<dyn AgentToolDispatcher>,
availability: Availability,
) -> Self {
self.dispatchers.push((dispatcher, availability));
self
}
pub fn maybe_add_dispatcher(self, dispatcher: Option<Arc<dyn AgentToolDispatcher>>) -> Self {
match dispatcher {
Some(d) => self.add_dispatcher(d),
None => self,
}
}
pub fn maybe_add_dispatcher_with_availability(
self,
dispatcher: Option<Arc<dyn AgentToolDispatcher>>,
availability: Availability,
) -> Self {
match dispatcher {
Some(d) => self.add_dispatcher_with_availability(d, availability),
None => self,
}
}
pub fn build(self) -> Result<ToolGateway, ToolError> {
let mut route: HashMap<String, usize> = HashMap::new();
let mut all_tools: Vec<Arc<ToolDef>> = Vec::new();
let mut tool_entry: Vec<usize> = Vec::new();
let mut entries: Vec<DispatcherEntry> = Vec::new();
for (dispatcher, availability) in self.dispatchers {
let entry_idx = entries.len();
for t in dispatcher.tools().iter() {
if route.contains_key(&t.name) {
return Err(ToolError::Other(format!(
"tool name collision in gateway: '{}'",
t.name
)));
}
route.insert(t.name.clone(), entry_idx);
all_tools.push(Arc::clone(t));
tool_entry.push(entry_idx);
}
entries.push(DispatcherEntry {
dispatcher,
availability,
});
}
let entry_available: Vec<bool> = entries
.iter()
.map(|e| e.availability.is_available())
.collect();
let mut visible = Vec::with_capacity(all_tools.len());
for (tool, &idx) in all_tools.iter().zip(tool_entry.iter()) {
if entry_available[idx] {
visible.push(Arc::clone(tool));
}
}
let visible_tools: Arc<[Arc<ToolDef>]> = visible.into();
Ok(ToolGateway {
all_tools,
tool_entry,
route,
entries,
cache: RwLock::new(ToolGatewayCache {
entry_available,
visible_tools,
}),
})
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl AgentToolDispatcher for ToolGateway {
fn tools(&self) -> Arc<[Arc<ToolDef>]> {
if let Ok(cache) = self.cache.try_read() {
let changed = self.entries.iter().enumerate().any(|(idx, entry)| {
cache.entry_available[idx] != entry.availability.is_available()
});
if !changed {
return Arc::clone(&cache.visible_tools);
}
}
let entry_available: Vec<bool> = self
.entries
.iter()
.map(|entry| entry.availability.is_available())
.collect();
let mut visible = Vec::with_capacity(self.all_tools.len());
for (tool, &idx) in self.all_tools.iter().zip(self.tool_entry.iter()) {
if entry_available[idx] {
visible.push(Arc::clone(tool));
}
}
let visible_tools: Arc<[Arc<ToolDef>]> = visible.into();
if let Ok(mut cache) = self.cache.try_write() {
cache.entry_available = entry_available;
cache.visible_tools = Arc::clone(&visible_tools);
}
visible_tools
}
async fn dispatch(&self, call: ToolCallView<'_>) -> Result<ToolResult, ToolError> {
let idx = self
.route
.get(call.name)
.ok_or_else(|| ToolError::not_found(call.name))?;
let entry = &self.entries[*idx];
if let Some(reason) = entry.availability.unavailable_reason() {
return Err(ToolError::unavailable(call.name, reason));
}
entry.dispatcher.dispatch(call).await
}
async fn poll_external_updates(&self) -> ExternalToolUpdate {
let mut all_notices: Vec<ExternalToolNotice> = Vec::new();
let mut all_pending: Vec<String> = Vec::new();
let mut seen_pending: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut seen_notices: std::collections::HashSet<(String, String, String)> =
std::collections::HashSet::new();
for entry in &self.entries {
let update = entry.dispatcher.poll_external_updates().await;
for notice in update.notices {
let key = (
notice.server.clone(),
format!("{:?}", notice.operation),
notice.status.clone(),
);
if seen_notices.insert(key) {
all_notices.push(notice);
}
}
for pending in update.pending {
if seen_pending.insert(pending.clone()) {
all_pending.push(pending);
}
}
}
ExternalToolUpdate {
notices: all_notices,
pending: all_pending,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use serde_json::Value;
use serde_json::json;
use std::sync::atomic::{AtomicBool, Ordering};
async fn dispatch_json(
gateway: &ToolGateway,
name: &str,
args: serde_json::Value,
) -> Result<Value, ToolError> {
let args_raw =
serde_json::value::RawValue::from_string(args.to_string()).expect("valid args json");
let call = ToolCallView {
id: "test-1",
name,
args: &args_raw,
};
let result = gateway.dispatch(call).await?;
serde_json::from_str(&result.content)
.map_err(|e| ToolError::execution_failed(e.to_string()))
}
fn empty_object_schema() -> Value {
let mut obj = serde_json::Map::new();
obj.insert("type".to_string(), Value::String("object".to_string()));
obj.insert(
"properties".to_string(),
Value::Object(serde_json::Map::new()),
);
obj.insert("required".to_string(), Value::Array(Vec::new()));
Value::Object(obj)
}
struct MockDispatcher {
tools: Arc<[Arc<ToolDef>]>,
prefix: String,
}
impl MockDispatcher {
fn new(prefix: &str, tool_names: &[&str]) -> Self {
let tools: Arc<[Arc<ToolDef>]> = tool_names
.iter()
.map(|name| {
Arc::new(ToolDef {
name: name.to_string(),
description: format!("{prefix} tool: {name}"),
input_schema: empty_object_schema(),
})
})
.collect::<Vec<_>>()
.into();
Self {
tools,
prefix: prefix.to_string(),
}
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl AgentToolDispatcher for MockDispatcher {
fn tools(&self) -> Arc<[Arc<ToolDef>]> {
Arc::clone(&self.tools)
}
async fn dispatch(&self, call: ToolCallView<'_>) -> Result<ToolResult, ToolError> {
if self.tools.iter().any(|t| t.name == call.name) {
Ok(ToolResult {
tool_use_id: call.id.to_string(),
content: json!({"source": self.prefix, "tool": call.name}).to_string(),
is_error: false,
})
} else {
Err(ToolError::not_found(call.name))
}
}
}
#[test]
fn test_gateway_merges_tools() {
let base = Arc::new(MockDispatcher::new("base", &["task_create", "task_list"]));
let overlay = Arc::new(MockDispatcher::new("comms", &["send", "peers"]));
let gateway = ToolGateway::new(base, Some(overlay)).unwrap();
let tools = gateway.tools();
let tool_names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
assert_eq!(tool_names.len(), 4);
assert!(tool_names.contains(&"task_create"));
assert!(tool_names.contains(&"task_list"));
assert!(tool_names.contains(&"send"));
assert!(tool_names.contains(&"peers"));
}
#[test]
fn test_gateway_no_overlay() {
let base = Arc::new(MockDispatcher::new("base", &["task_create", "task_list"]));
let gateway = ToolGateway::new(base, None).unwrap();
assert_eq!(gateway.tools().len(), 2);
}
#[test]
fn test_gateway_collision_error() {
let base = Arc::new(MockDispatcher::new("base", &["task_create", "send"]));
let overlay = Arc::new(MockDispatcher::new("comms", &["send", "peers"]));
let result = ToolGateway::new(base, Some(overlay));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("send"));
assert!(err.to_string().contains("collision"));
}
#[tokio::test]
async fn test_gateway_routes_to_base() {
let base = Arc::new(MockDispatcher::new("base", &["task_create"]));
let overlay = Arc::new(MockDispatcher::new("comms", &["send"]));
let gateway = ToolGateway::new(base, Some(overlay)).unwrap();
let result = dispatch_json(&gateway, "task_create", json!({}))
.await
.unwrap();
assert_eq!(result["source"], "base");
assert_eq!(result["tool"], "task_create");
}
#[tokio::test]
async fn test_gateway_routes_to_overlay() {
let base = Arc::new(MockDispatcher::new("base", &["task_create"]));
let overlay = Arc::new(MockDispatcher::new("comms", &["send"]));
let gateway = ToolGateway::new(base, Some(overlay)).unwrap();
let result = dispatch_json(&gateway, "send", json!({})).await.unwrap();
assert_eq!(result["source"], "comms");
assert_eq!(result["tool"], "send");
}
#[tokio::test]
async fn test_gateway_not_found() {
let base = Arc::new(MockDispatcher::new("base", &["task_create"]));
let gateway = ToolGateway::new(base, None).unwrap();
let result = dispatch_json(&gateway, "unknown_tool", json!({})).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ToolError::NotFound { .. }));
}
#[test]
fn test_builder_multiple_dispatchers() {
let base = Arc::new(MockDispatcher::new("base", &["task_create"]));
let comms = Arc::new(MockDispatcher::new("comms", &["send"]));
let shell = Arc::new(MockDispatcher::new("shell", &["run_command"]));
let gateway = ToolGatewayBuilder::new()
.add_dispatcher(base)
.add_dispatcher(comms)
.add_dispatcher(shell)
.build()
.unwrap();
assert_eq!(gateway.tools().len(), 3);
}
#[test]
fn test_availability_always() {
let avail = Availability::Always;
assert!(avail.is_available());
assert!(avail.unavailable_reason().is_none());
}
#[test]
fn test_availability_when_true() {
let avail = Availability::when("no peers", Arc::new(|| true));
assert!(avail.is_available());
assert!(avail.unavailable_reason().is_none());
}
#[test]
fn test_availability_when_false() {
let avail = Availability::when("no peers configured", Arc::new(|| false));
assert!(!avail.is_available());
assert_eq!(avail.unavailable_reason(), Some("no peers configured"));
}
#[test]
fn test_availability_dynamic() {
let flag = Arc::new(AtomicBool::new(false));
let flag_clone = flag.clone();
let avail = Availability::when(
"no peers",
Arc::new(move || flag_clone.load(Ordering::SeqCst)),
);
assert!(!avail.is_available());
flag.store(true, Ordering::SeqCst);
assert!(avail.is_available());
flag.store(false, Ordering::SeqCst);
assert!(!avail.is_available());
}
#[test]
fn test_gateway_conditional_visibility() {
let flag = Arc::new(AtomicBool::new(false));
let flag_clone = flag.clone();
let base = Arc::new(MockDispatcher::new("base", &["task_create"]));
let comms = Arc::new(MockDispatcher::new("comms", &["send"]));
let gateway = ToolGatewayBuilder::new()
.add_dispatcher(base)
.add_dispatcher_with_availability(
comms,
Availability::when(
"no peers",
Arc::new(move || flag_clone.load(Ordering::SeqCst)),
),
)
.build()
.unwrap();
let tools = gateway.tools();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name, "task_create");
flag.store(true, Ordering::SeqCst);
let tools = gateway.tools();
assert_eq!(tools.len(), 2);
flag.store(false, Ordering::SeqCst);
let tools = gateway.tools();
assert_eq!(tools.len(), 1);
}
#[tokio::test]
async fn test_gateway_unavailable_dispatch() {
let flag = Arc::new(AtomicBool::new(false));
let flag_clone = flag.clone();
let base = Arc::new(MockDispatcher::new("base", &["task_create"]));
let comms = Arc::new(MockDispatcher::new("comms", &["send"]));
let gateway = ToolGatewayBuilder::new()
.add_dispatcher(base)
.add_dispatcher_with_availability(
comms,
Availability::when(
"no peers configured",
Arc::new(move || flag_clone.load(Ordering::SeqCst)),
),
)
.build()
.unwrap();
let result = dispatch_json(&gateway, "send", json!({})).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, ToolError::Unavailable { .. }));
assert!(err.to_string().contains("no peers configured"));
flag.store(true, Ordering::SeqCst);
let result = dispatch_json(&gateway, "send", json!({})).await;
assert!(result.is_ok());
}
#[test]
fn test_collision_detection_ignores_availability() {
let flag = Arc::new(AtomicBool::new(false));
let base = Arc::new(MockDispatcher::new("base", &["send"]));
let comms = Arc::new(MockDispatcher::new("comms", &["send"]));
let result = ToolGatewayBuilder::new()
.add_dispatcher(base)
.add_dispatcher_with_availability(
comms,
Availability::when("no peers", Arc::new(move || flag.load(Ordering::SeqCst))),
)
.build();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("collision"));
}
#[test]
fn test_availability_debug() {
let always = Availability::Always;
assert_eq!(format!("{always:?}"), "Availability::Always");
let when = Availability::when("test reason", Arc::new(|| true));
assert!(format!("{when:?}").contains("test reason"));
}
#[test]
fn test_builder_maybe_add() {
let base = Arc::new(MockDispatcher::new("base", &["task_create"]));
let gateway = ToolGatewayBuilder::new()
.add_dispatcher(base.clone())
.maybe_add_dispatcher(None)
.build()
.unwrap();
assert_eq!(gateway.tools().len(), 1);
let overlay = Arc::new(MockDispatcher::new("comms", &["send"]));
let gateway = ToolGatewayBuilder::new()
.add_dispatcher(base)
.maybe_add_dispatcher(Some(overlay))
.build()
.unwrap();
assert_eq!(gateway.tools().len(), 2);
}
#[test]
fn test_dispatcher_all_or_nothing_visibility() {
let flag = Arc::new(AtomicBool::new(false));
let flag_clone = flag.clone();
let base = Arc::new(MockDispatcher::new("base", &["task_create"]));
let comms = Arc::new(MockDispatcher::new(
"comms",
&["send", "send_request", "send_response", "peers"],
));
let gateway = ToolGatewayBuilder::new()
.add_dispatcher(base)
.add_dispatcher_with_availability(
comms,
Availability::when(
"no peers",
Arc::new(move || flag_clone.load(Ordering::SeqCst)),
),
)
.build()
.unwrap();
let tools = gateway.tools();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name, "task_create");
flag.store(true, Ordering::SeqCst);
let tools = gateway.tools();
assert_eq!(tools.len(), 5); let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
assert!(names.contains(&"task_create"));
assert!(names.contains(&"send"));
assert!(names.contains(&"send_request"));
assert!(names.contains(&"send_response"));
assert!(names.contains(&"peers"));
flag.store(false, Ordering::SeqCst);
let tools = gateway.tools();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name, "task_create");
}
}