use std::sync::Arc;
use rapidhash::quality::RapidHasher;
use crate::arena::Arena;
use crate::DriverError;
#[derive(Clone)]
pub struct Config {
pub host: String,
pub port: u16,
pub user: String,
pub password: String,
pub database: String,
pub ssl: SslMode,
pub statement_timeout_secs: u32,
pub statement_cache_mode: StatementCacheMode,
pub ssl_root_cert: Option<String>,
pub ssl_cert: Option<String>,
pub ssl_key: Option<String>,
}
impl Drop for Config {
fn drop(&mut self) {
use zeroize::Zeroize;
self.password.zeroize();
}
}
impl std::fmt::Debug for Config {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Config")
.field("host", &self.host)
.field("port", &self.port)
.field("user", &self.user)
.field("password", &"[REDACTED]")
.field("database", &self.database)
.field("ssl", &self.ssl)
.field("statement_timeout_secs", &self.statement_timeout_secs)
.field("statement_cache_mode", &self.statement_cache_mode)
.field("ssl_root_cert", &self.ssl_root_cert)
.field("ssl_cert", &self.ssl_cert)
.field("ssl_key", &self.ssl_key)
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SslMode {
Disable,
Prefer,
Require,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum StatementCacheMode {
#[default]
Named,
Disabled,
}
impl Config {
pub fn from_url(url: &str) -> Result<Self, DriverError> {
let url = url
.strip_prefix("postgres://")
.or_else(|| url.strip_prefix("postgresql://"))
.ok_or_else(|| DriverError::Protocol("URL must start with postgres://".into()))?;
let (userinfo, rest) = url
.split_once('@')
.ok_or_else(|| DriverError::Protocol("missing @ in connection URL".into()))?;
let (user, password) = userinfo.split_once(':').unwrap_or((userinfo, ""));
let (hostport, rest) = rest.split_once('/').unwrap_or((rest, ""));
let (database, params) = rest.split_once('?').unwrap_or((rest, ""));
let (host, port) = if let Some((h, p)) = hostport.split_once(':') {
let port = p
.parse::<u16>()
.map_err(|_| DriverError::Protocol(format!("invalid port: {p}")))?;
(h.to_owned(), port)
} else {
(hostport.to_owned(), 5432)
};
let mut ssl = SslMode::Prefer;
let mut statement_timeout_secs: u32 = 30;
let mut statement_cache_mode = StatementCacheMode::Named;
let mut host_override: Option<String> = None;
let mut ssl_root_cert: Option<String> = None;
let mut ssl_cert: Option<String> = None;
let mut ssl_key: Option<String> = None;
for param in params.split('&') {
if param.is_empty() {
continue;
}
if let Some(val) = param.strip_prefix("sslmode=") {
ssl = match val {
"disable" => SslMode::Disable,
"prefer" => SslMode::Prefer,
"require" => SslMode::Require,
_ => {
return Err(DriverError::Protocol(format!(
"unknown sslmode: '{val}' (expected: disable, prefer, require)"
)));
}
};
} else if let Some(val) = param.strip_prefix("statement_timeout=") {
statement_timeout_secs = val.parse::<u32>().unwrap_or(30);
} else if let Some(val) = param.strip_prefix("statement_cache=") {
statement_cache_mode = match val {
"named" => StatementCacheMode::Named,
"disabled" => StatementCacheMode::Disabled,
_ => {
return Err(DriverError::Protocol(format!(
"unknown statement_cache mode: '{val}' (expected: named, disabled)"
)));
}
};
} else if let Some(val) = param.strip_prefix("host=") {
host_override = Some(url_decode(val)?);
} else if let Some(val) = param.strip_prefix("sslrootcert=") {
ssl_root_cert = Some(url_decode(val)?);
} else if let Some(val) = param.strip_prefix("sslcert=") {
ssl_cert = Some(url_decode(val)?);
} else if let Some(val) = param.strip_prefix("sslkey=") {
ssl_key = Some(url_decode(val)?);
}
}
let final_host = if let Some(h) = host_override {
h
} else {
url_decode(&host)?
};
let config = Config {
host: final_host,
port,
user: url_decode(user)?,
password: url_decode(password)?,
database: if database.is_empty() {
url_decode(user)?
} else {
url_decode(database)?
},
ssl,
statement_timeout_secs,
statement_cache_mode,
ssl_root_cert,
ssl_cert,
ssl_key,
};
config.validate()?;
Ok(config)
}
pub fn validate(&self) -> Result<(), DriverError> {
if self.host.is_empty() {
return Err(DriverError::Protocol("host cannot be empty".into()));
}
if self.user.is_empty() {
return Err(DriverError::Protocol("user cannot be empty".into()));
}
if self.database.is_empty() {
return Err(DriverError::Protocol("database cannot be empty".into()));
}
Ok(())
}
pub fn host_is_uds(&self) -> bool {
self.host.starts_with('/')
}
pub fn uds_path(&self) -> String {
format!("{}/.s.PGSQL.{}", self.host, self.port)
}
}
fn url_decode(s: &str) -> Result<String, DriverError> {
let mut bytes = Vec::with_capacity(s.len());
let input = s.as_bytes();
let mut i = 0;
while i < input.len() {
if input[i] == b'%' {
if i + 2 >= input.len() {
return Err(DriverError::Protocol(format!(
"malformed percent-encoding in URL: '{s}'"
)));
}
let hi = hex_val(input[i + 1]).ok_or_else(|| {
DriverError::Protocol(format!(
"invalid hex digit '{}' in URL: '{s}'",
input[i + 1] as char
))
})?;
let lo = hex_val(input[i + 2]).ok_or_else(|| {
DriverError::Protocol(format!(
"invalid hex digit '{}' in URL: '{s}'",
input[i + 2] as char
))
})?;
bytes.push(hi * 16 + lo);
i += 3;
} else {
bytes.push(input[i]);
i += 1;
}
}
String::from_utf8(bytes)
.map_err(|_| DriverError::Protocol(format!("invalid UTF-8 in URL: '{s}'")))
}
fn hex_val(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
pub(crate) enum StartupAction {
AuthOk,
AuthCleartext,
AuthMd5([u8; 4]),
AuthSasl(Vec<u8>),
ParameterStatus(Box<str>, Box<str>),
BackendKeyData(i32, i32),
ReadyForQuery(u8),
Error(String),
Notice,
}
#[derive(Debug, Clone)]
pub struct ColumnDesc {
pub name: Box<str>,
pub type_oid: u32,
pub table_oid: u32,
pub type_size: i16,
pub column_id: i16,
}
#[derive(Debug, Clone)]
pub struct PrepareResult {
pub columns: Vec<ColumnDesc>,
pub param_oids: Vec<u32>,
}
pub type SimpleRow = Vec<Option<String>>;
#[derive(Debug, Clone)]
pub struct Notification {
pub pid: i32,
pub channel: String,
pub payload: String,
}
pub struct QueryResult {
pub(crate) all_col_offsets: Vec<(usize, i32)>,
pub(crate) num_cols: usize,
pub(crate) columns: Arc<[ColumnDesc]>,
pub(crate) affected_rows: u64,
pub(crate) data_buf: Option<Vec<u8>>,
}
impl QueryResult {
pub fn from_parts(
all_col_offsets: Vec<(usize, i32)>,
num_cols: usize,
columns: Arc<[ColumnDesc]>,
affected_rows: u64,
) -> Self {
Self {
all_col_offsets,
num_cols,
columns,
affected_rows,
data_buf: None,
}
}
pub fn from_parts_with_buf(
all_col_offsets: Vec<(usize, i32)>,
num_cols: usize,
columns: Arc<[ColumnDesc]>,
affected_rows: u64,
data_buf: Vec<u8>,
) -> Self {
Self {
all_col_offsets,
num_cols,
columns,
affected_rows,
data_buf: if data_buf.is_empty() {
None
} else {
Some(data_buf)
},
}
}
pub fn len(&self) -> usize {
if self.num_cols == 0 {
return 0;
}
self.all_col_offsets.len() / self.num_cols
}
pub fn is_empty(&self) -> bool {
self.all_col_offsets.is_empty()
}
pub fn affected_rows(&self) -> u64 {
self.affected_rows
}
pub fn columns(&self) -> &[ColumnDesc] {
&self.columns
}
pub fn row<'a>(&'a self, idx: usize, arena: &'a Arena) -> Row<'a> {
let start = idx * self.num_cols;
let end = start + self.num_cols;
Row {
data: self.data_buf.as_deref(),
arena,
col_offsets: &self.all_col_offsets[start..end],
columns: &self.columns,
}
}
pub fn take_col_offsets(&mut self) -> Vec<(usize, i32)> {
std::mem::take(&mut self.all_col_offsets)
}
pub fn take_data_buf(&mut self) -> Option<Vec<u8>> {
self.data_buf.take()
}
pub fn rows<'a>(&'a self, arena: &'a Arena) -> impl Iterator<Item = Row<'a>> {
let num_cols = self.num_cols;
let columns = &self.columns;
let data = self.data_buf.as_deref();
self.all_col_offsets
.chunks(num_cols.max(1))
.map(move |chunk| Row {
data,
arena,
col_offsets: chunk,
columns,
})
}
}
pub struct Row<'a> {
data: Option<&'a [u8]>,
arena: &'a Arena,
col_offsets: &'a [(usize, i32)],
columns: &'a [ColumnDesc],
}
impl<'a> Row<'a> {
#[inline]
pub fn get_raw(&self, idx: usize) -> Option<&'a [u8]> {
let (offset, len) = self.col_offsets[idx];
if len < 0 {
None
} else if let Some(buf) = self.data {
Some(&buf[offset..offset + len as usize])
} else {
Some(self.arena.get(offset, len as usize))
}
}
#[inline]
pub fn is_null(&self, idx: usize) -> bool {
self.col_offsets[idx].1 < 0
}
#[inline]
pub fn column_count(&self) -> usize {
self.col_offsets.len()
}
#[inline]
pub fn get_bool(&self, idx: usize) -> Option<bool> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_bool(data).ok())
}
#[inline]
pub fn get_i16(&self, idx: usize) -> Option<i16> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_i16(data).ok())
}
#[inline]
pub fn get_i32(&self, idx: usize) -> Option<i32> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_i32(data).ok())
}
#[inline]
pub fn get_i64(&self, idx: usize) -> Option<i64> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_i64(data).ok())
}
#[inline]
pub fn get_f32(&self, idx: usize) -> Option<f32> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_f32(data).ok())
}
#[inline]
pub fn get_f64(&self, idx: usize) -> Option<f64> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_f64(data).ok())
}
#[inline]
pub fn get_str(&self, idx: usize) -> Option<&'a str> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_str(data).ok())
}
#[inline]
pub fn get_bytes(&self, idx: usize) -> Option<&'a [u8]> {
self.get_raw(idx)
}
#[inline]
pub fn column_name(&self, idx: usize) -> &str {
&self.columns[idx].name
}
#[inline]
pub fn column_type_oid(&self, idx: usize) -> u32 {
self.columns[idx].type_oid
}
}
pub struct PgDataRow<'a> {
data: &'a [u8],
offsets: smallvec::SmallVec<[(usize, i32); 16]>,
}
impl<'a> PgDataRow<'a> {
pub fn new(data: &'a [u8]) -> Result<Self, DriverError> {
if data.len() < 2 {
return Err(DriverError::Protocol("DataRow too short".into()));
}
let num_cols = i16::from_be_bytes([data[0], data[1]]);
if num_cols < 0 {
return Err(DriverError::Protocol(
"DataRow: negative column count".into(),
));
}
let num_cols = num_cols as usize;
let mut offsets = smallvec::SmallVec::<[(usize, i32); 16]>::with_capacity(num_cols);
let mut pos = 2usize;
for _ in 0..num_cols {
if pos + 4 > data.len() {
return Err(DriverError::Protocol("DataRow truncated".into()));
}
let col_len =
i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
offsets.push((pos, col_len));
if col_len > 0 {
pos += col_len as usize;
}
}
Ok(Self { data, offsets })
}
#[inline]
pub fn get_raw(&self, idx: usize) -> Option<&'a [u8]> {
let (offset, len) = self.offsets[idx];
if len < 0 {
None
} else {
Some(&self.data[offset..offset + len as usize])
}
}
#[inline]
pub fn is_null(&self, idx: usize) -> bool {
self.offsets[idx].1 < 0
}
#[inline]
pub fn column_count(&self) -> usize {
self.offsets.len()
}
#[inline]
pub fn get_bool(&self, idx: usize) -> Option<bool> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_bool(data).ok())
}
#[inline]
pub fn get_i16(&self, idx: usize) -> Option<i16> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_i16(data).ok())
}
#[inline]
pub fn get_i32(&self, idx: usize) -> Option<i32> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_i32(data).ok())
}
#[inline]
pub fn get_i64(&self, idx: usize) -> Option<i64> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_i64(data).ok())
}
#[inline]
pub fn get_f32(&self, idx: usize) -> Option<f32> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_f32(data).ok())
}
#[inline]
pub fn get_f64(&self, idx: usize) -> Option<f64> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_f64(data).ok())
}
#[inline]
pub fn get_str(&self, idx: usize) -> Option<&'a str> {
self.get_raw(idx)
.and_then(|data| crate::codec::decode_str(data).ok())
}
#[inline]
pub fn get_bytes(&self, idx: usize) -> Option<&'a [u8]> {
self.get_raw(idx)
}
}
pub fn hash_sql(sql: &str) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = RapidHasher::default();
sql.hash(&mut hasher);
hasher.finish()
}
#[cfg(test)]
#[allow(clippy::approx_constant)]
mod tests {
use super::*;
#[test]
fn config_parse_full_url() {
let cfg = Config::from_url("postgres://user:pass@localhost:5432/mydb").unwrap();
assert_eq!(cfg.user, "user");
assert_eq!(cfg.password, "pass");
assert_eq!(cfg.host, "localhost");
assert_eq!(cfg.port, 5432);
assert_eq!(cfg.database, "mydb");
}
#[test]
fn config_parse_default_port() {
let cfg = Config::from_url("postgres://user:pass@localhost/mydb").unwrap();
assert_eq!(cfg.port, 5432);
}
#[test]
fn config_parse_no_password() {
let cfg = Config::from_url("postgres://user@localhost/mydb").unwrap();
assert_eq!(cfg.user, "user");
assert_eq!(cfg.password, "");
}
#[test]
fn config_parse_empty_database() {
let cfg = Config::from_url("postgres://user:pass@localhost").unwrap();
assert_eq!(cfg.database, "user");
}
#[test]
fn config_parse_sslmode() {
let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=require").unwrap();
assert_eq!(cfg.ssl, SslMode::Require);
}
#[test]
fn config_parse_percent_encoding() {
let cfg = Config::from_url("postgres://user%40domain:p%40ss@localhost/db").unwrap();
assert_eq!(cfg.user, "user@domain");
assert_eq!(cfg.password, "p@ss");
}
#[test]
fn config_rejects_bad_scheme() {
let result = Config::from_url("mysql://user:pass@localhost/db");
assert!(result.is_err());
}
#[test]
fn config_rejects_unknown_sslmode() {
let result = Config::from_url("postgres://user:pass@localhost/db?sslmode=requre");
assert!(result.is_err(), "typo 'requre' should be rejected");
let result = Config::from_url("postgres://user:pass@localhost/db?sslmode=REQUIRE");
assert!(result.is_err(), "uppercase should be rejected");
let result = Config::from_url("postgres://user:pass@localhost/db?sslmode=bogus");
assert!(result.is_err(), "bogus value should be rejected");
}
#[test]
fn config_accepts_valid_sslmodes() {
let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=disable").unwrap();
assert_eq!(cfg.ssl, SslMode::Disable);
let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=prefer").unwrap();
assert_eq!(cfg.ssl, SslMode::Prefer);
let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=require").unwrap();
assert_eq!(cfg.ssl, SslMode::Require);
}
#[test]
fn config_parse_postgresql_scheme() {
let cfg = Config::from_url("postgresql://user:pass@localhost:5432/mydb").unwrap();
assert_eq!(cfg.user, "user");
assert_eq!(cfg.password, "pass");
assert_eq!(cfg.host, "localhost");
assert_eq!(cfg.port, 5432);
assert_eq!(cfg.database, "mydb");
}
#[test]
fn config_parse_no_password_standalone() {
let cfg = Config::from_url("postgres://admin@db.example.com/myapp").unwrap();
assert_eq!(cfg.user, "admin");
assert_eq!(cfg.password, "");
assert_eq!(cfg.host, "db.example.com");
assert_eq!(cfg.database, "myapp");
}
#[test]
fn config_empty_database_falls_back_to_user() {
let cfg = Config::from_url("postgres://testuser:pass@localhost").unwrap();
assert_eq!(cfg.database, "testuser");
}
#[test]
fn config_unknown_sslmode_error() {
let result = Config::from_url("postgres://u:p@h/d?sslmode=verify-full");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("unknown sslmode"),
"should describe unknown sslmode: {err}"
);
}
#[test]
fn config_multiple_query_params() {
let cfg = Config::from_url(
"postgres://user:pass@localhost/db?sslmode=disable&statement_timeout=60",
)
.unwrap();
assert_eq!(cfg.ssl, SslMode::Disable);
assert_eq!(cfg.statement_timeout_secs, 60);
}
#[test]
fn config_validate_empty_host() {
let cfg = Config {
host: String::new(),
port: 5432,
user: "user".into(),
password: "pass".into(),
database: "db".into(),
ssl: SslMode::Disable,
statement_timeout_secs: 30,
statement_cache_mode: StatementCacheMode::Named,
ssl_root_cert: None,
ssl_cert: None,
ssl_key: None,
};
assert!(cfg.validate().is_err());
}
#[test]
fn config_validate_empty_user() {
let cfg = Config {
host: "localhost".into(),
port: 5432,
user: String::new(),
password: "pass".into(),
database: "db".into(),
ssl: SslMode::Disable,
statement_timeout_secs: 30,
statement_cache_mode: StatementCacheMode::Named,
ssl_root_cert: None,
ssl_cert: None,
ssl_key: None,
};
assert!(cfg.validate().is_err());
}
#[test]
fn config_validate_empty_database() {
let cfg = Config {
host: "localhost".into(),
port: 5432,
user: "user".into(),
password: "pass".into(),
database: String::new(),
ssl: SslMode::Disable,
statement_timeout_secs: 30,
statement_cache_mode: StatementCacheMode::Named,
ssl_root_cert: None,
ssl_cert: None,
ssl_key: None,
};
assert!(cfg.validate().is_err());
}
#[test]
fn config_missing_at_sign() {
let result = Config::from_url("postgres://userpasslocalhost/db");
assert!(result.is_err());
}
#[test]
fn config_custom_port() {
let cfg = Config::from_url("postgres://user:pass@localhost:5433/db").unwrap();
assert_eq!(cfg.port, 5433);
}
#[test]
fn config_invalid_port() {
let result = Config::from_url("postgres://user:pass@localhost:notaport/db");
assert!(result.is_err());
}
#[cfg(not(feature = "tls"))]
#[test]
fn config_sslmode_require_without_tls_feature() {
let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=require").unwrap();
assert_eq!(cfg.ssl, SslMode::Require);
}
#[test]
fn config_statement_timeout_default() {
let cfg = Config::from_url("postgres://user:pass@localhost/db").unwrap();
assert_eq!(cfg.statement_timeout_secs, 30);
}
#[test]
fn config_statement_timeout_custom() {
let cfg =
Config::from_url("postgres://user:pass@localhost/db?statement_timeout=120").unwrap();
assert_eq!(cfg.statement_timeout_secs, 120);
}
#[test]
fn config_statement_timeout_zero() {
let cfg =
Config::from_url("postgres://user:pass@localhost/db?statement_timeout=0").unwrap();
assert_eq!(cfg.statement_timeout_secs, 0);
}
#[test]
fn config_statement_timeout_invalid_falls_back() {
let cfg =
Config::from_url("postgres://user:pass@localhost/db?statement_timeout=notanumber")
.unwrap();
assert_eq!(cfg.statement_timeout_secs, 30); }
#[test]
fn parse_statement_cache_default() {
let cfg = Config::from_url("postgres://user:pass@localhost/db").unwrap();
assert_eq!(cfg.statement_cache_mode, StatementCacheMode::Named);
}
#[test]
fn parse_statement_cache_named() {
let cfg =
Config::from_url("postgres://user:pass@localhost/db?statement_cache=named").unwrap();
assert_eq!(cfg.statement_cache_mode, StatementCacheMode::Named);
}
#[test]
fn parse_statement_cache_disabled() {
let cfg =
Config::from_url("postgres://user:pass@localhost/db?statement_cache=disabled").unwrap();
assert_eq!(cfg.statement_cache_mode, StatementCacheMode::Disabled);
}
#[test]
fn parse_statement_cache_invalid() {
let result = Config::from_url("postgres://user:pass@localhost/db?statement_cache=off");
assert!(result.is_err(), "invalid value 'off' should be rejected");
let result = Config::from_url("postgres://user:pass@localhost/db?statement_cache=DISABLED");
assert!(result.is_err(), "uppercase should be rejected");
let result = Config::from_url("postgres://user:pass@localhost/db?statement_cache=bogus");
assert!(result.is_err(), "bogus value should be rejected");
}
#[test]
fn parse_statement_cache_with_other_params() {
let cfg = Config::from_url(
"postgres://user:pass@localhost/db?sslmode=disable&statement_cache=disabled&statement_timeout=60",
)
.unwrap();
assert_eq!(cfg.statement_cache_mode, StatementCacheMode::Disabled);
assert_eq!(cfg.ssl, SslMode::Disable);
assert_eq!(cfg.statement_timeout_secs, 60);
}
#[test]
fn statement_cache_mode_default_is_named() {
assert_eq!(StatementCacheMode::default(), StatementCacheMode::Named);
}
#[test]
fn parse_ssl_root_cert() {
let cfg = Config::from_url("postgres://user:pass@localhost/db?sslrootcert=/path/to/ca.pem")
.unwrap();
assert_eq!(cfg.ssl_root_cert.as_deref(), Some("/path/to/ca.pem"));
assert_eq!(cfg.ssl_cert, None);
assert_eq!(cfg.ssl_key, None);
}
#[test]
fn parse_ssl_cert_and_key() {
let cfg = Config::from_url(
"postgres://user:pass@localhost/db?sslcert=/path/to/client.pem&sslkey=/path/to/client.key",
)
.unwrap();
assert_eq!(cfg.ssl_root_cert, None);
assert_eq!(cfg.ssl_cert.as_deref(), Some("/path/to/client.pem"));
assert_eq!(cfg.ssl_key.as_deref(), Some("/path/to/client.key"));
}
#[test]
fn parse_ssl_all_tls_params() {
let cfg = Config::from_url(
"postgres://user:pass@localhost/db?sslmode=require&sslrootcert=/ca.pem&sslcert=/client.pem&sslkey=/client.key",
)
.unwrap();
assert_eq!(cfg.ssl, SslMode::Require);
assert_eq!(cfg.ssl_root_cert.as_deref(), Some("/ca.pem"));
assert_eq!(cfg.ssl_cert.as_deref(), Some("/client.pem"));
assert_eq!(cfg.ssl_key.as_deref(), Some("/client.key"));
}
#[test]
fn parse_ssl_paths_percent_encoded() {
let cfg = Config::from_url("postgres://user:pass@localhost/db?sslrootcert=%2Ftmp%2Fca.pem")
.unwrap();
assert_eq!(cfg.ssl_root_cert.as_deref(), Some("/tmp/ca.pem"));
}
#[test]
fn parse_ssl_params_default_none() {
let cfg = Config::from_url("postgres://user:pass@localhost/db").unwrap();
assert_eq!(cfg.ssl_root_cert, None);
assert_eq!(cfg.ssl_cert, None);
assert_eq!(cfg.ssl_key, None);
}
#[test]
fn config_uds_path_format() {
let cfg = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
assert_eq!(cfg.uds_path(), "/tmp/.s.PGSQL.5432");
}
#[test]
fn config_uds_path_custom_port() {
let cfg = Config::from_url("postgres://user@localhost:5433/db?host=/tmp").unwrap();
assert_eq!(cfg.uds_path(), "/tmp/.s.PGSQL.5433");
}
#[test]
fn config_host_is_uds_absolute_path() {
let cfg = Config {
host: "/tmp".into(),
port: 5432,
user: "user".into(),
password: "".into(),
database: "db".into(),
ssl: SslMode::Disable,
statement_timeout_secs: 30,
statement_cache_mode: StatementCacheMode::Named,
ssl_root_cert: None,
ssl_cert: None,
ssl_key: None,
};
assert!(cfg.host_is_uds());
assert_eq!(cfg.uds_path(), "/tmp/.s.PGSQL.5432");
}
#[test]
fn config_host_is_uds_var_run() {
let cfg = Config {
host: "/var/run/postgresql".into(),
port: 5433,
user: "user".into(),
password: "".into(),
database: "db".into(),
ssl: SslMode::Disable,
statement_timeout_secs: 30,
statement_cache_mode: StatementCacheMode::Named,
ssl_root_cert: None,
ssl_cert: None,
ssl_key: None,
};
assert!(cfg.host_is_uds());
assert_eq!(cfg.uds_path(), "/var/run/postgresql/.s.PGSQL.5433");
}
#[test]
fn config_host_is_not_uds_for_hostname() {
let cfg = Config {
host: "localhost".into(),
port: 5432,
user: "user".into(),
password: "".into(),
database: "db".into(),
ssl: SslMode::Disable,
statement_timeout_secs: 30,
statement_cache_mode: StatementCacheMode::Named,
ssl_root_cert: None,
ssl_cert: None,
ssl_key: None,
};
assert!(!cfg.host_is_uds());
}
#[test]
fn config_host_is_not_uds_for_ip() {
let cfg = Config {
host: "127.0.0.1".into(),
port: 5432,
user: "user".into(),
password: "".into(),
database: "db".into(),
ssl: SslMode::Disable,
statement_timeout_secs: 30,
statement_cache_mode: StatementCacheMode::Named,
ssl_root_cert: None,
ssl_cert: None,
ssl_key: None,
};
assert!(!cfg.host_is_uds());
}
#[test]
fn config_parse_uds_host_query_param() {
let cfg = Config::from_url("postgres://user@localhost/mydb?host=/tmp").unwrap();
assert_eq!(cfg.host, "/tmp");
assert!(cfg.host_is_uds());
assert_eq!(cfg.uds_path(), "/tmp/.s.PGSQL.5432");
assert_eq!(cfg.database, "mydb");
assert_eq!(cfg.user, "user");
}
#[test]
fn config_parse_uds_host_query_param_custom_port() {
let cfg = Config::from_url("postgres://user@localhost:5433/mydb?host=/var/run/postgresql")
.unwrap();
assert_eq!(cfg.host, "/var/run/postgresql");
assert_eq!(cfg.port, 5433);
assert_eq!(cfg.uds_path(), "/var/run/postgresql/.s.PGSQL.5433");
}
#[test]
fn config_parse_uds_host_with_other_params() {
let cfg = Config::from_url(
"postgres://user@localhost/db?host=/tmp&sslmode=disable&statement_timeout=60",
)
.unwrap();
assert_eq!(cfg.host, "/tmp");
assert!(cfg.host_is_uds());
assert_eq!(cfg.ssl, SslMode::Disable);
assert_eq!(cfg.statement_timeout_secs, 60);
}
#[test]
fn config_parse_uds_host_percent_encoded() {
let cfg = Config::from_url("postgres://user@localhost/db?host=%2Ftmp").unwrap();
assert_eq!(cfg.host, "/tmp");
assert!(cfg.host_is_uds());
}
#[test]
fn config_parse_tcp_host_not_overridden_without_param() {
let cfg = Config::from_url("postgres://user@myserver/db").unwrap();
assert_eq!(cfg.host, "myserver");
assert!(!cfg.host_is_uds());
}
#[test]
fn config_parse_uds_host_overrides_url_hostname() {
let cfg = Config::from_url("postgres://user@db.example.com/mydb?host=/var/run/postgresql")
.unwrap();
assert_eq!(cfg.host, "/var/run/postgresql");
assert!(cfg.host_is_uds());
}
#[test]
fn config_parse_uds_empty_url_host() {
let cfg = Config::from_url("postgres://user@/mydb?host=/tmp").unwrap();
assert_eq!(cfg.host, "/tmp");
assert!(cfg.host_is_uds());
assert_eq!(cfg.database, "mydb");
}
#[test]
fn url_decode_works() {
assert_eq!(url_decode("hello%20world").unwrap(), "hello world");
assert_eq!(url_decode("no%20escape").unwrap(), "no escape");
assert_eq!(url_decode("plain").unwrap(), "plain");
assert_eq!(url_decode("a%40b").unwrap(), "a@b");
}
#[test]
fn url_decode_malformed_percent_trailing() {
let result = url_decode("abc%2");
assert!(result.is_err(), "truncated %2 should error");
}
#[test]
fn url_decode_malformed_percent_no_digits() {
let result = url_decode("abc%");
assert!(result.is_err(), "bare % at end should error");
}
#[test]
fn url_decode_invalid_hex_digit() {
let result = url_decode("abc%GG");
assert!(result.is_err(), "%GG should error");
}
#[test]
fn url_decode_invalid_hex_second_digit() {
let result = url_decode("abc%2Z");
assert!(result.is_err(), "%2Z should error");
}
#[test]
fn url_decode_invalid_utf8_percent() {
let result = url_decode("%80%81");
assert!(result.is_err(), "invalid UTF-8 bytes should error");
}
#[test]
fn url_decode_percent_everywhere() {
assert_eq!(url_decode("%41%42%43").unwrap(), "ABC");
assert_eq!(url_decode("%61").unwrap(), "a");
assert_eq!(url_decode("x%2Fy%2Fz").unwrap(), "x/y/z");
}
#[test]
fn url_decode_bare_percent_middle() {
assert!(url_decode("a%b").is_err(), "bare % in middle should error");
}
#[test]
fn url_decode_multibyte_utf8() {
let result = url_decode("caf%C3%A9").unwrap();
assert_eq!(result, "caf\u{00e9}"); }
#[test]
fn url_decode_invalid_percent_zz() {
let result = url_decode("abc%ZZ");
assert!(result.is_err(), "%ZZ should error");
}
#[test]
fn url_decode_truncated_percent_trailing() {
let result = url_decode("abc%");
assert!(result.is_err(), "trailing % should error");
}
#[test]
fn url_decode_invalid_utf8() {
let result = url_decode("%80");
assert!(result.is_err(), "invalid UTF-8 should error");
}
#[test]
fn url_decode_empty_string() {
assert_eq!(url_decode("").unwrap(), "");
}
#[test]
fn url_decode_no_encoding() {
assert_eq!(url_decode("hello").unwrap(), "hello");
}
#[test]
fn url_decode_all_ascii_hex() {
assert_eq!(url_decode("%2F").unwrap(), "/");
assert_eq!(url_decode("%2f").unwrap(), "/");
}
#[test]
fn config_unicode_password() {
let cfg =
Config::from_url("postgres://user:%D0%BF%D0%B0%D1%80%D0%BE%D0%BB%D1%8C@localhost/db")
.unwrap();
assert_eq!(cfg.user, "user");
assert_eq!(
cfg.password,
"\u{043F}\u{0430}\u{0440}\u{043E}\u{043B}\u{044C}"
); assert_eq!(cfg.host, "localhost");
assert_eq!(cfg.database, "db");
}
#[test]
fn config_port_zero() {
let cfg = Config::from_url("postgres://user:pass@localhost:0/db").unwrap();
assert_eq!(cfg.port, 0);
}
#[test]
fn config_port_max() {
let cfg = Config::from_url("postgres://user:pass@localhost:65535/db").unwrap();
assert_eq!(cfg.port, 65535);
}
#[test]
fn config_port_overflow() {
let result = Config::from_url("postgres://user:pass@localhost:65536/db");
assert!(result.is_err(), "port 65536 exceeds u16 max");
}
#[test]
fn config_unknown_param_ignored() {
let cfg = Config::from_url(
"postgres://user:pass@localhost/db?application_name=myapp&connect_timeout=10",
)
.unwrap();
assert_eq!(cfg.user, "user");
assert_eq!(cfg.host, "localhost");
assert_eq!(cfg.database, "db");
assert_eq!(cfg.statement_timeout_secs, 30);
assert_eq!(cfg.ssl, SslMode::Prefer);
}
#[test]
fn url_decode_double_percent_encoding() {
assert_eq!(url_decode("%2525").unwrap(), "%25");
}
#[test]
fn config_explicit_empty_password() {
let cfg = Config::from_url("postgres://user:@localhost/db").unwrap();
assert_eq!(cfg.user, "user");
assert_eq!(cfg.password, "");
}
#[test]
fn config_special_chars_in_user() {
let cfg = Config::from_url("postgres://my%2Fuser:pass@localhost/my%2Fdb").unwrap();
assert_eq!(cfg.user, "my/user");
assert_eq!(cfg.database, "my/db");
}
#[test]
fn url_decode_plus_is_literal() {
assert_eq!(url_decode("a+b").unwrap(), "a+b");
}
#[test]
fn config_minimal_valid_url() {
let cfg = Config::from_url("postgres://user@localhost/db").unwrap();
assert_eq!(cfg.user, "user");
assert_eq!(cfg.password, "");
assert_eq!(cfg.host, "localhost");
assert_eq!(cfg.port, 5432);
assert_eq!(cfg.database, "db");
}
#[test]
fn config_empty_param_segments() {
let cfg =
Config::from_url("postgres://user:pass@localhost/db?&&statement_timeout=60&&").unwrap();
assert_eq!(cfg.statement_timeout_secs, 60);
}
#[test]
fn hash_sql_deterministic() {
let h1 = hash_sql("SELECT 1");
let h2 = hash_sql("SELECT 1");
assert_eq!(h1, h2);
}
#[test]
fn hash_sql_different_queries() {
let h1 = hash_sql("SELECT 1");
let h2 = hash_sql("SELECT 2");
assert_ne!(h1, h2);
}
#[test]
fn hash_sql_empty() {
let _h = hash_sql(""); }
#[test]
fn hash_sql_whitespace_only() {
let h = hash_sql(" ");
assert_ne!(h, hash_sql(""));
}
#[test]
fn hash_sql_very_long() {
let long_sql = "SELECT ".to_string() + &"x".repeat(10_000);
let h = hash_sql(&long_sql);
assert_eq!(h, hash_sql(&long_sql));
}
#[test]
fn hash_sql_unicode() {
let h = hash_sql("SELECT '\u{1F600}'");
assert_ne!(h, hash_sql("SELECT 'x'"));
}
#[test]
fn notification_struct_fields() {
let n = Notification {
pid: 42,
channel: "test_chan".to_owned(),
payload: "hello".to_owned(),
};
assert_eq!(n.pid, 42);
assert_eq!(n.channel, "test_chan");
assert_eq!(n.payload, "hello");
}
#[test]
fn notification_clone() {
let n = Notification {
pid: 1,
channel: "c".to_owned(),
payload: "p".to_owned(),
};
let n2 = n.clone();
assert_eq!(n2.pid, 1);
assert_eq!(n2.channel, "c");
}
#[test]
fn notification_debug() {
let n = Notification {
pid: 1,
channel: "c".to_owned(),
payload: "p".to_owned(),
};
let dbg = format!("{n:?}");
assert!(dbg.contains("Notification"));
}
#[test]
fn query_result_empty() {
let result = QueryResult {
all_col_offsets: vec![],
num_cols: 0,
columns: Arc::from(Vec::new()),
affected_rows: 0,
data_buf: None,
};
assert!(result.is_empty());
assert_eq!(result.len(), 0);
}
#[test]
fn query_result_from_parts() {
let result = QueryResult::from_parts(vec![(0, 4), (0, -1)], 2, Arc::from(Vec::new()), 5);
assert_eq!(result.len(), 1);
assert_eq!(result.num_cols, 2);
assert_eq!(result.affected_rows, 5);
}
#[test]
fn query_result_affected_rows() {
let result = QueryResult {
all_col_offsets: vec![],
num_cols: 0,
columns: Arc::from(Vec::new()),
affected_rows: 42,
data_buf: None,
};
assert_eq!(result.affected_rows, 42);
assert!(result.is_empty());
}
fn make_data_row(columns: &[Option<&[u8]>]) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&(columns.len() as i16).to_be_bytes());
for col in columns {
match col {
Some(data) => {
buf.extend_from_slice(&(data.len() as i32).to_be_bytes());
buf.extend_from_slice(data);
}
None => {
buf.extend_from_slice(&(-1i32).to_be_bytes());
}
}
}
buf
}
#[test]
fn pg_data_row_get_i32() {
let data = make_data_row(&[Some(&42i32.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i32(0), Some(42));
assert_eq!(row.column_count(), 1);
}
#[test]
fn pg_data_row_get_i64() {
let data = make_data_row(&[Some(&12345i64.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i64(0), Some(12345));
}
#[test]
fn pg_data_row_get_str() {
let data = make_data_row(&[Some(b"hello")]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_str(0), Some("hello"));
}
#[test]
fn pg_data_row_get_bytes() {
let data = make_data_row(&[Some(&[0xDE, 0xAD, 0xBE, 0xEF])]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_bytes(0), Some(&[0xDE, 0xAD, 0xBE, 0xEF][..]));
}
#[test]
fn pg_data_row_get_bool() {
let data = make_data_row(&[Some(&[1u8])]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_bool(0), Some(true));
let data = make_data_row(&[Some(&[0u8])]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_bool(0), Some(false));
}
#[test]
fn pg_data_row_get_f64() {
let data = make_data_row(&[Some(&3.14f64.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert!((row.get_f64(0).unwrap() - 3.14).abs() < 1e-10);
}
#[test]
fn pg_data_row_null_column() {
let data = make_data_row(&[None]);
let row = PgDataRow::new(&data).unwrap();
assert!(row.is_null(0));
assert_eq!(row.get_i32(0), None);
assert_eq!(row.get_str(0), None);
}
#[test]
fn pg_data_row_multiple_columns() {
let data = make_data_row(&[
Some(&42i32.to_be_bytes()),
Some(b"alice"),
Some(b"alice@example.com"),
Some(&[1u8]),
Some(&3.14f64.to_be_bytes()),
]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.column_count(), 5);
assert_eq!(row.get_i32(0), Some(42));
assert_eq!(row.get_str(1), Some("alice"));
assert_eq!(row.get_str(2), Some("alice@example.com"));
assert_eq!(row.get_bool(3), Some(true));
assert!((row.get_f64(4).unwrap() - 3.14).abs() < 1e-10);
}
#[test]
fn pg_data_row_mixed_null() {
let data = make_data_row(&[Some(&42i32.to_be_bytes()), None, Some(b"text")]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i32(0), Some(42));
assert!(row.is_null(1));
assert_eq!(row.get_str(1), None);
assert_eq!(row.get_str(2), Some("text"));
}
#[test]
fn pg_data_row_empty() {
let data = make_data_row(&[]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.column_count(), 0);
}
#[test]
fn pg_data_row_too_short() {
let data = vec![0u8]; assert!(PgDataRow::new(&data).is_err());
}
#[test]
fn pg_data_row_truncated() {
let mut data = Vec::new();
data.extend_from_slice(&2i16.to_be_bytes());
data.extend_from_slice(&4i32.to_be_bytes());
data.extend_from_slice(&42i32.to_be_bytes());
assert!(PgDataRow::new(&data).is_err());
}
#[test]
fn pg_data_row_get_i16() {
let data = make_data_row(&[Some(&7i16.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i16(0), Some(7));
}
#[test]
fn pg_data_row_get_f32() {
let data = make_data_row(&[Some(&2.5f32.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert!((row.get_f32(0).unwrap() - 2.5).abs() < 1e-6);
}
#[test]
fn pg_data_row_get_raw_null() {
let data = make_data_row(&[None]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_raw(0), None);
}
#[test]
fn pg_data_row_get_raw_data() {
let data = make_data_row(&[Some(&[1, 2, 3])]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_raw(0), Some(&[1u8, 2, 3][..]));
}
#[test]
fn pg_data_row_stack_alloc_16_columns() {
let cols: Vec<Option<&[u8]>> = (0..16).map(|_| Some(&[0u8][..])).collect();
let data = make_data_row(&cols);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.column_count(), 16);
for i in 0..16 {
assert_eq!(row.get_raw(i), Some(&[0u8][..]));
}
}
#[test]
fn inline_sequential_decode_five_columns() {
let data = make_data_row(&[
Some(&42i32.to_be_bytes()),
Some(b"alice"),
Some(b"alice@example.com"),
Some(&[1u8]),
Some(&3.14f64.to_be_bytes()),
]);
let mut pos: usize = 2;
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
assert_eq!(len, 4);
let id = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += len as usize;
assert_eq!(id, 42);
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
assert_eq!(len, 5);
let name = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
pos += len as usize;
assert_eq!(name, "alice");
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let email = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
pos += len as usize;
assert_eq!(email, "alice@example.com");
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
assert_eq!(len, 1);
let active = data[pos] != 0;
pos += len as usize;
assert!(active);
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
assert_eq!(len, 8);
let score = f64::from_be_bytes([
data[pos],
data[pos + 1],
data[pos + 2],
data[pos + 3],
data[pos + 4],
data[pos + 5],
data[pos + 6],
data[pos + 7],
]);
pos += len as usize;
assert!((score - 3.14).abs() < 1e-10);
assert_eq!(pos, data.len());
}
#[test]
fn inline_sequential_decode_with_nulls() {
let data = make_data_row(&[
Some(&42i32.to_be_bytes()),
None, Some(b"text"),
]);
let mut pos: usize = 2;
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let id = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += len as usize;
assert_eq!(id, 42);
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let name: Option<&str> = if len < 0 {
None
} else {
let s = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
pos += len as usize;
Some(s)
};
assert!(name.is_none());
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let txt = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
pos += len as usize;
assert_eq!(txt, "text");
assert_eq!(pos, data.len());
}
#[test]
fn inline_sequential_decode_all_scalar_types() {
let data = make_data_row(&[
Some(&[1u8]), Some(&7i16.to_be_bytes()), Some(&42i32.to_be_bytes()), Some(&12345i64.to_be_bytes()), Some(&2.5f32.to_be_bytes()), Some(&3.14f64.to_be_bytes()), ]);
let mut pos: usize = 2;
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let v_bool = data[pos] != 0;
pos += len as usize;
assert!(v_bool);
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let v_i16 = i16::from_be_bytes([data[pos], data[pos + 1]]);
pos += len as usize;
assert_eq!(v_i16, 7);
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let v_i32 = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += len as usize;
assert_eq!(v_i32, 42);
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let v_i64 = i64::from_be_bytes([
data[pos],
data[pos + 1],
data[pos + 2],
data[pos + 3],
data[pos + 4],
data[pos + 5],
data[pos + 6],
data[pos + 7],
]);
pos += len as usize;
assert_eq!(v_i64, 12345);
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let v_f32 = f32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += len as usize;
assert!((v_f32 - 2.5).abs() < 1e-6);
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let v_f64 = f64::from_be_bytes([
data[pos],
data[pos + 1],
data[pos + 2],
data[pos + 3],
data[pos + 4],
data[pos + 5],
data[pos + 6],
data[pos + 7],
]);
pos += len as usize;
assert!((v_f64 - 3.14).abs() < 1e-10);
assert_eq!(pos, data.len());
}
#[test]
fn pg_data_row_new_is_public() {
let data = make_data_row(&[Some(&42i32.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i32(0), Some(42));
}
#[test]
fn inline_decode_matches_pgdatarow() {
let data = make_data_row(&[
Some(&99i32.to_be_bytes()),
Some(b"hello world"),
None,
Some(&[0u8]),
Some(&1.23f64.to_be_bytes()),
]);
let row = PgDataRow::new(&data).unwrap();
let dr_i32 = row.get_i32(0);
let dr_str = row.get_str(1);
let dr_null = row.get_str(2);
let dr_bool = row.get_bool(3);
let dr_f64 = row.get_f64(4);
let mut pos: usize = 2;
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let in_i32 = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += len as usize;
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let in_str = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
pos += len as usize;
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let in_null: Option<&str> = if len < 0 { None } else { unreachable!() };
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let in_bool = data[pos] != 0;
pos += len as usize;
let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let in_f64 = f64::from_be_bytes([
data[pos],
data[pos + 1],
data[pos + 2],
data[pos + 3],
data[pos + 4],
data[pos + 5],
data[pos + 6],
data[pos + 7],
]);
pos += len as usize;
assert_eq!(dr_i32, Some(in_i32));
assert_eq!(dr_str, Some(in_str));
assert_eq!(dr_null, in_null);
assert_eq!(dr_bool, Some(in_bool));
assert!((dr_f64.unwrap() - in_f64).abs() < 1e-15);
assert_eq!(pos, data.len());
}
#[test]
fn pg_data_row_all_null_columns() {
let data = make_data_row(&[None, None, None, None, None]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.column_count(), 5);
for i in 0..5 {
assert!(row.is_null(i), "column {i} should be null");
assert_eq!(row.get_raw(i), None);
assert_eq!(row.get_i32(i), None);
assert_eq!(row.get_i64(i), None);
assert_eq!(row.get_str(i), None);
assert_eq!(row.get_bool(i), None);
assert_eq!(row.get_f64(i), None);
}
}
#[test]
fn pg_data_row_very_long_text() {
let long_text = "x".repeat(2048);
let data = make_data_row(&[Some(long_text.as_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_str(0), Some(long_text.as_str()));
}
#[test]
fn pg_data_row_empty_text() {
let data = make_data_row(&[Some(b"")]);
let row = PgDataRow::new(&data).unwrap();
assert!(!row.is_null(0));
assert_eq!(row.get_str(0), Some(""));
assert_eq!(row.get_bytes(0), Some(&[][..]));
}
#[test]
fn pg_data_row_20_columns_exceeds_inline() {
let col_data: Vec<[u8; 4]> = (0..20).map(|i: i32| i.to_be_bytes()).collect();
let cols: Vec<Option<&[u8]>> = col_data.iter().map(|b| Some(b.as_slice())).collect();
let data = make_data_row(&cols);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.column_count(), 20);
for i in 0..20 {
assert_eq!(row.get_i32(i), Some(i as i32));
}
}
#[test]
fn pg_data_row_is_null_each_position() {
let data = make_data_row(&[Some(&1i32.to_be_bytes()), None, Some(&3i32.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert!(!row.is_null(0));
assert!(row.is_null(1));
assert!(!row.is_null(2));
}
#[test]
fn pg_data_row_negative_column_count() {
let data = (-1i16).to_be_bytes();
assert!(PgDataRow::new(&data).is_err());
}
#[test]
fn pg_data_row_get_str_invalid_utf8() {
let invalid_utf8 = &[0xFF, 0xFE, 0x80];
let data = make_data_row(&[Some(invalid_utf8)]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_str(0), None);
assert_eq!(row.get_bytes(0), Some(&[0xFF, 0xFE, 0x80][..]));
}
#[test]
fn pg_data_row_get_i32_wrong_length() {
let data = make_data_row(&[Some(&7i16.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i32(0), None); assert_eq!(row.get_i16(0), Some(7)); }
#[test]
fn pg_data_row_get_i64_wrong_length() {
let data = make_data_row(&[Some(&42i32.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i64(0), None);
}
#[test]
fn pg_data_row_get_f64_wrong_length() {
let data = make_data_row(&[Some(&2.5f32.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_f64(0), None); }
#[test]
fn pg_data_row_get_f32_wrong_length() {
let data = make_data_row(&[Some(&3.14f64.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_f32(0), None); }
#[test]
fn pg_data_row_get_bool_wrong_length() {
let data = make_data_row(&[Some(&42i32.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_bool(0), None);
}
#[test]
fn pg_data_row_unicode_text() {
let texts = [
"\u{1F600}\u{1F4A9}\u{1F680}", "\u{4e16}\u{754c}", "\u{0645}\u{0631}\u{062D}", "\u{1F468}\u{200D}\u{1F469}", ];
for text in &texts {
let data = make_data_row(&[Some(text.as_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_str(0), Some(*text));
}
}
#[test]
fn pg_data_row_i32_boundary_values() {
for &val in &[i32::MIN, -1, 0, 1, i32::MAX] {
let data = make_data_row(&[Some(&val.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i32(0), Some(val), "failed for {val}");
}
}
#[test]
fn pg_data_row_i64_boundary_values() {
for &val in &[i64::MIN, -1, 0, 1, i64::MAX] {
let data = make_data_row(&[Some(&val.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i64(0), Some(val), "failed for {val}");
}
}
#[test]
fn pg_data_row_f64_special_values() {
let data = make_data_row(&[Some(&f64::INFINITY.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_f64(0), Some(f64::INFINITY));
let data = make_data_row(&[Some(&f64::NEG_INFINITY.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_f64(0), Some(f64::NEG_INFINITY));
let data = make_data_row(&[Some(&f64::NAN.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert!(row.get_f64(0).unwrap().is_nan());
}
#[test]
fn pg_data_row_f32_special_values() {
let data = make_data_row(&[Some(&f32::INFINITY.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_f32(0), Some(f32::INFINITY));
let data = make_data_row(&[Some(&f32::NAN.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert!(row.get_f32(0).unwrap().is_nan());
}
#[test]
fn pg_data_row_i16_boundary_values() {
for &val in &[i16::MIN, -1, 0, 1, i16::MAX] {
let data = make_data_row(&[Some(&val.to_be_bytes())]);
let row = PgDataRow::new(&data).unwrap();
assert_eq!(row.get_i16(0), Some(val));
}
}
mod proptest_fuzz {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn config_from_url_never_panics(url in ".*") {
let _ = Config::from_url(&url);
}
#[test]
fn url_decode_never_panics(s in ".*") {
let _ = url_decode(&s);
}
#[test]
fn pg_data_row_new_never_panics(data in proptest::collection::vec(any::<u8>(), 0..8192)) {
let _ = PgDataRow::new(&data);
}
#[test]
fn hash_sql_never_panics(sql in ".*") {
let _ = hash_sql(&sql);
}
}
}
}