Skip to main content

camel_processor/
throttler.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Mutex;
4use std::task::{Context, Poll};
5use std::time::{Duration, Instant};
6
7use tower::Service;
8
9use camel_api::{BoxProcessor, CamelError, Exchange, ThrottleStrategy, ThrottlerConfig, Value};
10
11const CAMEL_STOP: &str = "CamelStop";
12
13struct RateLimiter {
14    tokens: f64,
15    max_tokens: f64,
16    refill_rate: f64,
17    last_refill: Instant,
18}
19
20impl RateLimiter {
21    fn new(max_requests: usize, period: Duration) -> Self {
22        let refill_rate = max_requests as f64 / period.as_secs_f64();
23        Self {
24            tokens: max_requests as f64,
25            max_tokens: max_requests as f64,
26            refill_rate,
27            last_refill: Instant::now(),
28        }
29    }
30
31    fn try_acquire(&mut self) -> bool {
32        let now = Instant::now();
33        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
34        if elapsed > 0.0 {
35            self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
36            self.last_refill = now;
37        }
38        if self.tokens >= 1.0 {
39            self.tokens -= 1.0;
40            true
41        } else {
42            false
43        }
44    }
45
46    fn time_until_next_token(&self) -> Duration {
47        if self.tokens >= 1.0 {
48            Duration::ZERO
49        } else {
50            let tokens_needed = 1.0 - self.tokens;
51            Duration::from_secs_f64(tokens_needed / self.refill_rate)
52        }
53    }
54}
55
56#[derive(Clone)]
57pub struct ThrottlerService {
58    config: ThrottlerConfig,
59    limiter: std::sync::Arc<Mutex<RateLimiter>>,
60    next: BoxProcessor,
61}
62
63impl ThrottlerService {
64    pub fn new(config: ThrottlerConfig, next: BoxProcessor) -> Self {
65        assert!(
66            config.period > Duration::ZERO,
67            "throttler period must be > 0"
68        );
69        let limiter = RateLimiter::new(config.max_requests, config.period);
70        Self {
71            config,
72            limiter: std::sync::Arc::new(Mutex::new(limiter)),
73            next,
74        }
75    }
76}
77
78impl Service<Exchange> for ThrottlerService {
79    type Response = Exchange;
80    type Error = CamelError;
81    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
82
83    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84        self.next.poll_ready(cx)
85    }
86
87    fn call(&mut self, mut exchange: Exchange) -> Self::Future {
88        let config = self.config.clone();
89        let limiter = self.limiter.clone();
90        let mut next = self.next.clone();
91
92        Box::pin(async move {
93            let acquired = {
94                let mut limiter = limiter.lock().unwrap(); // allow-unwrap
95                limiter.try_acquire()
96            };
97
98            if acquired {
99                next.call(exchange).await
100            } else {
101                match config.strategy {
102                    ThrottleStrategy::Delay => {
103                        loop {
104                            let wait_time = {
105                                let limiter = limiter.lock().unwrap(); // allow-unwrap
106                                limiter.time_until_next_token()
107                            };
108                            if wait_time > Duration::ZERO {
109                                tokio::time::sleep(wait_time).await;
110                            }
111                            let acquired = {
112                                let mut limiter = limiter.lock().unwrap(); // allow-unwrap
113                                limiter.try_acquire()
114                            };
115                            if acquired {
116                                break;
117                            }
118                            // Yield to avoid tight spinning when concurrent tasks
119                            // wake simultaneously and contend for the same token.
120                            tokio::task::yield_now().await;
121                        }
122                        next.call(exchange).await
123                    }
124                    ThrottleStrategy::Reject => Err(CamelError::ProcessorError(
125                        "Throttled: rate limit exceeded".to_string(),
126                    )),
127                    ThrottleStrategy::Drop => {
128                        exchange.set_property(CAMEL_STOP, Value::Bool(true));
129                        Ok(exchange)
130                    }
131                }
132            }
133        })
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use camel_api::{BoxProcessorExt, Message};
141    use tower::ServiceExt;
142
143    fn passthrough() -> BoxProcessor {
144        BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
145    }
146
147    #[test]
148    fn test_throttler_zero_period_rejected() {
149        let config = ThrottlerConfig::new(5, Duration::ZERO);
150        let result = std::panic::catch_unwind(|| {
151            ThrottlerService::new(config, passthrough());
152        });
153        assert!(result.is_err(), "zero period should panic");
154    }
155
156    #[tokio::test]
157    async fn test_throttler_allows_under_limit() {
158        let config = ThrottlerConfig::new(5, Duration::from_secs(1));
159        let mut svc = ThrottlerService::new(config, passthrough());
160
161        for _ in 0..5 {
162            let ex = Exchange::new(Message::new("test"));
163            let result = svc.ready().await.unwrap().call(ex).await;
164            assert!(result.is_ok());
165        }
166    }
167
168    #[tokio::test]
169    async fn test_throttler_delay_strategy_queues_message() {
170        let config = ThrottlerConfig::new(1, Duration::from_millis(100));
171        let mut svc = ThrottlerService::new(config, passthrough());
172
173        let ex1 = Exchange::new(Message::new("first"));
174        let result1 = svc.ready().await.unwrap().call(ex1).await;
175        assert!(result1.is_ok());
176
177        let start = Instant::now();
178        let ex2 = Exchange::new(Message::new("second"));
179        let result2 = svc.ready().await.unwrap().call(ex2).await;
180        let elapsed = start.elapsed();
181        assert!(result2.is_ok());
182        assert!(elapsed >= Duration::from_millis(50));
183    }
184
185    #[tokio::test]
186    async fn test_throttler_reject_strategy_returns_error() {
187        let config =
188            ThrottlerConfig::new(1, Duration::from_secs(10)).strategy(ThrottleStrategy::Reject);
189        let mut svc = ThrottlerService::new(config, passthrough());
190
191        let ex1 = Exchange::new(Message::new("first"));
192        let _ = svc.ready().await.unwrap().call(ex1).await;
193
194        let ex2 = Exchange::new(Message::new("second"));
195        let result = svc.ready().await.unwrap().call(ex2).await;
196        assert!(result.is_err());
197        let err = result.unwrap_err().to_string();
198        assert!(err.contains("Throttled"));
199    }
200
201    #[tokio::test]
202    async fn test_throttler_drop_strategy_sets_camel_stop() {
203        let config =
204            ThrottlerConfig::new(1, Duration::from_secs(10)).strategy(ThrottleStrategy::Drop);
205        let mut svc = ThrottlerService::new(config, passthrough());
206
207        let ex1 = Exchange::new(Message::new("first"));
208        let _ = svc.ready().await.unwrap().call(ex1).await;
209
210        let ex2 = Exchange::new(Message::new("second"));
211        let result = svc.ready().await.unwrap().call(ex2).await.unwrap();
212        assert_eq!(result.property(CAMEL_STOP), Some(&Value::Bool(true)));
213    }
214
215    #[tokio::test]
216    async fn test_throttler_token_replenishment() {
217        let config = ThrottlerConfig::new(1, Duration::from_millis(50));
218        let mut svc = ThrottlerService::new(config, passthrough());
219
220        let ex1 = Exchange::new(Message::new("first"));
221        let _ = svc.ready().await.unwrap().call(ex1).await;
222
223        tokio::time::sleep(Duration::from_millis(100)).await;
224
225        let ex2 = Exchange::new(Message::new("second"));
226        let result = svc.ready().await.unwrap().call(ex2).await;
227        assert!(result.is_ok());
228    }
229}