use std::io::{Read, Write};
use std::os::unix::net::UnixStream;
use std::sync::Arc;
use crate::DriverError;
use crate::arena::Arena;
use crate::auth;
use crate::codec::Encode;
use crate::conn::{ColumnDesc, Config, Notification, PgDataRow, QueryResult, SimpleRow};
use crate::proto::{self, BackendMessage};
struct StmtCache {
entries: Vec<(u64, StmtInfo)>,
}
impl Default for StmtCache {
fn default() -> Self {
Self {
entries: Vec::with_capacity(16),
}
}
}
impl StmtCache {
#[inline]
fn get_mut(&mut self, hash: &u64) -> Option<&mut StmtInfo> {
self.entries
.iter_mut()
.find(|(h, _)| h == hash)
.map(|(_, info)| info)
}
#[inline]
fn get(&self, hash: &u64) -> Option<&StmtInfo> {
self.entries
.iter()
.find(|(h, _)| h == hash)
.map(|(_, info)| info)
}
#[inline]
fn contains_key(&self, hash: &u64) -> bool {
self.entries.iter().any(|(h, _)| h == hash)
}
#[inline]
fn insert(&mut self, hash: u64, info: StmtInfo) {
if let Some(entry) = self.entries.iter_mut().find(|(h, _)| *h == hash) {
entry.1 = info;
} else {
self.entries.push((hash, info));
}
}
#[inline]
fn remove(&mut self, hash: &u64) -> Option<StmtInfo> {
if let Some(pos) = self.entries.iter().position(|(h, _)| h == hash) {
Some(self.entries.swap_remove(pos).1)
} else {
None
}
}
#[inline]
fn len(&self) -> usize {
self.entries.len()
}
fn evict_lru(&mut self) -> Option<(u64, StmtInfo)> {
if self.entries.is_empty() {
return None;
}
let min_idx = self
.entries
.iter()
.enumerate()
.min_by_key(|(_, (_, info))| info.last_used)
.map(|(i, _)| i)?;
Some(self.entries.swap_remove(min_idx))
}
}
struct StmtInfo {
name: Box<str>,
columns: Arc<[ColumnDesc]>,
last_used: u64,
bind_template: Option<BindTemplate>,
}
struct BindTemplate {
bytes: Vec<u8>,
param_slots: Vec<(usize, i32)>,
}
#[inline]
fn make_stmt_name(hash: u64) -> Box<str> {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut buf = [0u8; 18];
buf[0] = b's';
buf[1] = b'_';
let bytes = hash.to_be_bytes();
for (i, &b) in bytes.iter().enumerate() {
buf[2 + i * 2] = HEX[(b >> 4) as usize];
buf[2 + i * 2 + 1] = HEX[(b & 0x0f) as usize];
}
let s = std::str::from_utf8(&buf).expect("BUG: stmt name buffer contains only ASCII hex");
s.into()
}
enum StartupAction {
AuthOk,
AuthCleartext,
AuthMd5([u8; 4]),
AuthSasl(Vec<u8>),
ParameterStatus(Box<str>, Box<str>),
BackendKeyData(i32, i32),
ReadyForQuery(u8),
Error(String),
Notice,
}
pub struct SyncConnection {
stream: UnixStream,
read_buf: Vec<u8>,
stream_buf: Vec<u8>,
stream_buf_pos: usize,
stream_buf_end: usize,
write_buf: Vec<u8>,
stmts: StmtCache,
params: Vec<(Box<str>, Box<str>)>,
pid: i32,
secret: i32,
tx_status: u8,
last_used: std::time::Instant,
created_at: std::time::Instant,
pending_notifications: Vec<Notification>,
max_stmt_cache_size: usize,
query_counter: u64,
}
impl std::fmt::Debug for SyncConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SyncConnection")
.field("pid", &self.pid)
.field("tx_status", &(self.tx_status as char))
.field("stmt_cache_len", &self.stmts.len())
.finish()
}
}
impl SyncConnection {
pub fn connect(config: &Config) -> Result<Self, DriverError> {
if !config.host_is_uds() {
return Err(DriverError::Protocol(
"SyncConnection requires a Unix domain socket path (host starting with '/')".into(),
));
}
let path = config.uds_path();
let stream = UnixStream::connect(&path).map_err(DriverError::Io)?;
let mut conn = Self {
stream,
read_buf: Vec::with_capacity(8192),
stream_buf: vec![0u8; 65536],
stream_buf_pos: 0,
stream_buf_end: 0,
write_buf: Vec::with_capacity(4096),
stmts: StmtCache::default(),
params: Vec::new(),
pid: 0,
secret: 0,
tx_status: b'I',
last_used: std::time::Instant::now(),
created_at: std::time::Instant::now(),
pending_notifications: Vec::new(),
max_stmt_cache_size: 256,
query_counter: 0,
};
conn.startup(config)?;
conn.validate_server_params()?;
if config.statement_timeout_secs > 0 {
conn.simple_query(&format!(
"SET statement_timeout = '{}s'",
config.statement_timeout_secs
))?;
}
Ok(conn)
}
fn startup(&mut self, config: &Config) -> Result<(), DriverError> {
self.write_buf.clear();
proto::write_startup(&mut self.write_buf, &config.user, &config.database);
self.flush_write()?;
loop {
let action = self.read_startup_action()?;
match action {
StartupAction::AuthOk => {}
StartupAction::AuthCleartext => {
self.write_buf.clear();
let mut pw = config.password.as_bytes().to_vec();
pw.push(0);
proto::write_password(&mut self.write_buf, &pw);
self.flush_write()?;
}
StartupAction::AuthMd5(salt) => {
self.write_buf.clear();
let hash = auth::md5_password(&config.user, &config.password, &salt);
proto::write_password(&mut self.write_buf, &hash);
self.flush_write()?;
}
StartupAction::AuthSasl(mechanisms_data) => {
self.handle_scram(config, &mechanisms_data)?;
}
StartupAction::ParameterStatus(name, value) => {
if let Some(entry) = self.params.iter_mut().find(|(k, _)| *k == name) {
entry.1 = value;
} else {
self.params.push((name, value));
}
}
StartupAction::BackendKeyData(pid, secret) => {
self.pid = pid;
self.secret = secret;
}
StartupAction::ReadyForQuery(status) => {
self.tx_status = status;
return Ok(());
}
StartupAction::Error(msg) => {
return Err(DriverError::Auth(msg));
}
StartupAction::Notice => {}
}
}
}
fn read_startup_action(&mut self) -> Result<StartupAction, DriverError> {
let (msg_type, _) = self.read_message_buffered()?;
let payload = &self.read_buf;
let msg = proto::parse_backend_message(msg_type, payload)?;
match msg {
BackendMessage::AuthOk => Ok(StartupAction::AuthOk),
BackendMessage::AuthCleartext => Ok(StartupAction::AuthCleartext),
BackendMessage::AuthMd5 { salt } => Ok(StartupAction::AuthMd5(salt)),
BackendMessage::AuthSasl { mechanisms } => {
Ok(StartupAction::AuthSasl(mechanisms.to_vec()))
}
BackendMessage::ParameterStatus { name, value } => {
Ok(StartupAction::ParameterStatus(name.into(), value.into()))
}
BackendMessage::BackendKeyData { pid, secret } => {
Ok(StartupAction::BackendKeyData(pid, secret))
}
BackendMessage::ReadyForQuery { status } => Ok(StartupAction::ReadyForQuery(status)),
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
Ok(StartupAction::Error(fields.to_string()))
}
BackendMessage::NoticeResponse { .. } => Ok(StartupAction::Notice),
other => Err(DriverError::Protocol(format!(
"unexpected message during startup: {other:?}"
))),
}
}
fn handle_scram(&mut self, config: &Config, mechanisms_data: &[u8]) -> Result<(), DriverError> {
let mechs = auth::parse_sasl_mechanisms(mechanisms_data);
if !mechs.contains(&"SCRAM-SHA-256") {
return Err(DriverError::Auth(format!(
"server requires unsupported SASL mechanism(s): {mechs:?}"
)));
}
let mut scram = auth::ScramClient::new(&config.user, &config.password)?;
let client_first = scram.client_first_message();
self.write_buf.clear();
proto::write_sasl_initial(&mut self.write_buf, "SCRAM-SHA-256", &client_first);
self.flush_write()?;
let (msg_type, _) = self.read_message_buffered()?;
let server_first = {
let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
match msg {
BackendMessage::AuthSaslContinue { data } => data.to_vec(),
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
return Err(DriverError::Auth(fields.to_string()));
}
other => {
return Err(DriverError::Protocol(format!(
"expected AuthSaslContinue, got: {other:?}"
)));
}
}
};
scram.process_server_first(&server_first)?;
let client_final = scram.client_final_message()?;
self.write_buf.clear();
proto::write_sasl_response(&mut self.write_buf, &client_final);
self.flush_write()?;
let (msg_type, _) = self.read_message_buffered()?;
{
let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
match msg {
BackendMessage::AuthSaslFinal { data } => {
let data_owned = data.to_vec();
scram.verify_server_final(&data_owned)?;
}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
return Err(DriverError::Auth(fields.to_string()));
}
other => {
return Err(DriverError::Protocol(format!(
"expected AuthSaslFinal, got: {other:?}"
)));
}
}
}
let (msg_type, _) = self.read_message_buffered()?;
let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
match msg {
BackendMessage::AuthOk => Ok(()),
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
Err(DriverError::Auth(fields.to_string()))
}
other => Err(DriverError::Protocol(format!(
"expected AuthOk after SCRAM, got: {other:?}"
))),
}
}
fn validate_server_params(&self) -> Result<(), DriverError> {
if let Some(encoding) = self.parameter("server_encoding") {
let normalized = encoding.to_uppercase();
if normalized != "UTF8" && normalized != "UTF-8" {
return Err(DriverError::Protocol(format!(
"server_encoding is '{encoding}', but bsql requires UTF-8."
)));
}
}
if let Some(encoding) = self.parameter("client_encoding") {
let normalized = encoding.to_uppercase();
if normalized != "UTF8" && normalized != "UTF-8" {
return Err(DriverError::Protocol(format!(
"client_encoding is '{encoding}', but bsql requires UTF-8."
)));
}
}
if let Some(idt) = self.parameter("integer_datetimes") {
if idt != "on" {
return Err(DriverError::Protocol(format!(
"integer_datetimes is '{idt}', but bsql requires 'on'."
)));
}
}
Ok(())
}
pub fn prepare_only(&mut self, sql: &str, sql_hash: u64) -> Result<(), DriverError> {
if self.stmts.contains_key(&sql_hash) {
return Ok(());
}
let name = make_stmt_name(sql_hash);
self.write_buf.clear();
proto::write_parse(&mut self.write_buf, &name, sql, &[]);
proto::write_describe(&mut self.write_buf, b'S', &name);
proto::write_sync(&mut self.write_buf);
self.flush_write()?;
self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
let columns = self.read_column_description()?;
self.expect_ready()?;
self.query_counter += 1;
self.cache_stmt(
sql_hash,
StmtInfo {
name,
columns,
last_used: self.query_counter,
bind_template: None,
},
);
Ok(())
}
#[inline]
pub fn query(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
arena: &mut Arena,
) -> Result<QueryResult, DriverError> {
let columns = self
.send_pipeline(sql, sql_hash, params, true, true)?
.expect("send_pipeline(need_columns=true) must return Some");
let num_cols = columns.len();
let mut all_col_offsets: Vec<(usize, i32)> = Vec::with_capacity(num_cols.max(1) * 8);
let mut affected_rows: u64 = 0;
'outer: loop {
loop {
let avail = self.stream_buf_end - self.stream_buf_pos;
if avail < 5 {
break; }
let msg_type = self.stream_buf[self.stream_buf_pos];
let raw_len = i32::from_be_bytes([
self.stream_buf[self.stream_buf_pos + 1],
self.stream_buf[self.stream_buf_pos + 2],
self.stream_buf[self.stream_buf_pos + 3],
self.stream_buf[self.stream_buf_pos + 4],
]);
if raw_len < 4 {
return Err(DriverError::Protocol(format!(
"invalid message length {raw_len} for type '{}'",
msg_type as char
)));
}
let payload_len = (raw_len - 4) as usize;
let total_msg_len = 5 + payload_len;
if avail < total_msg_len {
if total_msg_len > self.stream_buf.len() {
let msg = self.read_one_message()?;
match msg {
BackendMessage::BindComplete => continue,
BackendMessage::DataRow { data } => {
parse_data_row_flat(data, arena, &mut all_col_offsets)?;
continue;
}
BackendMessage::CommandComplete { tag } => {
affected_rows = proto::parse_command_tag(tag);
continue;
}
BackendMessage::EmptyQuery => continue,
BackendMessage::ReadyForQuery { status } => {
self.tx_status = status;
break 'outer;
}
BackendMessage::NoticeResponse { .. } => continue,
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.maybe_invalidate_stmt_cache(&fields, sql_hash);
self.drain_to_ready()?;
return Err(self.make_server_error(fields));
}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during query: {other:?}"
)));
}
}
}
break; }
let payload_start = self.stream_buf_pos + 5;
let payload_end = payload_start + payload_len;
if msg_type == b'D' {
parse_data_row_flat(
&self.stream_buf[payload_start..payload_end],
arena,
&mut all_col_offsets,
)?;
} else if msg_type == b'Z' {
if payload_len >= 1 {
self.tx_status = self.stream_buf[payload_start];
}
self.stream_buf_pos += total_msg_len;
break 'outer;
} else {
self.handle_non_datarow_query(
msg_type,
payload_start,
payload_end,
sql_hash,
&mut affected_rows,
)?;
}
self.stream_buf_pos += total_msg_len;
}
self.refill_stream_buf()?;
}
self.shrink_buffers();
Ok(QueryResult::from_parts(
all_col_offsets,
num_cols,
columns,
affected_rows,
))
}
#[inline]
pub fn execute_monolithic(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<u64, DriverError> {
self.write_buf.clear();
let info = match self.stmts.get_mut(&sql_hash) {
Some(info) => {
self.query_counter += 1;
info.last_used = self.query_counter;
info
}
None => {
return self.execute_with_prepare(sql, sql_hash, params);
}
};
let can_use_template = info
.bind_template
.as_ref()
.is_some_and(|t| t.param_slots.len() == params.len());
let mut has_exec_sync = false;
if can_use_template {
let tmpl = info.bind_template.as_ref().unwrap();
self.write_buf.extend_from_slice(&tmpl.bytes);
let mut template_ok = true;
for (i, param) in params.iter().enumerate() {
let (data_offset, old_len) = tmpl.param_slots[i];
if param.is_null() {
let len_offset = data_offset - 4;
self.write_buf[len_offset..len_offset + 4]
.copy_from_slice(&(-1i32).to_be_bytes());
} else if old_len >= 0 {
let end = data_offset + old_len as usize;
if !param.encode_at(&mut self.write_buf[data_offset..end]) {
template_ok = false;
break;
}
} else {
template_ok = false;
break;
}
}
if template_ok {
has_exec_sync = true;
} else {
self.write_buf.clear();
proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
info.bind_template = None;
}
} else {
proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
}
if info.bind_template.is_none() && !self.write_buf.is_empty() {
info.bind_template = build_bind_template(&self.write_buf, params.len());
}
if !has_exec_sync {
self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
}
self.stream
.write_all(&self.write_buf)
.map_err(DriverError::Io)?;
let mut affected_rows: u64 = 0;
'outer: loop {
loop {
let avail = self.stream_buf_end - self.stream_buf_pos;
if avail < 5 {
break; }
let msg_type = self.stream_buf[self.stream_buf_pos];
let raw_len = i32::from_be_bytes([
self.stream_buf[self.stream_buf_pos + 1],
self.stream_buf[self.stream_buf_pos + 2],
self.stream_buf[self.stream_buf_pos + 3],
self.stream_buf[self.stream_buf_pos + 4],
]);
if raw_len < 4 {
return Err(DriverError::Protocol(format!(
"invalid message length {raw_len} for type '{}'",
msg_type as char
)));
}
let payload_len = (raw_len - 4) as usize;
let total_msg_len = 5 + payload_len;
if avail < total_msg_len {
if total_msg_len > self.stream_buf.len() {
let msg = self.read_one_message()?;
match msg {
BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
continue;
}
BackendMessage::CommandComplete { tag } => {
affected_rows = proto::parse_command_tag(tag);
continue;
}
BackendMessage::EmptyQuery => continue,
BackendMessage::ReadyForQuery { status } => {
self.tx_status = status;
break 'outer;
}
BackendMessage::NoticeResponse { .. } => continue,
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.maybe_invalidate_stmt_cache(&fields, sql_hash);
self.drain_to_ready()?;
return Err(self.make_server_error(fields));
}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during execute: {other:?}"
)));
}
}
}
break; }
let payload_start = self.stream_buf_pos + 5;
let payload_end = payload_start + payload_len;
if msg_type == b'2' {
self.stream_buf_pos += total_msg_len;
continue;
} else if msg_type == b'C' {
affected_rows = proto::parse_command_tag_bytes(
&self.stream_buf[payload_start..payload_end],
);
} else if msg_type == b'Z' {
if payload_len >= 1 {
self.tx_status = self.stream_buf[payload_start];
}
self.stream_buf_pos += total_msg_len;
break 'outer;
} else if msg_type == b'D' || msg_type == b'I' {
} else {
self.handle_non_datarow_execute(
msg_type,
payload_start,
payload_end,
sql_hash,
)?;
}
self.stream_buf_pos += total_msg_len;
}
let remaining = self.stream_buf_end - self.stream_buf_pos;
debug_assert!(
remaining == 0 || self.stream_buf_pos > 0,
"compact called with pos=0 and remaining data"
);
if remaining > 0 {
self.stream_buf
.copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
}
self.stream_buf_pos = 0;
self.stream_buf_end = remaining;
let n = self
.stream
.read(&mut self.stream_buf[remaining..])
.map_err(DriverError::Io)?;
if n == 0 {
return Err(DriverError::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"connection closed",
)));
}
self.stream_buf_end = remaining + n;
}
if self.query_counter & 63 == 0 {
if self.read_buf.capacity() > 64 * 1024 {
self.read_buf.clear();
self.read_buf.shrink_to(8192);
}
if self.write_buf.capacity() > 16 * 1024 {
self.write_buf.clear();
self.write_buf.shrink_to(8192);
}
}
Ok(affected_rows)
}
#[cold]
#[inline(never)]
fn execute_with_prepare(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<u64, DriverError> {
debug_assert_eq!(crate::conn::hash_sql(sql), sql_hash, "sql_hash mismatch");
if params.len() > i16::MAX as usize {
return Err(DriverError::Protocol(format!(
"parameter count {} exceeds maximum {}",
params.len(),
i16::MAX
)));
}
let name = make_stmt_name(sql_hash);
let param_oids: smallvec::SmallVec<[u32; 8]> =
params.iter().map(|p| p.type_oid()).collect();
self.write_buf.clear();
proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
proto::write_describe(&mut self.write_buf, b'S', &name);
proto::write_bind_params(&mut self.write_buf, "", &name, params);
self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
self.stream
.write_all(&self.write_buf)
.map_err(DriverError::Io)?;
self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
let columns = self.read_column_description()?;
self.query_counter += 1;
self.cache_stmt(
sql_hash,
StmtInfo {
name,
columns,
last_used: self.query_counter,
bind_template: None,
},
);
let mut affected_rows: u64 = 0;
'outer: loop {
loop {
let avail = self.stream_buf_end - self.stream_buf_pos;
if avail < 5 {
break;
}
let msg_type = self.stream_buf[self.stream_buf_pos];
let raw_len = i32::from_be_bytes([
self.stream_buf[self.stream_buf_pos + 1],
self.stream_buf[self.stream_buf_pos + 2],
self.stream_buf[self.stream_buf_pos + 3],
self.stream_buf[self.stream_buf_pos + 4],
]);
if raw_len < 4 {
return Err(DriverError::Protocol(format!(
"invalid message length {raw_len} for type '{}'",
msg_type as char
)));
}
let payload_len = (raw_len - 4) as usize;
let total_msg_len = 5 + payload_len;
if avail < total_msg_len {
if total_msg_len > self.stream_buf.len() {
let msg = self.read_one_message()?;
match msg {
BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
continue;
}
BackendMessage::CommandComplete { tag } => {
affected_rows = proto::parse_command_tag(tag);
continue;
}
BackendMessage::EmptyQuery => continue,
BackendMessage::ReadyForQuery { status } => {
self.tx_status = status;
break 'outer;
}
BackendMessage::NoticeResponse { .. } => continue,
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.maybe_invalidate_stmt_cache(&fields, sql_hash);
self.drain_to_ready()?;
return Err(self.make_server_error(fields));
}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during execute: {other:?}"
)));
}
}
}
break;
}
let payload_start = self.stream_buf_pos + 5;
let payload_end = payload_start + payload_len;
if msg_type == b'2' || msg_type == b'D' || msg_type == b'I' {
} else if msg_type == b'C' {
affected_rows = proto::parse_command_tag_bytes(
&self.stream_buf[payload_start..payload_end],
);
} else if msg_type == b'Z' {
if payload_len >= 1 {
self.tx_status = self.stream_buf[payload_start];
}
self.stream_buf_pos += total_msg_len;
break 'outer;
} else {
self.handle_non_datarow_execute(
msg_type,
payload_start,
payload_end,
sql_hash,
)?;
}
self.stream_buf_pos += total_msg_len;
}
self.refill_stream_buf()?;
}
Ok(affected_rows)
}
#[inline]
pub fn execute(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<u64, DriverError> {
self.execute_monolithic(sql, sql_hash, params)
}
pub fn execute_pipeline(
&mut self,
sql: &str,
sql_hash: u64,
param_sets: &[&[&(dyn Encode + Sync)]],
) -> Result<Vec<u64>, DriverError> {
if param_sets.is_empty() {
return Ok(Vec::new());
}
debug_assert_eq!(crate::conn::hash_sql(sql), sql_hash, "sql_hash mismatch");
self.write_buf.clear();
if !self.stmts.contains_key(&sql_hash) {
let name = make_stmt_name(sql_hash);
let first_params = param_sets[0];
if first_params.len() > i16::MAX as usize {
return Err(DriverError::Protocol(format!(
"parameter count {} exceeds maximum {}",
first_params.len(),
i16::MAX
)));
}
let param_oids: smallvec::SmallVec<[u32; 8]> =
first_params.iter().map(|p| p.type_oid()).collect();
proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
proto::write_describe(&mut self.write_buf, b'S', &name);
proto::write_sync(&mut self.write_buf);
self.flush_write()?;
self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
let columns = self.read_column_description()?;
self.expect_ready()?;
self.query_counter += 1;
self.cache_stmt(
sql_hash,
StmtInfo {
name,
columns,
last_used: self.query_counter,
bind_template: None,
},
);
self.write_buf.clear();
}
let stmt_name = self
.stmts
.get(&sql_hash)
.expect("BUG: stmt just cached but not found")
.name
.clone();
let count = param_sets.len();
for params in param_sets {
if params.len() > i16::MAX as usize {
return Err(DriverError::Protocol(format!(
"parameter count {} exceeds maximum {}",
params.len(),
i16::MAX
)));
}
proto::write_bind_params(&mut self.write_buf, "", &stmt_name, params);
self.write_buf.extend_from_slice(proto::EXECUTE_ONLY);
}
self.write_buf.extend_from_slice(proto::SYNC_ONLY);
self.flush_write()?;
let mut results = Vec::with_capacity(count);
for _ in 0..count {
self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
let mut affected_rows: u64 = 0;
loop {
let msg = self.read_one_message()?;
match msg {
BackendMessage::DataRow { .. } => {}
BackendMessage::CommandComplete { tag } => {
affected_rows = proto::parse_command_tag(tag);
break;
}
BackendMessage::EmptyQuery => break,
BackendMessage::NoticeResponse { .. } => {}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.maybe_invalidate_stmt_cache(&fields, sql_hash);
self.drain_to_ready()?;
return Err(self.make_server_error(fields));
}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during execute_pipeline: {other:?}"
)));
}
}
}
results.push(affected_rows);
}
self.expect_ready()?;
self.shrink_buffers();
Ok(results)
}
pub(crate) fn ensure_stmt_prepared(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<Box<str>, DriverError> {
if let Some(info) = self.stmts.get(&sql_hash) {
return Ok(info.name.clone());
}
let name = make_stmt_name(sql_hash);
if params.len() > i16::MAX as usize {
return Err(DriverError::Protocol(format!(
"parameter count {} exceeds maximum {}",
params.len(),
i16::MAX
)));
}
let param_oids: smallvec::SmallVec<[u32; 8]> =
params.iter().map(|p| p.type_oid()).collect();
self.write_buf.clear();
proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
proto::write_describe(&mut self.write_buf, b'S', &name);
proto::write_sync(&mut self.write_buf);
self.flush_write()?;
self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
let columns = self.read_column_description()?;
self.expect_ready()?;
self.query_counter += 1;
let stmt_name = name.clone();
self.cache_stmt(
sql_hash,
StmtInfo {
name,
columns,
last_used: self.query_counter,
bind_template: None,
},
);
Ok(stmt_name)
}
pub(crate) fn write_deferred_bind_execute(
&self,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
buf: &mut Vec<u8>,
) {
let stmt_name = &self
.stmts
.get(&sql_hash)
.expect("BUG: stmt just cached but not found")
.name;
proto::write_bind_params(buf, "", stmt_name, params);
buf.extend_from_slice(proto::EXECUTE_ONLY);
}
pub(crate) fn flush_deferred_pipeline(
&mut self,
buf: &mut Vec<u8>,
count: usize,
) -> Result<Vec<u64>, DriverError> {
if count == 0 {
buf.clear();
return Ok(Vec::new());
}
buf.extend_from_slice(proto::SYNC_ONLY);
self.stream.write_all(buf).map_err(DriverError::Io)?;
buf.clear();
let mut results = Vec::with_capacity(count);
for _ in 0..count {
self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
let mut affected_rows: u64 = 0;
loop {
let msg = self.read_one_message()?;
match msg {
BackendMessage::DataRow { .. } => {}
BackendMessage::CommandComplete { tag } => {
affected_rows = proto::parse_command_tag(tag);
break;
}
BackendMessage::EmptyQuery => break,
BackendMessage::NoticeResponse { .. } => {}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.drain_to_ready()?;
return Err(self.make_server_error(fields));
}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during flush_deferred_pipeline: {other:?}"
)));
}
}
}
results.push(affected_rows);
}
self.expect_ready()?;
self.shrink_buffers();
Ok(results)
}
pub fn for_each<F>(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
mut f: F,
) -> Result<(), DriverError>
where
F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
{
let _ = self.send_pipeline(sql, sql_hash, params, false, true)?;
'outer: loop {
loop {
let avail = self.stream_buf_end - self.stream_buf_pos;
if avail < 5 {
break; }
let msg_type = self.stream_buf[self.stream_buf_pos];
let raw_len = i32::from_be_bytes([
self.stream_buf[self.stream_buf_pos + 1],
self.stream_buf[self.stream_buf_pos + 2],
self.stream_buf[self.stream_buf_pos + 3],
self.stream_buf[self.stream_buf_pos + 4],
]);
if raw_len < 4 {
return Err(DriverError::Protocol(format!(
"invalid message length {raw_len} for type '{}'",
msg_type as char
)));
}
let payload_len = (raw_len - 4) as usize;
let total_msg_len = 5 + payload_len;
if avail < total_msg_len {
if total_msg_len > self.stream_buf.len() {
let msg = self.read_one_message()?;
match msg {
BackendMessage::BindComplete => continue,
BackendMessage::DataRow { data } => {
let row = PgDataRow::new(data)?;
f(row)?;
continue;
}
BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
continue;
}
BackendMessage::ReadyForQuery { status } => {
self.tx_status = status;
break 'outer;
}
BackendMessage::NoticeResponse { .. } => continue,
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.maybe_invalidate_stmt_cache(&fields, sql_hash);
self.drain_to_ready()?;
return Err(self.make_server_error(fields));
}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during for_each: {other:?}"
)));
}
}
}
break; }
let payload_start = self.stream_buf_pos + 5;
let payload_end = payload_start + payload_len;
if msg_type == b'D' {
let row = PgDataRow::new(&self.stream_buf[payload_start..payload_end])?;
f(row)?;
} else if msg_type == b'Z' {
if payload_len >= 1 {
self.tx_status = self.stream_buf[payload_start];
}
self.stream_buf_pos += total_msg_len;
break 'outer;
} else {
self.handle_non_datarow_execute(
msg_type,
payload_start,
payload_end,
sql_hash,
)?;
}
self.stream_buf_pos += total_msg_len;
}
self.refill_stream_buf()?;
}
self.shrink_buffers();
Ok(())
}
#[inline]
pub fn for_each_raw_monolithic<F>(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
mut f: F,
) -> Result<(), DriverError>
where
F: FnMut(&[u8]) -> Result<(), DriverError>,
{
self.write_buf.clear();
let info = match self.stmts.get_mut(&sql_hash) {
Some(info) => {
self.query_counter += 1;
info.last_used = self.query_counter;
info
}
None => {
return self.for_each_raw_with_prepare(sql, sql_hash, params, f);
}
};
let can_use_template = info
.bind_template
.as_ref()
.is_some_and(|t| t.param_slots.len() == params.len());
let mut has_exec_sync = false;
if can_use_template {
let tmpl = info.bind_template.as_ref().unwrap();
self.write_buf.extend_from_slice(&tmpl.bytes);
let mut template_ok = true;
for (i, param) in params.iter().enumerate() {
let (data_offset, old_len) = tmpl.param_slots[i];
if param.is_null() {
let len_offset = data_offset - 4;
self.write_buf[len_offset..len_offset + 4]
.copy_from_slice(&(-1i32).to_be_bytes());
} else if old_len >= 0 {
let end = data_offset + old_len as usize;
if !param.encode_at(&mut self.write_buf[data_offset..end]) {
template_ok = false;
break;
}
} else {
template_ok = false;
break;
}
}
if template_ok {
has_exec_sync = true;
} else {
self.write_buf.clear();
proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
info.bind_template = None;
}
} else {
proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
}
if info.bind_template.is_none() && !self.write_buf.is_empty() {
info.bind_template = build_bind_template(&self.write_buf, params.len());
}
if !has_exec_sync {
self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
}
self.stream
.write_all(&self.write_buf)
.map_err(DriverError::Io)?;
loop {
let avail = self.stream_buf_end - self.stream_buf_pos;
if avail >= 5 {
let bc_type = self.stream_buf[self.stream_buf_pos];
match bc_type {
b'2' => {
self.stream_buf_pos += 5;
break;
}
b'E' => {
let msg = self.read_one_message()?;
if let BackendMessage::ErrorResponse { data } = msg {
let fields = proto::parse_error_response(data);
self.drain_to_ready()?;
return Err(self.make_server_error(fields));
}
}
b'N' | b'S' => {
let raw_len = i32::from_be_bytes([
self.stream_buf[self.stream_buf_pos + 1],
self.stream_buf[self.stream_buf_pos + 2],
self.stream_buf[self.stream_buf_pos + 3],
self.stream_buf[self.stream_buf_pos + 4],
]);
let total = 1 + raw_len as usize;
if avail >= total {
self.stream_buf_pos += total;
continue;
}
self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
break;
}
_ => {
self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
break;
}
}
} else {
let remaining = self.stream_buf_end - self.stream_buf_pos;
if remaining > 0 && self.stream_buf_pos > 0 {
self.stream_buf
.copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
}
self.stream_buf_pos = 0;
self.stream_buf_end = remaining;
let n = self
.stream
.read(&mut self.stream_buf[remaining..])
.map_err(DriverError::Io)?;
if n == 0 {
return Err(DriverError::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"connection closed",
)));
}
self.stream_buf_end = remaining + n;
}
}
'outer: loop {
loop {
let avail = self.stream_buf_end - self.stream_buf_pos;
if avail < 5 {
break;
}
let msg_type = self.stream_buf[self.stream_buf_pos];
let raw_len = i32::from_be_bytes([
self.stream_buf[self.stream_buf_pos + 1],
self.stream_buf[self.stream_buf_pos + 2],
self.stream_buf[self.stream_buf_pos + 3],
self.stream_buf[self.stream_buf_pos + 4],
]);
if raw_len < 4 {
return Err(DriverError::Protocol(format!(
"invalid message length {raw_len} for type '{}'",
msg_type as char
)));
}
let payload_len = (raw_len - 4) as usize;
let total_msg_len = 5 + payload_len;
if avail < total_msg_len {
if total_msg_len > self.stream_buf.len() {
let msg = self.read_one_message()?;
match msg {
BackendMessage::DataRow { data } => {
f(data)?;
continue;
}
BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
continue;
}
BackendMessage::ReadyForQuery { status } => {
self.tx_status = status;
break 'outer;
}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.maybe_invalidate_stmt_cache(&fields, sql_hash);
self.drain_to_ready()?;
return Err(self.make_server_error(fields));
}
BackendMessage::NoticeResponse { .. } => continue,
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during for_each_raw: {other:?}"
)));
}
}
}
break; }
let payload_start = self.stream_buf_pos + 5;
let payload_end = payload_start + payload_len;
if msg_type == b'D' {
f(&self.stream_buf[payload_start..payload_end])?;
} else if msg_type == b'Z' {
if payload_len >= 1 {
self.tx_status = self.stream_buf[payload_start];
}
self.stream_buf_pos += total_msg_len;
break 'outer;
} else {
self.handle_non_datarow_execute(
msg_type,
payload_start,
payload_end,
sql_hash,
)?;
}
self.stream_buf_pos += total_msg_len;
}
let remaining = self.stream_buf_end - self.stream_buf_pos;
if remaining > 0 && self.stream_buf_pos > 0 {
self.stream_buf
.copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
}
self.stream_buf_pos = 0;
self.stream_buf_end = remaining;
let n = self
.stream
.read(&mut self.stream_buf[remaining..])
.map_err(DriverError::Io)?;
if n == 0 {
return Err(DriverError::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"connection closed",
)));
}
self.stream_buf_end = remaining + n;
}
if self.query_counter & 63 == 0 {
if self.read_buf.capacity() > 64 * 1024 {
self.read_buf.clear();
self.read_buf.shrink_to(8192);
}
if self.write_buf.capacity() > 16 * 1024 {
self.write_buf.clear();
self.write_buf.shrink_to(8192);
}
}
Ok(())
}
#[cold]
#[inline(never)]
fn for_each_raw_with_prepare<F>(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
mut f: F,
) -> Result<(), DriverError>
where
F: FnMut(&[u8]) -> Result<(), DriverError>,
{
debug_assert_eq!(crate::conn::hash_sql(sql), sql_hash, "sql_hash mismatch");
if params.len() > i16::MAX as usize {
return Err(DriverError::Protocol(format!(
"parameter count {} exceeds maximum {}",
params.len(),
i16::MAX
)));
}
let name = make_stmt_name(sql_hash);
let param_oids: smallvec::SmallVec<[u32; 8]> =
params.iter().map(|p| p.type_oid()).collect();
self.write_buf.clear();
proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
proto::write_describe(&mut self.write_buf, b'S', &name);
proto::write_bind_params(&mut self.write_buf, "", &name, params);
self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
self.stream
.write_all(&self.write_buf)
.map_err(DriverError::Io)?;
self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
let columns = self.read_column_description()?;
self.query_counter += 1;
self.cache_stmt(
sql_hash,
StmtInfo {
name,
columns,
last_used: self.query_counter,
bind_template: None,
},
);
self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
'outer: loop {
loop {
let avail = self.stream_buf_end - self.stream_buf_pos;
if avail < 5 {
break;
}
let msg_type = self.stream_buf[self.stream_buf_pos];
let raw_len = i32::from_be_bytes([
self.stream_buf[self.stream_buf_pos + 1],
self.stream_buf[self.stream_buf_pos + 2],
self.stream_buf[self.stream_buf_pos + 3],
self.stream_buf[self.stream_buf_pos + 4],
]);
if raw_len < 4 {
return Err(DriverError::Protocol(format!(
"invalid message length {raw_len} for type '{}'",
msg_type as char
)));
}
let payload_len = (raw_len - 4) as usize;
let total_msg_len = 5 + payload_len;
if avail < total_msg_len {
if total_msg_len > self.stream_buf.len() {
let msg = self.read_one_message()?;
match msg {
BackendMessage::DataRow { data } => {
f(data)?;
continue;
}
BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
continue;
}
BackendMessage::ReadyForQuery { status } => {
self.tx_status = status;
break 'outer;
}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.maybe_invalidate_stmt_cache(&fields, sql_hash);
self.drain_to_ready()?;
return Err(self.make_server_error(fields));
}
BackendMessage::NoticeResponse { .. } => continue,
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during for_each_raw: {other:?}"
)));
}
}
}
break;
}
let payload_start = self.stream_buf_pos + 5;
let payload_end = payload_start + payload_len;
if msg_type == b'D' {
f(&self.stream_buf[payload_start..payload_end])?;
} else if msg_type == b'Z' {
if payload_len >= 1 {
self.tx_status = self.stream_buf[payload_start];
}
self.stream_buf_pos += total_msg_len;
break 'outer;
} else {
self.handle_non_datarow_execute(
msg_type,
payload_start,
payload_end,
sql_hash,
)?;
}
self.stream_buf_pos += total_msg_len;
}
self.refill_stream_buf()?;
}
self.shrink_buffers();
Ok(())
}
#[inline]
pub fn for_each_raw<F>(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
f: F,
) -> Result<(), DriverError>
where
F: FnMut(&[u8]) -> Result<(), DriverError>,
{
self.for_each_raw_monolithic(sql, sql_hash, params, f)
}
pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
self.write_buf.clear();
proto::write_simple_query(&mut self.write_buf, sql);
self.flush_write()?;
loop {
let msg = self.read_one_message()?;
match msg {
BackendMessage::ReadyForQuery { status } => {
self.tx_status = status;
return Ok(());
}
BackendMessage::CommandComplete { .. }
| BackendMessage::RowDescription { .. }
| BackendMessage::DataRow { .. }
| BackendMessage::EmptyQuery
| BackendMessage::NoticeResponse { .. }
| BackendMessage::ParameterStatus { .. } => {}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.drain_to_ready()?;
return Err(self.make_server_error(fields));
}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during simple_query: {other:?}"
)));
}
}
}
}
pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
self.write_buf.clear();
proto::write_simple_query(&mut self.write_buf, sql);
self.flush_write()?;
let mut rows: Vec<SimpleRow> = Vec::new();
loop {
let msg = self.read_one_message()?;
match msg {
BackendMessage::ReadyForQuery { status } => {
self.tx_status = status;
return Ok(rows);
}
BackendMessage::DataRow { data } => {
rows.push(proto::parse_simple_data_row(data)?);
}
BackendMessage::RowDescription { .. }
| BackendMessage::CommandComplete { .. }
| BackendMessage::EmptyQuery
| BackendMessage::NoticeResponse { .. }
| BackendMessage::ParameterStatus { .. } => {}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.drain_to_ready()?;
return Err(self.make_server_error(fields));
}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message during simple_query_rows: {other:?}"
)));
}
}
}
}
pub fn close(mut self) -> Result<(), DriverError> {
self.write_buf.clear();
proto::write_terminate(&mut self.write_buf);
let _ = self.flush_write();
Ok(())
}
pub fn is_idle(&self) -> bool {
self.tx_status == b'I'
}
pub fn is_in_transaction(&self) -> bool {
self.tx_status == b'T'
}
pub fn is_in_failed_transaction(&self) -> bool {
self.tx_status == b'E'
}
pub fn touch(&mut self) {
self.last_used = std::time::Instant::now();
}
pub fn idle_duration(&self) -> std::time::Duration {
self.last_used.elapsed()
}
pub fn parameter(&self, name: &str) -> Option<&str> {
self.params
.iter()
.find(|(k, _)| &**k == name)
.map(|(_, v)| &**v)
}
pub fn pid(&self) -> i32 {
self.pid
}
pub fn secret_key(&self) -> i32 {
self.secret
}
pub fn drain_notifications(&mut self) -> Vec<Notification> {
std::mem::take(&mut self.pending_notifications)
}
pub fn pending_notification_count(&self) -> usize {
self.pending_notifications.len()
}
pub fn set_max_stmt_cache_size(&mut self, size: usize) {
self.max_stmt_cache_size = size;
}
pub fn stmt_cache_len(&self) -> usize {
self.stmts.len()
}
pub fn created_at(&self) -> std::time::Instant {
self.created_at
}
#[inline]
fn send_pipeline(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
need_columns: bool,
skip_bind_complete: bool,
) -> Result<Option<Arc<[ColumnDesc]>>, DriverError> {
debug_assert_eq!(crate::conn::hash_sql(sql), sql_hash, "sql_hash mismatch");
if params.len() > i16::MAX as usize {
return Err(DriverError::Protocol(format!(
"parameter count {} exceeds maximum {}",
params.len(),
i16::MAX
)));
}
self.write_buf.clear();
let columns = if let Some(info) = self.stmts.get_mut(&sql_hash) {
self.query_counter += 1;
info.last_used = self.query_counter;
let can_use_template = info
.bind_template
.as_ref()
.is_some_and(|t| t.param_slots.len() == params.len());
let mut has_exec_sync = false;
if can_use_template {
let tmpl = info.bind_template.as_ref().unwrap();
self.write_buf.extend_from_slice(&tmpl.bytes);
let mut template_ok = true;
for (i, param) in params.iter().enumerate() {
let (data_offset, old_len) = tmpl.param_slots[i];
if param.is_null() {
let len_offset = data_offset - 4;
self.write_buf[len_offset..len_offset + 4]
.copy_from_slice(&(-1i32).to_be_bytes());
} else if old_len >= 0 {
let end = data_offset + old_len as usize;
if !param.encode_at(&mut self.write_buf[data_offset..end]) {
template_ok = false;
break;
}
} else {
template_ok = false;
break;
}
}
if template_ok {
has_exec_sync = true; } else {
self.write_buf.clear();
proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
info.bind_template = None;
}
} else {
proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
}
let cols = if need_columns {
Some(info.columns.clone())
} else {
None
};
if info.bind_template.is_none() && !self.write_buf.is_empty() {
info.bind_template = build_bind_template(&self.write_buf, params.len());
}
if !has_exec_sync {
self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
}
self.flush_write()?;
cols
} else {
let name = make_stmt_name(sql_hash);
let param_oids: smallvec::SmallVec<[u32; 8]> =
params.iter().map(|p| p.type_oid()).collect();
proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
proto::write_describe(&mut self.write_buf, b'S', &name);
proto::write_bind_params(&mut self.write_buf, "", &name, params);
self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
self.flush_write()?;
self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
let columns = self.read_column_description()?;
self.query_counter += 1;
self.cache_stmt(
sql_hash,
StmtInfo {
name,
columns: columns.clone(),
last_used: self.query_counter,
bind_template: None,
},
);
if need_columns { Some(columns) } else { None }
};
if !skip_bind_complete {
self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
}
Ok(columns)
}
fn read_column_description(&mut self) -> Result<Arc<[ColumnDesc]>, DriverError> {
loop {
let msg = self.read_one_message()?;
match msg {
BackendMessage::RowDescription { data } => {
let cols = proto::parse_row_description(data)?;
return Ok(cols.into());
}
BackendMessage::ParameterDescription { .. } => {}
BackendMessage::NoData => return Ok(Arc::from(Vec::new())),
BackendMessage::NoticeResponse { .. } => {}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.drain_to_ready()?;
return Err(self.make_server_error(fields));
}
other => {
return Err(DriverError::Protocol(format!(
"expected RowDescription/NoData, got: {other:?}"
)));
}
}
}
}
fn cache_stmt(&mut self, sql_hash: u64, info: StmtInfo) {
if self.stmts.len() >= self.max_stmt_cache_size && !self.stmts.contains_key(&sql_hash) {
if let Some((_lru_hash, evicted)) = self.stmts.evict_lru() {
proto::write_close(&mut self.write_buf, b'S', &evicted.name);
}
}
self.stmts.insert(sql_hash, info);
}
fn buffer_notification(&mut self, pid: i32, channel: &str, payload: &str) {
if self.pending_notifications.len() < 1024 {
self.pending_notifications.push(Notification {
pid,
channel: channel.to_owned(),
payload: payload.to_owned(),
});
}
}
fn shrink_buffers(&mut self) {
if self.query_counter & 63 != 0 {
return;
}
if self.read_buf.capacity() > 64 * 1024 {
self.read_buf.clear();
self.read_buf.shrink_to(8192);
}
if self.write_buf.capacity() > 16 * 1024 {
self.write_buf.clear();
self.write_buf.shrink_to(8192);
}
}
fn maybe_invalidate_stmt_cache(&mut self, fields: &proto::ErrorFields, sql_hash: u64) -> bool {
if &*fields.code == "26000" {
self.stmts.remove(&sql_hash);
true
} else {
false
}
}
#[cold]
#[inline(never)]
fn make_server_error(&self, fields: proto::ErrorFields) -> DriverError {
DriverError::Server {
code: fields.code,
message: fields.message.into_boxed_str(),
detail: fields.detail.map(String::into_boxed_str),
hint: fields.hint.map(String::into_boxed_str),
position: fields.position,
}
}
#[cold]
#[inline(never)]
fn handle_non_datarow_query(
&mut self,
msg_type: u8,
payload_start: usize,
payload_end: usize,
sql_hash: u64,
affected_rows: &mut u64,
) -> Result<(), DriverError> {
match msg_type {
b'2' | b'I' => {} b'C' => {
*affected_rows =
proto::parse_command_tag_bytes(&self.stream_buf[payload_start..payload_end]);
}
b'E' => {
let fields =
proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
self.maybe_invalidate_stmt_cache(&fields, sql_hash);
self.drain_to_ready()?;
return Err(self.make_server_error(fields));
}
b'A' => {
let msg = proto::parse_backend_message(
msg_type,
&self.stream_buf[payload_start..payload_end],
)?;
if let BackendMessage::NotificationResponse {
pid,
channel,
payload,
} = msg
{
let ch = channel.to_owned();
let pl = payload.to_owned();
self.buffer_notification(pid, &ch, &pl);
}
}
_ => {} }
Ok(())
}
#[cold]
#[inline(never)]
fn handle_non_datarow_execute(
&mut self,
msg_type: u8,
payload_start: usize,
payload_end: usize,
sql_hash: u64,
) -> Result<(), DriverError> {
match msg_type {
b'2' | b'C' | b'I' => {} b'E' => {
let fields =
proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
self.maybe_invalidate_stmt_cache(&fields, sql_hash);
self.drain_to_ready()?;
return Err(self.make_server_error(fields));
}
b'A' => {
let msg = proto::parse_backend_message(
msg_type,
&self.stream_buf[payload_start..payload_end],
)?;
if let BackendMessage::NotificationResponse {
pid,
channel,
payload,
} = msg
{
let ch = channel.to_owned();
let pl = payload.to_owned();
self.buffer_notification(pid, &ch, &pl);
}
}
_ => {} }
Ok(())
}
#[inline]
fn read_one_message(&mut self) -> Result<BackendMessage<'_>, DriverError> {
loop {
let (msg_type, _payload_len) = self.read_message_buffered()?;
if msg_type == b'A' {
let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
if let BackendMessage::NotificationResponse {
pid,
channel,
payload,
} = msg
{
let pid_owned = pid;
let channel_owned = channel.to_owned();
let payload_owned = payload.to_owned();
self.buffer_notification(pid_owned, &channel_owned, &payload_owned);
continue;
}
}
return proto::parse_backend_message(msg_type, &self.read_buf);
}
}
fn expect_message(
&mut self,
pred: impl Fn(&BackendMessage<'_>) -> bool,
) -> Result<(), DriverError> {
loop {
let msg = self.read_one_message()?;
if pred(&msg) {
return Ok(());
}
match msg {
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.drain_to_ready()?;
return Err(self.make_server_error(fields));
}
BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
other => {
return Err(DriverError::Protocol(format!(
"unexpected message while waiting for expected type: {other:?}"
)));
}
}
}
}
fn expect_ready(&mut self) -> Result<(), DriverError> {
loop {
let msg = self.read_one_message()?;
match msg {
BackendMessage::ReadyForQuery { status } => {
self.tx_status = status;
return Ok(());
}
BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
BackendMessage::ErrorResponse { data } => {
let fields = proto::parse_error_response(data);
self.drain_to_ready()?;
return Err(self.make_server_error(fields));
}
_ => {}
}
}
}
#[inline]
fn drain_to_ready(&mut self) -> Result<(), DriverError> {
loop {
let msg = self.read_one_message()?;
if let BackendMessage::ReadyForQuery { status } = msg {
self.tx_status = status;
return Ok(());
}
}
}
#[inline]
fn flush_write(&mut self) -> Result<(), DriverError> {
self.stream
.write_all(&self.write_buf)
.map_err(DriverError::Io)
}
fn read_message_buffered(&mut self) -> Result<(u8, usize), DriverError> {
let mut header = [0u8; 5];
sync_buffered_read_exact(
&mut self.stream,
&mut self.stream_buf,
&mut self.stream_buf_pos,
&mut self.stream_buf_end,
&mut header,
)?;
let msg_type = header[0];
let len = i32::from_be_bytes([header[1], header[2], header[3], header[4]]);
if len < 4 {
return Err(DriverError::Protocol(format!(
"invalid message length {len} for type '{}'",
msg_type as char
)));
}
const MAX_MESSAGE_LEN: i32 = 128 * 1024 * 1024;
if len > MAX_MESSAGE_LEN {
return Err(DriverError::Protocol(format!(
"message length {len} exceeds maximum ({MAX_MESSAGE_LEN}) for type '{}'",
msg_type as char
)));
}
let payload_len = (len - 4) as usize;
self.read_buf.clear();
self.read_buf.resize(payload_len, 0);
if payload_len > 0 {
sync_buffered_read_exact(
&mut self.stream,
&mut self.stream_buf,
&mut self.stream_buf_pos,
&mut self.stream_buf_end,
&mut self.read_buf[..payload_len],
)?;
}
Ok((msg_type, payload_len))
}
#[inline]
fn refill_stream_buf(&mut self) -> Result<(), DriverError> {
let remaining = self.stream_buf_end - self.stream_buf_pos;
debug_assert!(
remaining == 0 || self.stream_buf_pos > 0,
"refill_stream_buf called with pos=0 and remaining data"
);
if remaining > 0 {
self.stream_buf
.copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
}
self.stream_buf_pos = 0;
self.stream_buf_end = remaining;
let n = self
.stream
.read(&mut self.stream_buf[remaining..])
.map_err(DriverError::Io)?;
if n == 0 {
return Err(DriverError::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"connection closed",
)));
}
self.stream_buf_end = remaining + n;
Ok(())
}
}
fn sync_buffered_read_exact(
stream: &mut UnixStream,
buf: &mut [u8],
pos: &mut usize,
end: &mut usize,
out: &mut [u8],
) -> Result<(), DriverError> {
let mut filled = 0;
while filled < out.len() {
let avail = *end - *pos;
if avail > 0 {
let take = avail.min(out.len() - filled);
out[filled..filled + take].copy_from_slice(&buf[*pos..*pos + take]);
*pos += take;
filled += take;
} else {
*pos = 0;
let n = stream.read(buf).map_err(DriverError::Io)?;
if n == 0 {
return Err(DriverError::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"connection closed",
)));
}
*end = n;
}
}
Ok(())
}
fn parse_data_row_flat(
data: &[u8],
arena: &mut Arena,
out: &mut Vec<(usize, i32)>,
) -> Result<(), DriverError> {
if data.len() < 2 {
return Err(DriverError::Protocol("DataRow too short".into()));
}
let num_cols_raw = i16::from_be_bytes([data[0], data[1]]);
if num_cols_raw < 0 {
return Err(DriverError::Protocol(
"DataRow: negative column count".into(),
));
}
let num_cols = num_cols_raw as usize;
out.reserve(num_cols);
let mut pos = 2;
for _ in 0..num_cols {
if pos + 4 > data.len() {
return Err(DriverError::Protocol("DataRow truncated".into()));
}
let col_len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
if col_len < 0 {
out.push((0, -1));
} else {
let len = col_len as usize;
if pos + len > data.len() {
return Err(DriverError::Protocol(
"DataRow column data truncated".into(),
));
}
let offset = arena.alloc_copy(&data[pos..pos + len]);
out.push((offset, col_len));
pos += len;
}
}
Ok(())
}
fn build_bind_template(write_buf: &[u8], param_count: usize) -> Option<BindTemplate> {
if write_buf.is_empty() || write_buf[0] != b'B' {
return None;
}
if write_buf.len() < 5 {
return None;
}
let mut pos = 5;
while pos < write_buf.len() && write_buf[pos] != 0 {
pos += 1;
}
pos += 1;
while pos < write_buf.len() && write_buf[pos] != 0 {
pos += 1;
}
pos += 1;
if pos + 2 > write_buf.len() {
return None;
}
let num_fmt_codes = i16::from_be_bytes([write_buf[pos], write_buf[pos + 1]]);
pos += 2;
pos += num_fmt_codes.max(0) as usize * 2;
if pos + 2 > write_buf.len() {
return None;
}
let wire_param_count = i16::from_be_bytes([write_buf[pos], write_buf[pos + 1]]) as usize;
pos += 2;
if wire_param_count != param_count {
return None;
}
let mut param_slots = Vec::with_capacity(param_count);
for _ in 0..param_count {
if pos + 4 > write_buf.len() {
return None;
}
let data_len = i32::from_be_bytes([
write_buf[pos],
write_buf[pos + 1],
write_buf[pos + 2],
write_buf[pos + 3],
]);
pos += 4;
if data_len < 0 {
param_slots.push((pos, -1));
} else {
let data_offset = pos;
param_slots.push((data_offset, data_len));
pos += data_len as usize;
}
}
let mut bytes = Vec::with_capacity(write_buf.len() + proto::EXECUTE_SYNC.len());
bytes.extend_from_slice(write_buf);
bytes.extend_from_slice(proto::EXECUTE_SYNC);
Some(BindTemplate { bytes, param_slots })
}
#[cfg(test)]
#[allow(clippy::approx_constant)]
mod tests {
use super::*;
use crate::conn::hash_sql;
#[test]
fn sync_make_stmt_name() {
let name = make_stmt_name(0);
assert_eq!(&*name, "s_0000000000000000");
let name = make_stmt_name(0xDEADBEEF12345678);
assert_eq!(&*name, "s_deadbeef12345678");
}
#[test]
fn sync_stmt_cache_basic() {
let cache = StmtCache::default();
assert_eq!(cache.len(), 0);
assert!(!cache.contains_key(&42));
}
#[test]
fn sync_config_rejects_tcp() {
let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
let result = SyncConnection::connect(&config);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Unix domain socket"),
"error should mention UDS requirement: {err}"
);
}
#[test]
fn sync_data_row_parsing() {
let mut arena = Arena::new();
let mut out = Vec::new();
let mut data = Vec::new();
data.extend_from_slice(&2i16.to_be_bytes());
data.extend_from_slice(&4i32.to_be_bytes());
data.extend_from_slice(&42i32.to_be_bytes());
data.extend_from_slice(&(-1i32).to_be_bytes());
parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
assert_eq!(out.len(), 2);
assert_eq!(out[0].1, 4);
assert_eq!(out[1].1, -1);
}
#[test]
fn sync_data_row_empty() {
let mut arena = Arena::new();
let mut out = Vec::new();
let data = 0i16.to_be_bytes();
parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
assert_eq!(out.len(), 0);
}
#[test]
fn sync_data_row_too_short() {
let mut arena = Arena::new();
let mut out = Vec::new();
let data = vec![0u8];
assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
}
#[test]
fn sync_data_row_negative_col_count() {
let mut arena = Arena::new();
let mut out = Vec::new();
let data = (-1i16).to_be_bytes();
assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
}
#[test]
fn sync_data_row_truncated() {
let mut arena = Arena::new();
let mut out = Vec::new();
let mut data = Vec::new();
data.extend_from_slice(&2i16.to_be_bytes());
data.extend_from_slice(&4i32.to_be_bytes());
data.extend_from_slice(&42i32.to_be_bytes());
assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
}
#[test]
fn sync_data_row_col_data_truncated() {
let mut arena = Arena::new();
let mut out = Vec::new();
let mut data = Vec::new();
data.extend_from_slice(&1i16.to_be_bytes());
data.extend_from_slice(&100i32.to_be_bytes()); data.push(0); assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
}
#[test]
fn build_bind_template_basic() {
let mut buf = Vec::new();
let val: i32 = 42;
proto::write_bind_params(&mut buf, "", "s_test", &[&val as &(dyn Encode + Sync)]);
let tmpl = build_bind_template(&buf, 1);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots.len(), 1);
assert_eq!(tmpl.param_slots[0].1, 4);
}
#[test]
fn build_bind_template_null_param() {
let mut buf = Vec::new();
let val: Option<i32> = None;
proto::write_bind_params(&mut buf, "", "s_test", &[&val as &(dyn Encode + Sync)]);
let tmpl = build_bind_template(&buf, 1);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots.len(), 1);
assert_eq!(tmpl.param_slots[0].1, -1); }
#[test]
fn build_bind_template_multiple_params() {
let mut buf = Vec::new();
let id: i32 = 1;
let name: &str = "alice";
proto::write_bind_params(
&mut buf,
"",
"s_test",
&[&id as &(dyn Encode + Sync), &name as &(dyn Encode + Sync)],
);
let tmpl = build_bind_template(&buf, 2);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots.len(), 2);
assert_eq!(tmpl.param_slots[0].1, 4); assert_eq!(tmpl.param_slots[1].1, 5); }
#[test]
fn build_bind_template_empty_buf() {
let tmpl = build_bind_template(&[], 0);
assert!(tmpl.is_none());
}
#[test]
fn build_bind_template_wrong_type() {
let tmpl = build_bind_template(&[b'E', 0, 0, 0, 4], 0);
assert!(tmpl.is_none());
}
#[test]
fn build_bind_template_param_count_mismatch() {
let mut buf = Vec::new();
let val: i32 = 42;
proto::write_bind_params(&mut buf, "", "s_test", &[&val as &(dyn Encode + Sync)]);
let tmpl = build_bind_template(&buf, 2);
assert!(tmpl.is_none());
}
#[test]
fn hash_sql_consistency() {
let h = hash_sql("SELECT 1");
assert_eq!(h, hash_sql("SELECT 1"));
assert_ne!(h, hash_sql("SELECT 2"));
}
#[test]
fn sync_connect_tcp_fails_with_uds_message() {
let config = Config::from_url("postgres://user:pass@localhost:5432/db").unwrap();
let result = SyncConnection::connect(&config);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Unix domain socket"),
"error should mention UDS: {err}"
);
}
#[test]
fn sync_connect_ip_address_fails() {
let config = Config::from_url("postgres://user:pass@127.0.0.1:5432/db").unwrap();
let result = SyncConnection::connect(&config);
assert!(result.is_err());
}
#[test]
fn sync_make_stmt_name_max() {
let name = make_stmt_name(u64::MAX);
assert_eq!(&*name, "s_ffffffffffffffff");
}
#[test]
fn sync_make_stmt_name_one() {
let name = make_stmt_name(1);
assert_eq!(&*name, "s_0000000000000001");
}
#[test]
fn sync_make_stmt_name_powers_of_two() {
let name = make_stmt_name(256);
assert_eq!(&*name, "s_0000000000000100");
}
#[test]
fn sync_make_stmt_name_format_always_18_chars() {
for val in [0u64, 1, 0xFF, 0xFFFF, 0xFFFF_FFFF, u64::MAX] {
let name = make_stmt_name(val);
assert_eq!(name.len(), 18, "name len for {val:x}");
assert!(name.starts_with("s_"));
assert!(name[2..].chars().all(|c| c.is_ascii_hexdigit()));
}
}
#[test]
fn sync_stmt_cache_insert_get_remove() {
let mut cache = StmtCache::default();
let info = StmtInfo {
name: "s_test".into(),
columns: Arc::from(Vec::new()),
last_used: 1,
bind_template: None,
};
cache.insert(42, info);
assert_eq!(cache.len(), 1);
assert!(cache.contains_key(&42));
assert!(cache.get(&42).is_some());
assert!(cache.get_mut(&42).is_some());
let removed = cache.remove(&42);
assert!(removed.is_some());
assert_eq!(cache.len(), 0);
assert!(!cache.contains_key(&42));
}
#[test]
fn sync_stmt_cache_evict_lru() {
let mut cache = StmtCache::default();
for i in 0..3u64 {
cache.insert(
i,
StmtInfo {
name: format!("s_{i}").into(),
columns: Arc::from(Vec::new()),
last_used: i + 1,
bind_template: None,
},
);
}
assert_eq!(cache.len(), 3);
let evicted = cache.evict_lru().unwrap();
assert_eq!(evicted.0, 0); assert_eq!(cache.len(), 2);
}
#[test]
fn sync_stmt_cache_insert_overwrite() {
let mut cache = StmtCache::default();
let info1 = StmtInfo {
name: "s_a".into(),
columns: Arc::from(Vec::new()),
last_used: 1,
bind_template: None,
};
let info2 = StmtInfo {
name: "s_b".into(),
columns: Arc::from(Vec::new()),
last_used: 2,
bind_template: None,
};
cache.insert(42, info1);
cache.insert(42, info2);
assert_eq!(cache.len(), 1);
assert_eq!(&*cache.get(&42).unwrap().name, "s_b");
}
#[test]
fn sync_data_row_all_null() {
let mut arena = Arena::new();
let mut out = Vec::new();
let mut data = Vec::new();
data.extend_from_slice(&3i16.to_be_bytes());
data.extend_from_slice(&(-1i32).to_be_bytes());
data.extend_from_slice(&(-1i32).to_be_bytes());
data.extend_from_slice(&(-1i32).to_be_bytes());
parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
assert_eq!(out.len(), 3);
for (_, len) in &out {
assert_eq!(*len, -1);
}
}
#[test]
fn sync_data_row_long_text() {
let mut arena = Arena::new();
let mut out = Vec::new();
let long_text = "a".repeat(2048);
let text_bytes = long_text.as_bytes();
let mut data = Vec::new();
data.extend_from_slice(&1i16.to_be_bytes());
data.extend_from_slice(&(text_bytes.len() as i32).to_be_bytes());
data.extend_from_slice(text_bytes);
parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
assert_eq!(out.len(), 1);
assert_eq!(out[0].1, text_bytes.len() as i32);
let stored = arena.get(out[0].0, out[0].1 as usize);
assert_eq!(stored, text_bytes);
}
#[test]
fn sync_data_row_empty_text() {
let mut arena = Arena::new();
let mut out = Vec::new();
let mut data = Vec::new();
data.extend_from_slice(&1i16.to_be_bytes());
data.extend_from_slice(&0i32.to_be_bytes()); parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
assert_eq!(out.len(), 1);
assert_eq!(out[0].1, 0); }
#[test]
fn sync_data_row_17_columns_exceeds_smallvec() {
let mut arena = Arena::new();
let mut out = Vec::new();
let mut data = Vec::new();
let num_cols: i16 = 20;
data.extend_from_slice(&num_cols.to_be_bytes());
for i in 0..num_cols {
let val = (i as i32).to_be_bytes();
data.extend_from_slice(&4i32.to_be_bytes());
data.extend_from_slice(&val);
}
parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
assert_eq!(out.len(), 20);
for (idx, (offset, len)) in out.iter().enumerate() {
assert_eq!(*len, 4);
let stored = arena.get(*offset, 4);
let val = i32::from_be_bytes([stored[0], stored[1], stored[2], stored[3]]);
assert_eq!(val, idx as i32);
}
}
#[test]
fn sync_data_row_mixed_null_and_data() {
let mut arena = Arena::new();
let mut out = Vec::new();
let mut data = Vec::new();
data.extend_from_slice(&5i16.to_be_bytes());
data.extend_from_slice(&(-1i32).to_be_bytes());
data.extend_from_slice(&4i32.to_be_bytes());
data.extend_from_slice(&42i32.to_be_bytes());
data.extend_from_slice(&(-1i32).to_be_bytes());
data.extend_from_slice(&(-1i32).to_be_bytes());
data.extend_from_slice(&5i32.to_be_bytes());
data.extend_from_slice(b"hello");
parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
assert_eq!(out.len(), 5);
assert_eq!(out[0].1, -1);
assert_eq!(out[1].1, 4);
assert_eq!(out[2].1, -1);
assert_eq!(out[3].1, -1);
assert_eq!(out[4].1, 5);
let stored = arena.get(out[4].0, 5);
assert_eq!(stored, b"hello");
}
#[test]
fn build_bind_template_too_short_buf() {
let tmpl = build_bind_template(&[b'B', 0, 0], 0);
assert!(tmpl.is_none());
}
#[test]
fn build_bind_template_zero_params() {
let mut buf = Vec::new();
proto::write_bind_params(&mut buf, "", "s_test", &[]);
let tmpl = build_bind_template(&buf, 0);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots.len(), 0);
}
#[test]
fn build_bind_template_bool_param() {
let mut buf = Vec::new();
let val = true;
proto::write_bind_params(&mut buf, "", "s_test", &[&val as &(dyn Encode + Sync)]);
let tmpl = build_bind_template(&buf, 1);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots.len(), 1);
assert_eq!(tmpl.param_slots[0].1, 1); }
#[test]
fn build_bind_template_i64_param() {
let mut buf = Vec::new();
let val: i64 = 123456789;
proto::write_bind_params(&mut buf, "", "s_test", &[&val as &(dyn Encode + Sync)]);
let tmpl = build_bind_template(&buf, 1);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots[0].1, 8); }
#[test]
fn build_bind_template_f64_param() {
let mut buf = Vec::new();
let val: f64 = 3.14;
proto::write_bind_params(&mut buf, "", "s_test", &[&val as &(dyn Encode + Sync)]);
let tmpl = build_bind_template(&buf, 1);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots[0].1, 8); }
#[test]
fn build_bind_template_str_param() {
let mut buf = Vec::new();
let val: &str = "hello world";
proto::write_bind_params(&mut buf, "", "s_test", &[&val as &(dyn Encode + Sync)]);
let tmpl = build_bind_template(&buf, 1);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots[0].1, 11); }
#[test]
fn build_bind_template_mixed_params_with_null() {
let mut buf = Vec::new();
let id: i32 = 1;
let name: Option<i32> = None;
let score: f64 = 9.9;
proto::write_bind_params(
&mut buf,
"",
"s_test",
&[
&id as &(dyn Encode + Sync),
&name as &(dyn Encode + Sync),
&score as &(dyn Encode + Sync),
],
);
let tmpl = build_bind_template(&buf, 3);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots.len(), 3);
assert_eq!(tmpl.param_slots[0].1, 4); assert_eq!(tmpl.param_slots[1].1, -1); assert_eq!(tmpl.param_slots[2].1, 8); }
#[test]
fn build_bind_template_preserves_bytes() {
let mut buf = Vec::new();
let val: i32 = 42;
proto::write_bind_params(&mut buf, "", "s_test", &[&val as &(dyn Encode + Sync)]);
let bind_len = buf.len();
let tmpl = build_bind_template(&buf, 1).unwrap();
assert_eq!(
&tmpl.bytes[..bind_len],
&buf[..],
"template must start with original Bind message"
);
assert_eq!(
&tmpl.bytes[bind_len..],
proto::EXECUTE_SYNC,
"template must end with EXECUTE_SYNC"
);
}
#[test]
#[ignore] fn sync_connect_uds_if_pg_available() {
let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
let result = SyncConnection::connect(&config);
if let Ok(conn) = result {
assert!(conn.pid() != 0, "pid should be nonzero");
assert!(conn.is_idle(), "should start idle");
assert!(!conn.is_in_transaction(), "should not be in tx");
assert!(
!conn.is_in_failed_transaction(),
"should not be in failed tx"
);
assert_eq!(conn.stmt_cache_len(), 0, "cache should be empty");
let _ = conn.close();
}
}
#[test]
#[ignore] fn sync_simple_query_if_pg_available() {
let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
let mut conn = SyncConnection::connect(&config).unwrap();
conn.simple_query("SELECT 1").unwrap();
assert!(conn.is_idle());
let _ = conn.close();
}
#[test]
#[ignore] fn sync_query_with_params_if_pg_available() {
let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
let mut conn = SyncConnection::connect(&config).unwrap();
let mut arena = Arena::new();
let sql = "SELECT $1::int4 + $2::int4 AS sum";
let hash = hash_sql(sql);
let a: i32 = 10;
let b: i32 = 20;
let result = conn
.query(
sql,
hash,
&[&a as &(dyn Encode + Sync), &b as &(dyn Encode + Sync)],
&mut arena,
)
.unwrap();
assert_eq!(result.len(), 1);
let _ = conn.close();
}
#[test]
#[ignore] fn sync_execute_if_pg_available() {
let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
let mut conn = SyncConnection::connect(&config).unwrap();
conn.simple_query("CREATE TEMP TABLE _sync_test (id int)")
.unwrap();
let sql = "INSERT INTO _sync_test VALUES ($1::int4)";
let hash = hash_sql(sql);
let val: i32 = 42;
let affected = conn
.execute(sql, hash, &[&val as &(dyn Encode + Sync)])
.unwrap();
assert_eq!(affected, 1);
let _ = conn.close();
}
#[test]
#[ignore] fn sync_for_each_zero_rows_if_pg_available() {
let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
let mut conn = SyncConnection::connect(&config).unwrap();
conn.simple_query("CREATE TEMP TABLE _sync_fe0 (id int)")
.unwrap();
let sql = "SELECT id FROM _sync_fe0";
let hash = hash_sql(sql);
let mut count = 0u32;
conn.for_each(sql, hash, &[], |_row| {
count += 1;
Ok(())
})
.unwrap();
assert_eq!(count, 0);
let _ = conn.close();
}
#[test]
#[ignore] fn sync_for_each_multiple_rows_if_pg_available() {
let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
let mut conn = SyncConnection::connect(&config).unwrap();
let sql = "SELECT generate_series(1, 5)";
let hash = hash_sql(sql);
let mut count = 0u32;
conn.for_each(sql, hash, &[], |_row| {
count += 1;
Ok(())
})
.unwrap();
assert_eq!(count, 5);
let _ = conn.close();
}
#[test]
#[ignore] fn sync_prepare_only_if_pg_available() {
let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
let mut conn = SyncConnection::connect(&config).unwrap();
let sql = "SELECT 1";
let hash = hash_sql(sql);
conn.prepare_only(sql, hash).unwrap();
assert_eq!(conn.stmt_cache_len(), 1);
conn.prepare_only(sql, hash).unwrap();
assert_eq!(conn.stmt_cache_len(), 1);
let _ = conn.close();
}
#[test]
#[ignore] fn sync_simple_query_rows_if_pg_available() {
let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
let mut conn = SyncConnection::connect(&config).unwrap();
let rows = conn.simple_query_rows("SELECT 42 AS n").unwrap();
assert!(!rows.is_empty());
let _ = conn.close();
}
#[test]
#[ignore] fn sync_stmt_cache_hit_miss_if_pg_available() {
let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
let mut conn = SyncConnection::connect(&config).unwrap();
let mut arena = Arena::new();
let sql1 = "SELECT 1";
let hash1 = hash_sql(sql1);
conn.query(sql1, hash1, &[], &mut arena).unwrap();
assert_eq!(conn.stmt_cache_len(), 1);
arena.reset();
conn.query(sql1, hash1, &[], &mut arena).unwrap();
assert_eq!(conn.stmt_cache_len(), 1);
let sql2 = "SELECT 2";
let hash2 = hash_sql(sql2);
arena.reset();
conn.query(sql2, hash2, &[], &mut arena).unwrap();
assert_eq!(conn.stmt_cache_len(), 2);
let _ = conn.close();
}
#[test]
#[ignore] fn sync_invalid_sql_error_if_pg_available() {
let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
let mut conn = SyncConnection::connect(&config).unwrap();
let mut arena = Arena::new();
let sql = "SELECTTTT INVALID GARBAGE";
let hash = hash_sql(sql);
let result = conn.query(sql, hash, &[], &mut arena);
assert!(result.is_err());
assert!(conn.is_idle());
let _ = conn.close();
}
#[test]
#[ignore] fn sync_tx_state_transitions_if_pg_available() {
let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
let mut conn = SyncConnection::connect(&config).unwrap();
assert!(conn.is_idle());
assert!(!conn.is_in_transaction());
conn.simple_query("BEGIN").unwrap();
assert!(conn.is_in_transaction());
assert!(!conn.is_idle());
conn.simple_query("COMMIT").unwrap();
assert!(conn.is_idle());
assert!(!conn.is_in_transaction());
let _ = conn.close();
}
#[test]
#[ignore] fn sync_lru_cache_eviction_if_pg_available() {
let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
let mut conn = SyncConnection::connect(&config).unwrap();
conn.set_max_stmt_cache_size(3);
let mut arena = Arena::new();
for i in 0..5 {
let sql = format!("SELECT {}", i);
let hash = hash_sql(&sql);
arena.reset();
conn.query(&sql, hash, &[], &mut arena).unwrap();
}
assert!(
conn.stmt_cache_len() <= 3,
"cache should be capped at 3, got {}",
conn.stmt_cache_len()
);
let _ = conn.close();
}
#[test]
#[ignore] fn sync_for_each_raw_if_pg_available() {
let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
let mut conn = SyncConnection::connect(&config).unwrap();
let sql = "SELECT generate_series(1, 3)";
let hash = hash_sql(sql);
let mut raw_count = 0u32;
conn.for_each_raw(sql, hash, &[], |_raw_data| {
raw_count += 1;
Ok(())
})
.unwrap();
assert_eq!(raw_count, 3);
let _ = conn.close();
}
#[test]
#[ignore] fn sync_query_null_params_if_pg_available() {
let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
let mut conn = SyncConnection::connect(&config).unwrap();
let mut arena = Arena::new();
let sql = "SELECT $1::int4 IS NULL AS is_null";
let hash = hash_sql(sql);
let val: Option<i32> = None;
let _result = conn
.query(sql, hash, &[&val as &(dyn Encode + Sync)], &mut arena)
.unwrap();
let _ = conn.close();
}
#[test]
#[ignore] fn sync_query_various_param_types_if_pg_available() {
let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
let mut conn = SyncConnection::connect(&config).unwrap();
let mut arena = Arena::new();
let sql = "SELECT $1::int4, $2::int8, $3::text, $4::bool, $5::float8";
let hash = hash_sql(sql);
let p1: i32 = 42;
let p2: i64 = 9999999;
let p3: &str = "hello";
let p4: bool = true;
let p5: f64 = 3.14;
let result = conn
.query(
sql,
hash,
&[
&p1 as &(dyn Encode + Sync),
&p2 as &(dyn Encode + Sync),
&p3 as &(dyn Encode + Sync),
&p4 as &(dyn Encode + Sync),
&p5 as &(dyn Encode + Sync),
],
&mut arena,
)
.unwrap();
assert_eq!(result.len(), 1);
let _ = conn.close();
}
#[test]
fn sync_shrink_threshold_values() {
let shrink = 64 * 1024usize;
let initial = 8192usize;
assert!(
shrink > initial,
"shrink threshold must exceed initial size"
);
}
#[test]
fn sync_connection_debug_format() {
let fmt_str = format!(
"SyncConnection {{ pid: {}, tx_status: '{}', stmt_cache_len: {} }}",
0, 'I', 0
);
assert!(fmt_str.contains("SyncConnection"));
assert!(fmt_str.contains("pid"));
assert!(fmt_str.contains("tx_status"));
}
}