use super::{
PgConnection, PgError, PgResult, is_ignorable_session_message, parse_affected_rows,
unexpected_backend_message,
};
use crate::protocol::{AstEncoder, BackendMessage, PgEncoder};
use bytes::BytesMut;
use qail_core::ast::{Action, Qail};
use std::future::Future;
pub(crate) fn quote_copy_column_ident(ident: &str) -> PgResult<String> {
if ident.is_empty() {
return Err(PgError::Query(
"COPY column identifier is empty".to_string(),
));
}
if ident.contains('\0') {
return Err(PgError::Query(
"COPY column identifier contains NUL byte".to_string(),
));
}
Ok(format!("\"{}\"", ident.replace('"', "\"\"")))
}
pub(crate) fn quote_copy_table_ref(table: &str) -> PgResult<String> {
if table.is_empty() {
return Err(PgError::Query("COPY table identifier is empty".to_string()));
}
if table.contains('\0') {
return Err(PgError::Query(
"COPY table identifier contains NUL byte".to_string(),
));
}
table
.split('.')
.map(|part| {
let part = part.trim();
if part.is_empty() {
return Err(PgError::Query(
"COPY table identifier contains an empty path segment".to_string(),
));
}
quote_copy_column_ident(part)
})
.collect::<PgResult<Vec<_>>>()
.map(|parts| parts.join("."))
}
fn parse_copy_text_row(line: &[u8]) -> PgResult<Vec<String>> {
let line = if line.ends_with(b"\r") {
&line[..line.len().saturating_sub(1)]
} else {
line
};
let mut fields = Vec::new();
let mut start = 0;
for (idx, byte) in line.iter().enumerate() {
if *byte == b'\t' {
fields.push(decode_copy_text_field(&line[start..idx])?);
start = idx + 1;
}
}
fields.push(decode_copy_text_field(&line[start..])?);
Ok(fields)
}
fn decode_copy_text_field(field: &[u8]) -> PgResult<String> {
if field == b"\\N" {
return Ok(String::new());
}
let mut out = Vec::with_capacity(field.len());
let mut idx = 0;
while idx < field.len() {
if field[idx] != b'\\' {
out.push(field[idx]);
idx += 1;
continue;
}
let Some(&escaped) = field.get(idx + 1) else {
out.push(b'\\');
break;
};
match escaped {
b'b' => {
out.push(0x08);
idx += 2;
}
b'f' => {
out.push(0x0c);
idx += 2;
}
b'n' => {
out.push(b'\n');
idx += 2;
}
b'r' => {
out.push(b'\r');
idx += 2;
}
b't' => {
out.push(b'\t');
idx += 2;
}
b'v' => {
out.push(0x0b);
idx += 2;
}
b'\\' => {
out.push(b'\\');
idx += 2;
}
b'0'..=b'7' => {
let mut value = 0u16;
let mut next = idx + 1;
for _ in 0..3 {
let Some(&digit) = field.get(next) else {
break;
};
if !(b'0'..=b'7').contains(&digit) {
break;
}
value = (value * 8) + u16::from(digit - b'0');
next += 1;
}
out.push(value as u8);
idx = next;
}
b'x' => {
let mut value = 0u8;
let mut next = idx + 2;
let mut digits = 0;
while digits < 2 {
let Some(&digit) = field.get(next) else {
break;
};
let Some(nibble) = hex_nibble(digit) else {
break;
};
value = (value << 4) | nibble;
next += 1;
digits += 1;
}
if digits == 0 {
out.push(b'x');
idx += 2;
} else {
out.push(value);
idx = next;
}
}
other => {
out.push(other);
idx += 2;
}
}
}
String::from_utf8(out)
.map_err(|e| PgError::Protocol(format!("COPY text field is not valid UTF-8: {}", e)))
}
fn hex_nibble(byte: u8) -> Option<u8> {
match byte {
b'0'..=b'9' => Some(byte - b'0'),
b'a'..=b'f' => Some(byte - b'a' + 10),
b'A'..=b'F' => Some(byte - b'A' + 10),
_ => None,
}
}
#[inline]
fn return_with_desync<T>(conn: &mut PgConnection, err: PgError) -> PgResult<T> {
if matches!(
err,
PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
) {
conn.mark_io_desynced();
}
Err(err)
}
fn encode_copy_export_sql(cmd: &Qail) -> PgResult<String> {
if cmd.action != Action::Export {
return Err(PgError::Query(
"copy_export requires Qail::Export action".to_string(),
));
}
let (sql, params) =
AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
if !params.is_empty() {
return Err(PgError::Encode(format!(
"copy_export cannot encode parameterized export with {} bind parameter(s); use an unfiltered export, a prefiltered database view, or a raw COPY statement with trusted SQL",
params.len()
)));
}
Ok(sql)
}
fn drain_copy_text_rows<F>(pending: &mut Vec<u8>, chunk: &[u8], on_row: &mut F) -> PgResult<()>
where
F: FnMut(Vec<String>) -> PgResult<()>,
{
pending.extend_from_slice(chunk);
while let Some(pos) = pending.iter().position(|&b| b == b'\n') {
let line = pending[..pos].to_vec();
pending.drain(..=pos);
let row = parse_copy_text_row(&line)?;
on_row(row)?;
}
Ok(())
}
fn flush_pending_copy_text_row<F>(pending: &mut Vec<u8>, on_row: &mut F) -> PgResult<()>
where
F: FnMut(Vec<String>) -> PgResult<()>,
{
if pending.is_empty() {
return Ok(());
}
let line = std::mem::take(pending);
let row = parse_copy_text_row(&line)?;
on_row(row)
}
impl PgConnection {
pub(crate) async fn copy_in_fast(
&mut self,
table: &str,
columns: &[String],
rows: &[Vec<qail_core::ast::Value>],
) -> PgResult<u64> {
use crate::protocol::try_encode_copy_batch;
let cols: Vec<String> = columns
.iter()
.map(|c| quote_copy_column_ident(c))
.collect::<PgResult<_>>()?;
let sql = format!(
"COPY {} ({}) FROM STDIN",
quote_copy_table_ref(table)?,
cols.join(", ")
);
let batch_data = try_encode_copy_batch(rows)?;
let bytes = PgEncoder::try_encode_query_string(&sql)?;
self.write_all_with_timeout(&bytes, "stream write").await?;
let mut startup_error: Option<PgError> = None;
loop {
let msg = self.recv().await?;
match msg {
BackendMessage::CopyInResponse { .. } => {
if let Some(err) = startup_error {
return return_with_desync(self, err);
}
break;
}
BackendMessage::ReadyForQuery(_) => {
return return_with_desync(
self,
startup_error.unwrap_or_else(|| {
PgError::Protocol(
"COPY IN failed before CopyInResponse (unexpected ReadyForQuery)"
.to_string(),
)
}),
);
}
BackendMessage::ErrorResponse(err) => {
if startup_error.is_none() {
startup_error = Some(PgError::QueryServer(err.into()));
}
}
msg if is_ignorable_session_message(&msg) => {}
other => {
return return_with_desync(
self,
unexpected_backend_message("copy-in startup", &other),
);
}
}
}
self.send_copy_data(&batch_data).await?;
self.send_copy_done().await?;
let mut affected = 0u64;
let mut final_error: Option<PgError> = None;
let mut saw_command_complete = false;
loop {
let msg = self.recv().await?;
match msg {
BackendMessage::CommandComplete(tag) => {
if saw_command_complete {
return return_with_desync(
self,
PgError::Protocol(
"COPY IN received duplicate CommandComplete".to_string(),
),
);
}
saw_command_complete = true;
if final_error.is_none() {
match parse_affected_rows(&tag) {
Ok(parsed) => affected = parsed,
Err(err) => return return_with_desync(self, err),
}
}
}
BackendMessage::ReadyForQuery(_) => {
if let Some(err) = final_error {
return Err(err);
}
if !saw_command_complete {
return return_with_desync(
self,
PgError::Protocol(
"COPY IN completion missing CommandComplete before ReadyForQuery"
.to_string(),
),
);
}
return Ok(affected);
}
BackendMessage::ErrorResponse(err) => {
if final_error.is_none() {
final_error = Some(PgError::QueryServer(err.into()));
}
}
msg if is_ignorable_session_message(&msg) => {}
other => {
return return_with_desync(
self,
unexpected_backend_message("copy-in completion", &other),
);
}
}
}
}
pub async fn copy_in_raw(
&mut self,
table: &str,
columns: &[String],
data: &[u8],
) -> PgResult<u64> {
let cols: Vec<String> = columns
.iter()
.map(|c| quote_copy_column_ident(c))
.collect::<PgResult<_>>()?;
let sql = format!(
"COPY {} ({}) FROM STDIN",
quote_copy_table_ref(table)?,
cols.join(", ")
);
let bytes = PgEncoder::try_encode_query_string(&sql)?;
self.write_all_with_timeout(&bytes, "stream write").await?;
let mut startup_error: Option<PgError> = None;
loop {
let msg = self.recv().await?;
match msg {
BackendMessage::CopyInResponse { .. } => {
if let Some(err) = startup_error {
return return_with_desync(self, err);
}
break;
}
BackendMessage::ReadyForQuery(_) => {
return return_with_desync(
self,
startup_error.unwrap_or_else(|| {
PgError::Protocol(
"COPY IN failed before CopyInResponse (unexpected ReadyForQuery)"
.to_string(),
)
}),
);
}
BackendMessage::ErrorResponse(err) => {
if startup_error.is_none() {
startup_error = Some(PgError::QueryServer(err.into()));
}
}
msg if is_ignorable_session_message(&msg) => {}
other => {
return return_with_desync(
self,
unexpected_backend_message("copy-in raw startup", &other),
);
}
}
}
self.send_copy_data(data).await?;
self.send_copy_done().await?;
let mut affected = 0u64;
let mut final_error: Option<PgError> = None;
let mut saw_command_complete = false;
loop {
let msg = self.recv().await?;
match msg {
BackendMessage::CommandComplete(tag) => {
if saw_command_complete {
return return_with_desync(
self,
PgError::Protocol(
"COPY IN raw received duplicate CommandComplete".to_string(),
),
);
}
saw_command_complete = true;
if final_error.is_none() {
match parse_affected_rows(&tag) {
Ok(parsed) => affected = parsed,
Err(err) => return return_with_desync(self, err),
}
}
}
BackendMessage::ReadyForQuery(_) => {
if let Some(err) = final_error {
return Err(err);
}
if !saw_command_complete {
return return_with_desync(
self,
PgError::Protocol(
"COPY IN raw completion missing CommandComplete before ReadyForQuery"
.to_string(),
),
);
}
return Ok(affected);
}
BackendMessage::ErrorResponse(err) => {
if final_error.is_none() {
final_error = Some(PgError::QueryServer(err.into()));
}
}
msg if is_ignorable_session_message(&msg) => {}
other => {
return return_with_desync(
self,
unexpected_backend_message("copy-in raw completion", &other),
);
}
}
}
}
pub(crate) async fn send_copy_data(&mut self, data: &[u8]) -> PgResult<()> {
let total_len = data
.len()
.checked_add(4)
.ok_or_else(|| PgError::Protocol("CopyData frame length overflow".to_string()))?;
let len = i32::try_from(total_len)
.map_err(|_| PgError::Protocol("CopyData frame exceeds i32::MAX".to_string()))?;
let mut buf = BytesMut::with_capacity(1 + 4 + data.len());
buf.extend_from_slice(b"d");
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(data);
self.write_all_with_timeout(&buf, "stream write").await?;
Ok(())
}
async fn send_copy_done(&mut self) -> PgResult<()> {
self.write_all_with_timeout(&[b'c', 0, 0, 0, 4], "stream write")
.await?;
Ok(())
}
async fn start_copy_out(&mut self, sql: &str, context: &str) -> PgResult<()> {
let bytes = PgEncoder::try_encode_query_string(sql)?;
self.write_all_with_timeout(&bytes, "stream write").await?;
let mut startup_error: Option<PgError> = None;
loop {
let msg = self.recv().await?;
match msg {
BackendMessage::CopyOutResponse { .. } => {
if let Some(err) = startup_error {
return return_with_desync(self, err);
}
return Ok(());
}
BackendMessage::ReadyForQuery(_) => {
return return_with_desync(
self,
startup_error.unwrap_or_else(|| {
PgError::Protocol(format!(
"{} failed before CopyOutResponse (unexpected ReadyForQuery)",
context
))
}),
);
}
BackendMessage::ErrorResponse(err) => {
if startup_error.is_none() {
startup_error = Some(PgError::QueryServer(err.into()));
}
}
msg if is_ignorable_session_message(&msg) => {}
other => {
return return_with_desync(self, unexpected_backend_message(context, &other));
}
}
}
}
async fn stream_copy_out_chunks<F, Fut>(
&mut self,
context: &str,
mut on_chunk: F,
) -> PgResult<()>
where
F: FnMut(Vec<u8>) -> Fut,
Fut: Future<Output = PgResult<()>>,
{
let mut stream_error: Option<PgError> = None;
let mut callback_error: Option<PgError> = None;
let mut saw_copy_done = false;
let mut saw_command_complete = false;
loop {
let msg = self.recv().await?;
match msg {
BackendMessage::CopyData(chunk) => {
if saw_copy_done {
return return_with_desync(
self,
PgError::Protocol(format!(
"{} received CopyData after CopyDone",
context
)),
);
}
if stream_error.is_none()
&& callback_error.is_none()
&& let Err(e) = on_chunk(chunk).await
{
callback_error = Some(e);
}
}
BackendMessage::CopyDone => {
if saw_copy_done {
return return_with_desync(
self,
PgError::Protocol(format!("{} received duplicate CopyDone", context)),
);
}
saw_copy_done = true;
}
BackendMessage::CommandComplete(_) => {
if saw_command_complete {
return return_with_desync(
self,
PgError::Protocol(format!(
"{} received duplicate CommandComplete",
context
)),
);
}
saw_command_complete = true;
}
BackendMessage::ReadyForQuery(_) => {
if let Some(err) = stream_error {
return Err(err);
}
if let Some(err) = callback_error {
return Err(err);
}
if !saw_copy_done {
return return_with_desync(
self,
PgError::Protocol(format!(
"{} missing CopyDone before ReadyForQuery",
context
)),
);
}
if !saw_command_complete {
return return_with_desync(
self,
PgError::Protocol(format!(
"{} missing CommandComplete before ReadyForQuery",
context
)),
);
}
return Ok(());
}
BackendMessage::ErrorResponse(err) => {
if stream_error.is_none() {
stream_error = Some(PgError::QueryServer(err.into()));
}
}
msg if is_ignorable_session_message(&msg) => {}
other => {
return return_with_desync(self, unexpected_backend_message(context, &other));
}
}
}
}
pub async fn copy_export(&mut self, cmd: &Qail) -> PgResult<Vec<Vec<String>>> {
let mut rows = Vec::new();
self.copy_export_stream_rows(cmd, |row| {
rows.push(row);
Ok(())
})
.await?;
Ok(rows)
}
pub async fn copy_export_stream_raw<F, Fut>(&mut self, cmd: &Qail, on_chunk: F) -> PgResult<()>
where
F: FnMut(Vec<u8>) -> Fut,
Fut: Future<Output = PgResult<()>>,
{
let sql = encode_copy_export_sql(cmd)?;
self.copy_out_raw_stream(&sql, on_chunk).await
}
pub async fn copy_export_stream_rows<F>(&mut self, cmd: &Qail, mut on_row: F) -> PgResult<()>
where
F: FnMut(Vec<String>) -> PgResult<()>,
{
let mut pending = Vec::new();
self.copy_export_stream_raw(cmd, |chunk| {
let res = drain_copy_text_rows(&mut pending, &chunk, &mut on_row);
std::future::ready(res)
})
.await?;
flush_pending_copy_text_row(&mut pending, &mut on_row)
}
pub(crate) async fn copy_out_raw(&mut self, sql: &str) -> PgResult<Vec<u8>> {
let mut data = Vec::new();
self.copy_out_raw_stream(sql, |chunk| {
data.extend_from_slice(&chunk);
std::future::ready(Ok(()))
})
.await?;
Ok(data)
}
pub(crate) async fn copy_out_raw_stream<F, Fut>(
&mut self,
sql: &str,
on_chunk: F,
) -> PgResult<()>
where
F: FnMut(Vec<u8>) -> Fut,
Fut: Future<Output = PgResult<()>>,
{
self.start_copy_out(sql, "copy-out raw startup").await?;
self.stream_copy_out_chunks("copy-out raw stream", on_chunk)
.await
}
}
#[cfg(test)]
mod tests {
use super::{
drain_copy_text_rows, encode_copy_export_sql, flush_pending_copy_text_row,
parse_copy_text_row, quote_copy_column_ident, quote_copy_table_ref, return_with_desync,
};
use crate::driver::{PgConnection, PgError, PgResult};
use qail_core::ast::{Operator, Qail};
#[cfg(unix)]
fn test_conn() -> PgConnection {
use crate::driver::connection::StatementCache;
use crate::driver::stream::PgStream;
use bytes::BytesMut;
use std::collections::{HashMap, VecDeque};
use std::num::NonZeroUsize;
use tokio::net::UnixStream;
let (unix_stream, _peer) = UnixStream::pair().expect("unix stream pair");
PgConnection {
stream: PgStream::Unix(unix_stream),
buffer: BytesMut::with_capacity(1024),
write_buf: BytesMut::with_capacity(1024),
sql_buf: BytesMut::with_capacity(256),
params_buf: Vec::new(),
prepared_statements: HashMap::new(),
stmt_cache: StatementCache::new(NonZeroUsize::new(2).expect("non-zero")),
column_info_cache: HashMap::new(),
process_id: 0,
cancel_key_bytes: Vec::new(),
requested_protocol_minor: PgConnection::default_protocol_minor(),
negotiated_protocol_minor: PgConnection::default_protocol_minor(),
notifications: VecDeque::new(),
replication_stream_active: false,
replication_mode_enabled: false,
last_replication_wal_end: None,
io_desynced: false,
pending_statement_closes: Vec::new(),
draining_statement_closes: false,
}
}
#[test]
fn parse_copy_text_row_splits_tabs() {
let row = parse_copy_text_row(b"a\tb\tc").unwrap();
assert_eq!(row, vec!["a", "b", "c"]);
}
#[test]
fn parse_copy_text_row_trims_cr() {
let row = parse_copy_text_row(b"a\tb\r").unwrap();
assert_eq!(row, vec!["a", "b"]);
}
#[test]
fn parse_copy_text_row_unescapes_copy_text_values() {
let row = parse_copy_text_row(b"a\\tb\tline\\nnext\tc\\\\d").unwrap();
assert_eq!(row, vec!["a\tb", "line\nnext", "c\\d"]);
}
#[test]
fn parse_copy_text_row_maps_copy_null_marker_to_empty_string() {
let row = parse_copy_text_row(b"a\t\\N\tb").unwrap();
assert_eq!(row, vec!["a", "", "b"]);
}
#[test]
fn parse_copy_text_row_rejects_invalid_utf8() {
let err = parse_copy_text_row(&[0xff]).expect_err("invalid UTF-8 must fail");
assert!(
err.to_string()
.contains("COPY text field is not valid UTF-8")
);
}
#[test]
fn copy_table_quoting_preserves_schema_qualification() {
assert_eq!(
quote_copy_table_ref("tenant_a.users").unwrap(),
"\"tenant_a\".\"users\""
);
}
#[test]
fn copy_identifier_quoting_rejects_nul_bytes() {
assert!(quote_copy_table_ref("tenant\0.users").is_err());
assert!(quote_copy_column_ident("name\0").is_err());
}
#[test]
fn copy_export_rejects_parameterized_ast_before_streaming() {
let cmd = Qail::export("users").filter("active", Operator::Eq, true);
let err = encode_copy_export_sql(&cmd).expect_err("bind params cannot be ignored");
assert!(matches!(err, PgError::Encode(msg) if msg.contains("parameterized export")));
}
#[cfg(unix)]
#[tokio::test]
async fn copy_return_with_desync_marks_protocol_error() {
let mut conn = test_conn();
let err = return_with_desync::<()>(
&mut conn,
PgError::Protocol("copy protocol ordering broke".to_string()),
)
.expect_err("protocol error must be returned");
assert!(err.to_string().contains("copy protocol ordering broke"));
assert!(conn.is_io_desynced());
}
#[test]
fn drain_copy_text_rows_handles_chunk_boundaries() {
let mut pending = Vec::new();
let mut rows: Vec<Vec<String>> = Vec::new();
drain_copy_text_rows(&mut pending, b"a\tb\nc", &mut |row: Vec<String>| {
rows.push(row);
Ok(())
})
.unwrap();
assert_eq!(rows, vec![vec!["a".to_string(), "b".to_string()]]);
assert_eq!(pending, b"c");
drain_copy_text_rows(&mut pending, b"\td\n", &mut |row: Vec<String>| {
rows.push(row);
Ok(())
})
.unwrap();
assert_eq!(
rows,
vec![
vec!["a".to_string(), "b".to_string()],
vec!["c".to_string(), "d".to_string()]
]
);
assert!(pending.is_empty());
}
#[test]
fn flush_pending_copy_text_row_emits_final_partial_line() {
let mut pending = b"x\ty".to_vec();
let mut rows = Vec::new();
let mut on_row = |row: Vec<String>| -> PgResult<()> {
rows.push(row);
Ok(())
};
flush_pending_copy_text_row(&mut pending, &mut on_row).unwrap();
assert_eq!(rows, vec![vec!["x".to_string(), "y".to_string()]]);
assert!(pending.is_empty());
}
#[test]
fn callback_error_bubbles_from_row_drainer() {
let mut pending = Vec::new();
let mut on_row =
|_row: Vec<String>| -> PgResult<()> { Err(PgError::Query("fail".to_string())) };
let err = drain_copy_text_rows(&mut pending, b"a\tb\n", &mut on_row).unwrap_err();
assert!(matches!(err, PgError::Query(msg) if msg == "fail"));
}
}