1#![allow(clippy::must_use_candidate)]
3#![allow(clippy::cast_possible_truncation)]
5#![allow(clippy::cast_sign_loss)]
6#![allow(clippy::cast_precision_loss)]
7#![allow(clippy::cast_possible_wrap)]
8#![allow(clippy::needless_lifetimes)]
10
11use std::future::Future;
56use std::marker::PhantomData;
57use std::pin::Pin;
58use std::sync::Arc;
59
60pub use behavior::{Loggable, Outcome, Retryable, Timeoutable};
62pub use domain::{LlmCall, ToolExec, ToolRequest, ToolResponse};
63
64pub trait Interceptable: Send + Sync + 'static {
73 type Input: Send;
75
76 type Output: Send;
78}
79
80pub trait Interceptor<T: Interceptable>: Send + Sync {
118 fn intercept<'a>(
122 &'a self,
123 input: &'a T::Input,
124 next: Next<'a, T>,
125 ) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>>;
126}
127
128pub struct Next<'a, T: Interceptable> {
133 interceptors: &'a [Arc<dyn Interceptor<T>>],
134 operation: &'a dyn Operation<T>,
135}
136
137impl<T: Interceptable> Clone for Next<'_, T> {
138 fn clone(&self) -> Self {
139 *self
140 }
141}
142
143impl<T: Interceptable> Copy for Next<'_, T> {}
145
146impl<T: Interceptable> Next<'_, T>
147where
148 T::Input: Sync,
149{
150 pub async fn run(self, input: &T::Input) -> T::Output {
155 if let Some((first, rest)) = self.interceptors.split_first() {
156 let next = Next {
157 interceptors: rest,
158 operation: self.operation,
159 };
160 first.intercept(input, next).await
161 } else {
162 self.operation.execute(input).await
163 }
164 }
165}
166
167pub trait Operation<T: Interceptable>: Send + Sync {
171 fn execute<'a>(
173 &'a self,
174 input: &'a T::Input,
175 ) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>>
176 where
177 T::Input: Sync;
178}
179
180pub struct FnOperation<T, F>
182where
183 T: Interceptable,
184 F: Fn(&T::Input) -> Pin<Box<dyn Future<Output = T::Output> + Send + '_>> + Send + Sync,
185{
186 f: F,
187 _marker: PhantomData<T>,
188}
189
190impl<T, F> FnOperation<T, F>
191where
192 T: Interceptable,
193 F: Fn(&T::Input) -> Pin<Box<dyn Future<Output = T::Output> + Send + '_>> + Send + Sync,
194{
195 pub fn new(f: F) -> Self {
197 Self {
198 f,
199 _marker: PhantomData,
200 }
201 }
202}
203
204impl<T, F> Operation<T> for FnOperation<T, F>
205where
206 T: Interceptable,
207 F: Fn(&T::Input) -> Pin<Box<dyn Future<Output = T::Output> + Send + '_>> + Send + Sync,
208{
209 fn execute<'a>(
210 &'a self,
211 input: &'a T::Input,
212 ) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>>
213 where
214 T::Input: Sync,
215 {
216 (self.f)(input)
217 }
218}
219
220pub struct InterceptorStack<T: Interceptable> {
235 layers: Vec<Arc<dyn Interceptor<T>>>,
236}
237
238impl<T: Interceptable> Clone for InterceptorStack<T> {
239 fn clone(&self) -> Self {
240 Self {
241 layers: self.layers.clone(),
242 }
243 }
244}
245
246impl<T: Interceptable> InterceptorStack<T> {
247 pub fn new() -> Self {
249 Self { layers: Vec::new() }
250 }
251
252 #[must_use]
256 pub fn with<I: Interceptor<T> + 'static>(mut self, interceptor: I) -> Self {
257 self.layers.push(Arc::new(interceptor));
258 self
259 }
260
261 #[must_use]
266 pub fn with_shared(mut self, interceptor: Arc<dyn Interceptor<T>>) -> Self {
267 self.layers.push(interceptor);
268 self
269 }
270
271 pub fn is_empty(&self) -> bool {
273 self.layers.is_empty()
274 }
275
276 pub fn len(&self) -> usize {
278 self.layers.len()
279 }
280
281 pub async fn execute<'a, O>(&'a self, input: &'a T::Input, operation: &'a O) -> T::Output
283 where
284 T::Input: Sync,
285 O: Operation<T>,
286 {
287 let next = Next {
288 interceptors: &self.layers,
289 operation,
290 };
291 next.run(input).await
292 }
293
294 pub async fn execute_fn<'a, F>(&'a self, input: &'a T::Input, f: F) -> T::Output
296 where
297 T::Input: Sync,
298 F: Fn(&T::Input) -> Pin<Box<dyn Future<Output = T::Output> + Send + '_>> + Send + Sync,
299 {
300 let op = FnOperation::<T, F>::new(f);
301 self.execute(input, &op).await
302 }
303}
304
305impl<T: Interceptable> Default for InterceptorStack<T> {
306 fn default() -> Self {
307 Self::new()
308 }
309}
310
311pub mod domain {
317 use super::Interceptable;
318 use crate::ChatResponse;
319 use crate::error::LlmError;
320 use crate::provider::ChatParams;
321 use serde_json::Value;
322 use std::marker::PhantomData;
323
324 pub struct LlmCall;
343
344 impl Interceptable for LlmCall {
345 type Input = ChatParams;
346 type Output = Result<ChatResponse, LlmError>;
347 }
348
349 pub struct ToolExec<Ctx = ()>(PhantomData<fn() -> Ctx>);
356
357 impl<Ctx: Send + Sync + 'static> Interceptable for ToolExec<Ctx> {
358 type Input = ToolRequest;
359 type Output = ToolResponse;
360 }
361
362 #[derive(Debug, Clone)]
364 pub struct ToolRequest {
365 pub name: String,
367
368 pub call_id: String,
370
371 pub arguments: Value,
373 }
374
375 #[derive(Debug, Clone)]
377 pub struct ToolResponse {
378 pub content: String,
380
381 pub is_error: bool,
383 }
384
385 impl ToolResponse {
386 pub fn success(content: impl Into<String>) -> Self {
388 Self {
389 content: content.into(),
390 is_error: false,
391 }
392 }
393
394 pub fn error(content: impl Into<String>) -> Self {
396 Self {
397 content: content.into(),
398 is_error: true,
399 }
400 }
401 }
402}
403
404pub mod behavior {
410 use crate::ChatResponse;
411 use crate::error::LlmError;
412 use crate::provider::ChatParams;
413 use std::time::Duration;
414
415 use super::domain::{ToolRequest, ToolResponse};
416
417 pub trait Retryable {
419 fn should_retry(&self) -> bool;
421 }
422
423 impl Retryable for Result<ChatResponse, LlmError> {
424 fn should_retry(&self) -> bool {
425 match self {
426 Ok(_) => false,
427 Err(e) => e.is_retryable(),
428 }
429 }
430 }
431
432 impl Retryable for ToolResponse {
433 fn should_retry(&self) -> bool {
434 false
437 }
438 }
439
440 pub trait Timeoutable: Sized {
442 fn timeout_error(duration: Duration) -> Self;
444 }
445
446 impl Timeoutable for Result<ChatResponse, LlmError> {
447 fn timeout_error(duration: Duration) -> Self {
448 Err(LlmError::Timeout {
449 elapsed_ms: duration.as_millis() as u64,
450 })
451 }
452 }
453
454 impl Timeoutable for ToolResponse {
455 fn timeout_error(duration: Duration) -> Self {
456 ToolResponse {
457 content: format!("Tool execution timed out after {duration:?}"),
458 is_error: true,
459 }
460 }
461 }
462
463 pub trait Loggable {
465 fn log_description(&self) -> String;
467 }
468
469 impl Loggable for ChatParams {
470 fn log_description(&self) -> String {
471 let tool_count = self.tools.as_ref().map_or(0, Vec::len);
472 format!(
473 "LLM request: {} messages, {} tools",
474 self.messages.len(),
475 tool_count
476 )
477 }
478 }
479
480 impl Loggable for ToolRequest {
481 fn log_description(&self) -> String {
482 format!("Tool call: {} ({})", self.name, self.call_id)
483 }
484 }
485
486 pub trait Outcome {
492 fn is_success(&self) -> bool;
494 }
495
496 impl Outcome for Result<ChatResponse, LlmError> {
497 fn is_success(&self) -> bool {
498 self.is_ok()
499 }
500 }
501
502 impl Outcome for ToolResponse {
503 fn is_success(&self) -> bool {
504 !self.is_error
505 }
506 }
507}
508
509pub mod interceptors {
515 #[cfg(feature = "tracing")]
516 use super::behavior::{Loggable, Outcome};
517 use super::behavior::{Retryable, Timeoutable};
518 use super::{Interceptable, Interceptor, Next};
519 use std::future::Future;
520 use std::pin::Pin;
521 use std::time::Duration;
522
523 #[derive(Debug, Clone)]
537 pub struct Retry {
538 pub max_attempts: u32,
540
541 pub initial_delay: Duration,
543
544 pub max_delay: Duration,
546
547 pub multiplier: f64,
549 }
550
551 impl Default for Retry {
552 fn default() -> Self {
553 Self {
554 max_attempts: 3,
555 initial_delay: Duration::from_millis(500),
556 max_delay: Duration::from_secs(30),
557 multiplier: 2.0,
558 }
559 }
560 }
561
562 impl Retry {
563 pub fn new(max_attempts: u32, initial_delay: Duration) -> Self {
565 Self {
566 max_attempts,
567 initial_delay,
568 ..Default::default()
569 }
570 }
571
572 fn delay_for_attempt(&self, attempt: u32) -> Duration {
573 let delay_ms = self.initial_delay.as_millis() as f64
574 * self.multiplier.powi(attempt.saturating_sub(1) as i32);
575 let delay = Duration::from_millis(delay_ms as u64);
576 std::cmp::min(delay, self.max_delay)
577 }
578 }
579
580 impl<T> Interceptor<T> for Retry
581 where
582 T: Interceptable,
583 T::Input: Sync,
584 T::Output: Retryable,
585 {
586 fn intercept<'a>(
587 &'a self,
588 input: &'a T::Input,
589 next: Next<'a, T>,
590 ) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>> {
591 Box::pin(async move {
592 let mut last_result: Option<T::Output> = None;
593
594 for attempt in 1..=self.max_attempts {
595 let result = next.run(input).await;
596
597 if !result.should_retry() || attempt == self.max_attempts {
598 return result;
599 }
600
601 let delay = self.delay_for_attempt(attempt);
603 tokio::time::sleep(delay).await;
604
605 last_result = Some(result);
606 }
607
608 last_result.expect("at least one attempt should have been made")
610 })
611 }
612 }
613
614 #[derive(Debug, Clone)]
629 pub struct Timeout {
630 pub duration: Duration,
632 }
633
634 impl Timeout {
635 pub fn new(duration: Duration) -> Self {
637 Self { duration }
638 }
639 }
640
641 impl<T> Interceptor<T> for Timeout
642 where
643 T: Interceptable,
644 T::Input: Sync,
645 T::Output: Timeoutable,
646 {
647 fn intercept<'a>(
648 &'a self,
649 input: &'a T::Input,
650 next: Next<'a, T>,
651 ) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>> {
652 let duration = self.duration;
653 Box::pin(async move {
654 match tokio::time::timeout(duration, next.run(input)).await {
655 Ok(result) => result,
656 Err(_) => T::Output::timeout_error(duration),
657 }
658 })
659 }
660 }
661
662 #[derive(Debug, Clone, Default)]
666 pub struct NoOp;
667
668 impl<T> Interceptor<T> for NoOp
669 where
670 T: Interceptable,
671 T::Input: Sync,
672 {
673 fn intercept<'a>(
674 &'a self,
675 input: &'a T::Input,
676 next: Next<'a, T>,
677 ) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>> {
678 Box::pin(next.run(input))
679 }
680 }
681
682 #[cfg(feature = "tracing")]
696 #[derive(Debug, Clone)]
697 pub struct Logging {
698 pub level: LogLevel,
700 }
701
702 #[cfg(feature = "tracing")]
704 #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
705 pub enum LogLevel {
706 #[default]
708 Info,
709 Debug,
711 Trace,
713 }
714
715 #[cfg(feature = "tracing")]
716 impl Default for Logging {
717 fn default() -> Self {
718 Self {
719 level: LogLevel::Info,
720 }
721 }
722 }
723
724 #[cfg(feature = "tracing")]
725 impl Logging {
726 pub fn new(level: LogLevel) -> Self {
728 Self { level }
729 }
730 }
731
732 #[cfg(feature = "tracing")]
733 impl<T> Interceptor<T> for Logging
734 where
735 T: Interceptable,
736 T::Input: Sync + Loggable,
737 T::Output: Outcome,
738 {
739 fn intercept<'a>(
740 &'a self,
741 input: &'a T::Input,
742 next: Next<'a, T>,
743 ) -> Pin<Box<dyn Future<Output = T::Output> + Send + 'a>> {
744 let description = input.log_description();
745 let level = self.level;
746
747 Box::pin(async move {
748 let start = std::time::Instant::now();
749
750 if level == LogLevel::Trace {
751 tracing::debug!(description = %description, "operation starting");
752 }
753
754 let result = next.run(input).await;
755 let duration = start.elapsed();
756 let success = result.is_success();
757
758 match level {
759 LogLevel::Info => {
760 tracing::info!(
761 duration_ms = duration.as_millis() as u64,
762 "operation completed"
763 );
764 }
765 LogLevel::Debug | LogLevel::Trace => {
766 tracing::debug!(
767 duration_ms = duration.as_millis() as u64,
768 success,
769 "operation completed"
770 );
771 }
772 }
773
774 result
775 })
776 }
777 }
778}
779
780#[cfg(feature = "tracing")]
782pub use interceptors::{LogLevel, Logging};
783pub use interceptors::{NoOp, Retry, Timeout};
784
785pub mod tool_interceptors {
791 use super::{
792 Interceptor, Next,
793 domain::{ToolExec, ToolRequest, ToolResponse},
794 };
795 use serde_json::Value;
796 use std::future::Future;
797 use std::pin::Pin;
798
799 #[derive(Debug, Clone)]
801 pub enum ApprovalDecision {
802 Allow,
804 Deny(String),
806 Modify(Value),
808 }
809
810 pub struct Approval<F> {
832 check: F,
833 }
834
835 impl<F> Approval<F>
836 where
837 F: Fn(&ToolRequest) -> ApprovalDecision + Send + Sync,
838 {
839 pub fn new(check: F) -> Self {
841 Self { check }
842 }
843 }
844
845 impl<F> std::fmt::Debug for Approval<F> {
846 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
847 f.debug_struct("Approval").finish_non_exhaustive()
848 }
849 }
850
851 impl<Ctx, F> Interceptor<ToolExec<Ctx>> for Approval<F>
852 where
853 Ctx: Send + Sync + 'static,
854 F: Fn(&ToolRequest) -> ApprovalDecision + Send + Sync,
855 {
856 fn intercept<'a>(
857 &'a self,
858 input: &'a ToolRequest,
859 next: Next<'a, ToolExec<Ctx>>,
860 ) -> Pin<Box<dyn Future<Output = ToolResponse> + Send + 'a>> {
861 Box::pin(async move {
862 match (self.check)(input) {
863 ApprovalDecision::Allow => next.run(input).await,
864 ApprovalDecision::Deny(reason) => ToolResponse {
865 content: reason,
866 is_error: true,
867 },
868 ApprovalDecision::Modify(new_args) => {
869 let modified = ToolRequest {
870 name: input.name.clone(),
871 call_id: input.call_id.clone(),
872 arguments: new_args,
873 };
874 next.run(&modified).await
875 }
876 }
877 })
878 }
879 }
880}
881
882pub use tool_interceptors::{Approval, ApprovalDecision};
883
884#[cfg(test)]
885mod tests {
886 use super::*;
887 use std::sync::atomic::{AtomicU32, Ordering};
888 use std::time::Duration;
889
890 struct TestOp;
892
893 impl Interceptable for TestOp {
894 type Input = String;
895 type Output = Result<String, String>;
896 }
897
898 impl behavior::Retryable for Result<String, String> {
899 fn should_retry(&self) -> bool {
900 self.is_err()
901 }
902 }
903
904 impl behavior::Timeoutable for Result<String, String> {
905 fn timeout_error(duration: Duration) -> Self {
906 Err(format!("timeout after {duration:?}"))
907 }
908 }
909
910 struct EchoOp;
911
912 impl Operation<TestOp> for EchoOp {
913 fn execute<'a>(
914 &'a self,
915 input: &'a String,
916 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
917 Box::pin(async move { Ok(format!("echo: {input}")) })
918 }
919 }
920
921 struct FailOp {
922 failures: AtomicU32,
923 max_failures: u32,
924 }
925
926 impl FailOp {
927 fn new(max_failures: u32) -> Self {
928 Self {
929 failures: AtomicU32::new(0),
930 max_failures,
931 }
932 }
933 }
934
935 impl Operation<TestOp> for FailOp {
936 fn execute<'a>(
937 &'a self,
938 input: &'a String,
939 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
940 Box::pin(async move {
941 let count = self.failures.fetch_add(1, Ordering::SeqCst);
942 if count < self.max_failures {
943 let failure_num = count + 1;
944 Err(format!("failure {failure_num}"))
945 } else {
946 Ok(format!("success after {count} failures: {input}"))
947 }
948 })
949 }
950 }
951
952 #[tokio::test]
953 async fn empty_stack_passthrough() {
954 let stack = InterceptorStack::<TestOp>::new();
955 let input = "hello".to_string();
956 let result = stack.execute(&input, &EchoOp).await;
957 assert_eq!(result, Ok("echo: hello".to_string()));
958 }
959
960 #[tokio::test]
961 async fn noop_interceptor_passthrough() {
962 let stack = InterceptorStack::<TestOp>::new().with(NoOp);
963 let input = "test".to_string();
964 let result = stack.execute(&input, &EchoOp).await;
965 assert_eq!(result, Ok("echo: test".to_string()));
966 }
967
968 #[tokio::test]
969 async fn multiple_noop_interceptors() {
970 let stack = InterceptorStack::<TestOp>::new()
971 .with(NoOp)
972 .with(NoOp)
973 .with(NoOp);
974 let input = "multi".to_string();
975 let result = stack.execute(&input, &EchoOp).await;
976 assert_eq!(result, Ok("echo: multi".to_string()));
977 }
978
979 #[tokio::test]
980 async fn retry_succeeds_after_failures() {
981 let stack = InterceptorStack::<TestOp>::new().with(Retry::new(3, Duration::from_millis(1)));
982
983 let op = FailOp::new(2); let input = "retry-test".to_string();
985 let result = stack.execute(&input, &op).await;
986
987 assert!(result.is_ok());
988 assert!(result.unwrap().contains("success after 2 failures"));
989 }
990
991 #[tokio::test]
992 async fn retry_exhausted() {
993 let stack = InterceptorStack::<TestOp>::new().with(Retry::new(2, Duration::from_millis(1)));
994
995 let op = FailOp::new(10); let input = "exhaust".to_string();
997 let result = stack.execute(&input, &op).await;
998
999 assert!(result.is_err());
1000 assert!(result.unwrap_err().contains("failure"));
1001 }
1002
1003 #[tokio::test]
1004 async fn timeout_success() {
1005 let stack = InterceptorStack::<TestOp>::new().with(Timeout::new(Duration::from_secs(1)));
1006 let input = "fast".to_string();
1007 let result = stack.execute(&input, &EchoOp).await;
1008 assert_eq!(result, Ok("echo: fast".to_string()));
1009 }
1010
1011 #[tokio::test]
1012 async fn timeout_expires() {
1013 struct SlowOp;
1014
1015 impl Operation<TestOp> for SlowOp {
1016 fn execute<'a>(
1017 &'a self,
1018 _input: &'a String,
1019 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
1020 Box::pin(async {
1021 tokio::time::sleep(Duration::from_secs(10)).await;
1022 Ok("should not reach".to_string())
1023 })
1024 }
1025 }
1026
1027 let stack = InterceptorStack::<TestOp>::new().with(Timeout::new(Duration::from_millis(10)));
1028 let input = "slow".to_string();
1029 let result = stack.execute(&input, &SlowOp).await;
1030
1031 assert!(result.is_err());
1032 assert!(result.unwrap_err().contains("timeout"));
1033 }
1034
1035 #[tokio::test]
1036 async fn interceptor_ordering() {
1037 use std::sync::Mutex;
1038
1039 struct RecordingInterceptor {
1040 name: &'static str,
1041 log: Arc<Mutex<Vec<String>>>,
1042 }
1043
1044 impl Interceptor<TestOp> for RecordingInterceptor {
1045 fn intercept<'a>(
1046 &'a self,
1047 input: &'a String,
1048 next: Next<'a, TestOp>,
1049 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
1050 let name = self.name;
1051 let log = Arc::clone(&self.log);
1052 Box::pin(async move {
1053 log.lock().unwrap().push(format!("{name}-before"));
1054 let result = next.run(input).await;
1055 log.lock().unwrap().push(format!("{name}-after"));
1056 result
1057 })
1058 }
1059 }
1060
1061 let log = Arc::new(Mutex::new(Vec::new()));
1062
1063 let stack = InterceptorStack::<TestOp>::new()
1064 .with(RecordingInterceptor {
1065 name: "A",
1066 log: Arc::clone(&log),
1067 })
1068 .with(RecordingInterceptor {
1069 name: "B",
1070 log: Arc::clone(&log),
1071 });
1072
1073 let input = "order".to_string();
1074 let _ = stack.execute(&input, &EchoOp).await;
1075
1076 let recorded = log.lock().unwrap().clone();
1077 assert_eq!(recorded, vec!["A-before", "B-before", "B-after", "A-after"]);
1078 }
1079
1080 #[tokio::test]
1081 async fn short_circuit_interceptor() {
1082 struct ShortCircuit;
1083
1084 impl Interceptor<TestOp> for ShortCircuit {
1085 fn intercept<'a>(
1086 &'a self,
1087 _input: &'a String,
1088 _next: Next<'a, TestOp>,
1089 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
1090 Box::pin(async { Err("short-circuited".to_string()) })
1091 }
1092 }
1093
1094 let stack = InterceptorStack::<TestOp>::new()
1095 .with(ShortCircuit)
1096 .with(NoOp); let input = "blocked".to_string();
1099 let result = stack.execute(&input, &EchoOp).await;
1100
1101 assert_eq!(result, Err("short-circuited".to_string()));
1102 }
1103
1104 #[tokio::test]
1105 async fn execute_with_closure() {
1106 let stack = InterceptorStack::<TestOp>::new().with(NoOp);
1107
1108 let input = "closure-test".to_string();
1109 let result = stack
1110 .execute_fn(&input, |i| Box::pin(async move { Ok(format!("fn: {i}")) }))
1111 .await;
1112
1113 assert_eq!(result, Ok("fn: closure-test".to_string()));
1114 }
1115
1116 #[tokio::test]
1117 async fn next_is_copy() {
1118 struct MultiCallInterceptor {
1120 calls: AtomicU32,
1121 }
1122
1123 impl Interceptor<TestOp> for MultiCallInterceptor {
1124 fn intercept<'a>(
1125 &'a self,
1126 input: &'a String,
1127 next: Next<'a, TestOp>,
1128 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
1129 Box::pin(async move {
1130 let _ = next.run(input).await;
1132 self.calls.fetch_add(1, Ordering::SeqCst);
1133 next.run(input).await
1134 })
1135 }
1136 }
1137
1138 let interceptor = MultiCallInterceptor {
1139 calls: AtomicU32::new(0),
1140 };
1141
1142 let stack = InterceptorStack::<TestOp>::new().with(interceptor);
1143 let input = "copy-test".to_string();
1144 let result = stack.execute(&input, &EchoOp).await;
1145
1146 assert_eq!(result, Ok("echo: copy-test".to_string()));
1147 }
1148
1149 #[tokio::test]
1150 async fn shared_interceptor() {
1151 let shared: Arc<dyn Interceptor<TestOp>> = Arc::new(NoOp);
1152
1153 let stack1 = InterceptorStack::<TestOp>::new().with_shared(Arc::clone(&shared));
1154
1155 let stack2 = InterceptorStack::<TestOp>::new().with_shared(Arc::clone(&shared));
1156
1157 let input = "shared".to_string();
1158 let r1 = stack1.execute(&input, &EchoOp).await;
1159 let r2 = stack2.execute(&input, &EchoOp).await;
1160
1161 assert_eq!(r1, Ok("echo: shared".to_string()));
1162 assert_eq!(r2, Ok("echo: shared".to_string()));
1163 }
1164
1165 #[test]
1166 fn stack_len_and_is_empty() {
1167 let empty: InterceptorStack<TestOp> = InterceptorStack::new();
1168 assert!(empty.is_empty());
1169 assert_eq!(empty.len(), 0);
1170
1171 let one = InterceptorStack::<TestOp>::new().with(NoOp);
1172 assert!(!one.is_empty());
1173 assert_eq!(one.len(), 1);
1174
1175 let two = InterceptorStack::<TestOp>::new().with(NoOp).with(NoOp);
1176 assert_eq!(two.len(), 2);
1177 }
1178
1179 mod approval_tests {
1184 use super::*;
1185 use crate::intercept::domain::{ToolExec, ToolRequest, ToolResponse};
1186 use crate::intercept::tool_interceptors::{Approval, ApprovalDecision};
1187 use serde_json::json;
1188
1189 struct EchoToolOp;
1190
1191 impl Operation<ToolExec<()>> for EchoToolOp {
1192 fn execute<'a>(
1193 &'a self,
1194 input: &'a ToolRequest,
1195 ) -> Pin<Box<dyn Future<Output = ToolResponse> + Send + 'a>> {
1196 Box::pin(async move {
1197 ToolResponse {
1198 content: format!("executed: {} with {:?}", input.name, input.arguments),
1199 is_error: false,
1200 }
1201 })
1202 }
1203 }
1204
1205 #[tokio::test]
1206 async fn approval_allow() {
1207 let stack = InterceptorStack::<ToolExec<()>>::new()
1208 .with(Approval::new(|_| ApprovalDecision::Allow));
1209
1210 let input = ToolRequest {
1211 name: "test_tool".into(),
1212 call_id: "call_1".into(),
1213 arguments: json!({"x": 1}),
1214 };
1215
1216 let result = stack.execute(&input, &EchoToolOp).await;
1217 assert!(!result.is_error);
1218 assert!(result.content.contains("test_tool"));
1219 }
1220
1221 #[tokio::test]
1222 async fn approval_deny() {
1223 let stack = InterceptorStack::<ToolExec<()>>::new().with(Approval::new(|req| {
1224 if req.name == "dangerous" {
1225 ApprovalDecision::Deny("Not allowed".into())
1226 } else {
1227 ApprovalDecision::Allow
1228 }
1229 }));
1230
1231 let input = ToolRequest {
1232 name: "dangerous".into(),
1233 call_id: "call_2".into(),
1234 arguments: json!({}),
1235 };
1236
1237 let result = stack.execute(&input, &EchoToolOp).await;
1238 assert!(result.is_error);
1239 assert_eq!(result.content, "Not allowed");
1240 }
1241
1242 #[tokio::test]
1243 async fn approval_modify() {
1244 let stack = InterceptorStack::<ToolExec<()>>::new().with(Approval::new(|req| {
1245 let mut args = req.arguments.clone();
1247 args["modified"] = json!(true);
1248 ApprovalDecision::Modify(args)
1249 }));
1250
1251 let input = ToolRequest {
1252 name: "my_tool".into(),
1253 call_id: "call_3".into(),
1254 arguments: json!({"original": "value"}),
1255 };
1256
1257 let result = stack.execute(&input, &EchoToolOp).await;
1258 assert!(!result.is_error);
1259 assert!(result.content.contains("modified"));
1260 assert!(result.content.contains("true"));
1261 }
1262
1263 #[tokio::test]
1264 async fn approval_debug() {
1265 let approval = Approval::new(|_: &ToolRequest| ApprovalDecision::Allow);
1266 let debug_str = format!("{approval:?}");
1267 assert!(debug_str.contains("Approval"));
1268 }
1269 }
1270}