1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Mutex;
4use std::task::{Context, Poll};
5use std::time::{Duration, Instant};
6
7use tower::Service;
8
9use camel_api::{BoxProcessor, CamelError, Exchange, ThrottleStrategy, ThrottlerConfig, Value};
10
11const CAMEL_STOP: &str = "CamelStop";
12
13pub struct RateLimiter {
14 tokens: f64,
15 max_tokens: f64,
16 refill_rate: f64,
17 last_refill: Instant,
18}
19
20impl RateLimiter {
21 fn new(max_requests: usize, period: Duration) -> Self {
22 let refill_rate = max_requests as f64 / period.as_secs_f64();
23 Self {
24 tokens: max_requests as f64,
25 max_tokens: max_requests as f64,
26 refill_rate,
27 last_refill: Instant::now(),
28 }
29 }
30
31 fn try_acquire(&mut self) -> bool {
32 let now = Instant::now();
33 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
34 if elapsed > 0.0 {
35 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
36 self.last_refill = now;
37 }
38 if self.tokens >= 1.0 {
39 self.tokens -= 1.0;
40 true
41 } else {
42 false
43 }
44 }
45
46 fn time_until_next_token(&self) -> Duration {
47 if self.tokens >= 1.0 {
48 Duration::ZERO
49 } else {
50 let tokens_needed = 1.0 - self.tokens;
51 Duration::from_secs_f64(tokens_needed / self.refill_rate)
52 }
53 }
54}
55
56#[derive(Clone)]
57pub struct ThrottlerService {
58 config: ThrottlerConfig,
59 limiter: std::sync::Arc<Mutex<RateLimiter>>,
60 next: BoxProcessor,
61}
62
63impl ThrottlerService {
64 pub fn new(config: ThrottlerConfig, next: BoxProcessor) -> Self {
65 assert!(
66 config.period > Duration::ZERO,
67 "throttler period must be > 0"
68 );
69 let limiter = RateLimiter::new(config.max_requests, config.period);
70 Self {
71 config,
72 limiter: std::sync::Arc::new(Mutex::new(limiter)),
73 next,
74 }
75 }
76}
77
78impl Service<Exchange> for ThrottlerService {
79 type Response = Exchange;
80 type Error = CamelError;
81 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
82
83 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84 self.next.poll_ready(cx)
85 }
86
87 fn call(&mut self, mut exchange: Exchange) -> Self::Future {
88 let config = self.config.clone();
89 let limiter = self.limiter.clone();
90 let mut next = self.next.clone();
91
92 Box::pin(async move {
93 let acquired = {
94 let mut limiter = limiter.lock().unwrap(); limiter.try_acquire()
96 };
97
98 if acquired {
99 next.call(exchange).await
100 } else {
101 match config.strategy {
102 ThrottleStrategy::Delay => {
103 loop {
104 let wait_time = {
105 let limiter = limiter.lock().unwrap(); limiter.time_until_next_token()
107 };
108 if wait_time > Duration::ZERO {
109 tokio::time::sleep(wait_time).await;
110 }
111 let acquired = {
112 let mut limiter = limiter.lock().unwrap(); limiter.try_acquire()
114 };
115 if acquired {
116 break;
117 }
118 tokio::task::yield_now().await;
121 }
122 next.call(exchange).await
123 }
124 ThrottleStrategy::Reject => Err(CamelError::ProcessorError(
125 "Throttled: rate limit exceeded".to_string(),
126 )),
127 ThrottleStrategy::Drop => {
128 exchange.set_property(CAMEL_STOP, Value::Bool(true));
129 Ok(exchange)
130 }
131 }
132 }
133 })
134 }
135}
136
137pub struct ThrottleSegment {
144 pub config: ThrottlerConfig,
145 pub limiter: std::sync::Arc<std::sync::Mutex<RateLimiter>>,
146 pub body: camel_api::OutcomeSegment,
147}
148
149impl ThrottleSegment {
150 pub fn new(config: ThrottlerConfig, body: camel_api::OutcomeSegment) -> Self {
151 assert!(
152 config.period > Duration::ZERO,
153 "throttler period must be > 0"
154 );
155 Self {
156 limiter: std::sync::Arc::new(std::sync::Mutex::new(RateLimiter::new(
157 config.max_requests,
158 config.period,
159 ))),
160 config,
161 body,
162 }
163 }
164}
165
166impl Clone for ThrottleSegment {
167 fn clone(&self) -> Self {
168 Self {
169 config: self.config.clone(),
170 limiter: std::sync::Arc::clone(&self.limiter),
171 body: self.body.clone(),
172 }
173 }
174}
175
176impl camel_api::OutcomePipeline for ThrottleSegment {
177 fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
178 Box::new(self.clone())
179 }
180
181 fn run<'a>(
182 &'a mut self,
183 exchange: camel_api::Exchange,
184 ) -> Pin<Box<dyn Future<Output = camel_api::PipelineOutcome> + Send + 'a>> {
185 Box::pin(async move {
186 let acquired = {
187 let mut limiter = self.limiter.lock().unwrap(); limiter.try_acquire()
189 };
190 if acquired {
191 return self.body.run(exchange).await;
192 }
193 match self.config.strategy {
194 ThrottleStrategy::Delay => {
195 loop {
196 let wait_time = {
197 let limiter = self.limiter.lock().unwrap(); limiter.time_until_next_token()
199 };
200 if wait_time > Duration::ZERO {
201 tokio::time::sleep(wait_time).await;
202 }
203 let acquired = {
204 let mut limiter = self.limiter.lock().unwrap(); limiter.try_acquire()
206 };
207 if acquired {
208 break;
209 }
210 tokio::task::yield_now().await;
211 }
212 self.body.run(exchange).await
213 }
214 ThrottleStrategy::Reject => {
215 camel_api::PipelineOutcome::Failed(camel_api::CamelError::ProcessorError(
216 "Throttled: rate limit exceeded".to_string(),
217 ))
218 }
219 ThrottleStrategy::Drop => {
220 let mut ex = exchange;
221 ex.set_property(CAMEL_STOP, camel_api::Value::Bool(true));
222 camel_api::PipelineOutcome::Completed(ex)
223 }
224 }
225 })
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use camel_api::{BoxProcessorExt, Message};
233 use tower::ServiceExt;
234
235 fn passthrough() -> BoxProcessor {
236 BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
237 }
238
239 #[test]
240 fn test_throttler_zero_period_rejected() {
241 let config = ThrottlerConfig::new(5, Duration::ZERO);
242 let result = std::panic::catch_unwind(|| {
243 ThrottlerService::new(config, passthrough());
244 });
245 assert!(result.is_err(), "zero period should panic");
246 }
247
248 #[tokio::test]
249 async fn test_throttler_allows_under_limit() {
250 let config = ThrottlerConfig::new(5, Duration::from_secs(1));
251 let mut svc = ThrottlerService::new(config, passthrough());
252
253 for _ in 0..5 {
254 let ex = Exchange::new(Message::new("test"));
255 let result = svc.ready().await.unwrap().call(ex).await;
256 assert!(result.is_ok());
257 }
258 }
259
260 #[tokio::test]
261 async fn test_throttler_delay_strategy_queues_message() {
262 let config = ThrottlerConfig::new(1, Duration::from_millis(100));
263 let mut svc = ThrottlerService::new(config, passthrough());
264
265 let ex1 = Exchange::new(Message::new("first"));
266 let result1 = svc.ready().await.unwrap().call(ex1).await;
267 assert!(result1.is_ok());
268
269 let start = Instant::now();
270 let ex2 = Exchange::new(Message::new("second"));
271 let result2 = svc.ready().await.unwrap().call(ex2).await;
272 let elapsed = start.elapsed();
273 assert!(result2.is_ok());
274 assert!(elapsed >= Duration::from_millis(50));
275 }
276
277 #[tokio::test]
278 async fn test_throttler_reject_strategy_returns_error() {
279 let config =
280 ThrottlerConfig::new(1, Duration::from_secs(10)).strategy(ThrottleStrategy::Reject);
281 let mut svc = ThrottlerService::new(config, passthrough());
282
283 let ex1 = Exchange::new(Message::new("first"));
284 let _ = svc.ready().await.unwrap().call(ex1).await;
285
286 let ex2 = Exchange::new(Message::new("second"));
287 let result = svc.ready().await.unwrap().call(ex2).await;
288 assert!(result.is_err());
289 let err = result.unwrap_err().to_string();
290 assert!(err.contains("Throttled"));
291 }
292
293 #[tokio::test]
294 async fn test_throttler_drop_strategy_sets_camel_stop() {
295 let config =
296 ThrottlerConfig::new(1, Duration::from_secs(10)).strategy(ThrottleStrategy::Drop);
297 let mut svc = ThrottlerService::new(config, passthrough());
298
299 let ex1 = Exchange::new(Message::new("first"));
300 let _ = svc.ready().await.unwrap().call(ex1).await;
301
302 let ex2 = Exchange::new(Message::new("second"));
303 let result = svc.ready().await.unwrap().call(ex2).await.unwrap();
304 assert_eq!(result.property(CAMEL_STOP), Some(&Value::Bool(true)));
305 }
306
307 #[tokio::test]
308 async fn test_throttler_token_replenishment() {
309 let config = ThrottlerConfig::new(1, Duration::from_millis(50));
310 let mut svc = ThrottlerService::new(config, passthrough());
311
312 let ex1 = Exchange::new(Message::new("first"));
313 let _ = svc.ready().await.unwrap().call(ex1).await;
314
315 tokio::time::sleep(Duration::from_millis(100)).await;
316
317 let ex2 = Exchange::new(Message::new("second"));
318 let result = svc.ready().await.unwrap().call(ex2).await;
319 assert!(result.is_ok());
320 }
321
322 #[tokio::test]
325 async fn throttle_segment_reject_strategy_returns_failed() {
326 use camel_api::{Exchange, Message, OutcomePipeline, PipelineOutcome};
327
328 #[derive(Clone)]
329 struct NoopSeg;
330 impl OutcomePipeline for NoopSeg {
331 fn clone_box(&self) -> Box<dyn OutcomePipeline> {
332 Box::new(NoopSeg)
333 }
334 fn run<'a>(
335 &'a mut self,
336 ex: Exchange,
337 ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
338 Box::pin(async move { PipelineOutcome::Completed(ex) })
339 }
340 }
341
342 let config = ThrottlerConfig {
343 max_requests: 0, period: Duration::from_secs(1),
345 strategy: ThrottleStrategy::Reject,
346 };
347 let body = camel_api::OutcomeSegment::new(Box::new(NoopSeg));
348 let mut seg = ThrottleSegment::new(config, body);
349 let ex = Exchange::new(Message::new("test"));
350 let outcome = seg.run(ex).await;
351 assert!(
352 matches!(outcome, PipelineOutcome::Failed(_)),
353 "Reject strategy must return Failed when tokens exhausted"
354 );
355 }
356
357 #[tokio::test]
358 async fn throttle_segment_drop_strategy_sets_camel_stop_and_completes() {
359 use camel_api::{Exchange, Message, OutcomePipeline, PipelineOutcome};
360
361 #[derive(Clone)]
362 struct NoopSeg;
363 impl OutcomePipeline for NoopSeg {
364 fn clone_box(&self) -> Box<dyn OutcomePipeline> {
365 Box::new(NoopSeg)
366 }
367 fn run<'a>(
368 &'a mut self,
369 ex: Exchange,
370 ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
371 Box::pin(async move { PipelineOutcome::Completed(ex) })
372 }
373 }
374
375 let config = ThrottlerConfig {
376 max_requests: 0,
377 period: Duration::from_secs(1),
378 strategy: ThrottleStrategy::Drop,
379 };
380 let body = camel_api::OutcomeSegment::new(Box::new(NoopSeg));
381 let mut seg = ThrottleSegment::new(config, body);
382 let ex = Exchange::new(Message::new("test"));
383 let outcome = seg.run(ex).await;
384 match outcome {
385 PipelineOutcome::Completed(returned_ex) => {
386 let stopped_flag = returned_ex.property(CAMEL_STOP).and_then(|v| v.as_bool());
387 assert_eq!(
388 stopped_flag,
389 Some(true),
390 "Drop strategy must set CamelStop=true property"
391 );
392 }
393 other => panic!("Drop must return Completed, got {:?}", other),
394 }
395 }
396
397 #[tokio::test]
398 async fn throttle_segment_delay_strategy_propagates_stopped_body() {
399 use camel_api::{Body, Exchange, Message, OutcomePipeline, PipelineOutcome};
400
401 #[derive(Clone)]
402 struct StoppingSeg;
403 impl OutcomePipeline for StoppingSeg {
404 fn clone_box(&self) -> Box<dyn OutcomePipeline> {
405 Box::new(StoppingSeg)
406 }
407 fn run<'a>(
408 &'a mut self,
409 mut ex: Exchange,
410 ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
411 Box::pin(async move {
412 ex.input.body = Body::Bytes(b"stopped-mut".to_vec().into());
413 PipelineOutcome::Stopped(ex)
414 })
415 }
416 }
417
418 let config = ThrottlerConfig {
419 max_requests: 1, period: Duration::from_secs(1),
421 strategy: ThrottleStrategy::Delay,
422 };
423 let body = camel_api::OutcomeSegment::new(Box::new(StoppingSeg));
424 let mut seg = ThrottleSegment::new(config, body);
425 let ex = Exchange::new(Message::new("test"));
426 let outcome = seg.run(ex).await;
427 match outcome {
428 PipelineOutcome::Stopped(returned_ex) => {
429 if let Body::Bytes(b) = &returned_ex.input.body {
430 assert_eq!(
431 b.as_ref(),
432 b"stopped-mut",
433 "BUG: throttle body Stop must preserve mutations"
434 );
435 } else {
436 panic!("expected Body::Bytes");
437 }
438 }
439 other => panic!("expected Stopped propagation, got {:?}", other),
440 }
441 }
442}