1use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::{OwnedSemaphorePermit, Semaphore};
10use tokio::time::sleep;
11
12#[derive(Debug, Clone)]
14pub struct BackpressureConfig {
15 pub initial_window: usize,
17 pub min_window: usize,
19 pub max_window: usize,
21 pub increase_factor: f64,
23 pub decrease_factor: f64,
25 pub slow_consumer_threshold: f64,
27 pub check_interval: Duration,
29}
30
31impl Default for BackpressureConfig {
32 fn default() -> Self {
33 Self {
34 initial_window: 100,
35 min_window: 10,
36 max_window: 10000,
37 increase_factor: 1.5,
38 decrease_factor: 0.5,
39 slow_consumer_threshold: 0.8,
40 check_interval: Duration::from_millis(100),
41 }
42 }
43}
44
45#[derive(Clone)]
47pub struct BackpressureController {
48 config: Arc<BackpressureConfig>,
49 semaphore: Arc<Semaphore>,
50 window_size: Arc<AtomicUsize>,
51 items_sent: Arc<AtomicU64>,
52 items_consumed: Arc<AtomicU64>,
53 last_adjustment: Arc<tokio::sync::Mutex<Instant>>,
54}
55
56impl BackpressureController {
57 pub fn new(config: BackpressureConfig) -> Self {
59 let initial_window = config.initial_window;
60 Self {
61 semaphore: Arc::new(Semaphore::new(initial_window)),
62 window_size: Arc::new(AtomicUsize::new(initial_window)),
63 items_sent: Arc::new(AtomicU64::new(0)),
64 items_consumed: Arc::new(AtomicU64::new(0)),
65 last_adjustment: Arc::new(tokio::sync::Mutex::new(Instant::now())),
66 config: Arc::new(config),
67 }
68 }
69
70 pub async fn acquire(&self) -> Result<BackpressurePermit, BackpressureError> {
72 let permit = self
73 .semaphore
74 .clone()
75 .acquire_owned()
76 .await
77 .map_err(|_| BackpressureError::Closed)?;
78
79 self.items_sent.fetch_add(1, Ordering::Relaxed);
80
81 Ok(BackpressurePermit {
82 _permit: permit,
83 controller: self.clone(),
84 })
85 }
86
87 pub fn try_acquire(&self) -> Result<BackpressurePermit, BackpressureError> {
89 let permit = self
90 .semaphore
91 .clone()
92 .try_acquire_owned()
93 .map_err(|_| BackpressureError::WouldBlock)?;
94
95 self.items_sent.fetch_add(1, Ordering::Relaxed);
96
97 Ok(BackpressurePermit {
98 _permit: permit,
99 controller: self.clone(),
100 })
101 }
102
103 pub fn signal_consumed(&self) {
105 self.items_consumed.fetch_add(1, Ordering::Relaxed);
106 }
107
108 pub fn window_size(&self) -> usize {
110 self.window_size.load(Ordering::Relaxed)
111 }
112
113 pub fn items_sent(&self) -> u64 {
115 self.items_sent.load(Ordering::Relaxed)
116 }
117
118 pub fn items_consumed(&self) -> u64 {
120 self.items_consumed.load(Ordering::Relaxed)
121 }
122
123 pub fn pending_items(&self) -> u64 {
125 let sent = self.items_sent();
126 let consumed = self.items_consumed();
127 sent.saturating_sub(consumed)
128 }
129
130 pub async fn check_congestion(&self) {
132 let mut last_adjustment = self.last_adjustment.lock().await;
133 let now = Instant::now();
134
135 if now.duration_since(*last_adjustment) < self.config.check_interval {
137 return;
138 }
139
140 let pending = self.pending_items();
141 let window = self.window_size() as u64;
142
143 if window == 0 {
144 return;
145 }
146
147 let utilization = pending as f64 / window as f64;
148
149 if utilization >= self.config.slow_consumer_threshold {
151 self.decrease_window().await;
152 tracing::debug!(
153 "Congestion detected, decreased window to {}",
154 self.window_size()
155 );
156 } else if utilization < 0.5 && (window as usize) < self.config.max_window {
157 self.increase_window().await;
159 tracing::debug!(
160 "Low utilization, increased window to {}",
161 self.window_size()
162 );
163 }
164
165 *last_adjustment = now;
166 }
167
168 async fn increase_window(&self) {
170 let current = self.window_size();
171 let new_size =
172 ((current as f64 * self.config.increase_factor) as usize).min(self.config.max_window);
173
174 if new_size > current {
175 let diff = new_size - current;
176 self.window_size.store(new_size, Ordering::Relaxed);
177 self.semaphore.add_permits(diff);
178 }
179 }
180
181 async fn decrease_window(&self) {
183 let current = self.window_size();
184 let new_size =
185 ((current as f64 * self.config.decrease_factor) as usize).max(self.config.min_window);
186
187 if new_size < current {
188 self.window_size.store(new_size, Ordering::Relaxed);
189 }
192 }
193
194 pub async fn adaptive_delay(&self) {
196 let pending = self.pending_items();
197 let window = self.window_size() as u64;
198
199 if window == 0 {
200 return;
201 }
202
203 let utilization = pending as f64 / window as f64;
204
205 if utilization > self.config.slow_consumer_threshold {
206 let delay_ms = ((utilization - self.config.slow_consumer_threshold) * 100.0) as u64;
208 sleep(Duration::from_millis(delay_ms)).await;
209 }
210 }
211
212 pub fn start_monitoring(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
214 tokio::spawn(async move {
215 loop {
216 self.check_congestion().await;
217 sleep(self.config.check_interval).await;
218 }
219 })
220 }
221}
222
223impl Default for BackpressureController {
224 fn default() -> Self {
225 Self::new(BackpressureConfig::default())
226 }
227}
228
229pub struct BackpressurePermit {
231 _permit: OwnedSemaphorePermit,
232 #[allow(dead_code)]
233 controller: BackpressureController,
234}
235
236impl Drop for BackpressurePermit {
237 fn drop(&mut self) {
238 }
242}
243
244#[derive(Debug, Clone, thiserror::Error)]
246pub enum BackpressureError {
247 #[error("Backpressure controller is closed")]
248 Closed,
249 #[error("Would block, no permits available")]
250 WouldBlock,
251}
252
253pub struct BackpressureStream<S> {
255 inner: S,
256 controller: Arc<BackpressureController>,
257}
258
259impl<S> BackpressureStream<S> {
260 pub fn new(stream: S, controller: Arc<BackpressureController>) -> Self {
262 Self {
263 inner: stream,
264 controller,
265 }
266 }
267
268 pub fn controller(&self) -> &Arc<BackpressureController> {
270 &self.controller
271 }
272}
273
274impl<S> futures::Stream for BackpressureStream<S>
275where
276 S: futures::Stream + Unpin,
277{
278 type Item = S::Item;
279
280 fn poll_next(
281 mut self: std::pin::Pin<&mut Self>,
282 cx: &mut std::task::Context<'_>,
283 ) -> std::task::Poll<Option<Self::Item>> {
284 let pending = self.controller.pending_items();
286 let window = self.controller.window_size() as u64;
287
288 if window > 0 && pending >= window {
289 cx.waker().wake_by_ref();
291 return std::task::Poll::Pending;
292 }
293
294 std::pin::Pin::new(&mut self.inner).poll_next(cx)
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[tokio::test]
304 async fn test_backpressure_controller_creation() {
305 let config = BackpressureConfig::default();
306 let controller = BackpressureController::new(config);
307
308 assert_eq!(controller.window_size(), 100);
309 assert_eq!(controller.items_sent(), 0);
310 assert_eq!(controller.items_consumed(), 0);
311 assert_eq!(controller.pending_items(), 0);
312 }
313
314 #[tokio::test]
315 async fn test_acquire_permit() {
316 let controller = BackpressureController::default();
317
318 let permit = controller.acquire().await.unwrap();
319 assert_eq!(controller.items_sent(), 1);
320 assert_eq!(controller.items_consumed(), 0);
321 assert_eq!(controller.pending_items(), 1);
322
323 drop(permit);
325 controller.signal_consumed();
326 assert_eq!(controller.items_consumed(), 1);
327 assert_eq!(controller.pending_items(), 0);
328 }
329
330 #[tokio::test]
331 async fn test_try_acquire() {
332 let config = BackpressureConfig {
333 initial_window: 2,
334 ..Default::default()
335 };
336 let controller = BackpressureController::new(config);
337
338 let _permit1 = controller.try_acquire().unwrap();
339 let _permit2 = controller.try_acquire().unwrap();
340
341 assert!(controller.try_acquire().is_err());
343 }
344
345 #[tokio::test]
346 async fn test_congestion_detection() {
347 let config = BackpressureConfig {
348 initial_window: 10,
349 min_window: 5,
350 slow_consumer_threshold: 0.8,
351 check_interval: Duration::from_millis(10),
352 ..Default::default()
353 };
354 let controller = BackpressureController::new(config);
355
356 let mut permits = Vec::new();
358 for _ in 0..9 {
359 permits.push(controller.acquire().await.unwrap());
360 }
361
362 assert_eq!(controller.pending_items(), 9);
364
365 sleep(Duration::from_millis(20)).await;
367 controller.check_congestion().await;
368
369 assert!(controller.window_size() < 10);
371 }
372
373 #[tokio::test]
374 async fn test_window_increase() {
375 let config = BackpressureConfig {
376 initial_window: 10,
377 max_window: 100,
378 increase_factor: 2.0,
379 check_interval: Duration::from_millis(10),
380 ..Default::default()
381 };
382 let controller = BackpressureController::new(config);
383
384 sleep(Duration::from_millis(20)).await;
386 controller.check_congestion().await;
387
388 assert!(controller.window_size() > 10);
390 }
391
392 #[tokio::test]
393 async fn test_adaptive_delay() {
394 let config = BackpressureConfig {
395 initial_window: 10,
396 slow_consumer_threshold: 0.8,
397 ..Default::default()
398 };
399 let controller = BackpressureController::new(config);
400
401 let mut permits = Vec::new();
403 for _ in 0..9 {
404 permits.push(controller.acquire().await.unwrap());
405 }
406
407 let start = Instant::now();
408 controller.adaptive_delay().await;
409 let elapsed = start.elapsed();
410
411 assert!(elapsed >= Duration::from_millis(0));
413 }
414
415 #[tokio::test]
416 async fn test_automatic_monitoring() {
417 let config = BackpressureConfig {
418 initial_window: 10,
419 check_interval: Duration::from_millis(50),
420 ..Default::default()
421 };
422 let controller = Arc::new(BackpressureController::new(config));
423
424 let handle = controller.clone().start_monitoring();
426
427 sleep(Duration::from_millis(200)).await;
429
430 handle.abort();
432
433 let _permit = controller.acquire().await.unwrap();
435 }
436
437 #[tokio::test]
438 async fn test_signal_consumed() {
439 let controller = BackpressureController::default();
440
441 controller.signal_consumed();
442 assert_eq!(controller.items_consumed(), 1);
443
444 controller.signal_consumed();
445 assert_eq!(controller.items_consumed(), 2);
446 }
447}