1use std::sync::Arc;
63use std::task::{Context, Poll};
64
65use async_trait::async_trait;
66use futures::future::BoxFuture;
67use serde_json::{Value, json};
68use tower::{Layer, Service};
69
70use entelix_core::PendingApprovalDecisions;
71use entelix_core::TenantId;
72use entelix_core::error::{Error, Result};
73use entelix_core::interruption::InterruptionKind;
74use entelix_core::service::ToolInvocation;
75use entelix_core::tools::ToolEffect;
76
77use crate::agent::approver::{ApprovalDecision, ApprovalRequest, Approver};
78use crate::agent::event::AgentEvent;
79use crate::agent::sink::AgentEventSink;
80
81#[async_trait]
89pub trait ToolApprovalEventSink: Send + Sync + 'static {
90 async fn record_approved(
102 &self,
103 tenant_id: &TenantId,
104 run_id: &str,
105 tool_use_id: &str,
106 tool: &str,
107 );
108
109 async fn record_denied(
113 &self,
114 tenant_id: &TenantId,
115 run_id: &str,
116 tool_use_id: &str,
117 tool: &str,
118 reason: &str,
119 );
120}
121
122#[derive(Clone)]
128pub struct ToolApprovalEventSinkHandle {
129 sink: Arc<dyn ToolApprovalEventSink>,
130}
131
132impl ToolApprovalEventSinkHandle {
133 pub fn new<E>(sink: E) -> Self
137 where
138 E: ToolApprovalEventSink,
139 {
140 Self {
141 sink: Arc::new(sink),
142 }
143 }
144
145 pub fn for_agent_sink<S>(sink: Arc<dyn AgentEventSink<S>>) -> Self
151 where
152 S: Clone + Send + Sync + 'static,
153 {
154 Self {
155 sink: Arc::new(SinkAdapter { sink }),
156 }
157 }
158
159 pub fn inner(&self) -> &Arc<dyn ToolApprovalEventSink> {
163 &self.sink
164 }
165}
166
167struct SinkAdapter<S> {
168 sink: Arc<dyn AgentEventSink<S>>,
169}
170
171#[async_trait]
172impl<S> ToolApprovalEventSink for SinkAdapter<S>
173where
174 S: Clone + Send + Sync + 'static,
175{
176 async fn record_approved(
177 &self,
178 tenant_id: &TenantId,
179 run_id: &str,
180 tool_use_id: &str,
181 tool: &str,
182 ) {
183 let event: AgentEvent<S> = AgentEvent::ToolCallApproved {
184 run_id: run_id.to_owned(),
185 tenant_id: tenant_id.clone(),
186 tool_use_id: tool_use_id.to_owned(),
187 tool: tool.to_owned(),
188 };
189 let _ = self.sink.send(event).await;
190 }
191
192 async fn record_denied(
193 &self,
194 tenant_id: &TenantId,
195 run_id: &str,
196 tool_use_id: &str,
197 tool: &str,
198 reason: &str,
199 ) {
200 let event: AgentEvent<S> = AgentEvent::ToolCallDenied {
201 run_id: run_id.to_owned(),
202 tenant_id: tenant_id.clone(),
203 tool_use_id: tool_use_id.to_owned(),
204 tool: tool.to_owned(),
205 reason: reason.to_owned(),
206 };
207 let _ = self.sink.send(event).await;
208 }
209}
210
211#[derive(Clone, Debug, Default, Eq, PartialEq)]
217#[non_exhaustive]
218pub enum EffectGate {
219 #[default]
223 Always,
224 DestructiveOnly,
230 MutatingAndAbove,
235}
236
237impl EffectGate {
238 #[must_use]
241 pub const fn requires_approval(self, effect: ToolEffect) -> bool {
242 match self {
243 Self::Always => true,
244 Self::DestructiveOnly => matches!(effect, ToolEffect::Destructive),
245 Self::MutatingAndAbove => {
246 matches!(effect, ToolEffect::Mutating | ToolEffect::Destructive)
247 }
248 }
249 }
250}
251
252pub struct ApprovalLayer {
257 approver: Arc<dyn Approver>,
258 gate: EffectGate,
259}
260
261impl ApprovalLayer {
262 pub const NAME: &'static str = "tool_approval";
267
268 pub fn new(approver: Arc<dyn Approver>) -> Self {
273 Self {
274 approver,
275 gate: EffectGate::default(),
276 }
277 }
278
279 #[must_use]
286 pub const fn with_effect_gate(mut self, gate: EffectGate) -> Self {
287 self.gate = gate;
288 self
289 }
290}
291
292impl Clone for ApprovalLayer {
293 fn clone(&self) -> Self {
294 Self {
295 approver: Arc::clone(&self.approver),
296 gate: self.gate.clone(),
297 }
298 }
299}
300
301impl<S> Layer<S> for ApprovalLayer {
302 type Service = ApprovalService<S>;
303
304 fn layer(&self, inner: S) -> Self::Service {
305 ApprovalService {
306 inner,
307 approver: Arc::clone(&self.approver),
308 gate: self.gate.clone(),
309 }
310 }
311}
312
313impl entelix_core::NamedLayer for ApprovalLayer {
314 fn layer_name(&self) -> &'static str {
315 Self::NAME
316 }
317}
318
319pub struct ApprovalService<S> {
323 inner: S,
324 approver: Arc<dyn Approver>,
325 gate: EffectGate,
326}
327
328impl<S: Clone> Clone for ApprovalService<S> {
329 fn clone(&self) -> Self {
330 Self {
331 inner: self.inner.clone(),
332 approver: Arc::clone(&self.approver),
333 gate: self.gate.clone(),
334 }
335 }
336}
337
338impl<S> Service<ToolInvocation> for ApprovalService<S>
339where
340 S: Service<ToolInvocation, Response = Value, Error = Error> + Clone + Send + 'static,
341 S::Future: Send + 'static,
342{
343 type Response = Value;
344 type Error = Error;
345 type Future = BoxFuture<'static, Result<Value>>;
346
347 #[inline]
348 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
349 self.inner.poll_ready(cx)
350 }
351
352 fn call(&mut self, invocation: ToolInvocation) -> Self::Future {
353 let approver = Arc::clone(&self.approver);
354 let gate = self.gate.clone();
355 let mut inner = self.inner.clone();
356 Box::pin(async move {
357 let override_decision = invocation
366 .ctx
367 .extension::<PendingApprovalDecisions>()
368 .and_then(|o| o.get(&invocation.tool_use_id).cloned());
369 if override_decision.is_none() && !gate.requires_approval(invocation.metadata.effect) {
370 return inner.call(invocation).await;
371 }
372 let decision = if let Some(d) = override_decision {
373 d
374 } else {
375 let request = ApprovalRequest::new(
376 invocation.tool_use_id.clone(),
377 invocation.metadata.name.clone(),
378 invocation.input.clone(),
379 );
380 approver.decide(&request, &invocation.ctx).await?
381 };
382
383 let sink = invocation.ctx.extension::<ToolApprovalEventSinkHandle>();
384 let tenant_id = invocation.ctx.tenant_id().clone();
385 let run_id = invocation.ctx.run_id().unwrap_or("").to_owned();
386 let tool_use_id = invocation.tool_use_id.clone();
387 let tool_name = invocation.metadata.name.clone();
388 let input = invocation.input.clone();
389
390 match decision {
391 ApprovalDecision::Approve => {
392 if let Some(handle) = sink.as_deref() {
393 handle
394 .inner()
395 .record_approved(&tenant_id, &run_id, &tool_use_id, &tool_name)
396 .await;
397 }
398 inner.call(invocation).await
399 }
400 ApprovalDecision::Reject { reason } => {
401 if let Some(handle) = sink.as_deref() {
402 handle
403 .inner()
404 .record_denied(&tenant_id, &run_id, &tool_use_id, &tool_name, &reason)
405 .await;
406 }
407 Err(Error::invalid_request(format!(
408 "approver rejected tool '{tool_name}' dispatch: {reason}"
409 )))
410 }
411 ApprovalDecision::AwaitExternal => {
412 Err(Error::Interrupted {
422 kind: InterruptionKind::ApprovalPending {
423 tool_use_id: tool_use_id.clone(),
424 },
425 payload: json!({
426 "run_id": run_id,
427 "tool_use_id": tool_use_id,
428 "tool": tool_name,
429 "input": input,
430 }),
431 })
432 }
433 _ => Err(Error::config(format!(
438 "ApprovalLayer received an unsupported `ApprovalDecision` variant for tool '{tool_name}'; \
439 update the layer to handle the new variant"
440 ))),
441 }
442 })
443 }
444}
445
446#[cfg(test)]
447#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
448mod tests {
449 use std::sync::atomic::{AtomicUsize, Ordering};
450
451 use entelix_core::AgentContext;
452 use entelix_core::ExecutionContext;
453 use entelix_core::tools::{Tool, ToolMetadata, ToolRegistry};
454 use serde_json::json;
455
456 use super::*;
457 use crate::agent::approver::{AlwaysApprove, ApprovalDecision, ApprovalRequest};
458
459 struct EchoTool {
460 metadata: ToolMetadata,
461 }
462
463 impl EchoTool {
464 fn new() -> Self {
465 Self {
466 metadata: ToolMetadata::function(
467 "echo",
468 "Echo input verbatim.",
469 json!({ "type": "object" }),
470 ),
471 }
472 }
473 }
474
475 #[async_trait]
476 impl Tool for EchoTool {
477 fn metadata(&self) -> &ToolMetadata {
478 &self.metadata
479 }
480
481 async fn execute(&self, input: Value, _ctx: &AgentContext<()>) -> Result<Value> {
482 Ok(input)
483 }
484 }
485
486 struct AlwaysReject {
487 reason: String,
488 }
489
490 #[async_trait]
491 impl Approver for AlwaysReject {
492 async fn decide(
493 &self,
494 _request: &ApprovalRequest,
495 _ctx: &ExecutionContext,
496 ) -> Result<ApprovalDecision> {
497 Ok(ApprovalDecision::Reject {
498 reason: self.reason.clone(),
499 })
500 }
501 }
502
503 struct CountingApprovalSink {
504 approved: Arc<AtomicUsize>,
505 denied: Arc<AtomicUsize>,
506 }
507
508 #[async_trait]
509 impl ToolApprovalEventSink for CountingApprovalSink {
510 async fn record_approved(
511 &self,
512 _tenant_id: &TenantId,
513 _run_id: &str,
514 _tool_use_id: &str,
515 _tool: &str,
516 ) {
517 self.approved.fetch_add(1, Ordering::SeqCst);
518 }
519 async fn record_denied(
520 &self,
521 _tenant_id: &TenantId,
522 _run_id: &str,
523 _tool_use_id: &str,
524 _tool: &str,
525 _reason: &str,
526 ) {
527 self.denied.fetch_add(1, Ordering::SeqCst);
528 }
529 }
530
531 #[tokio::test]
532 async fn approver_approve_dispatches_inner_tool() {
533 let approver: Arc<dyn Approver> = Arc::new(AlwaysApprove);
534 let registry = ToolRegistry::new()
535 .layer(ApprovalLayer::new(approver))
536 .register(Arc::new(EchoTool::new()))
537 .unwrap();
538 let ctx = ExecutionContext::new();
539 let result = registry
540 .dispatch("", "echo", json!({"x": 1}), &ctx)
541 .await
542 .unwrap();
543 assert_eq!(result, json!({"x": 1}));
544 }
545
546 #[tokio::test]
547 async fn approver_reject_short_circuits_dispatch() {
548 let approver: Arc<dyn Approver> = Arc::new(AlwaysReject {
549 reason: "policy violation".to_owned(),
550 });
551 let registry = ToolRegistry::new()
552 .layer(ApprovalLayer::new(approver))
553 .register(Arc::new(EchoTool::new()))
554 .unwrap();
555 let ctx = ExecutionContext::new();
556 let err = registry
557 .dispatch("", "echo", json!({"x": 1}), &ctx)
558 .await
559 .unwrap_err();
560 match err {
561 Error::InvalidRequest(msg) => {
562 assert!(msg.contains("approver rejected tool 'echo'"), "got: {msg}");
563 assert!(msg.contains("policy violation"), "got: {msg}");
564 }
565 other => panic!("expected InvalidRequest, got {other:?}"),
566 }
567 }
568
569 #[tokio::test]
570 async fn approval_sink_records_both_decisions() {
571 let approved = Arc::new(AtomicUsize::new(0));
572 let denied = Arc::new(AtomicUsize::new(0));
573 let sink = CountingApprovalSink {
574 approved: Arc::clone(&approved),
575 denied: Arc::clone(&denied),
576 };
577 let handle = ToolApprovalEventSinkHandle::new(sink);
578 let ctx = ExecutionContext::new().add_extension(handle);
579
580 let approver_ok: Arc<dyn Approver> = Arc::new(AlwaysApprove);
582 let registry = ToolRegistry::new()
583 .layer(ApprovalLayer::new(approver_ok))
584 .register(Arc::new(EchoTool::new()))
585 .unwrap();
586 registry
587 .dispatch("", "echo", json!({"x": 1}), &ctx)
588 .await
589 .unwrap();
590 assert_eq!(approved.load(Ordering::SeqCst), 1);
591 assert_eq!(denied.load(Ordering::SeqCst), 0);
592
593 let approver_no: Arc<dyn Approver> = Arc::new(AlwaysReject {
595 reason: "no".into(),
596 });
597 let registry = ToolRegistry::new()
598 .layer(ApprovalLayer::new(approver_no))
599 .register(Arc::new(EchoTool::new()))
600 .unwrap();
601 let _ = registry.dispatch("", "echo", json!({"x": 1}), &ctx).await;
602 assert_eq!(approved.load(Ordering::SeqCst), 1);
603 assert_eq!(denied.load(Ordering::SeqCst), 1);
604 }
605
606 #[tokio::test]
607 async fn approval_layer_runs_without_sink_attached() {
608 let approver: Arc<dyn Approver> = Arc::new(AlwaysApprove);
610 let registry = ToolRegistry::new()
611 .layer(ApprovalLayer::new(approver))
612 .register(Arc::new(EchoTool::new()))
613 .unwrap();
614 let result = registry
615 .dispatch("", "echo", json!({"x": 1}), &ExecutionContext::new())
616 .await
617 .unwrap();
618 assert_eq!(result, json!({"x": 1}));
619 }
620
621 struct AlwaysAwait;
622
623 #[async_trait]
624 impl Approver for AlwaysAwait {
625 async fn decide(
626 &self,
627 _request: &ApprovalRequest,
628 _ctx: &ExecutionContext,
629 ) -> Result<ApprovalDecision> {
630 Ok(ApprovalDecision::AwaitExternal)
631 }
632 }
633
634 #[tokio::test]
635 async fn await_external_raises_interrupted_with_payload() {
636 let approver: Arc<dyn Approver> = Arc::new(AlwaysAwait);
642 let registry = ToolRegistry::new()
643 .layer(ApprovalLayer::new(approver))
644 .register(Arc::new(EchoTool::new()))
645 .unwrap();
646 let err = registry
647 .dispatch("tu-1", "echo", json!({"x": 1}), &ExecutionContext::new())
648 .await
649 .unwrap_err();
650 match err {
651 Error::Interrupted { kind, payload } => {
652 assert_eq!(
653 kind,
654 InterruptionKind::ApprovalPending {
655 tool_use_id: "tu-1".into()
656 }
657 );
658 assert_eq!(payload["tool_use_id"].as_str(), Some("tu-1"));
659 assert_eq!(payload["tool"].as_str(), Some("echo"));
660 assert_eq!(payload["input"], json!({"x": 1}));
661 }
662 other => panic!("expected Interrupted, got {other:?}"),
663 }
664 }
665
666 #[tokio::test]
667 async fn approval_decision_overrides_short_circuit_approver() {
668 let approver: Arc<dyn Approver> = Arc::new(AlwaysAwait);
674 let registry = ToolRegistry::new()
675 .layer(ApprovalLayer::new(approver))
676 .register(Arc::new(EchoTool::new()))
677 .unwrap();
678 let overrides = {
679 let mut p = PendingApprovalDecisions::new();
680 p.insert("tu-1", ApprovalDecision::Approve);
681 p
682 };
683 let ctx = ExecutionContext::new().add_extension(overrides);
684
685 let result = registry
686 .dispatch("tu-1", "echo", json!({"x": 1}), &ctx)
687 .await
688 .unwrap();
689 assert_eq!(result, json!({"x": 1}));
690 }
691
692 #[tokio::test]
693 async fn approval_decision_overrides_propagate_reject_decision() {
694 let approver: Arc<dyn Approver> = Arc::new(AlwaysAwait);
699 let registry = ToolRegistry::new()
700 .layer(ApprovalLayer::new(approver))
701 .register(Arc::new(EchoTool::new()))
702 .unwrap();
703 let mut overrides = PendingApprovalDecisions::new();
704 overrides.insert(
705 "tu-1",
706 ApprovalDecision::Reject {
707 reason: "operator declined out-of-band".to_owned(),
708 },
709 );
710 let ctx = ExecutionContext::new().add_extension(overrides);
711
712 let err = registry
713 .dispatch("tu-1", "echo", json!({"x": 1}), &ctx)
714 .await
715 .unwrap_err();
716 match err {
717 Error::InvalidRequest(msg) => {
718 assert!(
719 msg.contains("operator declined out-of-band"),
720 "expected override reason, got: {msg}"
721 );
722 }
723 other => panic!("expected InvalidRequest from override, got {other:?}"),
724 }
725 }
726
727 #[tokio::test]
728 async fn approval_decision_overrides_only_apply_to_matching_tool_use_id() {
729 let approver: Arc<dyn Approver> = Arc::new(AlwaysAwait);
732 let registry = ToolRegistry::new()
733 .layer(ApprovalLayer::new(approver))
734 .register(Arc::new(EchoTool::new()))
735 .unwrap();
736 let mut overrides = PendingApprovalDecisions::new();
737 overrides.insert("a-different-id", ApprovalDecision::Approve);
738 let ctx = ExecutionContext::new().add_extension(overrides);
739
740 let err = registry
741 .dispatch("tu-1", "echo", json!({"x": 1}), &ctx)
742 .await
743 .unwrap_err();
744 assert!(matches!(err, Error::Interrupted { .. }));
747 }
748
749 #[tokio::test]
750 async fn approval_layer_composes_under_outer_layer() {
751 use entelix_core::tools::{ScopedToolLayer, ToolDispatchScope};
760 use futures::future::BoxFuture;
761
762 struct ApproveAfterScope {
763 scope_wraps: Arc<AtomicUsize>,
764 }
765 impl ToolDispatchScope for ApproveAfterScope {
766 fn wrap(
767 &self,
768 _ctx: ExecutionContext,
769 fut: BoxFuture<'static, Result<Value>>,
770 ) -> BoxFuture<'static, Result<Value>> {
771 self.scope_wraps.fetch_add(1, Ordering::SeqCst);
772 fut
773 }
774 }
775
776 let scope_wraps = Arc::new(AtomicUsize::new(0));
777 let scope = ApproveAfterScope {
778 scope_wraps: Arc::clone(&scope_wraps),
779 };
780 let approver: Arc<dyn Approver> = Arc::new(AlwaysApprove);
781 let registry = ToolRegistry::new()
782 .layer(ScopedToolLayer::new(scope)) .layer(ApprovalLayer::new(approver)) .register(Arc::new(EchoTool::new()))
785 .unwrap();
786
787 registry
788 .dispatch("", "echo", json!({"x": 1}), &ExecutionContext::new())
789 .await
790 .unwrap();
791 assert_eq!(scope_wraps.load(Ordering::SeqCst), 1);
794 }
795
796 #[tokio::test]
797 async fn approval_reject_short_circuits_before_inner_scope() {
798 use entelix_core::tools::{ScopedToolLayer, ToolDispatchScope};
805 use futures::future::BoxFuture;
806
807 struct CountScope {
808 wraps: Arc<AtomicUsize>,
809 }
810 impl ToolDispatchScope for CountScope {
811 fn wrap(
812 &self,
813 _ctx: ExecutionContext,
814 fut: BoxFuture<'static, Result<Value>>,
815 ) -> BoxFuture<'static, Result<Value>> {
816 self.wraps.fetch_add(1, Ordering::SeqCst);
817 fut
818 }
819 }
820
821 let wraps = Arc::new(AtomicUsize::new(0));
822 let scope = CountScope {
823 wraps: Arc::clone(&wraps),
824 };
825 let approver: Arc<dyn Approver> = Arc::new(AlwaysReject {
826 reason: "no".into(),
827 });
828 let registry = ToolRegistry::new()
829 .layer(ScopedToolLayer::new(scope)) .layer(ApprovalLayer::new(approver)) .register(Arc::new(EchoTool::new()))
832 .unwrap();
833
834 let _ = registry
835 .dispatch("", "echo", json!({"x": 1}), &ExecutionContext::new())
836 .await;
837 assert_eq!(
838 wraps.load(Ordering::SeqCst),
839 0,
840 "scope wrap must not fire when the outer ApprovalLayer rejects"
841 );
842 }
843}