1use dashmap::DashMap;
7use serde::{Deserialize, Serialize};
8use std::collections::VecDeque;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::oneshot;
13
14pub type ConnectionId = u64;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub struct RequestId(u64);
20
21impl RequestId {
22 fn new(id: u64) -> Self {
23 Self(id)
24 }
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct PipelineConfig {
30 pub max_depth: usize,
32 pub enabled: bool,
34 pub request_timeout_ms: u64,
36 pub auto_flush: bool,
38 pub auto_flush_interval_ms: u64,
40}
41
42impl Default for PipelineConfig {
43 fn default() -> Self {
44 Self {
45 max_depth: 16,
46 enabled: true,
47 request_timeout_ms: 30_000,
48 auto_flush: true,
49 auto_flush_interval_ms: 10,
50 }
51 }
52}
53
54#[derive(Debug)]
56pub struct PendingRequest {
57 pub id: RequestId,
59 pub data: Vec<u8>,
61 pub submitted_at: Instant,
63 response_tx: Option<oneshot::Sender<PipelineResponse>>,
65}
66
67#[derive(Debug)]
69pub struct PipelineResponse {
70 pub request_id: RequestId,
72 pub data: Vec<u8>,
74 pub response_time: Duration,
76 pub success: bool,
78 pub error: Option<String>,
80}
81
82pub struct Ticket {
84 rx: oneshot::Receiver<PipelineResponse>,
85}
86
87impl Ticket {
88 pub async fn wait(self) -> Result<PipelineResponse, PipelineError> {
90 self.rx.await.map_err(|_| PipelineError::ChannelClosed)
91 }
92
93 pub async fn wait_timeout(self, timeout: Duration) -> Result<PipelineResponse, PipelineError> {
95 tokio::time::timeout(timeout, self.rx)
96 .await
97 .map_err(|_| PipelineError::Timeout)?
98 .map_err(|_| PipelineError::ChannelClosed)
99 }
100}
101
102#[derive(Debug, Clone)]
104pub enum PipelineError {
105 PipelineFull,
107 Disabled,
109 Timeout,
111 ChannelClosed,
113 ConnectionError(String),
115}
116
117impl std::fmt::Display for PipelineError {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 match self {
120 Self::PipelineFull => write!(f, "Pipeline is full"),
121 Self::Disabled => write!(f, "Pipeline is disabled"),
122 Self::Timeout => write!(f, "Request timeout"),
123 Self::ChannelClosed => write!(f, "Channel closed"),
124 Self::ConnectionError(e) => write!(f, "Connection error: {}", e),
125 }
126 }
127}
128
129impl std::error::Error for PipelineError {}
130
131#[derive(Debug, Clone, Default, Serialize, Deserialize)]
133pub struct PipelineStats {
134 pub requests_submitted: u64,
136 pub requests_completed: u64,
138 pub requests_timeout: u64,
140 pub requests_rejected: u64,
142 pub avg_pipeline_depth: f64,
144 pub peak_pipeline_depth: usize,
146 pub avg_response_time_ms: f64,
148 pub bytes_sent: u64,
150 pub bytes_received: u64,
152}
153
154struct ConnectionPipeline {
156 pending: VecDeque<PendingRequest>,
158 peak_depth: usize,
160}
161
162impl Default for ConnectionPipeline {
163 fn default() -> Self {
164 Self {
165 pending: VecDeque::with_capacity(16),
166 peak_depth: 0,
167 }
168 }
169}
170
171pub struct RequestPipeline {
175 config: PipelineConfig,
177 connections: DashMap<ConnectionId, ConnectionPipeline>,
179 next_request_id: AtomicU64,
181 stats: Arc<parking_lot::RwLock<PipelineStats>>,
183 shutdown: AtomicBool,
185}
186
187impl RequestPipeline {
188 pub fn new(config: PipelineConfig) -> Self {
190 Self {
191 config,
192 connections: DashMap::new(),
193 next_request_id: AtomicU64::new(1),
194 stats: Arc::new(parking_lot::RwLock::new(PipelineStats::default())),
195 shutdown: AtomicBool::new(false),
196 }
197 }
198
199 pub fn submit(&self, conn_id: ConnectionId, data: Vec<u8>) -> Result<Ticket, PipelineError> {
201 if !self.config.enabled {
202 return Err(PipelineError::Disabled);
203 }
204
205 if self.shutdown.load(Ordering::Relaxed) {
206 return Err(PipelineError::ConnectionError(
207 "Pipeline shutdown".to_string(),
208 ));
209 }
210
211 let request_id = RequestId::new(self.next_request_id.fetch_add(1, Ordering::Relaxed));
212 let (tx, rx) = oneshot::channel();
213
214 let pending = PendingRequest {
215 id: request_id,
216 data,
217 submitted_at: Instant::now(),
218 response_tx: Some(tx),
219 };
220
221 let mut pipeline = self.connections.entry(conn_id).or_default();
223
224 if pipeline.pending.len() >= self.config.max_depth {
226 self.stats.write().requests_rejected += 1;
227 return Err(PipelineError::PipelineFull);
228 }
229
230 {
232 let mut stats = self.stats.write();
233 stats.requests_submitted += 1;
234 stats.bytes_sent += pending.data.len() as u64;
235 }
236
237 let current_depth = pipeline.pending.len() + 1;
239 if current_depth > pipeline.peak_depth {
240 pipeline.peak_depth = current_depth;
241 }
242
243 pipeline.pending.push_back(pending);
244
245 Ok(Ticket { rx })
246 }
247
248 pub fn complete(
250 &self,
251 conn_id: ConnectionId,
252 request_id: RequestId,
253 data: Vec<u8>,
254 success: bool,
255 error: Option<String>,
256 ) {
257 if let Some(mut pipeline) = self.connections.get_mut(&conn_id) {
258 if let Some(pos) = pipeline.pending.iter().position(|r| r.id == request_id) {
260 if let Some(mut req) = pipeline.pending.remove(pos) {
261 let response_time = req.submitted_at.elapsed();
262
263 {
265 let mut stats = self.stats.write();
266 stats.requests_completed += 1;
267 stats.bytes_received += data.len() as u64;
268
269 let ms = response_time.as_millis() as f64;
271 if stats.avg_response_time_ms == 0.0 {
272 stats.avg_response_time_ms = ms;
273 } else {
274 stats.avg_response_time_ms =
275 stats.avg_response_time_ms * 0.9 + ms * 0.1;
276 }
277 }
278
279 if let Some(tx) = req.response_tx.take() {
281 let _ = tx.send(PipelineResponse {
282 request_id,
283 data,
284 response_time,
285 success,
286 error,
287 });
288 }
289 }
290 }
291 }
292 }
293
294 pub fn complete_next(
296 &self,
297 conn_id: ConnectionId,
298 data: Vec<u8>,
299 success: bool,
300 error: Option<String>,
301 ) {
302 if let Some(mut pipeline) = self.connections.get_mut(&conn_id) {
303 if let Some(mut req) = pipeline.pending.pop_front() {
304 let response_time = req.submitted_at.elapsed();
305
306 {
308 let mut stats = self.stats.write();
309 stats.requests_completed += 1;
310 stats.bytes_received += data.len() as u64;
311
312 let ms = response_time.as_millis() as f64;
313 if stats.avg_response_time_ms == 0.0 {
314 stats.avg_response_time_ms = ms;
315 } else {
316 stats.avg_response_time_ms = stats.avg_response_time_ms * 0.9 + ms * 0.1;
317 }
318 }
319
320 if let Some(tx) = req.response_tx.take() {
321 let _ = tx.send(PipelineResponse {
322 request_id: req.id,
323 data,
324 response_time,
325 success,
326 error,
327 });
328 }
329 }
330 }
331 }
332
333 pub fn depth(&self, conn_id: ConnectionId) -> usize {
335 self.connections
336 .get(&conn_id)
337 .map(|p| p.pending.len())
338 .unwrap_or(0)
339 }
340
341 pub fn is_empty(&self, conn_id: ConnectionId) -> bool {
343 self.depth(conn_id) == 0
344 }
345
346 pub fn clear(&self, conn_id: ConnectionId) {
348 self.connections.remove(&conn_id);
349 }
350
351 pub fn stats(&self) -> PipelineStats {
353 let mut stats = self.stats.read().clone();
354
355 stats.peak_pipeline_depth = self
357 .connections
358 .iter()
359 .map(|p| p.peak_depth)
360 .max()
361 .unwrap_or(0);
362
363 let total_depth: usize = self.connections.iter().map(|p| p.pending.len()).sum();
365 let conn_count = self.connections.len();
366 stats.avg_pipeline_depth = if conn_count > 0 {
367 total_depth as f64 / conn_count as f64
368 } else {
369 0.0
370 };
371
372 stats
373 }
374
375 pub fn shutdown(&self) {
377 self.shutdown.store(true, Ordering::Release);
378 self.connections.clear();
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385
386 #[tokio::test]
387 async fn test_pipeline_submit() {
388 let pipeline = RequestPipeline::new(PipelineConfig::default());
389 let conn_id = 1;
390
391 let ticket = pipeline.submit(conn_id, b"SELECT 1".to_vec()).unwrap();
392 assert_eq!(pipeline.depth(conn_id), 1);
393
394 pipeline.complete_next(conn_id, b"1".to_vec(), true, None);
396 assert_eq!(pipeline.depth(conn_id), 0);
397
398 let response = ticket.wait().await.unwrap();
400 assert!(response.success);
401 }
402
403 #[tokio::test]
404 async fn test_pipeline_full() {
405 let config = PipelineConfig {
406 max_depth: 2,
407 ..Default::default()
408 };
409 let pipeline = RequestPipeline::new(config);
410 let conn_id = 1;
411
412 pipeline.submit(conn_id, b"SELECT 1".to_vec()).unwrap();
414 pipeline.submit(conn_id, b"SELECT 2".to_vec()).unwrap();
415
416 let result = pipeline.submit(conn_id, b"SELECT 3".to_vec());
418 assert!(matches!(result, Err(PipelineError::PipelineFull)));
419 }
420
421 #[test]
422 fn test_pipeline_stats() {
423 let pipeline = RequestPipeline::new(PipelineConfig::default());
424 let conn_id = 1;
425
426 pipeline.submit(conn_id, b"SELECT 1".to_vec()).unwrap();
427 pipeline.submit(conn_id, b"SELECT 2".to_vec()).unwrap();
428
429 let stats = pipeline.stats();
430 assert_eq!(stats.requests_submitted, 2);
431 }
432
433 #[test]
434 fn test_pipeline_disabled() {
435 let config = PipelineConfig {
436 enabled: false,
437 ..Default::default()
438 };
439 let pipeline = RequestPipeline::new(config);
440
441 let result = pipeline.submit(1, b"SELECT 1".to_vec());
442 assert!(matches!(result, Err(PipelineError::Disabled)));
443 }
444}