use std::collections::HashMap;
use std::ffi::{CStr, CString};
use std::ptr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, Ordering};
use quex_pq_sys as ffi;
use tokio::io::unix::AsyncFd;
use super::error::{Error, ExecuteResult, Result};
use super::options::ConnectOptions;
use super::rows::{Metadata, ResultSet};
use super::runtime::{
CONNECTION_POISONED, CONNECTION_READY, ConnHandle, ConnectParams, ConnectionOpGuard,
ParamScratch, SocketRef, execute_prepared_no_params, execute_result_from_handle,
finish_request, prepare_named_statement, send_query, should_prepare_query, wait_for_socket,
};
use super::statement::{CachedStatement, Statement};
#[repr(C)]
struct PgNotify {
relname: *mut libc::c_char,
be_pid: libc::c_int,
extra: *mut libc::c_char,
}
unsafe extern "C" {
fn PQnotifies(conn: *mut ffi::PGconn) -> *mut PgNotify;
fn PQfreemem(ptr: *mut libc::c_void);
}
struct QueryCacheEntry {
sql: CString,
kind: QueryCacheKind,
}
enum QueryCacheKind {
Simple {
metadata: Option<Arc<Metadata>>,
},
Prepared {
name: CString,
metadata: Option<Arc<Metadata>>,
},
}
pub(crate) struct CachedStmtEntry {
pub(crate) name: CString,
pub(crate) result_metadata: Option<Arc<Metadata>>,
pub(crate) scratch: ParamScratch,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Notification {
pub channel: String,
pub process_id: i32,
pub payload: String,
}
pub struct Connection {
pub(crate) conn: ConnHandle,
pub(crate) socket: AsyncFd<SocketRef>,
pub(crate) state: Arc<AtomicU8>,
query_cache: HashMap<Box<str>, QueryCacheEntry>,
pub(crate) statement_cache: HashMap<Box<str>, CachedStmtEntry>,
next_statement_id: u64,
}
impl Connection {
pub async fn connect(options: ConnectOptions) -> Result<Self> {
unsafe {
let params = ConnectParams::new(&options)?;
let conn =
ffi::PQconnectStartParams(params.keywords.as_ptr(), params.values.as_ptr(), 0);
let conn = ConnHandle(
std::ptr::NonNull::new(conn)
.ok_or_else(|| Error::new("PQconnectStartParams returned null"))?,
);
if ffi::PQsetnonblocking(conn.as_ptr(), 1) != 0 {
let error = Error::from_conn(conn.as_ptr(), "PQsetnonblocking failed");
ffi::PQfinish(conn.as_ptr());
return Err(error);
}
let socket_fd = ffi::PQsocket(conn.as_ptr());
if socket_fd < 0 {
let error = Error::from_conn(conn.as_ptr(), "libpq did not expose a valid socket");
ffi::PQfinish(conn.as_ptr());
return Err(error);
}
let connection = Self {
conn,
socket: AsyncFd::new(SocketRef(socket_fd))?,
state: Arc::new(AtomicU8::new(CONNECTION_READY)),
query_cache: HashMap::new(),
statement_cache: HashMap::new(),
next_statement_id: 1,
};
loop {
match ffi::PQconnectPoll(connection.conn.as_ptr()) {
x if x == ffi::PostgresPollingStatusType_PGRES_POLLING_OK => break,
x if x == ffi::PostgresPollingStatusType_PGRES_POLLING_READING => {
connection.wait_readable().await?;
}
x if x == ffi::PostgresPollingStatusType_PGRES_POLLING_WRITING => {
connection.wait_writable().await?;
}
x if x == ffi::PostgresPollingStatusType_PGRES_POLLING_ACTIVE => continue,
_ => {
return Err(Error::from_conn(
connection.conn.as_ptr(),
"connection failed",
));
}
}
}
Ok(connection)
}
}
pub async fn query(&mut self, sql_text: &str) -> Result<ResultSet> {
self.ensure_ready()?;
let mut guard = ConnectionOpGuard::new(&self.state);
if !self.query_cache.contains_key(sql_text) {
let kind = if should_prepare_query(sql_text) {
let statement_name = format!("quex_driver_query_{}", self.next_statement_id);
self.next_statement_id += 1;
QueryCacheKind::Prepared {
name: CString::new(statement_name).expect("statement name contains no nul"),
metadata: None,
}
} else {
QueryCacheKind::Simple { metadata: None }
};
self.query_cache.insert(
sql_text.into(),
QueryCacheEntry {
sql: CString::new(sql_text).expect("query contains no nul byte"),
kind,
},
);
}
let entry = self
.query_cache
.get_mut(sql_text)
.expect("query cache entry missing");
let (result, metadata) = match &mut entry.kind {
QueryCacheKind::Simple { metadata } => {
send_query(self.conn, entry.sql.as_ptr())?;
let result = finish_request(self.conn, &self.socket).await?;
let metadata = match metadata {
Some(metadata) => Arc::clone(metadata),
None => {
let new_metadata = Arc::new(Metadata::from_result(result));
*metadata = Some(Arc::clone(&new_metadata));
new_metadata
}
};
(result, metadata)
}
QueryCacheKind::Prepared { name, metadata } => {
if metadata.is_none() {
prepare_named_statement(self.conn, &self.socket, name, &entry.sql).await?;
}
let result =
execute_prepared_no_params(self.conn, &self.socket, &self.state, name).await?;
let metadata = match metadata {
Some(metadata) => Arc::clone(metadata),
None => {
let new_metadata = Arc::new(Metadata::from_result(result));
*metadata = Some(Arc::clone(&new_metadata));
new_metadata
}
};
(result, metadata)
}
};
guard.complete();
Ok(ResultSet::new(result, metadata))
}
pub async fn prepare(&mut self, sql: &str) -> Result<Statement<'_>> {
self.ensure_ready()?;
let mut guard = ConnectionOpGuard::new(&self.state);
let name = CString::new("")?;
let sql = CString::new(sql)?;
unsafe {
if ffi::PQsendPrepare(
self.conn.as_ptr(),
name.as_ptr(),
sql.as_ptr(),
0,
ptr::null(),
) == 0
{
return Err(Error::from_conn(self.conn.as_ptr(), "PQsendPrepare failed"));
}
}
let result = finish_request(self.conn, &self.socket).await?;
let status = unsafe { ffi::PQresultStatus(result.as_ptr()) };
if status != ffi::ExecStatusType_PGRES_COMMAND_OK {
let error = unsafe { Error::from_result(result.as_ptr(), "prepare failed") };
unsafe { ffi::PQclear(result.as_ptr()) };
return Err(error);
}
unsafe { ffi::PQclear(result.as_ptr()) };
guard.complete();
Ok(Statement {
conn: self,
name,
result_metadata: None,
scratch: ParamScratch::new(),
})
}
pub async fn prepare_cached(&mut self, sql: &str) -> Result<CachedStatement<'_>> {
self.ensure_ready()?;
let mut guard = ConnectionOpGuard::new(&self.state);
if !self.statement_cache.contains_key(sql) {
let statement_name = format!("quex_driver_stmt_{}", self.next_statement_id);
self.next_statement_id += 1;
let name = CString::new(statement_name).expect("statement name contains no nul");
let query = CString::new(sql)?;
prepare_named_statement(self.conn, &self.socket, &name, &query).await?;
self.statement_cache.insert(
sql.into(),
CachedStmtEntry {
name,
result_metadata: None,
scratch: ParamScratch::new(),
},
);
}
guard.complete();
Ok(CachedStatement {
conn: self,
key: sql.into(),
})
}
pub async fn begin(&mut self) -> Result<Transaction<'_>> {
self.query("begin").await?;
Ok(Transaction {
conn: self,
finished: false,
})
}
pub async fn execute(&mut self, sql: &str) -> Result<ExecuteResult> {
self.ensure_ready()?;
let mut guard = ConnectionOpGuard::new(&self.state);
let sql = CString::new(sql)?;
unsafe {
if ffi::PQsendQuery(self.conn.as_ptr(), sql.as_ptr()) == 0 {
return Err(Error::from_conn(self.conn.as_ptr(), "PQsendQuery failed"));
}
}
let result = finish_request(self.conn, &self.socket).await?;
let execute = execute_result_from_handle(result)?;
unsafe { ffi::PQclear(result.as_ptr()) };
guard.complete();
Ok(execute)
}
pub async fn listen(&mut self, channel: &str) -> Result<()> {
let sql = format!("listen {}", quote_identifier(channel));
self.execute(&sql).await?;
Ok(())
}
pub async fn unlisten(&mut self, channel: &str) -> Result<()> {
let sql = format!("unlisten {}", quote_identifier(channel));
self.execute(&sql).await?;
Ok(())
}
pub async fn unlisten_all(&mut self) -> Result<()> {
self.execute("unlisten *").await?;
Ok(())
}
pub async fn notify(&mut self, channel: &str, payload: Option<&str>) -> Result<()> {
let channel = quote_identifier(channel);
let sql = match payload {
Some(payload) => format!("notify {channel}, {}", quote_literal(payload)),
None => format!("notify {channel}"),
};
self.execute(&sql).await?;
Ok(())
}
pub async fn try_recv_notification(&mut self) -> Result<Option<Notification>> {
self.ensure_ready()?;
self.consume_input()?;
Ok(self.pop_notification())
}
pub async fn wait_for_notification(&mut self) -> Result<Notification> {
self.ensure_ready()?;
loop {
self.consume_input()?;
if let Some(notification) = self.pop_notification() {
return Ok(notification);
}
self.wait_readable().await?;
}
}
pub async fn commit(&mut self) -> Result<()> {
self.query("commit").await.map(|_| ())
}
pub async fn rollback(&mut self) -> Result<()> {
self.query("rollback").await.map(|_| ())
}
#[inline]
pub(crate) fn ensure_ready(&self) -> Result<()> {
if self.state.load(Ordering::Acquire) == CONNECTION_POISONED {
Err(Error::new(
"postgres connection is no longer reusable after a cancelled or dropped operation",
))
} else {
Ok(())
}
}
async fn wait_readable(&self) -> Result<()> {
wait_for_socket(&self.socket, libc::POLLIN, true).await
}
async fn wait_writable(&self) -> Result<()> {
wait_for_socket(&self.socket, libc::POLLOUT, false).await
}
fn consume_input(&mut self) -> Result<()> {
unsafe {
if ffi::PQconsumeInput(self.conn.as_ptr()) == 0 {
return Err(Error::from_conn(
self.conn.as_ptr(),
"PQconsumeInput failed",
));
}
}
Ok(())
}
fn pop_notification(&mut self) -> Option<Notification> {
unsafe {
let notify = PQnotifies(self.conn.as_ptr());
let notify = notify.as_ref()?;
let notification = Notification {
channel: c_string_lossy(notify.relname),
process_id: notify.be_pid,
payload: c_string_lossy(notify.extra),
};
PQfreemem(notify as *const PgNotify as *mut libc::c_void);
Some(notification)
}
}
}
impl Drop for Connection {
fn drop(&mut self) {
unsafe {
ffi::PQfinish(self.conn.as_ptr());
}
}
}
pub struct Transaction<'a> {
conn: &'a mut Connection,
finished: bool,
}
impl Transaction<'_> {
#[inline]
pub fn connection(&mut self) -> &mut Connection {
self.conn
}
pub async fn commit(mut self) -> Result<()> {
self.finished = true;
self.conn.commit().await
}
pub async fn rollback(mut self) -> Result<()> {
self.finished = true;
self.conn.rollback().await
}
}
impl Drop for Transaction<'_> {
fn drop(&mut self) {
if !self.finished {
self.conn
.state
.store(CONNECTION_POISONED, Ordering::Release);
}
}
}
unsafe impl Send for Connection {}
fn quote_identifier(value: &str) -> String {
let mut quoted = String::with_capacity(value.len() + 2);
quoted.push('"');
for ch in value.chars() {
if ch == '"' {
quoted.push('"');
}
quoted.push(ch);
}
quoted.push('"');
quoted
}
fn quote_literal(value: &str) -> String {
let mut quoted = String::with_capacity(value.len() + 2);
quoted.push('\'');
for ch in value.chars() {
if ch == '\'' {
quoted.push('\'');
}
quoted.push(ch);
}
quoted.push('\'');
quoted
}
fn c_string_lossy(ptr: *const libc::c_char) -> String {
if ptr.is_null() {
String::new()
} else {
unsafe { CStr::from_ptr(ptr) }
.to_string_lossy()
.into_owned()
}
}
#[cfg(test)]
mod tests {
use super::{quote_identifier, quote_literal};
#[test]
fn postgres_identifiers_are_quoted() {
assert_eq!(quote_identifier("events"), "\"events\"");
assert_eq!(quote_identifier("weird\"name"), "\"weird\"\"name\"");
}
#[test]
fn postgres_literals_are_quoted() {
assert_eq!(quote_literal("payload"), "'payload'");
assert_eq!(quote_literal("Ada's event"), "'Ada''s event'");
}
}