camel_processor/
wire_tap.rs1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use tokio::sync::Semaphore;
7use tower::{Service, ServiceExt};
8
9use camel_api::{CamelError, Exchange};
10
11#[derive(Clone, Default)]
13pub struct WireTapConfig {
14 pub max_concurrent: Option<usize>,
16}
17
18impl WireTapConfig {
19 pub fn bounded(max_concurrent: usize) -> Self {
21 assert!(max_concurrent > 0, "max_concurrent must be > 0");
22 Self {
23 max_concurrent: Some(max_concurrent),
24 }
25 }
26}
27
28#[derive(Clone)]
29pub struct WireTapService {
30 tap_endpoint: camel_api::BoxProcessor,
31 semaphore: Option<Arc<Semaphore>>,
32}
33
34impl WireTapService {
35 pub fn new(tap_endpoint: camel_api::BoxProcessor) -> Self {
37 Self {
38 tap_endpoint,
39 semaphore: None,
40 }
41 }
42
43 pub fn with_config(tap_endpoint: camel_api::BoxProcessor, config: WireTapConfig) -> Self {
45 let semaphore = config
46 .max_concurrent
47 .map(|limit| Arc::new(Semaphore::new(limit)));
48 Self {
49 tap_endpoint,
50 semaphore,
51 }
52 }
53}
54
55impl Service<Exchange> for WireTapService {
56 type Response = Exchange;
57 type Error = CamelError;
58 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
59
60 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
61 self.tap_endpoint.poll_ready(cx)
62 }
63
64 fn call(&mut self, exchange: Exchange) -> Self::Future {
65 let mut tap_endpoint = self.tap_endpoint.clone();
66 let tap_exchange = exchange.clone();
67 let semaphore = self.semaphore.clone();
68
69 tokio::spawn(async move {
70 let _permit = match &semaphore {
72 Some(sem) => match sem.acquire().await {
73 Ok(p) => Some(p),
74 Err(_) => {
75 tracing::warn!("WireTap semaphore closed, dropping tap");
76 return;
77 }
78 },
79 None => None,
80 };
81
82 if let Err(e) = tap_endpoint.ready().await {
83 tracing::warn!("WireTap endpoint poll_ready failed: {}", e);
84 return;
85 }
86 if let Err(e) = tap_endpoint.call(tap_exchange).await {
87 tracing::error!("WireTap processing error: {}", e);
88 }
89 });
90
91 Box::pin(async move { Ok(exchange) })
92 }
93}
94
95pub struct WireTapLayer {
97 tap_endpoint: camel_api::BoxProcessor,
98 config: WireTapConfig,
99}
100
101impl WireTapLayer {
102 pub fn new(tap_endpoint: camel_api::BoxProcessor) -> Self {
104 Self {
105 tap_endpoint,
106 config: WireTapConfig::default(),
107 }
108 }
109
110 pub fn bounded(tap_endpoint: camel_api::BoxProcessor, max_concurrent: usize) -> Self {
112 Self {
113 tap_endpoint,
114 config: WireTapConfig::bounded(max_concurrent),
115 }
116 }
117}
118
119impl<S> tower::Layer<S> for WireTapLayer {
120 type Service = WireTapService;
121
122 fn layer(&self, _inner: S) -> Self::Service {
123 WireTapService::with_config(self.tap_endpoint.clone(), self.config.clone())
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use camel_api::{BoxProcessor, BoxProcessorExt, Message};
131 use std::sync::Arc;
132 use std::sync::atomic::{AtomicUsize, Ordering};
133 use tower::ServiceExt;
134
135 #[tokio::test]
136 async fn test_wire_tap_returns_original_immediately() {
137 let tap_processor = BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }));
138
139 let mut wire_tap = WireTapService::new(tap_processor);
140 let exchange = Exchange::new(Message::new("test message"));
141
142 let result = wire_tap
143 .ready()
144 .await
145 .unwrap()
146 .call(exchange)
147 .await
148 .unwrap();
149
150 assert_eq!(result.input.body.as_text(), Some("test message"));
151 }
152
153 #[tokio::test]
154 async fn test_wire_tap_endpoint_receives_clone() {
155 let received_count = Arc::new(AtomicUsize::new(0));
156 let count_clone = received_count.clone();
157
158 let tap_processor = BoxProcessor::from_fn(move |ex| {
159 let count = count_clone.clone();
160 Box::pin(async move {
161 count.fetch_add(1, Ordering::SeqCst);
162 Ok(ex)
163 })
164 });
165
166 let mut wire_tap = WireTapService::new(tap_processor);
167 let exchange = Exchange::new(Message::new("test"));
168
169 let _result = wire_tap
170 .ready()
171 .await
172 .unwrap()
173 .call(exchange)
174 .await
175 .unwrap();
176
177 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
178
179 assert_eq!(received_count.load(Ordering::SeqCst), 1);
180 }
181
182 #[tokio::test]
183 async fn test_wire_tap_isolates_errors() {
184 let tap_processor = BoxProcessor::from_fn(|_ex| {
185 Box::pin(async move { Err(CamelError::ProcessorError("tap error".into())) })
186 });
187
188 let mut wire_tap = WireTapService::new(tap_processor);
189 let exchange = Exchange::new(Message::new("test"));
190
191 let result = wire_tap.ready().await.unwrap().call(exchange).await;
192
193 assert!(result.is_ok());
194 assert_eq!(result.unwrap().input.body.as_text(), Some("test"));
195 }
196
197 #[tokio::test]
198 async fn test_wire_tap_layer() {
199 use tower::Layer;
200
201 let tap_processor = BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }));
202
203 let layer = super::WireTapLayer::new(tap_processor);
204 let inner = camel_api::IdentityProcessor;
205 let mut svc = layer.layer(inner);
206
207 let exchange = Exchange::new(Message::new("test"));
208 let result = svc.ready().await.unwrap().call(exchange).await.unwrap();
209
210 assert_eq!(result.input.body.as_text(), Some("test"));
211 }
212
213 #[tokio::test]
214 async fn test_wiretap_bounded_concurrency() {
215 let concurrent = Arc::new(AtomicUsize::new(0));
218 let max_concurrent = Arc::new(AtomicUsize::new(0));
219
220 let c = Arc::clone(&concurrent);
221 let mc = Arc::clone(&max_concurrent);
222 let tap_processor = BoxProcessor::from_fn(move |ex| {
223 let c = Arc::clone(&c);
224 let mc = Arc::clone(&mc);
225 Box::pin(async move {
226 let current = c.fetch_add(1, Ordering::SeqCst) + 1;
227 mc.fetch_max(current, Ordering::SeqCst);
228 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
230 c.fetch_sub(1, Ordering::SeqCst);
231 Ok(ex)
232 })
233 });
234
235 let config = super::WireTapConfig::bounded(2);
236 let mut svc = super::WireTapService::with_config(tap_processor, config);
237
238 for _ in 0..3 {
240 let ex = Exchange::new(Message::new("test"));
241 let _ = svc.ready().await.unwrap().call(ex).await.unwrap();
242 }
243
244 tokio::time::sleep(std::time::Duration::from_millis(300)).await;
246
247 let observed_max = max_concurrent.load(Ordering::SeqCst);
248 assert!(
249 observed_max <= 2,
250 "max concurrency was {observed_max}, expected <= 2"
251 );
252 }
253}