use std::sync::Arc;
use rapidhash::quality::RapidHasher;
struct StmtCache {
entries: Vec<(u64, StmtInfo)>,
}
impl Default for StmtCache {
fn default() -> Self {
Self {
entries: Vec::with_capacity(16),
}
}
}
impl StmtCache {
#[inline]
fn get_mut(&mut self, hash: &u64) -> Option<&mut StmtInfo> {
self.entries
.iter_mut()
.find(|(h, _)| h == hash)
.map(|(_, info)| info)
}
#[inline]
fn get(&self, hash: &u64) -> Option<&StmtInfo> {
self.entries
.iter()
.find(|(h, _)| h == hash)
.map(|(_, info)| info)
}
#[inline]
fn contains_key(&self, hash: &u64) -> bool {
self.entries.iter().any(|(h, _)| h == hash)
}
#[inline]
fn insert(&mut self, hash: u64, info: StmtInfo) {
if let Some(entry) = self.entries.iter_mut().find(|(h, _)| *h == hash) {
entry.1 = info;
} else {
self.entries.push((hash, info));
}
}
#[inline]
fn remove(&mut self, hash: &u64) -> Option<StmtInfo> {
if let Some(pos) = self.entries.iter().position(|(h, _)| h == hash) {
Some(self.entries.swap_remove(pos).1)
} else {
None
}
}
#[inline]
fn len(&self) -> usize {
self.entries.len()
}
fn evict_lru(&mut self) -> Option<(u64, StmtInfo)> {
if self.entries.is_empty() {
return None;
}
let min_idx = self
.entries
.iter()
.enumerate()
.min_by_key(|(_, (_, info))| info.last_used)
.map(|(i, _)| i)?;
Some(self.entries.swap_remove(min_idx))
}
}
use tokio::io::{AsyncRead, AsyncWriteExt};
use tokio::net::TcpStream;
use crate::DriverError;
use crate::arena::Arena;
use crate::auth;
use crate::codec::Encode;
use crate::proto::{self, BackendMessage};
#[cfg(feature = "tls")]
use crate::tls;
enum Stream {
Plain(TcpStream),
#[cfg(feature = "tls")]
Tls(Box<tokio_rustls::client::TlsStream<TcpStream>>),
#[cfg(unix)]
Unix(tokio::net::UnixStream),
}
impl Stream {
async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
match self {
Stream::Plain(s) => s.write_all(buf).await,
#[cfg(feature = "tls")]
Stream::Tls(s) => s.write_all(buf).await,
#[cfg(unix)]
Stream::Unix(s) => s.write_all(buf).await,
}
}
async fn flush(&mut self) -> std::io::Result<()> {
match self {
Stream::Plain(s) => s.flush().await,
#[cfg(feature = "tls")]
Stream::Tls(s) => s.flush().await,
#[cfg(unix)]
Stream::Unix(s) => s.flush().await,
}
}
}
struct StreamReader<'a>(&'a mut Stream);
impl AsyncRead for StreamReader<'_> {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match &mut *self.0 {
Stream::Plain(s) => std::pin::Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "tls")]
Stream::Tls(s) => std::pin::Pin::new(s.as_mut()).poll_read(cx, buf),
#[cfg(unix)]
Stream::Unix(s) => std::pin::Pin::new(s).poll_read(cx, buf),
}
}
}
#[derive(Debug, Clone)]
pub struct Config {
pub host: String,
pub port: u16,
pub user: String,
pub password: String,
pub database: String,
pub ssl: SslMode,
pub statement_timeout_secs: u32,
}
impl Drop for Config {
fn drop(&mut self) {
use zeroize::Zeroize;
self.password.zeroize();
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SslMode {
Disable,
Prefer,
Require,
}
impl Config {
pub fn from_url(url: &str) -> Result<Self, DriverError> {
let url = url
.strip_prefix("postgres://")
.or_else(|| url.strip_prefix("postgresql://"))
.ok_or_else(|| DriverError::Protocol("URL must start with postgres://".into()))?;
let (userinfo, rest) = url
.split_once('@')
.ok_or_else(|| DriverError::Protocol("missing @ in connection URL".into()))?;
let (user, password) = userinfo.split_once(':').unwrap_or((userinfo, ""));
let (hostport, rest) = rest.split_once('/').unwrap_or((rest, ""));
let (database, params) = rest.split_once('?').unwrap_or((rest, ""));
let (host, port) = if let Some((h, p)) = hostport.split_once(':') {
let port = p
.parse::<u16>()
.map_err(|_| DriverError::Protocol(format!("invalid port: {p}")))?;
(h.to_owned(), port)
} else {
(hostport.to_owned(), 5432)
};
let mut ssl = SslMode::Prefer;
let mut statement_timeout_secs: u32 = 30;
let mut host_override: Option<String> = None;
for param in params.split('&') {
if param.is_empty() {
continue;
}
if let Some(val) = param.strip_prefix("sslmode=") {
ssl = match val {
"disable" => SslMode::Disable,
"prefer" => SslMode::Prefer,
"require" => SslMode::Require,
_ => {
return Err(DriverError::Protocol(format!(
"unknown sslmode: '{val}' (expected: disable, prefer, require)"
)));
}
};
} else if let Some(val) = param.strip_prefix("statement_timeout=") {
statement_timeout_secs = val.parse::<u32>().unwrap_or(30);
} else if let Some(val) = param.strip_prefix("host=") {
host_override = Some(url_decode(val)?);
}
}
let final_host = if let Some(h) = host_override {
h
} else {
url_decode(&host)?
};
let config = Config {
host: final_host,
port,
user: url_decode(user)?,
password: url_decode(password)?,
database: if database.is_empty() {
url_decode(user)?
} else {
url_decode(database)?
},
ssl,
statement_timeout_secs,
};
config.validate()?;
Ok(config)
}
pub fn validate(&self) -> Result<(), DriverError> {
if self.host.is_empty() {
return Err(DriverError::Protocol("host cannot be empty".into()));
}
if self.user.is_empty() {
return Err(DriverError::Protocol("user cannot be empty".into()));
}
if self.database.is_empty() {
return Err(DriverError::Protocol("database cannot be empty".into()));
}
Ok(())
}
pub fn host_is_uds(&self) -> bool {
self.host.starts_with('/')
}
pub fn uds_path(&self) -> String {
format!("{}/.s.PGSQL.{}", self.host, self.port)
}
}
fn url_decode(s: &str) -> Result<String, DriverError> {
let mut bytes = Vec::with_capacity(s.len());
let input = s.as_bytes();
let mut i = 0;
while i < input.len() {
if input[i] == b'%' {
if i + 2 >= input.len() {
return Err(DriverError::Protocol(format!(
"malformed percent-encoding in URL: '{s}'"
)));
}
let hi = hex_val(input[i + 1]).ok_or_else(|| {
DriverError::Protocol(format!(
"invalid hex digit '{}' in URL: '{s}'",
input[i + 1] as char
))
})?;
let lo = hex_val(input[i + 2]).ok_or_else(|| {
DriverError::Protocol(format!(
"invalid hex digit '{}' in URL: '{s}'",
input[i + 2] as char
))
})?;
bytes.push(hi * 16 + lo);
i += 3;
} else {
bytes.push(input[i]);
i += 1;
}
}
String::from_utf8(bytes)
.map_err(|_| DriverError::Protocol(format!("invalid UTF-8 in URL: '{s}'")))
}
fn hex_val(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
enum StartupAction {
AuthOk,
AuthCleartext,
AuthMd5([u8; 4]),
AuthSasl(Vec<u8>),
ParameterStatus(Box<str>, Box<str>),
BackendKeyData(i32, i32),
ReadyForQuery(u8),
Error(String),
Notice,
}
#[inline]
fn make_stmt_name(hash: u64) -> Box<str> {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut buf = [0u8; 18]; buf[0] = b's';
buf[1] = b'_';
let bytes = hash.to_be_bytes();
for (i, &b) in bytes.iter().enumerate() {
buf[2 + i * 2] = HEX[(b >> 4) as usize];
buf[2 + i * 2 + 1] = HEX[(b & 0x0f) as usize];
}
let s = std::str::from_utf8(&buf).expect("BUG: stmt name buffer contains only ASCII hex");
s.into()
}
struct StmtInfo {
name: Box<str>,
columns: Arc<[ColumnDesc]>,
last_used: u64,
bind_template: Option<BindTemplate>,
}
struct BindTemplate {
bytes: Vec<u8>,
bind_end: usize,
param_slots: Vec<(usize, i32)>,
}
#[derive(Debug, Clone)]
pub struct ColumnDesc {
pub name: Box<str>,
pub type_oid: u32,
pub type_size: i16,
pub table_oid: u32,
pub column_id: i16,
}
#[derive(Debug, Clone)]
pub struct PrepareResult {
pub columns: Vec<ColumnDesc>,
pub param_oids: Vec<u32>,
}
pub type SimpleRow = Vec<Option<String>>;
#[derive(Debug, Clone)]
pub struct Notification {
pub pid: i32,
pub channel: String,
pub payload: String,
}
pub struct Connection {
stream: Stream,
read_buf: Vec<u8>,
stream_buf: Vec<u8>,
stream_buf_pos: usize,
stream_buf_end: usize,
write_buf: Vec<u8>,
stmts: StmtCache,
params: Vec<(Box<str>, Box<str>)>,
pid: i32,
secret: i32,
tx_status: u8,
last_used: std::time::Instant,
streaming_active: bool,
created_at: std::time::Instant,
pending_notifications: Vec<Notification>,
max_stmt_cache_size: usize,
query_counter: u64,
}
impl Connection {
pub async fn connect(config: &Config) -> Result<Self, DriverError> {
#[cfg(unix)]
if config.host_is_uds() {
let path = config.uds_path();
let unix = tokio::net::UnixStream::connect(&path)
.await
.map_err(DriverError::Io)?;
let stream = Stream::Unix(unix);
return Self::finish_connect(stream, config).await;
}
let addr = format!("{}:{}", config.host, config.port);
let tcp = TcpStream::connect(&addr).await.map_err(DriverError::Io)?;
tcp.set_nodelay(true).map_err(DriverError::Io)?;
Self::set_keepalive(&tcp)?;
let stream = match config.ssl {
SslMode::Disable => Stream::Plain(tcp),
#[cfg(feature = "tls")]
SslMode::Prefer | SslMode::Require => {
match tls::try_upgrade(tcp, &config.host, config.ssl == SslMode::Require).await {
Ok(tls_stream) => Stream::Tls(Box::new(tls_stream)),
Err(e) if config.ssl == SslMode::Require => return Err(e),
Err(_) => {
let tcp = TcpStream::connect(&addr).await.map_err(DriverError::Io)?;
tcp.set_nodelay(true).map_err(DriverError::Io)?;
Self::set_keepalive(&tcp)?;
Stream::Plain(tcp)
}
}
}
#[cfg(not(feature = "tls"))]
SslMode::Require => {
return Err(DriverError::Protocol(
"TLS required but bsql-driver-postgres compiled without 'tls' feature".into(),
));
}
#[cfg(not(feature = "tls"))]
SslMode::Prefer => Stream::Plain(tcp),
};
Self::finish_connect(stream, config).await
}
async fn finish_connect(stream: Stream, config: &Config) -> Result<Self, DriverError> {
let mut conn = Self {
stream,
read_buf: Vec::with_capacity(8192),
stream_buf: vec![0u8; 65536],
stream_buf_pos: 0,
stream_buf_end: 0,
write_buf: Vec::with_capacity(4096),
stmts: StmtCache::default(),
params: Vec::new(),
pid: 0,
secret: 0,
tx_status: b'I',
last_used: std::time::Instant::now(),
streaming_active: false,
created_at: std::time::Instant::now(),
pending_notifications: Vec::new(),
max_stmt_cache_size: 256,
query_counter: 0,
};
conn.startup(config).await?;
conn.validate_server_params()?;
if config.statement_timeout_secs > 0 {
conn.simple_query(&format!(
"SET statement_timeout = '{}s'",
config.statement_timeout_secs
))
.await?;
}
Ok(conn)
}
async fn startup(&mut self, config: &Config) -> Result<(), DriverError> {
self.write_buf.clear();
proto::write_startup(&mut self.write_buf, &config.user, &config.database);
self.flush_write().await?;
loop {
let action = self.read_startup_action().await?;
match action {
StartupAction::AuthOk => {}
StartupAction::AuthCleartext => {
self.write_buf.clear();
let mut pw = config.password.as_bytes().to_vec();
pw.push(0);
proto::write_password(&mut self.write_buf, &pw);
self.flush_write().await?;
}
StartupAction::AuthMd5(salt) => {
self.write_buf.clear();
let hash = auth::md5_password(&config.user, &config.password, &salt);
proto::write_password(&mut self.write_buf, &hash);
self.flush_write().await?;
}
StartupAction::AuthSasl(mechanisms_data) => {
self.handle_scram(config, &mechanisms_data).await?;
}
StartupAction::ParameterStatus(name, value) => {
if let Some(entry) = self.params.iter_mut().find(|(k, _)| *k == name) {
entry.1 = value;
} else {
self.params.push((name, value));
}
}
StartupAction::BackendKeyData(pid, secret) => {
self.pid = pid;
self.secret = secret;
}
StartupAction::ReadyForQuery(status) => {
self.tx_status = status;
return Ok(());
}
StartupAction::Error(msg) => {
return Err(DriverError::Auth(msg));
}
StartupAction::Notice => {}
}
}
}
async fn read_startup_action(&mut self) -> Result<StartupAction, DriverError> {
let (msg_type, _) = self.read_message_buffered().await?;
self.read_startup_message_from_type(msg_type)
}
fn read_startup_message_from_type(&self, msg_type: u8) -> Result<StartupAction, DriverError> {
let payload = &self.read_buf;
let msg = proto::parse_backend_message(msg_type, payload)?;
match msg {
BackendMessage::AuthOk => Ok(StartupAction::AuthOk),
BackendMessage::AuthCleartext => Ok(StartupAction::AuthCleartext),
BackendMessage::AuthMd5 { salt } => Ok(StartupAction::AuthMd5(salt)),
BackendMessage::AuthSasl { mechanisms } => {
Ok(StartupAction::AuthSasl(mechanisms.to_vec()))
}
BackendMessage::ParameterStatus { name, value } => {
Ok(StartupAction::ParameterStatus(name.into(), value.into()))
}
BackendMessage::BackendKeyData { pid, secret } => {
Ok(StartupAction::BackendKeyData(pid, secret))
}
BackendMessage::ReadyForQuery { status } => Ok(StartupAction::ReadyForQuery(status)),
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
Ok(StartupAction::Error(fields.to_string()))
}
BackendMessage::NoticeResponse { .. } => Ok(StartupAction::Notice),
other => Err(DriverError::Protocol(format!(
"unexpected message during startup: {other:?}"
))),
}
}
async fn handle_scram(
&mut self,
config: &Config,
mechanisms_data: &[u8],
) -> Result<(), DriverError> {
let mechs = auth::parse_sasl_mechanisms(mechanisms_data);
if !mechs.contains(&"SCRAM-SHA-256") {
return Err(DriverError::Auth(format!(
"server requires unsupported SASL mechanism(s): {mechs:?}"
)));
}
let mut scram = auth::ScramClient::new(&config.user, &config.password)?;
let client_first = scram.client_first_message();
self.write_buf.clear();
proto::write_sasl_initial(&mut self.write_buf, "SCRAM-SHA-256", &client_first);
self.flush_write().await?;
let (msg_type, _) = self.read_message_buffered().await?;
let server_first = {
let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
match msg {
BackendMessage::AuthSaslContinue { data } => data.to_vec(),
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
return Err(DriverError::Auth(fields.to_string()));
}
other => {
return Err(DriverError::Protocol(format!(
"expected AuthSaslContinue, got: {other:?}"
)));
}
}
};
scram.process_server_first(&server_first)?;
let client_final = scram.client_final_message()?;
self.write_buf.clear();
proto::write_sasl_response(&mut self.write_buf, &client_final);
self.flush_write().await?;
let (msg_type, _) = self.read_message_buffered().await?;
{
let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
match msg {
BackendMessage::AuthSaslFinal { data } => {
let data_owned = data.to_vec();
scram.verify_server_final(&data_owned)?;
}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
return Err(DriverError::Auth(fields.to_string()));
}
other => {
return Err(DriverError::Protocol(format!(
"expected AuthSaslFinal, got: {other:?}"
)));
}
}
}
let (msg_type, _) = self.read_message_buffered().await?;
let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
match msg {
BackendMessage::AuthOk => Ok(()),
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
Err(DriverError::Auth(fields.to_string()))
}
other => Err(DriverError::Protocol(format!(
"expected AuthOk after SCRAM, got: {other:?}"
))),
}
}
pub async fn prepare_only(&mut self, sql: &str, sql_hash: u64) -> Result<(), DriverError> {
if self.stmts.contains_key(&sql_hash) {
return Ok(());
}
let name = make_stmt_name(sql_hash);
self.write_buf.clear();
proto::write_parse(&mut self.write_buf, &name, sql, &[]);
proto::write_describe(&mut self.write_buf, b'S', &name);
proto::write_sync(&mut self.write_buf);
self.flush_write().await?;
self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
.await?;
let columns = self.read_column_description().await?;
self.expect_ready().await?;
self.query_counter += 1;
self.cache_stmt(
sql_hash,
StmtInfo {
name,
columns,
last_used: self.query_counter,
bind_template: None,
},
);
Ok(())
}
pub async fn prepare_describe(&mut self, sql: &str) -> Result<PrepareResult, DriverError> {
self.write_buf.clear();
proto::write_parse(&mut self.write_buf, "", sql, &[]);
proto::write_describe(&mut self.write_buf, b'S', "");
proto::write_sync(&mut self.write_buf);
self.flush_write().await?;
self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
.await?;
let mut param_oids: Vec<u32> = Vec::new();
let columns;
loop {
let msg = self.read_one_message().await?;
match msg {
BackendMessage::ParameterDescription { data } => {
param_oids = proto::parse_parameter_description(data)?;
}
BackendMessage::RowDescription { data } => {
columns = proto::parse_row_description(data)?;
break;
}
BackendMessage::NoData => {
columns = Vec::new();
break;
}
BackendMessage::NoticeResponse { .. } => {}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.drain_to_ready().await?;
return Err(self.make_server_error(fields));
}
other => {
return Err(DriverError::Protocol(format!(
"expected ParameterDescription/RowDescription/NoData, got: {other:?}"
)));
}
}
}
self.expect_ready().await?;
Ok(PrepareResult {
columns,
param_oids,
})
}
pub async fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
self.write_buf.clear();
proto::write_simple_query(&mut self.write_buf, sql);
self.flush_write().await?;
let mut rows: Vec<SimpleRow> = Vec::new();
loop {
let msg = self.read_one_message().await?;
match msg {
BackendMessage::ReadyForQuery { status } => {
self.tx_status = status;
return Ok(rows);
}
BackendMessage::DataRow { data } => {
rows.push(proto::parse_simple_data_row(data)?);
}
BackendMessage::RowDescription { .. }
| BackendMessage::CommandComplete { .. }
| BackendMessage::EmptyQuery
| BackendMessage::NoticeResponse { .. } => {}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.drain_to_ready().await?;
return Err(self.make_server_error(fields));
}
BackendMessage::ParameterStatus { .. } => {}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during simple_query_rows: {other:?}"
)));
}
}
}
}
pub async fn query_streaming_start(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
chunk_size: i32,
) -> Result<(Arc<[ColumnDesc]>, bool), DriverError> {
self.write_buf.clear();
let columns = if let Some(info) = self.stmts.get_mut(&sql_hash) {
self.query_counter += 1;
info.last_used = self.query_counter;
let can_use_template = info
.bind_template
.as_ref()
.is_some_and(|t| t.param_slots.len() == params.len());
if can_use_template {
let tmpl = info.bind_template.as_ref().unwrap();
self.write_buf
.extend_from_slice(&tmpl.bytes[..tmpl.bind_end]);
let mut template_ok = true;
for (i, param) in params.iter().enumerate() {
let (data_offset, old_len) = tmpl.param_slots[i];
if param.is_null() {
let len_offset = data_offset - 4;
self.write_buf[len_offset..len_offset + 4]
.copy_from_slice(&(-1i32).to_be_bytes());
} else if old_len >= 0 {
let end = data_offset + old_len as usize;
if !param.encode_at(&mut self.write_buf[data_offset..end]) {
template_ok = false;
break;
}
} else {
template_ok = false;
break;
}
}
if !template_ok {
self.write_buf.clear();
proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
info.bind_template = None;
}
} else {
proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
}
let cols = info.columns.clone();
if info.bind_template.is_none() && !self.write_buf.is_empty() {
info.bind_template = build_bind_template(&self.write_buf, params.len());
}
proto::write_execute(&mut self.write_buf, "", chunk_size);
proto::write_flush(&mut self.write_buf);
self.flush_write().await?;
cols
} else {
let name = make_stmt_name(sql_hash);
let param_oids: smallvec::SmallVec<[u32; 8]> =
params.iter().map(|p| p.type_oid()).collect();
proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
proto::write_describe(&mut self.write_buf, b'S', &name);
proto::write_bind_params(&mut self.write_buf, "", &name, params);
proto::write_execute(&mut self.write_buf, "", chunk_size);
proto::write_flush(&mut self.write_buf);
self.flush_write().await?;
self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
.await?;
let columns = self.read_column_description().await?;
self.query_counter += 1;
self.cache_stmt(
sql_hash,
StmtInfo {
name,
columns: columns.clone(),
last_used: self.query_counter,
bind_template: None,
},
);
columns
};
self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
.await?;
self.streaming_active = true;
Ok((columns, false))
}
pub async fn streaming_next_chunk(
&mut self,
arena: &mut Arena,
all_col_offsets: &mut Vec<(usize, i32)>,
) -> Result<bool, DriverError> {
all_col_offsets.clear();
loop {
let msg = self.read_one_message().await?;
match msg {
BackendMessage::DataRow { data } => {
parse_data_row_flat(data, arena, all_col_offsets)?;
}
BackendMessage::PortalSuspended => {
return Ok(true);
}
BackendMessage::CommandComplete { .. } => {
self.write_buf.clear();
proto::write_sync(&mut self.write_buf);
self.flush_write().await?;
self.expect_ready().await?;
self.shrink_buffers();
self.streaming_active = false;
return Ok(false);
}
BackendMessage::EmptyQuery => {
self.write_buf.clear();
proto::write_sync(&mut self.write_buf);
self.flush_write().await?;
self.expect_ready().await?;
self.streaming_active = false;
return Ok(false);
}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.write_buf.clear();
proto::write_sync(&mut self.write_buf);
self.flush_write().await?;
self.drain_to_ready().await?;
self.streaming_active = false;
return Err(self.make_server_error(fields));
}
BackendMessage::NoticeResponse { .. } => {}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during streaming: {other:?}"
)));
}
}
}
}
pub async fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
self.write_buf.clear();
proto::write_execute(&mut self.write_buf, "", chunk_size);
proto::write_flush(&mut self.write_buf);
self.flush_write().await
}
async fn send_pipeline(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
need_columns: bool,
skip_bind_complete: bool,
) -> Result<Option<Arc<[ColumnDesc]>>, DriverError> {
debug_assert_eq!(
hash_sql(sql),
sql_hash,
"sql_hash mismatch: caller-provided hash does not match hash_sql(sql)"
);
if params.len() > i16::MAX as usize {
return Err(DriverError::Protocol(format!(
"parameter count {} exceeds maximum {} for PG wire protocol",
params.len(),
i16::MAX
)));
}
self.write_buf.clear();
let columns = if let Some(info) = self.stmts.get_mut(&sql_hash) {
self.query_counter += 1;
info.last_used = self.query_counter;
let can_use_template = info
.bind_template
.as_ref()
.is_some_and(|t| t.param_slots.len() == params.len());
let mut has_exec_sync = false;
if can_use_template {
let tmpl = info.bind_template.as_ref().unwrap();
self.write_buf.extend_from_slice(&tmpl.bytes);
let mut template_ok = true;
for (i, param) in params.iter().enumerate() {
let (data_offset, old_len) = tmpl.param_slots[i];
if param.is_null() {
let len_offset = data_offset - 4;
self.write_buf[len_offset..len_offset + 4]
.copy_from_slice(&(-1i32).to_be_bytes());
} else if old_len >= 0 {
let end = data_offset + old_len as usize;
if !param.encode_at(&mut self.write_buf[data_offset..end]) {
template_ok = false;
break;
}
} else {
template_ok = false;
break;
}
}
if template_ok {
has_exec_sync = true; } else {
self.write_buf.clear();
proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
info.bind_template = None;
}
} else {
proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
}
let cols = if need_columns {
Some(info.columns.clone())
} else {
None
};
if info.bind_template.is_none() && !self.write_buf.is_empty() {
info.bind_template = build_bind_template(&self.write_buf, params.len());
}
if !has_exec_sync {
self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
}
self.flush_write().await?;
cols
} else {
let name = make_stmt_name(sql_hash);
let param_oids: smallvec::SmallVec<[u32; 8]> =
params.iter().map(|p| p.type_oid()).collect();
proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
proto::write_describe(&mut self.write_buf, b'S', &name);
proto::write_bind_params(&mut self.write_buf, "", &name, params);
self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
self.flush_write().await?;
self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
.await?;
let columns = self.read_column_description().await?;
self.query_counter += 1;
self.cache_stmt(
sql_hash,
StmtInfo {
name,
columns: columns.clone(),
last_used: self.query_counter,
bind_template: None,
},
);
if need_columns { Some(columns) } else { None }
};
if !skip_bind_complete {
self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
.await?;
}
Ok(columns)
}
pub async fn query(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
arena: &mut Arena,
) -> Result<QueryResult, DriverError> {
let columns = self
.send_pipeline(sql, sql_hash, params, true, false)
.await?
.expect("send_pipeline(need_columns=true) must return Some");
let num_cols = columns.len();
let mut all_col_offsets: Vec<(usize, i32)> = Vec::with_capacity(num_cols.max(1) * 8);
let mut affected_rows: u64 = 0;
loop {
let msg = self.read_one_message().await?;
match msg {
BackendMessage::DataRow { data } => {
parse_data_row_flat(data, arena, &mut all_col_offsets)?;
}
BackendMessage::CommandComplete { tag } => {
affected_rows = proto::parse_command_tag(tag);
break;
}
BackendMessage::EmptyQuery => {
break;
}
BackendMessage::NoticeResponse { .. } => {
}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.maybe_invalidate_stmt_cache(&fields, sql_hash);
self.drain_to_ready().await?;
return Err(self.make_server_error(fields));
}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during query: {other:?}"
)));
}
}
}
self.expect_ready().await?;
self.shrink_buffers();
Ok(QueryResult {
all_col_offsets,
num_cols,
columns,
affected_rows,
})
}
async fn read_column_description(&mut self) -> Result<Arc<[ColumnDesc]>, DriverError> {
loop {
let msg = self.read_one_message().await?;
match msg {
BackendMessage::RowDescription { data } => {
let cols = proto::parse_row_description(data)?;
return Ok(cols.into());
}
BackendMessage::ParameterDescription { .. } => {
}
BackendMessage::NoData => return Ok(Arc::from(Vec::new())),
BackendMessage::NoticeResponse { .. } => {}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.drain_to_ready().await?;
return Err(self.make_server_error(fields));
}
other => {
return Err(DriverError::Protocol(format!(
"expected RowDescription/NoData after Parse, got: {other:?}"
)));
}
}
}
}
pub async fn execute(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<u64, DriverError> {
let _ = self
.send_pipeline(sql, sql_hash, params, false, false)
.await?;
let mut affected_rows: u64 = 0;
loop {
let msg = self.read_one_message().await?;
match msg {
BackendMessage::DataRow { .. } => {
}
BackendMessage::CommandComplete { tag } => {
affected_rows = proto::parse_command_tag(tag);
break;
}
BackendMessage::EmptyQuery => break,
BackendMessage::NoticeResponse { .. } => {}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.maybe_invalidate_stmt_cache(&fields, sql_hash);
self.drain_to_ready().await?;
return Err(self.make_server_error(fields));
}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during execute: {other:?}"
)));
}
}
}
self.expect_ready().await?;
self.shrink_buffers();
Ok(affected_rows)
}
pub async fn execute_pipeline(
&mut self,
sql: &str,
sql_hash: u64,
param_sets: &[&[&(dyn Encode + Sync)]],
) -> Result<Vec<u64>, DriverError> {
if param_sets.is_empty() {
return Ok(Vec::new());
}
debug_assert_eq!(
hash_sql(sql),
sql_hash,
"sql_hash mismatch: caller-provided hash does not match hash_sql(sql)"
);
self.write_buf.clear();
if !self.stmts.contains_key(&sql_hash) {
let name = make_stmt_name(sql_hash);
let first_params = param_sets[0];
if first_params.len() > i16::MAX as usize {
return Err(DriverError::Protocol(format!(
"parameter count {} exceeds maximum {}",
first_params.len(),
i16::MAX
)));
}
let param_oids: smallvec::SmallVec<[u32; 8]> =
first_params.iter().map(|p| p.type_oid()).collect();
proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
proto::write_describe(&mut self.write_buf, b'S', &name);
proto::write_sync(&mut self.write_buf);
self.flush_write().await?;
self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
.await?;
let columns = self.read_column_description().await?;
self.expect_ready().await?;
self.query_counter += 1;
self.cache_stmt(
sql_hash,
StmtInfo {
name,
columns,
last_used: self.query_counter,
bind_template: None,
},
);
self.write_buf.clear();
}
let stmt_name = self
.stmts
.get(&sql_hash)
.expect("BUG: stmt just cached but not found")
.name
.clone();
let count = param_sets.len();
for params in param_sets {
if params.len() > i16::MAX as usize {
return Err(DriverError::Protocol(format!(
"parameter count {} exceeds maximum {}",
params.len(),
i16::MAX
)));
}
proto::write_bind_params(&mut self.write_buf, "", &stmt_name, params);
self.write_buf.extend_from_slice(proto::EXECUTE_ONLY);
}
self.write_buf.extend_from_slice(proto::SYNC_ONLY);
self.flush_write().await?;
let mut results = Vec::with_capacity(count);
for _ in 0..count {
self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
.await?;
let mut affected_rows: u64 = 0;
loop {
let msg = self.read_one_message().await?;
match msg {
BackendMessage::DataRow { .. } => {}
BackendMessage::CommandComplete { tag } => {
affected_rows = proto::parse_command_tag(tag);
break;
}
BackendMessage::EmptyQuery => break,
BackendMessage::NoticeResponse { .. } => {}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.maybe_invalidate_stmt_cache(&fields, sql_hash);
self.drain_to_ready().await?;
return Err(self.make_server_error(fields));
}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during execute_pipeline: {other:?}"
)));
}
}
}
results.push(affected_rows);
}
self.expect_ready().await?;
self.shrink_buffers();
Ok(results)
}
pub(crate) async fn ensure_stmt_prepared(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<Box<str>, DriverError> {
if let Some(info) = self.stmts.get(&sql_hash) {
return Ok(info.name.clone());
}
let name = make_stmt_name(sql_hash);
if params.len() > i16::MAX as usize {
return Err(DriverError::Protocol(format!(
"parameter count {} exceeds maximum {}",
params.len(),
i16::MAX
)));
}
let param_oids: smallvec::SmallVec<[u32; 8]> =
params.iter().map(|p| p.type_oid()).collect();
self.write_buf.clear();
proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
proto::write_describe(&mut self.write_buf, b'S', &name);
proto::write_sync(&mut self.write_buf);
self.flush_write().await?;
self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
.await?;
let columns = self.read_column_description().await?;
self.expect_ready().await?;
self.query_counter += 1;
let stmt_name = name.clone();
self.cache_stmt(
sql_hash,
StmtInfo {
name,
columns,
last_used: self.query_counter,
bind_template: None,
},
);
Ok(stmt_name)
}
pub(crate) fn write_deferred_bind_execute(
&self,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
buf: &mut Vec<u8>,
) {
let stmt_name = &self
.stmts
.get(&sql_hash)
.expect("BUG: stmt just cached but not found")
.name;
proto::write_bind_params(buf, "", stmt_name, params);
buf.extend_from_slice(proto::EXECUTE_ONLY);
}
pub(crate) async fn flush_deferred_pipeline(
&mut self,
buf: &mut Vec<u8>,
count: usize,
) -> Result<Vec<u64>, DriverError> {
if count == 0 {
buf.clear();
return Ok(Vec::new());
}
buf.extend_from_slice(proto::SYNC_ONLY);
self.stream.write_all(buf).await.map_err(DriverError::Io)?;
self.stream.flush().await.map_err(DriverError::Io)?;
buf.clear();
let mut results = Vec::with_capacity(count);
for _ in 0..count {
self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
.await?;
let mut affected_rows: u64 = 0;
loop {
let msg = self.read_one_message().await?;
match msg {
BackendMessage::DataRow { .. } => {}
BackendMessage::CommandComplete { tag } => {
affected_rows = proto::parse_command_tag(tag);
break;
}
BackendMessage::EmptyQuery => break,
BackendMessage::NoticeResponse { .. } => {}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.drain_to_ready().await?;
return Err(self.make_server_error(fields));
}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during flush_deferred_pipeline: {other:?}"
)));
}
}
}
results.push(affected_rows);
}
self.expect_ready().await?;
self.shrink_buffers();
Ok(results)
}
pub async fn for_each<F>(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
mut f: F,
) -> Result<(), DriverError>
where
F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
{
let _ = self
.send_pipeline(sql, sql_hash, params, false, false)
.await?;
loop {
let msg = self.read_one_message().await?;
match msg {
BackendMessage::DataRow { data } => {
let row = PgDataRow::new(data)?;
f(row)?;
}
BackendMessage::CommandComplete { .. } => break,
BackendMessage::EmptyQuery => break,
BackendMessage::NoticeResponse { .. } => {}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.maybe_invalidate_stmt_cache(&fields, sql_hash);
self.drain_to_ready().await?;
return Err(self.make_server_error(fields));
}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during for_each: {other:?}"
)));
}
}
}
self.expect_ready().await?;
self.shrink_buffers();
Ok(())
}
pub async fn for_each_raw<F>(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
mut f: F,
) -> Result<(), DriverError>
where
F: FnMut(&[u8]) -> Result<(), DriverError>,
{
let _ = self
.send_pipeline(sql, sql_hash, params, false, true)
.await?;
loop {
let avail = self.stream_buf_end - self.stream_buf_pos;
if avail >= 5 {
let bc_type = self.stream_buf[self.stream_buf_pos];
match bc_type {
b'2' => {
self.stream_buf_pos += 5;
break;
}
b'E' => {
let msg = self.read_one_message().await?;
if let BackendMessage::ErrorResponse { data } = msg {
let fields = proto::parse_error_response(data);
self.drain_to_ready().await?;
return Err(self.make_server_error(fields));
}
}
b'N' | b'S' => {
let raw_len = i32::from_be_bytes([
self.stream_buf[self.stream_buf_pos + 1],
self.stream_buf[self.stream_buf_pos + 2],
self.stream_buf[self.stream_buf_pos + 3],
self.stream_buf[self.stream_buf_pos + 4],
]);
let total = 1 + raw_len as usize;
if avail >= total {
self.stream_buf_pos += total;
continue;
}
self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
.await?;
break;
}
_ => {
self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
.await?;
break;
}
}
} else {
let remaining = self.stream_buf_end - self.stream_buf_pos;
if remaining > 0 && self.stream_buf_pos > 0 {
self.stream_buf
.copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
}
self.stream_buf_pos = 0;
self.stream_buf_end = remaining;
let n = {
let mut reader = StreamReader(&mut self.stream);
use tokio::io::AsyncReadExt;
reader
.read(&mut self.stream_buf[remaining..])
.await
.map_err(DriverError::Io)?
};
if n == 0 {
return Err(DriverError::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"connection closed",
)));
}
self.stream_buf_end = remaining + n;
}
}
'outer: loop {
loop {
let avail = self.stream_buf_end - self.stream_buf_pos;
if avail < 5 {
break; }
let msg_type = self.stream_buf[self.stream_buf_pos];
let raw_len = i32::from_be_bytes([
self.stream_buf[self.stream_buf_pos + 1],
self.stream_buf[self.stream_buf_pos + 2],
self.stream_buf[self.stream_buf_pos + 3],
self.stream_buf[self.stream_buf_pos + 4],
]);
if raw_len < 4 {
return Err(DriverError::Protocol(format!(
"invalid message length {raw_len} for type '{}'",
msg_type as char
)));
}
let payload_len = (raw_len - 4) as usize;
let total_msg_len = 5 + payload_len;
if avail < total_msg_len {
if total_msg_len > self.stream_buf.len() {
let msg = self.read_one_message().await?;
match msg {
BackendMessage::DataRow { data } => {
f(data)?;
continue;
}
BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
break 'outer;
}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.maybe_invalidate_stmt_cache(&fields, sql_hash);
self.drain_to_ready().await?;
return Err(self.make_server_error(fields));
}
BackendMessage::NoticeResponse { .. } => continue,
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during for_each_raw: {other:?}"
)));
}
}
}
break;
}
let payload_start = self.stream_buf_pos + 5;
let payload_end = payload_start + payload_len;
if msg_type == b'D' {
f(&self.stream_buf[payload_start..payload_end])?;
} else if msg_type == b'C' || msg_type == b'I' {
self.stream_buf_pos += total_msg_len;
break 'outer;
} else {
self.handle_non_datarow_async(msg_type, payload_start, payload_end, sql_hash)
.await?;
}
self.stream_buf_pos += total_msg_len;
}
let remaining = self.stream_buf_end - self.stream_buf_pos;
if remaining > 0 && self.stream_buf_pos > 0 {
self.stream_buf
.copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
}
self.stream_buf_pos = 0;
self.stream_buf_end = remaining;
let n = {
let mut reader = StreamReader(&mut self.stream);
use tokio::io::AsyncReadExt;
reader
.read(&mut self.stream_buf[remaining..])
.await
.map_err(DriverError::Io)?
};
if n == 0 {
return Err(DriverError::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"connection closed",
)));
}
self.stream_buf_end = remaining + n;
}
self.expect_ready().await?;
self.shrink_buffers();
Ok(())
}
pub async fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
self.write_buf.clear();
proto::write_simple_query(&mut self.write_buf, sql);
self.flush_write().await?;
loop {
let msg = self.read_one_message().await?;
match msg {
BackendMessage::ReadyForQuery { status } => {
self.tx_status = status;
return Ok(());
}
BackendMessage::CommandComplete { .. }
| BackendMessage::RowDescription { .. }
| BackendMessage::DataRow { .. }
| BackendMessage::EmptyQuery
| BackendMessage::NoticeResponse { .. } => {}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.drain_to_ready().await?;
return Err(self.make_server_error(fields));
}
BackendMessage::ParameterStatus { .. } => {}
BackendMessage::AuthOk
| BackendMessage::AuthSaslFinal { .. }
| BackendMessage::AuthSaslContinue { .. }
| BackendMessage::AuthSasl { .. }
| BackendMessage::AuthMd5 { .. }
| BackendMessage::AuthCleartext
| BackendMessage::BackendKeyData { .. } => {}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during simple_query: {other:?}"
)));
}
}
}
}
pub async fn wait_for_notification(&mut self) -> Result<(String, String), DriverError> {
loop {
let (msg_type, _payload_len) = self.read_message_buffered().await?;
let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
match msg {
BackendMessage::NotificationResponse {
channel, payload, ..
} => {
return Ok((channel.to_owned(), payload.to_owned()));
}
BackendMessage::ParameterStatus { .. } | BackendMessage::NoticeResponse { .. } => {
continue;
}
_ => continue,
}
}
}
pub async fn close(mut self) -> Result<(), DriverError> {
self.write_buf.clear();
proto::write_terminate(&mut self.write_buf);
let _ = self.flush_write().await;
Ok(())
}
pub fn is_idle(&self) -> bool {
self.tx_status == b'I'
}
pub fn is_in_transaction(&self) -> bool {
self.tx_status == b'T'
}
pub fn is_in_failed_transaction(&self) -> bool {
self.tx_status == b'E'
}
pub fn touch(&mut self) {
self.last_used = std::time::Instant::now();
}
pub fn idle_duration(&self) -> std::time::Duration {
self.last_used.elapsed()
}
pub fn parameter(&self, name: &str) -> Option<&str> {
self.params
.iter()
.find(|(k, _)| &**k == name)
.map(|(_, v)| &**v)
}
pub fn server_params(&self) -> &[(Box<str>, Box<str>)] {
&self.params
}
fn validate_server_params(&self) -> Result<(), DriverError> {
if let Some(encoding) = self.parameter("server_encoding") {
let normalized = encoding.to_uppercase();
if normalized != "UTF8" && normalized != "UTF-8" {
return Err(DriverError::Protocol(format!(
"server_encoding is '{encoding}', but bsql requires UTF-8. \
Set server encoding to UTF-8 in postgresql.conf or \
use CREATE DATABASE ... ENCODING 'UTF8'."
)));
}
}
if let Some(encoding) = self.parameter("client_encoding") {
let normalized = encoding.to_uppercase();
if normalized != "UTF8" && normalized != "UTF-8" {
return Err(DriverError::Protocol(format!(
"client_encoding is '{encoding}', but bsql requires UTF-8. \
Check your connection or database configuration."
)));
}
}
if let Some(idt) = self.parameter("integer_datetimes") {
if idt != "on" {
return Err(DriverError::Protocol(format!(
"integer_datetimes is '{idt}', but bsql requires 'on'. \
Our timestamp codec assumes integer-format timestamps \
(microseconds since 2000-01-01). Float-format timestamps \
would produce incorrect decode results."
)));
}
}
Ok(())
}
pub fn pid(&self) -> i32 {
self.pid
}
pub fn secret_key(&self) -> i32 {
self.secret
}
pub async fn cancel(&self, config: &Config) -> Result<(), DriverError> {
let addr = format!("{}:{}", config.host, config.port);
let mut tcp = TcpStream::connect(&addr).await.map_err(DriverError::Io)?;
let mut buf = Vec::with_capacity(16);
proto::write_cancel_request(&mut buf, self.pid, self.secret);
tcp.write_all(&buf).await.map_err(DriverError::Io)?;
tcp.flush().await.map_err(DriverError::Io)?;
drop(tcp);
Ok(())
}
pub fn is_streaming(&self) -> bool {
self.streaming_active
}
pub fn drain_notifications(&mut self) -> Vec<Notification> {
std::mem::take(&mut self.pending_notifications)
}
pub fn pending_notification_count(&self) -> usize {
self.pending_notifications.len()
}
pub fn set_max_stmt_cache_size(&mut self, size: usize) {
self.max_stmt_cache_size = size;
}
pub fn stmt_cache_len(&self) -> usize {
self.stmts.len()
}
fn set_keepalive(tcp: &TcpStream) -> Result<(), DriverError> {
let sock = socket2::SockRef::from(tcp);
let ka = socket2::TcpKeepalive::new()
.with_time(std::time::Duration::from_secs(60))
.with_interval(std::time::Duration::from_secs(15));
sock.set_tcp_keepalive(&ka).map_err(DriverError::Io)?;
Ok(())
}
pub fn created_at(&self) -> std::time::Instant {
self.created_at
}
fn cache_stmt(&mut self, sql_hash: u64, info: StmtInfo) {
if self.stmts.len() >= self.max_stmt_cache_size && !self.stmts.contains_key(&sql_hash) {
if let Some((_lru_hash, evicted)) = self.stmts.evict_lru() {
proto::write_close(&mut self.write_buf, b'S', &evicted.name);
}
}
self.stmts.insert(sql_hash, info);
}
fn buffer_notification(&mut self, pid: i32, channel: &str, payload: &str) {
if self.pending_notifications.len() < 1024 {
self.pending_notifications.push(Notification {
pid,
channel: channel.to_owned(),
payload: payload.to_owned(),
});
}
}
fn shrink_buffers(&mut self) {
if self.query_counter & 63 != 0 {
return;
}
if self.read_buf.capacity() > 64 * 1024 {
self.read_buf.clear();
self.read_buf.shrink_to(8192);
}
if self.write_buf.capacity() > 16 * 1024 {
self.write_buf.clear();
self.write_buf.shrink_to(8192);
}
}
async fn read_one_message(&mut self) -> Result<BackendMessage<'_>, DriverError> {
loop {
let (msg_type, _payload_len) = self.read_message_buffered().await?;
if msg_type == b'A' {
let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
if let BackendMessage::NotificationResponse {
pid,
channel,
payload,
} = msg
{
let pid_owned = pid;
let channel_owned = channel.to_owned();
let payload_owned = payload.to_owned();
self.buffer_notification(pid_owned, &channel_owned, &payload_owned);
continue; }
}
return proto::parse_backend_message(msg_type, &self.read_buf);
}
}
async fn expect_message(
&mut self,
pred: impl Fn(&BackendMessage<'_>) -> bool,
) -> Result<(), DriverError> {
loop {
let msg = self.read_one_message().await?;
if pred(&msg) {
return Ok(());
}
match msg {
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.drain_to_ready().await?;
return Err(self.make_server_error(fields));
}
BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {
}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message while waiting for expected type: {other:?}"
)));
}
}
}
}
async fn expect_ready(&mut self) -> Result<(), DriverError> {
loop {
let msg = self.read_one_message().await?;
match msg {
BackendMessage::ReadyForQuery { status } => {
self.tx_status = status;
return Ok(());
}
BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.drain_to_ready().await?;
return Err(self.make_server_error(fields));
}
_ => {}
}
}
}
async fn drain_to_ready(&mut self) -> Result<(), DriverError> {
loop {
let msg = self.read_one_message().await?;
if let BackendMessage::ReadyForQuery { status } = msg {
self.tx_status = status;
return Ok(());
}
}
}
fn maybe_invalidate_stmt_cache(&mut self, fields: &proto::ErrorFields, sql_hash: u64) -> bool {
if &*fields.code == "26000" {
self.stmts.remove(&sql_hash);
true
} else {
false
}
}
#[cold]
#[inline(never)]
fn make_server_error(&self, fields: proto::ErrorFields) -> DriverError {
DriverError::Server {
code: fields.code,
message: fields.message.into_boxed_str(),
detail: fields.detail.map(String::into_boxed_str),
hint: fields.hint.map(String::into_boxed_str),
position: fields.position,
}
}
#[cold]
async fn handle_non_datarow_async(
&mut self,
msg_type: u8,
payload_start: usize,
payload_end: usize,
sql_hash: u64,
) -> Result<(), DriverError> {
match msg_type {
b'E' => {
let fields =
proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
self.maybe_invalidate_stmt_cache(&fields, sql_hash);
self.drain_to_ready().await?;
return Err(self.make_server_error(fields));
}
b'A' => {
let msg = proto::parse_backend_message(
msg_type,
&self.stream_buf[payload_start..payload_end],
)?;
if let BackendMessage::NotificationResponse {
pid,
channel,
payload,
} = msg
{
let ch = channel.to_owned();
let pl = payload.to_owned();
self.buffer_notification(pid, &ch, &pl);
}
}
_ => {} }
Ok(())
}
async fn flush_write(&mut self) -> Result<(), DriverError> {
self.stream
.write_all(&self.write_buf)
.await
.map_err(DriverError::Io)?;
self.stream.flush().await.map_err(DriverError::Io)?;
Ok(())
}
async fn read_message_buffered(&mut self) -> Result<(u8, usize), DriverError> {
let mut header = [0u8; 5];
buffered_read_exact(
&mut self.stream,
&mut self.stream_buf,
&mut self.stream_buf_pos,
&mut self.stream_buf_end,
&mut header,
)
.await?;
let msg_type = header[0];
let len = i32::from_be_bytes([header[1], header[2], header[3], header[4]]);
if len < 4 {
return Err(DriverError::Protocol(format!(
"invalid message length {len} for type '{}'",
msg_type as char
)));
}
const MAX_MESSAGE_LEN: i32 = 128 * 1024 * 1024;
if len > MAX_MESSAGE_LEN {
return Err(DriverError::Protocol(format!(
"message length {len} exceeds maximum ({MAX_MESSAGE_LEN}) for type '{}'",
msg_type as char
)));
}
let payload_len = (len - 4) as usize;
self.read_buf.clear();
self.read_buf.resize(payload_len, 0);
if payload_len > 0 {
buffered_read_exact(
&mut self.stream,
&mut self.stream_buf,
&mut self.stream_buf_pos,
&mut self.stream_buf_end,
&mut self.read_buf[..payload_len],
)
.await?;
}
Ok((msg_type, payload_len))
}
}
async fn buffered_read_exact(
stream: &mut Stream,
buf: &mut [u8],
pos: &mut usize,
end: &mut usize,
out: &mut [u8],
) -> Result<(), DriverError> {
let mut filled = 0;
while filled < out.len() {
let avail = *end - *pos;
if avail > 0 {
let take = avail.min(out.len() - filled);
out[filled..filled + take].copy_from_slice(&buf[*pos..*pos + take]);
*pos += take;
filled += take;
} else {
*pos = 0;
let n = {
let mut reader = StreamReader(stream);
use tokio::io::AsyncReadExt;
reader.read(buf).await.map_err(DriverError::Io)?
};
if n == 0 {
return Err(DriverError::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"connection closed",
)));
}
*end = n;
}
}
Ok(())
}
fn build_bind_template(write_buf: &[u8], param_count: usize) -> Option<BindTemplate> {
if write_buf.is_empty() || write_buf[0] != b'B' {
return None;
}
if write_buf.len() < 5 {
return None;
}
let mut pos = 5;
while pos < write_buf.len() && write_buf[pos] != 0 {
pos += 1;
}
pos += 1;
while pos < write_buf.len() && write_buf[pos] != 0 {
pos += 1;
}
pos += 1;
if pos + 2 > write_buf.len() {
return None;
}
let num_fmt_codes = i16::from_be_bytes([write_buf[pos], write_buf[pos + 1]]);
pos += 2;
pos += num_fmt_codes.max(0) as usize * 2;
if pos + 2 > write_buf.len() {
return None;
}
let wire_param_count = i16::from_be_bytes([write_buf[pos], write_buf[pos + 1]]) as usize;
pos += 2;
if wire_param_count != param_count {
return None;
}
let mut param_slots = Vec::with_capacity(param_count);
for _ in 0..param_count {
if pos + 4 > write_buf.len() {
return None;
}
let data_len = i32::from_be_bytes([
write_buf[pos],
write_buf[pos + 1],
write_buf[pos + 2],
write_buf[pos + 3],
]);
pos += 4;
if data_len < 0 {
param_slots.push((pos, -1));
} else {
param_slots.push((pos, data_len));
pos += data_len as usize;
}
}
let bind_end = write_buf.len();
let mut bytes = Vec::with_capacity(bind_end + proto::EXECUTE_SYNC.len());
bytes.extend_from_slice(write_buf);
bytes.extend_from_slice(proto::EXECUTE_SYNC);
Some(BindTemplate {
bytes,
bind_end,
param_slots,
})
}
pub struct QueryResult {
all_col_offsets: Vec<(usize, i32)>,
num_cols: usize,
columns: Arc<[ColumnDesc]>,
affected_rows: u64,
}
impl QueryResult {
pub fn from_parts(
all_col_offsets: Vec<(usize, i32)>,
num_cols: usize,
columns: Arc<[ColumnDesc]>,
affected_rows: u64,
) -> Self {
Self {
all_col_offsets,
num_cols,
columns,
affected_rows,
}
}
pub fn len(&self) -> usize {
if self.num_cols == 0 {
return 0;
}
self.all_col_offsets.len() / self.num_cols
}
pub fn is_empty(&self) -> bool {
self.all_col_offsets.is_empty()
}
pub fn affected_rows(&self) -> u64 {
self.affected_rows
}
pub fn columns(&self) -> &[ColumnDesc] {
&self.columns
}
pub fn row<'a>(&'a self, idx: usize, arena: &'a Arena) -> Row<'a> {
let start = idx * self.num_cols;
let end = start + self.num_cols;
Row {
arena,
col_offsets: &self.all_col_offsets[start..end],
columns: &self.columns,
}
}
pub fn take_col_offsets(&mut self) -> Vec<(usize, i32)> {
std::mem::take(&mut self.all_col_offsets)
}
pub fn rows<'a>(&'a self, arena: &'a Arena) -> impl Iterator<Item = Row<'a>> {
let num_cols = self.num_cols;
let columns = &self.columns;
self.all_col_offsets
.chunks(num_cols.max(1))
.map(move |chunk| Row {
arena,
col_offsets: chunk,
columns,
})
}
}
pub struct Row<'a> {
arena: &'a Arena,
col_offsets: &'a [(usize, i32)],
columns: &'a [ColumnDesc],
}
impl<'a> Row<'a> {
pub fn get_raw(&self, idx: usize) -> Option<&'a [u8]> {
let (offset, len) = self.col_offsets[idx];
if len < 0 {
None
} else {
Some(self.arena.get(offset, len as usize))
}
}
pub fn is_null(&self, idx: usize) -> bool {
self.col_offsets[idx].1 < 0
}
pub fn column_count(&self) -> usize {
self.col_offsets.len()
}
pub fn get_bool(&self, idx: usize) -> Option<bool> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_bool(data).ok())
}
pub fn get_i16(&self, idx: usize) -> Option<i16> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_i16(data).ok())
}
pub fn get_i32(&self, idx: usize) -> Option<i32> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_i32(data).ok())
}
pub fn get_i64(&self, idx: usize) -> Option<i64> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_i64(data).ok())
}
pub fn get_f32(&self, idx: usize) -> Option<f32> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_f32(data).ok())
}
pub fn get_f64(&self, idx: usize) -> Option<f64> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_f64(data).ok())
}
pub fn get_str(&self, idx: usize) -> Option<&'a str> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_str(data).ok())
}
pub fn get_bytes(&self, idx: usize) -> Option<&'a [u8]> {
self.get_raw(idx)
}
pub fn column_name(&self, idx: usize) -> &str {
&self.columns[idx].name
}
pub fn column_type_oid(&self, idx: usize) -> u32 {
self.columns[idx].type_oid
}
}
pub struct PgDataRow<'a> {
data: &'a [u8],
offsets: smallvec::SmallVec<[(usize, i32); 16]>,
}
impl<'a> PgDataRow<'a> {
pub fn new(data: &'a [u8]) -> Result<Self, DriverError> {
if data.len() < 2 {
return Err(DriverError::Protocol("DataRow too short".into()));
}
let num_cols = i16::from_be_bytes([data[0], data[1]]);
if num_cols < 0 {
return Err(DriverError::Protocol(
"DataRow: negative column count".into(),
));
}
let num_cols = num_cols as usize;
let mut offsets = smallvec::SmallVec::<[(usize, i32); 16]>::with_capacity(num_cols);
let mut pos = 2usize;
for _ in 0..num_cols {
if pos + 4 > data.len() {
return Err(DriverError::Protocol("DataRow truncated".into()));
}
let col_len =
i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
offsets.push((pos, col_len));
if col_len > 0 {
pos += col_len as usize;
}
}
Ok(Self { data, offsets })
}
#[inline]
pub fn get_raw(&self, idx: usize) -> Option<&'a [u8]> {
let (offset, len) = self.offsets[idx];
if len < 0 {
None
} else {
Some(&self.data[offset..offset + len as usize])
}
}
#[inline]
pub fn is_null(&self, idx: usize) -> bool {
self.offsets[idx].1 < 0
}
#[inline]
pub fn column_count(&self) -> usize {
self.offsets.len()
}
#[inline]
pub fn get_bool(&self, idx: usize) -> Option<bool> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_bool(data).ok())
}
#[inline]
pub fn get_i16(&self, idx: usize) -> Option<i16> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_i16(data).ok())
}
#[inline]
pub fn get_i32(&self, idx: usize) -> Option<i32> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_i32(data).ok())
}
#[inline]
pub fn get_i64(&self, idx: usize) -> Option<i64> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_i64(data).ok())
}
#[inline]
pub fn get_f32(&self, idx: usize) -> Option<f32> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_f32(data).ok())
}
#[inline]
pub fn get_f64(&self, idx: usize) -> Option<f64> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_f64(data).ok())
}
#[inline]
pub fn get_str(&self, idx: usize) -> Option<&'a str> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_str(data).ok())
}
#[inline]
pub fn get_bytes(&self, idx: usize) -> Option<&'a [u8]> {
self.get_raw(idx)
}
}
fn parse_data_row_flat(
data: &[u8],
arena: &mut Arena,
out: &mut Vec<(usize, i32)>,
) -> Result<(), DriverError> {
if data.len() < 2 {
return Err(DriverError::Protocol("DataRow too short".into()));
}
let num_cols_raw = i16::from_be_bytes([data[0], data[1]]);
if num_cols_raw < 0 {
return Err(DriverError::Protocol(
"DataRow: negative column count".into(),
));
}
let num_cols = num_cols_raw as usize;
out.reserve(num_cols);
let mut pos = 2;
for _ in 0..num_cols {
if pos + 4 > data.len() {
return Err(DriverError::Protocol("DataRow truncated".into()));
}
let col_len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
if col_len < 0 {
out.push((0, -1));
} else {
let len = col_len as usize;
if pos + len > data.len() {
return Err(DriverError::Protocol(
"DataRow column data truncated".into(),
));
}
let offset = arena.alloc_copy(&data[pos..pos + len]);
out.push((offset, col_len));
pos += len;
}
}
Ok(())
}
pub fn hash_sql(sql: &str) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = RapidHasher::default();
sql.hash(&mut hasher);
hasher.finish()
}
#[cfg(test)]
#[allow(clippy::approx_constant)]
mod tests {
use super::*;
#[test]
fn config_parse_full_url() {
let cfg = Config::from_url("postgres://user:pass@localhost:5432/mydb").unwrap();
assert_eq!(cfg.user, "user");
assert_eq!(cfg.password, "pass");
assert_eq!(cfg.host, "localhost");
assert_eq!(cfg.port, 5432);
assert_eq!(cfg.database, "mydb");
}
#[test]
fn config_parse_default_port() {
let cfg = Config::from_url("postgres://user:pass@localhost/mydb").unwrap();
assert_eq!(cfg.port, 5432);
}
#[test]
fn config_parse_no_password() {
let cfg = Config::from_url("postgres://user@localhost/mydb").unwrap();
assert_eq!(cfg.user, "user");
assert_eq!(cfg.password, "");
}
#[test]
fn config_parse_empty_database() {
let cfg = Config::from_url("postgres://user:pass@localhost").unwrap();
assert_eq!(cfg.database, "user");
}
#[test]
fn config_parse_sslmode() {
let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=require").unwrap();
assert_eq!(cfg.ssl, SslMode::Require);
}
#[test]
fn config_parse_percent_encoding() {
let cfg = Config::from_url("postgres://user%40domain:p%40ss@localhost/db").unwrap();
assert_eq!(cfg.user, "user@domain");
assert_eq!(cfg.password, "p@ss");
}
#[test]
fn config_rejects_bad_scheme() {
let result = Config::from_url("mysql://user:pass@localhost/db");
assert!(result.is_err());
}
#[test]
fn config_rejects_unknown_sslmode() {
let result = Config::from_url("postgres://user:pass@localhost/db?sslmode=requre");
assert!(result.is_err(), "typo 'requre' should be rejected");
let result = Config::from_url("postgres://user:pass@localhost/db?sslmode=REQUIRE");
assert!(result.is_err(), "uppercase should be rejected");
let result = Config::from_url("postgres://user:pass@localhost/db?sslmode=bogus");
assert!(result.is_err(), "bogus value should be rejected");
}
#[test]
fn config_accepts_valid_sslmodes() {
let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=disable").unwrap();
assert_eq!(cfg.ssl, SslMode::Disable);
let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=prefer").unwrap();
assert_eq!(cfg.ssl, SslMode::Prefer);
let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=require").unwrap();
assert_eq!(cfg.ssl, SslMode::Require);
}
#[test]
fn stmt_cache_basic_ops() {
let mut cache = StmtCache::default();
assert_eq!(cache.len(), 0);
assert!(!cache.contains_key(&42));
assert!(cache.get(&42).is_none());
assert!(cache.get_mut(&42).is_none());
assert!(cache.remove(&42).is_none());
}
#[test]
fn stmt_name_format() {
let name = make_stmt_name(0);
assert_eq!(&*name, "s_0000000000000000");
let name = make_stmt_name(0xDEADBEEF12345678);
assert_eq!(&*name, "s_deadbeef12345678");
let name = make_stmt_name(u64::MAX);
assert_eq!(&*name, "s_ffffffffffffffff");
}
#[test]
fn hash_sql_deterministic() {
let h1 = hash_sql("SELECT 1");
let h2 = hash_sql("SELECT 1");
assert_eq!(h1, h2);
}
#[test]
fn hash_sql_different_queries() {
let h1 = hash_sql("SELECT 1");
let h2 = hash_sql("SELECT 2");
assert_ne!(h1, h2);
}
#[test]
fn data_row_parsing() {
let mut arena = Arena::new();
let mut out = Vec::new();
let mut data = Vec::new();
data.extend_from_slice(&2i16.to_be_bytes());
data.extend_from_slice(&4i32.to_be_bytes()); data.extend_from_slice(&42i32.to_be_bytes());
data.extend_from_slice(&(-1i32).to_be_bytes());
parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
assert_eq!(out.len(), 2);
assert_eq!(out[0].1, 4);
assert_eq!(out[1].1, -1);
}
#[test]
fn data_row_empty() {
let mut arena = Arena::new();
let mut out = Vec::new();
let data = 0i16.to_be_bytes();
parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
assert_eq!(out.len(), 0);
}
#[test]
fn query_result_empty() {
let result = QueryResult {
all_col_offsets: vec![],
num_cols: 0,
columns: Arc::from(Vec::new()),
affected_rows: 0,
};
assert!(result.is_empty());
assert_eq!(result.len(), 0);
}
#[test]
fn url_decode_works() {
assert_eq!(url_decode("hello%20world").unwrap(), "hello world");
assert_eq!(url_decode("no%20escape").unwrap(), "no escape");
assert_eq!(url_decode("plain").unwrap(), "plain");
assert_eq!(url_decode("a%40b").unwrap(), "a@b");
}
#[test]
fn url_decode_malformed_percent_trailing() {
let result = url_decode("abc%2");
assert!(result.is_err(), "truncated %2 should error");
}
#[test]
fn url_decode_malformed_percent_no_digits() {
let result = url_decode("abc%");
assert!(result.is_err(), "bare % at end should error");
}
#[test]
fn url_decode_invalid_hex_digit() {
let result = url_decode("abc%GG");
assert!(result.is_err(), "%GG should error");
}
#[test]
fn url_decode_invalid_hex_second_digit() {
let result = url_decode("abc%2Z");
assert!(result.is_err(), "%2Z should error");
}
#[test]
fn url_decode_invalid_utf8_percent() {
let result = url_decode("%80%81");
assert!(result.is_err(), "invalid UTF-8 bytes should error");
}
#[test]
fn url_decode_percent_everywhere() {
assert_eq!(url_decode("%41%42%43").unwrap(), "ABC");
assert_eq!(url_decode("%61").unwrap(), "a");
assert_eq!(url_decode("x%2Fy%2Fz").unwrap(), "x/y/z");
}
#[test]
fn url_decode_bare_percent_middle() {
assert!(url_decode("a%b").is_err(), "bare % in middle should error");
}
#[test]
fn url_decode_multibyte_utf8() {
let result = url_decode("caf%C3%A9").unwrap();
assert_eq!(result, "caf\u{00e9}"); }
#[test]
fn config_parse_postgresql_scheme() {
let cfg = Config::from_url("postgresql://user:pass@localhost:5432/mydb").unwrap();
assert_eq!(cfg.user, "user");
assert_eq!(cfg.password, "pass");
assert_eq!(cfg.host, "localhost");
assert_eq!(cfg.port, 5432);
assert_eq!(cfg.database, "mydb");
}
#[test]
fn config_parse_no_password_standalone() {
let cfg = Config::from_url("postgres://admin@db.example.com/myapp").unwrap();
assert_eq!(cfg.user, "admin");
assert_eq!(cfg.password, "");
assert_eq!(cfg.host, "db.example.com");
assert_eq!(cfg.database, "myapp");
}
#[test]
fn config_empty_database_falls_back_to_user() {
let cfg = Config::from_url("postgres://testuser:pass@localhost").unwrap();
assert_eq!(cfg.database, "testuser");
}
#[test]
fn config_unknown_sslmode_error() {
let result = Config::from_url("postgres://u:p@h/d?sslmode=verify-full");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("unknown sslmode"),
"should describe unknown sslmode: {err}"
);
}
#[test]
fn config_multiple_query_params() {
let cfg = Config::from_url(
"postgres://user:pass@localhost/db?sslmode=disable&statement_timeout=60",
)
.unwrap();
assert_eq!(cfg.ssl, SslMode::Disable);
assert_eq!(cfg.statement_timeout_secs, 60);
}
#[test]
fn url_decode_invalid_percent_zz() {
let result = url_decode("abc%ZZ");
assert!(result.is_err(), "%ZZ should error");
}
#[test]
fn url_decode_truncated_percent_trailing() {
let result = url_decode("abc%");
assert!(result.is_err(), "trailing % should error");
}
#[test]
fn url_decode_invalid_utf8() {
let result = url_decode("%80");
assert!(result.is_err(), "invalid UTF-8 should error");
}
#[cfg(not(feature = "tls"))]
#[test]
fn config_sslmode_require_without_tls_feature() {
let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=require").unwrap();
assert_eq!(cfg.ssl, SslMode::Require);
}
#[test]
fn stmt_name_format_verification() {
let name = make_stmt_name(0xDEADBEEFCAFEBABE);
assert!(name.starts_with("s_"), "must start with s_");
assert_eq!(name.len(), 18, "s_ (2) + 16 hex = 18");
assert!(
name[2..].chars().all(|c| c.is_ascii_hexdigit()),
"remaining chars must be hex: {}",
&*name
);
}
#[test]
fn stmt_name_zero() {
let name = make_stmt_name(0);
assert_eq!(&*name, "s_0000000000000000");
}
#[test]
fn stmt_name_max() {
let name = make_stmt_name(u64::MAX);
assert_eq!(&*name, "s_ffffffffffffffff");
}
#[test]
fn config_validate_empty_host() {
let cfg = Config {
host: String::new(),
port: 5432,
user: "user".into(),
password: "pass".into(),
database: "db".into(),
ssl: SslMode::Disable,
statement_timeout_secs: 30,
};
assert!(cfg.validate().is_err());
}
#[test]
fn config_validate_empty_user() {
let cfg = Config {
host: "localhost".into(),
port: 5432,
user: String::new(),
password: "pass".into(),
database: "db".into(),
ssl: SslMode::Disable,
statement_timeout_secs: 30,
};
assert!(cfg.validate().is_err());
}
#[test]
fn config_validate_empty_database() {
let cfg = Config {
host: "localhost".into(),
port: 5432,
user: "user".into(),
password: "pass".into(),
database: String::new(),
ssl: SslMode::Disable,
statement_timeout_secs: 30,
};
assert!(cfg.validate().is_err());
}
#[test]
fn config_missing_at_sign() {
let result = Config::from_url("postgres://userpasslocalhost/db");
assert!(result.is_err());
}
#[test]
fn config_custom_port() {
let cfg = Config::from_url("postgres://user:pass@localhost:5433/db").unwrap();
assert_eq!(cfg.port, 5433);
}
#[test]
fn config_invalid_port() {
let result = Config::from_url("postgres://user:pass@localhost:notaport/db");
assert!(result.is_err());
}
#[test]
fn notification_struct_fields() {
let n = Notification {
pid: 42,
channel: "test_chan".to_owned(),
payload: "hello".to_owned(),
};
assert_eq!(n.pid, 42);
assert_eq!(n.channel, "test_chan");
assert_eq!(n.payload, "hello");
}
#[test]
fn notification_clone() {
let n = Notification {
pid: 1,
channel: "c".to_owned(),
payload: "p".to_owned(),
};
let n2 = n.clone();
assert_eq!(n2.pid, 1);
assert_eq!(n2.channel, "c");
}
#[test]
fn notification_debug() {
let n = Notification {
pid: 1,
channel: "c".to_owned(),
payload: "p".to_owned(),
};
let dbg = format!("{n:?}");
assert!(dbg.contains("Notification"));
}
#[test]
fn stmt_info_has_last_used_counter() {
let info = StmtInfo {
name: "s_test".into(),
columns: Arc::from(Vec::new()),
last_used: 42,
bind_template: None,
};
assert_eq!(info.last_used, 42);
}
fn make_data_row(columns: &[Option<&[u8]>]) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&(columns.len() as i16).to_be_bytes());
for col in columns {
match col {
Some(data) => {
buf.extend_from_slice(&(data.len() as i32).to_be_bytes());
buf.extend_from_slice(data);
}
None => {
buf.extend_from_slice(&(-1i32).to_be_bytes());
}
}
}
buf
}
#[test]
fn pg_data_row_get_i32() {
let data = make_data_row(&[Some(&42i32.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i32(0), Some(42));
assert_eq!(row.column_count(), 1);
}
#[test]
fn pg_data_row_get_i64() {
let data = make_data_row(&[Some(&12345i64.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i64(0), Some(12345));
}
#[test]
fn pg_data_row_get_str() {
let data = make_data_row(&[Some(b"hello")]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_str(0), Some("hello"));
}
#[test]
fn pg_data_row_get_bytes() {
let data = make_data_row(&[Some(&[0xDE, 0xAD, 0xBE, 0xEF])]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_bytes(0), Some(&[0xDE, 0xAD, 0xBE, 0xEF][..]));
}
#[test]
fn pg_data_row_get_bool() {
let data = make_data_row(&[Some(&[1u8])]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_bool(0), Some(true));
let data = make_data_row(&[Some(&[0u8])]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_bool(0), Some(false));
}
#[test]
fn pg_data_row_get_f64() {
let data = make_data_row(&[Some(&3.14f64.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert!((row.get_f64(0).unwrap() - 3.14).abs() < 1e-10);
}
#[test]
fn pg_data_row_null_column() {
let data = make_data_row(&[None]);
let row = PgDataRow::new(&data).unwrap();
assert!(row.is_null(0));
assert_eq!(row.get_i32(0), None);
assert_eq!(row.get_str(0), None);
}
#[test]
fn pg_data_row_multiple_columns() {
let data = make_data_row(&[
Some(&42i32.to_be_bytes()),
Some(b"alice"),
Some(b"alice@example.com"),
Some(&[1u8]),
Some(&3.14f64.to_be_bytes()),
]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.column_count(), 5);
assert_eq!(row.get_i32(0), Some(42));
assert_eq!(row.get_str(1), Some("alice"));
assert_eq!(row.get_str(2), Some("alice@example.com"));
assert_eq!(row.get_bool(3), Some(true));
assert!((row.get_f64(4).unwrap() - 3.14).abs() < 1e-10);
}
#[test]
fn pg_data_row_mixed_null() {
let data = make_data_row(&[Some(&42i32.to_be_bytes()), None, Some(b"text")]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i32(0), Some(42));
assert!(row.is_null(1));
assert_eq!(row.get_str(1), None);
assert_eq!(row.get_str(2), Some("text"));
}
#[test]
fn pg_data_row_empty() {
let data = make_data_row(&[]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.column_count(), 0);
}
#[test]
fn pg_data_row_too_short() {
let data = vec![0u8]; assert!(PgDataRow::new(&data).is_err());
}
#[test]
fn pg_data_row_truncated() {
let mut data = Vec::new();
data.extend_from_slice(&2i16.to_be_bytes());
data.extend_from_slice(&4i32.to_be_bytes());
data.extend_from_slice(&42i32.to_be_bytes());
assert!(PgDataRow::new(&data).is_err());
}
#[test]
fn pg_data_row_get_i16() {
let data = make_data_row(&[Some(&7i16.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i16(0), Some(7));
}
#[test]
fn pg_data_row_get_f32() {
let data = make_data_row(&[Some(&2.5f32.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert!((row.get_f32(0).unwrap() - 2.5).abs() < 1e-6);
}
#[test]
fn pg_data_row_get_raw_null() {
let data = make_data_row(&[None]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_raw(0), None);
}
#[test]
fn pg_data_row_get_raw_data() {
let data = make_data_row(&[Some(&[1, 2, 3])]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_raw(0), Some(&[1u8, 2, 3][..]));
}
#[test]
fn pg_data_row_stack_alloc_16_columns() {
let cols: Vec<Option<&[u8]>> = (0..16).map(|_| Some(&[0u8][..])).collect();
let data = make_data_row(&cols);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.column_count(), 16);
for i in 0..16 {
assert_eq!(row.get_raw(i), Some(&[0u8][..]));
}
}
#[test]
fn inline_sequential_decode_five_columns() {
let data = make_data_row(&[
Some(&42i32.to_be_bytes()),
Some(b"alice"),
Some(b"alice@example.com"),
Some(&[1u8]),
Some(&3.14f64.to_be_bytes()),
]);
let mut pos: usize = 2;
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
assert_eq!(len, 4);
let id = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += len as usize;
assert_eq!(id, 42);
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
assert_eq!(len, 5);
let name = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
pos += len as usize;
assert_eq!(name, "alice");
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let email = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
pos += len as usize;
assert_eq!(email, "alice@example.com");
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
assert_eq!(len, 1);
let active = data[pos] != 0;
pos += len as usize;
assert!(active);
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
assert_eq!(len, 8);
let score = f64::from_be_bytes([
data[pos],
data[pos + 1],
data[pos + 2],
data[pos + 3],
data[pos + 4],
data[pos + 5],
data[pos + 6],
data[pos + 7],
]);
pos += len as usize;
assert!((score - 3.14).abs() < 1e-10);
assert_eq!(pos, data.len());
}
#[test]
fn inline_sequential_decode_with_nulls() {
let data = make_data_row(&[
Some(&42i32.to_be_bytes()),
None, Some(b"text"),
]);
let mut pos: usize = 2;
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let id = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += len as usize;
assert_eq!(id, 42);
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let name: Option<&str> = if len < 0 {
None
} else {
let s = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
pos += len as usize;
Some(s)
};
assert!(name.is_none());
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let txt = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
pos += len as usize;
assert_eq!(txt, "text");
assert_eq!(pos, data.len());
}
#[test]
fn inline_sequential_decode_all_scalar_types() {
let data = make_data_row(&[
Some(&[1u8]), Some(&7i16.to_be_bytes()), Some(&42i32.to_be_bytes()), Some(&12345i64.to_be_bytes()), Some(&2.5f32.to_be_bytes()), Some(&3.14f64.to_be_bytes()), ]);
let mut pos: usize = 2;
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let v_bool = data[pos] != 0;
pos += len as usize;
assert!(v_bool);
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let v_i16 = i16::from_be_bytes([data[pos], data[pos + 1]]);
pos += len as usize;
assert_eq!(v_i16, 7);
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let v_i32 = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += len as usize;
assert_eq!(v_i32, 42);
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let v_i64 = i64::from_be_bytes([
data[pos],
data[pos + 1],
data[pos + 2],
data[pos + 3],
data[pos + 4],
data[pos + 5],
data[pos + 6],
data[pos + 7],
]);
pos += len as usize;
assert_eq!(v_i64, 12345);
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let v_f32 = f32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += len as usize;
assert!((v_f32 - 2.5).abs() < 1e-6);
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let v_f64 = f64::from_be_bytes([
data[pos],
data[pos + 1],
data[pos + 2],
data[pos + 3],
data[pos + 4],
data[pos + 5],
data[pos + 6],
data[pos + 7],
]);
pos += len as usize;
assert!((v_f64 - 3.14).abs() < 1e-10);
assert_eq!(pos, data.len());
}
#[test]
fn pg_data_row_new_is_public() {
let data = make_data_row(&[Some(&42i32.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i32(0), Some(42));
}
#[test]
fn inline_decode_matches_pgdatarow() {
let data = make_data_row(&[
Some(&99i32.to_be_bytes()),
Some(b"hello world"),
None,
Some(&[0u8]),
Some(&1.23f64.to_be_bytes()),
]);
let row = PgDataRow::new(&data).unwrap();
let dr_i32 = row.get_i32(0);
let dr_str = row.get_str(1);
let dr_null = row.get_str(2);
let dr_bool = row.get_bool(3);
let dr_f64 = row.get_f64(4);
let mut pos: usize = 2;
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let in_i32 = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += len as usize;
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let in_str = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
pos += len as usize;
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let in_null: Option<&str> = if len < 0 { None } else { unreachable!() };
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let in_bool = data[pos] != 0;
pos += len as usize;
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let in_f64 = f64::from_be_bytes([
data[pos],
data[pos + 1],
data[pos + 2],
data[pos + 3],
data[pos + 4],
data[pos + 5],
data[pos + 6],
data[pos + 7],
]);
pos += len as usize;
assert_eq!(dr_i32, Some(in_i32));
assert_eq!(dr_str, Some(in_str));
assert_eq!(dr_null, in_null);
assert_eq!(dr_bool, Some(in_bool));
assert!((dr_f64.unwrap() - in_f64).abs() < 1e-15);
assert_eq!(pos, data.len());
}
#[test]
fn config_host_is_uds_absolute_path() {
let cfg = Config {
host: "/tmp".into(),
port: 5432,
user: "user".into(),
password: "".into(),
database: "db".into(),
ssl: SslMode::Disable,
statement_timeout_secs: 30,
};
assert!(cfg.host_is_uds());
assert_eq!(cfg.uds_path(), "/tmp/.s.PGSQL.5432");
}
#[test]
fn config_host_is_uds_var_run() {
let cfg = Config {
host: "/var/run/postgresql".into(),
port: 5433,
user: "user".into(),
password: "".into(),
database: "db".into(),
ssl: SslMode::Disable,
statement_timeout_secs: 30,
};
assert!(cfg.host_is_uds());
assert_eq!(cfg.uds_path(), "/var/run/postgresql/.s.PGSQL.5433");
}
#[test]
fn config_host_is_not_uds_for_hostname() {
let cfg = Config {
host: "localhost".into(),
port: 5432,
user: "user".into(),
password: "".into(),
database: "db".into(),
ssl: SslMode::Disable,
statement_timeout_secs: 30,
};
assert!(!cfg.host_is_uds());
}
#[test]
fn config_host_is_not_uds_for_ip() {
let cfg = Config {
host: "127.0.0.1".into(),
port: 5432,
user: "user".into(),
password: "".into(),
database: "db".into(),
ssl: SslMode::Disable,
statement_timeout_secs: 30,
};
assert!(!cfg.host_is_uds());
}
#[test]
fn config_parse_uds_host_query_param() {
let cfg = Config::from_url("postgres://user@localhost/mydb?host=/tmp").unwrap();
assert_eq!(cfg.host, "/tmp");
assert!(cfg.host_is_uds());
assert_eq!(cfg.uds_path(), "/tmp/.s.PGSQL.5432");
assert_eq!(cfg.database, "mydb");
assert_eq!(cfg.user, "user");
}
#[test]
fn config_parse_uds_host_query_param_custom_port() {
let cfg = Config::from_url("postgres://user@localhost:5433/mydb?host=/var/run/postgresql")
.unwrap();
assert_eq!(cfg.host, "/var/run/postgresql");
assert_eq!(cfg.port, 5433);
assert_eq!(cfg.uds_path(), "/var/run/postgresql/.s.PGSQL.5433");
}
#[test]
fn config_parse_uds_host_with_other_params() {
let cfg = Config::from_url(
"postgres://user@localhost/db?host=/tmp&sslmode=disable&statement_timeout=60",
)
.unwrap();
assert_eq!(cfg.host, "/tmp");
assert!(cfg.host_is_uds());
assert_eq!(cfg.ssl, SslMode::Disable);
assert_eq!(cfg.statement_timeout_secs, 60);
}
#[test]
fn config_parse_uds_host_percent_encoded() {
let cfg = Config::from_url("postgres://user@localhost/db?host=%2Ftmp").unwrap();
assert_eq!(cfg.host, "/tmp");
assert!(cfg.host_is_uds());
}
#[test]
fn config_parse_tcp_host_not_overridden_without_param() {
let cfg = Config::from_url("postgres://user@myserver/db").unwrap();
assert_eq!(cfg.host, "myserver");
assert!(!cfg.host_is_uds());
}
#[test]
fn config_parse_uds_host_overrides_url_hostname() {
let cfg = Config::from_url("postgres://user@db.example.com/mydb?host=/var/run/postgresql")
.unwrap();
assert_eq!(cfg.host, "/var/run/postgresql");
assert!(cfg.host_is_uds());
}
#[test]
fn config_parse_uds_empty_url_host() {
let cfg = Config::from_url("postgres://user@/mydb?host=/tmp").unwrap();
assert_eq!(cfg.host, "/tmp");
assert!(cfg.host_is_uds());
assert_eq!(cfg.database, "mydb");
}
#[test]
fn pg_data_row_all_null_columns() {
let data = make_data_row(&[None, None, None, None, None]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.column_count(), 5);
for i in 0..5 {
assert!(row.is_null(i), "column {i} should be null");
assert_eq!(row.get_raw(i), None);
assert_eq!(row.get_i32(i), None);
assert_eq!(row.get_i64(i), None);
assert_eq!(row.get_str(i), None);
assert_eq!(row.get_bool(i), None);
assert_eq!(row.get_f64(i), None);
}
}
#[test]
fn pg_data_row_very_long_text() {
let long_text = "x".repeat(2048);
let data = make_data_row(&[Some(long_text.as_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_str(0), Some(long_text.as_str()));
}
#[test]
fn pg_data_row_empty_text() {
let data = make_data_row(&[Some(b"")]);
let row = PgDataRow::new(&data).unwrap();
assert!(!row.is_null(0));
assert_eq!(row.get_str(0), Some(""));
assert_eq!(row.get_bytes(0), Some(&[][..]));
}
#[test]
fn pg_data_row_20_columns_exceeds_inline() {
let col_data: Vec<[u8; 4]> = (0..20).map(|i: i32| i.to_be_bytes()).collect();
let cols: Vec<Option<&[u8]>> = col_data.iter().map(|b| Some(b.as_slice())).collect();
let data = make_data_row(&cols);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.column_count(), 20);
for i in 0..20 {
assert_eq!(row.get_i32(i), Some(i as i32));
}
}
#[test]
fn pg_data_row_is_null_each_position() {
let data = make_data_row(&[Some(&1i32.to_be_bytes()), None, Some(&3i32.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert!(!row.is_null(0));
assert!(row.is_null(1));
assert!(!row.is_null(2));
}
#[test]
fn pg_data_row_negative_column_count() {
let data = (-1i16).to_be_bytes();
assert!(PgDataRow::new(&data).is_err());
}
#[test]
fn pg_data_row_get_str_invalid_utf8() {
let invalid_utf8 = &[0xFF, 0xFE, 0x80];
let data = make_data_row(&[Some(invalid_utf8)]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_str(0), None);
assert_eq!(row.get_bytes(0), Some(&[0xFF, 0xFE, 0x80][..]));
}
#[test]
fn pg_data_row_get_i32_wrong_length() {
let data = make_data_row(&[Some(&7i16.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i32(0), None); assert_eq!(row.get_i16(0), Some(7)); }
#[test]
fn pg_data_row_get_i64_wrong_length() {
let data = make_data_row(&[Some(&42i32.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i64(0), None);
}
#[test]
fn pg_data_row_get_f64_wrong_length() {
let data = make_data_row(&[Some(&2.5f32.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_f64(0), None); }
#[test]
fn pg_data_row_get_f32_wrong_length() {
let data = make_data_row(&[Some(&3.14f64.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_f32(0), None); }
#[test]
fn pg_data_row_get_bool_wrong_length() {
let data = make_data_row(&[Some(&42i32.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_bool(0), None);
}
#[test]
fn pg_data_row_unicode_text() {
let texts = [
"\u{1F600}\u{1F4A9}\u{1F680}", "\u{4e16}\u{754c}", "\u{0645}\u{0631}\u{062D}", "\u{1F468}\u{200D}\u{1F469}", ];
for text in &texts {
let data = make_data_row(&[Some(text.as_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_str(0), Some(*text));
}
}
#[test]
fn pg_data_row_i32_boundary_values() {
for &val in &[i32::MIN, -1, 0, 1, i32::MAX] {
let data = make_data_row(&[Some(&val.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i32(0), Some(val), "failed for {val}");
}
}
#[test]
fn pg_data_row_i64_boundary_values() {
for &val in &[i64::MIN, -1, 0, 1, i64::MAX] {
let data = make_data_row(&[Some(&val.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i64(0), Some(val), "failed for {val}");
}
}
#[test]
fn pg_data_row_f64_special_values() {
let data = make_data_row(&[Some(&f64::INFINITY.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_f64(0), Some(f64::INFINITY));
let data = make_data_row(&[Some(&f64::NEG_INFINITY.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_f64(0), Some(f64::NEG_INFINITY));
let data = make_data_row(&[Some(&f64::NAN.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert!(row.get_f64(0).unwrap().is_nan());
}
#[test]
fn pg_data_row_f32_special_values() {
let data = make_data_row(&[Some(&f32::INFINITY.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_f32(0), Some(f32::INFINITY));
let data = make_data_row(&[Some(&f32::NAN.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert!(row.get_f32(0).unwrap().is_nan());
}
#[test]
fn pg_data_row_i16_boundary_values() {
for &val in &[i16::MIN, -1, 0, 1, i16::MAX] {
let data = make_data_row(&[Some(&val.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i16(0), Some(val));
}
}
#[test]
fn data_row_flat_all_null() {
let mut arena = Arena::new();
let mut out = Vec::new();
let mut data = Vec::new();
data.extend_from_slice(&4i16.to_be_bytes());
for _ in 0..4 {
data.extend_from_slice(&(-1i32).to_be_bytes());
}
parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
assert_eq!(out.len(), 4);
for (_, len) in &out {
assert_eq!(*len, -1);
}
}
#[test]
fn data_row_flat_long_text() {
let mut arena = Arena::new();
let mut out = Vec::new();
let long = vec![b'A'; 1024];
let mut data = Vec::new();
data.extend_from_slice(&1i16.to_be_bytes());
data.extend_from_slice(&(long.len() as i32).to_be_bytes());
data.extend_from_slice(&long);
parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
assert_eq!(out[0].1, 1024);
let stored = arena.get(out[0].0, 1024);
assert!(stored.iter().all(|&b| b == b'A'));
}
#[test]
fn data_row_flat_empty_text() {
let mut arena = Arena::new();
let mut out = Vec::new();
let mut data = Vec::new();
data.extend_from_slice(&1i16.to_be_bytes());
data.extend_from_slice(&0i32.to_be_bytes()); parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
assert_eq!(out[0].1, 0);
}
#[test]
fn query_result_from_parts() {
let result = QueryResult::from_parts(vec![(0, 4), (0, -1)], 2, Arc::from(Vec::new()), 5);
assert_eq!(result.len(), 1);
assert_eq!(result.num_cols, 2);
assert_eq!(result.affected_rows, 5);
}
#[test]
fn query_result_affected_rows() {
let result = QueryResult {
all_col_offsets: vec![],
num_cols: 0,
columns: Arc::from(Vec::new()),
affected_rows: 42,
};
assert_eq!(result.affected_rows, 42);
assert!(result.is_empty());
}
#[test]
fn driver_error_server_with_hint() {
let e = DriverError::Server {
code: "42601".into(),
message: "syntax error".into(),
detail: None,
hint: Some("check your SQL".into()),
position: Some(10),
};
let s = e.to_string();
assert!(s.contains("HINT: check your SQL"));
assert!(s.contains("(at position 10)"));
}
#[test]
fn driver_error_server_with_all_fields() {
let e = DriverError::Server {
code: "23505".into(),
message: "unique violation".into(),
detail: Some("Key (id)=(1) already exists.".into()),
hint: Some("change the id".into()),
position: Some(1),
};
let s = e.to_string();
assert!(s.contains("23505"));
assert!(s.contains("unique violation"));
assert!(s.contains("Key (id)=(1) already exists."));
assert!(s.contains("change the id"));
assert!(s.contains("(at position 1)"));
}
#[test]
fn config_statement_timeout_default() {
let cfg = Config::from_url("postgres://user:pass@localhost/db").unwrap();
assert_eq!(cfg.statement_timeout_secs, 30);
}
#[test]
fn config_statement_timeout_custom() {
let cfg =
Config::from_url("postgres://user:pass@localhost/db?statement_timeout=120").unwrap();
assert_eq!(cfg.statement_timeout_secs, 120);
}
#[test]
fn config_statement_timeout_zero() {
let cfg =
Config::from_url("postgres://user:pass@localhost/db?statement_timeout=0").unwrap();
assert_eq!(cfg.statement_timeout_secs, 0);
}
#[test]
fn config_statement_timeout_invalid_falls_back() {
let cfg =
Config::from_url("postgres://user:pass@localhost/db?statement_timeout=notanumber")
.unwrap();
assert_eq!(cfg.statement_timeout_secs, 30); }
#[test]
fn config_uds_path_format() {
let cfg = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
assert_eq!(cfg.uds_path(), "/tmp/.s.PGSQL.5432");
}
#[test]
fn config_uds_path_custom_port() {
let cfg = Config::from_url("postgres://user@localhost:5433/db?host=/tmp").unwrap();
assert_eq!(cfg.uds_path(), "/tmp/.s.PGSQL.5433");
}
#[test]
fn url_decode_empty_string() {
assert_eq!(url_decode("").unwrap(), "");
}
#[test]
fn url_decode_no_encoding() {
assert_eq!(url_decode("hello").unwrap(), "hello");
}
#[test]
fn url_decode_all_ascii_hex() {
assert_eq!(url_decode("%2F").unwrap(), "/");
assert_eq!(url_decode("%2f").unwrap(), "/");
}
#[test]
fn hash_sql_empty() {
let _h = hash_sql(""); }
#[test]
fn hash_sql_whitespace_only() {
let h = hash_sql(" ");
assert_ne!(h, hash_sql(""));
}
#[test]
fn hash_sql_very_long() {
let long_sql = "SELECT ".to_string() + &"x".repeat(10_000);
let h = hash_sql(&long_sql);
assert_eq!(h, hash_sql(&long_sql));
}
#[test]
fn hash_sql_unicode() {
let h = hash_sql("SELECT '\u{1F600}'");
assert_ne!(h, hash_sql("SELECT 'x'"));
}
}