use std::sync::Arc;
use std::task::{Context, Poll};
use futures::future::BoxFuture;
use serde_json::Value;
use tower::{Layer, Service};
use crate::context::ExecutionContext;
use crate::error::{Error, Result};
use crate::service::ToolInvocation;
pub trait ToolDispatchScope: Send + Sync + 'static {
fn wrap(
&self,
ctx: ExecutionContext,
fut: BoxFuture<'static, Result<Value>>,
) -> BoxFuture<'static, Result<Value>>;
}
pub struct ScopedToolLayer {
wrapper: Arc<dyn ToolDispatchScope>,
}
impl ScopedToolLayer {
pub const NAME: &'static str = "tool_scope";
pub fn new<W>(wrapper: W) -> Self
where
W: ToolDispatchScope,
{
Self {
wrapper: Arc::new(wrapper),
}
}
#[must_use]
pub fn from_arc(wrapper: Arc<dyn ToolDispatchScope>) -> Self {
Self { wrapper }
}
}
impl Clone for ScopedToolLayer {
fn clone(&self) -> Self {
Self {
wrapper: Arc::clone(&self.wrapper),
}
}
}
impl<S> Layer<S> for ScopedToolLayer {
type Service = ScopedToolService<S>;
fn layer(&self, inner: S) -> Self::Service {
ScopedToolService {
inner,
wrapper: Arc::clone(&self.wrapper),
}
}
}
impl crate::NamedLayer for ScopedToolLayer {
fn layer_name(&self) -> &'static str {
Self::NAME
}
}
pub struct ScopedToolService<S> {
inner: S,
wrapper: Arc<dyn ToolDispatchScope>,
}
impl<S: Clone> Clone for ScopedToolService<S> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
wrapper: Arc::clone(&self.wrapper),
}
}
}
impl<S> Service<ToolInvocation> for ScopedToolService<S>
where
S: Service<ToolInvocation, Response = Value, Error = Error> + Send + 'static,
S::Future: Send + 'static,
{
type Response = Value;
type Error = Error;
type Future = BoxFuture<'static, Result<Value>>;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, invocation: ToolInvocation) -> Self::Future {
let ctx = invocation.ctx.clone();
let inner_fut = self.inner.call(invocation);
self.wrapper.wrap(ctx, Box::pin(inner_fut))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use serde_json::json;
use super::*;
use crate::tools::{Tool, ToolMetadata, ToolRegistry};
use async_trait::async_trait;
struct CountingScope {
wraps: Arc<AtomicUsize>,
}
impl ToolDispatchScope for CountingScope {
fn wrap(
&self,
_ctx: ExecutionContext,
fut: BoxFuture<'static, Result<Value>>,
) -> BoxFuture<'static, Result<Value>> {
self.wraps.fetch_add(1, Ordering::SeqCst);
fut
}
}
struct EchoTool {
metadata: ToolMetadata,
}
impl EchoTool {
fn new() -> Self {
Self {
metadata: ToolMetadata::function(
"echo",
"Echo input verbatim.",
json!({ "type": "object" }),
),
}
}
}
#[async_trait]
impl Tool for EchoTool {
fn metadata(&self) -> &ToolMetadata {
&self.metadata
}
async fn execute(&self, input: Value, _ctx: &crate::AgentContext<()>) -> Result<Value> {
Ok(input)
}
}
#[tokio::test]
async fn scope_wrap_fires_on_dispatch() {
let wraps = Arc::new(AtomicUsize::new(0));
let scope = CountingScope {
wraps: Arc::clone(&wraps),
};
let registry = ToolRegistry::new()
.layer(ScopedToolLayer::new(scope))
.register(Arc::new(EchoTool::new()))
.unwrap();
let ctx = ExecutionContext::new();
let result = registry
.dispatch("", "echo", json!({"x": 1}), &ctx)
.await
.unwrap();
assert_eq!(result, json!({"x": 1}));
assert_eq!(wraps.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn scope_wrap_fires_per_dispatch() {
let wraps = Arc::new(AtomicUsize::new(0));
let scope = CountingScope {
wraps: Arc::clone(&wraps),
};
let registry = ToolRegistry::new()
.layer(ScopedToolLayer::new(scope))
.register(Arc::new(EchoTool::new()))
.unwrap();
let ctx = ExecutionContext::new();
for _ in 0..3 {
registry
.dispatch("", "echo", json!({"x": 1}), &ctx)
.await
.unwrap();
}
assert_eq!(wraps.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn scope_wrap_inherited_by_narrowed_view() {
let wraps = Arc::new(AtomicUsize::new(0));
let scope = CountingScope {
wraps: Arc::clone(&wraps),
};
let parent = ToolRegistry::new()
.layer(ScopedToolLayer::new(scope))
.register(Arc::new(EchoTool::new()))
.unwrap();
let narrowed = parent.restricted_to(&["echo"]).unwrap();
let ctx = ExecutionContext::new();
narrowed
.dispatch("", "echo", json!({"x": 1}), &ctx)
.await
.unwrap();
assert_eq!(wraps.load(Ordering::SeqCst), 1);
}
}