use crate::config::Config;
use crate::error::PgsqlError;
use crate::packet::{AuthStatus, Packet, SuccessMessage};
use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream};
use std::time::{Duration, Instant};
#[derive(Debug)]
pub(crate) enum PgStream {
Plain(TcpStream),
#[cfg(feature = "tls")]
Tls(Box<native_tls::TlsStream<TcpStream>>),
}
impl Read for PgStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
PgStream::Plain(s) => s.read(buf),
#[cfg(feature = "tls")]
PgStream::Tls(s) => s.read(buf),
}
}
}
impl Write for PgStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
PgStream::Plain(s) => s.write(buf),
#[cfg(feature = "tls")]
PgStream::Tls(s) => s.write(buf),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match self {
PgStream::Plain(s) => s.flush(),
#[cfg(feature = "tls")]
PgStream::Tls(s) => s.flush(),
}
}
}
impl PgStream {
fn peer_addr(&self) -> std::io::Result<SocketAddr> {
match self {
PgStream::Plain(s) => s.peer_addr(),
#[cfg(feature = "tls")]
PgStream::Tls(s) => s.get_ref().peer_addr(),
}
}
fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
match self {
PgStream::Plain(s) => s.shutdown(how),
#[cfg(feature = "tls")]
PgStream::Tls(s) => s.get_ref().shutdown(how),
}
}
#[allow(dead_code)]
fn set_read_timeout(&self, dur: Option<Duration>) -> std::io::Result<()> {
match self {
PgStream::Plain(s) => s.set_read_timeout(dur),
#[cfg(feature = "tls")]
PgStream::Tls(s) => s.get_ref().set_read_timeout(dur),
}
}
}
#[derive(Debug)]
pub struct Connect {
pub(crate) stream: PgStream,
_peer_addr: SocketAddr,
packet: Packet,
auth_status: AuthStatus,
last_used: Instant,
created_at: Instant,
}
impl Connect {
pub fn is_valid(&mut self) -> bool {
if self.stream.peer_addr().is_err() {
return false;
}
#[cfg(not(test))]
const IDLE_THRESHOLD: Duration = Duration::from_secs(5);
#[cfg(test)]
const IDLE_THRESHOLD: Duration = Duration::from_millis(0);
if self.last_used.elapsed() > IDLE_THRESHOLD {
return self.query("SELECT 1").is_ok();
}
true
}
pub fn peer_valid(&self) -> bool {
self.stream.peer_addr().is_ok()
}
pub fn touch(&mut self) {
self.last_used = Instant::now();
}
pub fn idle_elapsed(&self) -> Duration {
self.last_used.elapsed()
}
pub fn age(&self) -> Duration {
self.created_at.elapsed()
}
pub fn _close(&mut self) {
let _ = self.stream.write_all(&Packet::pack_terminate());
let _ = self.stream.shutdown(std::net::Shutdown::Both);
}
fn set_keepalive(stream: &TcpStream) -> Result<(), PgsqlError> {
let keepalive = socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(60))
.with_interval(Duration::from_secs(15));
#[cfg(not(target_os = "windows"))]
let keepalive = keepalive.with_retries(3);
let socket = socket2::SockRef::from(stream);
socket
.set_tcp_keepalive(&keepalive)
.map_err(|e| PgsqlError::Connection(format!("设置 TCP Keepalive 失败: {}", e)))
}
fn try_ssl_upgrade(mut stream: TcpStream, config: &Config) -> Result<PgStream, PgsqlError> {
stream
.write_all(&Packet::pack_ssl_request())
.map_err(|e| PgsqlError::Connection(format!("发送 SSLRequest 失败: {}", e)))?;
let mut resp = [0u8; 1];
stream
.read_exact(&mut resp)
.map_err(|e| PgsqlError::Connection(format!("读取 SSL 响应失败: {}", e)))?;
match resp[0] {
b'S' => {
#[cfg(feature = "tls")]
{
let connector = native_tls::TlsConnector::builder()
.danger_accept_invalid_certs(true)
.build()
.map_err(|e| PgsqlError::Connection(format!("TLS 初始化失败: {}", e)))?;
let tls_stream = connector
.connect(&config.hostname, stream)
.map_err(|e| PgsqlError::Connection(format!("TLS 握手失败: {}", e)))?;
Ok(PgStream::Tls(Box::new(tls_stream)))
}
#[cfg(not(feature = "tls"))]
{
let _ = config;
Err(PgsqlError::Connection(
"服务端要求 SSL 但未启用 tls feature".into(),
))
}
}
b'N' => {
if config.sslmode == "require" {
Err(PgsqlError::Connection(
"sslmode=require 但服务端不支持 SSL".into(),
))
} else {
Ok(PgStream::Plain(stream))
}
}
other => Err(PgsqlError::Connection(format!(
"无效的 SSL 响应字节: 0x{:02X}",
other
))),
}
}
pub fn new(mut config: Config) -> Result<Connect, PgsqlError> {
let stream =
TcpStream::connect(config.url()).map_err(|e| PgsqlError::Connection(e.to_string()))?;
stream
.set_nodelay(true)
.map_err(|e| PgsqlError::Connection(format!("设置 TCP_NODELAY 失败: {}", e)))?;
Self::set_keepalive(&stream)?;
stream
.set_read_timeout(Some(Duration::from_secs(30)))
.map_err(|e| PgsqlError::Connection(format!("设置读取超时失败: {}", e)))?;
stream
.set_write_timeout(Some(Duration::from_secs(30)))
.map_err(|e| PgsqlError::Connection(format!("设置写入超时失败: {}", e)))?;
let peer_addr = stream
.peer_addr()
.map_err(|e| PgsqlError::Connection(e.to_string()))?;
let stream = if config.sslmode != "disable" {
Self::try_ssl_upgrade(stream, &config)?
} else {
PgStream::Plain(stream)
};
let mut connect = Self {
stream,
_peer_addr: peer_addr,
packet: Packet::new(config),
auth_status: AuthStatus::None,
last_used: Instant::now(),
created_at: Instant::now(),
};
connect.authenticate()?;
Ok(connect)
}
fn authenticate(&mut self) -> Result<(), PgsqlError> {
self.stream
.write_all(&self.packet.pack_first())
.map_err(|e| PgsqlError::Auth(format!("发送 startup message 失败: {}", e)))?;
let data = self.read()?;
self.packet.unpack(data, 0)?;
if !self.packet.md5_salt.is_empty() {
self.md5_auth()?;
} else if self.packet.auth_mechanism.is_empty() && self.packet.md5_salt.is_empty() {
self.cleartext_auth()?;
} else {
self.scram_auth()?;
}
self.auth_status = AuthStatus::AuthenticationOk;
Ok(())
}
fn md5_auth(&mut self) -> Result<(), PgsqlError> {
self.stream
.write_all(&self.packet.pack_md5_password())
.map_err(|e| PgsqlError::Auth(format!("发送 MD5 密码失败: {}", e)))?;
let data = self.read()?;
self.packet.unpack(data, 0)?;
Ok(())
}
fn cleartext_auth(&mut self) -> Result<(), PgsqlError> {
self.stream
.write_all(&self.packet.pack_cleartext_password())
.map_err(|e| PgsqlError::Auth(format!("发送明文密码失败: {}", e)))?;
let data = self.read()?;
self.packet.unpack(data, 0)?;
Ok(())
}
fn scram_auth(&mut self) -> Result<(), PgsqlError> {
self.stream
.write_all(&self.packet.pack_auth())
.map_err(|e| PgsqlError::Auth(format!("发送 SASL Initial Response 失败: {}", e)))?;
let data = self.read()?;
self.packet.unpack(data, 0)?;
self.stream
.write_all(&self.packet.pack_auth_verify())
.map_err(|e| PgsqlError::Auth(format!("发送 SASL Verify 失败: {}", e)))?;
let data = self.read()?;
self.packet.unpack(data, 0)?;
Ok(())
}
fn read(&mut self) -> Result<Vec<u8>, PgsqlError> {
let mut msg = Vec::new();
let mut buf = [0u8; 4096];
let mut retry_count = 0;
#[cfg(not(test))]
const MAX_RETRIES: u32 = 100;
#[cfg(test)]
const MAX_RETRIES: u32 = 3;
#[cfg(not(test))]
const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024;
#[cfg(test)]
const MAX_MESSAGE_SIZE: usize = 128;
#[cfg(not(test))]
let deadline = std::time::Instant::now() + Duration::from_secs(300);
#[cfg(test)]
let deadline = std::time::Instant::now() + Duration::from_millis(200);
loop {
if std::time::Instant::now() >= deadline {
return Err(PgsqlError::Timeout("读取总超时".into()));
}
match self.stream.read(&mut buf) {
Ok(0) => return Err(PgsqlError::Connection("连接已关闭或服务端断开".into())),
Ok(n) => {
if msg.len() + n > MAX_MESSAGE_SIZE {
return Err(PgsqlError::Protocol("消息超过最大允许大小".into()));
}
msg.extend_from_slice(&buf[..n]);
retry_count = 0;
}
Err(ref e)
if e.kind() == std::io::ErrorKind::WouldBlock
|| e.kind() == std::io::ErrorKind::TimedOut =>
{
retry_count += 1;
if retry_count > MAX_RETRIES {
return Err(PgsqlError::Timeout("读取超时,已达最大重试次数".into()));
}
std::thread::sleep(Duration::from_millis(10));
continue;
}
Err(e) => return Err(PgsqlError::Io(e)),
};
if let AuthStatus::AuthenticationOk = self.auth_status {
if msg.ends_with(&[90, 0, 0, 0, 5, 73])
|| msg.ends_with(&[90, 0, 0, 0, 5, 84])
|| msg.ends_with(&[90, 0, 0, 0, 5, 69])
{
break;
}
} else if msg.len() >= 5 {
let len_bytes = &msg[1..=4];
if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
if msg.len() > len as usize {
break;
}
}
}
}
Ok(msg)
}
pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
self.stream
.write_all(&self.packet.pack_query(sql))
.map_err(PgsqlError::Io)?;
let data = self.read()?;
self.last_used = Instant::now();
self.packet.unpack(data, 0)
}
pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
self.stream
.write_all(&self.packet.pack_execute(sql))
.map_err(PgsqlError::Io)?;
let data = self.read()?;
self.last_used = Instant::now();
self.packet.unpack(data, 0)
}
pub fn query_params(
&mut self,
sql: &str,
params: &[Option<&str>],
) -> Result<SuccessMessage, PgsqlError> {
self.stream
.write_all(&self.packet.pack_query_params(sql, params))
.map_err(PgsqlError::Io)?;
let data = self.read()?;
self.last_used = Instant::now();
self.packet.unpack(data, 0)
}
pub fn execute_params(
&mut self,
sql: &str,
params: &[Option<&str>],
) -> Result<SuccessMessage, PgsqlError> {
self.stream
.write_all(&self.packet.pack_execute_params(sql, params))
.map_err(PgsqlError::Io)?;
let data = self.read()?;
self.last_used = Instant::now();
self.packet.unpack(data, 0)
}
pub fn query_str(&mut self, sql: &str, params: &[&str]) -> Result<SuccessMessage, PgsqlError> {
let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
self.query_params(sql, &opts)
}
pub fn execute_str(
&mut self,
sql: &str,
params: &[&str],
) -> Result<SuccessMessage, PgsqlError> {
let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
self.execute_params(sql, &opts)
}
pub fn query_portal(&mut self, sql: &str, max_rows: u32) -> Result<SuccessMessage, PgsqlError> {
self.stream
.write_all(&self.packet.pack_query_portal(sql, max_rows))
.map_err(PgsqlError::Io)?;
let data = self.read()?;
self.last_used = Instant::now();
self.packet.unpack(data, 0)
}
pub fn fetch_more(&mut self, max_rows: u32) -> Result<SuccessMessage, PgsqlError> {
self.stream
.write_all(&self.packet.pack_fetch_more(max_rows))
.map_err(PgsqlError::Io)?;
let data = self.read()?;
self.last_used = Instant::now();
self.packet.unpack(data, 0)
}
pub fn close_portal(&mut self) -> Result<SuccessMessage, PgsqlError> {
self.stream
.write_all(&self.packet.pack_close_portal())
.map_err(PgsqlError::Io)?;
let data = self.read()?;
self.last_used = Instant::now();
self.packet.unpack(data, 0)
}
}
impl Drop for Connect {
fn drop(&mut self) {
let _ = self.stream.write_all(&Packet::pack_terminate());
let _ = self.stream.shutdown(std::net::Shutdown::Both);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::TcpListener;
use std::thread;
fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
let mut m = Vec::with_capacity(5 + payload.len());
m.push(tag);
m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
m.extend_from_slice(payload);
m
}
fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
let mut body = Vec::new();
body.extend(&auth_type.to_be_bytes());
body.extend_from_slice(extra);
pg_msg(b'R', &body)
}
fn auth_ok() -> Vec<u8> {
pg_auth(0, &[])
}
fn param_status() -> Vec<u8> {
pg_msg(b'S', b"server_version\x0015.0\x00")
}
fn backend_key() -> Vec<u8> {
let mut p = Vec::new();
p.extend(&1u32.to_be_bytes());
p.extend(&2u32.to_be_bytes());
pg_msg(b'K', &p)
}
fn ready_for_query() -> Vec<u8> {
pg_msg(b'Z', b"I")
}
fn post_auth_ok() -> Vec<u8> {
let mut v = Vec::new();
v.extend(auth_ok());
v.extend(param_status());
v.extend(backend_key());
v.extend(ready_for_query());
v
}
fn simple_query_response() -> Vec<u8> {
let mut r = Vec::new();
r.extend(pg_msg(b'1', &[]));
r.extend(pg_msg(b'2', &[]));
let mut rd = Vec::new();
rd.extend(&1u16.to_be_bytes()); rd.extend(b"c\x00"); rd.extend(&0u32.to_be_bytes()); rd.extend(&1u16.to_be_bytes()); rd.extend(&23u32.to_be_bytes()); rd.extend(&4i16.to_be_bytes()); rd.extend(&(-1i32).to_be_bytes()); rd.extend(&0u16.to_be_bytes()); r.extend(pg_msg(b'T', &rd));
let mut dr = Vec::new();
dr.extend(&1u16.to_be_bytes());
dr.extend(&1u32.to_be_bytes()); dr.push(b'1');
r.extend(pg_msg(b'D', &dr));
r.extend(pg_msg(b'C', b"SELECT 1\x00"));
r.extend(ready_for_query());
r
}
fn execute_response() -> Vec<u8> {
let mut r = Vec::new();
r.extend(pg_msg(b'1', &[]));
r.extend(pg_msg(b'2', &[]));
r.extend(pg_msg(b'n', &[])); r.extend(pg_msg(b'C', b"UPDATE 3\x00"));
r.extend(ready_for_query());
r
}
fn query_params_response() -> Vec<u8> {
let mut r = Vec::new();
r.extend(pg_msg(b'1', &[]));
let mut pd = Vec::new();
pd.extend(&1u16.to_be_bytes());
pd.extend(&23u32.to_be_bytes());
r.extend(pg_msg(b't', &pd));
r.extend(pg_msg(b'2', &[]));
let mut rd = Vec::new();
rd.extend(&1u16.to_be_bytes());
rd.extend(b"p\x00");
rd.extend(&0u32.to_be_bytes());
rd.extend(&1u16.to_be_bytes());
rd.extend(&23u32.to_be_bytes());
rd.extend(&4i16.to_be_bytes());
rd.extend(&(-1i32).to_be_bytes());
rd.extend(&0u16.to_be_bytes());
r.extend(pg_msg(b'T', &rd));
let mut dr = Vec::new();
dr.extend(&1u16.to_be_bytes());
dr.extend(&2u32.to_be_bytes());
dr.extend(b"42");
r.extend(pg_msg(b'D', &dr));
r.extend(pg_msg(b'C', b"SELECT 1\x00"));
r.extend(ready_for_query());
r
}
fn execute_params_response() -> Vec<u8> {
let mut r = Vec::new();
r.extend(pg_msg(b'1', &[]));
let mut pd = Vec::new();
pd.extend(&1u16.to_be_bytes());
pd.extend(&23u32.to_be_bytes());
r.extend(pg_msg(b't', &pd));
r.extend(pg_msg(b'2', &[]));
r.extend(pg_msg(b'n', &[]));
r.extend(pg_msg(b'C', b"UPDATE 1\x00"));
r.extend(ready_for_query());
r
}
fn query_params_null_response() -> Vec<u8> {
let mut r = Vec::new();
r.extend(pg_msg(b'1', &[]));
let mut pd = Vec::new();
pd.extend(&1u16.to_be_bytes());
pd.extend(&25u32.to_be_bytes());
r.extend(pg_msg(b't', &pd));
r.extend(pg_msg(b'2', &[]));
let mut rd = Vec::new();
rd.extend(&1u16.to_be_bytes());
rd.extend(b"n\x00");
rd.extend(&0u32.to_be_bytes());
rd.extend(&1u16.to_be_bytes());
rd.extend(&25u32.to_be_bytes());
rd.extend(&(-1i16).to_be_bytes());
rd.extend(&(-1i32).to_be_bytes());
rd.extend(&0u16.to_be_bytes());
r.extend(pg_msg(b'T', &rd));
let mut dr = Vec::new();
dr.extend(&1u16.to_be_bytes());
dr.extend(&(-1i32).to_be_bytes());
r.extend(pg_msg(b'D', &dr));
r.extend(pg_msg(b'C', b"SELECT 1\x00"));
r.extend(ready_for_query());
r
}
fn error_response() -> Vec<u8> {
let mut payload = Vec::new();
payload.push(b'C');
payload.extend(b"42601\x00");
payload.push(b'M');
payload.extend(b"syntax error\x00");
payload.push(0);
let mut r = Vec::new();
r.extend(pg_msg(b'1', &[]));
r.extend(pg_msg(b'2', &[]));
r.extend(pg_msg(b'E', &payload));
r.extend(ready_for_query());
r
}
fn mock_config(port: u16) -> Config {
Config {
debug: false,
hostname: "127.0.0.1".into(),
hostport: port as i32,
username: "u".into(),
userpass: "p".into(),
database: "d".into(),
charset: "utf8".into(),
pool_max: 5,
sslmode: "disable".into(),
}
}
fn spawn_cleartext_server() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&pg_auth(3, &[]));
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&post_auth_ok());
loop {
match s.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(_) => {
let _ = s.write_all(&simple_query_response());
}
}
}
});
thread::sleep(Duration::from_millis(30));
port
}
fn spawn_md5_server() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&pg_auth(5, &[0xAA, 0xBB, 0xCC, 0xDD]));
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&post_auth_ok());
loop {
match s.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(_) => {
let _ = s.write_all(&simple_query_response());
}
}
}
});
thread::sleep(Duration::from_millis(30));
port
}
fn spawn_scram_server() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&pg_auth(10, b"SCRAM-SHA-256\x00\x00"));
let n = s.read(&mut buf).unwrap();
let payload = &buf[..n];
let text = String::from_utf8_lossy(payload);
let client_nonce = text.split("r=").nth(1).unwrap_or("clientnonce").to_string();
let challenge = format!("r={client_nonce}SERVERNONCE,s=c2FsdA==,i=4096");
let _ = s.write_all(&pg_auth(11, challenge.as_bytes()));
let _ = s.read(&mut buf).unwrap();
let mut resp = Vec::new();
resp.extend(pg_auth(12, b"v=dummyproof"));
resp.extend(auth_ok());
resp.extend(param_status());
resp.extend(backend_key());
resp.extend(ready_for_query());
let _ = s.write_all(&resp);
loop {
match s.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(_) => {
let _ = s.write_all(&simple_query_response());
}
}
}
});
thread::sleep(Duration::from_millis(30));
port
}
fn spawn_eof_server() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (s, _) = listener.accept().unwrap();
drop(s); });
thread::sleep(Duration::from_millis(30));
port
}
fn spawn_auth_error_server() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf).unwrap();
let mut payload = Vec::new();
payload.push(b'C');
payload.extend(b"28P01\x00");
payload.push(b'M');
payload.extend(b"password authentication failed\x00");
payload.push(0);
let _ = s.write_all(&pg_msg(b'E', &payload));
});
thread::sleep(Duration::from_millis(30));
port
}
fn spawn_query_error_server() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&pg_auth(3, &[]));
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&post_auth_ok());
loop {
match s.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(_) => {
let _ = s.write_all(&error_response());
}
}
}
});
thread::sleep(Duration::from_millis(30));
port
}
fn spawn_execute_server() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&pg_auth(3, &[]));
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&post_auth_ok());
loop {
match s.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(_) => {
let _ = s.write_all(&execute_response());
}
}
}
});
thread::sleep(Duration::from_millis(30));
port
}
fn spawn_query_params_server() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&pg_auth(3, &[]));
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&post_auth_ok());
loop {
match s.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(_) => {
let _ = s.write_all(&query_params_response());
}
}
}
});
thread::sleep(Duration::from_millis(30));
port
}
fn spawn_params_server() -> u16 {
spawn_query_params_server()
}
fn spawn_execute_params_server() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&pg_auth(3, &[]));
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&post_auth_ok());
loop {
match s.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(_) => {
let _ = s.write_all(&execute_params_response());
}
}
}
});
thread::sleep(Duration::from_millis(30));
port
}
fn spawn_query_params_null_server() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&pg_auth(3, &[]));
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&post_auth_ok());
loop {
match s.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(_) => {
let _ = s.write_all(&query_params_null_response());
}
}
}
});
thread::sleep(Duration::from_millis(30));
port
}
#[test]
fn connect_cleartext_auth_success() {
let port = spawn_cleartext_server();
let conn = Connect::new(mock_config(port));
assert!(conn.is_ok());
}
#[test]
fn connect_md5_auth_success() {
let port = spawn_md5_server();
let conn = Connect::new(mock_config(port));
assert!(conn.is_ok());
}
#[test]
fn connect_scram_auth_success() {
let port = spawn_scram_server();
let conn = Connect::new(mock_config(port));
assert!(conn.is_ok());
}
#[test]
fn connect_connection_refused() {
let cfg = mock_config(1);
let result = Connect::new(cfg);
assert!(result.is_err());
match result.unwrap_err() {
PgsqlError::Connection(_) => {}
other => panic!("expected Connection error, got {other:?}"),
}
}
#[test]
fn connect_server_closes_immediately() {
let port = spawn_eof_server();
let result = Connect::new(mock_config(port));
assert!(result.is_err());
}
#[test]
fn connect_auth_error_from_server() {
let port = spawn_auth_error_server();
let result = Connect::new(mock_config(port));
assert!(result.is_err());
}
#[test]
fn connect_query_success() {
let port = spawn_cleartext_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
let result = conn.query("SELECT 1");
assert!(result.is_ok());
let msg = result.unwrap();
assert_eq!(msg.rows.len(), 1);
assert_eq!(msg.rows[0]["c"].as_i32(), Some(1));
}
#[test]
fn connect_execute_success() {
let port = spawn_execute_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
let result = conn.execute("UPDATE t SET x=1");
assert!(result.is_ok());
let msg = result.unwrap();
assert_eq!(msg.affect_count, 3);
assert_eq!(msg.tag, "UPDATE 3");
}
#[test]
fn connect_query_params_success() {
let port = spawn_query_params_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
let result = conn.query_params("SELECT $1::int", &[Some("42")]);
assert!(result.is_ok());
let msg = result.unwrap();
assert!(!msg.param_oids.is_empty());
assert_eq!(msg.rows.len(), 1);
assert_eq!(msg.rows[0]["p"].as_i32(), Some(42));
}
#[test]
fn connect_execute_params_success() {
let port = spawn_execute_params_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
let result = conn.execute_params("UPDATE t SET x=$1", &[Some("42")]);
assert!(result.is_ok());
let msg = result.unwrap();
assert!(!msg.param_oids.is_empty());
assert_eq!(msg.affect_count, 1);
assert_eq!(msg.tag, "UPDATE 1");
}
#[test]
fn connect_query_str_success() {
let port = spawn_params_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
let result = conn.query_str("SELECT $1::int", &["42"]);
assert!(result.is_ok());
let msg = result.unwrap();
assert!(!msg.param_oids.is_empty());
assert_eq!(msg.rows.len(), 1);
}
#[test]
fn connect_execute_str_success() {
let port = spawn_execute_params_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
let result = conn.execute_str("UPDATE t SET x=$1", &["1"]);
assert!(result.is_ok());
let msg = result.unwrap();
assert!(!msg.param_oids.is_empty());
assert_eq!(msg.affect_count, 1);
}
#[test]
fn connect_query_params_with_null() {
let port = spawn_query_params_null_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
let result = conn.query_params("SELECT $1::text", &[None]);
assert!(result.is_ok());
let msg = result.unwrap();
assert!(!msg.param_oids.is_empty());
assert_eq!(msg.rows.len(), 1);
assert_eq!(msg.rows[0]["n"], "");
}
#[test]
fn connect_query_params_empty_string_vs_null() {
let port = spawn_params_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
let r1 = conn.query_params("SELECT $1::text", &[Some("")]);
assert!(r1.is_ok());
let r2 = conn.query_params("SELECT $1::text", &[None]);
assert!(r2.is_ok());
}
#[test]
fn connect_query_returns_error() {
let port = spawn_query_error_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
let result = conn.query("BAD SQL");
assert!(result.is_err());
}
#[test]
fn connect_is_valid_true() {
let port = spawn_cleartext_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
assert!(conn.is_valid());
}
#[test]
fn connect_is_valid_false_after_close() {
let port = spawn_cleartext_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
conn._close();
assert!(!conn.is_valid());
}
#[test]
fn connect_close_does_not_panic() {
let port = spawn_cleartext_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
conn._close();
conn._close();
}
#[test]
fn connect_drop_does_not_panic() {
let port = spawn_cleartext_server();
let conn = Connect::new(mock_config(port)).unwrap();
drop(conn);
}
fn spawn_transaction_status_server() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&pg_auth(3, &[]));
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&post_auth_ok());
loop {
match s.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(_) => {
let mut r = Vec::new();
r.extend(pg_msg(b'1', &[]));
r.extend(pg_msg(b'2', &[]));
let mut rd = Vec::new();
rd.extend(&1u16.to_be_bytes());
rd.extend(b"c\x00");
rd.extend(&0u32.to_be_bytes());
rd.extend(&1u16.to_be_bytes());
rd.extend(&23u32.to_be_bytes());
rd.extend(&4i16.to_be_bytes());
rd.extend(&(-1i32).to_be_bytes());
rd.extend(&0u16.to_be_bytes());
r.extend(pg_msg(b'T', &rd));
let mut dr = Vec::new();
dr.extend(&1u16.to_be_bytes());
dr.extend(&1u32.to_be_bytes());
dr.push(b'1');
r.extend(pg_msg(b'D', &dr));
r.extend(pg_msg(b'C', b"SELECT 1\x00"));
r.extend(pg_msg(b'Z', b"T"));
let _ = s.write_all(&r);
}
}
}
});
thread::sleep(Duration::from_millis(30));
port
}
fn spawn_error_status_server() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&pg_auth(3, &[]));
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&post_auth_ok());
loop {
match s.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(_) => {
let mut r = Vec::new();
r.extend(pg_msg(b'1', &[]));
r.extend(pg_msg(b'2', &[]));
let mut rd = Vec::new();
rd.extend(&1u16.to_be_bytes());
rd.extend(b"c\x00");
rd.extend(&0u32.to_be_bytes());
rd.extend(&1u16.to_be_bytes());
rd.extend(&23u32.to_be_bytes());
rd.extend(&4i16.to_be_bytes());
rd.extend(&(-1i32).to_be_bytes());
rd.extend(&0u16.to_be_bytes());
r.extend(pg_msg(b'T', &rd));
let mut dr = Vec::new();
dr.extend(&1u16.to_be_bytes());
dr.extend(&1u32.to_be_bytes());
dr.push(b'1');
r.extend(pg_msg(b'D', &dr));
r.extend(pg_msg(b'C', b"SELECT 1\x00"));
r.extend(pg_msg(b'Z', b"E"));
let _ = s.write_all(&r);
}
}
}
});
thread::sleep(Duration::from_millis(30));
port
}
fn spawn_slow_partial_server() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&pg_auth(3, &[]));
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&post_auth_ok());
match s.read(&mut buf) {
Ok(0) | Err(_) => {}
Ok(_) => {
let _ = s.write_all(&simple_query_response());
}
}
});
thread::sleep(Duration::from_millis(30));
port
}
fn spawn_rst_on_query_server() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&pg_auth(3, &[]));
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&post_auth_ok());
match s.read(&mut buf) {
Ok(0) | Err(_) => {}
Ok(_) => {
drop(s);
}
}
});
thread::sleep(Duration::from_millis(30));
port
}
#[test]
fn connect_query_ready_for_query_transaction_status() {
let port = spawn_transaction_status_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
let result = conn.query("SELECT 1");
assert!(result.is_ok());
}
#[test]
fn connect_query_ready_for_query_error_status() {
let port = spawn_error_status_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
let result = conn.query("SELECT 1");
assert!(result.is_ok());
}
#[test]
fn connect_query_server_closes_after_partial() {
let port = spawn_slow_partial_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
let r1 = conn.query("SELECT 1");
assert!(r1.is_ok());
let r2 = conn.query("SELECT 1");
assert!(r2.is_err());
}
#[test]
fn connect_query_server_rst_returns_io_or_connection_error() {
let port = spawn_rst_on_query_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
let result = conn.query("SELECT 1");
assert!(result.is_err());
}
#[test]
fn connect_read_would_block_max_retries() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf);
let _ = s.write_all(&pg_auth(3, &[]));
let _ = s.read(&mut buf);
let _ = s.write_all(&post_auth_ok());
let _ = s.read(&mut buf);
thread::sleep(Duration::from_secs(5));
});
thread::sleep(Duration::from_millis(30));
let mut conn = Connect::new(mock_config(port)).unwrap();
conn.stream
.set_read_timeout(Some(Duration::from_millis(1)))
.ok();
let result = conn.query("SELECT 1");
assert!(result.is_err());
let err_str = result.unwrap_err().to_string();
assert!(
err_str.contains("超时") || err_str.contains("Timeout") || err_str.contains("重试"),
"expected timeout error, got: {err_str}"
);
}
#[test]
fn connect_read_exceeds_max_message_size() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf);
let _ = s.write_all(&pg_auth(3, &[]));
let _ = s.read(&mut buf);
let _ = s.write_all(&post_auth_ok());
let _ = s.read(&mut buf);
let big = vec![b'X'; 256];
let _ = s.write_all(&big);
thread::sleep(Duration::from_secs(2));
});
thread::sleep(Duration::from_millis(30));
let mut conn = Connect::new(mock_config(port)).unwrap();
let result = conn.query("SELECT 1");
assert!(result.is_err());
let err_str = result.unwrap_err().to_string();
assert!(
err_str.contains("最大") || err_str.contains("大小") || err_str.contains("size"),
"expected max message size error, got: {err_str}"
);
}
#[test]
fn connect_read_deadline_timeout() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf);
let _ = s.write_all(&pg_auth(3, &[]));
let _ = s.read(&mut buf);
let _ = s.write_all(&post_auth_ok());
let _ = s.read(&mut buf);
for _ in 0..200 {
let _ = s.write_all(b"X");
thread::sleep(Duration::from_millis(5));
}
});
thread::sleep(Duration::from_millis(30));
let mut conn = Connect::new(mock_config(port)).unwrap();
let result = conn.query("SELECT 1");
assert!(result.is_err());
}
#[test]
fn connect_read_partial_auth_frame() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf);
let auth = pg_auth(3, &[]);
let _ = s.write_all(&auth[..5]);
thread::sleep(Duration::from_millis(50));
let _ = s.write_all(&auth[5..]);
let _ = s.read(&mut buf);
let _ = s.write_all(&post_auth_ok());
loop {
match s.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(_) => {
let _ = s.write_all(&simple_query_response());
}
}
}
});
thread::sleep(Duration::from_millis(30));
let mut conn = Connect::new(mock_config(port)).unwrap();
let result = conn.query("SELECT 1");
assert!(result.is_ok());
}
fn portal_response(rows: u16) -> Vec<u8> {
let mut r = Vec::new();
r.extend(pg_msg(b'1', &[])); r.extend(pg_msg(b'2', &[])); let mut rd = Vec::new();
rd.extend(&1u16.to_be_bytes());
rd.extend(b"id\x00");
rd.extend(&0u32.to_be_bytes());
rd.extend(&1u16.to_be_bytes());
rd.extend(&23u32.to_be_bytes());
rd.extend(&4i16.to_be_bytes());
rd.extend(&(-1i32).to_be_bytes());
rd.extend(&0u16.to_be_bytes());
r.extend(pg_msg(b'T', &rd));
for i in 0..rows {
let val = format!("{}", i + 1);
let mut dr = Vec::new();
dr.extend(&1u16.to_be_bytes());
dr.extend(&(val.len() as u32).to_be_bytes());
dr.extend(val.as_bytes());
r.extend(pg_msg(b'D', &dr));
}
r.extend(pg_msg(b's', &[]));
r.extend(ready_for_query());
r
}
fn portal_complete_response(rows: u16) -> Vec<u8> {
let mut r = Vec::new();
for i in 0..rows {
let val = format!("{}", i + 1);
let mut dr = Vec::new();
dr.extend(&1u16.to_be_bytes());
dr.extend(&(val.len() as u32).to_be_bytes());
dr.extend(val.as_bytes());
r.extend(pg_msg(b'D', &dr));
}
r.extend(pg_msg(b'C', b"SELECT 2\x00"));
r.extend(ready_for_query());
r
}
fn close_portal_response() -> Vec<u8> {
let mut r = Vec::new();
r.extend(pg_msg(b'3', &[])); r.extend(ready_for_query());
r
}
fn spawn_portal_server() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&pg_auth(3, &[]));
let _ = s.read(&mut buf).unwrap();
let _ = s.write_all(&post_auth_ok());
match s.read(&mut buf) {
Ok(0) | Err(_) => (),
Ok(_) => {
let _ = s.write_all(&portal_response(2));
}
}
match s.read(&mut buf) {
Ok(0) | Err(_) => (),
Ok(_) => {
let _ = s.write_all(&portal_complete_response(1));
}
}
match s.read(&mut buf) {
Ok(0) | Err(_) => (),
Ok(_) => {
let _ = s.write_all(&close_portal_response());
}
}
});
thread::sleep(Duration::from_millis(30));
port
}
#[test]
fn ssl_prefer_fallback_on_rejection() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf);
let _ = s.write_all(b"N");
let _ = s.read(&mut buf);
let _ = s.write_all(&pg_auth(3, &[]));
let _ = s.read(&mut buf);
let _ = s.write_all(&post_auth_ok());
loop {
match s.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(_) => {
let _ = s.write_all(&simple_query_response());
}
}
}
});
thread::sleep(Duration::from_millis(30));
let mut cfg = mock_config(port);
cfg.sslmode = "prefer".into();
let conn = Connect::new(cfg);
assert!(conn.is_ok());
}
#[test]
fn ssl_require_rejected_returns_error() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf);
let _ = s.write_all(b"N");
});
thread::sleep(Duration::from_millis(30));
let mut cfg = mock_config(port);
cfg.sslmode = "require".into();
let result = Connect::new(cfg);
assert!(result.is_err());
}
#[test]
fn ssl_invalid_response_byte() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let _ = s.read(&mut buf);
let _ = s.write_all(b"X");
});
thread::sleep(Duration::from_millis(30));
let mut cfg = mock_config(port);
cfg.sslmode = "prefer".into();
let result = Connect::new(cfg);
assert!(result.is_err());
}
#[test]
fn ssl_disable_skips_ssl_handshake() {
let port = spawn_cleartext_server();
let mut cfg = mock_config(port);
cfg.sslmode = "disable".into();
let conn = Connect::new(cfg);
assert!(conn.is_ok());
}
#[test]
fn connect_query_portal_returns_rows_with_has_more() {
let port = spawn_portal_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
let result = conn.query_portal("SELECT id FROM t", 2);
assert!(result.is_ok());
let msg = result.unwrap();
assert_eq!(msg.rows.len(), 2);
assert!(msg.has_more);
}
#[test]
fn connect_fetch_more_returns_remaining_rows() {
let port = spawn_portal_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
let _ = conn.query_portal("SELECT id FROM t", 2).unwrap();
let result = conn.fetch_more(10);
assert!(result.is_ok());
let msg = result.unwrap();
assert_eq!(msg.rows.len(), 1);
assert!(!msg.has_more);
}
#[test]
fn connect_close_portal_succeeds() {
let port = spawn_portal_server();
let mut conn = Connect::new(mock_config(port)).unwrap();
let _ = conn.query_portal("SELECT id FROM t", 2).unwrap();
let _ = conn.fetch_more(10).unwrap();
let result = conn.close_portal();
assert!(result.is_ok());
}
}