Skip to main content

camel_processor/
throttler.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Mutex;
4use std::task::{Context, Poll};
5use std::time::{Duration, Instant};
6
7use tower::Service;
8
9use camel_api::{BoxProcessor, CamelError, Exchange, ThrottleStrategy, ThrottlerConfig, Value};
10
11const CAMEL_STOP: &str = "CamelStop";
12
13pub struct RateLimiter {
14    tokens: f64,
15    max_tokens: f64,
16    refill_rate: f64,
17    last_refill: Instant,
18}
19
20impl RateLimiter {
21    fn new(max_requests: usize, period: Duration) -> Self {
22        let refill_rate = max_requests as f64 / period.as_secs_f64();
23        Self {
24            tokens: max_requests as f64,
25            max_tokens: max_requests as f64,
26            refill_rate,
27            last_refill: Instant::now(),
28        }
29    }
30
31    fn try_acquire(&mut self) -> bool {
32        let now = Instant::now();
33        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
34        if elapsed > 0.0 {
35            self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
36            self.last_refill = now;
37        }
38        if self.tokens >= 1.0 {
39            self.tokens -= 1.0;
40            true
41        } else {
42            false
43        }
44    }
45
46    fn time_until_next_token(&self) -> Duration {
47        if self.tokens >= 1.0 {
48            Duration::ZERO
49        } else {
50            let tokens_needed = 1.0 - self.tokens;
51            Duration::from_secs_f64(tokens_needed / self.refill_rate)
52        }
53    }
54}
55
56#[derive(Clone)]
57pub struct ThrottlerService {
58    config: ThrottlerConfig,
59    limiter: std::sync::Arc<Mutex<RateLimiter>>,
60    next: BoxProcessor,
61}
62
63impl ThrottlerService {
64    pub fn new(config: ThrottlerConfig, next: BoxProcessor) -> Self {
65        assert!(
66            config.period > Duration::ZERO,
67            "throttler period must be > 0"
68        );
69        let limiter = RateLimiter::new(config.max_requests, config.period);
70        Self {
71            config,
72            limiter: std::sync::Arc::new(Mutex::new(limiter)),
73            next,
74        }
75    }
76}
77
78impl Service<Exchange> for ThrottlerService {
79    type Response = Exchange;
80    type Error = CamelError;
81    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
82
83    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84        self.next.poll_ready(cx)
85    }
86
87    fn call(&mut self, mut exchange: Exchange) -> Self::Future {
88        let config = self.config.clone();
89        let limiter = self.limiter.clone();
90        let mut next = self.next.clone();
91
92        Box::pin(async move {
93            let acquired = {
94                let mut limiter = limiter.lock().unwrap(); // allow-unwrap
95                limiter.try_acquire()
96            };
97
98            if acquired {
99                next.call(exchange).await
100            } else {
101                match config.strategy {
102                    ThrottleStrategy::Delay => {
103                        loop {
104                            let wait_time = {
105                                let limiter = limiter.lock().unwrap(); // allow-unwrap
106                                limiter.time_until_next_token()
107                            };
108                            if wait_time > Duration::ZERO {
109                                tokio::time::sleep(wait_time).await;
110                            }
111                            let acquired = {
112                                let mut limiter = limiter.lock().unwrap(); // allow-unwrap
113                                limiter.try_acquire()
114                            };
115                            if acquired {
116                                break;
117                            }
118                            // Yield to avoid tight spinning when concurrent tasks
119                            // wake simultaneously and contend for the same token.
120                            tokio::task::yield_now().await;
121                        }
122                        next.call(exchange).await
123                    }
124                    ThrottleStrategy::Reject => Err(CamelError::ProcessorError(
125                        "Throttled: rate limit exceeded".to_string(),
126                    )),
127                    ThrottleStrategy::Drop => {
128                        exchange.set_property(CAMEL_STOP, Value::Bool(true));
129                        Ok(exchange)
130                    }
131                }
132            }
133        })
134    }
135}
136
137/// Outcome-aware throttle segment (ADR-0025).
138///
139/// Wraps a `ThrottlerConfig` + shared `RateLimiter` + child sub-pipeline body.
140/// Unlike `ThrottlerService` (which operates at the Tower layer),
141/// `ThrottleSegment` correctly propagates `PipelineOutcome::Stopped` / `Failed`
142/// from the body.
143pub struct ThrottleSegment {
144    pub config: ThrottlerConfig,
145    pub limiter: std::sync::Arc<std::sync::Mutex<RateLimiter>>,
146    pub body: camel_api::OutcomeSegment,
147}
148
149impl ThrottleSegment {
150    pub fn new(config: ThrottlerConfig, body: camel_api::OutcomeSegment) -> Self {
151        assert!(
152            config.period > Duration::ZERO,
153            "throttler period must be > 0"
154        );
155        Self {
156            limiter: std::sync::Arc::new(std::sync::Mutex::new(RateLimiter::new(
157                config.max_requests,
158                config.period,
159            ))),
160            config,
161            body,
162        }
163    }
164}
165
166impl Clone for ThrottleSegment {
167    fn clone(&self) -> Self {
168        Self {
169            config: self.config.clone(),
170            limiter: std::sync::Arc::clone(&self.limiter),
171            body: self.body.clone(),
172        }
173    }
174}
175
176impl camel_api::OutcomePipeline for ThrottleSegment {
177    fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
178        Box::new(self.clone())
179    }
180
181    fn run<'a>(
182        &'a mut self,
183        exchange: camel_api::Exchange,
184    ) -> Pin<Box<dyn Future<Output = camel_api::PipelineOutcome> + Send + 'a>> {
185        Box::pin(async move {
186            let acquired = {
187                let mut limiter = self.limiter.lock().unwrap(); // allow-unwrap
188                limiter.try_acquire()
189            };
190            if acquired {
191                return self.body.run(exchange).await;
192            }
193            match self.config.strategy {
194                ThrottleStrategy::Delay => {
195                    loop {
196                        let wait_time = {
197                            let limiter = self.limiter.lock().unwrap(); // allow-unwrap
198                            limiter.time_until_next_token()
199                        };
200                        if wait_time > Duration::ZERO {
201                            tokio::time::sleep(wait_time).await;
202                        }
203                        let acquired = {
204                            let mut limiter = self.limiter.lock().unwrap(); // allow-unwrap
205                            limiter.try_acquire()
206                        };
207                        if acquired {
208                            break;
209                        }
210                        tokio::task::yield_now().await;
211                    }
212                    self.body.run(exchange).await
213                }
214                ThrottleStrategy::Reject => {
215                    camel_api::PipelineOutcome::Failed(camel_api::CamelError::ProcessorError(
216                        "Throttled: rate limit exceeded".to_string(),
217                    ))
218                }
219                ThrottleStrategy::Drop => {
220                    let mut ex = exchange;
221                    ex.set_property(CAMEL_STOP, camel_api::Value::Bool(true));
222                    camel_api::PipelineOutcome::Completed(ex)
223                }
224            }
225        })
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use camel_api::{BoxProcessorExt, Message};
233    use tower::ServiceExt;
234
235    fn passthrough() -> BoxProcessor {
236        BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
237    }
238
239    #[test]
240    fn test_throttler_zero_period_rejected() {
241        let config = ThrottlerConfig::new(5, Duration::ZERO);
242        let result = std::panic::catch_unwind(|| {
243            ThrottlerService::new(config, passthrough());
244        });
245        assert!(result.is_err(), "zero period should panic");
246    }
247
248    #[tokio::test]
249    async fn test_throttler_allows_under_limit() {
250        let config = ThrottlerConfig::new(5, Duration::from_secs(1));
251        let mut svc = ThrottlerService::new(config, passthrough());
252
253        for _ in 0..5 {
254            let ex = Exchange::new(Message::new("test"));
255            let result = svc.ready().await.unwrap().call(ex).await;
256            assert!(result.is_ok());
257        }
258    }
259
260    #[tokio::test]
261    async fn test_throttler_delay_strategy_queues_message() {
262        let config = ThrottlerConfig::new(1, Duration::from_millis(100));
263        let mut svc = ThrottlerService::new(config, passthrough());
264
265        let ex1 = Exchange::new(Message::new("first"));
266        let result1 = svc.ready().await.unwrap().call(ex1).await;
267        assert!(result1.is_ok());
268
269        let start = Instant::now();
270        let ex2 = Exchange::new(Message::new("second"));
271        let result2 = svc.ready().await.unwrap().call(ex2).await;
272        let elapsed = start.elapsed();
273        assert!(result2.is_ok());
274        assert!(elapsed >= Duration::from_millis(50));
275    }
276
277    #[tokio::test]
278    async fn test_throttler_reject_strategy_returns_error() {
279        let config =
280            ThrottlerConfig::new(1, Duration::from_secs(10)).strategy(ThrottleStrategy::Reject);
281        let mut svc = ThrottlerService::new(config, passthrough());
282
283        let ex1 = Exchange::new(Message::new("first"));
284        let _ = svc.ready().await.unwrap().call(ex1).await;
285
286        let ex2 = Exchange::new(Message::new("second"));
287        let result = svc.ready().await.unwrap().call(ex2).await;
288        assert!(result.is_err());
289        let err = result.unwrap_err().to_string();
290        assert!(err.contains("Throttled"));
291    }
292
293    #[tokio::test]
294    async fn test_throttler_drop_strategy_sets_camel_stop() {
295        let config =
296            ThrottlerConfig::new(1, Duration::from_secs(10)).strategy(ThrottleStrategy::Drop);
297        let mut svc = ThrottlerService::new(config, passthrough());
298
299        let ex1 = Exchange::new(Message::new("first"));
300        let _ = svc.ready().await.unwrap().call(ex1).await;
301
302        let ex2 = Exchange::new(Message::new("second"));
303        let result = svc.ready().await.unwrap().call(ex2).await.unwrap();
304        assert_eq!(result.property(CAMEL_STOP), Some(&Value::Bool(true)));
305    }
306
307    #[tokio::test]
308    async fn test_throttler_token_replenishment() {
309        let config = ThrottlerConfig::new(1, Duration::from_millis(50));
310        let mut svc = ThrottlerService::new(config, passthrough());
311
312        let ex1 = Exchange::new(Message::new("first"));
313        let _ = svc.ready().await.unwrap().call(ex1).await;
314
315        tokio::time::sleep(Duration::from_millis(100)).await;
316
317        let ex2 = Exchange::new(Message::new("second"));
318        let result = svc.ready().await.unwrap().call(ex2).await;
319        assert!(result.is_ok());
320    }
321
322    // ── ThrottleSegment tests (ADR-0025 OutcomePipeline parity) ────────────
323
324    #[tokio::test]
325    async fn throttle_segment_reject_strategy_returns_failed() {
326        use camel_api::{Exchange, Message, OutcomePipeline, PipelineOutcome};
327
328        #[derive(Clone)]
329        struct NoopSeg;
330        impl OutcomePipeline for NoopSeg {
331            fn clone_box(&self) -> Box<dyn OutcomePipeline> {
332                Box::new(NoopSeg)
333            }
334            fn run<'a>(
335                &'a mut self,
336                ex: Exchange,
337            ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
338                Box::pin(async move { PipelineOutcome::Completed(ex) })
339            }
340        }
341
342        let config = ThrottlerConfig {
343            max_requests: 0, // 0 tokens immediately exhausted
344            period: Duration::from_secs(1),
345            strategy: ThrottleStrategy::Reject,
346        };
347        let body = camel_api::OutcomeSegment::new(Box::new(NoopSeg));
348        let mut seg = ThrottleSegment::new(config, body);
349        let ex = Exchange::new(Message::new("test"));
350        let outcome = seg.run(ex).await;
351        assert!(
352            matches!(outcome, PipelineOutcome::Failed(_)),
353            "Reject strategy must return Failed when tokens exhausted"
354        );
355    }
356
357    #[tokio::test]
358    async fn throttle_segment_drop_strategy_sets_camel_stop_and_completes() {
359        use camel_api::{Exchange, Message, OutcomePipeline, PipelineOutcome};
360
361        #[derive(Clone)]
362        struct NoopSeg;
363        impl OutcomePipeline for NoopSeg {
364            fn clone_box(&self) -> Box<dyn OutcomePipeline> {
365                Box::new(NoopSeg)
366            }
367            fn run<'a>(
368                &'a mut self,
369                ex: Exchange,
370            ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
371                Box::pin(async move { PipelineOutcome::Completed(ex) })
372            }
373        }
374
375        let config = ThrottlerConfig {
376            max_requests: 0,
377            period: Duration::from_secs(1),
378            strategy: ThrottleStrategy::Drop,
379        };
380        let body = camel_api::OutcomeSegment::new(Box::new(NoopSeg));
381        let mut seg = ThrottleSegment::new(config, body);
382        let ex = Exchange::new(Message::new("test"));
383        let outcome = seg.run(ex).await;
384        match outcome {
385            PipelineOutcome::Completed(returned_ex) => {
386                let stopped_flag = returned_ex.property(CAMEL_STOP).and_then(|v| v.as_bool());
387                assert_eq!(
388                    stopped_flag,
389                    Some(true),
390                    "Drop strategy must set CamelStop=true property"
391                );
392            }
393            other => panic!("Drop must return Completed, got {:?}", other),
394        }
395    }
396
397    #[tokio::test]
398    async fn throttle_segment_delay_strategy_propagates_stopped_body() {
399        use camel_api::{Body, Exchange, Message, OutcomePipeline, PipelineOutcome};
400
401        #[derive(Clone)]
402        struct StoppingSeg;
403        impl OutcomePipeline for StoppingSeg {
404            fn clone_box(&self) -> Box<dyn OutcomePipeline> {
405                Box::new(StoppingSeg)
406            }
407            fn run<'a>(
408                &'a mut self,
409                mut ex: Exchange,
410            ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
411                Box::pin(async move {
412                    ex.input.body = Body::Bytes(b"stopped-mut".to_vec().into());
413                    PipelineOutcome::Stopped(ex)
414                })
415            }
416        }
417
418        let config = ThrottlerConfig {
419            max_requests: 1, // 1 token available immediately
420            period: Duration::from_secs(1),
421            strategy: ThrottleStrategy::Delay,
422        };
423        let body = camel_api::OutcomeSegment::new(Box::new(StoppingSeg));
424        let mut seg = ThrottleSegment::new(config, body);
425        let ex = Exchange::new(Message::new("test"));
426        let outcome = seg.run(ex).await;
427        match outcome {
428            PipelineOutcome::Stopped(returned_ex) => {
429                if let Body::Bytes(b) = &returned_ex.input.body {
430                    assert_eq!(
431                        b.as_ref(),
432                        b"stopped-mut",
433                        "BUG: throttle body Stop must preserve mutations"
434                    );
435                } else {
436                    panic!("expected Body::Bytes");
437                }
438            }
439            other => panic!("expected Stopped propagation, got {:?}", other),
440        }
441    }
442}