#![allow(clippy::must_use_candidate)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_possible_wrap)]
#![allow(clippy::needless_lifetimes)]
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
pub use behavior::{Loggable, Outcome, Retryable, Timeoutable};
pub use domain::{LlmCall, ToolExec, ToolRequest, ToolResponse};
pub trait Interceptable: Send + Sync + 'static {
type Input: Send;
type Output: Send;
}
pub trait Interceptor<T: Interceptable>: Send + Sync {
fn intercept<'a>(
&'a self,
input: &'a T::Input,
next: Next<'a, T>,
) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>>;
}
pub struct Next<'a, T: Interceptable> {
interceptors: &'a [Arc<dyn Interceptor<T>>],
operation: &'a dyn Operation<T>,
}
impl<T: Interceptable> Clone for Next<'_, T> {
fn clone(&self) -> Self {
*self
}
}
impl<T: Interceptable> Copy for Next<'_, T> {}
impl<T: Interceptable> Next<'_, T>
where
T::Input: Sync,
{
pub async fn run(self, input: &T::Input) -> T::Output {
if let Some((first, rest)) = self.interceptors.split_first() {
let next = Next {
interceptors: rest,
operation: self.operation,
};
first.intercept(input, next).await
} else {
self.operation.execute(input).await
}
}
}
pub trait Operation<T: Interceptable>: Send + Sync {
fn execute<'a>(
&'a self,
input: &'a T::Input,
) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>>
where
T::Input: Sync;
}
pub struct FnOperation<T, F>
where
T: Interceptable,
F: Fn(&T::Input) -> Pin<Box<dyn Future<Output = T::Output> + Send + '_>> + Send + Sync,
{
f: F,
_marker: PhantomData<T>,
}
impl<T, F> FnOperation<T, F>
where
T: Interceptable,
F: Fn(&T::Input) -> Pin<Box<dyn Future<Output = T::Output> + Send + '_>> + Send + Sync,
{
pub fn new(f: F) -> Self {
Self {
f,
_marker: PhantomData,
}
}
}
impl<T, F> Operation<T> for FnOperation<T, F>
where
T: Interceptable,
F: Fn(&T::Input) -> Pin<Box<dyn Future<Output = T::Output> + Send + '_>> + Send + Sync,
{
fn execute<'a>(
&'a self,
input: &'a T::Input,
) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>>
where
T::Input: Sync,
{
(self.f)(input)
}
}
pub struct InterceptorStack<T: Interceptable> {
layers: Vec<Arc<dyn Interceptor<T>>>,
}
impl<T: Interceptable> Clone for InterceptorStack<T> {
fn clone(&self) -> Self {
Self {
layers: self.layers.clone(),
}
}
}
impl<T: Interceptable> InterceptorStack<T> {
pub fn new() -> Self {
Self { layers: Vec::new() }
}
#[must_use]
pub fn with<I: Interceptor<T> + 'static>(mut self, interceptor: I) -> Self {
self.layers.push(Arc::new(interceptor));
self
}
#[must_use]
pub fn with_shared(mut self, interceptor: Arc<dyn Interceptor<T>>) -> Self {
self.layers.push(interceptor);
self
}
pub fn is_empty(&self) -> bool {
self.layers.is_empty()
}
pub fn len(&self) -> usize {
self.layers.len()
}
pub async fn execute<'a, O>(&'a self, input: &'a T::Input, operation: &'a O) -> T::Output
where
T::Input: Sync,
O: Operation<T>,
{
let next = Next {
interceptors: &self.layers,
operation,
};
next.run(input).await
}
pub async fn execute_fn<'a, F>(&'a self, input: &'a T::Input, f: F) -> T::Output
where
T::Input: Sync,
F: Fn(&T::Input) -> Pin<Box<dyn Future<Output = T::Output> + Send + '_>> + Send + Sync,
{
let op = FnOperation::<T, F>::new(f);
self.execute(input, &op).await
}
}
impl<T: Interceptable> Default for InterceptorStack<T> {
fn default() -> Self {
Self::new()
}
}
pub mod domain {
use super::Interceptable;
use crate::ChatResponse;
use crate::error::LlmError;
use crate::provider::ChatParams;
use serde_json::Value;
use std::marker::PhantomData;
pub struct LlmCall;
impl Interceptable for LlmCall {
type Input = ChatParams;
type Output = Result<ChatResponse, LlmError>;
}
pub struct ToolExec<Ctx = ()>(PhantomData<fn() -> Ctx>);
impl<Ctx: Send + Sync + 'static> Interceptable for ToolExec<Ctx> {
type Input = ToolRequest;
type Output = ToolResponse;
}
#[derive(Debug, Clone)]
pub struct ToolRequest {
pub name: String,
pub call_id: String,
pub arguments: Value,
}
#[derive(Debug, Clone)]
pub struct ToolResponse {
pub content: String,
pub is_error: bool,
}
impl ToolResponse {
pub fn success(content: impl Into<String>) -> Self {
Self {
content: content.into(),
is_error: false,
}
}
pub fn error(content: impl Into<String>) -> Self {
Self {
content: content.into(),
is_error: true,
}
}
}
}
pub mod behavior {
use crate::ChatResponse;
use crate::error::LlmError;
use crate::provider::ChatParams;
use std::time::Duration;
use super::domain::{ToolRequest, ToolResponse};
pub trait Retryable {
fn should_retry(&self) -> bool;
}
impl Retryable for Result<ChatResponse, LlmError> {
fn should_retry(&self) -> bool {
match self {
Ok(_) => false,
Err(e) => e.is_retryable(),
}
}
}
impl Retryable for ToolResponse {
fn should_retry(&self) -> bool {
false
}
}
pub trait Timeoutable: Sized {
fn timeout_error(duration: Duration) -> Self;
}
impl Timeoutable for Result<ChatResponse, LlmError> {
fn timeout_error(duration: Duration) -> Self {
Err(LlmError::Timeout {
elapsed_ms: duration.as_millis() as u64,
})
}
}
impl Timeoutable for ToolResponse {
fn timeout_error(duration: Duration) -> Self {
ToolResponse {
content: format!("Tool execution timed out after {duration:?}"),
is_error: true,
}
}
}
pub trait Loggable {
fn log_description(&self) -> String;
}
impl Loggable for ChatParams {
fn log_description(&self) -> String {
let tool_count = self.tools.as_ref().map_or(0, Vec::len);
format!(
"LLM request: {} messages, {} tools",
self.messages.len(),
tool_count
)
}
}
impl Loggable for ToolRequest {
fn log_description(&self) -> String {
format!("Tool call: {} ({})", self.name, self.call_id)
}
}
pub trait Outcome {
fn is_success(&self) -> bool;
}
impl Outcome for Result<ChatResponse, LlmError> {
fn is_success(&self) -> bool {
self.is_ok()
}
}
impl Outcome for ToolResponse {
fn is_success(&self) -> bool {
!self.is_error
}
}
}
pub mod interceptors {
#[cfg(feature = "tracing")]
use super::behavior::{Loggable, Outcome};
use super::behavior::{Retryable, Timeoutable};
use super::{Interceptable, Interceptor, Next};
use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct Retry {
pub max_attempts: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub multiplier: f64,
}
impl Default for Retry {
fn default() -> Self {
Self {
max_attempts: 3,
initial_delay: Duration::from_millis(500),
max_delay: Duration::from_secs(30),
multiplier: 2.0,
}
}
}
impl Retry {
pub fn new(max_attempts: u32, initial_delay: Duration) -> Self {
Self {
max_attempts,
initial_delay,
..Default::default()
}
}
fn delay_for_attempt(&self, attempt: u32) -> Duration {
let delay_ms = self.initial_delay.as_millis() as f64
* self.multiplier.powi(attempt.saturating_sub(1) as i32);
let delay = Duration::from_millis(delay_ms as u64);
std::cmp::min(delay, self.max_delay)
}
}
impl<T> Interceptor<T> for Retry
where
T: Interceptable,
T::Input: Sync,
T::Output: Retryable,
{
fn intercept<'a>(
&'a self,
input: &'a T::Input,
next: Next<'a, T>,
) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>> {
Box::pin(async move {
let mut last_result: Option<T::Output> = None;
for attempt in 1..=self.max_attempts {
let result = next.run(input).await;
if !result.should_retry() || attempt == self.max_attempts {
return result;
}
let delay = self.delay_for_attempt(attempt);
tokio::time::sleep(delay).await;
last_result = Some(result);
}
last_result.expect("at least one attempt should have been made")
})
}
}
#[derive(Debug, Clone)]
pub struct Timeout {
pub duration: Duration,
}
impl Timeout {
pub fn new(duration: Duration) -> Self {
Self { duration }
}
}
impl<T> Interceptor<T> for Timeout
where
T: Interceptable,
T::Input: Sync,
T::Output: Timeoutable,
{
fn intercept<'a>(
&'a self,
input: &'a T::Input,
next: Next<'a, T>,
) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>> {
let duration = self.duration;
Box::pin(async move {
match tokio::time::timeout(duration, next.run(input)).await {
Ok(result) => result,
Err(_) => T::Output::timeout_error(duration),
}
})
}
}
#[derive(Debug, Clone, Default)]
pub struct NoOp;
impl<T> Interceptor<T> for NoOp
where
T: Interceptable,
T::Input: Sync,
{
fn intercept<'a>(
&'a self,
input: &'a T::Input,
next: Next<'a, T>,
) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>> {
Box::pin(next.run(input))
}
}
#[cfg(feature = "tracing")]
#[derive(Debug, Clone)]
pub struct Logging {
pub level: LogLevel,
}
#[cfg(feature = "tracing")]
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum LogLevel {
#[default]
Info,
Debug,
Trace,
}
#[cfg(feature = "tracing")]
impl Default for Logging {
fn default() -> Self {
Self {
level: LogLevel::Info,
}
}
}
#[cfg(feature = "tracing")]
impl Logging {
pub fn new(level: LogLevel) -> Self {
Self { level }
}
}
#[cfg(feature = "tracing")]
impl<T> Interceptor<T> for Logging
where
T: Interceptable,
T::Input: Sync + Loggable,
T::Output: Outcome,
{
fn intercept<'a>(
&'a self,
input: &'a T::Input,
next: Next<'a, T>,
) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>> {
let description = input.log_description();
let level = self.level;
Box::pin(async move {
let start = std::time::Instant::now();
if level == LogLevel::Trace {
tracing::debug!(description = %description, "operation starting");
}
let result = next.run(input).await;
let duration = start.elapsed();
let success = result.is_success();
match level {
LogLevel::Info => {
tracing::info!(
duration_ms = duration.as_millis() as u64,
"operation completed"
);
}
LogLevel::Debug | LogLevel::Trace => {
tracing::debug!(
duration_ms = duration.as_millis() as u64,
success,
"operation completed"
);
}
}
result
})
}
}
}
#[cfg(feature = "tracing")]
pub use interceptors::{LogLevel, Logging};
pub use interceptors::{NoOp, Retry, Timeout};
pub mod tool_interceptors {
use super::{
Interceptor, Next,
domain::{ToolExec, ToolRequest, ToolResponse},
};
use serde_json::Value;
use std::future::Future;
use std::pin::Pin;
#[derive(Debug, Clone)]
pub enum ApprovalDecision {
Allow,
Deny(String),
Modify(Value),
}
pub struct Approval<F> {
check: F,
}
impl<F> Approval<F>
where
F: Fn(&ToolRequest) -> ApprovalDecision + Send + Sync,
{
pub fn new(check: F) -> Self {
Self { check }
}
}
impl<F> std::fmt::Debug for Approval<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Approval").finish_non_exhaustive()
}
}
impl<Ctx, F> Interceptor<ToolExec<Ctx>> for Approval<F>
where
Ctx: Send + Sync + 'static,
F: Fn(&ToolRequest) -> ApprovalDecision + Send + Sync,
{
fn intercept<'a>(
&'a self,
input: &'a ToolRequest,
next: Next<'a, ToolExec<Ctx>>,
) -> Pin<Box<dyn Future<Output = ToolResponse> + Send + 'a>> {
Box::pin(async move {
match (self.check)(input) {
ApprovalDecision::Allow => next.run(input).await,
ApprovalDecision::Deny(reason) => ToolResponse {
content: reason,
is_error: true,
},
ApprovalDecision::Modify(new_args) => {
let modified = ToolRequest {
name: input.name.clone(),
call_id: input.call_id.clone(),
arguments: new_args,
};
next.run(&modified).await
}
}
})
}
}
}
pub use tool_interceptors::{Approval, ApprovalDecision};
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
struct TestOp;
impl Interceptable for TestOp {
type Input = String;
type Output = Result<String, String>;
}
impl behavior::Retryable for Result<String, String> {
fn should_retry(&self) -> bool {
self.is_err()
}
}
impl behavior::Timeoutable for Result<String, String> {
fn timeout_error(duration: Duration) -> Self {
Err(format!("timeout after {duration:?}"))
}
}
struct EchoOp;
impl Operation<TestOp> for EchoOp {
fn execute<'a>(
&'a self,
input: &'a String,
) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
Box::pin(async move { Ok(format!("echo: {input}")) })
}
}
struct FailOp {
failures: AtomicU32,
max_failures: u32,
}
impl FailOp {
fn new(max_failures: u32) -> Self {
Self {
failures: AtomicU32::new(0),
max_failures,
}
}
}
impl Operation<TestOp> for FailOp {
fn execute<'a>(
&'a self,
input: &'a String,
) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
Box::pin(async move {
let count = self.failures.fetch_add(1, Ordering::SeqCst);
if count < self.max_failures {
let failure_num = count + 1;
Err(format!("failure {failure_num}"))
} else {
Ok(format!("success after {count} failures: {input}"))
}
})
}
}
#[tokio::test]
async fn empty_stack_passthrough() {
let stack = InterceptorStack::<TestOp>::new();
let input = "hello".to_string();
let result = stack.execute(&input, &EchoOp).await;
assert_eq!(result, Ok("echo: hello".to_string()));
}
#[tokio::test]
async fn noop_interceptor_passthrough() {
let stack = InterceptorStack::<TestOp>::new().with(NoOp);
let input = "test".to_string();
let result = stack.execute(&input, &EchoOp).await;
assert_eq!(result, Ok("echo: test".to_string()));
}
#[tokio::test]
async fn multiple_noop_interceptors() {
let stack = InterceptorStack::<TestOp>::new()
.with(NoOp)
.with(NoOp)
.with(NoOp);
let input = "multi".to_string();
let result = stack.execute(&input, &EchoOp).await;
assert_eq!(result, Ok("echo: multi".to_string()));
}
#[tokio::test]
async fn retry_succeeds_after_failures() {
let stack = InterceptorStack::<TestOp>::new().with(Retry::new(3, Duration::from_millis(1)));
let op = FailOp::new(2); let input = "retry-test".to_string();
let result = stack.execute(&input, &op).await;
assert!(result.is_ok());
assert!(result.unwrap().contains("success after 2 failures"));
}
#[tokio::test]
async fn retry_exhausted() {
let stack = InterceptorStack::<TestOp>::new().with(Retry::new(2, Duration::from_millis(1)));
let op = FailOp::new(10); let input = "exhaust".to_string();
let result = stack.execute(&input, &op).await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("failure"));
}
#[tokio::test]
async fn timeout_success() {
let stack = InterceptorStack::<TestOp>::new().with(Timeout::new(Duration::from_secs(1)));
let input = "fast".to_string();
let result = stack.execute(&input, &EchoOp).await;
assert_eq!(result, Ok("echo: fast".to_string()));
}
#[tokio::test]
async fn timeout_expires() {
struct SlowOp;
impl Operation<TestOp> for SlowOp {
fn execute<'a>(
&'a self,
_input: &'a String,
) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
Box::pin(async {
tokio::time::sleep(Duration::from_secs(10)).await;
Ok("should not reach".to_string())
})
}
}
let stack = InterceptorStack::<TestOp>::new().with(Timeout::new(Duration::from_millis(10)));
let input = "slow".to_string();
let result = stack.execute(&input, &SlowOp).await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("timeout"));
}
#[tokio::test]
async fn interceptor_ordering() {
use std::sync::Mutex;
struct RecordingInterceptor {
name: &'static str,
log: Arc<Mutex<Vec<String>>>,
}
impl Interceptor<TestOp> for RecordingInterceptor {
fn intercept<'a>(
&'a self,
input: &'a String,
next: Next<'a, TestOp>,
) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
let name = self.name;
let log = Arc::clone(&self.log);
Box::pin(async move {
log.lock().unwrap().push(format!("{name}-before"));
let result = next.run(input).await;
log.lock().unwrap().push(format!("{name}-after"));
result
})
}
}
let log = Arc::new(Mutex::new(Vec::new()));
let stack = InterceptorStack::<TestOp>::new()
.with(RecordingInterceptor {
name: "A",
log: Arc::clone(&log),
})
.with(RecordingInterceptor {
name: "B",
log: Arc::clone(&log),
});
let input = "order".to_string();
let _ = stack.execute(&input, &EchoOp).await;
let recorded = log.lock().unwrap().clone();
assert_eq!(recorded, vec!["A-before", "B-before", "B-after", "A-after"]);
}
#[tokio::test]
async fn short_circuit_interceptor() {
struct ShortCircuit;
impl Interceptor<TestOp> for ShortCircuit {
fn intercept<'a>(
&'a self,
_input: &'a String,
_next: Next<'a, TestOp>,
) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
Box::pin(async { Err("short-circuited".to_string()) })
}
}
let stack = InterceptorStack::<TestOp>::new()
.with(ShortCircuit)
.with(NoOp);
let input = "blocked".to_string();
let result = stack.execute(&input, &EchoOp).await;
assert_eq!(result, Err("short-circuited".to_string()));
}
#[tokio::test]
async fn execute_with_closure() {
let stack = InterceptorStack::<TestOp>::new().with(NoOp);
let input = "closure-test".to_string();
let result = stack
.execute_fn(&input, |i| Box::pin(async move { Ok(format!("fn: {i}")) }))
.await;
assert_eq!(result, Ok("fn: closure-test".to_string()));
}
#[tokio::test]
async fn next_is_copy() {
struct MultiCallInterceptor {
calls: AtomicU32,
}
impl Interceptor<TestOp> for MultiCallInterceptor {
fn intercept<'a>(
&'a self,
input: &'a String,
next: Next<'a, TestOp>,
) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
Box::pin(async move {
let _ = next.run(input).await;
self.calls.fetch_add(1, Ordering::SeqCst);
next.run(input).await
})
}
}
let interceptor = MultiCallInterceptor {
calls: AtomicU32::new(0),
};
let stack = InterceptorStack::<TestOp>::new().with(interceptor);
let input = "copy-test".to_string();
let result = stack.execute(&input, &EchoOp).await;
assert_eq!(result, Ok("echo: copy-test".to_string()));
}
#[tokio::test]
async fn shared_interceptor() {
let shared: Arc<dyn Interceptor<TestOp>> = Arc::new(NoOp);
let stack1 = InterceptorStack::<TestOp>::new().with_shared(Arc::clone(&shared));
let stack2 = InterceptorStack::<TestOp>::new().with_shared(Arc::clone(&shared));
let input = "shared".to_string();
let r1 = stack1.execute(&input, &EchoOp).await;
let r2 = stack2.execute(&input, &EchoOp).await;
assert_eq!(r1, Ok("echo: shared".to_string()));
assert_eq!(r2, Ok("echo: shared".to_string()));
}
#[test]
fn stack_len_and_is_empty() {
let empty: InterceptorStack<TestOp> = InterceptorStack::new();
assert!(empty.is_empty());
assert_eq!(empty.len(), 0);
let one = InterceptorStack::<TestOp>::new().with(NoOp);
assert!(!one.is_empty());
assert_eq!(one.len(), 1);
let two = InterceptorStack::<TestOp>::new().with(NoOp).with(NoOp);
assert_eq!(two.len(), 2);
}
mod approval_tests {
use super::*;
use crate::intercept::domain::{ToolExec, ToolRequest, ToolResponse};
use crate::intercept::tool_interceptors::{Approval, ApprovalDecision};
use serde_json::json;
struct EchoToolOp;
impl Operation<ToolExec<()>> for EchoToolOp {
fn execute<'a>(
&'a self,
input: &'a ToolRequest,
) -> Pin<Box<dyn Future<Output = ToolResponse> + Send + 'a>> {
Box::pin(async move {
ToolResponse {
content: format!("executed: {} with {:?}", input.name, input.arguments),
is_error: false,
}
})
}
}
#[tokio::test]
async fn approval_allow() {
let stack = InterceptorStack::<ToolExec<()>>::new()
.with(Approval::new(|_| ApprovalDecision::Allow));
let input = ToolRequest {
name: "test_tool".into(),
call_id: "call_1".into(),
arguments: json!({"x": 1}),
};
let result = stack.execute(&input, &EchoToolOp).await;
assert!(!result.is_error);
assert!(result.content.contains("test_tool"));
}
#[tokio::test]
async fn approval_deny() {
let stack = InterceptorStack::<ToolExec<()>>::new().with(Approval::new(|req| {
if req.name == "dangerous" {
ApprovalDecision::Deny("Not allowed".into())
} else {
ApprovalDecision::Allow
}
}));
let input = ToolRequest {
name: "dangerous".into(),
call_id: "call_2".into(),
arguments: json!({}),
};
let result = stack.execute(&input, &EchoToolOp).await;
assert!(result.is_error);
assert_eq!(result.content, "Not allowed");
}
#[tokio::test]
async fn approval_modify() {
let stack = InterceptorStack::<ToolExec<()>>::new().with(Approval::new(|req| {
let mut args = req.arguments.clone();
args["modified"] = json!(true);
ApprovalDecision::Modify(args)
}));
let input = ToolRequest {
name: "my_tool".into(),
call_id: "call_3".into(),
arguments: json!({"original": "value"}),
};
let result = stack.execute(&input, &EchoToolOp).await;
assert!(!result.is_error);
assert!(result.content.contains("modified"));
assert!(result.content.contains("true"));
}
#[tokio::test]
async fn approval_debug() {
let approval = Approval::new(|_: &ToolRequest| ApprovalDecision::Allow);
let debug_str = format!("{approval:?}");
assert!(debug_str.contains("Approval"));
}
}
}