Skip to main content

llm_stack/
intercept.rs

1// Interceptor methods are chainable builders, not pure functions
2#![allow(clippy::must_use_candidate)]
3// Explicit casts for duration conversions are clearer than try_into
4#![allow(clippy::cast_possible_truncation)]
5#![allow(clippy::cast_sign_loss)]
6#![allow(clippy::cast_precision_loss)]
7#![allow(clippy::cast_possible_wrap)]
8// Lifetime names are kept explicit for clarity in trait impls
9#![allow(clippy::needless_lifetimes)]
10
11//! Unified interceptor system for LLM calls and tool executions.
12//!
13//! This module provides a composable middleware-like system that works across
14//! different domains. The core abstraction is the [`Interceptor`] trait, which
15//! wraps operations and can inspect, modify, or short-circuit them.
16//!
17//! # Architecture
18//!
19//! ```text
20//! InterceptorStack::new()
21//!     .with(Logging::default())      // outermost: sees request first
22//!     .with(Retry::default())        // retries wrap everything inside
23//!     .with(Timeout::new(30s))       // timeout on inner operation
24//!     .execute(&input, operation)
25//! ```
26//!
27//! # Domains
28//!
29//! The system supports multiple domains via marker types:
30//!
31//! - [`ToolExec<Ctx>`] - Tool executions (integrated with [`ToolRegistry`](crate::ToolRegistry))
32//! - [`LlmCall`] - LLM provider requests (reserved for future use, not yet integrated)
33//!
34//! # Generic Interceptors
35//!
36//! Interceptors like [`Retry`], [`Timeout`], and [`Logging`] are generic over
37//! any domain that implements the required behavior traits ([`Retryable`],
38//! [`Timeoutable`], [`Loggable`]).
39//!
40//! # Example
41//!
42//! ```rust,ignore
43//! use llm_stack::ToolRegistry;
44//! use llm_stack::intercept::{InterceptorStack, Retry, Timeout, ToolExec};
45//! use std::time::Duration;
46//!
47//! let registry: ToolRegistry<()> = ToolRegistry::new()
48//!     .with_interceptors(
49//!         InterceptorStack::<ToolExec<()>>::new()
50//!             .with(Retry::default())
51//!             .with(Timeout::new(Duration::from_secs(30)))
52//!     );
53//! ```
54
55use std::future::Future;
56use std::marker::PhantomData;
57use std::pin::Pin;
58use std::sync::Arc;
59
60// Re-export core types at module level for convenience
61pub use behavior::{Loggable, Outcome, Retryable, Timeoutable};
62pub use domain::{LlmCall, ToolExec, ToolRequest, ToolResponse};
63
64// Note: Interceptable, Interceptor, InterceptorStack, Next, Operation, FnOperation
65// are defined at module root and don't need re-exporting
66
67/// An operation that can be intercepted.
68///
69/// This trait defines the input and output types for an interceptable operation.
70/// Implement this for marker types that represent different domains (e.g., LLM calls,
71/// tool executions).
72pub trait Interceptable: Send + Sync + 'static {
73    /// Input to the operation.
74    type Input: Send;
75
76    /// Output from the operation.
77    type Output: Send;
78}
79
80/// Wraps an interceptable operation.
81///
82/// Interceptors form a chain. Each interceptor receives the input and a [`Next`]
83/// handle. It can:
84/// - Pass through: call `next.run(input).await`
85/// - Modify input: transform input, then call `next.run(&modified).await`
86/// - Short-circuit: return early without calling `next`
87/// - Retry: call `next.clone().run(input).await` multiple times
88/// - Wrap output: call `next`, then transform the result
89///
90/// # Implementing
91///
92/// ```rust,ignore
93/// use llm_stack::intercept::{Interceptor, Interceptable, Next};
94/// use std::future::Future;
95/// use std::pin::Pin;
96///
97/// struct MyInterceptor;
98///
99/// impl<T: Interceptable> Interceptor<T> for MyInterceptor
100/// where
101///     T::Input: Sync,
102/// {
103///     fn intercept<'a>(
104///         &'a self,
105///         input: &'a T::Input,
106///         next: Next<'a, T>,
107///     ) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>> {
108///         Box::pin(async move {
109///             // Do something before
110///             let result = next.run(input).await;
111///             // Do something after
112///             result
113///         })
114///     }
115/// }
116/// ```
117pub trait Interceptor<T: Interceptable>: Send + Sync {
118    /// Intercept the operation.
119    ///
120    /// Call `next.run(input)` to continue the chain, or return early to short-circuit.
121    fn intercept<'a>(
122        &'a self,
123        input: &'a T::Input,
124        next: Next<'a, T>,
125    ) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>>;
126}
127
128/// Handle to invoke the next interceptor in the chain (or the final operation).
129///
130/// `Next` is [`Clone`], which is essential for retry interceptors that need to
131/// call the chain multiple times. Cloning is cheap - it only clones references.
132pub struct Next<'a, T: Interceptable> {
133    interceptors: &'a [Arc<dyn Interceptor<T>>],
134    operation: &'a dyn Operation<T>,
135}
136
137impl<T: Interceptable> Clone for Next<'_, T> {
138    fn clone(&self) -> Self {
139        *self
140    }
141}
142
143// Copy is valid: both fields are references
144impl<T: Interceptable> Copy for Next<'_, T> {}
145
146impl<T: Interceptable> Next<'_, T>
147where
148    T::Input: Sync,
149{
150    /// Run the operation through the remaining chain.
151    ///
152    /// This consumes `self`, but since `Next` is `Copy`, you can call it multiple
153    /// times by copying first (e.g., for retry logic).
154    pub async fn run(self, input: &T::Input) -> T::Output {
155        if let Some((first, rest)) = self.interceptors.split_first() {
156            let next = Next {
157                interceptors: rest,
158                operation: self.operation,
159            };
160            first.intercept(input, next).await
161        } else {
162            self.operation.execute(input).await
163        }
164    }
165}
166
167/// The final operation to execute after all interceptors.
168///
169/// This is object-safe to allow storing different operation types.
170pub trait Operation<T: Interceptable>: Send + Sync {
171    /// Execute the operation.
172    fn execute<'a>(
173        &'a self,
174        input: &'a T::Input,
175    ) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>>
176    where
177        T::Input: Sync;
178}
179
180/// Wrap a closure as an [`Operation`].
181pub struct FnOperation<T, F>
182where
183    T: Interceptable,
184    F: Fn(&T::Input) -> Pin<Box<dyn Future<Output = T::Output> + Send + '_>> + Send + Sync,
185{
186    f: F,
187    _marker: PhantomData<T>,
188}
189
190impl<T, F> FnOperation<T, F>
191where
192    T: Interceptable,
193    F: Fn(&T::Input) -> Pin<Box<dyn Future<Output = T::Output> + Send + '_>> + Send + Sync,
194{
195    /// Create a new operation from a closure.
196    pub fn new(f: F) -> Self {
197        Self {
198            f,
199            _marker: PhantomData,
200        }
201    }
202}
203
204impl<T, F> Operation<T> for FnOperation<T, F>
205where
206    T: Interceptable,
207    F: Fn(&T::Input) -> Pin<Box<dyn Future<Output = T::Output> + Send + '_>> + Send + Sync,
208{
209    fn execute<'a>(
210        &'a self,
211        input: &'a T::Input,
212    ) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>>
213    where
214        T::Input: Sync,
215    {
216        (self.f)(input)
217    }
218}
219
220/// A composable stack of interceptors.
221///
222/// Interceptors are executed in the order they are added:
223/// - First added = outermost = sees request first, sees response last
224/// - Last added = innermost = sees request last, sees response first
225///
226/// # Example
227///
228/// ```rust,ignore
229/// use llm_stack::intercept::{InterceptorStack, Retry, ToolExec};
230///
231/// let stack = InterceptorStack::<ToolExec<()>>::new()
232///     .with(Retry::default());
233/// ```
234pub struct InterceptorStack<T: Interceptable> {
235    layers: Vec<Arc<dyn Interceptor<T>>>,
236}
237
238impl<T: Interceptable> Clone for InterceptorStack<T> {
239    fn clone(&self) -> Self {
240        Self {
241            layers: self.layers.clone(),
242        }
243    }
244}
245
246impl<T: Interceptable> InterceptorStack<T> {
247    /// Create an empty interceptor stack.
248    pub fn new() -> Self {
249        Self { layers: Vec::new() }
250    }
251
252    /// Add an interceptor to the stack.
253    ///
254    /// Interceptors are executed in the order added (first = outermost).
255    #[must_use]
256    pub fn with<I: Interceptor<T> + 'static>(mut self, interceptor: I) -> Self {
257        self.layers.push(Arc::new(interceptor));
258        self
259    }
260
261    /// Add a shared interceptor instance.
262    ///
263    /// Useful when the same interceptor instance needs to be used across
264    /// multiple stacks (e.g., for shared metrics collection).
265    #[must_use]
266    pub fn with_shared(mut self, interceptor: Arc<dyn Interceptor<T>>) -> Self {
267        self.layers.push(interceptor);
268        self
269    }
270
271    /// Check if the stack has any interceptors.
272    pub fn is_empty(&self) -> bool {
273        self.layers.is_empty()
274    }
275
276    /// Get the number of interceptors in the stack.
277    pub fn len(&self) -> usize {
278        self.layers.len()
279    }
280
281    /// Execute an operation through the interceptor stack.
282    pub async fn execute<'a, O>(&'a self, input: &'a T::Input, operation: &'a O) -> T::Output
283    where
284        T::Input: Sync,
285        O: Operation<T>,
286    {
287        let next = Next {
288            interceptors: &self.layers,
289            operation,
290        };
291        next.run(input).await
292    }
293
294    /// Execute with a closure as the operation.
295    pub async fn execute_fn<'a, F>(&'a self, input: &'a T::Input, f: F) -> T::Output
296    where
297        T::Input: Sync,
298        F: Fn(&T::Input) -> Pin<Box<dyn Future<Output = T::Output> + Send + '_>> + Send + Sync,
299    {
300        let op = FnOperation::<T, F>::new(f);
301        self.execute(input, &op).await
302    }
303}
304
305impl<T: Interceptable> Default for InterceptorStack<T> {
306    fn default() -> Self {
307        Self::new()
308    }
309}
310
311// ============================================================================
312// Domain markers
313// ============================================================================
314
315/// Domain-specific marker types and their Interceptable implementations.
316pub mod domain {
317    use super::Interceptable;
318    use crate::ChatResponse;
319    use crate::error::LlmError;
320    use crate::provider::ChatParams;
321    use serde_json::Value;
322    use std::marker::PhantomData;
323
324    /// Marker for LLM provider calls.
325    ///
326    /// - Input: [`ChatParams`]
327    /// - Output: `Result<ChatResponse, LlmError>`
328    ///
329    /// # Status: Reserved
330    ///
331    /// This domain marker is defined for future use but **not yet integrated**.
332    /// Unlike [`ToolExec`] (which is wired into [`ToolRegistry`](crate::ToolRegistry)),
333    /// there is currently no `Provider` wrapper that executes through an
334    /// `InterceptorStack<LlmCall>`.
335    ///
336    /// The marker exists so that:
337    /// 1. Generic interceptors (`Retry`, `Timeout`, `Logging`) already work with it
338    /// 2. Future provider-level interception can be added without breaking changes
339    ///
340    /// To use `LlmCall` today, you would need to build your own wrapper that
341    /// implements `Provider` and delegates through an `InterceptorStack<LlmCall>`.
342    pub struct LlmCall;
343
344    impl Interceptable for LlmCall {
345        type Input = ChatParams;
346        type Output = Result<ChatResponse, LlmError>;
347    }
348
349    /// Marker for tool executions.
350    ///
351    /// The `Ctx` type parameter matches the context type used by `ToolRegistry<Ctx>`.
352    ///
353    /// - Input: [`ToolRequest`]
354    /// - Output: [`ToolResponse`]
355    pub struct ToolExec<Ctx = ()>(PhantomData<fn() -> Ctx>);
356
357    impl<Ctx: Send + Sync + 'static> Interceptable for ToolExec<Ctx> {
358        type Input = ToolRequest;
359        type Output = ToolResponse;
360    }
361
362    /// Input for tool execution.
363    #[derive(Debug, Clone)]
364    pub struct ToolRequest {
365        /// Name of the tool being called.
366        pub name: String,
367
368        /// Unique ID of this tool call.
369        pub call_id: String,
370
371        /// Arguments passed to the tool (JSON).
372        pub arguments: Value,
373    }
374
375    /// Output from tool execution.
376    #[derive(Debug, Clone)]
377    pub struct ToolResponse {
378        /// The tool's output content.
379        pub content: String,
380
381        /// Whether the execution resulted in an error.
382        pub is_error: bool,
383    }
384
385    impl ToolResponse {
386        /// Create a successful response.
387        pub fn success(content: impl Into<String>) -> Self {
388            Self {
389                content: content.into(),
390                is_error: false,
391            }
392        }
393
394        /// Create an error response.
395        pub fn error(content: impl Into<String>) -> Self {
396            Self {
397                content: content.into(),
398                is_error: true,
399            }
400        }
401    }
402}
403
404// ============================================================================
405// Behavior traits
406// ============================================================================
407
408/// Behavior traits that interceptors can require.
409pub mod behavior {
410    use crate::ChatResponse;
411    use crate::error::LlmError;
412    use crate::provider::ChatParams;
413    use std::time::Duration;
414
415    use super::domain::{ToolRequest, ToolResponse};
416
417    /// Output that can indicate whether retry is appropriate.
418    pub trait Retryable {
419        /// Returns true if the operation should be retried.
420        fn should_retry(&self) -> bool;
421    }
422
423    impl Retryable for Result<ChatResponse, LlmError> {
424        fn should_retry(&self) -> bool {
425            match self {
426                Ok(_) => false,
427                Err(e) => e.is_retryable(),
428            }
429        }
430    }
431
432    impl Retryable for ToolResponse {
433        fn should_retry(&self) -> bool {
434            // By default, tool errors are not retried
435            // Users can implement custom retry logic via interceptors
436            false
437        }
438    }
439
440    /// Output that can represent a timeout.
441    pub trait Timeoutable: Sized {
442        /// Create a timeout error.
443        fn timeout_error(duration: Duration) -> Self;
444    }
445
446    impl Timeoutable for Result<ChatResponse, LlmError> {
447        fn timeout_error(duration: Duration) -> Self {
448            Err(LlmError::Timeout {
449                elapsed_ms: duration.as_millis() as u64,
450            })
451        }
452    }
453
454    impl Timeoutable for ToolResponse {
455        fn timeout_error(duration: Duration) -> Self {
456            ToolResponse {
457                content: format!("Tool execution timed out after {duration:?}"),
458                is_error: true,
459            }
460        }
461    }
462
463    /// Input that can describe itself for logging.
464    pub trait Loggable {
465        /// Return a description of the operation for logging.
466        fn log_description(&self) -> String;
467    }
468
469    impl Loggable for ChatParams {
470        fn log_description(&self) -> String {
471            let tool_count = self.tools.as_ref().map_or(0, Vec::len);
472            format!(
473                "LLM request: {} messages, {} tools",
474                self.messages.len(),
475                tool_count
476            )
477        }
478    }
479
480    impl Loggable for ToolRequest {
481        fn log_description(&self) -> String {
482            format!("Tool call: {} ({})", self.name, self.call_id)
483        }
484    }
485
486    /// Output that can report success/failure for logging.
487    ///
488    /// Separate from `Retryable` because logging success != retry decision.
489    /// A successful response might still be retryable (e.g., partial results),
490    /// and a failed response might not be retryable (e.g., auth error).
491    pub trait Outcome {
492        /// Returns true if the operation succeeded.
493        fn is_success(&self) -> bool;
494    }
495
496    impl Outcome for Result<ChatResponse, LlmError> {
497        fn is_success(&self) -> bool {
498            self.is_ok()
499        }
500    }
501
502    impl Outcome for ToolResponse {
503        fn is_success(&self) -> bool {
504            !self.is_error
505        }
506    }
507}
508
509// ============================================================================
510// Built-in interceptors
511// ============================================================================
512
513/// Built-in interceptors for common cross-cutting concerns.
514pub mod interceptors {
515    #[cfg(feature = "tracing")]
516    use super::behavior::{Loggable, Outcome};
517    use super::behavior::{Retryable, Timeoutable};
518    use super::{Interceptable, Interceptor, Next};
519    use std::future::Future;
520    use std::pin::Pin;
521    use std::time::Duration;
522
523    /// Retry interceptor with exponential backoff.
524    ///
525    /// Retries the operation when the output indicates failure via [`Retryable`].
526    ///
527    /// # Example
528    ///
529    /// ```rust,ignore
530    /// use llm_stack::intercept::{InterceptorStack, Retry, ToolExec};
531    /// use std::time::Duration;
532    ///
533    /// let stack = InterceptorStack::<ToolExec<()>>::new()
534    ///     .with(Retry::new(3, Duration::from_millis(100)));
535    /// ```
536    #[derive(Debug, Clone)]
537    pub struct Retry {
538        /// Maximum number of attempts (including the first).
539        pub max_attempts: u32,
540
541        /// Initial delay before first retry.
542        pub initial_delay: Duration,
543
544        /// Maximum delay between retries.
545        pub max_delay: Duration,
546
547        /// Multiplier for exponential backoff.
548        pub multiplier: f64,
549    }
550
551    impl Default for Retry {
552        fn default() -> Self {
553            Self {
554                max_attempts: 3,
555                initial_delay: Duration::from_millis(500),
556                max_delay: Duration::from_secs(30),
557                multiplier: 2.0,
558            }
559        }
560    }
561
562    impl Retry {
563        /// Create a retry interceptor with the given attempts and initial delay.
564        pub fn new(max_attempts: u32, initial_delay: Duration) -> Self {
565            Self {
566                max_attempts,
567                initial_delay,
568                ..Default::default()
569            }
570        }
571
572        fn delay_for_attempt(&self, attempt: u32) -> Duration {
573            let delay_ms = self.initial_delay.as_millis() as f64
574                * self.multiplier.powi(attempt.saturating_sub(1) as i32);
575            let delay = Duration::from_millis(delay_ms as u64);
576            std::cmp::min(delay, self.max_delay)
577        }
578    }
579
580    impl<T> Interceptor<T> for Retry
581    where
582        T: Interceptable,
583        T::Input: Sync,
584        T::Output: Retryable,
585    {
586        fn intercept<'a>(
587            &'a self,
588            input: &'a T::Input,
589            next: Next<'a, T>,
590        ) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>> {
591            Box::pin(async move {
592                let mut last_result: Option<T::Output> = None;
593
594                for attempt in 1..=self.max_attempts {
595                    let result = next.run(input).await;
596
597                    if !result.should_retry() || attempt == self.max_attempts {
598                        return result;
599                    }
600
601                    // Sleep before retry
602                    let delay = self.delay_for_attempt(attempt);
603                    tokio::time::sleep(delay).await;
604
605                    last_result = Some(result);
606                }
607
608                // Should not reach here, but return last result if we do
609                last_result.expect("at least one attempt should have been made")
610            })
611        }
612    }
613
614    /// Timeout interceptor.
615    ///
616    /// Wraps the operation with a timeout. If the timeout expires, returns
617    /// an error via [`Timeoutable`].
618    ///
619    /// # Example
620    ///
621    /// ```rust,ignore
622    /// use llm_stack::intercept::{InterceptorStack, Timeout, ToolExec};
623    /// use std::time::Duration;
624    ///
625    /// let stack = InterceptorStack::<ToolExec<()>>::new()
626    ///     .with(Timeout::new(Duration::from_secs(30)));
627    /// ```
628    #[derive(Debug, Clone)]
629    pub struct Timeout {
630        /// Maximum duration for the operation.
631        pub duration: Duration,
632    }
633
634    impl Timeout {
635        /// Create a timeout interceptor with the given duration.
636        pub fn new(duration: Duration) -> Self {
637            Self { duration }
638        }
639    }
640
641    impl<T> Interceptor<T> for Timeout
642    where
643        T: Interceptable,
644        T::Input: Sync,
645        T::Output: Timeoutable,
646    {
647        fn intercept<'a>(
648            &'a self,
649            input: &'a T::Input,
650            next: Next<'a, T>,
651        ) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>> {
652            let duration = self.duration;
653            Box::pin(async move {
654                match tokio::time::timeout(duration, next.run(input)).await {
655                    Ok(result) => result,
656                    Err(_) => T::Output::timeout_error(duration),
657                }
658            })
659        }
660    }
661
662    /// Pass-through interceptor that does nothing.
663    ///
664    /// Useful for testing and as a placeholder.
665    #[derive(Debug, Clone, Default)]
666    pub struct NoOp;
667
668    impl<T> Interceptor<T> for NoOp
669    where
670        T: Interceptable,
671        T::Input: Sync,
672    {
673        fn intercept<'a>(
674            &'a self,
675            input: &'a T::Input,
676            next: Next<'a, T>,
677        ) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>> {
678            Box::pin(next.run(input))
679        }
680    }
681
682    /// Logging interceptor using tracing.
683    ///
684    /// Logs operation start/completion with configurable verbosity.
685    /// Requires the `tracing` feature.
686    ///
687    /// # Example
688    ///
689    /// ```rust,ignore
690    /// use llm_stack::intercept::{InterceptorStack, Logging, LogLevel, ToolExec};
691    ///
692    /// let stack = InterceptorStack::<ToolExec<()>>::new()
693    ///     .with(Logging::new(LogLevel::Debug));
694    /// ```
695    #[cfg(feature = "tracing")]
696    #[derive(Debug, Clone)]
697    pub struct Logging {
698        /// Verbosity level for log output.
699        pub level: LogLevel,
700    }
701
702    /// Verbosity level for logging interceptor.
703    #[cfg(feature = "tracing")]
704    #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
705    pub enum LogLevel {
706        /// Basic logging: operation type and duration.
707        #[default]
708        Info,
709        /// Verbose logging: includes success/failure status.
710        Debug,
711        /// Trace logging: includes input descriptions.
712        Trace,
713    }
714
715    #[cfg(feature = "tracing")]
716    impl Default for Logging {
717        fn default() -> Self {
718            Self {
719                level: LogLevel::Info,
720            }
721        }
722    }
723
724    #[cfg(feature = "tracing")]
725    impl Logging {
726        /// Create a logging interceptor with the given level.
727        pub fn new(level: LogLevel) -> Self {
728            Self { level }
729        }
730    }
731
732    #[cfg(feature = "tracing")]
733    impl<T> Interceptor<T> for Logging
734    where
735        T: Interceptable,
736        T::Input: Sync + Loggable,
737        T::Output: Outcome,
738    {
739        fn intercept<'a>(
740            &'a self,
741            input: &'a T::Input,
742            next: Next<'a, T>,
743        ) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>> {
744            let description = input.log_description();
745            let level = self.level;
746
747            Box::pin(async move {
748                let start = std::time::Instant::now();
749
750                if level == LogLevel::Trace {
751                    tracing::debug!(description = %description, "operation starting");
752                }
753
754                let result = next.run(input).await;
755                let duration = start.elapsed();
756                let success = result.is_success();
757
758                match level {
759                    LogLevel::Info => {
760                        tracing::info!(
761                            duration_ms = duration.as_millis() as u64,
762                            "operation completed"
763                        );
764                    }
765                    LogLevel::Debug | LogLevel::Trace => {
766                        tracing::debug!(
767                            duration_ms = duration.as_millis() as u64,
768                            success,
769                            "operation completed"
770                        );
771                    }
772                }
773
774                result
775            })
776        }
777    }
778}
779
780// Re-export interceptors at module level
781#[cfg(feature = "tracing")]
782pub use interceptors::{LogLevel, Logging};
783pub use interceptors::{NoOp, Retry, Timeout};
784
785// ============================================================================
786// Domain-specific interceptors
787// ============================================================================
788
789/// Tool-specific interceptors.
790pub mod tool_interceptors {
791    use super::{
792        Interceptor, Next,
793        domain::{ToolExec, ToolRequest, ToolResponse},
794    };
795    use serde_json::Value;
796    use std::future::Future;
797    use std::pin::Pin;
798
799    /// Decision returned by an approval check function.
800    #[derive(Debug, Clone)]
801    pub enum ApprovalDecision {
802        /// Allow the tool call to proceed.
803        Allow,
804        /// Deny the tool call with an error message.
805        Deny(String),
806        /// Modify the tool call arguments before proceeding.
807        Modify(Value),
808    }
809
810    /// Approval gate interceptor for tool calls.
811    ///
812    /// Runs a check function before each tool execution. The function can:
813    /// - Allow the call to proceed unchanged
814    /// - Deny the call with an error message
815    /// - Modify the arguments before proceeding
816    ///
817    /// # Example
818    ///
819    /// ```rust,ignore
820    /// use llm_stack::intercept::{InterceptorStack, ToolExec, Approval, ApprovalDecision};
821    ///
822    /// let stack = InterceptorStack::<ToolExec<()>>::new()
823    ///     .with(Approval::new(|req| {
824    ///         if req.name == "delete_file" {
825    ///             ApprovalDecision::Deny("Destructive operations not allowed".into())
826    ///         } else {
827    ///             ApprovalDecision::Allow
828    ///         }
829    ///     }));
830    /// ```
831    pub struct Approval<F> {
832        check: F,
833    }
834
835    impl<F> Approval<F>
836    where
837        F: Fn(&ToolRequest) -> ApprovalDecision + Send + Sync,
838    {
839        /// Create an approval interceptor with the given check function.
840        pub fn new(check: F) -> Self {
841            Self { check }
842        }
843    }
844
845    impl<F> std::fmt::Debug for Approval<F> {
846        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
847            f.debug_struct("Approval").finish_non_exhaustive()
848        }
849    }
850
851    impl<Ctx, F> Interceptor<ToolExec<Ctx>> for Approval<F>
852    where
853        Ctx: Send + Sync + 'static,
854        F: Fn(&ToolRequest) -> ApprovalDecision + Send + Sync,
855    {
856        fn intercept<'a>(
857            &'a self,
858            input: &'a ToolRequest,
859            next: Next<'a, ToolExec<Ctx>>,
860        ) -> Pin<Box<dyn Future<Output = ToolResponse> + Send + 'a>> {
861            Box::pin(async move {
862                match (self.check)(input) {
863                    ApprovalDecision::Allow => next.run(input).await,
864                    ApprovalDecision::Deny(reason) => ToolResponse {
865                        content: reason,
866                        is_error: true,
867                    },
868                    ApprovalDecision::Modify(new_args) => {
869                        let modified = ToolRequest {
870                            name: input.name.clone(),
871                            call_id: input.call_id.clone(),
872                            arguments: new_args,
873                        };
874                        next.run(&modified).await
875                    }
876                }
877            })
878        }
879    }
880}
881
882pub use tool_interceptors::{Approval, ApprovalDecision};
883
884#[cfg(test)]
885mod tests {
886    use super::*;
887    use std::sync::atomic::{AtomicU32, Ordering};
888    use std::time::Duration;
889
890    // Simple test domain
891    struct TestOp;
892
893    impl Interceptable for TestOp {
894        type Input = String;
895        type Output = Result<String, String>;
896    }
897
898    impl behavior::Retryable for Result<String, String> {
899        fn should_retry(&self) -> bool {
900            self.is_err()
901        }
902    }
903
904    impl behavior::Timeoutable for Result<String, String> {
905        fn timeout_error(duration: Duration) -> Self {
906            Err(format!("timeout after {duration:?}"))
907        }
908    }
909
910    struct EchoOp;
911
912    impl Operation<TestOp> for EchoOp {
913        fn execute<'a>(
914            &'a self,
915            input: &'a String,
916        ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
917            Box::pin(async move { Ok(format!("echo: {input}")) })
918        }
919    }
920
921    struct FailOp {
922        failures: AtomicU32,
923        max_failures: u32,
924    }
925
926    impl FailOp {
927        fn new(max_failures: u32) -> Self {
928            Self {
929                failures: AtomicU32::new(0),
930                max_failures,
931            }
932        }
933    }
934
935    impl Operation<TestOp> for FailOp {
936        fn execute<'a>(
937            &'a self,
938            input: &'a String,
939        ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
940            Box::pin(async move {
941                let count = self.failures.fetch_add(1, Ordering::SeqCst);
942                if count < self.max_failures {
943                    let failure_num = count + 1;
944                    Err(format!("failure {failure_num}"))
945                } else {
946                    Ok(format!("success after {count} failures: {input}"))
947                }
948            })
949        }
950    }
951
952    #[tokio::test]
953    async fn empty_stack_passthrough() {
954        let stack = InterceptorStack::<TestOp>::new();
955        let input = "hello".to_string();
956        let result = stack.execute(&input, &EchoOp).await;
957        assert_eq!(result, Ok("echo: hello".to_string()));
958    }
959
960    #[tokio::test]
961    async fn noop_interceptor_passthrough() {
962        let stack = InterceptorStack::<TestOp>::new().with(NoOp);
963        let input = "test".to_string();
964        let result = stack.execute(&input, &EchoOp).await;
965        assert_eq!(result, Ok("echo: test".to_string()));
966    }
967
968    #[tokio::test]
969    async fn multiple_noop_interceptors() {
970        let stack = InterceptorStack::<TestOp>::new()
971            .with(NoOp)
972            .with(NoOp)
973            .with(NoOp);
974        let input = "multi".to_string();
975        let result = stack.execute(&input, &EchoOp).await;
976        assert_eq!(result, Ok("echo: multi".to_string()));
977    }
978
979    #[tokio::test]
980    async fn retry_succeeds_after_failures() {
981        let stack = InterceptorStack::<TestOp>::new().with(Retry::new(3, Duration::from_millis(1)));
982
983        let op = FailOp::new(2); // Fail twice, then succeed
984        let input = "retry-test".to_string();
985        let result = stack.execute(&input, &op).await;
986
987        assert!(result.is_ok());
988        assert!(result.unwrap().contains("success after 2 failures"));
989    }
990
991    #[tokio::test]
992    async fn retry_exhausted() {
993        let stack = InterceptorStack::<TestOp>::new().with(Retry::new(2, Duration::from_millis(1)));
994
995        let op = FailOp::new(10); // Always fail
996        let input = "exhaust".to_string();
997        let result = stack.execute(&input, &op).await;
998
999        assert!(result.is_err());
1000        assert!(result.unwrap_err().contains("failure"));
1001    }
1002
1003    #[tokio::test]
1004    async fn timeout_success() {
1005        let stack = InterceptorStack::<TestOp>::new().with(Timeout::new(Duration::from_secs(1)));
1006        let input = "fast".to_string();
1007        let result = stack.execute(&input, &EchoOp).await;
1008        assert_eq!(result, Ok("echo: fast".to_string()));
1009    }
1010
1011    #[tokio::test]
1012    async fn timeout_expires() {
1013        struct SlowOp;
1014
1015        impl Operation<TestOp> for SlowOp {
1016            fn execute<'a>(
1017                &'a self,
1018                _input: &'a String,
1019            ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
1020                Box::pin(async {
1021                    tokio::time::sleep(Duration::from_secs(10)).await;
1022                    Ok("should not reach".to_string())
1023                })
1024            }
1025        }
1026
1027        let stack = InterceptorStack::<TestOp>::new().with(Timeout::new(Duration::from_millis(10)));
1028        let input = "slow".to_string();
1029        let result = stack.execute(&input, &SlowOp).await;
1030
1031        assert!(result.is_err());
1032        assert!(result.unwrap_err().contains("timeout"));
1033    }
1034
1035    #[tokio::test]
1036    async fn interceptor_ordering() {
1037        use std::sync::Mutex;
1038
1039        struct RecordingInterceptor {
1040            name: &'static str,
1041            log: Arc<Mutex<Vec<String>>>,
1042        }
1043
1044        impl Interceptor<TestOp> for RecordingInterceptor {
1045            fn intercept<'a>(
1046                &'a self,
1047                input: &'a String,
1048                next: Next<'a, TestOp>,
1049            ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
1050                let name = self.name;
1051                let log = Arc::clone(&self.log);
1052                Box::pin(async move {
1053                    log.lock().unwrap().push(format!("{name}-before"));
1054                    let result = next.run(input).await;
1055                    log.lock().unwrap().push(format!("{name}-after"));
1056                    result
1057                })
1058            }
1059        }
1060
1061        let log = Arc::new(Mutex::new(Vec::new()));
1062
1063        let stack = InterceptorStack::<TestOp>::new()
1064            .with(RecordingInterceptor {
1065                name: "A",
1066                log: Arc::clone(&log),
1067            })
1068            .with(RecordingInterceptor {
1069                name: "B",
1070                log: Arc::clone(&log),
1071            });
1072
1073        let input = "order".to_string();
1074        let _ = stack.execute(&input, &EchoOp).await;
1075
1076        let recorded = log.lock().unwrap().clone();
1077        assert_eq!(recorded, vec!["A-before", "B-before", "B-after", "A-after"]);
1078    }
1079
1080    #[tokio::test]
1081    async fn short_circuit_interceptor() {
1082        struct ShortCircuit;
1083
1084        impl Interceptor<TestOp> for ShortCircuit {
1085            fn intercept<'a>(
1086                &'a self,
1087                _input: &'a String,
1088                _next: Next<'a, TestOp>,
1089            ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
1090                Box::pin(async { Err("short-circuited".to_string()) })
1091            }
1092        }
1093
1094        let stack = InterceptorStack::<TestOp>::new()
1095            .with(ShortCircuit)
1096            .with(NoOp); // This should never run
1097
1098        let input = "blocked".to_string();
1099        let result = stack.execute(&input, &EchoOp).await;
1100
1101        assert_eq!(result, Err("short-circuited".to_string()));
1102    }
1103
1104    #[tokio::test]
1105    async fn execute_with_closure() {
1106        let stack = InterceptorStack::<TestOp>::new().with(NoOp);
1107
1108        let input = "closure-test".to_string();
1109        let result = stack
1110            .execute_fn(&input, |i| Box::pin(async move { Ok(format!("fn: {i}")) }))
1111            .await;
1112
1113        assert_eq!(result, Ok("fn: closure-test".to_string()));
1114    }
1115
1116    #[tokio::test]
1117    async fn next_is_copy() {
1118        // Test that Next can be used multiple times (for retry)
1119        struct MultiCallInterceptor {
1120            calls: AtomicU32,
1121        }
1122
1123        impl Interceptor<TestOp> for MultiCallInterceptor {
1124            fn intercept<'a>(
1125                &'a self,
1126                input: &'a String,
1127                next: Next<'a, TestOp>,
1128            ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
1129                Box::pin(async move {
1130                    // Call next twice to verify Copy works
1131                    let _ = next.run(input).await;
1132                    self.calls.fetch_add(1, Ordering::SeqCst);
1133                    next.run(input).await
1134                })
1135            }
1136        }
1137
1138        let interceptor = MultiCallInterceptor {
1139            calls: AtomicU32::new(0),
1140        };
1141
1142        let stack = InterceptorStack::<TestOp>::new().with(interceptor);
1143        let input = "copy-test".to_string();
1144        let result = stack.execute(&input, &EchoOp).await;
1145
1146        assert_eq!(result, Ok("echo: copy-test".to_string()));
1147    }
1148
1149    #[tokio::test]
1150    async fn shared_interceptor() {
1151        let shared: Arc<dyn Interceptor<TestOp>> = Arc::new(NoOp);
1152
1153        let stack1 = InterceptorStack::<TestOp>::new().with_shared(Arc::clone(&shared));
1154
1155        let stack2 = InterceptorStack::<TestOp>::new().with_shared(Arc::clone(&shared));
1156
1157        let input = "shared".to_string();
1158        let r1 = stack1.execute(&input, &EchoOp).await;
1159        let r2 = stack2.execute(&input, &EchoOp).await;
1160
1161        assert_eq!(r1, Ok("echo: shared".to_string()));
1162        assert_eq!(r2, Ok("echo: shared".to_string()));
1163    }
1164
1165    #[test]
1166    fn stack_len_and_is_empty() {
1167        let empty: InterceptorStack<TestOp> = InterceptorStack::new();
1168        assert!(empty.is_empty());
1169        assert_eq!(empty.len(), 0);
1170
1171        let one = InterceptorStack::<TestOp>::new().with(NoOp);
1172        assert!(!one.is_empty());
1173        assert_eq!(one.len(), 1);
1174
1175        let two = InterceptorStack::<TestOp>::new().with(NoOp).with(NoOp);
1176        assert_eq!(two.len(), 2);
1177    }
1178
1179    // =========================================================================
1180    // Approval interceptor tests
1181    // =========================================================================
1182
1183    mod approval_tests {
1184        use super::*;
1185        use crate::intercept::domain::{ToolExec, ToolRequest, ToolResponse};
1186        use crate::intercept::tool_interceptors::{Approval, ApprovalDecision};
1187        use serde_json::json;
1188
1189        struct EchoToolOp;
1190
1191        impl Operation<ToolExec<()>> for EchoToolOp {
1192            fn execute<'a>(
1193                &'a self,
1194                input: &'a ToolRequest,
1195            ) -> Pin<Box<dyn Future<Output = ToolResponse> + Send + 'a>> {
1196                Box::pin(async move {
1197                    ToolResponse {
1198                        content: format!("executed: {} with {:?}", input.name, input.arguments),
1199                        is_error: false,
1200                    }
1201                })
1202            }
1203        }
1204
1205        #[tokio::test]
1206        async fn approval_allow() {
1207            let stack = InterceptorStack::<ToolExec<()>>::new()
1208                .with(Approval::new(|_| ApprovalDecision::Allow));
1209
1210            let input = ToolRequest {
1211                name: "test_tool".into(),
1212                call_id: "call_1".into(),
1213                arguments: json!({"x": 1}),
1214            };
1215
1216            let result = stack.execute(&input, &EchoToolOp).await;
1217            assert!(!result.is_error);
1218            assert!(result.content.contains("test_tool"));
1219        }
1220
1221        #[tokio::test]
1222        async fn approval_deny() {
1223            let stack = InterceptorStack::<ToolExec<()>>::new().with(Approval::new(|req| {
1224                if req.name == "dangerous" {
1225                    ApprovalDecision::Deny("Not allowed".into())
1226                } else {
1227                    ApprovalDecision::Allow
1228                }
1229            }));
1230
1231            let input = ToolRequest {
1232                name: "dangerous".into(),
1233                call_id: "call_2".into(),
1234                arguments: json!({}),
1235            };
1236
1237            let result = stack.execute(&input, &EchoToolOp).await;
1238            assert!(result.is_error);
1239            assert_eq!(result.content, "Not allowed");
1240        }
1241
1242        #[tokio::test]
1243        async fn approval_modify() {
1244            let stack = InterceptorStack::<ToolExec<()>>::new().with(Approval::new(|req| {
1245                // Always add a "modified" field
1246                let mut args = req.arguments.clone();
1247                args["modified"] = json!(true);
1248                ApprovalDecision::Modify(args)
1249            }));
1250
1251            let input = ToolRequest {
1252                name: "my_tool".into(),
1253                call_id: "call_3".into(),
1254                arguments: json!({"original": "value"}),
1255            };
1256
1257            let result = stack.execute(&input, &EchoToolOp).await;
1258            assert!(!result.is_error);
1259            assert!(result.content.contains("modified"));
1260            assert!(result.content.contains("true"));
1261        }
1262
1263        #[tokio::test]
1264        async fn approval_debug() {
1265            let approval = Approval::new(|_: &ToolRequest| ApprovalDecision::Allow);
1266            let debug_str = format!("{approval:?}");
1267            assert!(debug_str.contains("Approval"));
1268        }
1269    }
1270}