use std::sync::Arc;
use std::task::{Context, Poll};
use async_trait::async_trait;
use futures::future::BoxFuture;
use serde_json::{Value, json};
use tower::{Layer, Service};
use entelix_core::PendingApprovalDecisions;
use entelix_core::TenantId;
use entelix_core::error::{Error, Result};
use entelix_core::interruption::InterruptionKind;
use entelix_core::service::ToolInvocation;
use entelix_core::tools::ToolEffect;
use crate::agent::approver::{ApprovalDecision, ApprovalRequest, Approver};
use crate::agent::event::AgentEvent;
use crate::agent::sink::AgentEventSink;
#[async_trait]
pub trait ToolApprovalEventSink: Send + Sync + 'static {
async fn record_approved(
&self,
tenant_id: &TenantId,
run_id: &str,
tool_use_id: &str,
tool: &str,
);
async fn record_denied(
&self,
tenant_id: &TenantId,
run_id: &str,
tool_use_id: &str,
tool: &str,
reason: &str,
);
}
#[derive(Clone)]
pub struct ToolApprovalEventSinkHandle {
sink: Arc<dyn ToolApprovalEventSink>,
}
impl ToolApprovalEventSinkHandle {
pub fn new<E>(sink: E) -> Self
where
E: ToolApprovalEventSink,
{
Self {
sink: Arc::new(sink),
}
}
pub fn for_agent_sink<S>(sink: Arc<dyn AgentEventSink<S>>) -> Self
where
S: Clone + Send + Sync + 'static,
{
Self {
sink: Arc::new(SinkAdapter { sink }),
}
}
pub fn inner(&self) -> &Arc<dyn ToolApprovalEventSink> {
&self.sink
}
}
struct SinkAdapter<S> {
sink: Arc<dyn AgentEventSink<S>>,
}
#[async_trait]
impl<S> ToolApprovalEventSink for SinkAdapter<S>
where
S: Clone + Send + Sync + 'static,
{
async fn record_approved(
&self,
tenant_id: &TenantId,
run_id: &str,
tool_use_id: &str,
tool: &str,
) {
let event: AgentEvent<S> = AgentEvent::ToolCallApproved {
run_id: run_id.to_owned(),
tenant_id: tenant_id.clone(),
tool_use_id: tool_use_id.to_owned(),
tool: tool.to_owned(),
};
let _ = self.sink.send(event).await;
}
async fn record_denied(
&self,
tenant_id: &TenantId,
run_id: &str,
tool_use_id: &str,
tool: &str,
reason: &str,
) {
let event: AgentEvent<S> = AgentEvent::ToolCallDenied {
run_id: run_id.to_owned(),
tenant_id: tenant_id.clone(),
tool_use_id: tool_use_id.to_owned(),
tool: tool.to_owned(),
reason: reason.to_owned(),
};
let _ = self.sink.send(event).await;
}
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
#[non_exhaustive]
pub enum EffectGate {
#[default]
Always,
DestructiveOnly,
MutatingAndAbove,
}
impl EffectGate {
#[must_use]
pub const fn requires_approval(self, effect: ToolEffect) -> bool {
match self {
Self::Always => true,
Self::DestructiveOnly => matches!(effect, ToolEffect::Destructive),
Self::MutatingAndAbove => {
matches!(effect, ToolEffect::Mutating | ToolEffect::Destructive)
}
}
}
}
pub struct ApprovalLayer {
approver: Arc<dyn Approver>,
gate: EffectGate,
}
impl ApprovalLayer {
pub const NAME: &'static str = "tool_approval";
pub fn new(approver: Arc<dyn Approver>) -> Self {
Self {
approver,
gate: EffectGate::default(),
}
}
#[must_use]
pub const fn with_effect_gate(mut self, gate: EffectGate) -> Self {
self.gate = gate;
self
}
}
impl Clone for ApprovalLayer {
fn clone(&self) -> Self {
Self {
approver: Arc::clone(&self.approver),
gate: self.gate.clone(),
}
}
}
impl<S> Layer<S> for ApprovalLayer {
type Service = ApprovalService<S>;
fn layer(&self, inner: S) -> Self::Service {
ApprovalService {
inner,
approver: Arc::clone(&self.approver),
gate: self.gate.clone(),
}
}
}
impl entelix_core::NamedLayer for ApprovalLayer {
fn layer_name(&self) -> &'static str {
Self::NAME
}
}
pub struct ApprovalService<S> {
inner: S,
approver: Arc<dyn Approver>,
gate: EffectGate,
}
impl<S: Clone> Clone for ApprovalService<S> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
approver: Arc::clone(&self.approver),
gate: self.gate.clone(),
}
}
}
impl<S> Service<ToolInvocation> for ApprovalService<S>
where
S: Service<ToolInvocation, Response = Value, Error = Error> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = Value;
type Error = Error;
type Future = BoxFuture<'static, Result<Value>>;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, invocation: ToolInvocation) -> Self::Future {
let approver = Arc::clone(&self.approver);
let gate = self.gate.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
let override_decision = invocation
.ctx
.extension::<PendingApprovalDecisions>()
.and_then(|o| o.get(&invocation.tool_use_id).cloned());
if override_decision.is_none() && !gate.requires_approval(invocation.metadata.effect) {
return inner.call(invocation).await;
}
let decision = if let Some(d) = override_decision {
d
} else {
let request = ApprovalRequest::new(
invocation.tool_use_id.clone(),
invocation.metadata.name.clone(),
invocation.input.clone(),
);
approver.decide(&request, &invocation.ctx).await?
};
let sink = invocation.ctx.extension::<ToolApprovalEventSinkHandle>();
let tenant_id = invocation.ctx.tenant_id().clone();
let run_id = invocation.ctx.run_id().unwrap_or("").to_owned();
let tool_use_id = invocation.tool_use_id.clone();
let tool_name = invocation.metadata.name.clone();
let input = invocation.input.clone();
match decision {
ApprovalDecision::Approve => {
if let Some(handle) = sink.as_deref() {
handle
.inner()
.record_approved(&tenant_id, &run_id, &tool_use_id, &tool_name)
.await;
}
inner.call(invocation).await
}
ApprovalDecision::Reject { reason } => {
if let Some(handle) = sink.as_deref() {
handle
.inner()
.record_denied(&tenant_id, &run_id, &tool_use_id, &tool_name, &reason)
.await;
}
Err(Error::invalid_request(format!(
"approver rejected tool '{tool_name}' dispatch: {reason}"
)))
}
ApprovalDecision::AwaitExternal => {
Err(Error::Interrupted {
kind: InterruptionKind::ApprovalPending {
tool_use_id: tool_use_id.clone(),
},
payload: json!({
"run_id": run_id,
"tool_use_id": tool_use_id,
"tool": tool_name,
"input": input,
}),
})
}
_ => Err(Error::config(format!(
"ApprovalLayer received an unsupported `ApprovalDecision` variant for tool '{tool_name}'; \
update the layer to handle the new variant"
))),
}
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use entelix_core::AgentContext;
use entelix_core::ExecutionContext;
use entelix_core::tools::{Tool, ToolMetadata, ToolRegistry};
use serde_json::json;
use super::*;
use crate::agent::approver::{AlwaysApprove, ApprovalDecision, ApprovalRequest};
struct EchoTool {
metadata: ToolMetadata,
}
impl EchoTool {
fn new() -> Self {
Self {
metadata: ToolMetadata::function(
"echo",
"Echo input verbatim.",
json!({ "type": "object" }),
),
}
}
}
#[async_trait]
impl Tool for EchoTool {
fn metadata(&self) -> &ToolMetadata {
&self.metadata
}
async fn execute(&self, input: Value, _ctx: &AgentContext<()>) -> Result<Value> {
Ok(input)
}
}
struct AlwaysReject {
reason: String,
}
#[async_trait]
impl Approver for AlwaysReject {
async fn decide(
&self,
_request: &ApprovalRequest,
_ctx: &ExecutionContext,
) -> Result<ApprovalDecision> {
Ok(ApprovalDecision::Reject {
reason: self.reason.clone(),
})
}
}
struct CountingApprovalSink {
approved: Arc<AtomicUsize>,
denied: Arc<AtomicUsize>,
}
#[async_trait]
impl ToolApprovalEventSink for CountingApprovalSink {
async fn record_approved(
&self,
_tenant_id: &TenantId,
_run_id: &str,
_tool_use_id: &str,
_tool: &str,
) {
self.approved.fetch_add(1, Ordering::SeqCst);
}
async fn record_denied(
&self,
_tenant_id: &TenantId,
_run_id: &str,
_tool_use_id: &str,
_tool: &str,
_reason: &str,
) {
self.denied.fetch_add(1, Ordering::SeqCst);
}
}
#[tokio::test]
async fn approver_approve_dispatches_inner_tool() {
let approver: Arc<dyn Approver> = Arc::new(AlwaysApprove);
let registry = ToolRegistry::new()
.layer(ApprovalLayer::new(approver))
.register(Arc::new(EchoTool::new()))
.unwrap();
let ctx = ExecutionContext::new();
let result = registry
.dispatch("", "echo", json!({"x": 1}), &ctx)
.await
.unwrap();
assert_eq!(result, json!({"x": 1}));
}
#[tokio::test]
async fn approver_reject_short_circuits_dispatch() {
let approver: Arc<dyn Approver> = Arc::new(AlwaysReject {
reason: "policy violation".to_owned(),
});
let registry = ToolRegistry::new()
.layer(ApprovalLayer::new(approver))
.register(Arc::new(EchoTool::new()))
.unwrap();
let ctx = ExecutionContext::new();
let err = registry
.dispatch("", "echo", json!({"x": 1}), &ctx)
.await
.unwrap_err();
match err {
Error::InvalidRequest(msg) => {
assert!(msg.contains("approver rejected tool 'echo'"), "got: {msg}");
assert!(msg.contains("policy violation"), "got: {msg}");
}
other => panic!("expected InvalidRequest, got {other:?}"),
}
}
#[tokio::test]
async fn approval_sink_records_both_decisions() {
let approved = Arc::new(AtomicUsize::new(0));
let denied = Arc::new(AtomicUsize::new(0));
let sink = CountingApprovalSink {
approved: Arc::clone(&approved),
denied: Arc::clone(&denied),
};
let handle = ToolApprovalEventSinkHandle::new(sink);
let ctx = ExecutionContext::new().add_extension(handle);
let approver_ok: Arc<dyn Approver> = Arc::new(AlwaysApprove);
let registry = ToolRegistry::new()
.layer(ApprovalLayer::new(approver_ok))
.register(Arc::new(EchoTool::new()))
.unwrap();
registry
.dispatch("", "echo", json!({"x": 1}), &ctx)
.await
.unwrap();
assert_eq!(approved.load(Ordering::SeqCst), 1);
assert_eq!(denied.load(Ordering::SeqCst), 0);
let approver_no: Arc<dyn Approver> = Arc::new(AlwaysReject {
reason: "no".into(),
});
let registry = ToolRegistry::new()
.layer(ApprovalLayer::new(approver_no))
.register(Arc::new(EchoTool::new()))
.unwrap();
let _ = registry.dispatch("", "echo", json!({"x": 1}), &ctx).await;
assert_eq!(approved.load(Ordering::SeqCst), 1);
assert_eq!(denied.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn approval_layer_runs_without_sink_attached() {
let approver: Arc<dyn Approver> = Arc::new(AlwaysApprove);
let registry = ToolRegistry::new()
.layer(ApprovalLayer::new(approver))
.register(Arc::new(EchoTool::new()))
.unwrap();
let result = registry
.dispatch("", "echo", json!({"x": 1}), &ExecutionContext::new())
.await
.unwrap();
assert_eq!(result, json!({"x": 1}));
}
struct AlwaysAwait;
#[async_trait]
impl Approver for AlwaysAwait {
async fn decide(
&self,
_request: &ApprovalRequest,
_ctx: &ExecutionContext,
) -> Result<ApprovalDecision> {
Ok(ApprovalDecision::AwaitExternal)
}
}
#[tokio::test]
async fn await_external_raises_interrupted_with_payload() {
let approver: Arc<dyn Approver> = Arc::new(AlwaysAwait);
let registry = ToolRegistry::new()
.layer(ApprovalLayer::new(approver))
.register(Arc::new(EchoTool::new()))
.unwrap();
let err = registry
.dispatch("tu-1", "echo", json!({"x": 1}), &ExecutionContext::new())
.await
.unwrap_err();
match err {
Error::Interrupted { kind, payload } => {
assert_eq!(
kind,
InterruptionKind::ApprovalPending {
tool_use_id: "tu-1".into()
}
);
assert_eq!(payload["tool_use_id"].as_str(), Some("tu-1"));
assert_eq!(payload["tool"].as_str(), Some("echo"));
assert_eq!(payload["input"], json!({"x": 1}));
}
other => panic!("expected Interrupted, got {other:?}"),
}
}
#[tokio::test]
async fn approval_decision_overrides_short_circuit_approver() {
let approver: Arc<dyn Approver> = Arc::new(AlwaysAwait);
let registry = ToolRegistry::new()
.layer(ApprovalLayer::new(approver))
.register(Arc::new(EchoTool::new()))
.unwrap();
let overrides = {
let mut p = PendingApprovalDecisions::new();
p.insert("tu-1", ApprovalDecision::Approve);
p
};
let ctx = ExecutionContext::new().add_extension(overrides);
let result = registry
.dispatch("tu-1", "echo", json!({"x": 1}), &ctx)
.await
.unwrap();
assert_eq!(result, json!({"x": 1}));
}
#[tokio::test]
async fn approval_decision_overrides_propagate_reject_decision() {
let approver: Arc<dyn Approver> = Arc::new(AlwaysAwait);
let registry = ToolRegistry::new()
.layer(ApprovalLayer::new(approver))
.register(Arc::new(EchoTool::new()))
.unwrap();
let mut overrides = PendingApprovalDecisions::new();
overrides.insert(
"tu-1",
ApprovalDecision::Reject {
reason: "operator declined out-of-band".to_owned(),
},
);
let ctx = ExecutionContext::new().add_extension(overrides);
let err = registry
.dispatch("tu-1", "echo", json!({"x": 1}), &ctx)
.await
.unwrap_err();
match err {
Error::InvalidRequest(msg) => {
assert!(
msg.contains("operator declined out-of-band"),
"expected override reason, got: {msg}"
);
}
other => panic!("expected InvalidRequest from override, got {other:?}"),
}
}
#[tokio::test]
async fn approval_decision_overrides_only_apply_to_matching_tool_use_id() {
let approver: Arc<dyn Approver> = Arc::new(AlwaysAwait);
let registry = ToolRegistry::new()
.layer(ApprovalLayer::new(approver))
.register(Arc::new(EchoTool::new()))
.unwrap();
let mut overrides = PendingApprovalDecisions::new();
overrides.insert("a-different-id", ApprovalDecision::Approve);
let ctx = ExecutionContext::new().add_extension(overrides);
let err = registry
.dispatch("tu-1", "echo", json!({"x": 1}), &ctx)
.await
.unwrap_err();
assert!(matches!(err, Error::Interrupted { .. }));
}
#[tokio::test]
async fn approval_layer_composes_under_outer_layer() {
use entelix_core::tools::{ScopedToolLayer, ToolDispatchScope};
use futures::future::BoxFuture;
struct ApproveAfterScope {
scope_wraps: Arc<AtomicUsize>,
}
impl ToolDispatchScope for ApproveAfterScope {
fn wrap(
&self,
_ctx: ExecutionContext,
fut: BoxFuture<'static, Result<Value>>,
) -> BoxFuture<'static, Result<Value>> {
self.scope_wraps.fetch_add(1, Ordering::SeqCst);
fut
}
}
let scope_wraps = Arc::new(AtomicUsize::new(0));
let scope = ApproveAfterScope {
scope_wraps: Arc::clone(&scope_wraps),
};
let approver: Arc<dyn Approver> = Arc::new(AlwaysApprove);
let registry = ToolRegistry::new()
.layer(ScopedToolLayer::new(scope)) .layer(ApprovalLayer::new(approver)) .register(Arc::new(EchoTool::new()))
.unwrap();
registry
.dispatch("", "echo", json!({"x": 1}), &ExecutionContext::new())
.await
.unwrap();
assert_eq!(scope_wraps.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn approval_reject_short_circuits_before_inner_scope() {
use entelix_core::tools::{ScopedToolLayer, ToolDispatchScope};
use futures::future::BoxFuture;
struct CountScope {
wraps: Arc<AtomicUsize>,
}
impl ToolDispatchScope for CountScope {
fn wrap(
&self,
_ctx: ExecutionContext,
fut: BoxFuture<'static, Result<Value>>,
) -> BoxFuture<'static, Result<Value>> {
self.wraps.fetch_add(1, Ordering::SeqCst);
fut
}
}
let wraps = Arc::new(AtomicUsize::new(0));
let scope = CountScope {
wraps: Arc::clone(&wraps),
};
let approver: Arc<dyn Approver> = Arc::new(AlwaysReject {
reason: "no".into(),
});
let registry = ToolRegistry::new()
.layer(ScopedToolLayer::new(scope)) .layer(ApprovalLayer::new(approver)) .register(Arc::new(EchoTool::new()))
.unwrap();
let _ = registry
.dispatch("", "echo", json!({"x": 1}), &ExecutionContext::new())
.await;
assert_eq!(
wraps.load(Ordering::SeqCst),
0,
"scope wrap must not fire when the outer ApprovalLayer rejects"
);
}
}