Skip to main content

oxigdal_streaming/core/
flow_control.rs

1//! Flow control mechanisms for stream processing.
2
3use crate::error::{Result, StreamingError};
4use serde::{Deserialize, Serialize};
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::time::{Duration, Instant};
8use tokio::sync::RwLock;
9use tokio::time::sleep;
10
11/// Configuration for flow control.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct FlowControlConfig {
14    /// Maximum rate (elements per second)
15    pub max_rate: Option<f64>,
16
17    /// Burst size
18    pub burst_size: usize,
19
20    /// Enable rate limiting
21    pub enable_rate_limiting: bool,
22
23    /// Smoothing factor for rate adjustment
24    pub smoothing_factor: f64,
25
26    /// Target latency for adaptive control
27    pub target_latency: Duration,
28
29    /// Adjustment interval
30    pub adjustment_interval: Duration,
31}
32
33impl Default for FlowControlConfig {
34    fn default() -> Self {
35        Self {
36            max_rate: None,
37            burst_size: 100,
38            enable_rate_limiting: false,
39            smoothing_factor: 0.1,
40            target_latency: Duration::from_millis(100),
41            adjustment_interval: Duration::from_secs(5),
42        }
43    }
44}
45
46/// Metrics for flow control.
47#[derive(Debug, Clone, Default)]
48pub struct FlowControlMetrics {
49    /// Current rate (elements per second)
50    pub current_rate: f64,
51
52    /// Target rate (elements per second)
53    pub target_rate: Option<f64>,
54
55    /// Number of throttled operations
56    pub throttled_count: u64,
57
58    /// Total delay introduced (milliseconds)
59    pub total_delay_ms: u64,
60
61    /// Average processing latency
62    pub avg_latency: Duration,
63}
64
65/// Token bucket for rate limiting.
66struct TokenBucket {
67    /// Available tokens
68    tokens: Arc<RwLock<f64>>,
69
70    /// Maximum tokens (burst size)
71    max_tokens: f64,
72
73    /// Refill rate (tokens per second)
74    refill_rate: f64,
75
76    /// Last refill time
77    last_refill: Arc<RwLock<Instant>>,
78}
79
80impl TokenBucket {
81    fn new(max_tokens: usize, refill_rate: f64) -> Self {
82        Self {
83            tokens: Arc::new(RwLock::new(max_tokens as f64)),
84            max_tokens: max_tokens as f64,
85            refill_rate,
86            last_refill: Arc::new(RwLock::new(Instant::now())),
87        }
88    }
89
90    async fn try_acquire(&self, count: usize) -> bool {
91        self.refill().await;
92
93        let mut tokens = self.tokens.write().await;
94        if *tokens >= count as f64 {
95            *tokens -= count as f64;
96            true
97        } else {
98            false
99        }
100    }
101
102    async fn refill(&self) {
103        let now = Instant::now();
104        let mut last_refill = self.last_refill.write().await;
105        let elapsed = now.duration_since(*last_refill).as_secs_f64();
106
107        if elapsed > 0.0 {
108            let mut tokens = self.tokens.write().await;
109            let new_tokens = elapsed * self.refill_rate;
110            *tokens = (*tokens + new_tokens).min(self.max_tokens);
111            *last_refill = now;
112        }
113    }
114
115    async fn wait_for_tokens(&self, count: usize) -> Duration {
116        self.refill().await;
117
118        let tokens = self.tokens.read().await;
119        if *tokens >= count as f64 {
120            return Duration::ZERO;
121        }
122
123        let needed = count as f64 - *tokens;
124        let wait_time = needed / self.refill_rate;
125        Duration::from_secs_f64(wait_time)
126    }
127}
128
129/// Flow controller for managing stream throughput.
130pub struct FlowController {
131    config: FlowControlConfig,
132    token_bucket: Option<TokenBucket>,
133    metrics: Arc<RwLock<FlowControlMetrics>>,
134    elements_processed: AtomicU64,
135    throttled_count: AtomicU64,
136    total_delay_ms: AtomicU64,
137    start_time: Instant,
138    last_adjustment: Arc<RwLock<Instant>>,
139}
140
141impl FlowController {
142    /// Create a new flow controller.
143    pub fn new(config: FlowControlConfig) -> Self {
144        let token_bucket = if config.enable_rate_limiting && config.max_rate.is_some() {
145            Some(TokenBucket::new(
146                config.burst_size,
147                config.max_rate.unwrap_or(1000.0),
148            ))
149        } else {
150            None
151        };
152
153        Self {
154            config,
155            token_bucket,
156            metrics: Arc::new(RwLock::new(FlowControlMetrics::default())),
157            elements_processed: AtomicU64::new(0),
158            throttled_count: AtomicU64::new(0),
159            total_delay_ms: AtomicU64::new(0),
160            start_time: Instant::now(),
161            last_adjustment: Arc::new(RwLock::new(Instant::now())),
162        }
163    }
164
165    /// Acquire permission to process elements.
166    pub async fn acquire(&self, count: usize) -> Result<()> {
167        if !self.config.enable_rate_limiting {
168            self.elements_processed
169                .fetch_add(count as u64, Ordering::Relaxed);
170            return Ok(());
171        }
172
173        if let Some(ref bucket) = self.token_bucket {
174            if !bucket.try_acquire(count).await {
175                let wait_time = bucket.wait_for_tokens(count).await;
176
177                if wait_time > Duration::ZERO {
178                    self.throttled_count.fetch_add(1, Ordering::Relaxed);
179                    self.total_delay_ms
180                        .fetch_add(wait_time.as_millis() as u64, Ordering::Relaxed);
181
182                    sleep(wait_time).await;
183
184                    // Try again after waiting
185                    if !bucket.try_acquire(count).await {
186                        return Err(StreamingError::Other(
187                            "Failed to acquire tokens after waiting".to_string(),
188                        ));
189                    }
190                }
191            }
192        }
193
194        self.elements_processed
195            .fetch_add(count as u64, Ordering::Relaxed);
196        Ok(())
197    }
198
199    /// Try to acquire without blocking.
200    pub async fn try_acquire(&self, count: usize) -> bool {
201        if !self.config.enable_rate_limiting {
202            self.elements_processed
203                .fetch_add(count as u64, Ordering::Relaxed);
204            return true;
205        }
206
207        if let Some(ref bucket) = self.token_bucket {
208            if bucket.try_acquire(count).await {
209                self.elements_processed
210                    .fetch_add(count as u64, Ordering::Relaxed);
211                true
212            } else {
213                false
214            }
215        } else {
216            self.elements_processed
217                .fetch_add(count as u64, Ordering::Relaxed);
218            true
219        }
220    }
221
222    /// Record processing latency.
223    pub async fn record_latency(&self, latency: Duration) {
224        let mut metrics = self.metrics.write().await;
225
226        let alpha = self.config.smoothing_factor;
227        let new_latency_secs = latency.as_secs_f64();
228        let old_latency_secs = metrics.avg_latency.as_secs_f64();
229        let avg_latency_secs = alpha * new_latency_secs + (1.0 - alpha) * old_latency_secs;
230        metrics.avg_latency = Duration::from_secs_f64(avg_latency_secs);
231    }
232
233    /// Adjust rate based on observed latency.
234    pub async fn adjust_rate_adaptive(&self) {
235        let now = Instant::now();
236        let last_adjustment = *self.last_adjustment.read().await;
237
238        if now.duration_since(last_adjustment) < self.config.adjustment_interval {
239            return;
240        }
241
242        let metrics = self.metrics.read().await;
243        let current_latency = metrics.avg_latency;
244        let target_latency = self.config.target_latency;
245
246        drop(metrics);
247
248        if let Some(ref bucket) = self.token_bucket {
249            let current_rate = bucket.refill_rate;
250            let latency_ratio = current_latency.as_secs_f64() / target_latency.as_secs_f64();
251
252            let new_rate = if latency_ratio > 1.2 {
253                current_rate * 0.9
254            } else if latency_ratio < 0.8 {
255                current_rate * 1.1
256            } else {
257                current_rate
258            };
259
260            // Update metrics
261            let mut metrics = self.metrics.write().await;
262            metrics.target_rate = Some(new_rate);
263
264            *self.last_adjustment.write().await = now;
265        }
266    }
267
268    /// Get current metrics.
269    pub async fn metrics(&self) -> FlowControlMetrics {
270        let mut metrics = self.metrics.read().await.clone();
271
272        let elapsed = self.start_time.elapsed().as_secs_f64();
273        let processed = self.elements_processed.load(Ordering::Relaxed);
274        metrics.current_rate = processed as f64 / elapsed;
275        metrics.throttled_count = self.throttled_count.load(Ordering::Relaxed);
276        metrics.total_delay_ms = self.total_delay_ms.load(Ordering::Relaxed);
277
278        metrics
279    }
280
281    /// Reset metrics.
282    pub async fn reset_metrics(&self) {
283        let mut metrics = self.metrics.write().await;
284        *metrics = FlowControlMetrics::default();
285
286        self.elements_processed.store(0, Ordering::Relaxed);
287        self.throttled_count.store(0, Ordering::Relaxed);
288        self.total_delay_ms.store(0, Ordering::Relaxed);
289    }
290
291    /// Get current rate.
292    pub async fn current_rate(&self) -> f64 {
293        let elapsed = self.start_time.elapsed().as_secs_f64();
294        let processed = self.elements_processed.load(Ordering::Relaxed);
295
296        if elapsed > 0.0 {
297            processed as f64 / elapsed
298        } else {
299            0.0
300        }
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[tokio::test]
309    async fn test_flow_controller_creation() {
310        let config = FlowControlConfig::default();
311        let controller = FlowController::new(config);
312
313        assert_eq!(controller.current_rate().await, 0.0);
314    }
315
316    #[tokio::test]
317    async fn test_flow_controller_acquire() {
318        let config = FlowControlConfig::default();
319        let controller = FlowController::new(config);
320
321        controller
322            .acquire(10)
323            .await
324            .expect("flow controller acquire should succeed");
325
326        let metrics = controller.metrics().await;
327        assert!(metrics.current_rate > 0.0);
328    }
329
330    #[tokio::test]
331    async fn test_token_bucket() {
332        let bucket = TokenBucket::new(100, 10.0);
333
334        assert!(bucket.try_acquire(50).await);
335        assert!(bucket.try_acquire(50).await);
336        assert!(!bucket.try_acquire(1).await);
337    }
338
339    #[tokio::test]
340    async fn test_rate_limiting() {
341        let config = FlowControlConfig {
342            enable_rate_limiting: true,
343            max_rate: Some(100.0),
344            burst_size: 50,
345            ..Default::default()
346        };
347
348        let controller = FlowController::new(config);
349
350        // First acquire should succeed using burst capacity
351        controller
352            .acquire(50)
353            .await
354            .expect("flow controller acquire should succeed");
355
356        // After consuming all burst tokens, wait for refill at 100 tokens/sec
357        // Wait 20ms to allow ~2 tokens to be refilled (100 tokens/sec = 0.1 tokens/ms)
358        tokio::time::sleep(Duration::from_millis(20)).await;
359
360        // Now try_acquire should succeed as tokens have been refilled
361        assert!(controller.try_acquire(1).await);
362    }
363
364    #[tokio::test]
365    async fn test_latency_recording() {
366        let config = FlowControlConfig::default();
367        let controller = FlowController::new(config);
368
369        controller.record_latency(Duration::from_millis(100)).await;
370        controller.record_latency(Duration::from_millis(200)).await;
371
372        let metrics = controller.metrics().await;
373        assert!(metrics.avg_latency > Duration::ZERO);
374    }
375}