use crate::buffer_set::BufferSet;
use crate::error::{Error, Result};
use crate::handler::SimpleHandler;
use crate::protocol::backend::{
CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse, RawMessage, ReadyForQuery,
RowDescription, msg_type,
};
use crate::protocol::frontend::write_query;
use crate::protocol::types::TransactionStatus;
use super::StateMachine;
use super::action::{Action, AsyncMessage};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum State {
Initial,
WaitingResponse,
ProcessingRows,
WaitingReady,
Finished,
}
pub struct SimpleQueryStateMachine<'a, 'q, H> {
state: State,
handler: &'a mut H,
query: &'q str,
transaction_status: TransactionStatus,
pending_error: Option<crate::error::ServerError>,
}
impl<'a, 'q, H: SimpleHandler> SimpleQueryStateMachine<'a, 'q, H> {
pub fn new(handler: &'a mut H, query: &'q str) -> Self {
Self {
state: State::Initial,
handler,
query,
transaction_status: TransactionStatus::Idle,
pending_error: None,
}
}
fn handle_response(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
let type_byte = buffer_set.type_byte;
let payload = &buffer_set.read_buffer;
match type_byte {
msg_type::ROW_DESCRIPTION => {
buffer_set.column_buffer.clear();
buffer_set.column_buffer.extend_from_slice(payload);
let cols = RowDescription::parse(&buffer_set.column_buffer)?;
self.handler.result_start(cols)?;
self.state = State::ProcessingRows;
Ok(Action::ReadMessage)
}
msg_type::COMMAND_COMPLETE => {
let complete = CommandComplete::parse(payload)?;
self.handler.result_end(complete)?;
self.state = State::WaitingResponse;
Ok(Action::ReadMessage)
}
msg_type::EMPTY_QUERY_RESPONSE => {
EmptyQueryResponse::parse(payload)?;
self.state = State::WaitingReady;
Ok(Action::ReadMessage)
}
msg_type::READY_FOR_QUERY => {
let ready = ReadyForQuery::parse(payload)?;
self.transaction_status = ready.transaction_status().unwrap_or_default();
self.state = State::Finished;
Ok(Action::Finished)
}
_ => Err(Error::LibraryBug(format!(
"Unexpected message in query response: '{}'",
type_byte as char
))),
}
}
fn handle_rows(&mut self, buffer_set: &BufferSet) -> Result<Action> {
let type_byte = buffer_set.type_byte;
let payload = &buffer_set.read_buffer;
match type_byte {
msg_type::DATA_ROW => {
let cols = RowDescription::parse(&buffer_set.column_buffer)?;
let row = DataRow::parse(payload)?;
self.handler.row(cols, row)?;
Ok(Action::ReadMessage)
}
msg_type::COMMAND_COMPLETE => {
let complete = CommandComplete::parse(payload)?;
self.handler.result_end(complete)?;
self.state = State::WaitingResponse;
Ok(Action::ReadMessage)
}
msg_type::READY_FOR_QUERY => {
let ready = ReadyForQuery::parse(payload)?;
self.transaction_status = ready.transaction_status().unwrap_or_default();
self.state = State::Finished;
Ok(Action::Finished)
}
_ => Err(Error::LibraryBug(format!(
"Unexpected message in row processing: '{}'",
type_byte as char
))),
}
}
fn handle_ready(&mut self, buffer_set: &BufferSet) -> Result<Action> {
if buffer_set.type_byte != msg_type::READY_FOR_QUERY {
return Err(Error::LibraryBug(format!(
"Expected ReadyForQuery, got '{}'",
buffer_set.type_byte as char
)));
}
let ready = ReadyForQuery::parse(&buffer_set.read_buffer)?;
self.transaction_status = ready.transaction_status().unwrap_or_default();
self.state = State::Finished;
if let Some(err) = self.pending_error.take() {
Ok(Action::Error(err))
} else {
Ok(Action::Finished)
}
}
fn handle_async_message(&self, msg: &RawMessage<'_>) -> Result<Action> {
match msg.type_byte {
msg_type::NOTICE_RESPONSE => {
let notice = crate::protocol::backend::NoticeResponse::parse(msg.payload)?;
Ok(Action::HandleAsyncMessageAndReadMessage(
AsyncMessage::Notice(notice.0),
))
}
msg_type::PARAMETER_STATUS => {
let param = crate::protocol::backend::auth::ParameterStatus::parse(msg.payload)?;
Ok(Action::HandleAsyncMessageAndReadMessage(
AsyncMessage::ParameterChanged {
name: param.name.to_string(),
value: param.value.to_string(),
},
))
}
msg_type::NOTIFICATION_RESPONSE => {
let notification =
crate::protocol::backend::auth::NotificationResponse::parse(msg.payload)?;
Ok(Action::HandleAsyncMessageAndReadMessage(
AsyncMessage::Notification {
pid: notification.pid,
channel: notification.channel.to_string(),
payload: notification.payload.to_string(),
},
))
}
_ => Err(Error::LibraryBug(format!(
"Unknown async message type: '{}'",
msg.type_byte as char
))),
}
}
}
impl<H: SimpleHandler> StateMachine for SimpleQueryStateMachine<'_, '_, H> {
fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
if self.state == State::Initial {
buffer_set.write_buffer.clear();
write_query(&mut buffer_set.write_buffer, self.query);
self.state = State::WaitingResponse;
return Ok(Action::WriteAndReadMessage);
}
let type_byte = buffer_set.type_byte;
if RawMessage::is_async_type(type_byte) {
let msg = RawMessage::new(type_byte, &buffer_set.read_buffer);
return self.handle_async_message(&msg);
}
if type_byte == msg_type::ERROR_RESPONSE {
let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
self.pending_error = Some(error.0);
self.state = State::WaitingReady;
return Ok(Action::ReadMessage);
}
match self.state {
State::WaitingResponse => self.handle_response(buffer_set),
State::ProcessingRows => self.handle_rows(buffer_set),
State::WaitingReady => self.handle_ready(buffer_set),
_ => Err(Error::LibraryBug(format!(
"Unexpected state {:?}",
self.state
))),
}
}
fn transaction_status(&self) -> TransactionStatus {
self.transaction_status
}
}