Skip to main content

heliosdb_proxy/
pipeline.rs

1//! Request Pipeline for HeliosProxy
2//!
3//! Provides request pipelining to reduce latency by sending multiple requests
4//! without waiting for responses. Supports PostgreSQL protocol pipelining.
5
6use dashmap::DashMap;
7use serde::{Deserialize, Serialize};
8use std::collections::VecDeque;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::oneshot;
13
14/// Connection ID type
15pub type ConnectionId = u64;
16
17/// Request ID for tracking pipelined requests
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub struct RequestId(u64);
20
21impl RequestId {
22    fn new(id: u64) -> Self {
23        Self(id)
24    }
25}
26
27/// Pipeline configuration
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct PipelineConfig {
30    /// Maximum depth of the pipeline per connection
31    pub max_depth: usize,
32    /// Enable request pipelining
33    pub enabled: bool,
34    /// Timeout for individual requests (ms)
35    pub request_timeout_ms: u64,
36    /// Enable auto-flushing when idle
37    pub auto_flush: bool,
38    /// Auto-flush interval (ms)
39    pub auto_flush_interval_ms: u64,
40}
41
42impl Default for PipelineConfig {
43    fn default() -> Self {
44        Self {
45            max_depth: 16,
46            enabled: true,
47            request_timeout_ms: 30_000,
48            auto_flush: true,
49            auto_flush_interval_ms: 10,
50        }
51    }
52}
53
54/// A pending request in the pipeline
55#[derive(Debug)]
56pub struct PendingRequest {
57    /// Request ID
58    pub id: RequestId,
59    /// Request data (SQL query or command)
60    pub data: Vec<u8>,
61    /// Submission timestamp
62    pub submitted_at: Instant,
63    /// Response channel
64    response_tx: Option<oneshot::Sender<PipelineResponse>>,
65}
66
67/// Pipeline response
68#[derive(Debug)]
69pub struct PipelineResponse {
70    /// Request ID
71    pub request_id: RequestId,
72    /// Response data
73    pub data: Vec<u8>,
74    /// Response time (from submission to completion)
75    pub response_time: Duration,
76    /// Whether the request succeeded
77    pub success: bool,
78    /// Error message if failed
79    pub error: Option<String>,
80}
81
82/// Ticket for awaiting a pipelined response
83pub struct Ticket {
84    rx: oneshot::Receiver<PipelineResponse>,
85}
86
87impl Ticket {
88    /// Wait for the response
89    pub async fn wait(self) -> Result<PipelineResponse, PipelineError> {
90        self.rx.await.map_err(|_| PipelineError::ChannelClosed)
91    }
92
93    /// Wait with timeout
94    pub async fn wait_timeout(self, timeout: Duration) -> Result<PipelineResponse, PipelineError> {
95        tokio::time::timeout(timeout, self.rx)
96            .await
97            .map_err(|_| PipelineError::Timeout)?
98            .map_err(|_| PipelineError::ChannelClosed)
99    }
100}
101
102/// Pipeline error types
103#[derive(Debug, Clone)]
104pub enum PipelineError {
105    /// Pipeline is full
106    PipelineFull,
107    /// Pipeline is disabled
108    Disabled,
109    /// Request timeout
110    Timeout,
111    /// Channel closed unexpectedly
112    ChannelClosed,
113    /// Connection error
114    ConnectionError(String),
115}
116
117impl std::fmt::Display for PipelineError {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        match self {
120            Self::PipelineFull => write!(f, "Pipeline is full"),
121            Self::Disabled => write!(f, "Pipeline is disabled"),
122            Self::Timeout => write!(f, "Request timeout"),
123            Self::ChannelClosed => write!(f, "Channel closed"),
124            Self::ConnectionError(e) => write!(f, "Connection error: {}", e),
125        }
126    }
127}
128
129impl std::error::Error for PipelineError {}
130
131/// Pipeline statistics
132#[derive(Debug, Clone, Default, Serialize, Deserialize)]
133pub struct PipelineStats {
134    /// Total requests submitted
135    pub requests_submitted: u64,
136    /// Total requests completed
137    pub requests_completed: u64,
138    /// Requests that timed out
139    pub requests_timeout: u64,
140    /// Requests rejected (pipeline full)
141    pub requests_rejected: u64,
142    /// Average pipeline depth
143    pub avg_pipeline_depth: f64,
144    /// Peak pipeline depth
145    pub peak_pipeline_depth: usize,
146    /// Average response time (ms)
147    pub avg_response_time_ms: f64,
148    /// Total bytes sent
149    pub bytes_sent: u64,
150    /// Total bytes received
151    pub bytes_received: u64,
152}
153
154/// Request Pipeline for a single connection
155struct ConnectionPipeline {
156    /// Pending requests
157    pending: VecDeque<PendingRequest>,
158    /// Peak depth for this connection
159    peak_depth: usize,
160}
161
162impl Default for ConnectionPipeline {
163    fn default() -> Self {
164        Self {
165            pending: VecDeque::with_capacity(16),
166            peak_depth: 0,
167        }
168    }
169}
170
171/// Request Pipeline Manager
172///
173/// Manages pipelined requests across multiple connections.
174pub struct RequestPipeline {
175    /// Configuration
176    config: PipelineConfig,
177    /// Pending requests per connection
178    connections: DashMap<ConnectionId, ConnectionPipeline>,
179    /// Next request ID
180    next_request_id: AtomicU64,
181    /// Statistics
182    stats: Arc<parking_lot::RwLock<PipelineStats>>,
183    /// Shutdown flag
184    shutdown: AtomicBool,
185}
186
187impl RequestPipeline {
188    /// Create a new request pipeline
189    pub fn new(config: PipelineConfig) -> Self {
190        Self {
191            config,
192            connections: DashMap::new(),
193            next_request_id: AtomicU64::new(1),
194            stats: Arc::new(parking_lot::RwLock::new(PipelineStats::default())),
195            shutdown: AtomicBool::new(false),
196        }
197    }
198
199    /// Submit a request to the pipeline
200    pub fn submit(&self, conn_id: ConnectionId, data: Vec<u8>) -> Result<Ticket, PipelineError> {
201        if !self.config.enabled {
202            return Err(PipelineError::Disabled);
203        }
204
205        if self.shutdown.load(Ordering::Relaxed) {
206            return Err(PipelineError::ConnectionError(
207                "Pipeline shutdown".to_string(),
208            ));
209        }
210
211        let request_id = RequestId::new(self.next_request_id.fetch_add(1, Ordering::Relaxed));
212        let (tx, rx) = oneshot::channel();
213
214        let pending = PendingRequest {
215            id: request_id,
216            data,
217            submitted_at: Instant::now(),
218            response_tx: Some(tx),
219        };
220
221        // Get or create pipeline for connection
222        let mut pipeline = self.connections.entry(conn_id).or_default();
223
224        // Check pipeline depth
225        if pipeline.pending.len() >= self.config.max_depth {
226            self.stats.write().requests_rejected += 1;
227            return Err(PipelineError::PipelineFull);
228        }
229
230        // Track statistics
231        {
232            let mut stats = self.stats.write();
233            stats.requests_submitted += 1;
234            stats.bytes_sent += pending.data.len() as u64;
235        }
236
237        // Update peak depth
238        let current_depth = pipeline.pending.len() + 1;
239        if current_depth > pipeline.peak_depth {
240            pipeline.peak_depth = current_depth;
241        }
242
243        pipeline.pending.push_back(pending);
244
245        Ok(Ticket { rx })
246    }
247
248    /// Complete a request with a response
249    pub fn complete(
250        &self,
251        conn_id: ConnectionId,
252        request_id: RequestId,
253        data: Vec<u8>,
254        success: bool,
255        error: Option<String>,
256    ) {
257        if let Some(mut pipeline) = self.connections.get_mut(&conn_id) {
258            // Find and remove the matching request
259            if let Some(pos) = pipeline.pending.iter().position(|r| r.id == request_id) {
260                if let Some(mut req) = pipeline.pending.remove(pos) {
261                    let response_time = req.submitted_at.elapsed();
262
263                    // Update statistics
264                    {
265                        let mut stats = self.stats.write();
266                        stats.requests_completed += 1;
267                        stats.bytes_received += data.len() as u64;
268
269                        // Update average response time (exponential moving average)
270                        let ms = response_time.as_millis() as f64;
271                        if stats.avg_response_time_ms == 0.0 {
272                            stats.avg_response_time_ms = ms;
273                        } else {
274                            stats.avg_response_time_ms =
275                                stats.avg_response_time_ms * 0.9 + ms * 0.1;
276                        }
277                    }
278
279                    // Send response
280                    if let Some(tx) = req.response_tx.take() {
281                        let _ = tx.send(PipelineResponse {
282                            request_id,
283                            data,
284                            response_time,
285                            success,
286                            error,
287                        });
288                    }
289                }
290            }
291        }
292    }
293
294    /// Complete the next pending request in order (FIFO)
295    pub fn complete_next(
296        &self,
297        conn_id: ConnectionId,
298        data: Vec<u8>,
299        success: bool,
300        error: Option<String>,
301    ) {
302        if let Some(mut pipeline) = self.connections.get_mut(&conn_id) {
303            if let Some(mut req) = pipeline.pending.pop_front() {
304                let response_time = req.submitted_at.elapsed();
305
306                // Update statistics
307                {
308                    let mut stats = self.stats.write();
309                    stats.requests_completed += 1;
310                    stats.bytes_received += data.len() as u64;
311
312                    let ms = response_time.as_millis() as f64;
313                    if stats.avg_response_time_ms == 0.0 {
314                        stats.avg_response_time_ms = ms;
315                    } else {
316                        stats.avg_response_time_ms = stats.avg_response_time_ms * 0.9 + ms * 0.1;
317                    }
318                }
319
320                if let Some(tx) = req.response_tx.take() {
321                    let _ = tx.send(PipelineResponse {
322                        request_id: req.id,
323                        data,
324                        response_time,
325                        success,
326                        error,
327                    });
328                }
329            }
330        }
331    }
332
333    /// Get current pipeline depth for a connection
334    pub fn depth(&self, conn_id: ConnectionId) -> usize {
335        self.connections
336            .get(&conn_id)
337            .map(|p| p.pending.len())
338            .unwrap_or(0)
339    }
340
341    /// Check if pipeline is empty for a connection
342    pub fn is_empty(&self, conn_id: ConnectionId) -> bool {
343        self.depth(conn_id) == 0
344    }
345
346    /// Clear pipeline for a connection (e.g., on connection close)
347    pub fn clear(&self, conn_id: ConnectionId) {
348        self.connections.remove(&conn_id);
349    }
350
351    /// Get statistics snapshot
352    pub fn stats(&self) -> PipelineStats {
353        let mut stats = self.stats.read().clone();
354
355        // Calculate peak pipeline depth across all connections
356        stats.peak_pipeline_depth = self
357            .connections
358            .iter()
359            .map(|p| p.peak_depth)
360            .max()
361            .unwrap_or(0);
362
363        // Calculate average pipeline depth
364        let total_depth: usize = self.connections.iter().map(|p| p.pending.len()).sum();
365        let conn_count = self.connections.len();
366        stats.avg_pipeline_depth = if conn_count > 0 {
367            total_depth as f64 / conn_count as f64
368        } else {
369            0.0
370        };
371
372        stats
373    }
374
375    /// Shutdown the pipeline
376    pub fn shutdown(&self) {
377        self.shutdown.store(true, Ordering::Release);
378        self.connections.clear();
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[tokio::test]
387    async fn test_pipeline_submit() {
388        let pipeline = RequestPipeline::new(PipelineConfig::default());
389        let conn_id = 1;
390
391        let ticket = pipeline.submit(conn_id, b"SELECT 1".to_vec()).unwrap();
392        assert_eq!(pipeline.depth(conn_id), 1);
393
394        // Complete the request
395        pipeline.complete_next(conn_id, b"1".to_vec(), true, None);
396        assert_eq!(pipeline.depth(conn_id), 0);
397
398        // Verify response
399        let response = ticket.wait().await.unwrap();
400        assert!(response.success);
401    }
402
403    #[tokio::test]
404    async fn test_pipeline_full() {
405        let config = PipelineConfig {
406            max_depth: 2,
407            ..Default::default()
408        };
409        let pipeline = RequestPipeline::new(config);
410        let conn_id = 1;
411
412        // Submit up to max depth
413        pipeline.submit(conn_id, b"SELECT 1".to_vec()).unwrap();
414        pipeline.submit(conn_id, b"SELECT 2".to_vec()).unwrap();
415
416        // Third should fail
417        let result = pipeline.submit(conn_id, b"SELECT 3".to_vec());
418        assert!(matches!(result, Err(PipelineError::PipelineFull)));
419    }
420
421    #[test]
422    fn test_pipeline_stats() {
423        let pipeline = RequestPipeline::new(PipelineConfig::default());
424        let conn_id = 1;
425
426        pipeline.submit(conn_id, b"SELECT 1".to_vec()).unwrap();
427        pipeline.submit(conn_id, b"SELECT 2".to_vec()).unwrap();
428
429        let stats = pipeline.stats();
430        assert_eq!(stats.requests_submitted, 2);
431    }
432
433    #[test]
434    fn test_pipeline_disabled() {
435        let config = PipelineConfig {
436            enabled: false,
437            ..Default::default()
438        };
439        let pipeline = RequestPipeline::new(config);
440
441        let result = pipeline.submit(1, b"SELECT 1".to_vec());
442        assert!(matches!(result, Err(PipelineError::Disabled)));
443    }
444}