Skip to main content

heliosdb_proxy/
pipeline.rs

1//! Request Pipeline for HeliosProxy
2//!
3//! Provides request pipelining to reduce latency by sending multiple requests
4//! without waiting for responses. Supports PostgreSQL protocol pipelining.
5
6use std::collections::VecDeque;
7use std::sync::atomic::{AtomicU64, AtomicBool, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use dashmap::DashMap;
11use tokio::sync::oneshot;
12use serde::{Deserialize, Serialize};
13
14/// Connection ID type
15pub type ConnectionId = u64;
16
17/// Request ID for tracking pipelined requests
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub struct RequestId(u64);
20
21impl RequestId {
22    fn new(id: u64) -> Self {
23        Self(id)
24    }
25}
26
27/// Pipeline configuration
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct PipelineConfig {
30    /// Maximum depth of the pipeline per connection
31    pub max_depth: usize,
32    /// Enable request pipelining
33    pub enabled: bool,
34    /// Timeout for individual requests (ms)
35    pub request_timeout_ms: u64,
36    /// Enable auto-flushing when idle
37    pub auto_flush: bool,
38    /// Auto-flush interval (ms)
39    pub auto_flush_interval_ms: u64,
40}
41
42impl Default for PipelineConfig {
43    fn default() -> Self {
44        Self {
45            max_depth: 16,
46            enabled: true,
47            request_timeout_ms: 30_000,
48            auto_flush: true,
49            auto_flush_interval_ms: 10,
50        }
51    }
52}
53
54/// A pending request in the pipeline
55#[derive(Debug)]
56pub struct PendingRequest {
57    /// Request ID
58    pub id: RequestId,
59    /// Request data (SQL query or command)
60    pub data: Vec<u8>,
61    /// Submission timestamp
62    pub submitted_at: Instant,
63    /// Response channel
64    response_tx: Option<oneshot::Sender<PipelineResponse>>,
65}
66
67/// Pipeline response
68#[derive(Debug)]
69pub struct PipelineResponse {
70    /// Request ID
71    pub request_id: RequestId,
72    /// Response data
73    pub data: Vec<u8>,
74    /// Response time (from submission to completion)
75    pub response_time: Duration,
76    /// Whether the request succeeded
77    pub success: bool,
78    /// Error message if failed
79    pub error: Option<String>,
80}
81
82/// Ticket for awaiting a pipelined response
83pub struct Ticket {
84    rx: oneshot::Receiver<PipelineResponse>,
85}
86
87impl Ticket {
88    /// Wait for the response
89    pub async fn wait(self) -> Result<PipelineResponse, PipelineError> {
90        self.rx.await.map_err(|_| PipelineError::ChannelClosed)
91    }
92
93    /// Wait with timeout
94    pub async fn wait_timeout(self, timeout: Duration) -> Result<PipelineResponse, PipelineError> {
95        tokio::time::timeout(timeout, self.rx)
96            .await
97            .map_err(|_| PipelineError::Timeout)?
98            .map_err(|_| PipelineError::ChannelClosed)
99    }
100}
101
102/// Pipeline error types
103#[derive(Debug, Clone)]
104pub enum PipelineError {
105    /// Pipeline is full
106    PipelineFull,
107    /// Pipeline is disabled
108    Disabled,
109    /// Request timeout
110    Timeout,
111    /// Channel closed unexpectedly
112    ChannelClosed,
113    /// Connection error
114    ConnectionError(String),
115}
116
117impl std::fmt::Display for PipelineError {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        match self {
120            Self::PipelineFull => write!(f, "Pipeline is full"),
121            Self::Disabled => write!(f, "Pipeline is disabled"),
122            Self::Timeout => write!(f, "Request timeout"),
123            Self::ChannelClosed => write!(f, "Channel closed"),
124            Self::ConnectionError(e) => write!(f, "Connection error: {}", e),
125        }
126    }
127}
128
129impl std::error::Error for PipelineError {}
130
131/// Pipeline statistics
132#[derive(Debug, Clone, Default, Serialize, Deserialize)]
133pub struct PipelineStats {
134    /// Total requests submitted
135    pub requests_submitted: u64,
136    /// Total requests completed
137    pub requests_completed: u64,
138    /// Requests that timed out
139    pub requests_timeout: u64,
140    /// Requests rejected (pipeline full)
141    pub requests_rejected: u64,
142    /// Average pipeline depth
143    pub avg_pipeline_depth: f64,
144    /// Peak pipeline depth
145    pub peak_pipeline_depth: usize,
146    /// Average response time (ms)
147    pub avg_response_time_ms: f64,
148    /// Total bytes sent
149    pub bytes_sent: u64,
150    /// Total bytes received
151    pub bytes_received: u64,
152}
153
154/// Request Pipeline for a single connection
155struct ConnectionPipeline {
156    /// Pending requests
157    pending: VecDeque<PendingRequest>,
158    /// Peak depth for this connection
159    peak_depth: usize,
160}
161
162impl Default for ConnectionPipeline {
163    fn default() -> Self {
164        Self {
165            pending: VecDeque::with_capacity(16),
166            peak_depth: 0,
167        }
168    }
169}
170
171/// Request Pipeline Manager
172///
173/// Manages pipelined requests across multiple connections.
174pub struct RequestPipeline {
175    /// Configuration
176    config: PipelineConfig,
177    /// Pending requests per connection
178    connections: DashMap<ConnectionId, ConnectionPipeline>,
179    /// Next request ID
180    next_request_id: AtomicU64,
181    /// Statistics
182    stats: Arc<parking_lot::RwLock<PipelineStats>>,
183    /// Shutdown flag
184    shutdown: AtomicBool,
185}
186
187impl RequestPipeline {
188    /// Create a new request pipeline
189    pub fn new(config: PipelineConfig) -> Self {
190        Self {
191            config,
192            connections: DashMap::new(),
193            next_request_id: AtomicU64::new(1),
194            stats: Arc::new(parking_lot::RwLock::new(PipelineStats::default())),
195            shutdown: AtomicBool::new(false),
196        }
197    }
198
199    /// Submit a request to the pipeline
200    pub fn submit(&self, conn_id: ConnectionId, data: Vec<u8>) -> Result<Ticket, PipelineError> {
201        if !self.config.enabled {
202            return Err(PipelineError::Disabled);
203        }
204
205        if self.shutdown.load(Ordering::Relaxed) {
206            return Err(PipelineError::ConnectionError("Pipeline shutdown".to_string()));
207        }
208
209        let request_id = RequestId::new(self.next_request_id.fetch_add(1, Ordering::Relaxed));
210        let (tx, rx) = oneshot::channel();
211
212        let pending = PendingRequest {
213            id: request_id,
214            data,
215            submitted_at: Instant::now(),
216            response_tx: Some(tx),
217        };
218
219        // Get or create pipeline for connection
220        let mut pipeline = self.connections.entry(conn_id).or_default();
221
222        // Check pipeline depth
223        if pipeline.pending.len() >= self.config.max_depth {
224            self.stats.write().requests_rejected += 1;
225            return Err(PipelineError::PipelineFull);
226        }
227
228        // Track statistics
229        {
230            let mut stats = self.stats.write();
231            stats.requests_submitted += 1;
232            stats.bytes_sent += pending.data.len() as u64;
233        }
234
235        // Update peak depth
236        let current_depth = pipeline.pending.len() + 1;
237        if current_depth > pipeline.peak_depth {
238            pipeline.peak_depth = current_depth;
239        }
240
241        pipeline.pending.push_back(pending);
242
243        Ok(Ticket { rx })
244    }
245
246    /// Complete a request with a response
247    pub fn complete(&self, conn_id: ConnectionId, request_id: RequestId, data: Vec<u8>, success: bool, error: Option<String>) {
248        if let Some(mut pipeline) = self.connections.get_mut(&conn_id) {
249            // Find and remove the matching request
250            if let Some(pos) = pipeline.pending.iter().position(|r| r.id == request_id) {
251                if let Some(mut req) = pipeline.pending.remove(pos) {
252                    let response_time = req.submitted_at.elapsed();
253
254                    // Update statistics
255                    {
256                        let mut stats = self.stats.write();
257                        stats.requests_completed += 1;
258                        stats.bytes_received += data.len() as u64;
259
260                        // Update average response time (exponential moving average)
261                        let ms = response_time.as_millis() as f64;
262                        if stats.avg_response_time_ms == 0.0 {
263                            stats.avg_response_time_ms = ms;
264                        } else {
265                            stats.avg_response_time_ms = stats.avg_response_time_ms * 0.9 + ms * 0.1;
266                        }
267                    }
268
269                    // Send response
270                    if let Some(tx) = req.response_tx.take() {
271                        let _ = tx.send(PipelineResponse {
272                            request_id,
273                            data,
274                            response_time,
275                            success,
276                            error,
277                        });
278                    }
279                }
280            }
281        }
282    }
283
284    /// Complete the next pending request in order (FIFO)
285    pub fn complete_next(&self, conn_id: ConnectionId, data: Vec<u8>, success: bool, error: Option<String>) {
286        if let Some(mut pipeline) = self.connections.get_mut(&conn_id) {
287            if let Some(mut req) = pipeline.pending.pop_front() {
288                let response_time = req.submitted_at.elapsed();
289
290                // Update statistics
291                {
292                    let mut stats = self.stats.write();
293                    stats.requests_completed += 1;
294                    stats.bytes_received += data.len() as u64;
295
296                    let ms = response_time.as_millis() as f64;
297                    if stats.avg_response_time_ms == 0.0 {
298                        stats.avg_response_time_ms = ms;
299                    } else {
300                        stats.avg_response_time_ms = stats.avg_response_time_ms * 0.9 + ms * 0.1;
301                    }
302                }
303
304                if let Some(tx) = req.response_tx.take() {
305                    let _ = tx.send(PipelineResponse {
306                        request_id: req.id,
307                        data,
308                        response_time,
309                        success,
310                        error,
311                    });
312                }
313            }
314        }
315    }
316
317    /// Get current pipeline depth for a connection
318    pub fn depth(&self, conn_id: ConnectionId) -> usize {
319        self.connections
320            .get(&conn_id)
321            .map(|p| p.pending.len())
322            .unwrap_or(0)
323    }
324
325    /// Check if pipeline is empty for a connection
326    pub fn is_empty(&self, conn_id: ConnectionId) -> bool {
327        self.depth(conn_id) == 0
328    }
329
330    /// Clear pipeline for a connection (e.g., on connection close)
331    pub fn clear(&self, conn_id: ConnectionId) {
332        self.connections.remove(&conn_id);
333    }
334
335    /// Get statistics snapshot
336    pub fn stats(&self) -> PipelineStats {
337        let mut stats = self.stats.read().clone();
338
339        // Calculate peak pipeline depth across all connections
340        stats.peak_pipeline_depth = self.connections
341            .iter()
342            .map(|p| p.peak_depth)
343            .max()
344            .unwrap_or(0);
345
346        // Calculate average pipeline depth
347        let total_depth: usize = self.connections.iter().map(|p| p.pending.len()).sum();
348        let conn_count = self.connections.len();
349        stats.avg_pipeline_depth = if conn_count > 0 {
350            total_depth as f64 / conn_count as f64
351        } else {
352            0.0
353        };
354
355        stats
356    }
357
358    /// Shutdown the pipeline
359    pub fn shutdown(&self) {
360        self.shutdown.store(true, Ordering::Release);
361        self.connections.clear();
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[tokio::test]
370    async fn test_pipeline_submit() {
371        let pipeline = RequestPipeline::new(PipelineConfig::default());
372        let conn_id = 1;
373
374        let ticket = pipeline.submit(conn_id, b"SELECT 1".to_vec()).unwrap();
375        assert_eq!(pipeline.depth(conn_id), 1);
376
377        // Complete the request
378        pipeline.complete_next(conn_id, b"1".to_vec(), true, None);
379        assert_eq!(pipeline.depth(conn_id), 0);
380
381        // Verify response
382        let response = ticket.wait().await.unwrap();
383        assert!(response.success);
384    }
385
386    #[tokio::test]
387    async fn test_pipeline_full() {
388        let config = PipelineConfig {
389            max_depth: 2,
390            ..Default::default()
391        };
392        let pipeline = RequestPipeline::new(config);
393        let conn_id = 1;
394
395        // Submit up to max depth
396        pipeline.submit(conn_id, b"SELECT 1".to_vec()).unwrap();
397        pipeline.submit(conn_id, b"SELECT 2".to_vec()).unwrap();
398
399        // Third should fail
400        let result = pipeline.submit(conn_id, b"SELECT 3".to_vec());
401        assert!(matches!(result, Err(PipelineError::PipelineFull)));
402    }
403
404    #[test]
405    fn test_pipeline_stats() {
406        let pipeline = RequestPipeline::new(PipelineConfig::default());
407        let conn_id = 1;
408
409        pipeline.submit(conn_id, b"SELECT 1".to_vec()).unwrap();
410        pipeline.submit(conn_id, b"SELECT 2".to_vec()).unwrap();
411
412        let stats = pipeline.stats();
413        assert_eq!(stats.requests_submitted, 2);
414    }
415
416    #[test]
417    fn test_pipeline_disabled() {
418        let config = PipelineConfig {
419            enabled: false,
420            ..Default::default()
421        };
422        let pipeline = RequestPipeline::new(config);
423
424        let result = pipeline.submit(1, b"SELECT 1".to_vec());
425        assert!(matches!(result, Err(PipelineError::Disabled)));
426    }
427}