use std::collections::VecDeque;
use std::sync::Arc;
use bytes::BytesMut;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot, Mutex};
use crate::connection::WireConn;
use crate::error::PgWireError;
use crate::protocol::backend;
use crate::protocol::frontend;
use crate::protocol::types::{BackendMsg, FormatCode, FrontendMsg, RawRow};
pub(crate) struct PipelineRequest {
pub(crate) messages: BytesMut,
pub(crate) collector: ResponseCollector,
pub(crate) response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
}
#[allow(dead_code)]
#[non_exhaustive]
pub enum ResponseCollector {
Rows,
Drain,
Stream {
header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
},
CopyIn {
data: Vec<u8>,
},
CopyOut,
}
#[non_exhaustive]
pub enum PipelineResponse {
Rows {
fields: Vec<crate::protocol::types::FieldDescription>,
rows: Vec<RawRow>,
command_tag: String,
},
Done,
}
#[derive(Debug, Clone)]
pub struct StreamHeader {
pub fields: Vec<crate::protocol::types::FieldDescription>,
}
pub type StreamedRow = RawRow;
pub struct AsyncConn {
request_tx: mpsc::Sender<PipelineRequest>,
stmt_cache: std::sync::Mutex<std::collections::HashMap<String, (String, u64)>>,
stmt_counter: std::sync::atomic::AtomicU64,
alive: Arc<std::sync::atomic::AtomicBool>,
backend_pid: i32,
backend_secret: i32,
addr: String,
#[allow(dead_code)]
notification_tx: mpsc::Sender<crate::protocol::types::BackendMsg>,
notification_rx: std::sync::Mutex<Option<mpsc::Receiver<crate::protocol::types::BackendMsg>>>,
state_mutated: Arc<std::sync::atomic::AtomicBool>,
broken: Arc<std::sync::atomic::AtomicBool>,
dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
}
impl std::fmt::Debug for AsyncConn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncConn")
.field("addr", &self.addr)
.field("backend_pid", &self.backend_pid)
.field("alive", &self.is_alive())
.finish()
}
}
impl AsyncConn {
pub fn is_alive(&self) -> bool {
self.alive.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn backend_pid(&self) -> i32 {
self.backend_pid
}
pub fn addr(&self) -> &str {
&self.addr
}
pub fn cancel_token(&self) -> crate::cancel::CancelToken {
crate::cancel::CancelToken::new(self.addr.clone(), self.backend_pid, self.backend_secret)
}
pub fn mark_state_mutated(&self) {
self.state_mutated
.store(true, std::sync::atomic::Ordering::Release);
}
pub fn take_state_mutated(&self) -> bool {
self.state_mutated
.swap(false, std::sync::atomic::Ordering::AcqRel)
}
pub fn is_state_mutated(&self) -> bool {
self.state_mutated
.load(std::sync::atomic::Ordering::Acquire)
}
pub fn mark_broken(&self) {
self.broken
.store(true, std::sync::atomic::Ordering::Release);
}
pub fn is_broken(&self) -> bool {
self.broken.load(std::sync::atomic::Ordering::Acquire)
}
#[doc(hidden)]
pub fn __force_mark_dead_for_test(&self) {
self.alive
.store(false, std::sync::atomic::Ordering::Release);
}
pub fn enqueue_rollback(&self) -> bool {
if !self.is_alive() {
return false;
}
try_enqueue_rollback(&self.request_tx)
}
}
fn try_enqueue_rollback(request_tx: &mpsc::Sender<PipelineRequest>) -> bool {
let mut buf = BytesMut::with_capacity(16);
frontend::encode_message(&FrontendMsg::Query(b"ROLLBACK"), &mut buf);
let (tx, _rx) = oneshot::channel();
request_tx
.try_send(PipelineRequest {
messages: buf,
collector: ResponseCollector::Drain,
response_tx: tx,
})
.is_ok()
}
struct PendingResponse {
collector: ResponseCollector,
response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
}
impl AsyncConn {
pub fn new(conn: WireConn) -> Self {
let backend_pid = conn.pid;
let backend_secret = conn.secret;
let addr = conn
.stream
.peer_addr()
.map(|a| a.to_string())
.unwrap_or_default();
let (notification_tx, notification_rx) = mpsc::channel(4096);
let (request_tx, request_rx) = mpsc::channel::<PipelineRequest>(256);
let pending: Arc<Mutex<VecDeque<PendingResponse>>> = Arc::new(Mutex::new(VecDeque::new()));
let pending_notify = Arc::new(tokio::sync::Notify::new());
let alive = Arc::new(std::sync::atomic::AtomicBool::new(true));
let state_mutated = Arc::new(std::sync::atomic::AtomicBool::new(false));
let broken = Arc::new(std::sync::atomic::AtomicBool::new(false));
let dropped_notifications = Arc::new(std::sync::atomic::AtomicU64::new(0));
let (stream_read, stream_write) = tokio::io::split(conn.into_stream());
{
let pending = Arc::clone(&pending);
let pending_notify = Arc::clone(&pending_notify);
let alive = Arc::clone(&alive);
tokio::spawn(async move {
writer_task(request_rx, stream_write, pending, pending_notify).await;
alive.store(false, std::sync::atomic::Ordering::Relaxed);
tracing::warn!("pg-wired writer task exited");
});
}
{
let pending = Arc::clone(&pending);
let pending_notify = Arc::clone(&pending_notify);
let alive_clone = Arc::clone(&alive);
let state_mutated = Arc::clone(&state_mutated);
let ntf_tx = notification_tx.clone();
let dropped = Arc::clone(&dropped_notifications);
tokio::spawn(async move {
reader_task(
stream_read,
pending,
pending_notify,
ntf_tx,
state_mutated,
dropped,
)
.await;
alive_clone.store(false, std::sync::atomic::Ordering::Relaxed);
tracing::warn!("pg-wired reader task exited");
});
}
Self {
request_tx,
stmt_cache: std::sync::Mutex::new(std::collections::HashMap::new()),
stmt_counter: std::sync::atomic::AtomicU64::new(0),
alive,
backend_pid,
backend_secret,
addr,
notification_tx,
notification_rx: std::sync::Mutex::new(Some(notification_rx)),
state_mutated,
broken,
dropped_notifications,
}
}
pub fn dropped_notifications(&self) -> u64 {
self.dropped_notifications
.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn take_notification_receiver(
&self,
) -> Option<mpsc::Receiver<crate::protocol::types::BackendMsg>> {
self.notification_rx
.lock()
.ok()
.and_then(|mut guard| guard.take())
}
pub fn lookup_or_alloc(&self, sql: &str, _param_oids: &[u32]) -> (Vec<u8>, bool) {
let cache = match self.stmt_cache.lock() {
Ok(c) => c,
Err(poisoned) => poisoned.into_inner(),
};
if let Some((name, _)) = cache.get(sql) {
return (name.as_bytes().to_vec(), false);
}
let n = self
.stmt_counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let name = format!("s{n}");
(name.into_bytes(), true)
}
pub fn cache_statement(&self, sql: &str, name: &[u8]) {
let Ok(name_str) = std::str::from_utf8(name) else {
return;
};
let counter = name_str
.strip_prefix('s')
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or_else(|| self.stmt_counter.load(std::sync::atomic::Ordering::Relaxed));
let mut cache = match self.stmt_cache.lock() {
Ok(c) => c,
Err(poisoned) => poisoned.into_inner(),
};
if cache.contains_key(sql) {
return;
}
if cache.len() >= 256 {
if let Some((oldest_key, oldest_name)) = cache
.iter()
.min_by_key(|(_, (_, counter))| *counter)
.map(|(k, (name, _))| (k.clone(), name.clone()))
{
cache.remove(&oldest_key);
let mut close_buf = BytesMut::with_capacity(32);
frontend::encode_message(
&FrontendMsg::Close {
kind: b'S',
name: oldest_name.as_bytes(),
},
&mut close_buf,
);
frontend::encode_message(&FrontendMsg::Sync, &mut close_buf);
let (tx, _rx) = oneshot::channel();
let _ = self.request_tx.try_send(PipelineRequest {
messages: close_buf,
collector: ResponseCollector::Drain,
response_tx: tx,
});
}
}
cache.insert(sql.to_string(), (name_str.to_string(), counter));
}
pub async fn copy_in(&self, copy_sql: &str, data: &[u8]) -> Result<u64, PgWireError> {
use crate::protocol::types::FrontendMsg;
const CHUNK_SIZE: usize = 1024 * 1024;
let mut buf = BytesMut::with_capacity(copy_sql.len() + data.len().min(CHUNK_SIZE) + 64);
frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
for chunk in data.chunks(CHUNK_SIZE) {
frontend::encode_message(&FrontendMsg::CopyData(chunk), &mut buf);
}
frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
let resp = self
.submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
.await?;
match resp {
PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
PipelineResponse::Done => Ok(0),
}
}
pub async fn copy_in_stream<R: tokio::io::AsyncRead + Unpin>(
&self,
copy_sql: &str,
mut reader: R,
) -> Result<u64, PgWireError> {
use tokio::io::AsyncReadExt;
const CHUNK_SIZE: usize = 1024 * 1024;
let mut buf = BytesMut::with_capacity(copy_sql.len() + 16);
frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
let mut chunk = vec![0u8; CHUNK_SIZE];
loop {
let n = reader.read(&mut chunk).await?;
if n == 0 {
break;
}
frontend::encode_message(&FrontendMsg::CopyData(&chunk[..n]), &mut buf);
}
frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
let resp = self
.submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
.await?;
match resp {
PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
PipelineResponse::Done => Ok(0),
}
}
pub async fn copy_out(&self, copy_sql: &str) -> Result<Vec<u8>, PgWireError> {
use crate::protocol::types::FrontendMsg;
let mut buf = BytesMut::new();
frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
let resp = self.submit(buf, ResponseCollector::CopyOut).await?;
match resp {
PipelineResponse::Rows { rows, .. } => {
let mut result = Vec::new();
for row in rows {
for data in row.iter().flatten() {
result.extend_from_slice(data);
}
}
Ok(result)
}
PipelineResponse::Done => Ok(Vec::new()),
}
}
pub fn invalidate_statement(&self, sql: &str) {
let mut cache = match self.stmt_cache.lock() {
Ok(c) => c,
Err(poisoned) => poisoned.into_inner(),
};
cache.remove(sql);
}
pub fn clear_statement_cache(&self) {
let mut cache = match self.stmt_cache.lock() {
Ok(c) => c,
Err(poisoned) => poisoned.into_inner(),
};
cache.clear();
}
pub async fn exec_transaction(
&self,
setup_sql: &str,
query_sql: &str,
params: &[Option<&[u8]>],
param_oids: &[u32],
) -> Result<Vec<RawRow>, PgWireError> {
let (stmt_name, needs_parse) = self.lookup_or_alloc(query_sql, param_oids);
match self
.pipeline_transaction(
setup_sql,
query_sql,
params,
param_oids,
&stmt_name,
needs_parse,
)
.await
{
Ok(rows) => {
if needs_parse {
self.cache_statement(query_sql, &stmt_name);
}
Ok(rows)
}
Err(PgWireError::Pg(ref pg_err))
if !needs_parse && is_stale_statement_error(pg_err) =>
{
tracing::debug!(
sql = query_sql,
"prepared statement invalidated — re-parsing in transaction"
);
self.invalidate_statement(query_sql);
let (stmt_name, _) = self.lookup_or_alloc(query_sql, param_oids);
let result = self
.pipeline_transaction(
setup_sql, query_sql, params, param_oids, &stmt_name, true,
)
.await;
if result.is_ok() {
self.cache_statement(query_sql, &stmt_name);
}
result
}
Err(e) => Err(e),
}
}
pub async fn exec_query(
&self,
sql: &str,
params: &[Option<&[u8]>],
param_oids: &[u32],
) -> Result<Vec<RawRow>, PgWireError> {
let (stmt_name, needs_parse) = self.lookup_or_alloc(sql, param_oids);
match self
.query(sql, params, param_oids, &stmt_name, needs_parse)
.await
{
Ok(rows) => {
if needs_parse {
self.cache_statement(sql, &stmt_name);
}
Ok(rows)
}
Err(PgWireError::Pg(ref pg_err))
if !needs_parse && is_stale_statement_error(pg_err) =>
{
tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
self.invalidate_statement(sql);
let (stmt_name, _) = self.lookup_or_alloc(sql, param_oids);
let result = self.query(sql, params, param_oids, &stmt_name, true).await;
if result.is_ok() {
self.cache_statement(sql, &stmt_name);
}
result
}
Err(e) => Err(e),
}
}
const REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
pub async fn submit(
&self,
messages: BytesMut,
collector: ResponseCollector,
) -> Result<PipelineResponse, PgWireError> {
let (response_tx, response_rx) = oneshot::channel();
let req = PipelineRequest {
messages,
collector,
response_tx,
};
self.request_tx
.send(req)
.await
.map_err(|_| PgWireError::ConnectionClosed)?;
match tokio::time::timeout(Self::REQUEST_TIMEOUT, response_rx).await {
Ok(Ok(result)) => result,
Ok(Err(_)) => Err(PgWireError::ConnectionClosed),
Err(_elapsed) => {
tracing::error!(
"request timed out after {:?} — reader/writer task may be dead",
Self::REQUEST_TIMEOUT
);
Err(PgWireError::ConnectionClosed)
}
}
}
pub async fn submit_batch(
&self,
items: Vec<(BytesMut, ResponseCollector)>,
) -> Result<Vec<Result<PipelineResponse, PgWireError>>, PgWireError> {
let mut receivers = Vec::with_capacity(items.len());
for (messages, collector) in items {
let (response_tx, response_rx) = oneshot::channel();
self.request_tx
.send(PipelineRequest {
messages,
collector,
response_tx,
})
.await
.map_err(|_| PgWireError::ConnectionClosed)?;
receivers.push(response_rx);
}
let mut results = Vec::with_capacity(receivers.len());
for rx in receivers {
match tokio::time::timeout(Self::REQUEST_TIMEOUT, rx).await {
Ok(Ok(r)) => results.push(r),
Ok(Err(_)) => results.push(Err(PgWireError::ConnectionClosed)),
Err(_) => {
tracing::error!(
"submit_batch request timed out after {:?}",
Self::REQUEST_TIMEOUT
);
results.push(Err(PgWireError::ConnectionClosed));
}
}
}
Ok(results)
}
pub async fn close(&self) -> Result<(), PgWireError> {
if !self.is_alive() {
return Ok(());
}
let mut buf = BytesMut::with_capacity(5);
frontend::encode_message(&FrontendMsg::Terminate, &mut buf);
match self.submit(buf, ResponseCollector::Drain).await {
Ok(_) | Err(PgWireError::ConnectionClosed) => Ok(()),
Err(PgWireError::Io(e)) if e.kind() == std::io::ErrorKind::BrokenPipe => Ok(()),
Err(e) => Err(e),
}
}
pub async fn submit_stream(
&self,
messages: BytesMut,
row_buffer: usize,
) -> Result<
(
StreamHeader,
mpsc::Receiver<Result<StreamedRow, PgWireError>>,
),
PgWireError,
> {
let (header_tx, header_rx) = oneshot::channel();
let (row_tx, row_rx) = mpsc::channel(row_buffer);
let (response_tx, _response_rx) = oneshot::channel();
let req = PipelineRequest {
messages,
collector: ResponseCollector::Stream { header_tx, row_tx },
response_tx,
};
self.request_tx
.send(req)
.await
.map_err(|_| PgWireError::ConnectionClosed)?;
let header = header_rx
.await
.map_err(|_| PgWireError::ConnectionClosed)??;
Ok((header, row_rx))
}
pub async fn pipeline_transaction(
&self,
setup_sql: &str,
query_sql: &str,
params: &[Option<&[u8]>],
param_oids: &[u32],
stmt_name: &[u8],
needs_parse: bool,
) -> Result<Vec<RawRow>, PgWireError> {
let mut buf = BytesMut::with_capacity(1024);
frontend::encode_message(&FrontendMsg::Query(setup_sql.as_bytes()), &mut buf);
let setup_msgs = buf.split();
let text_fmts: Vec<FormatCode> = vec![FormatCode::Text; params.len().max(1)];
let result_fmts = [FormatCode::Text];
if needs_parse {
frontend::encode_message(
&FrontendMsg::Parse {
name: stmt_name,
sql: query_sql.as_bytes(),
param_oids,
},
&mut buf,
);
}
frontend::encode_message(
&FrontendMsg::Bind {
portal: b"",
statement: stmt_name,
param_formats: &text_fmts[..params.len()],
params,
result_formats: &result_fmts,
},
&mut buf,
);
frontend::encode_message(
&FrontendMsg::Execute {
portal: b"",
max_rows: 0,
},
&mut buf,
);
frontend::encode_message(&FrontendMsg::Sync, &mut buf);
let data_msgs = buf.split();
let mut commit_buf = BytesMut::with_capacity(32);
frontend::encode_message(&FrontendMsg::Query(b"COMMIT"), &mut commit_buf);
let (setup_tx, setup_rx) = oneshot::channel();
let (data_tx, data_rx) = oneshot::channel();
let (commit_tx, commit_rx) = oneshot::channel();
self.request_tx
.send(PipelineRequest {
messages: setup_msgs,
collector: ResponseCollector::Drain,
response_tx: setup_tx,
})
.await
.map_err(|_| PgWireError::ConnectionClosed)?;
self.request_tx
.send(PipelineRequest {
messages: data_msgs,
collector: ResponseCollector::Rows,
response_tx: data_tx,
})
.await
.map_err(|_| PgWireError::ConnectionClosed)?;
self.request_tx
.send(PipelineRequest {
messages: commit_buf,
collector: ResponseCollector::Drain,
response_tx: commit_tx,
})
.await
.map_err(|_| PgWireError::ConnectionClosed)?;
setup_rx
.await
.map_err(|_| PgWireError::ConnectionClosed)??;
let data_resp = data_rx.await.map_err(|_| PgWireError::ConnectionClosed)??;
commit_rx
.await
.map_err(|_| PgWireError::ConnectionClosed)??;
match data_resp {
PipelineResponse::Rows { rows, .. } => Ok(rows),
PipelineResponse::Done => Ok(Vec::new()),
}
}
pub async fn query(
&self,
sql: &str,
params: &[Option<&[u8]>],
param_oids: &[u32],
stmt_name: &[u8],
needs_parse: bool,
) -> Result<Vec<RawRow>, PgWireError> {
self.query_with_formats(sql, params, param_oids, &[], &[], stmt_name, needs_parse)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn query_with_formats(
&self,
sql: &str,
params: &[Option<&[u8]>],
param_oids: &[u32],
param_formats: &[FormatCode],
result_formats: &[FormatCode],
stmt_name: &[u8],
needs_parse: bool,
) -> Result<Vec<RawRow>, PgWireError> {
let mut buf = BytesMut::with_capacity(512);
let text_param_fmts: Vec<FormatCode>;
let param_fmts_slice: &[FormatCode] = if param_formats.is_empty() {
text_param_fmts = vec![FormatCode::Text; params.len().max(1)];
&text_param_fmts[..params.len()]
} else {
param_formats
};
let default_result_fmts = [FormatCode::Text];
let result_fmts_slice: &[FormatCode] = if result_formats.is_empty() {
&default_result_fmts
} else {
result_formats
};
if needs_parse {
frontend::encode_message(
&FrontendMsg::Parse {
name: stmt_name,
sql: sql.as_bytes(),
param_oids,
},
&mut buf,
);
}
frontend::encode_message(
&FrontendMsg::Bind {
portal: b"",
statement: stmt_name,
param_formats: param_fmts_slice,
params,
result_formats: result_fmts_slice,
},
&mut buf,
);
frontend::encode_message(
&FrontendMsg::Execute {
portal: b"",
max_rows: 0,
},
&mut buf,
);
frontend::encode_message(&FrontendMsg::Sync, &mut buf);
let resp = self.submit(buf, ResponseCollector::Rows).await?;
match resp {
PipelineResponse::Rows { rows, .. } => Ok(rows),
PipelineResponse::Done => Ok(Vec::new()),
}
}
pub async fn exec_query_with_formats(
&self,
sql: &str,
params: &[Option<&[u8]>],
param_oids: &[u32],
param_formats: &[FormatCode],
result_formats: &[FormatCode],
) -> Result<Vec<RawRow>, PgWireError> {
let (stmt_name, needs_parse) = self.lookup_or_alloc(sql, param_oids);
match self
.query_with_formats(
sql,
params,
param_oids,
param_formats,
result_formats,
&stmt_name,
needs_parse,
)
.await
{
Ok(rows) => {
if needs_parse {
self.cache_statement(sql, &stmt_name);
}
Ok(rows)
}
Err(PgWireError::Pg(ref pg_err))
if !needs_parse && is_stale_statement_error(pg_err) =>
{
tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
self.invalidate_statement(sql);
let (stmt_name, _) = self.lookup_or_alloc(sql, param_oids);
let result = self
.query_with_formats(
sql,
params,
param_oids,
param_formats,
result_formats,
&stmt_name,
true,
)
.await;
if result.is_ok() {
self.cache_statement(sql, &stmt_name);
}
result
}
Err(e) => Err(e),
}
}
}
async fn writer_task(
mut rx: mpsc::Receiver<PipelineRequest>,
mut stream: tokio::io::WriteHalf<crate::tls::MaybeTlsStream>,
pending: Arc<Mutex<VecDeque<PendingResponse>>>,
pending_notify: Arc<tokio::sync::Notify>,
) {
let mut write_buf = BytesMut::with_capacity(8192);
loop {
let first = match rx.recv().await {
Some(req) => req,
None => {
drain_pending_on_exit(&pending).await;
return;
}
};
write_buf.clear();
write_buf.extend_from_slice(&first.messages);
let mut batch: Vec<PendingResponse> = vec![PendingResponse {
collector: first.collector,
response_tx: first.response_tx,
}];
while let Ok(req) = rx.try_recv() {
write_buf.extend_from_slice(&req.messages);
batch.push(PendingResponse {
collector: req.collector,
response_tx: req.response_tx,
});
}
let write_result = stream.write_all(&write_buf).await;
let write_err = match write_result {
Ok(_) => stream.flush().await.err(),
Err(e) => Some(e),
};
if let Some(e) = write_err {
tracing::error!("Writer error: {e}");
let msg = e.to_string();
for p in batch {
let _ = p.response_tx.send(Err(PgWireError::Io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
msg.clone(),
))));
}
drain_pending_on_exit(&pending).await;
return;
}
{
let mut pq = pending.lock().await;
for p in batch {
pq.push_back(p);
}
}
pending_notify.notify_one();
}
}
async fn drain_pending_on_exit(pending: &Arc<Mutex<VecDeque<PendingResponse>>>) {
let mut pq = pending.lock().await;
while let Some(pr) = pq.pop_front() {
let _ = pr.response_tx.send(Err(PgWireError::ConnectionClosed));
}
}
async fn reader_task(
mut stream: tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
pending: Arc<Mutex<VecDeque<PendingResponse>>>,
pending_notify: Arc<tokio::sync::Notify>,
notification_tx: mpsc::Sender<BackendMsg>,
state_mutated: Arc<std::sync::atomic::AtomicBool>,
dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
) {
let mut recv_buf = BytesMut::with_capacity(32 * 1024);
loop {
let pr = loop {
{
let mut pq = pending.lock().await;
if let Some(pr) = pq.pop_front() {
break pr;
}
}
pending_notify.notified().await;
};
let result = match pr.collector {
ResponseCollector::Rows => {
collect_rows(
&mut stream,
&mut recv_buf,
¬ification_tx,
&state_mutated,
&dropped_notifications,
)
.await
}
ResponseCollector::Drain => {
drain_until_ready(&mut stream, &mut recv_buf, Some(&state_mutated))
.await
.map(|_| PipelineResponse::Done)
}
ResponseCollector::Stream { header_tx, row_tx } => {
stream_rows(
&mut stream,
&mut recv_buf,
header_tx,
row_tx,
¬ification_tx,
&state_mutated,
&dropped_notifications,
)
.await;
Ok(PipelineResponse::Done)
}
ResponseCollector::CopyIn { .. } => {
collect_copy_in_response(&mut stream, &mut recv_buf, &state_mutated).await
}
ResponseCollector::CopyOut => {
collect_copy_out(&mut stream, &mut recv_buf, &state_mutated).await
}
};
let _ = pr.response_tx.send(result);
}
}
async fn read_msg(
stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
buf: &mut BytesMut,
) -> Result<BackendMsg, PgWireError> {
loop {
if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
return Ok(msg);
}
let n = stream.read_buf(buf).await?;
if n == 0 {
if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
return Ok(msg);
}
return Err(PgWireError::ConnectionClosed);
}
}
}
fn note_rfq_status(status: u8, state_mutated: &std::sync::atomic::AtomicBool) {
if status != b'I' {
state_mutated.store(true, std::sync::atomic::Ordering::Release);
}
}
async fn collect_rows(
stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
buf: &mut BytesMut,
notification_tx: &mpsc::Sender<BackendMsg>,
state_mutated: &std::sync::atomic::AtomicBool,
dropped_notifications: &std::sync::atomic::AtomicU64,
) -> Result<PipelineResponse, PgWireError> {
let mut rows = Vec::new();
let mut fields = Vec::new();
let mut command_tag = String::new();
loop {
let msg = read_msg(stream, buf).await?;
match msg {
BackendMsg::DataRow(row) => rows.push(row),
BackendMsg::RowDescription { fields: f } => fields = f,
BackendMsg::CommandComplete { tag } => command_tag = tag,
BackendMsg::ReadyForQuery { status } => {
note_rfq_status(status, state_mutated);
return Ok(PipelineResponse::Rows {
fields,
rows,
command_tag,
});
}
BackendMsg::ErrorResponse { fields } => {
drain_until_ready(stream, buf, Some(state_mutated)).await?;
return Err(PgWireError::Pg(fields));
}
msg @ BackendMsg::NotificationResponse { .. } => {
#[allow(clippy::collapsible_match)]
if notification_tx.try_send(msg).is_err() {
dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tracing::warn!("notification channel full, dropping notification");
}
}
BackendMsg::ParseComplete
| BackendMsg::BindComplete
| BackendMsg::NoData
| BackendMsg::NoticeResponse { .. }
| BackendMsg::EmptyQueryResponse => {}
_ => {}
}
}
}
async fn drain_until_ready(
stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
buf: &mut BytesMut,
state_mutated: Option<&std::sync::atomic::AtomicBool>,
) -> Result<(), PgWireError> {
loop {
let msg = read_msg(stream, buf).await?;
if let BackendMsg::ReadyForQuery { status } = msg {
if let Some(sm) = state_mutated {
note_rfq_status(status, sm);
}
return Ok(());
}
if let BackendMsg::ErrorResponse { ref fields } = msg {
tracing::warn!("Error in drain: {}: {}", fields.code, fields.message);
}
}
}
async fn stream_rows(
stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
buf: &mut BytesMut,
header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
notification_tx: &mpsc::Sender<BackendMsg>,
state_mutated: &std::sync::atomic::AtomicBool,
dropped_notifications: &std::sync::atomic::AtomicU64,
) {
let mut header_tx = Some(header_tx);
let mut fields = Vec::new();
loop {
let msg = match read_msg(stream, buf).await {
Ok(msg) => msg,
Err(e) => {
if let Some(htx) = header_tx.take() {
let _ = htx.send(Err(e));
} else {
let _ = row_tx.send(Err(e)).await;
}
return;
}
};
match msg {
BackendMsg::RowDescription { fields: f } => {
fields = f;
}
BackendMsg::DataRow(row) => {
if let Some(htx) = header_tx.take() {
let _ = htx.send(Ok(StreamHeader {
fields: fields.clone(),
}));
}
if row_tx.send(Ok(row)).await.is_err() {
let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
return;
}
}
BackendMsg::CommandComplete { .. } => {
if let Some(htx) = header_tx.take() {
let _ = htx.send(Ok(StreamHeader {
fields: std::mem::take(&mut fields),
}));
}
}
BackendMsg::ReadyForQuery { status } => {
note_rfq_status(status, state_mutated);
if let Some(htx) = header_tx.take() {
let _ = htx.send(Ok(StreamHeader {
fields: std::mem::take(&mut fields),
}));
}
return;
}
BackendMsg::ErrorResponse { fields: err } => {
if let Some(htx) = header_tx.take() {
let _ = htx.send(Err(PgWireError::Pg(err)));
} else {
let _ = row_tx.send(Err(PgWireError::Pg(err))).await;
}
let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
return;
}
msg @ BackendMsg::NotificationResponse { .. } => {
#[allow(clippy::collapsible_match)]
if notification_tx.try_send(msg).is_err() {
dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tracing::warn!("notification channel full, dropping notification");
}
}
BackendMsg::ParseComplete
| BackendMsg::BindComplete
| BackendMsg::NoData
| BackendMsg::PortalSuspended
| BackendMsg::NoticeResponse { .. }
| BackendMsg::EmptyQueryResponse => {}
_ => {}
}
}
}
async fn collect_copy_in_response(
stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
buf: &mut BytesMut,
state_mutated: &std::sync::atomic::AtomicBool,
) -> Result<PipelineResponse, PgWireError> {
let mut command_tag = String::new();
loop {
let msg = read_msg(stream, buf).await?;
match msg {
BackendMsg::CopyInResponse { .. } => {}
BackendMsg::CommandComplete { tag } => command_tag = tag,
BackendMsg::ReadyForQuery { status } => {
note_rfq_status(status, state_mutated);
return Ok(PipelineResponse::Rows {
fields: Vec::new(),
rows: Vec::new(),
command_tag,
});
}
BackendMsg::ErrorResponse { fields } => {
drain_until_ready(stream, buf, Some(state_mutated)).await?;
return Err(PgWireError::Pg(fields));
}
_ => {}
}
}
}
async fn collect_copy_out(
stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
buf: &mut BytesMut,
state_mutated: &std::sync::atomic::AtomicBool,
) -> Result<PipelineResponse, PgWireError> {
let mut data_chunks: Vec<RawRow> = Vec::new();
let mut command_tag = String::new();
loop {
let msg = read_msg(stream, buf).await?;
match msg {
BackendMsg::CopyOutResponse { .. } => {}
BackendMsg::CopyData { data } => {
let body = bytes::Bytes::from(data);
data_chunks.push(RawRow::from_full_body(body));
}
BackendMsg::CopyDone => {}
BackendMsg::CommandComplete { tag } => command_tag = tag,
BackendMsg::ReadyForQuery { status } => {
note_rfq_status(status, state_mutated);
return Ok(PipelineResponse::Rows {
fields: Vec::new(),
rows: data_chunks,
command_tag,
});
}
BackendMsg::ErrorResponse { fields } => {
drain_until_ready(stream, buf, Some(state_mutated)).await?;
return Err(PgWireError::Pg(fields));
}
_ => {}
}
}
}
fn is_stale_statement_error(err: &crate::protocol::types::PgError) -> bool {
matches!(err.code.as_str(), "26000" | "0A000")
}
fn parse_copy_count(tag: &str) -> u64 {
tag.strip_prefix("COPY ")
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0)
}
impl WireConn {
pub(crate) fn into_stream(self) -> crate::tls::MaybeTlsStream {
self.stream
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn try_enqueue_rollback_returns_false_when_channel_full() {
let (tx, _rx) = mpsc::channel::<PipelineRequest>(2);
let mut filled = false;
for _ in 0..16 {
if !try_enqueue_rollback(&tx) {
filled = true;
break;
}
}
assert!(
filled,
"expected try_enqueue_rollback to eventually return false on a full channel"
);
assert!(
!try_enqueue_rollback(&tx),
"subsequent calls on a full channel must keep returning false"
);
}
#[tokio::test]
async fn try_enqueue_rollback_returns_false_when_channel_closed() {
let (tx, rx) = mpsc::channel::<PipelineRequest>(8);
drop(rx);
assert!(
!try_enqueue_rollback(&tx),
"try_enqueue_rollback must return false when the receiver has been dropped"
);
}
#[tokio::test]
async fn try_enqueue_rollback_returns_true_and_enqueues_query() {
let (tx, mut rx) = mpsc::channel::<PipelineRequest>(2);
assert!(try_enqueue_rollback(&tx));
let req = rx.recv().await.expect("request should be received");
assert_eq!(
req.messages.first().copied(),
Some(b'Q'),
"queued request should be a simple Query message"
);
assert!(
req.messages.windows(8).any(|w| w == b"ROLLBACK"),
"queued request should contain the ROLLBACK statement text"
);
}
}