use std::collections::BTreeMap;
use std::fmt::{self, Debug, Formatter};
use std::future::Future;
use std::sync::Arc;
use crate::HashMap;
use crate::common::StatementCache;
use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::io::StatementId;
use crate::message::{
BackendMessageFormat, Close, Query, ReadyForQuery, ReceivedMessage, Terminate,
TransactionStatus,
};
use crate::statement::PgStatementMetadata;
use crate::transaction::Transaction;
use crate::types::Oid;
use crate::{PgConnectOptions, PgTypeInfo, Postgres};
pub(crate) use sqlx_core::connection::*;
use sqlx_core::sql_str::SqlSafeStr;
pub use self::stream::PgStream;
pub(crate) mod describe;
mod establish;
mod executor;
mod sasl;
mod stream;
mod tls;
pub struct PgConnection {
pub(crate) inner: Box<PgConnectionInner>,
}
pub struct PgConnectionInner {
pub(crate) stream: PgStream,
#[allow(dead_code)]
process_id: u32,
#[allow(dead_code)]
secret_key: u32,
next_statement_id: StatementId,
cache_statement: StatementCache<(StatementId, Arc<PgStatementMetadata>)>,
cache_type_info: HashMap<Oid, PgTypeInfo>,
cache_type_oid: HashMap<UStr, Oid>,
cache_elem_type_to_array: HashMap<Oid, Oid>,
cache_table_to_column_names: HashMap<Oid, TableColumns>,
pub(crate) pending_ready_for_query_count: usize,
transaction_status: TransactionStatus,
pub(crate) transaction_depth: usize,
log_settings: LogSettings,
}
pub(crate) struct TableColumns {
table_name: Arc<str>,
columns: BTreeMap<i16, Arc<str>>,
}
impl PgConnection {
pub fn server_version_num(&self) -> Option<u32> {
self.inner.stream.server_version_num
}
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
if !self.inner.stream.write_buffer_mut().is_empty() {
self.inner.stream.flush().await?;
}
while self.inner.pending_ready_for_query_count > 0 {
let message = self.inner.stream.recv().await?;
if let BackendMessageFormat::ReadyForQuery = message.format {
self.handle_ready_for_query(message)?;
}
}
Ok(())
}
async fn recv_ready_for_query(&mut self) -> Result<(), Error> {
let r: ReadyForQuery = self.inner.stream.recv_expect().await?;
self.inner.pending_ready_for_query_count -= 1;
self.inner.transaction_status = r.transaction_status;
Ok(())
}
#[inline(always)]
fn handle_ready_for_query(&mut self, message: ReceivedMessage) -> Result<(), Error> {
self.inner.pending_ready_for_query_count = self
.inner
.pending_ready_for_query_count
.checked_sub(1)
.ok_or_else(|| err_protocol!("received more ReadyForQuery messages than expected"))?;
self.inner.transaction_status = message.decode::<ReadyForQuery>()?.transaction_status;
Ok(())
}
#[inline(always)]
pub(crate) fn queue_simple_query(&mut self, query: &str) -> Result<(), Error> {
self.inner.stream.write_msg(Query(query))?;
self.inner.pending_ready_for_query_count += 1;
Ok(())
}
pub(crate) fn in_transaction(&self) -> bool {
match self.inner.transaction_status {
TransactionStatus::Transaction => true,
TransactionStatus::Error | TransactionStatus::Idle => false,
}
}
}
impl Debug for PgConnection {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("PgConnection").finish()
}
}
impl Connection for PgConnection {
type Database = Postgres;
type Options = PgConnectOptions;
async fn close(mut self) -> Result<(), Error> {
self.inner.stream.send(Terminate).await?;
self.inner.stream.shutdown().await?;
Ok(())
}
async fn close_hard(mut self) -> Result<(), Error> {
self.inner.stream.shutdown().await?;
Ok(())
}
async fn ping(&mut self) -> Result<(), Error> {
self.write_sync();
self.wait_until_ready().await
}
fn begin(
&mut self,
) -> impl Future<Output = Result<Transaction<'_, Self::Database>, Error>> + Send + '_ {
Transaction::begin(self, None)
}
fn begin_with(
&mut self,
statement: impl SqlSafeStr,
) -> impl Future<Output = Result<Transaction<'_, Self::Database>, Error>> + Send + '_
where
Self: Sized,
{
Transaction::begin(self, Some(statement.into_sql_str()))
}
fn cached_statements_size(&self) -> usize {
self.inner.cache_statement.len()
}
async fn clear_cached_statements(&mut self) -> Result<(), Error> {
self.inner.cache_type_oid.clear();
let mut cleared = 0_usize;
self.wait_until_ready().await?;
while let Some((id, _)) = self.inner.cache_statement.remove_lru() {
self.inner.stream.write_msg(Close::Statement(id))?;
cleared += 1;
}
if cleared > 0 {
self.write_sync();
self.inner.stream.flush().await?;
self.wait_for_close_complete(cleared).await?;
self.recv_ready_for_query().await?;
}
Ok(())
}
fn shrink_buffers(&mut self) {
self.inner.stream.shrink_buffers();
}
#[doc(hidden)]
fn flush(&mut self) -> impl Future<Output = Result<(), Error>> + Send + '_ {
self.wait_until_ready()
}
#[doc(hidden)]
fn should_flush(&self) -> bool {
!self.inner.stream.write_buffer().is_empty()
}
}
impl AsMut<PgConnection> for PgConnection {
fn as_mut(&mut self) -> &mut PgConnection {
self
}
}