use crate::conversion::ToParams;
use crate::error::{Error, Result};
use crate::handler::ExtendedHandler;
use crate::protocol::backend::{
BindComplete, CloseComplete, CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse,
NoData, ParameterDescription, ParseComplete, PortalSuspended, RawMessage, ReadyForQuery,
RowDescription, msg_type,
};
use crate::protocol::frontend::{
write_bind, write_close_statement, write_describe_portal, write_describe_statement,
write_execute, write_parse, write_sync,
};
use crate::protocol::types::{Oid, TransactionStatus};
use super::StateMachine;
use super::action::{Action, AsyncMessage};
use crate::buffer_set::BufferSet;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum State {
Initial,
WaitingParse,
WaitingBind,
WaitingDescribe,
WaitingRowDesc,
ProcessingRows,
WaitingReady,
Finished,
}
#[derive(Debug, Clone)]
pub struct PreparedStatement {
pub idx: u64,
pub param_oids: Vec<Oid>,
pub(crate) row_desc_payload: Option<Vec<u8>>,
}
impl PreparedStatement {
pub fn wire_name(&self) -> String {
format!("_zero_s_{}", self.idx)
}
pub fn parse_columns(&self) -> Option<Result<RowDescription<'_>>> {
self.row_desc_payload
.as_ref()
.map(|bytes| RowDescription::parse(bytes))
}
pub fn row_desc_payload(&self) -> Option<&[u8]> {
self.row_desc_payload.as_deref()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Operation {
Prepare,
Execute,
ExecuteSql,
CloseStatement,
}
pub struct ExtendedQueryStateMachine<'a, H> {
state: State,
handler: &'a mut H,
operation: Operation,
transaction_status: TransactionStatus,
prepared_stmt: Option<PreparedStatement>,
pending_error: Option<crate::error::ServerError>,
}
impl<'a, H: ExtendedHandler> ExtendedQueryStateMachine<'a, H> {
pub fn take_prepared_statement(&mut self) -> Option<PreparedStatement> {
self.prepared_stmt.take()
}
pub fn prepare(
handler: &'a mut H,
buffer_set: &mut BufferSet,
idx: u64,
query: &str,
param_oids: &[Oid],
) -> Self {
let stmt_name = format!("_zero_s_{}", idx);
buffer_set.write_buffer.clear();
write_parse(&mut buffer_set.write_buffer, &stmt_name, query, param_oids);
write_describe_statement(&mut buffer_set.write_buffer, &stmt_name);
write_sync(&mut buffer_set.write_buffer);
Self {
state: State::Initial,
handler,
operation: Operation::Prepare,
transaction_status: TransactionStatus::Idle,
prepared_stmt: Some(PreparedStatement {
idx,
param_oids: Vec::new(),
row_desc_payload: None,
}),
pending_error: None,
}
}
pub fn execute<P: ToParams>(
handler: &'a mut H,
buffer_set: &mut BufferSet,
statement_name: &str,
param_oids: &[Oid],
params: &P,
) -> Result<Self> {
buffer_set.write_buffer.clear();
write_bind(
&mut buffer_set.write_buffer,
"",
statement_name,
params,
param_oids,
)?;
write_describe_portal(&mut buffer_set.write_buffer, "");
write_execute(&mut buffer_set.write_buffer, "", 0);
write_sync(&mut buffer_set.write_buffer);
Ok(Self {
state: State::Initial,
handler,
operation: Operation::Execute,
transaction_status: TransactionStatus::Idle,
prepared_stmt: None,
pending_error: None,
})
}
pub fn execute_sql<P: ToParams>(
handler: &'a mut H,
buffer_set: &mut BufferSet,
sql: &str,
params: &P,
) -> Result<Self> {
let param_oids = params.natural_oids();
buffer_set.write_buffer.clear();
write_parse(&mut buffer_set.write_buffer, "", sql, ¶m_oids);
write_bind(&mut buffer_set.write_buffer, "", "", params, ¶m_oids)?;
write_describe_portal(&mut buffer_set.write_buffer, "");
write_execute(&mut buffer_set.write_buffer, "", 0);
write_sync(&mut buffer_set.write_buffer);
Ok(Self {
state: State::Initial,
handler,
operation: Operation::ExecuteSql,
transaction_status: TransactionStatus::Idle,
prepared_stmt: None,
pending_error: None,
})
}
pub fn close_statement(handler: &'a mut H, buffer_set: &mut BufferSet, name: &str) -> Self {
buffer_set.write_buffer.clear();
write_close_statement(&mut buffer_set.write_buffer, name);
write_sync(&mut buffer_set.write_buffer);
Self {
state: State::Initial,
handler,
operation: Operation::CloseStatement,
transaction_status: TransactionStatus::Idle,
prepared_stmt: None,
pending_error: None,
}
}
fn handle_parse(&mut self, buffer_set: &BufferSet) -> Result<Action> {
let type_byte = buffer_set.type_byte;
if type_byte != msg_type::PARSE_COMPLETE {
return Err(Error::LibraryBug(format!(
"Expected ParseComplete, got '{}'",
type_byte as char
)));
}
ParseComplete::parse(&buffer_set.read_buffer)?;
self.state = match self.operation {
Operation::ExecuteSql => State::WaitingBind,
Operation::Prepare => State::WaitingDescribe,
_ => {
return Err(Error::LibraryBug(
"handle_parse called for non-parse operation".into(),
));
}
};
Ok(Action::ReadMessage)
}
fn handle_describe(&mut self, buffer_set: &BufferSet) -> Result<Action> {
let type_byte = buffer_set.type_byte;
if type_byte != msg_type::PARAMETER_DESCRIPTION {
return Err(Error::LibraryBug(format!(
"Expected ParameterDescription, got '{}'",
type_byte as char
)));
}
let param_desc = ParameterDescription::parse(&buffer_set.read_buffer)?;
if let Some(stmt) = &mut self.prepared_stmt {
stmt.param_oids = param_desc.oids().to_vec();
}
self.state = State::WaitingRowDesc;
Ok(Action::ReadMessage)
}
fn handle_row_desc(&mut self, buffer_set: &BufferSet) -> Result<Action> {
let type_byte = buffer_set.type_byte;
match type_byte {
msg_type::ROW_DESCRIPTION => {
if let Some(stmt) = &mut self.prepared_stmt {
stmt.row_desc_payload = Some(buffer_set.read_buffer.clone());
}
self.state = State::WaitingReady;
Ok(Action::ReadMessage)
}
msg_type::NO_DATA => {
let payload = &buffer_set.read_buffer;
NoData::parse(payload)?;
self.state = State::WaitingReady;
Ok(Action::ReadMessage)
}
_ => Err(Error::LibraryBug(format!(
"Expected RowDescription or NoData, got '{}'",
type_byte as char
))),
}
}
fn handle_bind(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
let type_byte = buffer_set.type_byte;
match type_byte {
msg_type::BIND_COMPLETE => {
BindComplete::parse(&buffer_set.read_buffer)?;
self.state = State::ProcessingRows;
Ok(Action::ReadMessage)
}
msg_type::ROW_DESCRIPTION => {
buffer_set.column_buffer.clear();
buffer_set
.column_buffer
.extend_from_slice(&buffer_set.read_buffer);
let cols = RowDescription::parse(&buffer_set.column_buffer)?;
self.handler.result_start(cols)?;
self.state = State::ProcessingRows;
Ok(Action::ReadMessage)
}
_ => Err(Error::LibraryBug(format!(
"Expected BindComplete, got '{}'",
type_byte as char
))),
}
}
fn handle_rows(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
let type_byte = buffer_set.type_byte;
let payload = &buffer_set.read_buffer;
match type_byte {
msg_type::ROW_DESCRIPTION => {
buffer_set.column_buffer.clear();
buffer_set.column_buffer.extend_from_slice(payload);
let cols = RowDescription::parse(&buffer_set.column_buffer)?;
self.handler.result_start(cols)?;
Ok(Action::ReadMessage)
}
msg_type::NO_DATA => {
NoData::parse(payload)?;
Ok(Action::ReadMessage)
}
msg_type::DATA_ROW => {
let cols = RowDescription::parse(&buffer_set.column_buffer)?;
let row = DataRow::parse(payload)?;
self.handler.row(cols, row)?;
Ok(Action::ReadMessage)
}
msg_type::COMMAND_COMPLETE => {
let complete = CommandComplete::parse(payload)?;
self.handler.result_end(complete)?;
self.state = State::WaitingReady;
Ok(Action::ReadMessage)
}
msg_type::EMPTY_QUERY_RESPONSE => {
EmptyQueryResponse::parse(payload)?;
self.state = State::WaitingReady;
Ok(Action::ReadMessage)
}
msg_type::PORTAL_SUSPENDED => {
PortalSuspended::parse(payload)?;
self.state = State::WaitingReady;
Ok(Action::ReadMessage)
}
msg_type::READY_FOR_QUERY => {
let ready = ReadyForQuery::parse(payload)?;
self.transaction_status = ready.transaction_status().unwrap_or_default();
self.state = State::Finished;
Ok(Action::Finished)
}
_ => Err(Error::LibraryBug(format!(
"Unexpected message in rows: '{}'",
type_byte as char
))),
}
}
fn handle_ready(&mut self, buffer_set: &BufferSet) -> Result<Action> {
let type_byte = buffer_set.type_byte;
let payload = &buffer_set.read_buffer;
match type_byte {
msg_type::READY_FOR_QUERY => {
let ready = ReadyForQuery::parse(payload)?;
self.transaction_status = ready.transaction_status().unwrap_or_default();
self.state = State::Finished;
if let Some(err) = self.pending_error.take() {
Ok(Action::Error(err))
} else {
Ok(Action::Finished)
}
}
msg_type::CLOSE_COMPLETE => {
CloseComplete::parse(payload)?;
Ok(Action::ReadMessage)
}
_ => Err(Error::LibraryBug(format!(
"Expected ReadyForQuery, got '{}'",
type_byte as char
))),
}
}
fn handle_async_message(&self, msg: &RawMessage<'_>) -> Result<Action> {
match msg.type_byte {
msg_type::NOTICE_RESPONSE => {
let notice = crate::protocol::backend::NoticeResponse::parse(msg.payload)?;
Ok(Action::HandleAsyncMessageAndReadMessage(
AsyncMessage::Notice(notice.0),
))
}
msg_type::PARAMETER_STATUS => {
let param = crate::protocol::backend::auth::ParameterStatus::parse(msg.payload)?;
Ok(Action::HandleAsyncMessageAndReadMessage(
AsyncMessage::ParameterChanged {
name: param.name.to_string(),
value: param.value.to_string(),
},
))
}
msg_type::NOTIFICATION_RESPONSE => {
let notification =
crate::protocol::backend::auth::NotificationResponse::parse(msg.payload)?;
Ok(Action::HandleAsyncMessageAndReadMessage(
AsyncMessage::Notification {
pid: notification.pid,
channel: notification.channel.to_string(),
payload: notification.payload.to_string(),
},
))
}
_ => Err(Error::LibraryBug(format!(
"Unknown async message type: '{}'",
msg.type_byte as char
))),
}
}
}
impl<H: ExtendedHandler> StateMachine for ExtendedQueryStateMachine<'_, H> {
fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
if self.state == State::Initial {
self.state = match self.operation {
Operation::Prepare => State::WaitingParse,
Operation::Execute => State::WaitingBind, Operation::ExecuteSql => State::WaitingParse,
Operation::CloseStatement => State::WaitingReady,
};
return Ok(Action::WriteAndReadMessage);
}
let type_byte = buffer_set.type_byte;
if RawMessage::is_async_type(type_byte) {
let msg = RawMessage::new(type_byte, &buffer_set.read_buffer);
return self.handle_async_message(&msg);
}
if type_byte == msg_type::ERROR_RESPONSE {
let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
self.pending_error = Some(error.0);
self.state = State::WaitingReady;
return Ok(Action::ReadMessage);
}
match self.state {
State::WaitingParse => self.handle_parse(buffer_set),
State::WaitingDescribe => self.handle_describe(buffer_set),
State::WaitingRowDesc => self.handle_row_desc(buffer_set),
State::WaitingBind => self.handle_bind(buffer_set),
State::ProcessingRows => self.handle_rows(buffer_set),
State::WaitingReady => self.handle_ready(buffer_set),
_ => Err(Error::LibraryBug(format!(
"Unexpected state {:?}",
self.state
))),
}
}
fn transaction_status(&self) -> TransactionStatus {
self.transaction_status
}
}
use crate::protocol::frontend::write_flush;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum BindState {
Initial,
WaitingParse,
WaitingBind,
Finished,
}
pub struct BindStateMachine {
state: BindState,
needs_parse: bool,
}
impl BindStateMachine {
pub fn bind_prepared<P: ToParams>(
buffer_set: &mut BufferSet,
portal_name: &str,
statement_name: &str,
param_oids: &[Oid],
params: &P,
) -> Result<Self> {
buffer_set.write_buffer.clear();
write_bind(
&mut buffer_set.write_buffer,
portal_name,
statement_name,
params,
param_oids,
)?;
write_flush(&mut buffer_set.write_buffer);
Ok(Self {
state: BindState::Initial,
needs_parse: false,
})
}
pub fn bind_sql<P: ToParams>(
buffer_set: &mut BufferSet,
portal_name: &str,
sql: &str,
params: &P,
) -> Result<Self> {
let param_oids = params.natural_oids();
buffer_set.write_buffer.clear();
write_parse(&mut buffer_set.write_buffer, "", sql, ¶m_oids);
write_bind(
&mut buffer_set.write_buffer,
portal_name,
"",
params,
¶m_oids,
)?;
write_flush(&mut buffer_set.write_buffer);
Ok(Self {
state: BindState::Initial,
needs_parse: true,
})
}
pub fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
if self.state == BindState::Initial {
self.state = if self.needs_parse {
BindState::WaitingParse
} else {
BindState::WaitingBind
};
return Ok(Action::WriteAndReadMessage);
}
let type_byte = buffer_set.type_byte;
if RawMessage::is_async_type(type_byte) {
return Ok(Action::ReadMessage);
}
if type_byte == msg_type::ERROR_RESPONSE {
let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
return Err(error.into_error());
}
match self.state {
BindState::WaitingParse => {
if type_byte != msg_type::PARSE_COMPLETE {
return Err(Error::LibraryBug(format!(
"Expected ParseComplete, got '{}'",
type_byte as char
)));
}
ParseComplete::parse(&buffer_set.read_buffer)?;
self.state = BindState::WaitingBind;
Ok(Action::ReadMessage)
}
BindState::WaitingBind => {
if type_byte != msg_type::BIND_COMPLETE {
return Err(Error::LibraryBug(format!(
"Expected BindComplete, got '{}'",
type_byte as char
)));
}
BindComplete::parse(&buffer_set.read_buffer)?;
self.state = BindState::Finished;
Ok(Action::Finished)
}
_ => Err(Error::LibraryBug(format!(
"Unexpected state {:?}",
self.state
))),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum BatchState {
Initial,
WaitingParse,
Processing,
Finished,
}
pub struct BatchStateMachine {
state: BatchState,
needs_parse: bool,
transaction_status: TransactionStatus,
pending_error: Option<crate::error::ServerError>,
}
impl BatchStateMachine {
pub fn new(needs_parse: bool) -> Self {
Self {
state: BatchState::Initial,
needs_parse,
transaction_status: TransactionStatus::Idle,
pending_error: None,
}
}
pub fn transaction_status(&self) -> TransactionStatus {
self.transaction_status
}
pub fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
if self.state == BatchState::Initial {
self.state = if self.needs_parse {
BatchState::WaitingParse
} else {
BatchState::Processing
};
return Ok(Action::WriteAndReadMessage);
}
let type_byte = buffer_set.type_byte;
if RawMessage::is_async_type(type_byte) {
return Ok(Action::ReadMessage);
}
if type_byte == msg_type::ERROR_RESPONSE {
let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
self.pending_error = Some(error.0);
self.state = BatchState::Processing;
return Ok(Action::ReadMessage);
}
match self.state {
BatchState::WaitingParse => {
if type_byte != msg_type::PARSE_COMPLETE {
return Err(Error::LibraryBug(format!(
"Expected ParseComplete, got '{}'",
type_byte as char
)));
}
ParseComplete::parse(&buffer_set.read_buffer)?;
self.state = BatchState::Processing;
Ok(Action::ReadMessage)
}
BatchState::Processing => {
match type_byte {
msg_type::BIND_COMPLETE => {
BindComplete::parse(&buffer_set.read_buffer)?;
Ok(Action::ReadMessage)
}
msg_type::NO_DATA => {
NoData::parse(&buffer_set.read_buffer)?;
Ok(Action::ReadMessage)
}
msg_type::ROW_DESCRIPTION => {
RowDescription::parse(&buffer_set.read_buffer)?;
Ok(Action::ReadMessage)
}
msg_type::DATA_ROW => {
Ok(Action::ReadMessage)
}
msg_type::COMMAND_COMPLETE => {
CommandComplete::parse(&buffer_set.read_buffer)?;
Ok(Action::ReadMessage)
}
msg_type::EMPTY_QUERY_RESPONSE => {
EmptyQueryResponse::parse(&buffer_set.read_buffer)?;
Ok(Action::ReadMessage)
}
msg_type::READY_FOR_QUERY => {
let ready = ReadyForQuery::parse(&buffer_set.read_buffer)?;
self.transaction_status = ready.transaction_status().unwrap_or_default();
self.state = BatchState::Finished;
if let Some(err) = self.pending_error.take() {
Ok(Action::Error(err))
} else {
Ok(Action::Finished)
}
}
_ => Err(Error::LibraryBug(format!(
"Unexpected message in batch: '{}'",
type_byte as char
))),
}
}
_ => Err(Error::LibraryBug(format!(
"Unexpected state {:?}",
self.state
))),
}
}
}