pub mod batch;
use std::sync::Arc;
use bytes::BytesMut;
use crate::error::{Error, Result};
use crate::protocol::backend::BackendMessage;
use crate::protocol::frontend;
use crate::row::{parse_command_tag, CommandResult, Row, RowDescription};
#[derive(Debug)]
pub struct PipelineQuery {
pub sql: String,
pub param_types: Vec<u32>,
pub params: Vec<Option<Vec<u8>>>,
}
#[derive(Debug)]
pub enum QueryResult {
Rows(Vec<Row>),
Command(CommandResult),
}
impl QueryResult {
pub fn into_rows(self) -> Result<Vec<Row>> {
match self {
QueryResult::Rows(rows) => Ok(rows),
QueryResult::Command(_) => Err(Error::Protocol(
"expected rows but got command result".to_string(),
)),
}
}
pub fn into_command(self) -> Result<CommandResult> {
match self {
QueryResult::Command(r) => Ok(r),
QueryResult::Rows(_) => Err(Error::Protocol(
"expected command result but got rows".to_string(),
)),
}
}
}
pub fn encode_pipeline(buf: &mut BytesMut, queries: &[PipelineQuery]) {
for q in queries {
let oids: Vec<u32> = q.param_types.clone();
frontend::parse(buf, "", &q.sql, &oids);
let param_refs: Vec<Option<&[u8]>> = q.params.iter().map(|p| p.as_deref()).collect();
frontend::bind(buf, "", "", ¶m_refs, &[]);
frontend::describe_portal(buf, "");
frontend::execute(buf, "", 0);
}
frontend::sync(buf);
}
pub(crate) async fn read_pipeline_responses(
conn: &mut crate::connection::stream::PgConnection,
count: usize,
) -> Result<Vec<QueryResult>> {
let mut results = Vec::with_capacity(count);
for _ in 0..count {
expect_message(conn, "ParseComplete", |m| {
matches!(m, BackendMessage::ParseComplete)
})
.await?;
expect_message(conn, "BindComplete", |m| {
matches!(m, BackendMessage::BindComplete)
})
.await?;
let msg = conn.recv().await?;
let description = match msg {
BackendMessage::RowDescription { fields } => {
Some(Arc::new(RowDescription::new(fields)))
}
BackendMessage::NoData => None,
BackendMessage::ErrorResponse { fields } => {
return Err(Error::server(
fields.severity,
fields.code,
fields.message,
fields.detail,
fields.hint,
fields.position,
));
}
other => {
return Err(Error::protocol(format!(
"expected RowDescription or NoData, got {other:?}"
)));
}
};
let result = read_query_result(conn, description).await?;
results.push(result);
}
let msg = conn.recv().await?;
match msg {
BackendMessage::ReadyForQuery { .. } => {}
BackendMessage::ErrorResponse { fields } => {
return Err(Error::server(
fields.severity,
fields.code,
fields.message,
fields.detail,
fields.hint,
fields.position,
));
}
other => {
return Err(Error::protocol(format!(
"expected ReadyForQuery, got {other:?}"
)));
}
}
Ok(results)
}
async fn read_query_result(
conn: &mut crate::connection::stream::PgConnection,
description: Option<Arc<RowDescription>>,
) -> Result<QueryResult> {
let mut rows = Vec::new();
loop {
let msg = conn.recv().await?;
match msg {
BackendMessage::DataRow { columns } => {
let desc = description
.as_ref()
.ok_or_else(|| Error::protocol("received DataRow without RowDescription"))?;
rows.push(Row::new(columns, Arc::clone(desc)));
}
BackendMessage::CommandComplete { tag } => {
if rows.is_empty() {
return Ok(QueryResult::Command(parse_command_tag(&tag)));
} else {
return Ok(QueryResult::Rows(rows));
}
}
BackendMessage::EmptyQueryResponse => {
return Ok(QueryResult::Command(CommandResult {
command: String::new(),
rows_affected: 0,
}));
}
BackendMessage::ErrorResponse { fields } => {
return Err(Error::server(
fields.severity,
fields.code,
fields.message,
fields.detail,
fields.hint,
fields.position,
));
}
other => {
return Err(Error::protocol(format!(
"unexpected message in query result: {other:?}"
)));
}
}
}
}
async fn expect_message(
conn: &mut crate::connection::stream::PgConnection,
expected: &str,
check: impl FnOnce(&BackendMessage) -> bool,
) -> Result<()> {
let msg = conn.recv().await?;
if check(&msg) {
Ok(())
} else if let BackendMessage::ErrorResponse { fields } = msg {
Err(Error::server(
fields.severity,
fields.code,
fields.message,
fields.detail,
fields.hint,
fields.position,
))
} else {
Err(Error::protocol(format!("expected {expected}, got {msg:?}")))
}
}