use compio::buf::BufResult;
use compio::net::TcpStream;
#[cfg(unix)]
use compio::net::UnixStream;
use crate::buffer_pool::PooledBufferSet;
use crate::conversion::ToParams;
use crate::error::{Error, Result};
use crate::handler::{DropHandler, ExtendedHandler, FirstRowHandler, SimpleHandler};
use crate::opts::Opts;
use crate::protocol::backend::BackendKeyData;
use crate::protocol::frontend::write_terminate;
use crate::protocol::types::TransactionStatus;
use crate::state::StateMachine;
use crate::state::action::Action;
use crate::state::connection::ConnectionStateMachine;
use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
use crate::statement::{IntoStatement, StatementRef};
use super::stream::Stream;
type AsyncMessageHandlerBox = Option<Box<dyn FnMut(&crate::state::action::AsyncMessage)>>;
pub struct Conn {
pub(crate) stream: Stream,
pub(crate) buffer_set: PooledBufferSet,
backend_key: Option<BackendKeyData>,
server_params: Vec<(String, String)>,
pub(crate) transaction_status: TransactionStatus,
pub(crate) is_broken: bool,
name_counter: u64,
async_message_handler: AsyncMessageHandlerBox,
}
impl Conn {
pub async fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
where
Error: From<O::Error>,
{
let opts = opts.try_into()?;
let stream = if let Some(socket_path) = &opts.socket {
#[cfg(unix)]
{
Stream::unix(UnixStream::connect(socket_path).await?)
}
#[cfg(not(unix))]
{
let _ = socket_path;
return Err(Error::Unsupported(
"Unix sockets are not supported on this platform".into(),
));
}
} else {
if opts.host.is_empty() {
return Err(Error::InvalidUsage("host is empty".into()));
}
let addr = format!("{}:{}", opts.host, opts.port);
let tcp = TcpStream::connect(&addr).await?;
tcp.set_nodelay(true)?;
Stream::tcp(tcp)
};
Self::new_with_stream(stream, opts).await
}
pub async fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
let mut buffer_set = options.buffer_pool.get_buffer_set();
let mut state_machine = ConnectionStateMachine::new(options.clone());
loop {
match state_machine.step(&mut buffer_set)? {
Action::WriteAndReadByte => {
let buf = std::mem::take(&mut buffer_set.write_buffer);
let BufResult(result, buf) = stream.write_all_owned(buf).await;
buffer_set.write_buffer = buf;
result?;
stream.flush().await?;
let byte = stream.read_u8().await?;
state_machine.set_ssl_response(byte);
}
Action::ReadMessage => {
stream.read_message(&mut buffer_set).await?;
}
Action::Write => {
let buf = std::mem::take(&mut buffer_set.write_buffer);
let BufResult(result, buf) = stream.write_all_owned(buf).await;
buffer_set.write_buffer = buf;
result?;
stream.flush().await?;
}
Action::WriteAndReadMessage => {
let buf = std::mem::take(&mut buffer_set.write_buffer);
let BufResult(result, buf) = stream.write_all_owned(buf).await;
buffer_set.write_buffer = buf;
result?;
stream.flush().await?;
stream.read_message(&mut buffer_set).await?;
}
Action::TlsHandshake => {
#[cfg(feature = "compio-tls")]
{
stream = stream.upgrade_to_tls(&options.host).await?;
}
#[cfg(not(feature = "compio-tls"))]
{
return Err(Error::Unsupported(
"TLS requested but compio-tls feature not enabled".into(),
));
}
}
Action::HandleAsyncMessageAndReadMessage(_) => {
stream.read_message(&mut buffer_set).await?;
}
Action::Error(_) => {
return Err(Error::LibraryBug(
"unexpected server error during connection startup".into(),
));
}
Action::Finished => break,
}
}
let conn = Self {
stream,
buffer_set,
backend_key: state_machine.backend_key().cloned(),
server_params: state_machine.take_server_params(),
transaction_status: state_machine.transaction_status(),
is_broken: false,
name_counter: 0,
async_message_handler: None,
};
#[cfg(unix)]
let conn = if options.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
conn.try_upgrade_to_unix_socket(&options).await
} else {
conn
};
Ok(conn)
}
#[cfg(unix)]
fn try_upgrade_to_unix_socket(
mut self,
opts: &Opts,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Self> + '_>> {
let opts = opts.clone();
Box::pin(async move {
let mut handler = FirstRowHandler::<(String,)>::new();
if self
.query("SHOW unix_socket_directories", &mut handler)
.await
.is_err()
{
return self;
}
let socket_dir = match handler.into_row() {
Some((dirs,)) => {
match dirs.split(',').next() {
Some(d) if !d.trim().is_empty() => d.trim().to_string(),
_ => return self,
}
}
None => return self,
};
let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
let unix_stream = match UnixStream::connect(&socket_path).await {
Ok(s) => s,
Err(_) => return self,
};
let mut opts_unix = opts.clone();
opts_unix.upgrade_to_unix_socket = false;
match Self::new_with_stream(Stream::unix(unix_stream), opts_unix).await {
Ok(new_conn) => new_conn,
Err(_) => self,
}
})
}
pub fn backend_key(&self) -> Option<&BackendKeyData> {
self.backend_key.as_ref()
}
pub fn connection_id(&self) -> u32 {
self.backend_key.as_ref().map_or(0, |k| k.process_id())
}
pub fn server_params(&self) -> &[(String, String)] {
&self.server_params
}
pub fn transaction_status(&self) -> TransactionStatus {
self.transaction_status
}
pub fn in_transaction(&self) -> bool {
self.transaction_status.in_transaction()
}
pub fn is_broken(&self) -> bool {
self.is_broken
}
pub(crate) fn next_portal_name(&mut self) -> String {
self.name_counter += 1;
format!("_zero_p_{}", self.name_counter)
}
pub(crate) async fn create_named_portal<S: IntoStatement, P: ToParams>(
&mut self,
portal_name: &str,
statement: &S,
params: &P,
) -> Result<()> {
let mut state_machine = match statement.statement_ref() {
StatementRef::Sql(sql) => {
BindStateMachine::bind_sql(&mut self.buffer_set, portal_name, sql, params)?
}
StatementRef::Prepared(stmt) => BindStateMachine::bind_prepared(
&mut self.buffer_set,
portal_name,
&stmt.wire_name(),
&stmt.param_oids,
params,
)?,
};
loop {
match state_machine.step(&mut self.buffer_set)? {
Action::ReadMessage => {
self.stream.read_message(&mut self.buffer_set).await?;
}
Action::Write => {
let buf = std::mem::take(&mut self.buffer_set.write_buffer);
let BufResult(result, buf) = self.stream.write_all_owned(buf).await;
self.buffer_set.write_buffer = buf;
result?;
self.stream.flush().await?;
}
Action::WriteAndReadMessage => {
let buf = std::mem::take(&mut self.buffer_set.write_buffer);
let BufResult(result, buf) = self.stream.write_all_owned(buf).await;
self.buffer_set.write_buffer = buf;
result?;
self.stream.flush().await?;
self.stream.read_message(&mut self.buffer_set).await?;
}
Action::Finished => break,
_ => return Err(Error::LibraryBug("Unexpected action in bind".into())),
}
}
Ok(())
}
pub fn set_async_message_handler(
&mut self,
handler: impl FnMut(&crate::state::action::AsyncMessage) + 'static,
) {
self.async_message_handler = Some(Box::new(handler));
}
pub fn clear_async_message_handler(&mut self) {
self.async_message_handler = None;
}
pub async fn ping(&mut self) -> Result<()> {
self.query_drop("").await?;
Ok(())
}
async fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
loop {
match state_machine.step(&mut self.buffer_set)? {
Action::WriteAndReadByte => {
return Err(Error::LibraryBug(
"Unexpected WriteAndReadByte in query state machine".into(),
));
}
Action::ReadMessage => {
self.stream.read_message(&mut self.buffer_set).await?;
}
Action::Write => {
let buf = std::mem::take(&mut self.buffer_set.write_buffer);
let BufResult(result, buf) = self.stream.write_all_owned(buf).await;
self.buffer_set.write_buffer = buf;
result?;
self.stream.flush().await?;
}
Action::WriteAndReadMessage => {
let buf = std::mem::take(&mut self.buffer_set.write_buffer);
let BufResult(result, buf) = self.stream.write_all_owned(buf).await;
self.buffer_set.write_buffer = buf;
result?;
self.stream.flush().await?;
self.stream.read_message(&mut self.buffer_set).await?;
}
Action::TlsHandshake => {
return Err(Error::LibraryBug(
"Unexpected TlsHandshake in query state machine".into(),
));
}
Action::HandleAsyncMessageAndReadMessage(async_msg) => {
if let Some(h) = &mut self.async_message_handler {
h(&async_msg);
}
self.stream.read_message(&mut self.buffer_set).await?;
}
Action::Error(server_error) => {
self.transaction_status = state_machine.transaction_status();
return Err(Error::Server(server_error));
}
Action::Finished => {
self.transaction_status = state_machine.transaction_status();
break;
}
}
}
Ok(())
}
pub async fn query<H: SimpleHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
let result = self.query_inner(sql, handler).await;
if let Err(e) = &result
&& e.is_connection_broken()
{
self.is_broken = true;
}
result
}
async fn query_inner<H: SimpleHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
self.drive(&mut state_machine).await
}
pub async fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
let mut handler = DropHandler::new();
self.query(sql, &mut handler).await?;
Ok(handler.rows_affected())
}
pub async fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
&mut self,
sql: &str,
) -> Result<Vec<T>> {
let mut handler = crate::handler::CollectHandler::<T>::new();
self.query(sql, &mut handler).await?;
Ok(handler.into_rows())
}
pub async fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
&mut self,
sql: &str,
) -> Result<Option<T>> {
let mut handler = crate::handler::FirstRowHandler::<T>::new();
self.query(sql, &mut handler).await?;
Ok(handler.into_row())
}
pub async fn query_foreach<
T: for<'a> crate::conversion::FromRow<'a>,
F: FnMut(T) -> Result<()>,
>(
&mut self,
sql: &str,
f: F,
) -> Result<()> {
let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
self.query(sql, &mut handler).await?;
Ok(())
}
pub async fn close(mut self) -> Result<()> {
self.buffer_set.write_buffer.clear();
write_terminate(&mut self.buffer_set.write_buffer);
let buf = std::mem::take(&mut self.buffer_set.write_buffer);
let BufResult(result, buf) = self.stream.write_all_owned(buf).await;
self.buffer_set.write_buffer = buf;
result?;
self.stream.flush().await?;
Ok(())
}
pub async fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
self.prepare_typed(query, &[]).await
}
pub async fn prepare_typed(
&mut self,
query: &str,
param_oids: &[u32],
) -> Result<PreparedStatement> {
self.name_counter += 1;
let idx = self.name_counter;
let result = self.prepare_inner(idx, query, param_oids).await;
if let Err(e) = &result
&& e.is_connection_broken()
{
self.is_broken = true;
}
result
}
pub async fn prepare_batch(&mut self, queries: &[&str]) -> Result<Vec<PreparedStatement>> {
if queries.is_empty() {
return Ok(Vec::new());
}
let start_idx = self.name_counter + 1;
self.name_counter += queries.len() as u64;
let result = self.prepare_batch_inner(queries, start_idx).await;
if let Err(e) = &result
&& e.is_connection_broken()
{
self.is_broken = true;
}
result
}
async fn prepare_batch_inner(
&mut self,
queries: &[&str],
start_idx: u64,
) -> Result<Vec<PreparedStatement>> {
use crate::state::batch_prepare::BatchPrepareStateMachine;
let mut state_machine =
BatchPrepareStateMachine::new(&mut self.buffer_set, queries, start_idx);
loop {
match state_machine.step(&mut self.buffer_set)? {
Action::ReadMessage => {
self.stream.read_message(&mut self.buffer_set).await?;
}
Action::WriteAndReadMessage => {
let buf = std::mem::take(&mut self.buffer_set.write_buffer);
let BufResult(result, buf) = self.stream.write_all_owned(buf).await;
self.buffer_set.write_buffer = buf;
result?;
self.stream.flush().await?;
self.stream.read_message(&mut self.buffer_set).await?;
}
Action::Finished => {
self.transaction_status = state_machine.transaction_status();
break;
}
_ => {
return Err(Error::LibraryBug(
"Unexpected action in batch prepare".into(),
));
}
}
}
Ok(state_machine.take_statements())
}
async fn prepare_inner(
&mut self,
idx: u64,
query: &str,
param_oids: &[u32],
) -> Result<PreparedStatement> {
let mut handler = DropHandler::new();
let mut state_machine = ExtendedQueryStateMachine::prepare(
&mut handler,
&mut self.buffer_set,
idx,
query,
param_oids,
);
self.drive(&mut state_machine).await?;
state_machine
.take_prepared_statement()
.ok_or_else(|| Error::LibraryBug("No prepared statement".into()))
}
pub async fn exec<S: IntoStatement, P: ToParams, H: ExtendedHandler>(
&mut self,
statement: S,
params: P,
handler: &mut H,
) -> Result<()> {
let result = self.exec_inner(&statement, ¶ms, handler).await;
if let Err(e) = &result
&& e.is_connection_broken()
{
self.is_broken = true;
}
result
}
async fn exec_inner<S: IntoStatement, P: ToParams, H: ExtendedHandler>(
&mut self,
statement: &S,
params: &P,
handler: &mut H,
) -> Result<()> {
let mut state_machine = match statement.statement_ref() {
StatementRef::Sql(sql) => {
ExtendedQueryStateMachine::execute_sql(handler, &mut self.buffer_set, sql, params)?
}
StatementRef::Prepared(stmt) => ExtendedQueryStateMachine::execute(
handler,
&mut self.buffer_set,
&stmt.wire_name(),
&stmt.param_oids,
params,
)?,
};
self.drive(&mut state_machine).await
}
pub async fn exec_drop<S: IntoStatement, P: ToParams>(
&mut self,
statement: S,
params: P,
) -> Result<Option<u64>> {
let mut handler = DropHandler::new();
self.exec(statement, params, &mut handler).await?;
Ok(handler.rows_affected())
}
pub async fn exec_collect<
T: for<'a> crate::conversion::FromRow<'a>,
S: IntoStatement,
P: ToParams,
>(
&mut self,
statement: S,
params: P,
) -> Result<Vec<T>> {
let mut handler = crate::handler::CollectHandler::<T>::new();
self.exec(statement, params, &mut handler).await?;
Ok(handler.into_rows())
}
pub async fn exec_first<
T: for<'a> crate::conversion::FromRow<'a>,
S: IntoStatement,
P: ToParams,
>(
&mut self,
statement: S,
params: P,
) -> Result<Option<T>> {
let mut handler = crate::handler::FirstRowHandler::<T>::new();
self.exec(statement, params, &mut handler).await?;
Ok(handler.into_row())
}
pub async fn exec_foreach<
T: for<'a> crate::conversion::FromRow<'a>,
S: IntoStatement,
P: ToParams,
F: FnMut(T) -> Result<()>,
>(
&mut self,
statement: S,
params: P,
f: F,
) -> Result<()> {
let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
self.exec(statement, params, &mut handler).await?;
Ok(())
}
pub async fn exec_batch<S: IntoStatement, P: ToParams>(
&mut self,
statement: S,
params_list: &[P],
) -> Result<()> {
self.exec_batch_chunked(statement, params_list, 1000).await
}
pub async fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
&mut self,
statement: S,
params_list: &[P],
chunk_size: usize,
) -> Result<()> {
let result = self
.exec_batch_inner(&statement, params_list, chunk_size)
.await;
if let Err(e) = &result
&& e.is_connection_broken()
{
self.is_broken = true;
}
result
}
async fn exec_batch_inner<S: IntoStatement, P: ToParams>(
&mut self,
statement: &S,
params_list: &[P],
chunk_size: usize,
) -> Result<()> {
use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
use crate::state::extended::BatchStateMachine;
if params_list.is_empty() {
return Ok(());
}
let chunk_size = chunk_size.max(1);
let stmt_ref = statement.statement_ref();
let (param_oids, stmt_name) = match stmt_ref {
StatementRef::Sql(_) => (params_list[0].natural_oids(), String::new()),
StatementRef::Prepared(stmt) => (stmt.param_oids.clone(), stmt.wire_name()),
};
for chunk in params_list.chunks(chunk_size) {
self.buffer_set.write_buffer.clear();
if let StatementRef::Sql(sql) = stmt_ref {
write_parse(&mut self.buffer_set.write_buffer, "", sql, ¶m_oids);
}
for params in chunk {
let effective_stmt_name = if matches!(stmt_ref, StatementRef::Sql(_)) {
""
} else {
&stmt_name
};
write_bind(
&mut self.buffer_set.write_buffer,
"",
effective_stmt_name,
params,
¶m_oids,
)?;
write_execute(&mut self.buffer_set.write_buffer, "", 0);
}
write_sync(&mut self.buffer_set.write_buffer);
let mut state_machine =
BatchStateMachine::new(matches!(stmt_ref, StatementRef::Sql(_)));
self.drive_batch(&mut state_machine).await?;
self.transaction_status = state_machine.transaction_status();
}
Ok(())
}
async fn drive_batch(
&mut self,
state_machine: &mut crate::state::extended::BatchStateMachine,
) -> Result<()> {
use crate::state::action::Action;
loop {
let step_result = state_machine.step(&mut self.buffer_set);
match step_result {
Ok(Action::ReadMessage) => {
self.stream.read_message(&mut self.buffer_set).await?;
}
Ok(Action::WriteAndReadMessage) => {
let buf = std::mem::take(&mut self.buffer_set.write_buffer);
let BufResult(result, buf) = self.stream.write_all_owned(buf).await;
self.buffer_set.write_buffer = buf;
result?;
self.stream.flush().await?;
self.stream.read_message(&mut self.buffer_set).await?;
}
Ok(Action::Finished) => {
break;
}
Ok(Action::Error(server_error)) => {
self.transaction_status = state_machine.transaction_status();
return Err(Error::Server(server_error));
}
Ok(_) => return Err(Error::LibraryBug("Unexpected action in batch".into())),
Err(e) => return Err(e),
}
}
Ok(())
}
pub async fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
let result = self.close_statement_inner(&stmt.wire_name()).await;
if let Err(e) = &result
&& e.is_connection_broken()
{
self.is_broken = true;
}
result
}
async fn close_statement_inner(&mut self, name: &str) -> Result<()> {
let mut handler = DropHandler::new();
let mut state_machine =
ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
self.drive(&mut state_machine).await
}
pub async fn lowlevel_flush(&mut self) -> Result<()> {
use crate::protocol::frontend::write_flush;
self.buffer_set.write_buffer.clear();
write_flush(&mut self.buffer_set.write_buffer);
let buf = std::mem::take(&mut self.buffer_set.write_buffer);
let BufResult(result, buf) = self.stream.write_all_owned(buf).await;
self.buffer_set.write_buffer = buf;
result?;
self.stream.flush().await?;
Ok(())
}
pub async fn lowlevel_sync(&mut self) -> Result<()> {
let result = self.sync_inner().await;
if let Err(e) = &result
&& e.is_connection_broken()
{
self.is_broken = true;
}
result
}
pub(crate) async fn sync_inner(&mut self) -> Result<()> {
use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
use crate::protocol::frontend::write_sync;
self.buffer_set.write_buffer.clear();
write_sync(&mut self.buffer_set.write_buffer);
let buf = std::mem::take(&mut self.buffer_set.write_buffer);
let BufResult(result, buf) = self.stream.write_all_owned(buf).await;
self.buffer_set.write_buffer = buf;
result?;
self.stream.flush().await?;
let mut pending_error: Option<Error> = None;
loop {
self.stream.read_message(&mut self.buffer_set).await?;
let type_byte = self.buffer_set.type_byte;
if RawMessage::is_async_type(type_byte) {
continue;
}
match type_byte {
msg_type::READY_FOR_QUERY => {
let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
self.transaction_status = ready.transaction_status().unwrap_or_default();
if let Some(e) = pending_error {
return Err(e);
}
return Ok(());
}
msg_type::ERROR_RESPONSE => {
let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
pending_error = Some(error.into_error());
}
_ => {
}
}
}
}
pub async fn lowlevel_bind<P: ToParams>(
&mut self,
portal: &str,
statement_name: &str,
params: P,
) -> Result<()> {
let result = self.bind_inner(portal, statement_name, ¶ms).await;
if let Err(e) = &result
&& e.is_connection_broken()
{
self.is_broken = true;
}
result
}
pub(crate) async fn bind_inner<P: ToParams>(
&mut self,
portal: &str,
statement_name: &str,
params: &P,
) -> Result<()> {
use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
use crate::protocol::frontend::{write_bind, write_flush};
let param_oids = params.natural_oids();
self.buffer_set.write_buffer.clear();
write_bind(
&mut self.buffer_set.write_buffer,
portal,
statement_name,
params,
¶m_oids,
)?;
write_flush(&mut self.buffer_set.write_buffer);
let buf = std::mem::take(&mut self.buffer_set.write_buffer);
let BufResult(result, buf) = self.stream.write_all_owned(buf).await;
self.buffer_set.write_buffer = buf;
result?;
self.stream.flush().await?;
loop {
self.stream.read_message(&mut self.buffer_set).await?;
let type_byte = self.buffer_set.type_byte;
if RawMessage::is_async_type(type_byte) {
continue;
}
match type_byte {
msg_type::BIND_COMPLETE => {
BindComplete::parse(&self.buffer_set.read_buffer)?;
return Ok(());
}
msg_type::ERROR_RESPONSE => {
let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
return Err(error.into_error());
}
_ => {
return Err(Error::LibraryBug(format!(
"Expected BindComplete or ErrorResponse, got '{}'",
type_byte as char
)));
}
}
}
}
pub async fn lowlevel_execute<H: ExtendedHandler>(
&mut self,
portal: &str,
max_rows: u32,
handler: &mut H,
) -> Result<bool> {
let result = self.execute_portal_inner(portal, max_rows, handler).await;
if let Err(e) = &result
&& e.is_connection_broken()
{
self.is_broken = true;
}
result
}
pub(crate) async fn execute_portal_inner<H: ExtendedHandler>(
&mut self,
portal: &str,
max_rows: u32,
handler: &mut H,
) -> Result<bool> {
use crate::protocol::backend::{
CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
RowDescription, msg_type,
};
use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
self.buffer_set.write_buffer.clear();
write_describe_portal(&mut self.buffer_set.write_buffer, portal);
write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
write_flush(&mut self.buffer_set.write_buffer);
let buf = std::mem::take(&mut self.buffer_set.write_buffer);
let BufResult(result, buf) = self.stream.write_all_owned(buf).await;
self.buffer_set.write_buffer = buf;
result?;
self.stream.flush().await?;
let mut column_buffer: Vec<u8> = Vec::new();
loop {
self.stream.read_message(&mut self.buffer_set).await?;
let type_byte = self.buffer_set.type_byte;
if RawMessage::is_async_type(type_byte) {
continue;
}
match type_byte {
msg_type::ROW_DESCRIPTION => {
column_buffer.clear();
column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
let cols = RowDescription::parse(&column_buffer)?;
handler.result_start(cols)?;
}
msg_type::NO_DATA => {
NoData::parse(&self.buffer_set.read_buffer)?;
}
msg_type::DATA_ROW => {
let cols = RowDescription::parse(&column_buffer)?;
let row = DataRow::parse(&self.buffer_set.read_buffer)?;
handler.row(cols, row)?;
}
msg_type::COMMAND_COMPLETE => {
let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
handler.result_end(complete)?;
return Ok(false); }
msg_type::PORTAL_SUSPENDED => {
PortalSuspended::parse(&self.buffer_set.read_buffer)?;
return Ok(true); }
msg_type::ERROR_RESPONSE => {
let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
return Err(error.into_error());
}
_ => {
return Err(Error::LibraryBug(format!(
"Unexpected message in execute: '{}'",
type_byte as char
)));
}
}
}
}
pub async fn exec_portal<S: IntoStatement, P, F, T>(
&mut self,
statement: S,
params: P,
f: F,
) -> Result<T>
where
P: ToParams,
F: AsyncFnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Result<T>,
{
let result = self.exec_portal_inner(&statement, ¶ms, f).await;
if let Err(e) = &result
&& e.is_connection_broken()
{
self.is_broken = true;
}
result
}
async fn exec_portal_inner<S: IntoStatement, P, F, T>(
&mut self,
statement: &S,
params: &P,
f: F,
) -> Result<T>
where
P: ToParams,
F: AsyncFnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Result<T>,
{
let mut state_machine = match statement.statement_ref() {
StatementRef::Sql(sql) => {
BindStateMachine::bind_sql(&mut self.buffer_set, "", sql, params)?
}
StatementRef::Prepared(stmt) => BindStateMachine::bind_prepared(
&mut self.buffer_set,
"",
&stmt.wire_name(),
&stmt.param_oids,
params,
)?,
};
loop {
match state_machine.step(&mut self.buffer_set)? {
Action::ReadMessage => {
self.stream.read_message(&mut self.buffer_set).await?;
}
Action::Write => {
let buf = std::mem::take(&mut self.buffer_set.write_buffer);
let BufResult(result, buf) = self.stream.write_all_owned(buf).await;
self.buffer_set.write_buffer = buf;
result?;
self.stream.flush().await?;
}
Action::WriteAndReadMessage => {
let buf = std::mem::take(&mut self.buffer_set.write_buffer);
let BufResult(result, buf) = self.stream.write_all_owned(buf).await;
self.buffer_set.write_buffer = buf;
result?;
self.stream.flush().await?;
self.stream.read_message(&mut self.buffer_set).await?;
}
Action::Finished => break,
_ => return Err(Error::LibraryBug("Unexpected action in bind".into())),
}
}
let mut portal = super::unnamed_portal::UnnamedPortal { conn: self };
let result = f(&mut portal).await;
let sync_result = portal.conn.sync_inner().await;
match (result, sync_result) {
(Ok(v), Ok(())) => Ok(v),
(Err(e), _) => Err(e),
(Ok(_), Err(e)) => Err(e),
}
}
pub async fn lowlevel_close_portal(&mut self, portal: &str) -> Result<()> {
let result = self.close_portal_inner(portal).await;
if let Err(e) = &result
&& e.is_connection_broken()
{
self.is_broken = true;
}
result
}
pub(crate) async fn close_portal_inner(&mut self, portal: &str) -> Result<()> {
use crate::protocol::backend::{CloseComplete, ErrorResponse, RawMessage, msg_type};
use crate::protocol::frontend::{write_close_portal, write_flush};
self.buffer_set.write_buffer.clear();
write_close_portal(&mut self.buffer_set.write_buffer, portal);
write_flush(&mut self.buffer_set.write_buffer);
let buf = std::mem::take(&mut self.buffer_set.write_buffer);
let BufResult(result, buf) = self.stream.write_all_owned(buf).await;
self.buffer_set.write_buffer = buf;
result?;
self.stream.flush().await?;
loop {
self.stream.read_message(&mut self.buffer_set).await?;
let type_byte = self.buffer_set.type_byte;
if RawMessage::is_async_type(type_byte) {
continue;
}
match type_byte {
msg_type::CLOSE_COMPLETE => {
CloseComplete::parse(&self.buffer_set.read_buffer)?;
return Ok(());
}
msg_type::ERROR_RESPONSE => {
let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
return Err(error.into_error());
}
_ => {
return Err(Error::LibraryBug(format!(
"Expected CloseComplete or ErrorResponse, got '{}'",
type_byte as char
)));
}
}
}
}
pub async fn pipeline<T, F>(&mut self, f: F) -> Result<T>
where
F: AsyncFnOnce(&mut super::pipeline::Pipeline<'_>) -> Result<T>,
{
let mut pipeline = super::pipeline::Pipeline::new_inner(self);
let result = f(&mut pipeline).await;
pipeline.cleanup().await;
result
}
pub async fn transaction<F, R>(&mut self, f: F) -> Result<R>
where
F: AsyncFnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
{
if self.in_transaction() {
return Err(Error::InvalidUsage(
"nested transactions are not supported".into(),
));
}
self.query_drop("BEGIN").await?;
let tx = super::transaction::Transaction::new(self.connection_id());
let result = f(self, tx).await;
if self.in_transaction() {
match &result {
Ok(_) => {
self.query_drop("COMMIT").await?;
}
Err(_) => {
let _ = self.query_drop("ROLLBACK").await;
}
}
}
result
}
}
use crate::state::simple_query::SimpleQueryStateMachine;