use std::cmp::max;
use std::fmt::Debug;
use std::ops::DerefMut;
use std::sync::Arc;
use async_trait::async_trait;
use futures::sink::{Sink, SinkExt};
use futures::stream::StreamExt;
use super::portal::Portal;
use super::results::{Tag, into_row_description};
use super::stmt::{NoopQueryParser, QueryParser, StoredStatement};
use super::store::PortalStore;
use super::{ClientInfo, ClientPortalStore, DEFAULT_NAME, copy};
use crate::api::PgWireConnectionState;
use crate::api::Type;
use crate::api::portal::PortalExecutionState;
use crate::api::results::{
DescribePortalResponse, DescribeResponse, DescribeStatementResponse, QueryResponse, Response,
};
use crate::error::{ErrorInfo, PgWireError, PgWireResult};
use crate::messages::PgWireBackendMessage;
use crate::messages::data::{NoData, ParameterDescription};
use crate::messages::extendedquery::{
Bind, BindComplete, Close, CloseComplete, Describe, Execute, Flush, Parse, ParseComplete,
PortalSuspended, Sync as PgSync, TARGET_TYPE_BYTE_PORTAL, TARGET_TYPE_BYTE_STATEMENT,
};
use crate::messages::response::{EmptyQueryResponse, ReadyForQuery, TransactionStatus};
use crate::messages::simplequery::Query;
fn is_empty_query(q: &str) -> bool {
let trimmed_query = q.trim();
trimmed_query == ";" || trimmed_query.is_empty()
}
#[async_trait]
pub trait SimpleQueryHandler: Send + Sync {
async fn on_query<C>(&self, client: &mut C, query: Query) -> PgWireResult<()>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
self._on_query(client, query).await
}
async fn _on_query<C>(&self, client: &mut C, query: Query) -> PgWireResult<()>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
if !matches!(client.state(), super::PgWireConnectionState::ReadyForQuery) {
return Err(PgWireError::NotReadyForQuery);
}
let mut transaction_status = client.transaction_status();
client.set_state(super::PgWireConnectionState::QueryInProgress);
let query_string = query.query;
if is_empty_query(&query_string) {
client
.feed(PgWireBackendMessage::EmptyQueryResponse(EmptyQueryResponse))
.await?;
} else {
let resp = self.do_query(client, &query_string).await?;
for r in resp {
match r {
Response::EmptyQuery => {
client
.feed(PgWireBackendMessage::EmptyQueryResponse(EmptyQueryResponse))
.await?;
}
Response::Query(results) => {
send_query_response(client, results, true).await?;
}
Response::Execution(tag) => {
send_execution_response(client, tag).await?;
}
Response::TransactionStart(tag) => {
send_execution_response(client, tag).await?;
transaction_status = transaction_status.to_in_transaction_state();
}
Response::TransactionEnd(tag) => {
send_execution_response(client, tag).await?;
transaction_status = transaction_status.to_idle_state();
}
Response::Error(e) => {
client
.feed(PgWireBackendMessage::ErrorResponse((*e).into()))
.await?;
transaction_status = transaction_status.to_error_state();
}
Response::CopyIn(result) => {
copy::send_copy_in_response(client, result).await?;
client.set_state(PgWireConnectionState::CopyInProgress(false));
}
Response::CopyOut(result) => {
copy::send_copy_out_response(client, result).await?;
}
Response::CopyBoth(result) => {
copy::send_copy_both_response(client, result).await?;
client.set_state(PgWireConnectionState::CopyInProgress(false));
}
}
}
}
if !matches!(client.state(), PgWireConnectionState::CopyInProgress(_)) {
client.set_state(super::PgWireConnectionState::ReadyForQuery);
client.set_transaction_status(transaction_status);
send_ready_for_query(client, transaction_status).await?;
};
Ok(())
}
async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>;
}
#[async_trait]
pub trait ExtendedQueryHandler: Send + Sync {
type Statement: Clone + Send + Sync;
type QueryParser: QueryParser<Statement = Self::Statement> + Send + Sync;
fn query_parser(&self) -> Arc<Self::QueryParser>;
async fn on_parse<C>(&self, client: &mut C, message: Parse) -> PgWireResult<()>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::PortalStore: PortalStore<Statement = Self::Statement>,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let parser = self.query_parser();
let stmt = StoredStatement::parse(client, &message, parser).await?;
client.portal_store().put_statement(Arc::new(stmt));
client
.send(PgWireBackendMessage::ParseComplete(ParseComplete::new()))
.await?;
Ok(())
}
async fn on_bind<C>(&self, client: &mut C, message: Bind) -> PgWireResult<()>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::PortalStore: PortalStore<Statement = Self::Statement>,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let statement_name = message.statement_name.as_deref().unwrap_or(DEFAULT_NAME);
if let Some(statement) = client.portal_store().get_statement(statement_name) {
let portal = Portal::try_new(&message, statement)?;
client.portal_store().put_portal(Arc::new(portal));
client
.send(PgWireBackendMessage::BindComplete(BindComplete::new()))
.await?;
Ok(())
} else {
Err(PgWireError::StatementNotFound(statement_name.to_owned()))
}
}
async fn on_execute<C>(&self, client: &mut C, message: Execute) -> PgWireResult<()>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::PortalStore: PortalStore<Statement = Self::Statement>,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
self._on_execute(client, message).await
}
async fn _on_execute<C>(&self, client: &mut C, message: Execute) -> PgWireResult<()>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::PortalStore: PortalStore<Statement = Self::Statement>,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
if !matches!(client.state(), super::PgWireConnectionState::ReadyForQuery) {
return Err(PgWireError::NotReadyForQuery);
}
let mut transaction_status = client.transaction_status();
client.set_state(super::PgWireConnectionState::QueryInProgress);
let portal_name = message.name.as_deref().unwrap_or(DEFAULT_NAME);
let max_rows = message.max_rows as usize;
if let Some(portal) = client.portal_store().get_portal(portal_name) {
let portal_state_lock = portal.state();
let mut portal_state = portal_state_lock.lock().await;
match portal_state.deref_mut() {
PortalExecutionState::Initial => {
match self.do_query(client, portal.as_ref(), max_rows).await? {
Response::EmptyQuery => {
client
.feed(PgWireBackendMessage::EmptyQueryResponse(EmptyQueryResponse))
.await?;
}
Response::Query(mut results) => {
if max_rows > 0 {
if send_partial_query_response(client, &mut results, max_rows)
.await?
{
*portal_state = PortalExecutionState::Suspended(results);
} else {
*portal_state = PortalExecutionState::Finished;
}
} else {
send_query_response(client, results, false).await?;
}
}
Response::Execution(tag) => {
send_execution_response(client, tag).await?;
}
Response::TransactionStart(tag) => {
send_execution_response(client, tag).await?;
transaction_status = transaction_status.to_in_transaction_state();
}
Response::TransactionEnd(tag) => {
send_execution_response(client, tag).await?;
transaction_status = transaction_status.to_idle_state();
}
Response::Error(err) => {
client
.send(PgWireBackendMessage::ErrorResponse((*err).into()))
.await?;
transaction_status = transaction_status.to_error_state();
}
Response::CopyIn(result) => {
client.set_state(PgWireConnectionState::CopyInProgress(true));
copy::send_copy_in_response(client, result).await?;
}
Response::CopyOut(result) => {
copy::send_copy_out_response(client, result).await?;
}
Response::CopyBoth(result) => {
client.set_state(PgWireConnectionState::CopyInProgress(true));
copy::send_copy_both_response(client, result).await?;
}
}
}
PortalExecutionState::Suspended(results) => {
let has_more = send_partial_query_response(client, results, max_rows).await?;
if !has_more {
*portal_state = PortalExecutionState::Finished;
}
}
PortalExecutionState::Finished => {
client.send(PgWireBackendMessage::NoData(NoData)).await?;
}
}
if !matches!(client.state(), PgWireConnectionState::CopyInProgress(_)) {
client.set_state(super::PgWireConnectionState::ReadyForQuery);
client.set_transaction_status(transaction_status);
};
if portal_name == DEFAULT_NAME {
client.portal_store().rm_portal(portal_name);
}
Ok(())
} else {
Err(PgWireError::PortalNotFound(portal_name.to_owned()))
}
}
async fn on_describe<C>(&self, client: &mut C, message: Describe) -> PgWireResult<()>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::PortalStore: PortalStore<Statement = Self::Statement>,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
self._on_describe(client, message).await
}
async fn _on_describe<C>(&self, client: &mut C, message: Describe) -> PgWireResult<()>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::PortalStore: PortalStore<Statement = Self::Statement>,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let name = message.name.as_deref().unwrap_or(DEFAULT_NAME);
match message.target_type {
TARGET_TYPE_BYTE_STATEMENT => {
if let Some(stmt) = client.portal_store().get_statement(name) {
let describe_response = self.do_describe_statement(client, &stmt).await?;
send_describe_response(client, &describe_response).await?;
} else {
return Err(PgWireError::StatementNotFound(name.to_owned()));
}
}
TARGET_TYPE_BYTE_PORTAL => {
if let Some(portal) = client.portal_store().get_portal(name) {
let describe_response = self.do_describe_portal(client, &portal).await?;
send_describe_response(client, &describe_response).await?;
} else {
return Err(PgWireError::PortalNotFound(name.to_owned()));
}
}
_ => return Err(PgWireError::InvalidTargetType(message.target_type)),
}
Ok(())
}
async fn on_flush<C>(&self, client: &mut C, _message: Flush) -> PgWireResult<()>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
client.flush().await?;
Ok(())
}
async fn on_sync<C>(&self, client: &mut C, _message: PgSync) -> PgWireResult<()>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::PortalStore: PortalStore<Statement = Self::Statement>,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
client.portal_store().clear_portals();
client
.send(PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new(
client.transaction_status(),
)))
.await?;
client.flush().await?;
Ok(())
}
async fn on_close<C>(&self, client: &mut C, message: Close) -> PgWireResult<()>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::PortalStore: PortalStore<Statement = Self::Statement>,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let name = message.name.as_deref().unwrap_or(DEFAULT_NAME);
match message.target_type {
TARGET_TYPE_BYTE_STATEMENT => {
client.portal_store().rm_statement(name);
}
TARGET_TYPE_BYTE_PORTAL => {
client.portal_store().rm_portal(name);
}
_ => {}
}
client
.send(PgWireBackendMessage::CloseComplete(CloseComplete))
.await?;
Ok(())
}
async fn do_describe_statement<C>(
&self,
_client: &mut C,
target: &StoredStatement<Self::Statement>,
) -> PgWireResult<DescribeStatementResponse>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::PortalStore: PortalStore<Statement = Self::Statement>,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let stmt = &target.statement;
let query_parser = self.query_parser();
let server_param_types = query_parser.get_parameter_types(stmt)?;
let result_schema = query_parser.get_result_schema(stmt, None)?;
let param_types = (0usize..max(target.parameter_types.len(), server_param_types.len()))
.map(|idx| {
target
.parameter_types
.get(idx)
.cloned()
.and_then(|f| f)
.or_else(|| server_param_types.get(idx).cloned())
.unwrap_or(Type::UNKNOWN)
})
.collect::<Vec<Type>>();
Ok(DescribeStatementResponse::new(param_types, result_schema))
}
async fn do_describe_portal<C>(
&self,
_client: &mut C,
target: &Portal<Self::Statement>,
) -> PgWireResult<DescribePortalResponse>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::PortalStore: PortalStore<Statement = Self::Statement>,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let stmt = &target.statement.statement;
let query_parser = self.query_parser();
let result_schema =
query_parser.get_result_schema(stmt, Some(&target.result_column_format))?;
Ok(DescribePortalResponse::new(result_schema))
}
async fn do_query<C>(
&self,
client: &mut C,
portal: &Portal<Self::Statement>,
max_rows: usize,
) -> PgWireResult<Response>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::PortalStore: PortalStore<Statement = Self::Statement>,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>;
}
pub async fn send_query_response<C>(
client: &mut C,
results: QueryResponse,
send_describe: bool,
) -> PgWireResult<()>
where
C: Sink<PgWireBackendMessage> + Unpin,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let QueryResponse {
command_tag,
row_schema,
mut data_rows,
} = results;
if send_describe {
let row_desc = into_row_description(&row_schema);
client
.send(PgWireBackendMessage::RowDescription(row_desc))
.await?;
}
let mut rows = 0;
while let Some(row) = data_rows.next().await {
let row = row?;
rows += 1;
client.feed(PgWireBackendMessage::DataRow(row)).await?;
}
let tag = Tag::new(&command_tag).with_rows(rows);
client
.send(PgWireBackendMessage::CommandComplete(tag.into()))
.await?;
Ok(())
}
pub async fn send_partial_query_response<C>(
client: &mut C,
results: &mut QueryResponse,
max_rows: usize,
) -> PgWireResult<bool>
where
C: Sink<PgWireBackendMessage> + Unpin,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let command_tag = results.command_tag().to_string();
let data_rows = results.data_rows();
let mut rows = 0;
let mut suspended = true;
while max_rows == 0 || rows < max_rows {
if let Some(row) = data_rows.next().await {
let row = row?;
client.feed(PgWireBackendMessage::DataRow(row)).await?;
rows += 1;
} else {
suspended = false;
break;
}
}
if suspended {
client
.send(PgWireBackendMessage::PortalSuspended(PortalSuspended))
.await?;
} else {
let tag = Tag::new(&command_tag).with_rows(rows);
client
.send(PgWireBackendMessage::CommandComplete(tag.into()))
.await?;
}
Ok(suspended)
}
pub async fn send_ready_for_query<C>(
client: &mut C,
transaction_status: TransactionStatus,
) -> PgWireResult<()>
where
C: Sink<PgWireBackendMessage> + Unpin,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let message = ReadyForQuery::new(transaction_status);
client
.send(PgWireBackendMessage::ReadyForQuery(message))
.await?;
Ok(())
}
pub async fn send_execution_response<C>(client: &mut C, tag: Tag) -> PgWireResult<()>
where
C: Sink<PgWireBackendMessage> + Unpin,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
client
.send(PgWireBackendMessage::CommandComplete(tag.into()))
.await?;
Ok(())
}
pub async fn send_describe_response<C, DR>(
client: &mut C,
describe_response: &DR,
) -> PgWireResult<()>
where
C: Sink<PgWireBackendMessage> + Unpin,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
DR: DescribeResponse,
{
if let Some(parameter_types) = describe_response.parameters() {
client
.send(PgWireBackendMessage::ParameterDescription(
ParameterDescription::new(parameter_types.iter().map(|t| t.oid()).collect()),
))
.await?;
}
if describe_response.is_no_data() {
client.send(PgWireBackendMessage::NoData(NoData)).await?;
} else {
let row_desc = into_row_description(describe_response.fields());
client
.send(PgWireBackendMessage::RowDescription(row_desc))
.await?;
}
Ok(())
}
#[async_trait]
impl ExtendedQueryHandler for super::NoopHandler {
type Statement = String;
type QueryParser = NoopQueryParser;
fn query_parser(&self) -> Arc<Self::QueryParser> {
Arc::new(NoopQueryParser)
}
async fn do_query<C>(
&self,
_client: &mut C,
_portal: &Portal<Self::Statement>,
_max_rows: usize,
) -> PgWireResult<Response>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_owned(),
"08P01".to_owned(),
"This feature is not implemented.".to_string(),
))))
}
async fn do_describe_statement<C>(
&self,
_client: &mut C,
_statement: &StoredStatement<Self::Statement>,
) -> PgWireResult<DescribeStatementResponse>
where
C: ClientInfo + Unpin + Send + Sync,
{
Ok(DescribeStatementResponse::no_data())
}
async fn do_describe_portal<C>(
&self,
_client: &mut C,
_portal: &Portal<Self::Statement>,
) -> PgWireResult<DescribePortalResponse>
where
C: ClientInfo + Unpin + Send + Sync,
{
Ok(DescribePortalResponse::no_data())
}
}
#[async_trait]
impl SimpleQueryHandler for super::NoopHandler {
async fn do_query<C>(&self, _client: &mut C, _query: &str) -> PgWireResult<Vec<Response>>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_owned(),
"08P01".to_owned(),
"This feature is not implemented.".to_string(),
))))
}
}