use crate::AgentToolDispatcher;
use crate::agent::{DetachedOpCompletion, ExternalToolUpdate};
use crate::error::ToolError;
use crate::event::ExternalToolDelta;
#[cfg(target_arch = "wasm32")]
use crate::tokio;
use crate::tool_catalog::{ToolCatalogCapabilities, ToolCatalogEntry};
#[cfg(test)]
use crate::types::ToolResult;
use crate::types::{ToolCallView, ToolDef};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
pub type AvailabilityCheck = Arc<dyn Fn() -> bool + Send + Sync>;
#[derive(Clone, Default)]
pub enum Availability {
#[default]
Always,
When {
check: AvailabilityCheck,
reason: String,
},
}
impl std::fmt::Debug for Availability {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Availability::Always => write!(f, "Availability::Always"),
Availability::When { reason, .. } => {
write!(f, "Availability::When {{ reason: {reason:?} }}")
}
}
}
}
impl Availability {
pub fn when(reason: impl Into<String>, check: AvailabilityCheck) -> Self {
Availability::When {
check,
reason: reason.into(),
}
}
pub fn is_available(&self) -> bool {
match self {
Availability::Always => true,
Availability::When { check, .. } => check(),
}
}
pub fn unavailable_reason(&self) -> Option<&str> {
match self {
Availability::Always => None,
Availability::When { check, reason } => {
if check() {
None
} else {
Some(reason)
}
}
}
}
}
struct DispatcherEntry {
dispatcher: Arc<dyn AgentToolDispatcher>,
availability: Availability,
}
pub struct ToolGateway {
all_tools: Vec<Arc<ToolDef>>,
catalog_entries: Vec<ToolCatalogEntry>,
tool_entry: Vec<usize>,
route: HashMap<String, usize>,
entries: Vec<DispatcherEntry>,
cache: RwLock<ToolGatewayCache>,
}
#[derive(Debug)]
struct ToolGatewayCache {
entry_available: Vec<bool>,
visible_tools: Arc<[Arc<ToolDef>]>,
}
impl std::fmt::Debug for ToolGateway {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolGateway")
.field(
"all_tools",
&self
.all_tools
.iter()
.map(|t| t.name.as_str())
.collect::<Vec<_>>(),
)
.field("routes", &self.route.keys().collect::<Vec<_>>())
.finish_non_exhaustive()
}
}
impl ToolGateway {
pub fn new(
base: Arc<dyn AgentToolDispatcher>,
overlay: Option<Arc<dyn AgentToolDispatcher>>,
) -> Result<Self, ToolError> {
let mut builder = ToolGatewayBuilder::new().add_dispatcher(base);
if let Some(o) = overlay {
builder = builder.add_dispatcher(o);
}
builder.build()
}
}
pub struct ToolGatewayBuilder {
dispatchers: Vec<(Arc<dyn AgentToolDispatcher>, Availability)>,
}
impl Default for ToolGatewayBuilder {
fn default() -> Self {
Self::new()
}
}
impl ToolGatewayBuilder {
pub fn new() -> Self {
Self {
dispatchers: Vec::new(),
}
}
pub fn add_dispatcher(self, dispatcher: Arc<dyn AgentToolDispatcher>) -> Self {
self.add_dispatcher_with_availability(dispatcher, Availability::Always)
}
pub fn add_dispatcher_with_availability(
mut self,
dispatcher: Arc<dyn AgentToolDispatcher>,
availability: Availability,
) -> Self {
self.dispatchers.push((dispatcher, availability));
self
}
pub fn maybe_add_dispatcher(self, dispatcher: Option<Arc<dyn AgentToolDispatcher>>) -> Self {
match dispatcher {
Some(d) => self.add_dispatcher(d),
None => self,
}
}
pub fn maybe_add_dispatcher_with_availability(
self,
dispatcher: Option<Arc<dyn AgentToolDispatcher>>,
availability: Availability,
) -> Self {
match dispatcher {
Some(d) => self.add_dispatcher_with_availability(d, availability),
None => self,
}
}
pub fn build(self) -> Result<ToolGateway, ToolError> {
let mut route: HashMap<String, usize> = HashMap::new();
let mut all_tools: Vec<Arc<ToolDef>> = Vec::new();
let mut catalog_entries: Vec<ToolCatalogEntry> = Vec::new();
let mut tool_entry: Vec<usize> = Vec::new();
let mut entries: Vec<DispatcherEntry> = Vec::new();
for (dispatcher, availability) in self.dispatchers {
let entry_idx = entries.len();
let dispatcher_catalog: Vec<ToolCatalogEntry> =
if dispatcher.tool_catalog_capabilities().exact_catalog {
dispatcher.tool_catalog().iter().cloned().collect()
} else {
dispatcher
.tools()
.iter()
.map(|tool| ToolCatalogEntry::session_inline(Arc::clone(tool), true))
.collect()
};
for entry in dispatcher_catalog {
if route.contains_key(&entry.tool.name) {
return Err(ToolError::Other(format!(
"tool name collision in gateway: '{}'",
entry.tool.name
)));
}
let tool_idx = all_tools.len();
route.insert(entry.tool.name.clone(), tool_idx);
all_tools.push(Arc::clone(&entry.tool));
catalog_entries.push(entry);
tool_entry.push(entry_idx);
}
entries.push(DispatcherEntry {
dispatcher,
availability,
});
}
let entry_available: Vec<bool> = entries
.iter()
.map(|e| e.availability.is_available())
.collect();
let mut visible = Vec::with_capacity(all_tools.len());
for ((tool, entry), &idx) in all_tools
.iter()
.zip(catalog_entries.iter())
.zip(tool_entry.iter())
{
if entry_available[idx] && entry.currently_callable {
visible.push(Arc::clone(tool));
}
}
let visible_tools: Arc<[Arc<ToolDef>]> = visible.into();
Ok(ToolGateway {
all_tools,
catalog_entries,
tool_entry,
route,
entries,
cache: RwLock::new(ToolGatewayCache {
entry_available,
visible_tools,
}),
})
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl AgentToolDispatcher for ToolGateway {
fn tools(&self) -> Arc<[Arc<ToolDef>]> {
if let Ok(cache) = self.cache.try_read() {
let changed = self.entries.iter().enumerate().any(|(idx, entry)| {
cache.entry_available[idx] != entry.availability.is_available()
});
if !changed {
return Arc::clone(&cache.visible_tools);
}
}
let entry_available: Vec<bool> = self
.entries
.iter()
.map(|entry| entry.availability.is_available())
.collect();
let mut visible = Vec::with_capacity(self.all_tools.len());
for ((tool, entry), &idx) in self
.all_tools
.iter()
.zip(self.catalog_entries.iter())
.zip(self.tool_entry.iter())
{
if entry_available[idx] && entry.currently_callable {
visible.push(Arc::clone(tool));
}
}
let visible_tools: Arc<[Arc<ToolDef>]> = visible.into();
if let Ok(mut cache) = self.cache.try_write() {
cache.entry_available = entry_available;
cache.visible_tools = Arc::clone(&visible_tools);
}
visible_tools
}
async fn dispatch(
&self,
call: ToolCallView<'_>,
) -> Result<crate::ops::ToolDispatchOutcome, ToolError> {
let tool_idx = self
.route
.get(call.name)
.ok_or_else(|| ToolError::not_found(call.name))?;
let entry = &self.entries[self.tool_entry[*tool_idx]];
if let Some(reason) = entry.availability.unavailable_reason() {
return Err(ToolError::unavailable(call.name, reason));
}
if !self.catalog_entries[*tool_idx].currently_callable {
return Err(ToolError::unavailable(
call.name,
"tool is not currently callable",
));
}
entry.dispatcher.dispatch(call).await
}
fn tool_catalog_capabilities(&self) -> ToolCatalogCapabilities {
ToolCatalogCapabilities {
exact_catalog: self
.entries
.iter()
.all(|entry| entry.dispatcher.tool_catalog_capabilities().exact_catalog),
may_require_catalog_control_plane: self.entries.iter().any(|entry| {
entry
.dispatcher
.tool_catalog_capabilities()
.may_require_catalog_control_plane
}),
}
}
fn pending_catalog_sources(&self) -> Arc<[String]> {
let mut pending = std::collections::BTreeSet::new();
for entry in &self.entries {
let sources = entry.dispatcher.pending_catalog_sources();
pending.extend(sources.iter().cloned());
}
pending.into_iter().collect::<Vec<_>>().into()
}
fn tool_catalog(&self) -> Arc<[ToolCatalogEntry]> {
let entry_available: Vec<bool> = self
.entries
.iter()
.map(|entry| entry.availability.is_available())
.collect();
self.catalog_entries
.iter()
.zip(self.tool_entry.iter())
.map(|(entry, entry_idx)| {
let mut entry = entry.clone();
entry.currently_callable &= entry_available[*entry_idx];
entry
})
.collect::<Vec<_>>()
.into()
}
fn capabilities(&self) -> crate::agent::DispatcherCapabilities {
let mut caps = crate::agent::DispatcherCapabilities::default();
for entry in &self.entries {
let c = entry.dispatcher.capabilities();
caps.ops_lifecycle |= c.ops_lifecycle;
}
caps
}
fn bind_ops_lifecycle(
self: Arc<Self>,
registry: Arc<dyn crate::ops_lifecycle::OpsLifecycleRegistry>,
owner_session_id: crate::types::SessionId,
) -> Result<crate::agent::BindOutcome, crate::agent::OpsLifecycleBindError> {
let owned = Arc::try_unwrap(self)
.map_err(|_| crate::agent::OpsLifecycleBindError::SharedOwnership)?;
let mut builder = ToolGatewayBuilder::new();
let mut any_bound = false;
for entry in owned.entries {
if entry.dispatcher.capabilities().ops_lifecycle
&& Arc::strong_count(&entry.dispatcher) == 1
{
let outcome = entry
.dispatcher
.bind_ops_lifecycle(Arc::clone(®istry), owner_session_id.clone())?;
if outcome.was_bound() {
any_bound = true;
}
builder = builder.add_dispatcher_with_availability(
outcome.into_dispatcher(),
entry.availability,
);
} else {
builder =
builder.add_dispatcher_with_availability(entry.dispatcher, entry.availability);
}
}
let gateway = builder
.build()
.map_err(|_| crate::agent::OpsLifecycleBindError::Unsupported)?;
let d: Arc<dyn AgentToolDispatcher> = Arc::new(gateway);
Ok(if any_bound {
crate::agent::BindOutcome::Bound(d)
} else {
crate::agent::BindOutcome::Skipped(d)
})
}
fn completion_enrichment(
&self,
) -> Option<Arc<dyn crate::completion_feed::CompletionEnrichmentProvider>> {
self.entries
.iter()
.find_map(|e| e.dispatcher.completion_enrichment())
}
async fn poll_external_updates(&self) -> ExternalToolUpdate {
let mut all_notices: Vec<ExternalToolDelta> = Vec::new();
let mut all_pending: Vec<String> = Vec::new();
let mut seen_pending: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut seen_notices: std::collections::HashSet<(
String,
String,
String,
bool,
Option<u32>,
)> = std::collections::HashSet::new();
let mut seen_bg_job_ids: std::collections::HashSet<String> =
std::collections::HashSet::new();
let mut all_bg_completions: Vec<DetachedOpCompletion> = Vec::new();
for entry in &self.entries {
let update = entry.dispatcher.poll_external_updates().await;
for notice in update.notices {
let key = (
notice.target.clone(),
format!("{:?}", notice.operation),
notice.status_text(),
notice.persisted,
notice.applied_at_turn,
);
if seen_notices.insert(key) {
all_notices.push(notice);
}
}
for pending in update.pending {
if seen_pending.insert(pending.clone()) {
all_pending.push(pending);
}
}
for bg in update.background_completions {
if seen_bg_job_ids.insert(bg.job_id.clone()) {
all_bg_completions.push(bg);
}
}
}
ExternalToolUpdate {
notices: all_notices,
pending: all_pending,
background_completions: all_bg_completions,
}
}
}
pub struct DynamicToolComposite {
dispatchers: Vec<Arc<dyn AgentToolDispatcher>>,
}
impl DynamicToolComposite {
pub fn new(dispatchers: Vec<Arc<dyn AgentToolDispatcher>>) -> Self {
Self { dispatchers }
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl AgentToolDispatcher for DynamicToolComposite {
fn tools(&self) -> Arc<[Arc<ToolDef>]> {
if self.tool_catalog_capabilities().exact_catalog {
return self
.tool_catalog()
.iter()
.filter(|entry| entry.currently_callable)
.map(|entry| Arc::clone(&entry.tool))
.collect::<Vec<_>>()
.into();
}
let mut seen = std::collections::HashSet::new();
let mut result = Vec::new();
for d in &self.dispatchers {
for t in d.tools().iter() {
if seen.insert(t.name.clone()) {
result.push(Arc::clone(t));
}
}
}
result.into()
}
async fn dispatch(
&self,
call: crate::types::ToolCallView<'_>,
) -> Result<crate::ops::ToolDispatchOutcome, crate::error::ToolError> {
if self.tool_catalog_capabilities().exact_catalog {
for d in &self.dispatchers {
if let Some(entry) = d
.tool_catalog()
.iter()
.find(|entry| entry.tool.name == call.name)
{
if !entry.currently_callable {
return Err(crate::error::ToolError::unavailable(
call.name,
"tool is not currently callable",
));
}
return d.dispatch(call).await;
}
}
return Err(crate::error::ToolError::not_found(call.name));
}
for d in &self.dispatchers {
if d.tools().iter().any(|t| t.name == call.name) {
return d.dispatch(call).await;
}
}
Err(crate::error::ToolError::not_found(call.name))
}
async fn poll_external_updates(&self) -> ExternalToolUpdate {
let mut all_notices = Vec::new();
let mut all_pending = Vec::new();
for d in &self.dispatchers {
let update = d.poll_external_updates().await;
all_notices.extend(update.notices);
all_pending.extend(update.pending);
}
ExternalToolUpdate {
notices: all_notices,
pending: all_pending,
background_completions: Vec::new(),
}
}
fn capabilities(&self) -> crate::agent::DispatcherCapabilities {
let mut caps = crate::agent::DispatcherCapabilities::default();
for d in &self.dispatchers {
let c = d.capabilities();
caps.ops_lifecycle |= c.ops_lifecycle;
}
caps
}
fn completion_enrichment(
&self,
) -> Option<Arc<dyn crate::completion_feed::CompletionEnrichmentProvider>> {
self.dispatchers
.iter()
.find_map(|d| d.completion_enrichment())
}
fn tool_catalog_capabilities(&self) -> ToolCatalogCapabilities {
ToolCatalogCapabilities {
exact_catalog: self
.dispatchers
.iter()
.all(|dispatcher| dispatcher.tool_catalog_capabilities().exact_catalog),
may_require_catalog_control_plane: self.dispatchers.iter().any(|dispatcher| {
dispatcher
.tool_catalog_capabilities()
.may_require_catalog_control_plane
}),
}
}
fn pending_catalog_sources(&self) -> Arc<[String]> {
let mut pending = std::collections::BTreeSet::new();
for dispatcher in &self.dispatchers {
let sources = dispatcher.pending_catalog_sources();
pending.extend(sources.iter().cloned());
}
pending.into_iter().collect::<Vec<_>>().into()
}
fn tool_catalog(&self) -> Arc<[ToolCatalogEntry]> {
if !self.tool_catalog_capabilities().exact_catalog {
return self
.tools()
.iter()
.map(|tool| ToolCatalogEntry::session_inline(Arc::clone(tool), true))
.collect::<Vec<_>>()
.into();
}
let mut seen = std::collections::HashSet::new();
let mut result = Vec::new();
for dispatcher in &self.dispatchers {
for entry in dispatcher.tool_catalog().iter() {
if seen.insert(entry.tool.name.clone()) {
result.push(entry.clone());
}
}
}
result.into()
}
fn bind_ops_lifecycle(
self: Arc<Self>,
registry: Arc<dyn crate::ops_lifecycle::OpsLifecycleRegistry>,
owner_session_id: crate::types::SessionId,
) -> Result<crate::agent::BindOutcome, crate::agent::OpsLifecycleBindError> {
let owned = Arc::try_unwrap(self)
.map_err(|_| crate::agent::OpsLifecycleBindError::SharedOwnership)?;
let mut rebound = Vec::with_capacity(owned.dispatchers.len());
let mut any_bound = false;
for d in owned.dispatchers {
if d.capabilities().ops_lifecycle && Arc::strong_count(&d) == 1 {
let outcome =
d.bind_ops_lifecycle(Arc::clone(®istry), owner_session_id.clone())?;
if outcome.was_bound() {
any_bound = true;
}
rebound.push(outcome.into_dispatcher());
} else {
rebound.push(d);
}
}
let d: Arc<dyn AgentToolDispatcher> = Arc::new(DynamicToolComposite::new(rebound));
Ok(if any_bound {
crate::agent::BindOutcome::Bound(d)
} else {
crate::agent::BindOutcome::Skipped(d)
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use serde_json::Value;
use serde_json::json;
use std::sync::atomic::{AtomicBool, Ordering};
async fn dispatch_json(
gateway: &ToolGateway,
name: &str,
args: serde_json::Value,
) -> Result<Value, ToolError> {
let args_raw =
serde_json::value::RawValue::from_string(args.to_string()).expect("valid args json");
let call = ToolCallView {
id: "test-1",
name,
args: &args_raw,
};
let outcome = gateway.dispatch(call).await?;
serde_json::from_str(&outcome.result.text_content())
.map_err(|e| ToolError::execution_failed(e.to_string()))
}
fn empty_object_schema() -> Value {
let mut obj = serde_json::Map::new();
obj.insert("type".to_string(), Value::String("object".to_string()));
obj.insert(
"properties".to_string(),
Value::Object(serde_json::Map::new()),
);
obj.insert("required".to_string(), Value::Array(Vec::new()));
Value::Object(obj)
}
struct MockDispatcher {
tools: Arc<[Arc<ToolDef>]>,
prefix: String,
}
impl MockDispatcher {
fn new(prefix: &str, tool_names: &[&str]) -> Self {
let tools: Arc<[Arc<ToolDef>]> = tool_names
.iter()
.map(|name| {
Arc::new(ToolDef {
name: name.to_string(),
description: format!("{prefix} tool: {name}"),
input_schema: empty_object_schema(),
provenance: None,
})
})
.collect::<Vec<_>>()
.into();
Self {
tools,
prefix: prefix.to_string(),
}
}
}
struct ExactMockDispatcher {
tools: Arc<[Arc<ToolDef>]>,
catalog: Arc<[crate::ToolCatalogEntry]>,
prefix: String,
}
impl ExactMockDispatcher {
fn with_callability(prefix: &str, entries: &[(&str, bool)]) -> Self {
let catalog: Vec<crate::ToolCatalogEntry> = entries
.iter()
.map(|(name, currently_callable)| {
crate::ToolCatalogEntry::session_inline(
Arc::new(ToolDef {
name: (*name).to_string(),
description: format!("{prefix} tool: {name}"),
input_schema: empty_object_schema(),
provenance: None,
}),
*currently_callable,
)
})
.collect();
let tools: Arc<[Arc<ToolDef>]> = catalog
.iter()
.filter(|entry| entry.currently_callable)
.map(|entry| Arc::clone(&entry.tool))
.collect::<Vec<_>>()
.into();
Self {
tools,
catalog: catalog.into(),
prefix: prefix.to_string(),
}
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl AgentToolDispatcher for ExactMockDispatcher {
fn tools(&self) -> Arc<[Arc<ToolDef>]> {
Arc::clone(&self.tools)
}
fn tool_catalog_capabilities(&self) -> crate::ToolCatalogCapabilities {
crate::ToolCatalogCapabilities {
exact_catalog: true,
may_require_catalog_control_plane: false,
}
}
fn tool_catalog(&self) -> Arc<[crate::ToolCatalogEntry]> {
Arc::clone(&self.catalog)
}
async fn dispatch(
&self,
call: ToolCallView<'_>,
) -> Result<crate::ops::ToolDispatchOutcome, ToolError> {
let Some(entry) = self
.catalog
.iter()
.find(|entry| entry.tool.name == call.name)
else {
return Err(ToolError::not_found(call.name));
};
if !entry.currently_callable {
return Err(ToolError::unavailable(
call.name,
"tool is not currently callable",
));
}
Ok(ToolResult::new(
call.id.to_string(),
json!({"source": self.prefix, "tool": call.name}).to_string(),
false,
)
.into())
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl AgentToolDispatcher for MockDispatcher {
fn tools(&self) -> Arc<[Arc<ToolDef>]> {
Arc::clone(&self.tools)
}
async fn dispatch(
&self,
call: ToolCallView<'_>,
) -> Result<crate::ops::ToolDispatchOutcome, ToolError> {
if self.tools.iter().any(|t| t.name == call.name) {
Ok(ToolResult::new(
call.id.to_string(),
json!({"source": self.prefix, "tool": call.name}).to_string(),
false,
)
.into())
} else {
Err(ToolError::not_found(call.name))
}
}
}
#[test]
fn test_gateway_merges_tools() {
let base = Arc::new(MockDispatcher::new("base", &["task_create", "task_list"]));
let overlay = Arc::new(MockDispatcher::new("comms", &["send", "peers"]));
let gateway = ToolGateway::new(base, Some(overlay)).unwrap();
let tools = gateway.tools();
let tool_names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
assert_eq!(tool_names.len(), 4);
assert!(tool_names.contains(&"task_create"));
assert!(tool_names.contains(&"task_list"));
assert!(tool_names.contains(&"send"));
assert!(tool_names.contains(&"peers"));
}
#[test]
fn test_gateway_no_overlay() {
let base = Arc::new(MockDispatcher::new("base", &["task_create", "task_list"]));
let gateway = ToolGateway::new(base, None).unwrap();
assert_eq!(gateway.tools().len(), 2);
}
#[test]
fn test_gateway_collision_error() {
let base = Arc::new(MockDispatcher::new("base", &["task_create", "send"]));
let overlay = Arc::new(MockDispatcher::new("comms", &["send", "peers"]));
let result = ToolGateway::new(base, Some(overlay));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("send"));
assert!(err.to_string().contains("collision"));
}
#[tokio::test]
async fn test_gateway_routes_to_base() {
let base = Arc::new(MockDispatcher::new("base", &["task_create"]));
let overlay = Arc::new(MockDispatcher::new("comms", &["send"]));
let gateway = ToolGateway::new(base, Some(overlay)).unwrap();
let result = dispatch_json(&gateway, "task_create", json!({}))
.await
.unwrap();
assert_eq!(result["source"], "base");
assert_eq!(result["tool"], "task_create");
}
#[tokio::test]
async fn test_gateway_routes_to_overlay() {
let base = Arc::new(MockDispatcher::new("base", &["task_create"]));
let overlay = Arc::new(MockDispatcher::new("comms", &["send"]));
let gateway = ToolGateway::new(base, Some(overlay)).unwrap();
let result = dispatch_json(&gateway, "send", json!({})).await.unwrap();
assert_eq!(result["source"], "comms");
assert_eq!(result["tool"], "send");
}
#[tokio::test]
async fn test_gateway_not_found() {
let base = Arc::new(MockDispatcher::new("base", &["task_create"]));
let gateway = ToolGateway::new(base, None).unwrap();
let result = dispatch_json(&gateway, "unknown_tool", json!({})).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ToolError::NotFound { .. }));
}
#[test]
fn test_builder_multiple_dispatchers() {
let base = Arc::new(MockDispatcher::new("base", &["task_create"]));
let comms = Arc::new(MockDispatcher::new("comms", &["send"]));
let shell = Arc::new(MockDispatcher::new("shell", &["run_command"]));
let gateway = ToolGatewayBuilder::new()
.add_dispatcher(base)
.add_dispatcher(comms)
.add_dispatcher(shell)
.build()
.unwrap();
assert_eq!(gateway.tools().len(), 3);
}
#[test]
fn test_availability_always() {
let avail = Availability::Always;
assert!(avail.is_available());
assert!(avail.unavailable_reason().is_none());
}
#[test]
fn test_availability_when_true() {
let avail = Availability::when("no peers", Arc::new(|| true));
assert!(avail.is_available());
assert!(avail.unavailable_reason().is_none());
}
#[test]
fn test_availability_when_false() {
let avail = Availability::when("no peers configured", Arc::new(|| false));
assert!(!avail.is_available());
assert_eq!(avail.unavailable_reason(), Some("no peers configured"));
}
#[test]
fn test_availability_dynamic() {
let flag = Arc::new(AtomicBool::new(false));
let flag_clone = flag.clone();
let avail = Availability::when(
"no peers",
Arc::new(move || flag_clone.load(Ordering::SeqCst)),
);
assert!(!avail.is_available());
flag.store(true, Ordering::SeqCst);
assert!(avail.is_available());
flag.store(false, Ordering::SeqCst);
assert!(!avail.is_available());
}
#[test]
fn test_gateway_conditional_visibility() {
let flag = Arc::new(AtomicBool::new(false));
let flag_clone = flag.clone();
let base = Arc::new(MockDispatcher::new("base", &["task_create"]));
let comms = Arc::new(MockDispatcher::new("comms", &["send"]));
let gateway = ToolGatewayBuilder::new()
.add_dispatcher(base)
.add_dispatcher_with_availability(
comms,
Availability::when(
"no peers",
Arc::new(move || flag_clone.load(Ordering::SeqCst)),
),
)
.build()
.unwrap();
let tools = gateway.tools();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name, "task_create");
flag.store(true, Ordering::SeqCst);
let tools = gateway.tools();
assert_eq!(tools.len(), 2);
flag.store(false, Ordering::SeqCst);
let tools = gateway.tools();
assert_eq!(tools.len(), 1);
}
#[tokio::test]
async fn test_gateway_unavailable_dispatch() {
let flag = Arc::new(AtomicBool::new(false));
let flag_clone = flag.clone();
let base = Arc::new(MockDispatcher::new("base", &["task_create"]));
let comms = Arc::new(MockDispatcher::new("comms", &["send"]));
let gateway = ToolGatewayBuilder::new()
.add_dispatcher(base)
.add_dispatcher_with_availability(
comms,
Availability::when(
"no peers configured",
Arc::new(move || flag_clone.load(Ordering::SeqCst)),
),
)
.build()
.unwrap();
let result = dispatch_json(&gateway, "send", json!({})).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, ToolError::Unavailable { .. }));
assert!(err.to_string().contains("no peers configured"));
flag.store(true, Ordering::SeqCst);
let result = dispatch_json(&gateway, "send", json!({})).await;
assert!(result.is_ok());
}
#[test]
fn test_collision_detection_ignores_availability() {
let flag = Arc::new(AtomicBool::new(false));
let base = Arc::new(MockDispatcher::new("base", &["send"]));
let comms = Arc::new(MockDispatcher::new("comms", &["send"]));
let result = ToolGatewayBuilder::new()
.add_dispatcher(base)
.add_dispatcher_with_availability(
comms,
Availability::when("no peers", Arc::new(move || flag.load(Ordering::SeqCst))),
)
.build();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("collision"));
}
#[test]
fn test_availability_debug() {
let always = Availability::Always;
assert_eq!(format!("{always:?}"), "Availability::Always");
let when = Availability::when("test reason", Arc::new(|| true));
assert!(format!("{when:?}").contains("test reason"));
}
#[test]
fn test_builder_maybe_add() {
let base = Arc::new(MockDispatcher::new("base", &["task_create"]));
let gateway = ToolGatewayBuilder::new()
.add_dispatcher(base.clone())
.maybe_add_dispatcher(None)
.build()
.unwrap();
assert_eq!(gateway.tools().len(), 1);
let overlay = Arc::new(MockDispatcher::new("comms", &["send"]));
let gateway = ToolGatewayBuilder::new()
.add_dispatcher(base)
.maybe_add_dispatcher(Some(overlay))
.build()
.unwrap();
assert_eq!(gateway.tools().len(), 2);
}
#[test]
fn test_dispatcher_all_or_nothing_visibility() {
let flag = Arc::new(AtomicBool::new(false));
let flag_clone = flag.clone();
let base = Arc::new(MockDispatcher::new("base", &["task_create"]));
let comms = Arc::new(MockDispatcher::new(
"comms",
&["send", "send_request", "send_response", "peers"],
));
let gateway = ToolGatewayBuilder::new()
.add_dispatcher(base)
.add_dispatcher_with_availability(
comms,
Availability::when(
"no peers",
Arc::new(move || flag_clone.load(Ordering::SeqCst)),
),
)
.build()
.unwrap();
let tools = gateway.tools();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name, "task_create");
flag.store(true, Ordering::SeqCst);
let tools = gateway.tools();
assert_eq!(tools.len(), 5); let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
assert!(names.contains(&"task_create"));
assert!(names.contains(&"send"));
assert!(names.contains(&"send_request"));
assert!(names.contains(&"send_response"));
assert!(names.contains(&"peers"));
flag.store(false, Ordering::SeqCst);
let tools = gateway.tools();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name, "task_create");
}
struct MockBgDispatcher {
update: ExternalToolUpdate,
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl AgentToolDispatcher for MockBgDispatcher {
fn tools(&self) -> Arc<[Arc<ToolDef>]> {
Arc::new([])
}
async fn dispatch(
&self,
_call: ToolCallView<'_>,
) -> Result<crate::ops::ToolDispatchOutcome, ToolError> {
Err(ToolError::not_found(""))
}
async fn poll_external_updates(&self) -> ExternalToolUpdate {
self.update.clone()
}
}
#[tokio::test]
async fn choke_003_gateway_dedups_background_completions_by_job_id() {
use crate::agent::DetachedOpCompletion;
use crate::ops_lifecycle::{OperationKind, OperationStatus};
let completion = DetachedOpCompletion {
job_id: "j_123".into(),
kind: OperationKind::BackgroundToolOp,
status: OperationStatus::Completed,
terminal_outcome: None,
display_name: "sleep 2".into(),
detail: "exit_code: 0".into(),
elapsed_ms: Some(2000),
};
let update = ExternalToolUpdate {
notices: Vec::new(),
pending: Vec::new(),
background_completions: vec![completion.clone()],
};
let d1: Arc<dyn AgentToolDispatcher> = Arc::new(MockBgDispatcher {
update: update.clone(),
});
let d2: Arc<dyn AgentToolDispatcher> = Arc::new(MockBgDispatcher { update });
let gateway = ToolGatewayBuilder::new()
.add_dispatcher(d1)
.add_dispatcher(d2)
.build()
.unwrap();
let result = gateway.poll_external_updates().await;
assert_eq!(
result.background_completions.len(),
1,
"gateway must dedup background_completions by job_id; got {} entries",
result.background_completions.len()
);
assert_eq!(result.background_completions[0].job_id, "j_123");
}
#[test]
fn gateway_exact_catalog_tracks_unavailable_winners() {
let base = Arc::new(ExactMockDispatcher::with_callability(
"base",
&[("alpha", true)],
));
let overlay = Arc::new(ExactMockDispatcher::with_callability(
"overlay",
&[("beta", false)],
));
let gateway = ToolGateway::new(base, Some(overlay)).expect("gateway should build");
assert!(
gateway.tool_catalog_capabilities().exact_catalog,
"gateway should be exact when every child is exact"
);
let visible_names: Vec<_> = gateway
.tools()
.iter()
.map(|tool| tool.name.clone())
.collect();
assert_eq!(visible_names, vec!["alpha".to_string()]);
let catalog = gateway.tool_catalog();
let catalog_names: Vec<_> = catalog
.iter()
.map(|entry| entry.tool.name.clone())
.collect();
assert_eq!(catalog_names, vec!["alpha".to_string(), "beta".to_string()]);
assert!(
!catalog
.iter()
.find(|entry| entry.tool.name == "beta")
.expect("beta catalog entry")
.currently_callable,
"exact catalog should retain unavailable winners"
);
}
#[test]
fn gateway_exact_catalog_is_disabled_by_non_exact_child() {
let exact = Arc::new(ExactMockDispatcher::with_callability(
"exact",
&[("alpha", true)],
));
let non_exact = Arc::new(MockDispatcher::new("legacy", &["beta"]));
let gateway = ToolGateway::new(exact, Some(non_exact)).expect("gateway should build");
assert!(
!gateway.tool_catalog_capabilities().exact_catalog,
"gateway should disable deferred catalogs when any child is non-exact"
);
}
#[test]
fn dynamic_tool_composite_exact_catalog_keeps_first_winner_even_when_unavailable() {
let first = Arc::new(ExactMockDispatcher::with_callability(
"first",
&[("shared", false)],
));
let second = Arc::new(ExactMockDispatcher::with_callability(
"second",
&[("shared", true), ("other", true)],
));
let composite = DynamicToolComposite::new(vec![first, second]);
assert!(
composite.tool_catalog_capabilities().exact_catalog,
"dynamic composite should be exact when every child is exact"
);
let visible_names: Vec<_> = composite
.tools()
.iter()
.map(|tool| tool.name.clone())
.collect();
assert_eq!(
visible_names,
vec!["other".to_string()],
"a later visible collision loser must not become the exported winner"
);
let catalog = composite.tool_catalog();
assert_eq!(catalog.len(), 2);
assert!(
!catalog
.iter()
.find(|entry| entry.tool.name == "shared")
.expect("shared entry")
.currently_callable
);
}
}