use std::collections::VecDeque;
use std::sync::atomic::{AtomicU64, AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use tokio::sync::oneshot;
use serde::{Deserialize, Serialize};
pub type ConnectionId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RequestId(u64);
impl RequestId {
fn new(id: u64) -> Self {
Self(id)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PipelineConfig {
pub max_depth: usize,
pub enabled: bool,
pub request_timeout_ms: u64,
pub auto_flush: bool,
pub auto_flush_interval_ms: u64,
}
impl Default for PipelineConfig {
fn default() -> Self {
Self {
max_depth: 16,
enabled: true,
request_timeout_ms: 30_000,
auto_flush: true,
auto_flush_interval_ms: 10,
}
}
}
#[derive(Debug)]
pub struct PendingRequest {
pub id: RequestId,
pub data: Vec<u8>,
pub submitted_at: Instant,
response_tx: Option<oneshot::Sender<PipelineResponse>>,
}
#[derive(Debug)]
pub struct PipelineResponse {
pub request_id: RequestId,
pub data: Vec<u8>,
pub response_time: Duration,
pub success: bool,
pub error: Option<String>,
}
pub struct Ticket {
rx: oneshot::Receiver<PipelineResponse>,
}
impl Ticket {
pub async fn wait(self) -> Result<PipelineResponse, PipelineError> {
self.rx.await.map_err(|_| PipelineError::ChannelClosed)
}
pub async fn wait_timeout(self, timeout: Duration) -> Result<PipelineResponse, PipelineError> {
tokio::time::timeout(timeout, self.rx)
.await
.map_err(|_| PipelineError::Timeout)?
.map_err(|_| PipelineError::ChannelClosed)
}
}
#[derive(Debug, Clone)]
pub enum PipelineError {
PipelineFull,
Disabled,
Timeout,
ChannelClosed,
ConnectionError(String),
}
impl std::fmt::Display for PipelineError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::PipelineFull => write!(f, "Pipeline is full"),
Self::Disabled => write!(f, "Pipeline is disabled"),
Self::Timeout => write!(f, "Request timeout"),
Self::ChannelClosed => write!(f, "Channel closed"),
Self::ConnectionError(e) => write!(f, "Connection error: {}", e),
}
}
}
impl std::error::Error for PipelineError {}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PipelineStats {
pub requests_submitted: u64,
pub requests_completed: u64,
pub requests_timeout: u64,
pub requests_rejected: u64,
pub avg_pipeline_depth: f64,
pub peak_pipeline_depth: usize,
pub avg_response_time_ms: f64,
pub bytes_sent: u64,
pub bytes_received: u64,
}
struct ConnectionPipeline {
pending: VecDeque<PendingRequest>,
peak_depth: usize,
}
impl Default for ConnectionPipeline {
fn default() -> Self {
Self {
pending: VecDeque::with_capacity(16),
peak_depth: 0,
}
}
}
pub struct RequestPipeline {
config: PipelineConfig,
connections: DashMap<ConnectionId, ConnectionPipeline>,
next_request_id: AtomicU64,
stats: Arc<parking_lot::RwLock<PipelineStats>>,
shutdown: AtomicBool,
}
impl RequestPipeline {
pub fn new(config: PipelineConfig) -> Self {
Self {
config,
connections: DashMap::new(),
next_request_id: AtomicU64::new(1),
stats: Arc::new(parking_lot::RwLock::new(PipelineStats::default())),
shutdown: AtomicBool::new(false),
}
}
pub fn submit(&self, conn_id: ConnectionId, data: Vec<u8>) -> Result<Ticket, PipelineError> {
if !self.config.enabled {
return Err(PipelineError::Disabled);
}
if self.shutdown.load(Ordering::Relaxed) {
return Err(PipelineError::ConnectionError("Pipeline shutdown".to_string()));
}
let request_id = RequestId::new(self.next_request_id.fetch_add(1, Ordering::Relaxed));
let (tx, rx) = oneshot::channel();
let pending = PendingRequest {
id: request_id,
data,
submitted_at: Instant::now(),
response_tx: Some(tx),
};
let mut pipeline = self.connections.entry(conn_id).or_default();
if pipeline.pending.len() >= self.config.max_depth {
self.stats.write().requests_rejected += 1;
return Err(PipelineError::PipelineFull);
}
{
let mut stats = self.stats.write();
stats.requests_submitted += 1;
stats.bytes_sent += pending.data.len() as u64;
}
let current_depth = pipeline.pending.len() + 1;
if current_depth > pipeline.peak_depth {
pipeline.peak_depth = current_depth;
}
pipeline.pending.push_back(pending);
Ok(Ticket { rx })
}
pub fn complete(&self, conn_id: ConnectionId, request_id: RequestId, data: Vec<u8>, success: bool, error: Option<String>) {
if let Some(mut pipeline) = self.connections.get_mut(&conn_id) {
if let Some(pos) = pipeline.pending.iter().position(|r| r.id == request_id) {
if let Some(mut req) = pipeline.pending.remove(pos) {
let response_time = req.submitted_at.elapsed();
{
let mut stats = self.stats.write();
stats.requests_completed += 1;
stats.bytes_received += data.len() as u64;
let ms = response_time.as_millis() as f64;
if stats.avg_response_time_ms == 0.0 {
stats.avg_response_time_ms = ms;
} else {
stats.avg_response_time_ms = stats.avg_response_time_ms * 0.9 + ms * 0.1;
}
}
if let Some(tx) = req.response_tx.take() {
let _ = tx.send(PipelineResponse {
request_id,
data,
response_time,
success,
error,
});
}
}
}
}
}
pub fn complete_next(&self, conn_id: ConnectionId, data: Vec<u8>, success: bool, error: Option<String>) {
if let Some(mut pipeline) = self.connections.get_mut(&conn_id) {
if let Some(mut req) = pipeline.pending.pop_front() {
let response_time = req.submitted_at.elapsed();
{
let mut stats = self.stats.write();
stats.requests_completed += 1;
stats.bytes_received += data.len() as u64;
let ms = response_time.as_millis() as f64;
if stats.avg_response_time_ms == 0.0 {
stats.avg_response_time_ms = ms;
} else {
stats.avg_response_time_ms = stats.avg_response_time_ms * 0.9 + ms * 0.1;
}
}
if let Some(tx) = req.response_tx.take() {
let _ = tx.send(PipelineResponse {
request_id: req.id,
data,
response_time,
success,
error,
});
}
}
}
}
pub fn depth(&self, conn_id: ConnectionId) -> usize {
self.connections
.get(&conn_id)
.map(|p| p.pending.len())
.unwrap_or(0)
}
pub fn is_empty(&self, conn_id: ConnectionId) -> bool {
self.depth(conn_id) == 0
}
pub fn clear(&self, conn_id: ConnectionId) {
self.connections.remove(&conn_id);
}
pub fn stats(&self) -> PipelineStats {
let mut stats = self.stats.read().clone();
stats.peak_pipeline_depth = self.connections
.iter()
.map(|p| p.peak_depth)
.max()
.unwrap_or(0);
let total_depth: usize = self.connections.iter().map(|p| p.pending.len()).sum();
let conn_count = self.connections.len();
stats.avg_pipeline_depth = if conn_count > 0 {
total_depth as f64 / conn_count as f64
} else {
0.0
};
stats
}
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::Release);
self.connections.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_pipeline_submit() {
let pipeline = RequestPipeline::new(PipelineConfig::default());
let conn_id = 1;
let ticket = pipeline.submit(conn_id, b"SELECT 1".to_vec()).unwrap();
assert_eq!(pipeline.depth(conn_id), 1);
pipeline.complete_next(conn_id, b"1".to_vec(), true, None);
assert_eq!(pipeline.depth(conn_id), 0);
let response = ticket.wait().await.unwrap();
assert!(response.success);
}
#[tokio::test]
async fn test_pipeline_full() {
let config = PipelineConfig {
max_depth: 2,
..Default::default()
};
let pipeline = RequestPipeline::new(config);
let conn_id = 1;
pipeline.submit(conn_id, b"SELECT 1".to_vec()).unwrap();
pipeline.submit(conn_id, b"SELECT 2".to_vec()).unwrap();
let result = pipeline.submit(conn_id, b"SELECT 3".to_vec());
assert!(matches!(result, Err(PipelineError::PipelineFull)));
}
#[test]
fn test_pipeline_stats() {
let pipeline = RequestPipeline::new(PipelineConfig::default());
let conn_id = 1;
pipeline.submit(conn_id, b"SELECT 1".to_vec()).unwrap();
pipeline.submit(conn_id, b"SELECT 2".to_vec()).unwrap();
let stats = pipeline.stats();
assert_eq!(stats.requests_submitted, 2);
}
#[test]
fn test_pipeline_disabled() {
let config = PipelineConfig {
enabled: false,
..Default::default()
};
let pipeline = RequestPipeline::new(config);
let result = pipeline.submit(1, b"SELECT 1".to_vec());
assert!(matches!(result, Err(PipelineError::Disabled)));
}
}