use bytes::BytesMut;
use tokio::io::{AsyncRead, AsyncWrite};
use super::result::{NativePgResult, NativeResultStatus};
use super::wire;
use crate::error::ReplicationError;
pub async fn simple_query<S: AsyncRead + AsyncWrite + Unpin>(
stream: &mut S,
buf: &mut BytesMut,
sql: &str,
) -> Result<NativePgResult, ReplicationError> {
let query_msg = wire::build_query_message(sql);
wire::write_all(stream, &query_msg).await?;
wire::flush(stream).await?;
let mut result = NativePgResult::new();
loop {
let msg = wire::read_message(stream, buf).await?;
if msg.is_empty() {
continue;
}
match msg[0] {
b'T' => {
result.parse_row_description(&msg[5..]);
}
b'D' => {
result.parse_data_row(&msg[5..]);
}
b'C' => {
if result.status == NativeResultStatus::Empty {
result.status = NativeResultStatus::CommandOk;
}
}
b'Z' => {
break;
}
b'W' => {
result.status = NativeResultStatus::CopyBoth;
break;
}
b'H' => {
result.status = NativeResultStatus::CopyOut;
break;
}
b'E' => {
let fields = super::error::parse_error_fields(&msg[5..]);
result.status = NativeResultStatus::FatalError;
result.error_msg = Some(format!("{}", fields));
}
b'N' => {
let fields = super::error::parse_error_fields(&msg[5..]);
tracing::info!("Server notice: {}", fields);
}
_ => {
tracing::debug!("Skipping message type '{}' during query", msg[0] as char);
}
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::BufMut;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[test]
fn test_result_status_variants() {
assert_ne!(NativeResultStatus::CommandOk, NativeResultStatus::TuplesOk);
assert_ne!(NativeResultStatus::CopyBoth, NativeResultStatus::FatalError);
}
fn build_command_complete(tag: &str) -> Vec<u8> {
let mut payload = tag.as_bytes().to_vec();
payload.push(0);
let len = (4 + payload.len()) as i32;
let mut msg = vec![b'C'];
msg.extend_from_slice(&len.to_be_bytes());
msg.extend_from_slice(&payload);
msg
}
fn build_ready_for_query(status: u8) -> Vec<u8> {
vec![b'Z', 0, 0, 0, 5, status]
}
fn build_copy_both_response() -> Vec<u8> {
let mut msg = vec![b'W'];
msg.extend_from_slice(&7i32.to_be_bytes());
msg.push(0); msg.extend_from_slice(&0i16.to_be_bytes()); msg
}
fn build_copy_out_response() -> Vec<u8> {
let mut msg = vec![b'H'];
msg.extend_from_slice(&7i32.to_be_bytes());
msg.push(0); msg.extend_from_slice(&0i16.to_be_bytes()); msg
}
fn build_error_response(severity: &str, code: &str, message: &str) -> Vec<u8> {
let mut payload = Vec::new();
payload.push(b'S');
payload.extend_from_slice(severity.as_bytes());
payload.push(0);
payload.push(b'C');
payload.extend_from_slice(code.as_bytes());
payload.push(0);
payload.push(b'M');
payload.extend_from_slice(message.as_bytes());
payload.push(0);
payload.push(0);
let len = (4 + payload.len()) as i32;
let mut msg = vec![b'E'];
msg.extend_from_slice(&len.to_be_bytes());
msg.extend_from_slice(&payload);
msg
}
fn build_row_description(names: &[&str]) -> Vec<u8> {
let mut payload = Vec::new();
payload.extend_from_slice(&(names.len() as i16).to_be_bytes());
for name in names {
payload.extend_from_slice(name.as_bytes());
payload.push(0); payload.extend_from_slice(&0i32.to_be_bytes()); payload.extend_from_slice(&0i16.to_be_bytes()); payload.extend_from_slice(&25i32.to_be_bytes()); payload.extend_from_slice(&(-1i16).to_be_bytes()); payload.extend_from_slice(&0i32.to_be_bytes()); payload.extend_from_slice(&0i16.to_be_bytes()); }
let len = (4 + payload.len()) as i32;
let mut msg = vec![b'T'];
msg.extend_from_slice(&len.to_be_bytes());
msg.extend_from_slice(&payload);
msg
}
fn build_data_row(values: &[&str]) -> Vec<u8> {
let mut payload = Vec::new();
payload.extend_from_slice(&(values.len() as i16).to_be_bytes());
for val in values {
let bytes = val.as_bytes();
payload.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
payload.extend_from_slice(bytes);
}
let len = (4 + payload.len()) as i32;
let mut msg = vec![b'D'];
msg.extend_from_slice(&len.to_be_bytes());
msg.extend_from_slice(&payload);
msg
}
#[tokio::test]
async fn test_simple_query_command_ok() {
let (mut client, mut server) = tokio::io::duplex(8192);
tokio::spawn(async move {
let mut discard = vec![0u8; 1024];
let _ = server.read(&mut discard).await;
server
.write_all(&build_command_complete("SELECT 0"))
.await
.unwrap();
server
.write_all(&build_ready_for_query(b'I'))
.await
.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let result = simple_query(&mut client, &mut buf, "SELECT 1")
.await
.unwrap();
assert_eq!(result.status(), &NativeResultStatus::CommandOk);
}
#[tokio::test]
async fn test_simple_query_with_rows() {
let (mut client, mut server) = tokio::io::duplex(8192);
tokio::spawn(async move {
let mut discard = vec![0u8; 1024];
let _ = server.read(&mut discard).await;
server
.write_all(&build_row_description(&["systemid", "timeline"]))
.await
.unwrap();
server
.write_all(&build_data_row(&["12345", "1"]))
.await
.unwrap();
server
.write_all(&build_data_row(&["67890", "2"]))
.await
.unwrap();
server
.write_all(&build_command_complete("SELECT 2"))
.await
.unwrap();
server
.write_all(&build_ready_for_query(b'I'))
.await
.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let result = simple_query(&mut client, &mut buf, "IDENTIFY_SYSTEM")
.await
.unwrap();
assert_eq!(result.ntuples(), 2);
assert_eq!(result.nfields(), 2);
assert_eq!(result.get_value(0, 0), Some("12345".to_string()));
assert_eq!(result.get_value(0, 1), Some("1".to_string()));
assert_eq!(result.get_value(1, 0), Some("67890".to_string()));
}
#[tokio::test]
async fn test_simple_query_error_response() {
let (mut client, mut server) = tokio::io::duplex(8192);
tokio::spawn(async move {
let mut discard = vec![0u8; 1024];
let _ = server.read(&mut discard).await;
server
.write_all(&build_error_response("ERROR", "42601", "syntax error"))
.await
.unwrap();
server
.write_all(&build_ready_for_query(b'I'))
.await
.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let result = simple_query(&mut client, &mut buf, "INVALID SQL")
.await
.unwrap();
assert_eq!(result.status(), &NativeResultStatus::FatalError);
assert!(result.error_message().is_some());
let err_msg = result.error_message().unwrap();
assert!(err_msg.contains("syntax error"), "Got: {err_msg}");
}
#[tokio::test]
async fn test_simple_query_copy_both() {
let (mut client, mut server) = tokio::io::duplex(8192);
tokio::spawn(async move {
let mut discard = vec![0u8; 1024];
let _ = server.read(&mut discard).await;
server.write_all(&build_copy_both_response()).await.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let result = simple_query(
&mut client,
&mut buf,
"START_REPLICATION SLOT test LOGICAL 0/0",
)
.await
.unwrap();
assert_eq!(result.status(), &NativeResultStatus::CopyBoth);
}
#[tokio::test]
async fn test_simple_query_copy_out() {
let (mut client, mut server) = tokio::io::duplex(8192);
tokio::spawn(async move {
let mut discard = vec![0u8; 1024];
let _ = server.read(&mut discard).await;
server.write_all(&build_copy_out_response()).await.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let result = simple_query(&mut client, &mut buf, "BASE_BACKUP")
.await
.unwrap();
assert_eq!(result.status(), &NativeResultStatus::CopyOut);
}
}