use super::{
pipeline, startup, BackendMessage, BytesMut, CancelToken, Config, Connection, Duration, Error,
PgConnection, PipelineBatch, Result, StatementCache, ToSql, TransactionStatus,
};
use crate::config::{LoadBalanceHosts, TargetSessionAttrs};
impl Connection {
pub async fn connect(config: Config) -> Result<Self> {
let mut hosts: Vec<(String, u16)> = config.hosts().to_vec();
if hosts.is_empty() {
hosts.push(("localhost".to_string(), 5432));
}
if config.load_balance_hosts() == LoadBalanceHosts::Random {
use rand::seq::SliceRandom;
use rand::thread_rng;
hosts.shuffle(&mut thread_rng());
}
let mut last_error: Option<Error> = None;
for (host, port) in &hosts {
match Self::try_connect_host(&config, host, *port).await {
Ok(conn) => return Ok(conn),
Err(e) => {
tracing::debug!(host = %host, port = %port, error = %e, "host failed");
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or_else(|| Error::AllHostsFailed("no hosts configured".to_string())))
}
async fn try_connect_host(config: &Config, host: &str, port: u16) -> Result<Self> {
let mut conn = PgConnection::connect_host(config, host, port).await?;
let result = startup::startup(&mut conn, config).await?;
if config.target_session_attrs() != TargetSessionAttrs::Any {
startup::check_session_attrs(&mut conn, config.target_session_attrs()).await?;
}
let query_timeout = config.statement_timeout();
Ok(Self {
conn,
config: config.clone(),
connected_host: host.to_string(),
connected_port: port,
process_id: result.process_id,
secret_key: result.secret_key,
transaction_status: result.transaction_status,
stmt_cache: StatementCache::new(),
query_timeout,
is_broken: false,
})
}
pub async fn close(self) -> Result<()> {
self.conn.close().await
}
pub fn cancel_token(&self) -> CancelToken {
CancelToken::new(
&self.connected_host,
self.connected_port,
self.process_id,
self.secret_key,
)
}
pub fn config(&self) -> &Config {
&self.config
}
pub fn connected_host(&self) -> &str {
&self.connected_host
}
pub fn connected_port(&self) -> u16 {
self.connected_port
}
pub fn is_tls(&self) -> bool {
self.conn.is_tls()
}
#[cfg(unix)]
pub fn is_unix(&self) -> bool {
self.conn.is_unix()
}
pub fn process_id(&self) -> i32 {
self.process_id
}
pub fn query_timeout(&self) -> Option<Duration> {
self.query_timeout
}
pub fn is_broken(&self) -> bool {
self.is_broken
}
pub fn transaction_status(&self) -> TransactionStatus {
self.transaction_status
}
pub(crate) fn pg_connection_mut(&mut self) -> &mut PgConnection {
&mut self.conn
}
pub(crate) async fn query_internal(
&mut self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> Result<pipeline::QueryResult> {
let param_types: Vec<u32> = params.iter().map(|p| p.oid().0).collect();
let mut encoded_params: Vec<Option<Vec<u8>>> = Vec::with_capacity(params.len());
for param in params {
if param.is_null() {
encoded_params.push(None);
} else {
let mut buf = BytesMut::new();
param.to_sql(&mut buf)?;
encoded_params.push(Some(buf.to_vec()));
}
}
let mut batch = PipelineBatch::new();
batch.add(sql.to_string(), param_types, encoded_params);
let mut results = batch.execute(&mut self.conn).await?;
results
.pop()
.ok_or_else(|| Error::protocol("pipeline returned no results"))
}
pub(crate) async fn drain_until_ready(&mut self) -> Result<()> {
loop {
if let BackendMessage::ReadyForQuery { transaction_status } = self.conn.recv().await? {
self.transaction_status = transaction_status;
return Ok(());
}
}
}
}