use std::collections::VecDeque;
use crate::pipeline::Expectation;
use crate::pipeline::Ticket;
use crate::conversion::{FromRow, ToParams};
use crate::error::{Error, Result};
use crate::handler::ExtendedHandler;
use crate::protocol::backend::{
BindComplete, CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse, NoData,
ParseComplete, RawMessage, ReadyForQuery, RowDescription, msg_type,
};
use crate::protocol::frontend::{
write_bind, write_describe_portal, write_execute, write_flush, write_parse, write_sync,
};
use crate::state::extended::PreparedStatement;
use crate::statement::{IntoStatement, StatementRef};
use super::conn::Conn;
pub struct Pipeline<'a> {
conn: &'a mut Conn,
queue_seq: usize,
claim_seq: usize,
aborted: bool,
column_buffer: Vec<u8>,
expectations: VecDeque<Expectation>,
}
impl<'a> Pipeline<'a> {
#[cfg(feature = "lowlevel")]
pub fn new(conn: &'a mut Conn) -> Self {
Self::new_inner(conn)
}
pub(crate) fn new_inner(conn: &'a mut Conn) -> Self {
conn.buffer_set.write_buffer.clear();
Self {
conn,
queue_seq: 0,
claim_seq: 0,
aborted: false,
column_buffer: Vec::new(),
expectations: VecDeque::new(),
}
}
#[cfg(feature = "lowlevel")]
pub async fn cleanup(&mut self) {
self.cleanup_inner().await;
}
#[cfg(not(feature = "lowlevel"))]
pub(crate) async fn cleanup(&mut self) {
self.cleanup_inner().await;
}
async fn cleanup_inner(&mut self) {
if self.queue_seq == 0 && self.expectations.is_empty() {
return;
}
if !self.conn.buffer_set.write_buffer.is_empty()
|| !self.expectations.iter().any(|e| *e == Expectation::Sync)
{
let _ = self.sync().await;
}
if self.aborted {
while let Some(expectation) = self.expectations.pop_front() {
if expectation == Expectation::Sync {
let _ = self.consume_ready_for_query().await;
}
}
} else {
while let Some(expectation) = self.expectations.pop_front() {
let _ = self.drain_expectation(expectation).await;
}
}
self.queue_seq = 0;
self.claim_seq = 0;
self.aborted = false;
}
async fn drain_expectation(&mut self, expectation: Expectation) {
let mut handler = crate::handler::DropHandler::new();
let _ = match expectation {
Expectation::ParseBindExecute => self.claim_parse_bind_exec_inner(&mut handler).await,
Expectation::BindExecute => self.claim_bind_exec_inner(&mut handler, None).await,
Expectation::Sync => self.consume_ready_for_query().await,
};
}
pub fn exec<'s, P: ToParams>(
&mut self,
statement: &'s (impl IntoStatement + ?Sized),
params: P,
) -> Result<Ticket<'s>> {
let seq = self.queue_seq;
self.queue_seq += 1;
match statement.statement_ref() {
StatementRef::Sql(sql) => {
self.exec_sql_inner(sql, ¶ms)?;
Ok(Ticket { seq, stmt: None })
}
StatementRef::Prepared(stmt) => {
self.exec_prepared_inner(&stmt.wire_name(), &stmt.param_oids, ¶ms)?;
Ok(Ticket {
seq,
stmt: Some(stmt),
})
}
}
}
fn exec_sql_inner<P: ToParams>(&mut self, sql: &str, params: &P) -> Result<()> {
let param_oids = params.natural_oids();
let buf = &mut self.conn.buffer_set.write_buffer;
write_parse(buf, "", sql, ¶m_oids);
write_bind(buf, "", "", params, ¶m_oids)?;
write_describe_portal(buf, "");
write_execute(buf, "", 0);
self.expectations.push_back(Expectation::ParseBindExecute);
Ok(())
}
fn exec_prepared_inner<P: ToParams>(
&mut self,
stmt_name: &str,
param_oids: &[u32],
params: &P,
) -> Result<()> {
let buf = &mut self.conn.buffer_set.write_buffer;
write_bind(buf, "", stmt_name, params, param_oids)?;
write_execute(buf, "", 0);
self.expectations.push_back(Expectation::BindExecute);
Ok(())
}
pub async fn flush(&mut self) -> Result<()> {
if !self.conn.buffer_set.write_buffer.is_empty() {
write_flush(&mut self.conn.buffer_set.write_buffer);
self.conn
.stream
.write_all(&self.conn.buffer_set.write_buffer)
.await?;
self.conn.stream.flush().await?;
self.conn.buffer_set.write_buffer.clear();
}
Ok(())
}
pub async fn sync(&mut self) -> Result<()> {
let result = self.sync_inner().await;
if let Err(e) = &result
&& e.is_connection_broken()
{
self.conn.is_broken = true;
}
result
}
async fn sync_inner(&mut self) -> Result<()> {
write_sync(&mut self.conn.buffer_set.write_buffer);
self.expectations.push_back(Expectation::Sync);
self.conn
.stream
.write_all(&self.conn.buffer_set.write_buffer)
.await?;
self.conn.stream.flush().await?;
self.conn.buffer_set.write_buffer.clear();
Ok(())
}
async fn consume_ready_for_query(&mut self) -> Result<()> {
loop {
self.conn
.stream
.read_message(&mut self.conn.buffer_set)
.await?;
let type_byte = self.conn.buffer_set.type_byte;
if RawMessage::is_async_type(type_byte) {
continue;
}
if type_byte == msg_type::ERROR_RESPONSE {
let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
return Err(error.into_error());
}
if type_byte == msg_type::READY_FOR_QUERY {
let ready = ReadyForQuery::parse(&self.conn.buffer_set.read_buffer)?;
self.conn.transaction_status = ready.transaction_status().unwrap_or_default();
return Ok(());
}
}
}
async fn consume_pending_syncs(&mut self) -> Result<()> {
while self.expectations.front() == Some(&Expectation::Sync) {
self.expectations.pop_front();
self.consume_ready_for_query().await?;
self.aborted = false;
}
Ok(())
}
#[cfg(feature = "lowlevel")]
pub async fn claim<H: ExtendedHandler>(
&mut self,
ticket: Ticket<'_>,
handler: &mut H,
) -> Result<()> {
self.claim_with_handler(ticket, handler).await
}
async fn claim_with_handler<H: ExtendedHandler>(
&mut self,
ticket: Ticket<'_>,
handler: &mut H,
) -> Result<()> {
self.check_sequence(ticket.seq)?;
if !self.conn.buffer_set.write_buffer.is_empty() {
self.sync().await?;
}
if self.aborted {
self.claim_seq += 1;
self.expectations.pop_front();
self.consume_pending_syncs().await?;
return Err(Error::LibraryBug(
"pipeline aborted due to earlier error".into(),
));
}
let expectation = self.expectations.pop_front();
let result = match expectation {
Some(Expectation::ParseBindExecute) => self.claim_parse_bind_exec_inner(handler).await,
Some(Expectation::BindExecute) => {
self.claim_bind_exec_inner(handler, ticket.stmt).await
}
Some(Expectation::Sync) => Err(Error::LibraryBug("unexpected Sync expectation".into())),
None => Err(Error::LibraryBug("no expectation in queue".into())),
};
if let Err(e) = &result {
if e.is_connection_broken() {
self.conn.is_broken = true;
}
self.aborted = true;
}
self.claim_seq += 1;
self.consume_pending_syncs().await?;
result
}
pub async fn claim_collect<T: for<'b> FromRow<'b>>(
&mut self,
ticket: Ticket<'_>,
) -> Result<Vec<T>> {
let mut handler = crate::handler::CollectHandler::<T>::new();
self.claim_with_handler(ticket, &mut handler).await?;
Ok(handler.into_rows())
}
pub async fn claim_one<T: for<'b> FromRow<'b>>(
&mut self,
ticket: Ticket<'_>,
) -> Result<Option<T>> {
let mut handler = crate::handler::FirstRowHandler::<T>::new();
self.claim_with_handler(ticket, &mut handler).await?;
Ok(handler.into_row())
}
pub async fn claim_drop(&mut self, ticket: Ticket<'_>) -> Result<()> {
let mut handler = crate::handler::DropHandler::new();
self.claim_with_handler(ticket, &mut handler).await
}
fn check_sequence(&self, seq: usize) -> Result<()> {
if seq != self.claim_seq {
return Err(Error::InvalidUsage(format!(
"claim out of order: expected seq {}, got {}",
self.claim_seq, seq
)));
}
Ok(())
}
async fn claim_parse_bind_exec_inner<H: ExtendedHandler>(
&mut self,
handler: &mut H,
) -> Result<()> {
self.read_next_message().await?;
if self.conn.buffer_set.type_byte != msg_type::PARSE_COMPLETE {
return self.unexpected_message("ParseComplete");
}
ParseComplete::parse(&self.conn.buffer_set.read_buffer)?;
self.read_next_message().await?;
if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
return self.unexpected_message("BindComplete");
}
BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
self.claim_rows_inner(handler).await
}
async fn claim_bind_exec_inner<H: ExtendedHandler>(
&mut self,
handler: &mut H,
stmt: Option<&PreparedStatement>,
) -> Result<()> {
self.read_next_message().await?;
if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
return self.unexpected_message("BindComplete");
}
BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
let row_desc = stmt.and_then(|s| s.row_desc_payload());
self.claim_rows_cached_inner(handler, row_desc).await
}
async fn claim_rows_inner<H: ExtendedHandler>(&mut self, handler: &mut H) -> Result<()> {
self.read_next_message().await?;
let has_rows = match self.conn.buffer_set.type_byte {
msg_type::ROW_DESCRIPTION => {
self.column_buffer.clear();
self.column_buffer
.extend_from_slice(&self.conn.buffer_set.read_buffer);
true
}
msg_type::NO_DATA => {
NoData::parse(&self.conn.buffer_set.read_buffer)?;
false
}
_ => {
return Err(Error::LibraryBug(format!(
"expected RowDescription or NoData, got '{}'",
self.conn.buffer_set.type_byte as char
)));
}
};
loop {
self.read_next_message().await?;
let type_byte = self.conn.buffer_set.type_byte;
match type_byte {
msg_type::DATA_ROW => {
if !has_rows {
return Err(Error::LibraryBug(
"received DataRow but no RowDescription".into(),
));
}
let cols = RowDescription::parse(&self.column_buffer)?;
let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
handler.row(cols, row)?;
}
msg_type::COMMAND_COMPLETE => {
let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
handler.result_end(cmd)?;
return Ok(());
}
msg_type::EMPTY_QUERY_RESPONSE => {
EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
return Ok(());
}
_ => {
return Err(Error::LibraryBug(format!(
"unexpected message type in pipeline claim: '{}'",
type_byte as char
)));
}
}
}
}
async fn claim_rows_cached_inner<H: ExtendedHandler>(
&mut self,
handler: &mut H,
row_desc: Option<&[u8]>,
) -> Result<()> {
loop {
self.read_next_message().await?;
let type_byte = self.conn.buffer_set.type_byte;
match type_byte {
msg_type::DATA_ROW => {
let row_desc = row_desc.ok_or_else(|| {
Error::LibraryBug("received DataRow but no RowDescription cached".into())
})?;
let cols = RowDescription::parse(row_desc)?;
let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
handler.row(cols, row)?;
}
msg_type::COMMAND_COMPLETE => {
let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
handler.result_end(cmd)?;
return Ok(());
}
msg_type::EMPTY_QUERY_RESPONSE => {
EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
return Ok(());
}
_ => {
return Err(Error::LibraryBug(format!(
"unexpected message type in pipeline claim: '{}'",
type_byte as char
)));
}
}
}
}
async fn read_next_message(&mut self) -> Result<()> {
loop {
self.conn
.stream
.read_message(&mut self.conn.buffer_set)
.await?;
let type_byte = self.conn.buffer_set.type_byte;
if RawMessage::is_async_type(type_byte) {
continue;
}
if type_byte == msg_type::ERROR_RESPONSE {
let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
return Err(error.into_error());
}
return Ok(());
}
}
fn unexpected_message<T>(&self, expected: &str) -> Result<T> {
Err(Error::LibraryBug(format!(
"expected {}, got '{}'",
expected, self.conn.buffer_set.type_byte as char
)))
}
pub fn pending_count(&self) -> usize {
self.queue_seq - self.claim_seq
}
pub fn is_aborted(&self) -> bool {
self.aborted
}
}