camel_processor/
throttler.rs1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Mutex;
4use std::task::{Context, Poll};
5use std::time::{Duration, Instant};
6
7use tower::Service;
8
9use camel_api::{BoxProcessor, CamelError, Exchange, ThrottleStrategy, ThrottlerConfig, Value};
10
11const CAMEL_STOP: &str = "CamelStop";
12
13struct RateLimiter {
14 tokens: f64,
15 max_tokens: f64,
16 refill_rate: f64,
17 last_refill: Instant,
18}
19
20impl RateLimiter {
21 fn new(max_requests: usize, period: Duration) -> Self {
22 let refill_rate = max_requests as f64 / period.as_secs_f64();
23 Self {
24 tokens: max_requests as f64,
25 max_tokens: max_requests as f64,
26 refill_rate,
27 last_refill: Instant::now(),
28 }
29 }
30
31 fn try_acquire(&mut self) -> bool {
32 let now = Instant::now();
33 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
34 if elapsed > 0.0 {
35 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
36 self.last_refill = now;
37 }
38 if self.tokens >= 1.0 {
39 self.tokens -= 1.0;
40 true
41 } else {
42 false
43 }
44 }
45
46 fn time_until_next_token(&self) -> Duration {
47 if self.tokens >= 1.0 {
48 Duration::ZERO
49 } else {
50 let tokens_needed = 1.0 - self.tokens;
51 Duration::from_secs_f64(tokens_needed / self.refill_rate)
52 }
53 }
54}
55
56#[derive(Clone)]
57pub struct ThrottlerService {
58 config: ThrottlerConfig,
59 limiter: std::sync::Arc<Mutex<RateLimiter>>,
60 next: BoxProcessor,
61}
62
63impl ThrottlerService {
64 pub fn new(config: ThrottlerConfig, next: BoxProcessor) -> Self {
65 assert!(
66 config.period > Duration::ZERO,
67 "throttler period must be > 0"
68 );
69 let limiter = RateLimiter::new(config.max_requests, config.period);
70 Self {
71 config,
72 limiter: std::sync::Arc::new(Mutex::new(limiter)),
73 next,
74 }
75 }
76}
77
78impl Service<Exchange> for ThrottlerService {
79 type Response = Exchange;
80 type Error = CamelError;
81 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
82
83 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84 self.next.poll_ready(cx)
85 }
86
87 fn call(&mut self, mut exchange: Exchange) -> Self::Future {
88 let config = self.config.clone();
89 let limiter = self.limiter.clone();
90 let mut next = self.next.clone();
91
92 Box::pin(async move {
93 let acquired = {
94 let mut limiter = limiter.lock().unwrap(); limiter.try_acquire()
96 };
97
98 if acquired {
99 next.call(exchange).await
100 } else {
101 match config.strategy {
102 ThrottleStrategy::Delay => {
103 loop {
104 let wait_time = {
105 let limiter = limiter.lock().unwrap(); limiter.time_until_next_token()
107 };
108 if wait_time > Duration::ZERO {
109 tokio::time::sleep(wait_time).await;
110 }
111 let acquired = {
112 let mut limiter = limiter.lock().unwrap(); limiter.try_acquire()
114 };
115 if acquired {
116 break;
117 }
118 tokio::task::yield_now().await;
121 }
122 next.call(exchange).await
123 }
124 ThrottleStrategy::Reject => Err(CamelError::ProcessorError(
125 "Throttled: rate limit exceeded".to_string(),
126 )),
127 ThrottleStrategy::Drop => {
128 exchange.set_property(CAMEL_STOP, Value::Bool(true));
129 Ok(exchange)
130 }
131 }
132 }
133 })
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use camel_api::{BoxProcessorExt, Message};
141 use tower::ServiceExt;
142
143 fn passthrough() -> BoxProcessor {
144 BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
145 }
146
147 #[test]
148 fn test_throttler_zero_period_rejected() {
149 let config = ThrottlerConfig::new(5, Duration::ZERO);
150 let result = std::panic::catch_unwind(|| {
151 ThrottlerService::new(config, passthrough());
152 });
153 assert!(result.is_err(), "zero period should panic");
154 }
155
156 #[tokio::test]
157 async fn test_throttler_allows_under_limit() {
158 let config = ThrottlerConfig::new(5, Duration::from_secs(1));
159 let mut svc = ThrottlerService::new(config, passthrough());
160
161 for _ in 0..5 {
162 let ex = Exchange::new(Message::new("test"));
163 let result = svc.ready().await.unwrap().call(ex).await;
164 assert!(result.is_ok());
165 }
166 }
167
168 #[tokio::test]
169 async fn test_throttler_delay_strategy_queues_message() {
170 let config = ThrottlerConfig::new(1, Duration::from_millis(100));
171 let mut svc = ThrottlerService::new(config, passthrough());
172
173 let ex1 = Exchange::new(Message::new("first"));
174 let result1 = svc.ready().await.unwrap().call(ex1).await;
175 assert!(result1.is_ok());
176
177 let start = Instant::now();
178 let ex2 = Exchange::new(Message::new("second"));
179 let result2 = svc.ready().await.unwrap().call(ex2).await;
180 let elapsed = start.elapsed();
181 assert!(result2.is_ok());
182 assert!(elapsed >= Duration::from_millis(50));
183 }
184
185 #[tokio::test]
186 async fn test_throttler_reject_strategy_returns_error() {
187 let config =
188 ThrottlerConfig::new(1, Duration::from_secs(10)).strategy(ThrottleStrategy::Reject);
189 let mut svc = ThrottlerService::new(config, passthrough());
190
191 let ex1 = Exchange::new(Message::new("first"));
192 let _ = svc.ready().await.unwrap().call(ex1).await;
193
194 let ex2 = Exchange::new(Message::new("second"));
195 let result = svc.ready().await.unwrap().call(ex2).await;
196 assert!(result.is_err());
197 let err = result.unwrap_err().to_string();
198 assert!(err.contains("Throttled"));
199 }
200
201 #[tokio::test]
202 async fn test_throttler_drop_strategy_sets_camel_stop() {
203 let config =
204 ThrottlerConfig::new(1, Duration::from_secs(10)).strategy(ThrottleStrategy::Drop);
205 let mut svc = ThrottlerService::new(config, passthrough());
206
207 let ex1 = Exchange::new(Message::new("first"));
208 let _ = svc.ready().await.unwrap().call(ex1).await;
209
210 let ex2 = Exchange::new(Message::new("second"));
211 let result = svc.ready().await.unwrap().call(ex2).await.unwrap();
212 assert_eq!(result.property(CAMEL_STOP), Some(&Value::Bool(true)));
213 }
214
215 #[tokio::test]
216 async fn test_throttler_token_replenishment() {
217 let config = ThrottlerConfig::new(1, Duration::from_millis(50));
218 let mut svc = ThrottlerService::new(config, passthrough());
219
220 let ex1 = Exchange::new(Message::new("first"));
221 let _ = svc.ready().await.unwrap().call(ex1).await;
222
223 tokio::time::sleep(Duration::from_millis(100)).await;
224
225 let ex2 = Exchange::new(Message::new("second"));
226 let result = svc.ready().await.unwrap().call(ex2).await;
227 assert!(result.is_ok());
228 }
229}