1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Mutex;
4use std::task::{Context, Poll};
5use std::time::{Duration, Instant};
6
7use tower::Service;
8
9use camel_api::CAMEL_STOP;
10use camel_api::{BoxProcessor, CamelError, Exchange, ThrottleStrategy, ThrottlerConfig, Value};
11
12pub struct RateLimiter {
13 tokens: f64,
14 max_tokens: f64,
15 refill_rate: f64,
16 last_refill: Instant,
17}
18
19impl RateLimiter {
20 fn new(max_requests: usize, period: Duration) -> Self {
21 let refill_rate = max_requests as f64 / period.as_secs_f64();
22 Self {
23 tokens: max_requests as f64,
24 max_tokens: max_requests as f64,
25 refill_rate,
26 last_refill: Instant::now(),
27 }
28 }
29
30 fn try_acquire(&mut self) -> bool {
31 let now = Instant::now();
32 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
33 if elapsed > 0.0 {
34 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
35 self.last_refill = now;
36 }
37 if self.tokens >= 1.0 {
38 self.tokens -= 1.0;
39 true
40 } else {
41 false
42 }
43 }
44
45 fn time_until_next_token(&self) -> Duration {
46 if self.tokens >= 1.0 {
47 Duration::ZERO
48 } else {
49 let tokens_needed = 1.0 - self.tokens;
50 Duration::from_secs_f64(tokens_needed / self.refill_rate)
51 }
52 }
53}
54
55#[derive(Clone)]
56pub struct ThrottlerService {
57 config: ThrottlerConfig,
58 limiter: std::sync::Arc<Mutex<RateLimiter>>,
59 next: BoxProcessor,
60}
61
62impl ThrottlerService {
63 pub fn new(config: ThrottlerConfig, next: BoxProcessor) -> Self {
64 assert!(
65 config.period > Duration::ZERO,
66 "throttler period must be > 0"
67 );
68 let limiter = RateLimiter::new(config.max_requests, config.period);
69 Self {
70 config,
71 limiter: std::sync::Arc::new(Mutex::new(limiter)),
72 next,
73 }
74 }
75}
76
77impl Service<Exchange> for ThrottlerService {
78 type Response = Exchange;
79 type Error = CamelError;
80 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
81
82 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
83 self.next.poll_ready(cx)
84 }
85
86 fn call(&mut self, mut exchange: Exchange) -> Self::Future {
87 let config = self.config.clone();
88 let limiter = self.limiter.clone();
89 let mut next = self.next.clone();
90
91 Box::pin(async move {
92 let acquired = {
93 let mut limiter = limiter.lock().unwrap(); limiter.try_acquire()
95 };
96
97 if acquired {
98 next.call(exchange).await
99 } else {
100 match config.strategy {
101 ThrottleStrategy::Delay => {
102 loop {
103 let wait_time = {
104 let limiter = limiter.lock().unwrap(); limiter.time_until_next_token()
106 };
107 if wait_time > Duration::ZERO {
108 tokio::time::sleep(wait_time).await;
109 }
110 let acquired = {
111 let mut limiter = limiter.lock().unwrap(); limiter.try_acquire()
113 };
114 if acquired {
115 break;
116 }
117 tokio::task::yield_now().await;
120 }
121 next.call(exchange).await
122 }
123 ThrottleStrategy::Reject => Err(CamelError::ProcessorError(
124 "Throttled: rate limit exceeded".to_string(),
125 )),
126 ThrottleStrategy::Drop => {
127 exchange.set_property(CAMEL_STOP, Value::Bool(true));
128 Ok(exchange)
129 }
130 }
131 }
132 })
133 }
134}
135
136pub struct ThrottleSegment {
143 pub config: ThrottlerConfig,
144 pub limiter: std::sync::Arc<std::sync::Mutex<RateLimiter>>,
145 pub body: camel_api::OutcomeSegment,
146}
147
148impl ThrottleSegment {
149 pub fn new(config: ThrottlerConfig, body: camel_api::OutcomeSegment) -> Self {
150 assert!(
151 config.period > Duration::ZERO,
152 "throttler period must be > 0"
153 );
154 Self {
155 limiter: std::sync::Arc::new(std::sync::Mutex::new(RateLimiter::new(
156 config.max_requests,
157 config.period,
158 ))),
159 config,
160 body,
161 }
162 }
163}
164
165impl Clone for ThrottleSegment {
166 fn clone(&self) -> Self {
167 Self {
168 config: self.config.clone(),
169 limiter: std::sync::Arc::clone(&self.limiter),
170 body: self.body.clone(),
171 }
172 }
173}
174
175impl camel_api::OutcomePipeline for ThrottleSegment {
176 fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
177 Box::new(self.clone())
178 }
179
180 fn run<'a>(
181 &'a mut self,
182 exchange: camel_api::Exchange,
183 ) -> Pin<Box<dyn Future<Output = camel_api::PipelineOutcome> + Send + 'a>> {
184 Box::pin(async move {
185 let acquired = {
186 let mut limiter = self.limiter.lock().unwrap(); limiter.try_acquire()
188 };
189 if acquired {
190 return self.body.run(exchange).await;
191 }
192 match self.config.strategy {
193 ThrottleStrategy::Delay => {
194 loop {
195 let wait_time = {
196 let limiter = self.limiter.lock().unwrap(); limiter.time_until_next_token()
198 };
199 if wait_time > Duration::ZERO {
200 tokio::time::sleep(wait_time).await;
201 }
202 let acquired = {
203 let mut limiter = self.limiter.lock().unwrap(); limiter.try_acquire()
205 };
206 if acquired {
207 break;
208 }
209 tokio::task::yield_now().await;
210 }
211 self.body.run(exchange).await
212 }
213 ThrottleStrategy::Reject => {
214 camel_api::PipelineOutcome::Failed(camel_api::CamelError::ProcessorError(
215 "Throttled: rate limit exceeded".to_string(),
216 ))
217 }
218 ThrottleStrategy::Drop => {
219 let mut ex = exchange;
220 ex.set_property(CAMEL_STOP, camel_api::Value::Bool(true));
221 camel_api::PipelineOutcome::Stopped(ex)
222 }
223 }
224 })
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use camel_api::{BoxProcessorExt, Message};
232 use tower::ServiceExt;
233
234 fn passthrough() -> BoxProcessor {
235 BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
236 }
237
238 #[test]
239 fn test_throttler_zero_period_rejected() {
240 let config = ThrottlerConfig::new(5, Duration::ZERO);
241 let result = std::panic::catch_unwind(|| {
242 ThrottlerService::new(config, passthrough());
243 });
244 assert!(result.is_err(), "zero period should panic");
245 }
246
247 #[tokio::test]
248 async fn test_throttler_allows_under_limit() {
249 let config = ThrottlerConfig::new(5, Duration::from_secs(1));
250 let mut svc = ThrottlerService::new(config, passthrough());
251
252 for _ in 0..5 {
253 let ex = Exchange::new(Message::new("test"));
254 let result = svc.ready().await.unwrap().call(ex).await;
255 assert!(result.is_ok());
256 }
257 }
258
259 #[tokio::test]
260 async fn test_throttler_delay_strategy_queues_message() {
261 let config = ThrottlerConfig::new(1, Duration::from_millis(100));
262 let mut svc = ThrottlerService::new(config, passthrough());
263
264 let ex1 = Exchange::new(Message::new("first"));
265 let result1 = svc.ready().await.unwrap().call(ex1).await;
266 assert!(result1.is_ok());
267
268 let start = Instant::now();
269 let ex2 = Exchange::new(Message::new("second"));
270 let result2 = svc.ready().await.unwrap().call(ex2).await;
271 let elapsed = start.elapsed();
272 assert!(result2.is_ok());
273 assert!(elapsed >= Duration::from_millis(50));
274 }
275
276 #[tokio::test]
277 async fn test_throttler_reject_strategy_returns_error() {
278 let config =
279 ThrottlerConfig::new(1, Duration::from_secs(10)).strategy(ThrottleStrategy::Reject);
280 let mut svc = ThrottlerService::new(config, passthrough());
281
282 let ex1 = Exchange::new(Message::new("first"));
283 let _ = svc.ready().await.unwrap().call(ex1).await;
284
285 let ex2 = Exchange::new(Message::new("second"));
286 let result = svc.ready().await.unwrap().call(ex2).await;
287 assert!(result.is_err());
288 let err = result.unwrap_err().to_string();
289 assert!(err.contains("Throttled"));
290 }
291
292 #[tokio::test]
293 async fn test_throttler_drop_strategy_sets_camel_stop() {
294 let config =
295 ThrottlerConfig::new(1, Duration::from_secs(10)).strategy(ThrottleStrategy::Drop);
296 let mut svc = ThrottlerService::new(config, passthrough());
297
298 let ex1 = Exchange::new(Message::new("first"));
299 let _ = svc.ready().await.unwrap().call(ex1).await;
300
301 let ex2 = Exchange::new(Message::new("second"));
302 let result = svc.ready().await.unwrap().call(ex2).await.unwrap();
303 assert_eq!(result.property(CAMEL_STOP), Some(&Value::Bool(true)));
304 }
305
306 #[tokio::test]
307 async fn test_throttler_token_replenishment() {
308 let config = ThrottlerConfig::new(1, Duration::from_millis(50));
309 let mut svc = ThrottlerService::new(config, passthrough());
310
311 let ex1 = Exchange::new(Message::new("first"));
312 let _ = svc.ready().await.unwrap().call(ex1).await;
313
314 tokio::time::sleep(Duration::from_millis(100)).await;
315
316 let ex2 = Exchange::new(Message::new("second"));
317 let result = svc.ready().await.unwrap().call(ex2).await;
318 assert!(result.is_ok());
319 }
320
321 #[tokio::test]
324 async fn throttle_segment_reject_strategy_returns_failed() {
325 use camel_api::{Exchange, Message, OutcomePipeline, PipelineOutcome};
326
327 #[derive(Clone)]
328 struct NoopSeg;
329 impl OutcomePipeline for NoopSeg {
330 fn clone_box(&self) -> Box<dyn OutcomePipeline> {
331 Box::new(NoopSeg)
332 }
333 fn run<'a>(
334 &'a mut self,
335 ex: Exchange,
336 ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
337 Box::pin(async move { PipelineOutcome::Completed(ex) })
338 }
339 }
340
341 let config = ThrottlerConfig {
342 max_requests: 0, period: Duration::from_secs(1),
344 strategy: ThrottleStrategy::Reject,
345 };
346 let body = camel_api::OutcomeSegment::new(Box::new(NoopSeg));
347 let mut seg = ThrottleSegment::new(config, body);
348 let ex = Exchange::new(Message::new("test"));
349 let outcome = seg.run(ex).await;
350 assert!(
351 matches!(outcome, PipelineOutcome::Failed(_)),
352 "Reject strategy must return Failed when tokens exhausted"
353 );
354 }
355
356 #[tokio::test]
357 async fn throttle_segment_drop_strategy_returns_stopped() {
358 use camel_api::{Exchange, Message, OutcomePipeline, PipelineOutcome};
359
360 #[derive(Clone)]
361 struct NoopSeg;
362 impl OutcomePipeline for NoopSeg {
363 fn clone_box(&self) -> Box<dyn OutcomePipeline> {
364 Box::new(NoopSeg)
365 }
366 fn run<'a>(
367 &'a mut self,
368 ex: Exchange,
369 ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
370 Box::pin(async move { PipelineOutcome::Completed(ex) })
371 }
372 }
373
374 let config = ThrottlerConfig {
375 max_requests: 0,
376 period: Duration::from_secs(1),
377 strategy: ThrottleStrategy::Drop,
378 };
379 let body = camel_api::OutcomeSegment::new(Box::new(NoopSeg));
380 let mut seg = ThrottleSegment::new(config, body);
381 let ex = Exchange::new(Message::new("test"));
382 let outcome = seg.run(ex).await;
383 match outcome {
384 PipelineOutcome::Stopped(returned_ex) => {
385 let stopped_flag = returned_ex.property(CAMEL_STOP).and_then(|v| v.as_bool());
386 assert_eq!(
387 stopped_flag,
388 Some(true),
389 "Drop strategy must set CamelStop=true property"
390 );
391 }
392 other => panic!("Drop must return Stopped, got {:?}", other),
393 }
394 }
395
396 #[tokio::test]
397 async fn throttle_segment_delay_strategy_propagates_stopped_body() {
398 use camel_api::{Body, Exchange, Message, OutcomePipeline, PipelineOutcome};
399
400 #[derive(Clone)]
401 struct StoppingSeg;
402 impl OutcomePipeline for StoppingSeg {
403 fn clone_box(&self) -> Box<dyn OutcomePipeline> {
404 Box::new(StoppingSeg)
405 }
406 fn run<'a>(
407 &'a mut self,
408 mut ex: Exchange,
409 ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
410 Box::pin(async move {
411 ex.input.body = Body::Bytes(b"stopped-mut".to_vec().into());
412 PipelineOutcome::Stopped(ex)
413 })
414 }
415 }
416
417 let config = ThrottlerConfig {
418 max_requests: 1, period: Duration::from_secs(1),
420 strategy: ThrottleStrategy::Delay,
421 };
422 let body = camel_api::OutcomeSegment::new(Box::new(StoppingSeg));
423 let mut seg = ThrottleSegment::new(config, body);
424 let ex = Exchange::new(Message::new("test"));
425 let outcome = seg.run(ex).await;
426 match outcome {
427 PipelineOutcome::Stopped(returned_ex) => {
428 if let Body::Bytes(b) = &returned_ex.input.body {
429 assert_eq!(
430 b.as_ref(),
431 b"stopped-mut",
432 "BUG: throttle body Stop must preserve mutations"
433 );
434 } else {
435 panic!("expected Body::Bytes");
436 }
437 }
438 other => panic!("expected Stopped propagation, got {:?}", other),
439 }
440 }
441}