use std::sync::Arc;
use std::task::{Context, Poll};
use async_trait::async_trait;
use futures::future::BoxFuture;
use serde_json::Value;
use tower::{Layer, Service, ServiceExt};
use entelix_core::context::ExecutionContext;
use entelix_core::error::{Error, Result};
use entelix_core::service::ToolInvocation;
#[derive(Clone, Debug, PartialEq)]
#[non_exhaustive]
pub enum ToolHookDecision {
Continue,
ReplaceInput(Value),
Reject {
reason: String,
},
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct ToolHookRequest {
pub run_id: Option<String>,
pub tool_use_id: String,
pub tool_name: String,
pub tool_version: Option<String>,
pub input: Value,
}
impl ToolHookRequest {
fn from_invocation(invocation: &ToolInvocation) -> Self {
Self {
run_id: invocation.ctx.run_id().map(str::to_owned),
tool_use_id: invocation.tool_use_id.clone(),
tool_name: invocation.metadata.name.clone(),
tool_version: invocation.metadata.version.clone(),
input: invocation.input.clone(),
}
}
fn with_input(&self, input: Value) -> Self {
Self {
input,
..self.clone()
}
}
}
#[async_trait]
pub trait ToolHook: Send + Sync + 'static {
async fn before_tool(
&self,
_request: &ToolHookRequest,
_ctx: &ExecutionContext,
) -> Result<ToolHookDecision> {
Ok(ToolHookDecision::Continue)
}
async fn after_tool(
&self,
_request: &ToolHookRequest,
_output: &Value,
_ctx: &ExecutionContext,
) -> Result<()> {
Ok(())
}
async fn on_tool_error(
&self,
_request: &ToolHookRequest,
_error: &Error,
_ctx: &ExecutionContext,
) -> Result<()> {
Ok(())
}
}
#[derive(Clone, Default)]
pub struct ToolHookRegistry {
hooks: Arc<Vec<Arc<dyn ToolHook>>>,
}
impl ToolHookRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn len(&self) -> usize {
self.hooks.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.hooks.is_empty()
}
#[must_use]
pub fn register<H>(self, hook: H) -> Self
where
H: ToolHook,
{
self.register_arc(Arc::new(hook))
}
#[must_use]
pub fn register_arc(self, hook: Arc<dyn ToolHook>) -> Self {
let mut hooks = Vec::with_capacity(self.hooks.len() + 1);
hooks.extend(self.hooks.iter().cloned());
hooks.push(hook);
Self {
hooks: Arc::new(hooks),
}
}
async fn apply_before(&self, invocation: &mut ToolInvocation) -> Result<ToolHookRequest> {
let mut request = ToolHookRequest::from_invocation(invocation);
for hook in self.hooks.iter() {
match hook.before_tool(&request, &invocation.ctx).await? {
ToolHookDecision::Continue => {}
ToolHookDecision::ReplaceInput(input) => {
invocation.input = input.clone();
request = request.with_input(input);
}
ToolHookDecision::Reject { reason } => {
return Err(Error::invalid_request(reason));
}
}
}
Ok(request)
}
async fn apply_after(
&self,
request: &ToolHookRequest,
output: &Value,
ctx: &ExecutionContext,
) -> Result<()> {
for hook in self.hooks.iter() {
hook.after_tool(request, output, ctx).await?;
}
Ok(())
}
async fn apply_error(&self, request: &ToolHookRequest, error: &Error, ctx: &ExecutionContext) {
for hook in self.hooks.iter() {
if let Err(hook_error) = hook.on_tool_error(request, error, ctx).await {
tracing::warn!(
tool = %request.tool_name,
tool_use_id = %request.tool_use_id,
error = %hook_error,
"tool error hook failed; preserving original tool error"
);
}
}
}
}
#[derive(Clone, Default)]
pub struct ToolHookLayer {
hooks: ToolHookRegistry,
}
impl ToolHookLayer {
pub const NAME: &'static str = "tool_hook";
#[must_use]
pub const fn new(hooks: ToolHookRegistry) -> Self {
Self { hooks }
}
#[must_use]
pub const fn hooks(&self) -> &ToolHookRegistry {
&self.hooks
}
}
impl<Inner> Layer<Inner> for ToolHookLayer {
type Service = ToolHookService<Inner>;
fn layer(&self, inner: Inner) -> Self::Service {
ToolHookService {
inner,
hooks: self.hooks.clone(),
}
}
}
impl entelix_core::NamedLayer for ToolHookLayer {
fn layer_name(&self) -> &'static str {
Self::NAME
}
}
pub struct ToolHookService<Inner> {
inner: Inner,
hooks: ToolHookRegistry,
}
impl<Inner: Clone> Clone for ToolHookService<Inner> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
hooks: self.hooks.clone(),
}
}
}
impl<Inner> Service<ToolInvocation> for ToolHookService<Inner>
where
Inner: Service<ToolInvocation, Response = Value, Error = Error> + Clone + Send + 'static,
Inner::Future: Send + 'static,
{
type Response = Value;
type Error = Error;
type Future = BoxFuture<'static, Result<Value>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut invocation: ToolInvocation) -> Self::Future {
let hooks = self.hooks.clone();
let inner = self.inner.clone();
Box::pin(async move {
let request = hooks.apply_before(&mut invocation).await?;
let ctx = invocation.ctx.clone();
let result = inner.oneshot(invocation).await;
match result {
Ok(output) => {
hooks.apply_after(&request, &output, &ctx).await?;
Ok(output)
}
Err(error) => {
hooks.apply_error(&request, &error, &ctx).await;
Err(error)
}
}
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, Ordering};
use serde_json::json;
use super::*;
use entelix_core::tools::ToolMetadata;
#[derive(Clone)]
struct RecordingService {
seen: Arc<Mutex<Vec<Value>>>,
output: RecordingOutput,
}
#[derive(Clone)]
enum RecordingOutput {
Ok(Value),
InvalidRequest(&'static str),
}
impl Service<ToolInvocation> for RecordingService {
type Response = Value;
type Error = Error;
type Future = BoxFuture<'static, Result<Value>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, invocation: ToolInvocation) -> Self::Future {
self.seen.lock().unwrap().push(invocation.input);
let output = match self.output.clone() {
RecordingOutput::Ok(value) => Ok(value),
RecordingOutput::InvalidRequest(message) => Err(Error::invalid_request(message)),
};
Box::pin(async move { output })
}
}
struct ReplaceHook;
struct ReplaceWithHook {
replacement: Value,
seen_inputs: Arc<Mutex<Vec<Value>>>,
}
#[async_trait]
impl ToolHook for ReplaceHook {
async fn before_tool(
&self,
_request: &ToolHookRequest,
_ctx: &ExecutionContext,
) -> Result<ToolHookDecision> {
Ok(ToolHookDecision::ReplaceInput(json!({"value": 2})))
}
}
#[async_trait]
impl ToolHook for ReplaceWithHook {
async fn before_tool(
&self,
request: &ToolHookRequest,
_ctx: &ExecutionContext,
) -> Result<ToolHookDecision> {
self.seen_inputs.lock().unwrap().push(request.input.clone());
Ok(ToolHookDecision::ReplaceInput(self.replacement.clone()))
}
}
struct RejectHook;
#[async_trait]
impl ToolHook for RejectHook {
async fn before_tool(
&self,
_request: &ToolHookRequest,
_ctx: &ExecutionContext,
) -> Result<ToolHookDecision> {
Ok(ToolHookDecision::Reject {
reason: "blocked".to_owned(),
})
}
}
struct AfterHook {
outputs: Arc<Mutex<Vec<Value>>>,
}
struct AfterRequestHook {
inputs: Arc<Mutex<Vec<Value>>>,
}
#[async_trait]
impl ToolHook for AfterHook {
async fn after_tool(
&self,
_request: &ToolHookRequest,
output: &Value,
_ctx: &ExecutionContext,
) -> Result<()> {
self.outputs.lock().unwrap().push(output.clone());
Ok(())
}
}
#[async_trait]
impl ToolHook for AfterRequestHook {
async fn after_tool(
&self,
request: &ToolHookRequest,
_output: &Value,
_ctx: &ExecutionContext,
) -> Result<()> {
self.inputs.lock().unwrap().push(request.input.clone());
Ok(())
}
}
struct ErrorHook {
fired: Arc<AtomicBool>,
}
#[async_trait]
impl ToolHook for ErrorHook {
async fn on_tool_error(
&self,
_request: &ToolHookRequest,
_error: &Error,
_ctx: &ExecutionContext,
) -> Result<()> {
self.fired.store(true, Ordering::SeqCst);
Ok(())
}
}
fn invocation(input: Value) -> ToolInvocation {
ToolInvocation::new(
"call_1".to_owned(),
Arc::new(ToolMetadata::function(
"echo",
"echo input",
json!({"type": "object"}),
)),
input,
ExecutionContext::new(),
)
}
#[tokio::test]
async fn before_hook_replaces_input_before_inner_service() {
let seen = Arc::new(Mutex::new(Vec::new()));
let service = RecordingService {
seen: Arc::clone(&seen),
output: RecordingOutput::Ok(json!({"ok": true})),
};
let mut service =
ToolHookLayer::new(ToolHookRegistry::new().register(ReplaceHook)).layer(service);
let output = service.call(invocation(json!({"value": 1}))).await.unwrap();
assert_eq!(output, json!({"ok": true}));
assert_eq!(*seen.lock().unwrap(), vec![json!({"value": 2})]);
}
#[tokio::test]
async fn before_hooks_run_in_registration_order_and_pass_replacements_forward() {
let inner_seen = Arc::new(Mutex::new(Vec::new()));
let hook_seen = Arc::new(Mutex::new(Vec::new()));
let after_seen = Arc::new(Mutex::new(Vec::new()));
let service = RecordingService {
seen: Arc::clone(&inner_seen),
output: RecordingOutput::Ok(json!({"ok": true})),
};
let hooks = ToolHookRegistry::new()
.register(ReplaceWithHook {
replacement: json!({"value": 2}),
seen_inputs: Arc::clone(&hook_seen),
})
.register(ReplaceWithHook {
replacement: json!({"value": 3}),
seen_inputs: Arc::clone(&hook_seen),
})
.register(AfterRequestHook {
inputs: Arc::clone(&after_seen),
});
let mut service = ToolHookLayer::new(hooks).layer(service);
let output = service.call(invocation(json!({"value": 1}))).await.unwrap();
assert_eq!(output, json!({"ok": true}));
assert_eq!(
*hook_seen.lock().unwrap(),
vec![json!({"value": 1}), json!({"value": 2})],
"each before hook must see the input produced by the previous hook"
);
assert_eq!(*inner_seen.lock().unwrap(), vec![json!({"value": 3})]);
assert_eq!(
*after_seen.lock().unwrap(),
vec![json!({"value": 3})],
"after hooks must observe the final dispatch request"
);
}
#[tokio::test]
async fn before_hook_rejects_without_calling_inner_service() {
let seen = Arc::new(Mutex::new(Vec::new()));
let service = RecordingService {
seen: Arc::clone(&seen),
output: RecordingOutput::Ok(json!({"ok": true})),
};
let mut service =
ToolHookLayer::new(ToolHookRegistry::new().register(RejectHook)).layer(service);
let error = service.call(invocation(json!({}))).await.unwrap_err();
assert!(matches!(error, Error::InvalidRequest(_)));
assert!(format!("{error}").contains("blocked"));
assert!(seen.lock().unwrap().is_empty());
}
#[tokio::test]
async fn after_hook_observes_successful_output() {
let outputs = Arc::new(Mutex::new(Vec::new()));
let service = RecordingService {
seen: Arc::new(Mutex::new(Vec::new())),
output: RecordingOutput::Ok(json!({"answer": 42})),
};
let mut service = ToolHookLayer::new(ToolHookRegistry::new().register(AfterHook {
outputs: Arc::clone(&outputs),
}))
.layer(service);
let _ = service.call(invocation(json!({}))).await.unwrap();
assert_eq!(*outputs.lock().unwrap(), vec![json!({"answer": 42})]);
}
#[tokio::test]
async fn error_hook_observes_inner_error_without_masking_it() {
let fired = Arc::new(AtomicBool::new(false));
let service = RecordingService {
seen: Arc::new(Mutex::new(Vec::new())),
output: RecordingOutput::InvalidRequest("inner failed"),
};
let mut service = ToolHookLayer::new(ToolHookRegistry::new().register(ErrorHook {
fired: Arc::clone(&fired),
}))
.layer(service);
let error = service.call(invocation(json!({}))).await.unwrap_err();
assert!(fired.load(Ordering::SeqCst));
assert!(format!("{error}").contains("inner failed"));
}
}