use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::BytesMut;
use tokio::sync::mpsc;
use pg_wired::protocol::frontend;
use pg_wired::protocol::types::{FormatCode, FrontendMsg, RawRow};
use pg_wired::{AsyncConn, PipelineResponse, ResponseCollector, WireConn};
use crate::encode::SqlParam;
use crate::error::TypedError;
use crate::row::{Row, RowSchema};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum IsolationLevel {
ReadCommitted,
RepeatableRead,
Serializable,
}
impl IsolationLevel {
pub fn as_sql(&self) -> &'static str {
match self {
Self::ReadCommitted => "READ COMMITTED",
Self::RepeatableRead => "REPEATABLE READ",
Self::Serializable => "SERIALIZABLE",
}
}
}
#[derive(Debug)]
pub struct Client {
conn: AsyncConn,
}
impl Client {
pub const DEFAULT_STREAM_BUFFER: usize = 256;
}
impl Client {
pub fn new(conn: WireConn) -> Self {
Self {
conn: AsyncConn::new(conn),
}
}
pub fn from_async_conn(conn: AsyncConn) -> Self {
Self { conn }
}
pub async fn connect(
addr: &str,
user: &str,
password: &str,
database: &str,
) -> Result<Self, TypedError> {
tracing::debug!(addr = addr, user = user, database = database, "connecting");
let wire = WireConn::connect(addr, user, password, database).await?;
tracing::info!(
addr = addr,
database = database,
pid = wire.pid(),
"connected"
);
Ok(Self::new(wire))
}
pub async fn connect_from_str(connstr: &str) -> Result<Self, TypedError> {
let (user, password, host, port, database) = parse_connection_string(connstr)
.ok_or_else(|| TypedError::Config("invalid connection string".to_string()))?;
let addr = format!("{host}:{port}");
Self::connect(&addr, &user, &password, &database).await
}
pub async fn connect_with_init(
addr: &str,
user: &str,
password: &str,
database: &str,
init_sql: &[&str],
) -> Result<Self, TypedError> {
let client = Self::connect(addr, user, password, database).await?;
for sql in init_sql {
tracing::debug!(sql = sql, "running init SQL");
client.simple_query(sql).await?;
}
Ok(client)
}
pub async fn lookup_type_oid(&self, type_name: &str) -> Result<Option<u32>, TypedError> {
let rows = self
.query(
"SELECT oid::int4 FROM pg_type WHERE typname = $1",
&[&type_name.to_string()],
)
.await?;
if rows.is_empty() {
Ok(None)
} else {
let oid: i32 = rows[0].get(0)?;
Ok(Some(oid as u32))
}
}
pub async fn lookup_type_oids(&self, type_name: &str) -> Result<(u32, u32), TypedError> {
let rows = self
.query(
"SELECT t.oid::int4, COALESCE(t.typarray, 0)::int4 \
FROM pg_type t WHERE t.typname = $1",
&[&type_name.to_string()],
)
.await?;
if rows.is_empty() {
Err(TypedError::Config(format!("type not found: {type_name}")))
} else {
let oid: i32 = rows[0].get(0)?;
let array_oid: i32 = rows[0].get(1)?;
Ok((oid as u32, array_oid as u32))
}
}
pub async fn query(&self, sql: &str, params: &[&dyn SqlParam]) -> Result<Vec<Row>, TypedError> {
let start = std::time::Instant::now();
let result = match self.query_inner(sql, params).await {
Err(TypedError::Wire(ref e)) if is_stale_statement_error(e) => {
tracing::debug!("stale statement detected, re-preparing");
self.conn.invalidate_statement(sql);
self.query_inner(sql, params).await
}
other => other,
};
let elapsed = start.elapsed();
match &result {
Ok(rows) => {
let us = elapsed.as_micros() as u64;
crate::metrics::record_query(us);
tracing::debug!(sql = %truncate_sql(sql), rows = rows.len(), elapsed_us = us, "query ok");
}
Err(ref e) => {
crate::metrics::record_query_error();
tracing::warn!(sql = %truncate_sql(sql), error = %e, elapsed_us = elapsed.as_micros() as u64, "query failed");
}
}
result.map_err(|e| e.with_sql(sql))
}
pub(crate) async fn query_on_conn(
conn: &AsyncConn,
sql: &str,
params: &[&dyn SqlParam],
) -> Result<Vec<Row>, TypedError> {
Self::query_inner_on(conn, sql, params).await
}
pub(crate) async fn execute_on_conn(
conn: &AsyncConn,
sql: &str,
params: &[&dyn SqlParam],
) -> Result<u64, TypedError> {
Self::execute_inner_on(conn, sql, params).await
}
pub(crate) async fn simple_query_on_conn(
conn: &AsyncConn,
sql: &str,
) -> Result<(), TypedError> {
let mut buf = BytesMut::new();
frontend::encode_message(&FrontendMsg::Query(sql.as_bytes()), &mut buf);
let _resp = conn.submit(buf, ResponseCollector::Drain).await?;
Ok(())
}
async fn query_inner(
&self,
sql: &str,
params: &[&dyn SqlParam],
) -> Result<Vec<Row>, TypedError> {
Self::query_inner_on(&self.conn, sql, params).await
}
async fn query_inner_on(
conn: &AsyncConn,
sql: &str,
params: &[&dyn SqlParam],
) -> Result<Vec<Row>, TypedError> {
let mut buf = BytesMut::with_capacity(512);
let param_formats: Vec<FormatCode> = vec![FormatCode::Binary; params.len()];
let result_formats = [FormatCode::Binary];
let mut param_oids: Vec<u32> = Vec::with_capacity(params.len());
let mut param_values: Vec<Option<BytesMut>> = Vec::with_capacity(params.len());
for p in params {
param_oids.push(p.param_oid());
param_values.push(p.encode_param_value());
}
let param_refs: Vec<Option<&[u8]>> = param_values
.iter()
.map(|v| v.as_ref().map(|b| b.as_ref()))
.collect();
let (stmt_name, needs_parse) = conn.lookup_or_alloc(sql, ¶m_oids);
if needs_parse {
frontend::encode_message(
&FrontendMsg::Parse {
name: &stmt_name,
sql: sql.as_bytes(),
param_oids: ¶m_oids,
},
&mut buf,
);
}
frontend::encode_message(
&FrontendMsg::Bind {
portal: b"",
statement: &stmt_name,
param_formats: ¶m_formats,
params: ¶m_refs,
result_formats: &result_formats,
},
&mut buf,
);
frontend::encode_message(
&FrontendMsg::Describe {
kind: b'P', name: b"",
},
&mut buf,
);
frontend::encode_message(
&FrontendMsg::Execute {
portal: b"",
max_rows: 0,
},
&mut buf,
);
frontend::encode_message(&FrontendMsg::Sync, &mut buf);
let resp = conn.submit(buf, ResponseCollector::Rows).await?;
if needs_parse {
conn.cache_statement(sql, &stmt_name);
}
match resp {
PipelineResponse::Rows {
fields,
rows: raw_rows,
command_tag: _,
} => {
let schema = Arc::new(build_row_schema(&fields, raw_rows.first()));
let rows = raw_rows
.into_iter()
.map(|data| Row {
schema: Arc::clone(&schema),
data,
})
.collect();
Ok(rows)
}
PipelineResponse::Done => Ok(Vec::new()),
_ => Ok(Vec::new()),
}
}
pub async fn query_one(&self, sql: &str, params: &[&dyn SqlParam]) -> Result<Row, TypedError> {
let rows = self.query(sql, params).await?;
if rows.len() != 1 {
return Err(TypedError::NotExactlyOne(rows.len()));
}
Ok(rows
.into_iter()
.next()
.expect("already verified rows.len()"))
}
pub async fn query_opt(
&self,
sql: &str,
params: &[&dyn SqlParam],
) -> Result<Option<Row>, TypedError> {
let rows = self.query(sql, params).await?;
match rows.len() {
0 => Ok(None),
1 => Ok(Some(
rows.into_iter()
.next()
.expect("already verified rows.len()"),
)),
n => Err(TypedError::NotExactlyOne(n)),
}
}
pub async fn execute(&self, sql: &str, params: &[&dyn SqlParam]) -> Result<u64, TypedError> {
let start = std::time::Instant::now();
let result = match self.execute_inner(sql, params).await {
Err(TypedError::Wire(ref e)) if is_stale_statement_error(e) => {
tracing::debug!("stale statement detected, re-preparing");
self.conn.invalidate_statement(sql);
self.execute_inner(sql, params).await
}
other => other,
};
let elapsed = start.elapsed();
match &result {
Ok(n) => {
crate::metrics::record_execute();
tracing::debug!(sql = %truncate_sql(sql), affected = n, elapsed_us = elapsed.as_micros() as u64, "execute ok");
}
Err(ref e) => {
crate::metrics::record_execute_error();
tracing::warn!(sql = %truncate_sql(sql), error = %e, elapsed_us = elapsed.as_micros() as u64, "execute failed");
}
}
result.map_err(|e| e.with_sql(sql))
}
async fn execute_inner(&self, sql: &str, params: &[&dyn SqlParam]) -> Result<u64, TypedError> {
Self::execute_inner_on(&self.conn, sql, params).await
}
async fn execute_inner_on(
conn: &AsyncConn,
sql: &str,
params: &[&dyn SqlParam],
) -> Result<u64, TypedError> {
let mut buf = BytesMut::with_capacity(512);
let param_formats: Vec<FormatCode> = vec![FormatCode::Binary; params.len()];
let result_formats = [FormatCode::Binary];
let mut param_oids: Vec<u32> = Vec::with_capacity(params.len());
let mut param_values: Vec<Option<BytesMut>> = Vec::with_capacity(params.len());
for p in params {
param_oids.push(p.param_oid());
param_values.push(p.encode_param_value());
}
let param_refs: Vec<Option<&[u8]>> = param_values
.iter()
.map(|v| v.as_ref().map(|b| b.as_ref()))
.collect();
let (stmt_name, needs_parse) = conn.lookup_or_alloc(sql, ¶m_oids);
if needs_parse {
frontend::encode_message(
&FrontendMsg::Parse {
name: &stmt_name,
sql: sql.as_bytes(),
param_oids: ¶m_oids,
},
&mut buf,
);
}
frontend::encode_message(
&FrontendMsg::Bind {
portal: b"",
statement: &stmt_name,
param_formats: ¶m_formats,
params: ¶m_refs,
result_formats: &result_formats,
},
&mut buf,
);
frontend::encode_message(
&FrontendMsg::Execute {
portal: b"",
max_rows: 0,
},
&mut buf,
);
frontend::encode_message(&FrontendMsg::Sync, &mut buf);
let resp = conn.submit(buf, ResponseCollector::Rows).await?;
if needs_parse {
conn.cache_statement(sql, &stmt_name);
}
match resp {
PipelineResponse::Rows { command_tag, .. } => Ok(parse_row_count(&command_tag)),
PipelineResponse::Done => Ok(0),
_ => Ok(0),
}
}
pub async fn query_stream(
&self,
sql: &str,
params: &[&dyn SqlParam],
) -> Result<RowStream, TypedError> {
let mut buf = BytesMut::with_capacity(512);
let param_formats: Vec<FormatCode> = vec![FormatCode::Binary; params.len()];
let result_formats = [FormatCode::Binary];
let mut param_oids: Vec<u32> = Vec::with_capacity(params.len());
let mut param_values: Vec<Option<BytesMut>> = Vec::with_capacity(params.len());
for p in params {
param_oids.push(p.param_oid());
param_values.push(p.encode_param_value());
}
let param_refs: Vec<Option<&[u8]>> = param_values
.iter()
.map(|v| v.as_ref().map(|b| b.as_ref()))
.collect();
let (stmt_name, needs_parse) = self.conn.lookup_or_alloc(sql, ¶m_oids);
if needs_parse {
frontend::encode_message(
&FrontendMsg::Parse {
name: &stmt_name,
sql: sql.as_bytes(),
param_oids: ¶m_oids,
},
&mut buf,
);
}
frontend::encode_message(
&FrontendMsg::Bind {
portal: b"",
statement: &stmt_name,
param_formats: ¶m_formats,
params: ¶m_refs,
result_formats: &result_formats,
},
&mut buf,
);
frontend::encode_message(
&FrontendMsg::Describe {
kind: b'P',
name: b"",
},
&mut buf,
);
frontend::encode_message(
&FrontendMsg::Execute {
portal: b"",
max_rows: 0,
},
&mut buf,
);
frontend::encode_message(&FrontendMsg::Sync, &mut buf);
let (header, row_rx) = self
.conn
.submit_stream(buf, Self::DEFAULT_STREAM_BUFFER)
.await?;
if needs_parse {
self.conn.cache_statement(sql, &stmt_name);
}
let has_desc = !header.fields.is_empty();
let schema = Arc::new(build_row_schema(&header.fields, None));
Ok(RowStream {
row_rx,
schema,
has_desc,
})
}
pub async fn copy_in(&self, copy_sql: &str, data: &[u8]) -> Result<u64, TypedError> {
self.conn
.copy_in(copy_sql, data)
.await
.map_err(|e| TypedError::from(e).with_sql(copy_sql))
}
pub async fn copy_out(&self, copy_sql: &str) -> Result<Vec<u8>, TypedError> {
self.conn
.copy_out(copy_sql)
.await
.map_err(|e| TypedError::from(e).with_sql(copy_sql))
}
pub async fn begin(&self) -> Result<Transaction<'_>, TypedError> {
self.simple_query("BEGIN").await?;
Ok(Transaction {
client: self,
done: false,
})
}
pub async fn begin_with(&self, level: IsolationLevel) -> Result<Transaction<'_>, TypedError> {
let sql = format!("BEGIN ISOLATION LEVEL {}", level.as_sql());
self.simple_query(&sql).await?;
Ok(Transaction {
client: self,
done: false,
})
}
pub async fn advisory_lock(&self, key: i64) -> Result<(), TypedError> {
self.simple_query(&format!("SELECT pg_advisory_lock({key})"))
.await
}
pub async fn try_advisory_lock(&self, key: i64) -> Result<bool, TypedError> {
let rows = self
.query("SELECT pg_try_advisory_lock($1::int8) AS acquired", &[&key])
.await?;
let row = rows.into_iter().next().ok_or_else(|| TypedError::Decode {
column: 0,
message: "pg_try_advisory_lock returned no rows".into(),
})?;
row.get::<bool>(0)
}
pub async fn advisory_unlock(&self, key: i64) -> Result<bool, TypedError> {
let rows = self
.query("SELECT pg_advisory_unlock($1::int8) AS released", &[&key])
.await?;
let row = rows.into_iter().next().ok_or_else(|| TypedError::Decode {
column: 0,
message: "pg_advisory_unlock returned no rows".into(),
})?;
row.get::<bool>(0)
}
pub async fn advisory_xact_lock(&self, key: i64) -> Result<(), TypedError> {
self.query("SELECT pg_advisory_xact_lock($1::int8)", &[&key])
.await?;
Ok(())
}
pub async fn try_advisory_xact_lock(&self, key: i64) -> Result<bool, TypedError> {
let rows = self
.query(
"SELECT pg_try_advisory_xact_lock($1::int8) AS acquired",
&[&key],
)
.await?;
let row = rows.into_iter().next().ok_or_else(|| TypedError::Decode {
column: 0,
message: "pg_try_advisory_xact_lock returned no rows".into(),
})?;
row.get::<bool>(0)
}
pub async fn simple_query(&self, sql: &str) -> Result<(), TypedError> {
use pg_wired::protocol::types::FrontendMsg;
let mut buf = BytesMut::new();
frontend::encode_message(&FrontendMsg::Query(sql.as_bytes()), &mut buf);
let _resp = self.conn.submit(buf, ResponseCollector::Drain).await?;
Ok(())
}
pub async fn query_named(
&self,
sql: &str,
params: &[(&str, &dyn SqlParam)],
) -> Result<Vec<Row>, TypedError> {
let (rewritten, names) = crate::named_params::rewrite(sql);
let ordered = resolve_named_params(&names, params)?;
self.query(&rewritten, &ordered).await
}
pub async fn execute_named(
&self,
sql: &str,
params: &[(&str, &dyn SqlParam)],
) -> Result<u64, TypedError> {
let (rewritten, names) = crate::named_params::rewrite(sql);
let ordered = resolve_named_params(&names, params)?;
self.execute(&rewritten, &ordered).await
}
pub async fn query_timeout(
&self,
sql: &str,
params: &[&dyn SqlParam],
timeout: std::time::Duration,
) -> Result<Vec<Row>, TypedError> {
let token = self.cancel_token();
match tokio::time::timeout(timeout, self.query(sql, params)).await {
Ok(result) => result,
Err(_elapsed) => {
let _ = token.cancel().await;
Err(TypedError::Timeout(timeout))
}
}
}
pub async fn execute_timeout(
&self,
sql: &str,
params: &[&dyn SqlParam],
timeout: std::time::Duration,
) -> Result<u64, TypedError> {
let token = self.cancel_token();
match tokio::time::timeout(timeout, self.execute(sql, params)).await {
Ok(result) => result,
Err(_elapsed) => {
let _ = token.cancel().await;
Err(TypedError::Timeout(timeout))
}
}
}
pub fn cancel_token(&self) -> pg_wired::CancelToken {
self.conn.cancel_token()
}
pub async fn ping(&self) -> Result<(), TypedError> {
self.query("SELECT 1", &[]).await?;
Ok(())
}
pub fn is_alive(&self) -> bool {
self.conn.is_alive()
}
pub fn conn(&self) -> &AsyncConn {
&self.conn
}
}
#[derive(Debug)]
pub struct Transaction<'a> {
pub(crate) client: &'a Client,
pub(crate) done: bool,
}
impl<'a> Transaction<'a> {
pub async fn query(&self, sql: &str, params: &[&dyn SqlParam]) -> Result<Vec<Row>, TypedError> {
self.client.query(sql, params).await
}
pub async fn execute(&self, sql: &str, params: &[&dyn SqlParam]) -> Result<u64, TypedError> {
self.client.execute(sql, params).await
}
pub async fn commit(mut self) -> Result<(), TypedError> {
self.done = true;
self.client.simple_query("COMMIT").await
}
pub async fn rollback(mut self) -> Result<(), TypedError> {
self.done = true;
self.client.simple_query("ROLLBACK").await
}
pub async fn query_named(
&self,
sql: &str,
params: &[(&str, &dyn SqlParam)],
) -> Result<Vec<Row>, TypedError> {
self.client.query_named(sql, params).await
}
pub async fn execute_named(
&self,
sql: &str,
params: &[(&str, &dyn SqlParam)],
) -> Result<u64, TypedError> {
self.client.execute_named(sql, params).await
}
}
impl<'a> Drop for Transaction<'a> {
fn drop(&mut self) {
if !self.done && self.client.is_alive() {
if !self.client.conn().enqueue_rollback() {
self.client.conn().mark_broken();
tracing::warn!(
"Transaction dropped without commit/rollback; could not queue ROLLBACK, connection marked broken"
);
}
}
}
}
pub(crate) fn resolve_named_params<'a>(
names: &[String],
params: &[(&str, &'a dyn SqlParam)],
) -> Result<Vec<&'a dyn SqlParam>, TypedError> {
names
.iter()
.map(|name| {
params
.iter()
.find(|(n, _)| *n == name.as_str())
.map(|(_, p)| *p)
.ok_or_else(|| TypedError::MissingParam(name.to_string()))
})
.collect()
}
#[derive(Debug)]
#[must_use = "Pipeline does nothing until .run() is awaited"]
pub struct Pipeline<'a> {
client: &'a Client,
buffers: Vec<BytesMut>,
pending_cache: Vec<(String, Vec<u8>)>,
}
#[non_exhaustive]
pub enum PipelineResult {
Rows(Vec<Row>),
Execute(u64),
}
impl<'a> Pipeline<'a> {
pub fn query(mut self, sql: &str, params: &[&dyn SqlParam]) -> Self {
self.encode_query(sql, params);
self
}
pub fn execute(mut self, sql: &str, params: &[&dyn SqlParam]) -> Self {
self.encode_query(sql, params);
self
}
fn encode_query(&mut self, sql: &str, params: &[&dyn SqlParam]) {
let param_formats: Vec<FormatCode> = vec![FormatCode::Binary; params.len()];
let result_formats = [FormatCode::Binary];
let mut param_oids: Vec<u32> = Vec::with_capacity(params.len());
let mut param_values: Vec<Option<BytesMut>> = Vec::with_capacity(params.len());
for p in params {
param_oids.push(p.param_oid());
param_values.push(p.encode_param_value());
}
let param_refs: Vec<Option<&[u8]>> = param_values
.iter()
.map(|v| v.as_ref().map(|b| b.as_ref()))
.collect();
let (stmt_name, needs_parse) = self.client.conn.lookup_or_alloc(sql, ¶m_oids);
let mut buf = BytesMut::with_capacity(256);
if needs_parse {
frontend::encode_message(
&FrontendMsg::Parse {
name: &stmt_name,
sql: sql.as_bytes(),
param_oids: ¶m_oids,
},
&mut buf,
);
}
frontend::encode_message(
&FrontendMsg::Bind {
portal: b"",
statement: &stmt_name,
param_formats: ¶m_formats,
params: ¶m_refs,
result_formats: &result_formats,
},
&mut buf,
);
frontend::encode_message(
&FrontendMsg::Describe {
kind: b'P',
name: b"",
},
&mut buf,
);
frontend::encode_message(
&FrontendMsg::Execute {
portal: b"",
max_rows: 0,
},
&mut buf,
);
frontend::encode_message(&FrontendMsg::Sync, &mut buf);
self.buffers.push(buf);
if needs_parse {
self.pending_cache
.push((sql.to_string(), stmt_name.clone()));
}
}
pub async fn run(self) -> Result<Vec<PipelineResult>, TypedError> {
if self.buffers.is_empty() {
return Ok(Vec::new());
}
let pending_cache = self.pending_cache;
let items: Vec<(BytesMut, ResponseCollector)> = self
.buffers
.into_iter()
.map(|b| (b, ResponseCollector::Rows))
.collect();
let responses = self.client.conn.submit_batch(items).await?;
for (sql, name) in &pending_cache {
self.client.conn.cache_statement(sql, name);
}
let mut results = Vec::with_capacity(responses.len());
for resp in responses {
match resp? {
PipelineResponse::Rows {
fields,
rows,
command_tag,
} => {
if rows.is_empty() && !command_tag.is_empty() {
results.push(PipelineResult::Execute(parse_row_count(&command_tag)));
} else {
let schema = Arc::new(build_row_schema(&fields, rows.first()));
let typed_rows = rows
.into_iter()
.map(|data| Row {
schema: Arc::clone(&schema),
data,
})
.collect();
results.push(PipelineResult::Rows(typed_rows));
}
}
PipelineResponse::Done => {
results.push(PipelineResult::Execute(0));
}
_ => {
results.push(PipelineResult::Execute(0));
}
}
}
Ok(results)
}
}
impl Client {
pub fn pipeline(&self) -> Pipeline<'_> {
Pipeline {
client: self,
buffers: Vec::new(),
pending_cache: Vec::new(),
}
}
}
#[derive(Debug)]
pub struct RowStream {
row_rx: mpsc::Receiver<Result<RawRow, pg_wired::PgWireError>>,
schema: Arc<RowSchema>,
has_desc: bool,
}
impl tokio_stream::Stream for RowStream {
type Item = Result<Row, TypedError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.row_rx.poll_recv(cx) {
Poll::Ready(Some(Ok(data))) => {
if !self.has_desc && self.schema.formats.len() != data.len() {
let mut s = RowSchema::empty();
s.formats = vec![1i16; data.len()];
self.schema = Arc::new(s);
self.has_desc = true;
}
let row = Row {
schema: Arc::clone(&self.schema),
data,
};
Poll::Ready(Some(Ok(row)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending,
}
}
}
fn build_row_schema(
fields: &[pg_wired::protocol::types::FieldDescription],
first_row: Option<&RawRow>,
) -> RowSchema {
if !fields.is_empty() {
RowSchema {
columns: fields.iter().map(|f| f.name.clone()).collect(),
type_oids: fields.iter().map(|f| f.type_oid).collect(),
formats: fields.iter().map(|f| f.format as i16).collect(),
}
} else if let Some(row) = first_row {
let mut s = RowSchema::empty();
s.formats = vec![1i16; row.len()];
s
} else {
RowSchema::empty()
}
}
fn truncate_sql(sql: &str) -> String {
if sql.len() <= 100 {
sql.to_string()
} else {
format!("{}...", &sql[..100])
}
}
fn is_stale_statement_error(e: &pg_wired::PgWireError) -> bool {
if let pg_wired::PgWireError::Pg(ref pg_err) = e {
matches!(pg_err.code.as_str(), "26000" | "0A000")
} else {
false
}
}
fn parse_row_count(tag: &str) -> u64 {
tag.rsplit_once(' ')
.and_then(|(_, count)| count.parse::<u64>().ok())
.unwrap_or(0)
}
#[doc(hidden)]
pub fn parse_connection_string(s: &str) -> Option<(String, String, String, u16, String)> {
if s.starts_with("postgres://") || s.starts_with("postgresql://") {
parse_pg_uri(s)
} else {
parse_pg_keyvalue(s)
}
}
fn parse_pg_uri(uri: &str) -> Option<(String, String, String, u16, String)> {
let rest = uri
.strip_prefix("postgres://")
.or_else(|| uri.strip_prefix("postgresql://"))?;
let rest = rest.split('?').next().unwrap_or(rest);
let (auth, hostdb) = rest.split_once('@').unwrap_or(("postgres:postgres", rest));
let (user, password) = auth.split_once(':').unwrap_or((auth, ""));
let (hostport, database) = hostdb.split_once('/').unwrap_or((hostdb, "postgres"));
let (host, port_str) = hostport.split_once(':').unwrap_or((hostport, "5432"));
let port: u16 = port_str.parse().unwrap_or(5432);
Some((
url_decode(user),
url_decode(password),
host.to_string(),
port,
url_decode(database),
))
}
fn parse_pg_keyvalue(s: &str) -> Option<(String, String, String, u16, String)> {
let mut host = "127.0.0.1".to_string();
let mut port: u16 = 5432;
let mut user = "postgres".to_string();
let mut password = String::new();
let mut dbname = "postgres".to_string();
for part in s.split_whitespace() {
if let Some((key, value)) = part.split_once('=') {
match key {
"host" | "hostaddr" => host = value.to_string(),
"port" => port = value.parse().unwrap_or(5432),
"user" => user = value.to_string(),
"password" => password = value.to_string(),
"dbname" => dbname = value.to_string(),
_ => {}
}
}
}
Some((user, password, host, port, dbname))
}
fn url_decode(s: &str) -> String {
let mut result = String::with_capacity(s.len());
let bytes = s.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%' && i + 2 < bytes.len() {
if let (Some(hi), Some(lo)) = (hex_digit(bytes[i + 1]), hex_digit(bytes[i + 2])) {
result.push((hi << 4 | lo) as char);
i += 3;
continue;
}
}
result.push(bytes[i] as char);
i += 1;
}
result
}
fn hex_digit(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_url_decode_basic() {
assert_eq!(url_decode("hello"), "hello");
assert_eq!(url_decode("hello%20world"), "hello world");
assert_eq!(url_decode("p%40ss"), "p@ss");
assert_eq!(url_decode("%23hash"), "#hash");
}
#[test]
fn test_url_decode_password_with_special_chars() {
assert_eq!(url_decode("p%40ssw%23rd"), "p@ssw#rd");
assert_eq!(url_decode("100%25done"), "100%done");
}
#[test]
fn test_url_decode_invalid_sequences() {
assert_eq!(url_decode("abc%2"), "abc%2");
assert_eq!(url_decode("abc%"), "abc%");
assert_eq!(url_decode("abc%ZZ"), "abc%ZZ");
}
#[test]
fn test_url_decode_empty() {
assert_eq!(url_decode(""), "");
}
#[test]
fn test_parse_connection_string_with_encoded_password() {
let (user, pass, host, port, db) =
parse_connection_string("postgres://user:p%40ss@localhost:5432/mydb").unwrap();
assert_eq!(user, "user");
assert_eq!(pass, "p@ss");
assert_eq!(host, "localhost");
assert_eq!(port, 5432);
assert_eq!(db, "mydb");
}
#[test]
fn test_parse_connection_string_keyvalue() {
let (user, pass, host, port, db) = parse_connection_string(
"host=db.example.com port=5433 dbname=prod user=admin password=secret",
)
.unwrap();
assert_eq!(user, "admin");
assert_eq!(pass, "secret");
assert_eq!(host, "db.example.com");
assert_eq!(port, 5433);
assert_eq!(db, "prod");
}
#[test]
fn test_parse_connection_string_defaults() {
let (user, pass, host, port, db) =
parse_connection_string("postgres://localhost/mydb").unwrap();
assert_eq!(user, "postgres");
assert_eq!(pass, "postgres");
assert_eq!(host, "localhost");
assert_eq!(port, 5432);
assert_eq!(db, "mydb");
}
}