use std::collections::HashMap;
use std::ffi::{c_void, CStr};
use std::sync::Arc;
use std::time::Instant;
use log::debug;
use libpq::Connection;
use libpq_sys::ExecStatusType::{PGRES_COMMAND_OK, PGRES_TUPLES_OK};
use libpq_sys::{
PGContextVisibility, PQclear, PQconsumeInput, PQfname, PQgetResult, PQgetvalue, PQlibVersion,
PQnfields, PQntuples, PQresultStatus, PQresultVerboseErrorMessage, PQsendQuery,
PQsetErrorVerbosity, PQsetNoticeReceiver,
};
use crate::notices::{notice_receiver, Notice, NoticeStorage, Verbosity};
use crate::value::Value;
pub struct PgwireLite {
hostname: String,
port: u16,
use_tls: bool,
verbosity: Verbosity,
notices: NoticeStorage,
}
#[derive(Debug)]
pub struct QueryResult {
pub rows: Vec<HashMap<String, Value>>,
pub column_names: Vec<String>,
pub notices: Vec<Notice>,
pub row_count: i32,
pub col_count: i32,
pub notice_count: usize,
pub status: libpq_sys::ExecStatusType,
pub elapsed_time_ms: u64,
}
fn clear_pg_result(result: *mut libpq_sys::PGresult) {
if !result.is_null() {
unsafe {
debug!("Clearing PGresult at {:p}", result);
PQclear(result);
debug!("PGresult cleared successfully");
}
}
}
impl PgwireLite {
pub fn new(
hostname: &str,
port: u16,
use_tls: bool,
verbosity: &str,
) -> Result<Self, Box<dyn std::error::Error>> {
let verbosity_val = match verbosity.to_lowercase().as_str() {
"default" => Verbosity::Default,
"verbose" => Verbosity::Verbose,
"terse" => Verbosity::Terse,
"sqlstate" => Verbosity::Sqlstate,
"" => Verbosity::Default,
_ => Verbosity::Default,
};
match verbosity_val {
Verbosity::Terse => log::set_max_level(log::LevelFilter::Warn),
Verbosity::Default => log::set_max_level(log::LevelFilter::Info),
Verbosity::Verbose => log::set_max_level(log::LevelFilter::Debug),
Verbosity::Sqlstate => log::set_max_level(log::LevelFilter::Debug),
}
let notices = Arc::new(std::sync::Mutex::new(Vec::new()));
Ok(PgwireLite {
hostname: hostname.to_string(),
port,
use_tls,
verbosity: verbosity_val,
notices,
})
}
pub fn libpq_version(&self) -> String {
let version = unsafe { PQlibVersion() };
let major = version / 10000;
let minor = (version / 100) % 100;
let patch = version % 100;
format!("{}.{}.{}", major, minor, patch)
}
pub fn verbosity(&self) -> String {
format!("{:?}", self.verbosity)
}
fn consume_pending_results(conn: &Connection) {
debug!("Consuming pending results");
unsafe {
PQconsumeInput(conn.into());
loop {
let result = PQgetResult(conn.into());
if result.is_null() {
break;
}
clear_pg_result(result);
}
}
}
pub fn query(&self, query: &str) -> Result<QueryResult, Box<dyn std::error::Error>> {
debug!("Clearing previous notices");
if let Ok(mut notices) = self.notices.lock() {
notices.clear();
}
let start_time = Instant::now();
let conn_str = format!(
"host={} port={} sslmode={} application_name=pgwire-lite-client connect_timeout=10 client_encoding=UTF8",
self.hostname,
self.port,
if self.use_tls { "verify-full" } else { "disable" }
);
debug!("Establishing connection using: {}", conn_str);
let conn = Connection::new(&conn_str)?;
unsafe {
let ssl_in_use = libpq_sys::PQsslInUse((&conn).into()) != 0;
let host_ptr = libpq_sys::PQhost((&conn).into());
let port_ptr = libpq_sys::PQport((&conn).into());
if !host_ptr.is_null() && !port_ptr.is_null() {
let host = CStr::from_ptr(host_ptr).to_string_lossy();
let port = CStr::from_ptr(port_ptr).to_string_lossy();
debug!("Connected to: {}:{} (ssl: {})", host, port, ssl_in_use);
}
let status = libpq_sys::PQstatus((&conn).into());
debug!("Connection status: {:?}", status);
let tx_status = libpq_sys::PQtransactionStatus((&conn).into());
debug!("Transaction status: {:?}", tx_status);
let server_version = libpq_sys::PQserverVersion((&conn).into());
let major = server_version / 10000;
let minor = (server_version / 100) % 100;
let revision = server_version % 100;
debug!(
"Server version: {}.{}.{} ({})",
major, minor, revision, server_version
);
}
debug!("Setting error verbosity to: {:?}", self.verbosity);
unsafe {
PQsetErrorVerbosity((&conn).into(), self.verbosity.into());
}
debug!("Setting up notice receiver");
let notices_ptr = Arc::into_raw(self.notices.clone()) as *mut c_void;
unsafe {
PQsetNoticeReceiver((&conn).into(), Some(notice_receiver), notices_ptr);
}
let query = if query.ends_with(';') {
query.to_string()
} else {
format!("{};", query)
};
debug!("Sending query: {}", query);
let send_success = unsafe { PQsendQuery((&conn).into(), query.as_ptr() as *const i8) };
if send_success == 0 {
return Err(
format!("Error: {}", conn.error_message().unwrap_or("Unknown error")).into(),
);
}
debug!("Processing the result");
let result = unsafe { PQgetResult((&conn).into()) };
if result.is_null() {
return Err("No result returned".into());
}
let status = unsafe { PQresultStatus(result) };
if status != PGRES_TUPLES_OK && status != PGRES_COMMAND_OK {
let error_msg_ptr = unsafe {
PQresultVerboseErrorMessage(
result,
self.verbosity.into(),
PGContextVisibility::PQSHOW_CONTEXT_ALWAYS,
)
};
let error_msg = if !error_msg_ptr.is_null() {
let msg = unsafe { CStr::from_ptr(error_msg_ptr).to_string_lossy().into_owned() };
unsafe { libpq_sys::PQfreemem(error_msg_ptr as *mut _) };
msg
} else {
conn.error_message().unwrap_or("Unknown error").to_string()
};
clear_pg_result(result);
Self::consume_pending_results(&conn);
return Err(error_msg.trim_end().to_string().into());
}
debug!("Getting column count");
let col_count = unsafe { PQnfields(result) };
debug!("Getting column names");
let mut column_names = Vec::with_capacity(col_count as usize);
for col_index in 0..col_count {
let col_name_ptr = unsafe { PQfname(result, col_index) };
if !col_name_ptr.is_null() {
let col_name =
unsafe { CStr::from_ptr(col_name_ptr).to_string_lossy().into_owned() };
column_names.push(col_name);
} else {
column_names.push(String::from("(unknown)"));
}
}
debug!("Getting row count");
let row_count = if status == PGRES_TUPLES_OK {
unsafe { PQntuples(result) }
} else {
0
};
let mut rows = Vec::new();
if status == PGRES_TUPLES_OK {
debug!("Processing rows");
for row_index in 0..row_count {
let mut row_data = HashMap::new();
for col_index in 0..col_count {
let value_ptr = unsafe { PQgetvalue(result, row_index, col_index) };
let value = if !value_ptr.is_null() {
let string_value =
unsafe { CStr::from_ptr(value_ptr).to_string_lossy().into_owned() };
Value::String(string_value)
} else {
Value::Null
};
row_data.insert(column_names[col_index as usize].clone(), value);
}
rows.push(row_data);
}
}
debug!("Rows processed: {}", rows.len());
clear_pg_result(result);
Self::consume_pending_results(&conn);
debug!("Collecting notices");
let notices = if let Ok(mut lock) = self.notices.lock() {
lock.drain(..).collect()
} else {
Vec::new()
};
let notice_count = notices.len();
let elapsed_time_ms = start_time.elapsed().as_millis() as u64;
drop(conn);
Ok(QueryResult {
rows,
column_names,
notices,
row_count,
col_count,
notice_count,
status,
elapsed_time_ms,
})
}
}