Skip to main content

camel_processor/
throttler.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Mutex;
4use std::task::{Context, Poll};
5use std::time::{Duration, Instant};
6
7use tower::Service;
8
9use camel_api::CAMEL_STOP;
10use camel_api::{BoxProcessor, CamelError, Exchange, ThrottleStrategy, ThrottlerConfig, Value};
11
12pub struct RateLimiter {
13    tokens: f64,
14    max_tokens: f64,
15    refill_rate: f64,
16    last_refill: Instant,
17}
18
19impl RateLimiter {
20    fn new(max_requests: usize, period: Duration) -> Self {
21        let refill_rate = max_requests as f64 / period.as_secs_f64();
22        Self {
23            tokens: max_requests as f64,
24            max_tokens: max_requests as f64,
25            refill_rate,
26            last_refill: Instant::now(),
27        }
28    }
29
30    fn try_acquire(&mut self) -> bool {
31        let now = Instant::now();
32        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
33        if elapsed > 0.0 {
34            self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
35            self.last_refill = now;
36        }
37        if self.tokens >= 1.0 {
38            self.tokens -= 1.0;
39            true
40        } else {
41            false
42        }
43    }
44
45    fn time_until_next_token(&self) -> Duration {
46        if self.tokens >= 1.0 {
47            Duration::ZERO
48        } else {
49            let tokens_needed = 1.0 - self.tokens;
50            Duration::from_secs_f64(tokens_needed / self.refill_rate)
51        }
52    }
53}
54
55#[derive(Clone)]
56pub struct ThrottlerService {
57    config: ThrottlerConfig,
58    limiter: std::sync::Arc<Mutex<RateLimiter>>,
59    next: BoxProcessor,
60}
61
62impl ThrottlerService {
63    pub fn new(config: ThrottlerConfig, next: BoxProcessor) -> Self {
64        assert!(
65            config.period > Duration::ZERO,
66            "throttler period must be > 0"
67        );
68        let limiter = RateLimiter::new(config.max_requests, config.period);
69        Self {
70            config,
71            limiter: std::sync::Arc::new(Mutex::new(limiter)),
72            next,
73        }
74    }
75}
76
77impl Service<Exchange> for ThrottlerService {
78    type Response = Exchange;
79    type Error = CamelError;
80    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
81
82    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
83        self.next.poll_ready(cx)
84    }
85
86    fn call(&mut self, mut exchange: Exchange) -> Self::Future {
87        let config = self.config.clone();
88        let limiter = self.limiter.clone();
89        let mut next = self.next.clone();
90
91        Box::pin(async move {
92            let acquired = {
93                let mut limiter = limiter.lock().unwrap(); // allow-unwrap
94                limiter.try_acquire()
95            };
96
97            if acquired {
98                next.call(exchange).await
99            } else {
100                match config.strategy {
101                    ThrottleStrategy::Delay => {
102                        loop {
103                            let wait_time = {
104                                let limiter = limiter.lock().unwrap(); // allow-unwrap
105                                limiter.time_until_next_token()
106                            };
107                            if wait_time > Duration::ZERO {
108                                tokio::time::sleep(wait_time).await;
109                            }
110                            let acquired = {
111                                let mut limiter = limiter.lock().unwrap(); // allow-unwrap
112                                limiter.try_acquire()
113                            };
114                            if acquired {
115                                break;
116                            }
117                            // Yield to avoid tight spinning when concurrent tasks
118                            // wake simultaneously and contend for the same token.
119                            tokio::task::yield_now().await;
120                        }
121                        next.call(exchange).await
122                    }
123                    ThrottleStrategy::Reject => Err(CamelError::ProcessorError(
124                        "Throttled: rate limit exceeded".to_string(),
125                    )),
126                    ThrottleStrategy::Drop => {
127                        exchange.set_property(CAMEL_STOP, Value::Bool(true));
128                        Ok(exchange)
129                    }
130                }
131            }
132        })
133    }
134}
135
136/// Outcome-aware throttle segment (ADR-0025).
137///
138/// Wraps a `ThrottlerConfig` + shared `RateLimiter` + child sub-pipeline body.
139/// Unlike `ThrottlerService` (which operates at the Tower layer),
140/// `ThrottleSegment` correctly propagates `PipelineOutcome::Stopped` / `Failed`
141/// from the body.
142pub struct ThrottleSegment {
143    pub config: ThrottlerConfig,
144    pub limiter: std::sync::Arc<std::sync::Mutex<RateLimiter>>,
145    pub body: camel_api::OutcomeSegment,
146}
147
148impl ThrottleSegment {
149    pub fn new(config: ThrottlerConfig, body: camel_api::OutcomeSegment) -> Self {
150        assert!(
151            config.period > Duration::ZERO,
152            "throttler period must be > 0"
153        );
154        Self {
155            limiter: std::sync::Arc::new(std::sync::Mutex::new(RateLimiter::new(
156                config.max_requests,
157                config.period,
158            ))),
159            config,
160            body,
161        }
162    }
163}
164
165impl Clone for ThrottleSegment {
166    fn clone(&self) -> Self {
167        Self {
168            config: self.config.clone(),
169            limiter: std::sync::Arc::clone(&self.limiter),
170            body: self.body.clone(),
171        }
172    }
173}
174
175impl camel_api::OutcomePipeline for ThrottleSegment {
176    fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
177        Box::new(self.clone())
178    }
179
180    fn run<'a>(
181        &'a mut self,
182        exchange: camel_api::Exchange,
183    ) -> Pin<Box<dyn Future<Output = camel_api::PipelineOutcome> + Send + 'a>> {
184        Box::pin(async move {
185            let acquired = {
186                let mut limiter = self.limiter.lock().unwrap(); // allow-unwrap
187                limiter.try_acquire()
188            };
189            if acquired {
190                return self.body.run(exchange).await;
191            }
192            match self.config.strategy {
193                ThrottleStrategy::Delay => {
194                    loop {
195                        let wait_time = {
196                            let limiter = self.limiter.lock().unwrap(); // allow-unwrap
197                            limiter.time_until_next_token()
198                        };
199                        if wait_time > Duration::ZERO {
200                            tokio::time::sleep(wait_time).await;
201                        }
202                        let acquired = {
203                            let mut limiter = self.limiter.lock().unwrap(); // allow-unwrap
204                            limiter.try_acquire()
205                        };
206                        if acquired {
207                            break;
208                        }
209                        tokio::task::yield_now().await;
210                    }
211                    self.body.run(exchange).await
212                }
213                ThrottleStrategy::Reject => {
214                    camel_api::PipelineOutcome::Failed(camel_api::CamelError::ProcessorError(
215                        "Throttled: rate limit exceeded".to_string(),
216                    ))
217                }
218                ThrottleStrategy::Drop => {
219                    let mut ex = exchange;
220                    ex.set_property(CAMEL_STOP, camel_api::Value::Bool(true));
221                    camel_api::PipelineOutcome::Stopped(ex)
222                }
223            }
224        })
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use camel_api::{BoxProcessorExt, Message};
232    use tower::ServiceExt;
233
234    fn passthrough() -> BoxProcessor {
235        BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
236    }
237
238    #[test]
239    fn test_throttler_zero_period_rejected() {
240        let config = ThrottlerConfig::new(5, Duration::ZERO);
241        let result = std::panic::catch_unwind(|| {
242            ThrottlerService::new(config, passthrough());
243        });
244        assert!(result.is_err(), "zero period should panic");
245    }
246
247    #[tokio::test]
248    async fn test_throttler_allows_under_limit() {
249        let config = ThrottlerConfig::new(5, Duration::from_secs(1));
250        let mut svc = ThrottlerService::new(config, passthrough());
251
252        for _ in 0..5 {
253            let ex = Exchange::new(Message::new("test"));
254            let result = svc.ready().await.unwrap().call(ex).await;
255            assert!(result.is_ok());
256        }
257    }
258
259    #[tokio::test]
260    async fn test_throttler_delay_strategy_queues_message() {
261        let config = ThrottlerConfig::new(1, Duration::from_millis(100));
262        let mut svc = ThrottlerService::new(config, passthrough());
263
264        let ex1 = Exchange::new(Message::new("first"));
265        let result1 = svc.ready().await.unwrap().call(ex1).await;
266        assert!(result1.is_ok());
267
268        let start = Instant::now();
269        let ex2 = Exchange::new(Message::new("second"));
270        let result2 = svc.ready().await.unwrap().call(ex2).await;
271        let elapsed = start.elapsed();
272        assert!(result2.is_ok());
273        assert!(elapsed >= Duration::from_millis(50));
274    }
275
276    #[tokio::test]
277    async fn test_throttler_reject_strategy_returns_error() {
278        let config =
279            ThrottlerConfig::new(1, Duration::from_secs(10)).strategy(ThrottleStrategy::Reject);
280        let mut svc = ThrottlerService::new(config, passthrough());
281
282        let ex1 = Exchange::new(Message::new("first"));
283        let _ = svc.ready().await.unwrap().call(ex1).await;
284
285        let ex2 = Exchange::new(Message::new("second"));
286        let result = svc.ready().await.unwrap().call(ex2).await;
287        assert!(result.is_err());
288        let err = result.unwrap_err().to_string();
289        assert!(err.contains("Throttled"));
290    }
291
292    #[tokio::test]
293    async fn test_throttler_drop_strategy_sets_camel_stop() {
294        let config =
295            ThrottlerConfig::new(1, Duration::from_secs(10)).strategy(ThrottleStrategy::Drop);
296        let mut svc = ThrottlerService::new(config, passthrough());
297
298        let ex1 = Exchange::new(Message::new("first"));
299        let _ = svc.ready().await.unwrap().call(ex1).await;
300
301        let ex2 = Exchange::new(Message::new("second"));
302        let result = svc.ready().await.unwrap().call(ex2).await.unwrap();
303        assert_eq!(result.property(CAMEL_STOP), Some(&Value::Bool(true)));
304    }
305
306    #[tokio::test]
307    async fn test_throttler_token_replenishment() {
308        let config = ThrottlerConfig::new(1, Duration::from_millis(50));
309        let mut svc = ThrottlerService::new(config, passthrough());
310
311        let ex1 = Exchange::new(Message::new("first"));
312        let _ = svc.ready().await.unwrap().call(ex1).await;
313
314        tokio::time::sleep(Duration::from_millis(100)).await;
315
316        let ex2 = Exchange::new(Message::new("second"));
317        let result = svc.ready().await.unwrap().call(ex2).await;
318        assert!(result.is_ok());
319    }
320
321    // ── ThrottleSegment tests (ADR-0025 OutcomePipeline parity) ────────────
322
323    #[tokio::test]
324    async fn throttle_segment_reject_strategy_returns_failed() {
325        use camel_api::{Exchange, Message, OutcomePipeline, PipelineOutcome};
326
327        #[derive(Clone)]
328        struct NoopSeg;
329        impl OutcomePipeline for NoopSeg {
330            fn clone_box(&self) -> Box<dyn OutcomePipeline> {
331                Box::new(NoopSeg)
332            }
333            fn run<'a>(
334                &'a mut self,
335                ex: Exchange,
336            ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
337                Box::pin(async move { PipelineOutcome::Completed(ex) })
338            }
339        }
340
341        let config = ThrottlerConfig {
342            max_requests: 0, // 0 tokens immediately exhausted
343            period: Duration::from_secs(1),
344            strategy: ThrottleStrategy::Reject,
345        };
346        let body = camel_api::OutcomeSegment::new(Box::new(NoopSeg));
347        let mut seg = ThrottleSegment::new(config, body);
348        let ex = Exchange::new(Message::new("test"));
349        let outcome = seg.run(ex).await;
350        assert!(
351            matches!(outcome, PipelineOutcome::Failed(_)),
352            "Reject strategy must return Failed when tokens exhausted"
353        );
354    }
355
356    #[tokio::test]
357    async fn throttle_segment_drop_strategy_returns_stopped() {
358        use camel_api::{Exchange, Message, OutcomePipeline, PipelineOutcome};
359
360        #[derive(Clone)]
361        struct NoopSeg;
362        impl OutcomePipeline for NoopSeg {
363            fn clone_box(&self) -> Box<dyn OutcomePipeline> {
364                Box::new(NoopSeg)
365            }
366            fn run<'a>(
367                &'a mut self,
368                ex: Exchange,
369            ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
370                Box::pin(async move { PipelineOutcome::Completed(ex) })
371            }
372        }
373
374        let config = ThrottlerConfig {
375            max_requests: 0,
376            period: Duration::from_secs(1),
377            strategy: ThrottleStrategy::Drop,
378        };
379        let body = camel_api::OutcomeSegment::new(Box::new(NoopSeg));
380        let mut seg = ThrottleSegment::new(config, body);
381        let ex = Exchange::new(Message::new("test"));
382        let outcome = seg.run(ex).await;
383        match outcome {
384            PipelineOutcome::Stopped(returned_ex) => {
385                let stopped_flag = returned_ex.property(CAMEL_STOP).and_then(|v| v.as_bool());
386                assert_eq!(
387                    stopped_flag,
388                    Some(true),
389                    "Drop strategy must set CamelStop=true property"
390                );
391            }
392            other => panic!("Drop must return Stopped, got {:?}", other),
393        }
394    }
395
396    #[tokio::test]
397    async fn throttle_segment_delay_strategy_propagates_stopped_body() {
398        use camel_api::{Body, Exchange, Message, OutcomePipeline, PipelineOutcome};
399
400        #[derive(Clone)]
401        struct StoppingSeg;
402        impl OutcomePipeline for StoppingSeg {
403            fn clone_box(&self) -> Box<dyn OutcomePipeline> {
404                Box::new(StoppingSeg)
405            }
406            fn run<'a>(
407                &'a mut self,
408                mut ex: Exchange,
409            ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
410                Box::pin(async move {
411                    ex.input.body = Body::Bytes(b"stopped-mut".to_vec().into());
412                    PipelineOutcome::Stopped(ex)
413                })
414            }
415        }
416
417        let config = ThrottlerConfig {
418            max_requests: 1, // 1 token available immediately
419            period: Duration::from_secs(1),
420            strategy: ThrottleStrategy::Delay,
421        };
422        let body = camel_api::OutcomeSegment::new(Box::new(StoppingSeg));
423        let mut seg = ThrottleSegment::new(config, body);
424        let ex = Exchange::new(Message::new("test"));
425        let outcome = seg.run(ex).await;
426        match outcome {
427            PipelineOutcome::Stopped(returned_ex) => {
428                if let Body::Bytes(b) = &returned_ex.input.body {
429                    assert_eq!(
430                        b.as_ref(),
431                        b"stopped-mut",
432                        "BUG: throttle body Stop must preserve mutations"
433                    );
434                } else {
435                    panic!("expected Body::Bytes");
436                }
437            }
438            other => panic!("expected Stopped propagation, got {:?}", other),
439        }
440    }
441}