Skip to main content

br_pgsql/
connect.rs

1use crate::config::Config;
2use crate::error::PgsqlError;
3use crate::packet::{AuthStatus, Packet, SuccessMessage};
4use std::io::{Read, Write};
5use std::net::{SocketAddr, TcpStream};
6use std::time::{Duration, Instant};
7
8/// 连接流:支持明文和 TLS
9#[derive(Debug)]
10pub(crate) enum PgStream {
11    Plain(TcpStream),
12    #[cfg(feature = "tls")]
13    Tls(Box<native_tls::TlsStream<TcpStream>>),
14}
15
16impl Read for PgStream {
17    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
18        match self {
19            PgStream::Plain(s) => s.read(buf),
20            #[cfg(feature = "tls")]
21            PgStream::Tls(s) => s.read(buf),
22        }
23    }
24}
25
26impl Write for PgStream {
27    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
28        match self {
29            PgStream::Plain(s) => s.write(buf),
30            #[cfg(feature = "tls")]
31            PgStream::Tls(s) => s.write(buf),
32        }
33    }
34    fn flush(&mut self) -> std::io::Result<()> {
35        match self {
36            PgStream::Plain(s) => s.flush(),
37            #[cfg(feature = "tls")]
38            PgStream::Tls(s) => s.flush(),
39        }
40    }
41}
42
43impl PgStream {
44    fn peer_addr(&self) -> std::io::Result<SocketAddr> {
45        match self {
46            PgStream::Plain(s) => s.peer_addr(),
47            #[cfg(feature = "tls")]
48            PgStream::Tls(s) => s.get_ref().peer_addr(),
49        }
50    }
51    fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
52        match self {
53            PgStream::Plain(s) => s.shutdown(how),
54            #[cfg(feature = "tls")]
55            PgStream::Tls(s) => s.get_ref().shutdown(how),
56        }
57    }
58    #[allow(dead_code)]
59    fn set_read_timeout(&self, dur: Option<Duration>) -> std::io::Result<()> {
60        match self {
61            PgStream::Plain(s) => s.set_read_timeout(dur),
62            #[cfg(feature = "tls")]
63            PgStream::Tls(s) => s.get_ref().set_read_timeout(dur),
64        }
65    }
66}
67
68#[derive(Debug)]
69pub struct Connect {
70    pub(crate) stream: PgStream,
71    /// 缓存的 peer 地址
72    _peer_addr: SocketAddr,
73    packet: Packet,
74    auth_status: AuthStatus,
75    /// 上次使用时间,用于懒健康检查
76    last_used: Instant,
77    /// 连接创建时间,用于最大生命周期检查
78    created_at: Instant,
79}
80
81impl Connect {
82    /// 懒健康检查:peer_addr 检测 TCP 存活,空闲超过 5 秒才发 SELECT 1
83    pub fn is_valid(&mut self) -> bool {
84        if self.stream.peer_addr().is_err() {
85            return false;
86        }
87        #[cfg(not(test))]
88        const IDLE_THRESHOLD: Duration = Duration::from_secs(5);
89        #[cfg(test)]
90        const IDLE_THRESHOLD: Duration = Duration::from_millis(0);
91        if self.last_used.elapsed() > IDLE_THRESHOLD {
92            return self.query("SELECT 1").is_ok();
93        }
94        true
95    }
96
97    /// 仅检查 TCP 连接是否存活(不发送查询)
98    pub fn peer_valid(&self) -> bool {
99        self.stream.peer_addr().is_ok()
100    }
101
102    /// 更新最后使用时间
103    pub fn touch(&mut self) {
104        self.last_used = Instant::now();
105    }
106
107    /// 返回空闲时长
108    pub fn idle_elapsed(&self) -> Duration {
109        self.last_used.elapsed()
110    }
111
112    /// 返回连接已存活时长
113    pub fn age(&self) -> Duration {
114        self.created_at.elapsed()
115    }
116
117    pub fn _close(&mut self) {
118        let _ = self.stream.write_all(&Packet::pack_terminate());
119        let _ = self.stream.shutdown(std::net::Shutdown::Both);
120    }
121
122    /// 设置 TCP Keepalive(60秒探测,15秒间隔,Unix 下 3 次重试)
123    fn set_keepalive(stream: &TcpStream) -> Result<(), PgsqlError> {
124        let keepalive = socket2::TcpKeepalive::new()
125            .with_time(Duration::from_secs(60))
126            .with_interval(Duration::from_secs(15));
127        #[cfg(not(target_os = "windows"))]
128        let keepalive = keepalive.with_retries(3);
129
130        let socket = socket2::SockRef::from(stream);
131        socket
132            .set_tcp_keepalive(&keepalive)
133            .map_err(|e| PgsqlError::Connection(format!("设置 TCP Keepalive 失败: {}", e)))
134    }
135
136    /// SSL/TLS 升级:发送 SSLRequest,根据服务端响应升级或回退
137    fn try_ssl_upgrade(mut stream: TcpStream, config: &Config) -> Result<PgStream, PgsqlError> {
138        // 发送 SSLRequest
139        stream
140            .write_all(&Packet::pack_ssl_request())
141            .map_err(|e| PgsqlError::Connection(format!("发送 SSLRequest 失败: {}", e)))?;
142        let mut resp = [0u8; 1];
143        stream
144            .read_exact(&mut resp)
145            .map_err(|e| PgsqlError::Connection(format!("读取 SSL 响应失败: {}", e)))?;
146        match resp[0] {
147            b'S' => {
148                // 服务端同意 SSL
149                #[cfg(feature = "tls")]
150                {
151                    let connector = native_tls::TlsConnector::builder()
152                        .danger_accept_invalid_certs(true)
153                        .build()
154                        .map_err(|e| PgsqlError::Connection(format!("TLS 初始化失败: {}", e)))?;
155                    let tls_stream = connector
156                        .connect(&config.hostname, stream)
157                        .map_err(|e| PgsqlError::Connection(format!("TLS 握手失败: {}", e)))?;
158                    Ok(PgStream::Tls(Box::new(tls_stream)))
159                }
160                #[cfg(not(feature = "tls"))]
161                {
162                    let _ = config;
163                    Err(PgsqlError::Connection(
164                        "服务端要求 SSL 但未启用 tls feature".into(),
165                    ))
166                }
167            }
168            b'N' => {
169                // 服务端拒绝 SSL
170                if config.sslmode == "require" {
171                    Err(PgsqlError::Connection(
172                        "sslmode=require 但服务端不支持 SSL".into(),
173                    ))
174                } else {
175                    Ok(PgStream::Plain(stream))
176                }
177            }
178            other => Err(PgsqlError::Connection(format!(
179                "无效的 SSL 响应字节: 0x{:02X}",
180                other
181            ))),
182        }
183    }
184    pub fn new(mut config: Config) -> Result<Connect, PgsqlError> {
185        let stream =
186            TcpStream::connect(config.url()).map_err(|e| PgsqlError::Connection(e.to_string()))?;
187        // TCP 优化:禁用 Nagle 算法,减少小包延迟
188        stream
189            .set_nodelay(true)
190            .map_err(|e| PgsqlError::Connection(format!("设置 TCP_NODELAY 失败: {}", e)))?;
191        // TCP Keepalive:防止空闲连接被防火墙/NAT/PG服务端静默断开
192        Self::set_keepalive(&stream)?;
193        stream
194            .set_read_timeout(Some(Duration::from_secs(30)))
195            .map_err(|e| PgsqlError::Connection(format!("设置读取超时失败: {}", e)))?;
196        stream
197            .set_write_timeout(Some(Duration::from_secs(30)))
198            .map_err(|e| PgsqlError::Connection(format!("设置写入超时失败: {}", e)))?;
199        let peer_addr = stream
200            .peer_addr()
201            .map_err(|e| PgsqlError::Connection(e.to_string()))?;
202
203        // SSL/TLS 升级
204        let stream = if config.sslmode != "disable" {
205            Self::try_ssl_upgrade(stream, &config)?
206        } else {
207            PgStream::Plain(stream)
208        };
209
210        let mut connect = Self {
211            stream,
212            _peer_addr: peer_addr,
213            packet: Packet::new(config),
214            auth_status: AuthStatus::None,
215            last_used: Instant::now(),
216            created_at: Instant::now(),
217        };
218
219        connect.authenticate()?;
220
221        Ok(connect)
222    }
223
224    fn authenticate(&mut self) -> Result<(), PgsqlError> {
225        self.stream
226            .write_all(&self.packet.pack_first())
227            .map_err(|e| PgsqlError::Auth(format!("发送 startup message 失败: {}", e)))?;
228
229        let data = self.read()?;
230        self.packet.unpack(data, 0)?;
231
232        if !self.packet.md5_salt.is_empty() {
233            self.md5_auth()?;
234        } else if self.packet.auth_mechanism.is_empty() && self.packet.md5_salt.is_empty() {
235            self.cleartext_auth()?;
236        } else {
237            self.scram_auth()?;
238        }
239
240        self.auth_status = AuthStatus::AuthenticationOk;
241        Ok(())
242    }
243
244    fn md5_auth(&mut self) -> Result<(), PgsqlError> {
245        self.stream
246            .write_all(&self.packet.pack_md5_password())
247            .map_err(|e| PgsqlError::Auth(format!("发送 MD5 密码失败: {}", e)))?;
248
249        let data = self.read()?;
250        self.packet.unpack(data, 0)?;
251        Ok(())
252    }
253
254    fn cleartext_auth(&mut self) -> Result<(), PgsqlError> {
255        self.stream
256            .write_all(&self.packet.pack_cleartext_password())
257            .map_err(|e| PgsqlError::Auth(format!("发送明文密码失败: {}", e)))?;
258
259        let data = self.read()?;
260        self.packet.unpack(data, 0)?;
261        Ok(())
262    }
263
264    fn scram_auth(&mut self) -> Result<(), PgsqlError> {
265        self.stream
266            .write_all(&self.packet.pack_auth())
267            .map_err(|e| PgsqlError::Auth(format!("发送 SASL Initial Response 失败: {}", e)))?;
268
269        let data = self.read()?;
270        self.packet.unpack(data, 0)?;
271
272        self.stream
273            .write_all(&self.packet.pack_auth_verify())
274            .map_err(|e| PgsqlError::Auth(format!("发送 SASL Verify 失败: {}", e)))?;
275
276        let data = self.read()?;
277        self.packet.unpack(data, 0)?;
278        Ok(())
279    }
280
281    fn read(&mut self) -> Result<Vec<u8>, PgsqlError> {
282        let mut msg = Vec::new();
283        let mut buf = [0u8; 4096];
284        let mut retry_count = 0;
285
286        #[cfg(not(test))]
287        const MAX_RETRIES: u32 = 100;
288        #[cfg(test)]
289        const MAX_RETRIES: u32 = 3;
290
291        #[cfg(not(test))]
292        const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024;
293        #[cfg(test)]
294        const MAX_MESSAGE_SIZE: usize = 128;
295
296        #[cfg(not(test))]
297        let deadline = std::time::Instant::now() + Duration::from_secs(300);
298        #[cfg(test)]
299        let deadline = std::time::Instant::now() + Duration::from_millis(200);
300
301        loop {
302            if std::time::Instant::now() >= deadline {
303                return Err(PgsqlError::Timeout("读取总超时".into()));
304            }
305
306            match self.stream.read(&mut buf) {
307                Ok(0) => return Err(PgsqlError::Connection("连接已关闭或服务端断开".into())),
308                Ok(n) => {
309                    if msg.len() + n > MAX_MESSAGE_SIZE {
310                        return Err(PgsqlError::Protocol("消息超过最大允许大小".into()));
311                    }
312                    msg.extend_from_slice(&buf[..n]);
313                    retry_count = 0;
314                }
315                Err(ref e)
316                    if e.kind() == std::io::ErrorKind::WouldBlock
317                        || e.kind() == std::io::ErrorKind::TimedOut =>
318                {
319                    retry_count += 1;
320                    if retry_count > MAX_RETRIES {
321                        return Err(PgsqlError::Timeout("读取超时,已达最大重试次数".into()));
322                    }
323                    std::thread::sleep(Duration::from_millis(10));
324                    continue;
325                }
326                Err(e) => return Err(PgsqlError::Io(e)),
327            };
328
329            if let AuthStatus::AuthenticationOk = self.auth_status {
330                if msg.ends_with(&[90, 0, 0, 0, 5, 73])
331                    || msg.ends_with(&[90, 0, 0, 0, 5, 84])
332                    || msg.ends_with(&[90, 0, 0, 0, 5, 69])
333                {
334                    break;
335                }
336            } else if msg.len() >= 5 {
337                let len_bytes = &msg[1..=4];
338                if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
339                    if msg.len() > len as usize {
340                        break;
341                    }
342                }
343            }
344        }
345
346        Ok(msg)
347    }
348
349    pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
350        self.stream
351            .write_all(&self.packet.pack_query(sql))
352            .map_err(PgsqlError::Io)?;
353        let data = self.read()?;
354        self.last_used = Instant::now();
355        self.packet.unpack(data, 0)
356    }
357
358    pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
359        self.stream
360            .write_all(&self.packet.pack_execute(sql))
361            .map_err(PgsqlError::Io)?;
362        let data = self.read()?;
363        self.last_used = Instant::now();
364        self.packet.unpack(data, 0)
365    }
366
367    /// 参数化查询
368    pub fn query_params(
369        &mut self,
370        sql: &str,
371        params: &[Option<&str>],
372    ) -> Result<SuccessMessage, PgsqlError> {
373        self.stream
374            .write_all(&self.packet.pack_query_params(sql, params))
375            .map_err(PgsqlError::Io)?;
376
377        let data = self.read()?;
378        self.last_used = Instant::now();
379        self.packet.unpack(data, 0)
380    }
381
382    /// 参数化执行
383    pub fn execute_params(
384        &mut self,
385        sql: &str,
386        params: &[Option<&str>],
387    ) -> Result<SuccessMessage, PgsqlError> {
388        self.stream
389            .write_all(&self.packet.pack_execute_params(sql, params))
390            .map_err(PgsqlError::Io)?;
391        let data = self.read()?;
392        self.last_used = Instant::now();
393        self.packet.unpack(data, 0)
394    }
395
396    /// 参数化查询(便捷版,所有参数非 NULL)
397    pub fn query_str(&mut self, sql: &str, params: &[&str]) -> Result<SuccessMessage, PgsqlError> {
398        let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
399        self.query_params(sql, &opts)
400    }
401
402    /// 参数化执行(便捷版,所有参数非 NULL)
403    pub fn execute_str(
404        &mut self,
405        sql: &str,
406        params: &[&str],
407    ) -> Result<SuccessMessage, PgsqlError> {
408        let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
409        self.execute_params(sql, &opts)
410    }
411    // ── Portal/Cursor API ──────────────────────────────────────────────
412    /// Portal 查询:分批获取结果,max_rows 指定每批行数
413    pub fn query_portal(&mut self, sql: &str, max_rows: u32) -> Result<SuccessMessage, PgsqlError> {
414        self.stream
415            .write_all(&self.packet.pack_query_portal(sql, max_rows))
416            .map_err(PgsqlError::Io)?;
417        let data = self.read()?;
418        self.last_used = Instant::now();
419        self.packet.unpack(data, 0)
420    }
421    /// 从已打开的 Portal 继续获取更多行
422    pub fn fetch_more(&mut self, max_rows: u32) -> Result<SuccessMessage, PgsqlError> {
423        self.stream
424            .write_all(&self.packet.pack_fetch_more(max_rows))
425            .map_err(PgsqlError::Io)?;
426        let data = self.read()?;
427        self.last_used = Instant::now();
428        self.packet.unpack(data, 0)
429    }
430    /// 关闭当前 Portal
431    pub fn close_portal(&mut self) -> Result<SuccessMessage, PgsqlError> {
432        self.stream
433            .write_all(&self.packet.pack_close_portal())
434            .map_err(PgsqlError::Io)?;
435        let data = self.read()?;
436        self.last_used = Instant::now();
437        self.packet.unpack(data, 0)
438    }
439}
440
441impl Drop for Connect {
442    fn drop(&mut self) {
443        let _ = self.stream.write_all(&Packet::pack_terminate());
444        let _ = self.stream.shutdown(std::net::Shutdown::Both);
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use std::net::TcpListener;
452    use std::thread;
453
454    // ── wire-protocol helpers ──────────────────────────────────────────
455
456    /// Build a single PG backend message: type_byte | len(4) | payload
457    fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
458        let mut m = Vec::with_capacity(5 + payload.len());
459        m.push(tag);
460        m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
461        m.extend_from_slice(payload);
462        m
463    }
464
465    /// Build an Authentication message (tag 'R') with given auth_type + extra payload
466    fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
467        let mut body = Vec::new();
468        body.extend(&auth_type.to_be_bytes());
469        body.extend_from_slice(extra);
470        pg_msg(b'R', &body)
471    }
472
473    /// AuthenticationOk (R, type=0)
474    fn auth_ok() -> Vec<u8> {
475        pg_auth(0, &[])
476    }
477
478    /// ParameterStatus for server_version=15.0
479    fn param_status() -> Vec<u8> {
480        pg_msg(b'S', b"server_version\x0015.0\x00")
481    }
482
483    /// BackendKeyData (process_id=1, secret_key=2)
484    fn backend_key() -> Vec<u8> {
485        let mut p = Vec::new();
486        p.extend(&1u32.to_be_bytes());
487        p.extend(&2u32.to_be_bytes());
488        pg_msg(b'K', &p)
489    }
490
491    /// ReadyForQuery (status = Idle)
492    fn ready_for_query() -> Vec<u8> {
493        pg_msg(b'Z', b"I")
494    }
495
496    /// The standard tail sent after auth succeeds: AuthOk + param + key + ready
497    fn post_auth_ok() -> Vec<u8> {
498        let mut v = Vec::new();
499        v.extend(auth_ok());
500        v.extend(param_status());
501        v.extend(backend_key());
502        v.extend(ready_for_query());
503        v
504    }
505
506    /// Build a simple query response: ParseComplete + BindComplete +
507    /// RowDescription(1 int4 col "c") + DataRow("1") + CommandComplete("SELECT 1") + ReadyForQuery
508    fn simple_query_response() -> Vec<u8> {
509        let mut r = Vec::new();
510        // ParseComplete
511        r.extend(pg_msg(b'1', &[]));
512        // BindComplete
513        r.extend(pg_msg(b'2', &[]));
514        // RowDescription – 1 field "c", type_oid=23 (int4)
515        let mut rd = Vec::new();
516        rd.extend(&1u16.to_be_bytes()); // field count
517        rd.extend(b"c\x00"); // name
518        rd.extend(&0u32.to_be_bytes()); // table oid
519        rd.extend(&1u16.to_be_bytes()); // column index
520        rd.extend(&23u32.to_be_bytes()); // type oid (int4)
521        rd.extend(&4i16.to_be_bytes()); // column length
522        rd.extend(&(-1i32).to_be_bytes()); // type modifier
523        rd.extend(&0u16.to_be_bytes()); // format (text)
524        r.extend(pg_msg(b'T', &rd));
525        // DataRow – 1 field, value "1"
526        let mut dr = Vec::new();
527        dr.extend(&1u16.to_be_bytes());
528        dr.extend(&1u32.to_be_bytes()); // length of value
529        dr.push(b'1');
530        r.extend(pg_msg(b'D', &dr));
531        // CommandComplete
532        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
533        // ReadyForQuery
534        r.extend(ready_for_query());
535        r
536    }
537
538    /// Build an execute response (no rows): ParseComplete + BindComplete +
539    /// NoData + CommandComplete("UPDATE 3") + ReadyForQuery
540    fn execute_response() -> Vec<u8> {
541        let mut r = Vec::new();
542        r.extend(pg_msg(b'1', &[]));
543        r.extend(pg_msg(b'2', &[]));
544        r.extend(pg_msg(b'n', &[])); // NoData
545        r.extend(pg_msg(b'C', b"UPDATE 3\x00"));
546        r.extend(ready_for_query());
547        r
548    }
549
550    /// Build a parameterized query response: ParseComplete + ParameterDescription + BindComplete +
551    /// RowDescription(1 int4 col "p") + DataRow("42") + CommandComplete("SELECT 1") + ReadyForQuery
552    fn query_params_response() -> Vec<u8> {
553        let mut r = Vec::new();
554        r.extend(pg_msg(b'1', &[]));
555
556        let mut pd = Vec::new();
557        pd.extend(&1u16.to_be_bytes());
558        pd.extend(&23u32.to_be_bytes());
559        r.extend(pg_msg(b't', &pd));
560
561        r.extend(pg_msg(b'2', &[]));
562
563        let mut rd = Vec::new();
564        rd.extend(&1u16.to_be_bytes());
565        rd.extend(b"p\x00");
566        rd.extend(&0u32.to_be_bytes());
567        rd.extend(&1u16.to_be_bytes());
568        rd.extend(&23u32.to_be_bytes());
569        rd.extend(&4i16.to_be_bytes());
570        rd.extend(&(-1i32).to_be_bytes());
571        rd.extend(&0u16.to_be_bytes());
572        r.extend(pg_msg(b'T', &rd));
573
574        let mut dr = Vec::new();
575        dr.extend(&1u16.to_be_bytes());
576        dr.extend(&2u32.to_be_bytes());
577        dr.extend(b"42");
578        r.extend(pg_msg(b'D', &dr));
579
580        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
581        r.extend(ready_for_query());
582        r
583    }
584
585    /// Build a parameterized execute response: ParseComplete + ParameterDescription + BindComplete +
586    /// NoData + CommandComplete("UPDATE 1") + ReadyForQuery
587    fn execute_params_response() -> Vec<u8> {
588        let mut r = Vec::new();
589        r.extend(pg_msg(b'1', &[]));
590
591        let mut pd = Vec::new();
592        pd.extend(&1u16.to_be_bytes());
593        pd.extend(&23u32.to_be_bytes());
594        r.extend(pg_msg(b't', &pd));
595
596        r.extend(pg_msg(b'2', &[]));
597        r.extend(pg_msg(b'n', &[]));
598        r.extend(pg_msg(b'C', b"UPDATE 1\x00"));
599        r.extend(ready_for_query());
600        r
601    }
602
603    /// Build a parameterized query response with NULL row value.
604    fn query_params_null_response() -> Vec<u8> {
605        let mut r = Vec::new();
606        r.extend(pg_msg(b'1', &[]));
607
608        let mut pd = Vec::new();
609        pd.extend(&1u16.to_be_bytes());
610        pd.extend(&25u32.to_be_bytes());
611        r.extend(pg_msg(b't', &pd));
612
613        r.extend(pg_msg(b'2', &[]));
614
615        let mut rd = Vec::new();
616        rd.extend(&1u16.to_be_bytes());
617        rd.extend(b"n\x00");
618        rd.extend(&0u32.to_be_bytes());
619        rd.extend(&1u16.to_be_bytes());
620        rd.extend(&25u32.to_be_bytes());
621        rd.extend(&(-1i16).to_be_bytes());
622        rd.extend(&(-1i32).to_be_bytes());
623        rd.extend(&0u16.to_be_bytes());
624        r.extend(pg_msg(b'T', &rd));
625
626        let mut dr = Vec::new();
627        dr.extend(&1u16.to_be_bytes());
628        dr.extend(&(-1i32).to_be_bytes());
629        r.extend(pg_msg(b'D', &dr));
630
631        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
632        r.extend(ready_for_query());
633        r
634    }
635
636    /// Build an ErrorResponse for query phase
637    fn error_response() -> Vec<u8> {
638        let mut payload = Vec::new();
639        payload.push(b'C');
640        payload.extend(b"42601\x00");
641        payload.push(b'M');
642        payload.extend(b"syntax error\x00");
643        payload.push(0);
644        let mut r = Vec::new();
645        r.extend(pg_msg(b'1', &[]));
646        r.extend(pg_msg(b'2', &[]));
647        r.extend(pg_msg(b'E', &payload));
648        r.extend(ready_for_query());
649        r
650    }
651
652    // ── mock server spawners ───────────────────────────────────────────
653
654    /// Config pointing at 127.0.0.1:<port>
655    fn mock_config(port: u16) -> Config {
656        Config {
657            debug: false,
658            hostname: "127.0.0.1".into(),
659            hostport: port as i32,
660            username: "u".into(),
661            userpass: "p".into(),
662            database: "d".into(),
663            charset: "utf8".into(),
664            pool_max: 5,
665            sslmode: "disable".into(),
666        }
667    }
668
669    /// Spawn a mock PG server that does **cleartext** auth.
670    /// Returns the port.  The server handles one connection.
671    fn spawn_cleartext_server() -> u16 {
672        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
673        let port = listener.local_addr().unwrap().port();
674        thread::spawn(move || {
675            let (mut s, _) = listener.accept().unwrap();
676            let mut buf = [0u8; 4096];
677            // 1. read startup
678            let _ = s.read(&mut buf).unwrap();
679            // 2. send CleartextPassword request (auth_type=3)
680            let _ = s.write_all(&pg_auth(3, &[]));
681            // 3. read password message
682            let _ = s.read(&mut buf).unwrap();
683            // 4. send AuthOk + params + ready
684            let _ = s.write_all(&post_auth_ok());
685            // keep connection alive for queries
686            loop {
687                match s.read(&mut buf) {
688                    Ok(0) | Err(_) => break,
689                    Ok(_) => {
690                        let _ = s.write_all(&simple_query_response());
691                    }
692                }
693            }
694        });
695        thread::sleep(Duration::from_millis(30));
696        port
697    }
698
699    /// Spawn a mock PG server that does **MD5** auth.
700    fn spawn_md5_server() -> u16 {
701        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
702        let port = listener.local_addr().unwrap().port();
703        thread::spawn(move || {
704            let (mut s, _) = listener.accept().unwrap();
705            let mut buf = [0u8; 4096];
706            // 1. read startup
707            let _ = s.read(&mut buf).unwrap();
708            // 2. send MD5Password request (auth_type=5) + 4-byte salt
709            let _ = s.write_all(&pg_auth(5, &[0xAA, 0xBB, 0xCC, 0xDD]));
710            // 3. read md5 password
711            let _ = s.read(&mut buf).unwrap();
712            // 4. send AuthOk + params + ready
713            let _ = s.write_all(&post_auth_ok());
714            loop {
715                match s.read(&mut buf) {
716                    Ok(0) | Err(_) => break,
717                    Ok(_) => {
718                        let _ = s.write_all(&simple_query_response());
719                    }
720                }
721            }
722        });
723        thread::sleep(Duration::from_millis(30));
724        port
725    }
726
727    /// Spawn a mock PG server that does **SCRAM-SHA-256** auth.
728    fn spawn_scram_server() -> u16 {
729        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
730        let port = listener.local_addr().unwrap().port();
731        thread::spawn(move || {
732            let (mut s, _) = listener.accept().unwrap();
733            let mut buf = [0u8; 4096];
734            // 1. read startup
735            let _ = s.read(&mut buf).unwrap();
736            // 2. send SASL auth (auth_type=10, mechanism)
737            let _ = s.write_all(&pg_auth(10, b"SCRAM-SHA-256\x00\x00"));
738            // 3. read SASLInitialResponse – extract client nonce
739            let n = s.read(&mut buf).unwrap();
740            let payload = &buf[..n];
741            // find "n=,r=" in the payload to extract client nonce
742            let text = String::from_utf8_lossy(payload);
743            let client_nonce = text.split("r=").nth(1).unwrap_or("clientnonce").to_string();
744            // 4. send SCRAM challenge (auth_type=11)
745            let challenge = format!("r={client_nonce}SERVERNONCE,s=c2FsdA==,i=4096");
746            let _ = s.write_all(&pg_auth(11, challenge.as_bytes()));
747            // 5. read SCRAM client final
748            let _ = s.read(&mut buf).unwrap();
749            // 6. send SCRAM complete (auth_type=12) + AuthOk + params + ready
750            let mut resp = Vec::new();
751            resp.extend(pg_auth(12, b"v=dummyproof"));
752            resp.extend(auth_ok());
753            resp.extend(param_status());
754            resp.extend(backend_key());
755            resp.extend(ready_for_query());
756            let _ = s.write_all(&resp);
757            loop {
758                match s.read(&mut buf) {
759                    Ok(0) | Err(_) => break,
760                    Ok(_) => {
761                        let _ = s.write_all(&simple_query_response());
762                    }
763                }
764            }
765        });
766        thread::sleep(Duration::from_millis(30));
767        port
768    }
769
770    /// Spawn a server that accepts connection then immediately closes (EOF).
771    fn spawn_eof_server() -> u16 {
772        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
773        let port = listener.local_addr().unwrap().port();
774        thread::spawn(move || {
775            let (s, _) = listener.accept().unwrap();
776            drop(s); // close immediately
777        });
778        thread::sleep(Duration::from_millis(30));
779        port
780    }
781
782    /// Spawn a server that sends an ErrorResponse after startup.
783    fn spawn_auth_error_server() -> u16 {
784        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
785        let port = listener.local_addr().unwrap().port();
786        thread::spawn(move || {
787            let (mut s, _) = listener.accept().unwrap();
788            let mut buf = [0u8; 4096];
789            let _ = s.read(&mut buf).unwrap();
790            // Send ErrorResponse
791            let mut payload = Vec::new();
792            payload.push(b'C');
793            payload.extend(b"28P01\x00");
794            payload.push(b'M');
795            payload.extend(b"password authentication failed\x00");
796            payload.push(0);
797            let _ = s.write_all(&pg_msg(b'E', &payload));
798        });
799        thread::sleep(Duration::from_millis(30));
800        port
801    }
802
803    /// Spawn a cleartext server that responds with error on query.
804    fn spawn_query_error_server() -> u16 {
805        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
806        let port = listener.local_addr().unwrap().port();
807        thread::spawn(move || {
808            let (mut s, _) = listener.accept().unwrap();
809            let mut buf = [0u8; 4096];
810            // auth
811            let _ = s.read(&mut buf).unwrap();
812            let _ = s.write_all(&pg_auth(3, &[]));
813            let _ = s.read(&mut buf).unwrap();
814            let _ = s.write_all(&post_auth_ok());
815            // query → error
816            loop {
817                match s.read(&mut buf) {
818                    Ok(0) | Err(_) => break,
819                    Ok(_) => {
820                        let _ = s.write_all(&error_response());
821                    }
822                }
823            }
824        });
825        thread::sleep(Duration::from_millis(30));
826        port
827    }
828
829    /// Spawn a cleartext server that responds with execute response.
830    fn spawn_execute_server() -> u16 {
831        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
832        let port = listener.local_addr().unwrap().port();
833        thread::spawn(move || {
834            let (mut s, _) = listener.accept().unwrap();
835            let mut buf = [0u8; 4096];
836            // auth
837            let _ = s.read(&mut buf).unwrap();
838            let _ = s.write_all(&pg_auth(3, &[]));
839            let _ = s.read(&mut buf).unwrap();
840            let _ = s.write_all(&post_auth_ok());
841            // queries
842            loop {
843                match s.read(&mut buf) {
844                    Ok(0) | Err(_) => break,
845                    Ok(_) => {
846                        let _ = s.write_all(&execute_response());
847                    }
848                }
849            }
850        });
851        thread::sleep(Duration::from_millis(30));
852        port
853    }
854
855    /// Spawn a cleartext server that responds with parameterized query response.
856    fn spawn_query_params_server() -> u16 {
857        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
858        let port = listener.local_addr().unwrap().port();
859        thread::spawn(move || {
860            let (mut s, _) = listener.accept().unwrap();
861            let mut buf = [0u8; 4096];
862            let _ = s.read(&mut buf).unwrap();
863            let _ = s.write_all(&pg_auth(3, &[]));
864            let _ = s.read(&mut buf).unwrap();
865            let _ = s.write_all(&post_auth_ok());
866            loop {
867                match s.read(&mut buf) {
868                    Ok(0) | Err(_) => break,
869                    Ok(_) => {
870                        let _ = s.write_all(&query_params_response());
871                    }
872                }
873            }
874        });
875        thread::sleep(Duration::from_millis(30));
876        port
877    }
878
879    fn spawn_params_server() -> u16 {
880        spawn_query_params_server()
881    }
882
883    /// Spawn a cleartext server that responds with parameterized execute response.
884    fn spawn_execute_params_server() -> u16 {
885        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
886        let port = listener.local_addr().unwrap().port();
887        thread::spawn(move || {
888            let (mut s, _) = listener.accept().unwrap();
889            let mut buf = [0u8; 4096];
890            let _ = s.read(&mut buf).unwrap();
891            let _ = s.write_all(&pg_auth(3, &[]));
892            let _ = s.read(&mut buf).unwrap();
893            let _ = s.write_all(&post_auth_ok());
894            loop {
895                match s.read(&mut buf) {
896                    Ok(0) | Err(_) => break,
897                    Ok(_) => {
898                        let _ = s.write_all(&execute_params_response());
899                    }
900                }
901            }
902        });
903        thread::sleep(Duration::from_millis(30));
904        port
905    }
906
907    /// Spawn a cleartext server that returns NULL in parameterized query result.
908    fn spawn_query_params_null_server() -> u16 {
909        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
910        let port = listener.local_addr().unwrap().port();
911        thread::spawn(move || {
912            let (mut s, _) = listener.accept().unwrap();
913            let mut buf = [0u8; 4096];
914            let _ = s.read(&mut buf).unwrap();
915            let _ = s.write_all(&pg_auth(3, &[]));
916            let _ = s.read(&mut buf).unwrap();
917            let _ = s.write_all(&post_auth_ok());
918            loop {
919                match s.read(&mut buf) {
920                    Ok(0) | Err(_) => break,
921                    Ok(_) => {
922                        let _ = s.write_all(&query_params_null_response());
923                    }
924                }
925            }
926        });
927        thread::sleep(Duration::from_millis(30));
928        port
929    }
930
931    // ── tests ──────────────────────────────────────────────────────────
932
933    #[test]
934    fn connect_cleartext_auth_success() {
935        let port = spawn_cleartext_server();
936        let conn = Connect::new(mock_config(port));
937        assert!(conn.is_ok());
938    }
939
940    #[test]
941    fn connect_md5_auth_success() {
942        let port = spawn_md5_server();
943        let conn = Connect::new(mock_config(port));
944        assert!(conn.is_ok());
945    }
946
947    #[test]
948    fn connect_scram_auth_success() {
949        let port = spawn_scram_server();
950        let conn = Connect::new(mock_config(port));
951        assert!(conn.is_ok());
952    }
953
954    #[test]
955    fn connect_connection_refused() {
956        // port 1 is almost certainly not listening
957        let cfg = mock_config(1);
958        let result = Connect::new(cfg);
959        assert!(result.is_err());
960        match result.unwrap_err() {
961            PgsqlError::Connection(_) => {}
962            other => panic!("expected Connection error, got {other:?}"),
963        }
964    }
965
966    #[test]
967    fn connect_server_closes_immediately() {
968        let port = spawn_eof_server();
969        let result = Connect::new(mock_config(port));
970        assert!(result.is_err());
971    }
972
973    #[test]
974    fn connect_auth_error_from_server() {
975        let port = spawn_auth_error_server();
976        let result = Connect::new(mock_config(port));
977        assert!(result.is_err());
978    }
979
980    #[test]
981    fn connect_query_success() {
982        let port = spawn_cleartext_server();
983        let mut conn = Connect::new(mock_config(port)).unwrap();
984        let result = conn.query("SELECT 1");
985        assert!(result.is_ok());
986        let msg = result.unwrap();
987        assert_eq!(msg.rows.len(), 1);
988        assert_eq!(msg.rows[0]["c"].as_i32(), Some(1));
989    }
990
991    #[test]
992    fn connect_execute_success() {
993        let port = spawn_execute_server();
994        let mut conn = Connect::new(mock_config(port)).unwrap();
995        let result = conn.execute("UPDATE t SET x=1");
996        assert!(result.is_ok());
997        let msg = result.unwrap();
998        assert_eq!(msg.affect_count, 3);
999        assert_eq!(msg.tag, "UPDATE 3");
1000    }
1001
1002    #[test]
1003    fn connect_query_params_success() {
1004        let port = spawn_query_params_server();
1005        let mut conn = Connect::new(mock_config(port)).unwrap();
1006        let result = conn.query_params("SELECT $1::int", &[Some("42")]);
1007        assert!(result.is_ok());
1008        let msg = result.unwrap();
1009        assert!(!msg.param_oids.is_empty());
1010        assert_eq!(msg.rows.len(), 1);
1011        assert_eq!(msg.rows[0]["p"].as_i32(), Some(42));
1012    }
1013
1014    #[test]
1015    fn connect_execute_params_success() {
1016        let port = spawn_execute_params_server();
1017        let mut conn = Connect::new(mock_config(port)).unwrap();
1018        let result = conn.execute_params("UPDATE t SET x=$1", &[Some("42")]);
1019        assert!(result.is_ok());
1020        let msg = result.unwrap();
1021        assert!(!msg.param_oids.is_empty());
1022        assert_eq!(msg.affect_count, 1);
1023        assert_eq!(msg.tag, "UPDATE 1");
1024    }
1025
1026    #[test]
1027    fn connect_query_str_success() {
1028        let port = spawn_params_server();
1029        let mut conn = Connect::new(mock_config(port)).unwrap();
1030        let result = conn.query_str("SELECT $1::int", &["42"]);
1031        assert!(result.is_ok());
1032        let msg = result.unwrap();
1033        assert!(!msg.param_oids.is_empty());
1034        assert_eq!(msg.rows.len(), 1);
1035    }
1036
1037    #[test]
1038    fn connect_execute_str_success() {
1039        let port = spawn_execute_params_server();
1040        let mut conn = Connect::new(mock_config(port)).unwrap();
1041        let result = conn.execute_str("UPDATE t SET x=$1", &["1"]);
1042        assert!(result.is_ok());
1043        let msg = result.unwrap();
1044        assert!(!msg.param_oids.is_empty());
1045        assert_eq!(msg.affect_count, 1);
1046    }
1047
1048    #[test]
1049    fn connect_query_params_with_null() {
1050        let port = spawn_query_params_null_server();
1051        let mut conn = Connect::new(mock_config(port)).unwrap();
1052        let result = conn.query_params("SELECT $1::text", &[None]);
1053        assert!(result.is_ok());
1054        let msg = result.unwrap();
1055        assert!(!msg.param_oids.is_empty());
1056        assert_eq!(msg.rows.len(), 1);
1057        assert_eq!(msg.rows[0]["n"], "");
1058    }
1059
1060    #[test]
1061    fn connect_query_params_empty_string_vs_null() {
1062        let port = spawn_params_server();
1063        let mut conn = Connect::new(mock_config(port)).unwrap();
1064
1065        // 空字符串参数
1066        let r1 = conn.query_params("SELECT $1::text", &[Some("")]);
1067        assert!(r1.is_ok());
1068
1069        // NULL 参数
1070        let r2 = conn.query_params("SELECT $1::text", &[None]);
1071        assert!(r2.is_ok());
1072    }
1073
1074    #[test]
1075    fn connect_query_returns_error() {
1076        let port = spawn_query_error_server();
1077        let mut conn = Connect::new(mock_config(port)).unwrap();
1078        let result = conn.query("BAD SQL");
1079        assert!(result.is_err());
1080    }
1081
1082    #[test]
1083    fn connect_is_valid_true() {
1084        let port = spawn_cleartext_server();
1085        let mut conn = Connect::new(mock_config(port)).unwrap();
1086        assert!(conn.is_valid());
1087    }
1088
1089    #[test]
1090    fn connect_is_valid_false_after_close() {
1091        let port = spawn_cleartext_server();
1092        let mut conn = Connect::new(mock_config(port)).unwrap();
1093        conn._close();
1094        // After closing, is_valid should return false
1095        assert!(!conn.is_valid());
1096    }
1097
1098    #[test]
1099    fn connect_close_does_not_panic() {
1100        let port = spawn_cleartext_server();
1101        let mut conn = Connect::new(mock_config(port)).unwrap();
1102        conn._close();
1103        // calling close again should not panic
1104        conn._close();
1105    }
1106
1107    #[test]
1108    fn connect_drop_does_not_panic() {
1109        let port = spawn_cleartext_server();
1110        let conn = Connect::new(mock_config(port)).unwrap();
1111        drop(conn);
1112    }
1113
1114    fn spawn_transaction_status_server() -> u16 {
1115        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1116        let port = listener.local_addr().unwrap().port();
1117        thread::spawn(move || {
1118            let (mut s, _) = listener.accept().unwrap();
1119            let mut buf = [0u8; 4096];
1120            let _ = s.read(&mut buf).unwrap();
1121            let _ = s.write_all(&pg_auth(3, &[]));
1122            let _ = s.read(&mut buf).unwrap();
1123            let _ = s.write_all(&post_auth_ok());
1124            loop {
1125                match s.read(&mut buf) {
1126                    Ok(0) | Err(_) => break,
1127                    Ok(_) => {
1128                        let mut r = Vec::new();
1129                        r.extend(pg_msg(b'1', &[]));
1130                        r.extend(pg_msg(b'2', &[]));
1131                        let mut rd = Vec::new();
1132                        rd.extend(&1u16.to_be_bytes());
1133                        rd.extend(b"c\x00");
1134                        rd.extend(&0u32.to_be_bytes());
1135                        rd.extend(&1u16.to_be_bytes());
1136                        rd.extend(&23u32.to_be_bytes());
1137                        rd.extend(&4i16.to_be_bytes());
1138                        rd.extend(&(-1i32).to_be_bytes());
1139                        rd.extend(&0u16.to_be_bytes());
1140                        r.extend(pg_msg(b'T', &rd));
1141                        let mut dr = Vec::new();
1142                        dr.extend(&1u16.to_be_bytes());
1143                        dr.extend(&1u32.to_be_bytes());
1144                        dr.push(b'1');
1145                        r.extend(pg_msg(b'D', &dr));
1146                        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
1147                        r.extend(pg_msg(b'Z', b"T"));
1148                        let _ = s.write_all(&r);
1149                    }
1150                }
1151            }
1152        });
1153        thread::sleep(Duration::from_millis(30));
1154        port
1155    }
1156
1157    fn spawn_error_status_server() -> u16 {
1158        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1159        let port = listener.local_addr().unwrap().port();
1160        thread::spawn(move || {
1161            let (mut s, _) = listener.accept().unwrap();
1162            let mut buf = [0u8; 4096];
1163            let _ = s.read(&mut buf).unwrap();
1164            let _ = s.write_all(&pg_auth(3, &[]));
1165            let _ = s.read(&mut buf).unwrap();
1166            let _ = s.write_all(&post_auth_ok());
1167            loop {
1168                match s.read(&mut buf) {
1169                    Ok(0) | Err(_) => break,
1170                    Ok(_) => {
1171                        let mut r = Vec::new();
1172                        r.extend(pg_msg(b'1', &[]));
1173                        r.extend(pg_msg(b'2', &[]));
1174                        let mut rd = Vec::new();
1175                        rd.extend(&1u16.to_be_bytes());
1176                        rd.extend(b"c\x00");
1177                        rd.extend(&0u32.to_be_bytes());
1178                        rd.extend(&1u16.to_be_bytes());
1179                        rd.extend(&23u32.to_be_bytes());
1180                        rd.extend(&4i16.to_be_bytes());
1181                        rd.extend(&(-1i32).to_be_bytes());
1182                        rd.extend(&0u16.to_be_bytes());
1183                        r.extend(pg_msg(b'T', &rd));
1184                        let mut dr = Vec::new();
1185                        dr.extend(&1u16.to_be_bytes());
1186                        dr.extend(&1u32.to_be_bytes());
1187                        dr.push(b'1');
1188                        r.extend(pg_msg(b'D', &dr));
1189                        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
1190                        r.extend(pg_msg(b'Z', b"E"));
1191                        let _ = s.write_all(&r);
1192                    }
1193                }
1194            }
1195        });
1196        thread::sleep(Duration::from_millis(30));
1197        port
1198    }
1199
1200    fn spawn_slow_partial_server() -> u16 {
1201        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1202        let port = listener.local_addr().unwrap().port();
1203        thread::spawn(move || {
1204            let (mut s, _) = listener.accept().unwrap();
1205            let mut buf = [0u8; 4096];
1206            let _ = s.read(&mut buf).unwrap();
1207            let _ = s.write_all(&pg_auth(3, &[]));
1208            let _ = s.read(&mut buf).unwrap();
1209            let _ = s.write_all(&post_auth_ok());
1210            match s.read(&mut buf) {
1211                Ok(0) | Err(_) => {}
1212                Ok(_) => {
1213                    let _ = s.write_all(&simple_query_response());
1214                }
1215            }
1216        });
1217        thread::sleep(Duration::from_millis(30));
1218        port
1219    }
1220
1221    fn spawn_rst_on_query_server() -> u16 {
1222        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1223        let port = listener.local_addr().unwrap().port();
1224        thread::spawn(move || {
1225            let (mut s, _) = listener.accept().unwrap();
1226            let mut buf = [0u8; 4096];
1227            let _ = s.read(&mut buf).unwrap();
1228            let _ = s.write_all(&pg_auth(3, &[]));
1229            let _ = s.read(&mut buf).unwrap();
1230            let _ = s.write_all(&post_auth_ok());
1231            match s.read(&mut buf) {
1232                Ok(0) | Err(_) => {}
1233                Ok(_) => {
1234                    drop(s);
1235                }
1236            }
1237        });
1238        thread::sleep(Duration::from_millis(30));
1239        port
1240    }
1241
1242    #[test]
1243    fn connect_query_ready_for_query_transaction_status() {
1244        let port = spawn_transaction_status_server();
1245        let mut conn = Connect::new(mock_config(port)).unwrap();
1246        let result = conn.query("SELECT 1");
1247        assert!(result.is_ok());
1248    }
1249
1250    #[test]
1251    fn connect_query_ready_for_query_error_status() {
1252        let port = spawn_error_status_server();
1253        let mut conn = Connect::new(mock_config(port)).unwrap();
1254        let result = conn.query("SELECT 1");
1255        assert!(result.is_ok());
1256    }
1257
1258    #[test]
1259    fn connect_query_server_closes_after_partial() {
1260        let port = spawn_slow_partial_server();
1261        let mut conn = Connect::new(mock_config(port)).unwrap();
1262        let r1 = conn.query("SELECT 1");
1263        assert!(r1.is_ok());
1264        let r2 = conn.query("SELECT 1");
1265        assert!(r2.is_err());
1266    }
1267
1268    #[test]
1269    fn connect_query_server_rst_returns_io_or_connection_error() {
1270        let port = spawn_rst_on_query_server();
1271        let mut conn = Connect::new(mock_config(port)).unwrap();
1272        let result = conn.query("SELECT 1");
1273        assert!(result.is_err());
1274    }
1275
1276    #[test]
1277    fn connect_read_would_block_max_retries() {
1278        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1279        let port = listener.local_addr().unwrap().port();
1280        thread::spawn(move || {
1281            let (mut s, _) = listener.accept().unwrap();
1282            let mut buf = [0u8; 4096];
1283            let _ = s.read(&mut buf);
1284            let _ = s.write_all(&pg_auth(3, &[]));
1285            let _ = s.read(&mut buf);
1286            let _ = s.write_all(&post_auth_ok());
1287            let _ = s.read(&mut buf);
1288            thread::sleep(Duration::from_secs(5));
1289        });
1290        thread::sleep(Duration::from_millis(30));
1291
1292        let mut conn = Connect::new(mock_config(port)).unwrap();
1293        conn.stream
1294            .set_read_timeout(Some(Duration::from_millis(1)))
1295            .ok();
1296        let result = conn.query("SELECT 1");
1297        assert!(result.is_err());
1298        let err_str = result.unwrap_err().to_string();
1299        assert!(
1300            err_str.contains("超时") || err_str.contains("Timeout") || err_str.contains("重试"),
1301            "expected timeout error, got: {err_str}"
1302        );
1303    }
1304
1305    #[test]
1306    fn connect_read_exceeds_max_message_size() {
1307        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1308        let port = listener.local_addr().unwrap().port();
1309        thread::spawn(move || {
1310            let (mut s, _) = listener.accept().unwrap();
1311            let mut buf = [0u8; 4096];
1312            let _ = s.read(&mut buf);
1313            let _ = s.write_all(&pg_auth(3, &[]));
1314            let _ = s.read(&mut buf);
1315            let _ = s.write_all(&post_auth_ok());
1316            let _ = s.read(&mut buf);
1317            let big = vec![b'X'; 256];
1318            let _ = s.write_all(&big);
1319            thread::sleep(Duration::from_secs(2));
1320        });
1321        thread::sleep(Duration::from_millis(30));
1322
1323        let mut conn = Connect::new(mock_config(port)).unwrap();
1324        let result = conn.query("SELECT 1");
1325        assert!(result.is_err());
1326        let err_str = result.unwrap_err().to_string();
1327        assert!(
1328            err_str.contains("最大") || err_str.contains("大小") || err_str.contains("size"),
1329            "expected max message size error, got: {err_str}"
1330        );
1331    }
1332
1333    #[test]
1334    fn connect_read_deadline_timeout() {
1335        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1336        let port = listener.local_addr().unwrap().port();
1337        thread::spawn(move || {
1338            let (mut s, _) = listener.accept().unwrap();
1339            let mut buf = [0u8; 4096];
1340            let _ = s.read(&mut buf);
1341            let _ = s.write_all(&pg_auth(3, &[]));
1342            let _ = s.read(&mut buf);
1343            let _ = s.write_all(&post_auth_ok());
1344            let _ = s.read(&mut buf);
1345            for _ in 0..200 {
1346                let _ = s.write_all(b"X");
1347                thread::sleep(Duration::from_millis(5));
1348            }
1349        });
1350        thread::sleep(Duration::from_millis(30));
1351
1352        let mut conn = Connect::new(mock_config(port)).unwrap();
1353        let result = conn.query("SELECT 1");
1354        assert!(result.is_err());
1355    }
1356
1357    #[test]
1358    fn connect_read_partial_auth_frame() {
1359        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1360        let port = listener.local_addr().unwrap().port();
1361        thread::spawn(move || {
1362            let (mut s, _) = listener.accept().unwrap();
1363            let mut buf = [0u8; 4096];
1364            let _ = s.read(&mut buf);
1365            let auth = pg_auth(3, &[]);
1366            let _ = s.write_all(&auth[..5]);
1367            thread::sleep(Duration::from_millis(50));
1368            let _ = s.write_all(&auth[5..]);
1369            let _ = s.read(&mut buf);
1370            let _ = s.write_all(&post_auth_ok());
1371            loop {
1372                match s.read(&mut buf) {
1373                    Ok(0) | Err(_) => break,
1374                    Ok(_) => {
1375                        let _ = s.write_all(&simple_query_response());
1376                    }
1377                }
1378            }
1379        });
1380        thread::sleep(Duration::from_millis(30));
1381
1382        let mut conn = Connect::new(mock_config(port)).unwrap();
1383        let result = conn.query("SELECT 1");
1384        assert!(result.is_ok());
1385    }
1386    // ── portal/cursor helpers ──────────────────────────────────────────────────
1387    /// Build a portal query response with PortalSuspended:
1388    /// ParseComplete + BindComplete + RowDescription + DataRow(s) + PortalSuspended + ReadyForQuery
1389    fn portal_response(rows: u16) -> Vec<u8> {
1390        let mut r = Vec::new();
1391        r.extend(pg_msg(b'1', &[])); // ParseComplete
1392        r.extend(pg_msg(b'2', &[])); // BindComplete
1393                                     // RowDescription – 1 field "id", type_oid=23 (int4)
1394        let mut rd = Vec::new();
1395        rd.extend(&1u16.to_be_bytes());
1396        rd.extend(b"id\x00");
1397        rd.extend(&0u32.to_be_bytes());
1398        rd.extend(&1u16.to_be_bytes());
1399        rd.extend(&23u32.to_be_bytes());
1400        rd.extend(&4i16.to_be_bytes());
1401        rd.extend(&(-1i32).to_be_bytes());
1402        rd.extend(&0u16.to_be_bytes());
1403        r.extend(pg_msg(b'T', &rd));
1404        for i in 0..rows {
1405            let val = format!("{}", i + 1);
1406            let mut dr = Vec::new();
1407            dr.extend(&1u16.to_be_bytes());
1408            dr.extend(&(val.len() as u32).to_be_bytes());
1409            dr.extend(val.as_bytes());
1410            r.extend(pg_msg(b'D', &dr));
1411        }
1412        // PortalSuspended (tag 's')
1413        r.extend(pg_msg(b's', &[]));
1414        r.extend(ready_for_query());
1415        r
1416    }
1417    /// Build a portal complete response (no more rows):
1418    /// DataRow(s) + CommandComplete + ReadyForQuery
1419    fn portal_complete_response(rows: u16) -> Vec<u8> {
1420        let mut r = Vec::new();
1421        for i in 0..rows {
1422            let val = format!("{}", i + 1);
1423            let mut dr = Vec::new();
1424            dr.extend(&1u16.to_be_bytes());
1425            dr.extend(&(val.len() as u32).to_be_bytes());
1426            dr.extend(val.as_bytes());
1427            r.extend(pg_msg(b'D', &dr));
1428        }
1429        r.extend(pg_msg(b'C', b"SELECT 2\x00"));
1430        r.extend(ready_for_query());
1431        r
1432    }
1433    /// Build a close portal response: CloseComplete + ReadyForQuery
1434    fn close_portal_response() -> Vec<u8> {
1435        let mut r = Vec::new();
1436        r.extend(pg_msg(b'3', &[])); // CloseComplete
1437        r.extend(ready_for_query());
1438        r
1439    }
1440    fn spawn_portal_server() -> u16 {
1441        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1442        let port = listener.local_addr().unwrap().port();
1443        thread::spawn(move || {
1444            let (mut s, _) = listener.accept().unwrap();
1445            let mut buf = [0u8; 4096];
1446            // auth
1447            let _ = s.read(&mut buf).unwrap();
1448            let _ = s.write_all(&pg_auth(3, &[]));
1449            let _ = s.read(&mut buf).unwrap();
1450            let _ = s.write_all(&post_auth_ok());
1451            // First query: portal response with 2 rows + suspended
1452            match s.read(&mut buf) {
1453                Ok(0) | Err(_) => (),
1454                Ok(_) => {
1455                    let _ = s.write_all(&portal_response(2));
1456                }
1457            }
1458            // Second query: fetch_more → complete with 1 row
1459            match s.read(&mut buf) {
1460                Ok(0) | Err(_) => (),
1461                Ok(_) => {
1462                    let _ = s.write_all(&portal_complete_response(1));
1463                }
1464            }
1465            // Third query: close_portal
1466            match s.read(&mut buf) {
1467                Ok(0) | Err(_) => (),
1468                Ok(_) => {
1469                    let _ = s.write_all(&close_portal_response());
1470                }
1471            }
1472        });
1473        thread::sleep(Duration::from_millis(30));
1474        port
1475    }
1476    // ── SSL tests ─────────────────────────────────────────────────────────────
1477    #[test]
1478    fn ssl_prefer_fallback_on_rejection() {
1479        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1480        let port = listener.local_addr().unwrap().port();
1481        thread::spawn(move || {
1482            let (mut s, _) = listener.accept().unwrap();
1483            let mut buf = [0u8; 4096];
1484            // 读取 SSLRequest
1485            let _ = s.read(&mut buf);
1486            // 拒绝 SSL
1487            let _ = s.write_all(b"N");
1488            // 正常 cleartext auth
1489            let _ = s.read(&mut buf);
1490            let _ = s.write_all(&pg_auth(3, &[]));
1491            let _ = s.read(&mut buf);
1492            let _ = s.write_all(&post_auth_ok());
1493            loop {
1494                match s.read(&mut buf) {
1495                    Ok(0) | Err(_) => break,
1496                    Ok(_) => {
1497                        let _ = s.write_all(&simple_query_response());
1498                    }
1499                }
1500            }
1501        });
1502        thread::sleep(Duration::from_millis(30));
1503        let mut cfg = mock_config(port);
1504        cfg.sslmode = "prefer".into();
1505        let conn = Connect::new(cfg);
1506        assert!(conn.is_ok());
1507    }
1508    #[test]
1509    fn ssl_require_rejected_returns_error() {
1510        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1511        let port = listener.local_addr().unwrap().port();
1512        thread::spawn(move || {
1513            let (mut s, _) = listener.accept().unwrap();
1514            let mut buf = [0u8; 4096];
1515            let _ = s.read(&mut buf);
1516            let _ = s.write_all(b"N");
1517        });
1518        thread::sleep(Duration::from_millis(30));
1519        let mut cfg = mock_config(port);
1520        cfg.sslmode = "require".into();
1521        let result = Connect::new(cfg);
1522        assert!(result.is_err());
1523    }
1524    #[test]
1525    fn ssl_invalid_response_byte() {
1526        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1527        let port = listener.local_addr().unwrap().port();
1528        thread::spawn(move || {
1529            let (mut s, _) = listener.accept().unwrap();
1530            let mut buf = [0u8; 4096];
1531            let _ = s.read(&mut buf);
1532            let _ = s.write_all(b"X");
1533        });
1534        thread::sleep(Duration::from_millis(30));
1535        let mut cfg = mock_config(port);
1536        cfg.sslmode = "prefer".into();
1537        let result = Connect::new(cfg);
1538        assert!(result.is_err());
1539    }
1540    #[test]
1541    fn ssl_disable_skips_ssl_handshake() {
1542        let port = spawn_cleartext_server();
1543        let mut cfg = mock_config(port);
1544        cfg.sslmode = "disable".into();
1545        let conn = Connect::new(cfg);
1546        assert!(conn.is_ok());
1547    }
1548    // ── portal/cursor tests ──────────────────────────────────────────────────
1549    #[test]
1550    fn connect_query_portal_returns_rows_with_has_more() {
1551        let port = spawn_portal_server();
1552        let mut conn = Connect::new(mock_config(port)).unwrap();
1553        let result = conn.query_portal("SELECT id FROM t", 2);
1554        assert!(result.is_ok());
1555        let msg = result.unwrap();
1556        assert_eq!(msg.rows.len(), 2);
1557        assert!(msg.has_more);
1558    }
1559    #[test]
1560    fn connect_fetch_more_returns_remaining_rows() {
1561        let port = spawn_portal_server();
1562        let mut conn = Connect::new(mock_config(port)).unwrap();
1563        let _ = conn.query_portal("SELECT id FROM t", 2).unwrap();
1564        let result = conn.fetch_more(10);
1565        assert!(result.is_ok());
1566        let msg = result.unwrap();
1567        assert_eq!(msg.rows.len(), 1);
1568        assert!(!msg.has_more);
1569    }
1570    #[test]
1571    fn connect_close_portal_succeeds() {
1572        let port = spawn_portal_server();
1573        let mut conn = Connect::new(mock_config(port)).unwrap();
1574        let _ = conn.query_portal("SELECT id FROM t", 2).unwrap();
1575        let _ = conn.fetch_more(10).unwrap();
1576        let result = conn.close_portal();
1577        assert!(result.is_ok());
1578    }
1579}