1use std::collections::VecDeque;
7use std::sync::atomic::{AtomicU64, AtomicBool, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use dashmap::DashMap;
11use tokio::sync::oneshot;
12use serde::{Deserialize, Serialize};
13
14pub type ConnectionId = u64;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub struct RequestId(u64);
20
21impl RequestId {
22 fn new(id: u64) -> Self {
23 Self(id)
24 }
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct PipelineConfig {
30 pub max_depth: usize,
32 pub enabled: bool,
34 pub request_timeout_ms: u64,
36 pub auto_flush: bool,
38 pub auto_flush_interval_ms: u64,
40}
41
42impl Default for PipelineConfig {
43 fn default() -> Self {
44 Self {
45 max_depth: 16,
46 enabled: true,
47 request_timeout_ms: 30_000,
48 auto_flush: true,
49 auto_flush_interval_ms: 10,
50 }
51 }
52}
53
54#[derive(Debug)]
56pub struct PendingRequest {
57 pub id: RequestId,
59 pub data: Vec<u8>,
61 pub submitted_at: Instant,
63 response_tx: Option<oneshot::Sender<PipelineResponse>>,
65}
66
67#[derive(Debug)]
69pub struct PipelineResponse {
70 pub request_id: RequestId,
72 pub data: Vec<u8>,
74 pub response_time: Duration,
76 pub success: bool,
78 pub error: Option<String>,
80}
81
82pub struct Ticket {
84 rx: oneshot::Receiver<PipelineResponse>,
85}
86
87impl Ticket {
88 pub async fn wait(self) -> Result<PipelineResponse, PipelineError> {
90 self.rx.await.map_err(|_| PipelineError::ChannelClosed)
91 }
92
93 pub async fn wait_timeout(self, timeout: Duration) -> Result<PipelineResponse, PipelineError> {
95 tokio::time::timeout(timeout, self.rx)
96 .await
97 .map_err(|_| PipelineError::Timeout)?
98 .map_err(|_| PipelineError::ChannelClosed)
99 }
100}
101
102#[derive(Debug, Clone)]
104pub enum PipelineError {
105 PipelineFull,
107 Disabled,
109 Timeout,
111 ChannelClosed,
113 ConnectionError(String),
115}
116
117impl std::fmt::Display for PipelineError {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 match self {
120 Self::PipelineFull => write!(f, "Pipeline is full"),
121 Self::Disabled => write!(f, "Pipeline is disabled"),
122 Self::Timeout => write!(f, "Request timeout"),
123 Self::ChannelClosed => write!(f, "Channel closed"),
124 Self::ConnectionError(e) => write!(f, "Connection error: {}", e),
125 }
126 }
127}
128
129impl std::error::Error for PipelineError {}
130
131#[derive(Debug, Clone, Default, Serialize, Deserialize)]
133pub struct PipelineStats {
134 pub requests_submitted: u64,
136 pub requests_completed: u64,
138 pub requests_timeout: u64,
140 pub requests_rejected: u64,
142 pub avg_pipeline_depth: f64,
144 pub peak_pipeline_depth: usize,
146 pub avg_response_time_ms: f64,
148 pub bytes_sent: u64,
150 pub bytes_received: u64,
152}
153
154struct ConnectionPipeline {
156 pending: VecDeque<PendingRequest>,
158 peak_depth: usize,
160}
161
162impl Default for ConnectionPipeline {
163 fn default() -> Self {
164 Self {
165 pending: VecDeque::with_capacity(16),
166 peak_depth: 0,
167 }
168 }
169}
170
171pub struct RequestPipeline {
175 config: PipelineConfig,
177 connections: DashMap<ConnectionId, ConnectionPipeline>,
179 next_request_id: AtomicU64,
181 stats: Arc<parking_lot::RwLock<PipelineStats>>,
183 shutdown: AtomicBool,
185}
186
187impl RequestPipeline {
188 pub fn new(config: PipelineConfig) -> Self {
190 Self {
191 config,
192 connections: DashMap::new(),
193 next_request_id: AtomicU64::new(1),
194 stats: Arc::new(parking_lot::RwLock::new(PipelineStats::default())),
195 shutdown: AtomicBool::new(false),
196 }
197 }
198
199 pub fn submit(&self, conn_id: ConnectionId, data: Vec<u8>) -> Result<Ticket, PipelineError> {
201 if !self.config.enabled {
202 return Err(PipelineError::Disabled);
203 }
204
205 if self.shutdown.load(Ordering::Relaxed) {
206 return Err(PipelineError::ConnectionError("Pipeline shutdown".to_string()));
207 }
208
209 let request_id = RequestId::new(self.next_request_id.fetch_add(1, Ordering::Relaxed));
210 let (tx, rx) = oneshot::channel();
211
212 let pending = PendingRequest {
213 id: request_id,
214 data,
215 submitted_at: Instant::now(),
216 response_tx: Some(tx),
217 };
218
219 let mut pipeline = self.connections.entry(conn_id).or_default();
221
222 if pipeline.pending.len() >= self.config.max_depth {
224 self.stats.write().requests_rejected += 1;
225 return Err(PipelineError::PipelineFull);
226 }
227
228 {
230 let mut stats = self.stats.write();
231 stats.requests_submitted += 1;
232 stats.bytes_sent += pending.data.len() as u64;
233 }
234
235 let current_depth = pipeline.pending.len() + 1;
237 if current_depth > pipeline.peak_depth {
238 pipeline.peak_depth = current_depth;
239 }
240
241 pipeline.pending.push_back(pending);
242
243 Ok(Ticket { rx })
244 }
245
246 pub fn complete(&self, conn_id: ConnectionId, request_id: RequestId, data: Vec<u8>, success: bool, error: Option<String>) {
248 if let Some(mut pipeline) = self.connections.get_mut(&conn_id) {
249 if let Some(pos) = pipeline.pending.iter().position(|r| r.id == request_id) {
251 if let Some(mut req) = pipeline.pending.remove(pos) {
252 let response_time = req.submitted_at.elapsed();
253
254 {
256 let mut stats = self.stats.write();
257 stats.requests_completed += 1;
258 stats.bytes_received += data.len() as u64;
259
260 let ms = response_time.as_millis() as f64;
262 if stats.avg_response_time_ms == 0.0 {
263 stats.avg_response_time_ms = ms;
264 } else {
265 stats.avg_response_time_ms = stats.avg_response_time_ms * 0.9 + ms * 0.1;
266 }
267 }
268
269 if let Some(tx) = req.response_tx.take() {
271 let _ = tx.send(PipelineResponse {
272 request_id,
273 data,
274 response_time,
275 success,
276 error,
277 });
278 }
279 }
280 }
281 }
282 }
283
284 pub fn complete_next(&self, conn_id: ConnectionId, data: Vec<u8>, success: bool, error: Option<String>) {
286 if let Some(mut pipeline) = self.connections.get_mut(&conn_id) {
287 if let Some(mut req) = pipeline.pending.pop_front() {
288 let response_time = req.submitted_at.elapsed();
289
290 {
292 let mut stats = self.stats.write();
293 stats.requests_completed += 1;
294 stats.bytes_received += data.len() as u64;
295
296 let ms = response_time.as_millis() as f64;
297 if stats.avg_response_time_ms == 0.0 {
298 stats.avg_response_time_ms = ms;
299 } else {
300 stats.avg_response_time_ms = stats.avg_response_time_ms * 0.9 + ms * 0.1;
301 }
302 }
303
304 if let Some(tx) = req.response_tx.take() {
305 let _ = tx.send(PipelineResponse {
306 request_id: req.id,
307 data,
308 response_time,
309 success,
310 error,
311 });
312 }
313 }
314 }
315 }
316
317 pub fn depth(&self, conn_id: ConnectionId) -> usize {
319 self.connections
320 .get(&conn_id)
321 .map(|p| p.pending.len())
322 .unwrap_or(0)
323 }
324
325 pub fn is_empty(&self, conn_id: ConnectionId) -> bool {
327 self.depth(conn_id) == 0
328 }
329
330 pub fn clear(&self, conn_id: ConnectionId) {
332 self.connections.remove(&conn_id);
333 }
334
335 pub fn stats(&self) -> PipelineStats {
337 let mut stats = self.stats.read().clone();
338
339 stats.peak_pipeline_depth = self.connections
341 .iter()
342 .map(|p| p.peak_depth)
343 .max()
344 .unwrap_or(0);
345
346 let total_depth: usize = self.connections.iter().map(|p| p.pending.len()).sum();
348 let conn_count = self.connections.len();
349 stats.avg_pipeline_depth = if conn_count > 0 {
350 total_depth as f64 / conn_count as f64
351 } else {
352 0.0
353 };
354
355 stats
356 }
357
358 pub fn shutdown(&self) {
360 self.shutdown.store(true, Ordering::Release);
361 self.connections.clear();
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[tokio::test]
370 async fn test_pipeline_submit() {
371 let pipeline = RequestPipeline::new(PipelineConfig::default());
372 let conn_id = 1;
373
374 let ticket = pipeline.submit(conn_id, b"SELECT 1".to_vec()).unwrap();
375 assert_eq!(pipeline.depth(conn_id), 1);
376
377 pipeline.complete_next(conn_id, b"1".to_vec(), true, None);
379 assert_eq!(pipeline.depth(conn_id), 0);
380
381 let response = ticket.wait().await.unwrap();
383 assert!(response.success);
384 }
385
386 #[tokio::test]
387 async fn test_pipeline_full() {
388 let config = PipelineConfig {
389 max_depth: 2,
390 ..Default::default()
391 };
392 let pipeline = RequestPipeline::new(config);
393 let conn_id = 1;
394
395 pipeline.submit(conn_id, b"SELECT 1".to_vec()).unwrap();
397 pipeline.submit(conn_id, b"SELECT 2".to_vec()).unwrap();
398
399 let result = pipeline.submit(conn_id, b"SELECT 3".to_vec());
401 assert!(matches!(result, Err(PipelineError::PipelineFull)));
402 }
403
404 #[test]
405 fn test_pipeline_stats() {
406 let pipeline = RequestPipeline::new(PipelineConfig::default());
407 let conn_id = 1;
408
409 pipeline.submit(conn_id, b"SELECT 1".to_vec()).unwrap();
410 pipeline.submit(conn_id, b"SELECT 2".to_vec()).unwrap();
411
412 let stats = pipeline.stats();
413 assert_eq!(stats.requests_submitted, 2);
414 }
415
416 #[test]
417 fn test_pipeline_disabled() {
418 let config = PipelineConfig {
419 enabled: false,
420 ..Default::default()
421 };
422 let pipeline = RequestPipeline::new(config);
423
424 let result = pipeline.submit(1, b"SELECT 1".to_vec());
425 assert!(matches!(result, Err(PipelineError::Disabled)));
426 }
427}