Skip to main content

br_pgsql/
connect.rs

1use crate::config::Config;
2use crate::error::PgsqlError;
3use crate::packet::{AuthStatus, Packet, SuccessMessage};
4use std::io::{Read, Write};
5use std::net::{SocketAddr, TcpStream};
6use std::time::{Duration, Instant};
7
8/// 连接流:支持明文和 TLS
9#[derive(Debug)]
10pub(crate) enum PgStream {
11    Plain(TcpStream),
12    #[cfg(feature = "tls")]
13    Tls(native_tls::TlsStream<TcpStream>),
14}
15
16impl Read for PgStream {
17    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
18        match self {
19            PgStream::Plain(s) => s.read(buf),
20            #[cfg(feature = "tls")]
21            PgStream::Tls(s) => s.read(buf),
22        }
23    }
24}
25
26impl Write for PgStream {
27    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
28        match self {
29            PgStream::Plain(s) => s.write(buf),
30            #[cfg(feature = "tls")]
31            PgStream::Tls(s) => s.write(buf),
32        }
33    }
34    fn flush(&mut self) -> std::io::Result<()> {
35        match self {
36            PgStream::Plain(s) => s.flush(),
37            #[cfg(feature = "tls")]
38            PgStream::Tls(s) => s.flush(),
39        }
40    }
41}
42
43impl PgStream {
44    fn peer_addr(&self) -> std::io::Result<SocketAddr> {
45        match self {
46            PgStream::Plain(s) => s.peer_addr(),
47            #[cfg(feature = "tls")]
48            PgStream::Tls(s) => s.get_ref().peer_addr(),
49        }
50    }
51    fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
52        match self {
53            PgStream::Plain(s) => s.shutdown(how),
54            #[cfg(feature = "tls")]
55            PgStream::Tls(s) => s.get_ref().shutdown(how),
56        }
57    }
58    #[allow(dead_code)]
59    fn set_read_timeout(&self, dur: Option<Duration>) -> std::io::Result<()> {
60        match self {
61            PgStream::Plain(s) => s.set_read_timeout(dur),
62            #[cfg(feature = "tls")]
63            PgStream::Tls(s) => s.get_ref().set_read_timeout(dur),
64        }
65    }
66}
67
68#[derive(Debug)]
69pub struct Connect {
70    pub(crate) stream: PgStream,
71    /// 缓存的 peer 地址
72    _peer_addr: SocketAddr,
73    packet: Packet,
74    auth_status: AuthStatus,
75    /// 上次使用时间,用于懒健康检查
76    last_used: Instant,
77    /// 连接创建时间,用于最大生命周期检查
78    created_at: Instant,
79}
80
81impl Connect {
82    /// 懒健康检查:peer_addr 检测 TCP 存活,空闲超过 5 秒才发 SELECT 1
83    pub fn is_valid(&mut self) -> bool {
84        if self.stream.peer_addr().is_err() {
85            return false;
86        }
87        #[cfg(not(test))]
88        const IDLE_THRESHOLD: Duration = Duration::from_secs(5);
89        #[cfg(test)]
90        const IDLE_THRESHOLD: Duration = Duration::from_millis(0);
91        if self.last_used.elapsed() > IDLE_THRESHOLD {
92            return self.query("SELECT 1").is_ok();
93        }
94        true
95    }
96
97    /// 仅检查 TCP 连接是否存活(不发送查询)
98    pub fn peer_valid(&self) -> bool {
99        self.stream.peer_addr().is_ok()
100    }
101
102    /// 更新最后使用时间
103    pub fn touch(&mut self) {
104        self.last_used = Instant::now();
105    }
106
107    /// 返回空闲时长
108    pub fn idle_elapsed(&self) -> Duration {
109        self.last_used.elapsed()
110    }
111
112    /// 返回连接已存活时长
113    pub fn age(&self) -> Duration {
114        self.created_at.elapsed()
115    }
116
117    pub fn _close(&mut self) {
118        let _ = self.stream.write_all(&Packet::pack_terminate());
119        let _ = self.stream.shutdown(std::net::Shutdown::Both);
120    }
121
122    /// 设置 TCP Keepalive(60秒间隔,3次探测)
123    fn set_keepalive(stream: &TcpStream) -> Result<(), PgsqlError> {
124        use std::os::unix::io::{AsRawFd, FromRawFd};
125        let fd = stream.as_raw_fd();
126        let socket = unsafe { socket2::Socket::from_raw_fd(fd) };
127        let keepalive = socket2::TcpKeepalive::new()
128            .with_time(Duration::from_secs(60))
129            .with_interval(Duration::from_secs(15))
130            .with_retries(3);
131        let result = socket.set_tcp_keepalive(&keepalive);
132        // 不要让 socket2::Socket drop 关闭 fd,转回原始 fd
133        std::mem::forget(socket);
134        result.map_err(|e| PgsqlError::Connection(format!("设置 TCP Keepalive 失败: {}", e)))
135    }
136
137    /// SSL/TLS 升级:发送 SSLRequest,根据服务端响应升级或回退
138    fn try_ssl_upgrade(mut stream: TcpStream, config: &Config) -> Result<PgStream, PgsqlError> {
139        // 发送 SSLRequest
140        stream
141            .write_all(&Packet::pack_ssl_request())
142            .map_err(|e| PgsqlError::Connection(format!("发送 SSLRequest 失败: {}", e)))?;
143        let mut resp = [0u8; 1];
144        stream
145            .read_exact(&mut resp)
146            .map_err(|e| PgsqlError::Connection(format!("读取 SSL 响应失败: {}", e)))?;
147        match resp[0] {
148            b'S' => {
149                // 服务端同意 SSL
150                #[cfg(feature = "tls")]
151                {
152                    let connector = native_tls::TlsConnector::builder()
153                        .danger_accept_invalid_certs(true)
154                        .build()
155                        .map_err(|e| PgsqlError::Connection(format!("TLS 初始化失败: {}", e)))?;
156                    let tls_stream = connector
157                        .connect(&config.hostname, stream)
158                        .map_err(|e| PgsqlError::Connection(format!("TLS 握手失败: {}", e)))?;
159                    Ok(PgStream::Tls(tls_stream))
160                }
161                #[cfg(not(feature = "tls"))]
162                {
163                    let _ = config;
164                    Err(PgsqlError::Connection(
165                        "服务端要求 SSL 但未启用 tls feature".into(),
166                    ))
167                }
168            }
169            b'N' => {
170                // 服务端拒绝 SSL
171                if config.sslmode == "require" {
172                    Err(PgsqlError::Connection(
173                        "sslmode=require 但服务端不支持 SSL".into(),
174                    ))
175                } else {
176                    Ok(PgStream::Plain(stream))
177                }
178            }
179            other => Err(PgsqlError::Connection(format!(
180                "无效的 SSL 响应字节: 0x{:02X}",
181                other
182            ))),
183        }
184    }
185    pub fn new(mut config: Config) -> Result<Connect, PgsqlError> {
186        let stream =
187            TcpStream::connect(config.url()).map_err(|e| PgsqlError::Connection(e.to_string()))?;
188        // TCP 优化:禁用 Nagle 算法,减少小包延迟
189        stream
190            .set_nodelay(true)
191            .map_err(|e| PgsqlError::Connection(format!("设置 TCP_NODELAY 失败: {}", e)))?;
192        // TCP Keepalive:防止空闲连接被防火墙/NAT/PG服务端静默断开
193        Self::set_keepalive(&stream)?;
194        stream
195            .set_read_timeout(Some(Duration::from_secs(30)))
196            .map_err(|e| PgsqlError::Connection(format!("设置读取超时失败: {}", e)))?;
197        stream
198            .set_write_timeout(Some(Duration::from_secs(30)))
199            .map_err(|e| PgsqlError::Connection(format!("设置写入超时失败: {}", e)))?;
200        let peer_addr = stream
201            .peer_addr()
202            .map_err(|e| PgsqlError::Connection(e.to_string()))?;
203
204        // SSL/TLS 升级
205        let stream = if config.sslmode != "disable" {
206            Self::try_ssl_upgrade(stream, &config)?
207        } else {
208            PgStream::Plain(stream)
209        };
210
211        let mut connect = Self {
212            stream,
213            _peer_addr: peer_addr,
214            packet: Packet::new(config),
215            auth_status: AuthStatus::None,
216            last_used: Instant::now(),
217            created_at: Instant::now(),
218        };
219
220        connect.authenticate()?;
221
222        Ok(connect)
223    }
224
225    fn authenticate(&mut self) -> Result<(), PgsqlError> {
226        self.stream
227            .write_all(&self.packet.pack_first())
228            .map_err(|e| PgsqlError::Auth(format!("发送 startup message 失败: {}", e)))?;
229
230        let data = self.read()?;
231        self.packet.unpack(data, 0)?;
232
233        if !self.packet.md5_salt.is_empty() {
234            self.md5_auth()?;
235        } else if self.packet.auth_mechanism.is_empty() && self.packet.md5_salt.is_empty() {
236            self.cleartext_auth()?;
237        } else {
238            self.scram_auth()?;
239        }
240
241        self.auth_status = AuthStatus::AuthenticationOk;
242        Ok(())
243    }
244
245    fn md5_auth(&mut self) -> Result<(), PgsqlError> {
246        self.stream
247            .write_all(&self.packet.pack_md5_password())
248            .map_err(|e| PgsqlError::Auth(format!("发送 MD5 密码失败: {}", e)))?;
249
250        let data = self.read()?;
251        self.packet.unpack(data, 0)?;
252        Ok(())
253    }
254
255    fn cleartext_auth(&mut self) -> Result<(), PgsqlError> {
256        self.stream
257            .write_all(&self.packet.pack_cleartext_password())
258            .map_err(|e| PgsqlError::Auth(format!("发送明文密码失败: {}", e)))?;
259
260        let data = self.read()?;
261        self.packet.unpack(data, 0)?;
262        Ok(())
263    }
264
265    fn scram_auth(&mut self) -> Result<(), PgsqlError> {
266        self.stream
267            .write_all(&self.packet.pack_auth())
268            .map_err(|e| PgsqlError::Auth(format!("发送 SASL Initial Response 失败: {}", e)))?;
269
270        let data = self.read()?;
271        self.packet.unpack(data, 0)?;
272
273        self.stream
274            .write_all(&self.packet.pack_auth_verify())
275            .map_err(|e| PgsqlError::Auth(format!("发送 SASL Verify 失败: {}", e)))?;
276
277        let data = self.read()?;
278        self.packet.unpack(data, 0)?;
279        Ok(())
280    }
281
282    fn read(&mut self) -> Result<Vec<u8>, PgsqlError> {
283        let mut msg = Vec::new();
284        let mut buf = [0u8; 4096];
285        let mut retry_count = 0;
286
287        #[cfg(not(test))]
288        const MAX_RETRIES: u32 = 100;
289        #[cfg(test)]
290        const MAX_RETRIES: u32 = 3;
291
292        #[cfg(not(test))]
293        const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024;
294        #[cfg(test)]
295        const MAX_MESSAGE_SIZE: usize = 128;
296
297        #[cfg(not(test))]
298        let deadline = std::time::Instant::now() + Duration::from_secs(300);
299        #[cfg(test)]
300        let deadline = std::time::Instant::now() + Duration::from_millis(200);
301
302        loop {
303            if std::time::Instant::now() >= deadline {
304                return Err(PgsqlError::Timeout("读取总超时".into()));
305            }
306
307            match self.stream.read(&mut buf) {
308                Ok(0) => return Err(PgsqlError::Connection("连接已关闭或服务端断开".into())),
309                Ok(n) => {
310                    if msg.len() + n > MAX_MESSAGE_SIZE {
311                        return Err(PgsqlError::Protocol("消息超过最大允许大小".into()));
312                    }
313                    msg.extend_from_slice(&buf[..n]);
314                    retry_count = 0;
315                }
316                Err(ref e)
317                    if e.kind() == std::io::ErrorKind::WouldBlock
318                        || e.kind() == std::io::ErrorKind::TimedOut =>
319                {
320                    retry_count += 1;
321                    if retry_count > MAX_RETRIES {
322                        return Err(PgsqlError::Timeout("读取超时,已达最大重试次数".into()));
323                    }
324                    std::thread::sleep(Duration::from_millis(10));
325                    continue;
326                }
327                Err(e) => return Err(PgsqlError::Io(e)),
328            };
329
330            if let AuthStatus::AuthenticationOk = self.auth_status {
331                if msg.ends_with(&[90, 0, 0, 0, 5, 73])
332                    || msg.ends_with(&[90, 0, 0, 0, 5, 84])
333                    || msg.ends_with(&[90, 0, 0, 0, 5, 69])
334                {
335                    break;
336                }
337            } else if msg.len() >= 5 {
338                let len_bytes = &msg[1..=4];
339                if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
340                    if msg.len() > len as usize {
341                        break;
342                    }
343                }
344            }
345        }
346
347        Ok(msg)
348    }
349
350    pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
351        self.stream
352            .write_all(&self.packet.pack_query(sql))
353            .map_err(PgsqlError::Io)?;
354        let data = self.read()?;
355        self.last_used = Instant::now();
356        self.packet.unpack(data, 0)
357    }
358
359    pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
360        self.stream
361            .write_all(&self.packet.pack_execute(sql))
362            .map_err(PgsqlError::Io)?;
363        let data = self.read()?;
364        self.last_used = Instant::now();
365        self.packet.unpack(data, 0)
366    }
367
368    /// 参数化查询
369    pub fn query_params(
370        &mut self,
371        sql: &str,
372        params: &[Option<&str>],
373    ) -> Result<SuccessMessage, PgsqlError> {
374        self.stream
375            .write_all(&self.packet.pack_query_params(sql, params))
376            .map_err(PgsqlError::Io)?;
377
378        let data = self.read()?;
379        self.last_used = Instant::now();
380        self.packet.unpack(data, 0)
381    }
382
383    /// 参数化执行
384    pub fn execute_params(
385        &mut self,
386        sql: &str,
387        params: &[Option<&str>],
388    ) -> Result<SuccessMessage, PgsqlError> {
389        self.stream
390            .write_all(&self.packet.pack_execute_params(sql, params))
391            .map_err(PgsqlError::Io)?;
392        let data = self.read()?;
393        self.last_used = Instant::now();
394        self.packet.unpack(data, 0)
395    }
396
397    /// 参数化查询(便捷版,所有参数非 NULL)
398    pub fn query_str(&mut self, sql: &str, params: &[&str]) -> Result<SuccessMessage, PgsqlError> {
399        let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
400        self.query_params(sql, &opts)
401    }
402
403    /// 参数化执行(便捷版,所有参数非 NULL)
404    pub fn execute_str(
405        &mut self,
406        sql: &str,
407        params: &[&str],
408    ) -> Result<SuccessMessage, PgsqlError> {
409        let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
410        self.execute_params(sql, &opts)
411    }
412    // ── Portal/Cursor API ──────────────────────────────────────────────
413    /// Portal 查询:分批获取结果,max_rows 指定每批行数
414    pub fn query_portal(&mut self, sql: &str, max_rows: u32) -> Result<SuccessMessage, PgsqlError> {
415        self.stream
416            .write_all(&self.packet.pack_query_portal(sql, max_rows))
417            .map_err(PgsqlError::Io)?;
418        let data = self.read()?;
419        self.last_used = Instant::now();
420        self.packet.unpack(data, 0)
421    }
422    /// 从已打开的 Portal 继续获取更多行
423    pub fn fetch_more(&mut self, max_rows: u32) -> Result<SuccessMessage, PgsqlError> {
424        self.stream
425            .write_all(&self.packet.pack_fetch_more(max_rows))
426            .map_err(PgsqlError::Io)?;
427        let data = self.read()?;
428        self.last_used = Instant::now();
429        self.packet.unpack(data, 0)
430    }
431    /// 关闭当前 Portal
432    pub fn close_portal(&mut self) -> Result<SuccessMessage, PgsqlError> {
433        self.stream
434            .write_all(&self.packet.pack_close_portal())
435            .map_err(PgsqlError::Io)?;
436        let data = self.read()?;
437        self.last_used = Instant::now();
438        self.packet.unpack(data, 0)
439    }
440}
441
442impl Drop for Connect {
443    fn drop(&mut self) {
444        let _ = self.stream.write_all(&Packet::pack_terminate());
445        let _ = self.stream.shutdown(std::net::Shutdown::Both);
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use std::net::TcpListener;
453    use std::thread;
454
455    // ── wire-protocol helpers ──────────────────────────────────────────
456
457    /// Build a single PG backend message: type_byte | len(4) | payload
458    fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
459        let mut m = Vec::with_capacity(5 + payload.len());
460        m.push(tag);
461        m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
462        m.extend_from_slice(payload);
463        m
464    }
465
466    /// Build an Authentication message (tag 'R') with given auth_type + extra payload
467    fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
468        let mut body = Vec::new();
469        body.extend(&auth_type.to_be_bytes());
470        body.extend_from_slice(extra);
471        pg_msg(b'R', &body)
472    }
473
474    /// AuthenticationOk (R, type=0)
475    fn auth_ok() -> Vec<u8> {
476        pg_auth(0, &[])
477    }
478
479    /// ParameterStatus for server_version=15.0
480    fn param_status() -> Vec<u8> {
481        pg_msg(b'S', b"server_version\x0015.0\x00")
482    }
483
484    /// BackendKeyData (process_id=1, secret_key=2)
485    fn backend_key() -> Vec<u8> {
486        let mut p = Vec::new();
487        p.extend(&1u32.to_be_bytes());
488        p.extend(&2u32.to_be_bytes());
489        pg_msg(b'K', &p)
490    }
491
492    /// ReadyForQuery (status = Idle)
493    fn ready_for_query() -> Vec<u8> {
494        pg_msg(b'Z', b"I")
495    }
496
497    /// The standard tail sent after auth succeeds: AuthOk + param + key + ready
498    fn post_auth_ok() -> Vec<u8> {
499        let mut v = Vec::new();
500        v.extend(auth_ok());
501        v.extend(param_status());
502        v.extend(backend_key());
503        v.extend(ready_for_query());
504        v
505    }
506
507    /// Build a simple query response: ParseComplete + BindComplete +
508    /// RowDescription(1 int4 col "c") + DataRow("1") + CommandComplete("SELECT 1") + ReadyForQuery
509    fn simple_query_response() -> Vec<u8> {
510        let mut r = Vec::new();
511        // ParseComplete
512        r.extend(pg_msg(b'1', &[]));
513        // BindComplete
514        r.extend(pg_msg(b'2', &[]));
515        // RowDescription – 1 field "c", type_oid=23 (int4)
516        let mut rd = Vec::new();
517        rd.extend(&1u16.to_be_bytes()); // field count
518        rd.extend(b"c\x00"); // name
519        rd.extend(&0u32.to_be_bytes()); // table oid
520        rd.extend(&1u16.to_be_bytes()); // column index
521        rd.extend(&23u32.to_be_bytes()); // type oid (int4)
522        rd.extend(&4i16.to_be_bytes()); // column length
523        rd.extend(&(-1i32).to_be_bytes()); // type modifier
524        rd.extend(&0u16.to_be_bytes()); // format (text)
525        r.extend(pg_msg(b'T', &rd));
526        // DataRow – 1 field, value "1"
527        let mut dr = Vec::new();
528        dr.extend(&1u16.to_be_bytes());
529        dr.extend(&1u32.to_be_bytes()); // length of value
530        dr.push(b'1');
531        r.extend(pg_msg(b'D', &dr));
532        // CommandComplete
533        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
534        // ReadyForQuery
535        r.extend(ready_for_query());
536        r
537    }
538
539    /// Build an execute response (no rows): ParseComplete + BindComplete +
540    /// NoData + CommandComplete("UPDATE 3") + ReadyForQuery
541    fn execute_response() -> Vec<u8> {
542        let mut r = Vec::new();
543        r.extend(pg_msg(b'1', &[]));
544        r.extend(pg_msg(b'2', &[]));
545        r.extend(pg_msg(b'n', &[])); // NoData
546        r.extend(pg_msg(b'C', b"UPDATE 3\x00"));
547        r.extend(ready_for_query());
548        r
549    }
550
551    /// Build a parameterized query response: ParseComplete + ParameterDescription + BindComplete +
552    /// RowDescription(1 int4 col "p") + DataRow("42") + CommandComplete("SELECT 1") + ReadyForQuery
553    fn query_params_response() -> Vec<u8> {
554        let mut r = Vec::new();
555        r.extend(pg_msg(b'1', &[]));
556
557        let mut pd = Vec::new();
558        pd.extend(&1u16.to_be_bytes());
559        pd.extend(&23u32.to_be_bytes());
560        r.extend(pg_msg(b't', &pd));
561
562        r.extend(pg_msg(b'2', &[]));
563
564        let mut rd = Vec::new();
565        rd.extend(&1u16.to_be_bytes());
566        rd.extend(b"p\x00");
567        rd.extend(&0u32.to_be_bytes());
568        rd.extend(&1u16.to_be_bytes());
569        rd.extend(&23u32.to_be_bytes());
570        rd.extend(&4i16.to_be_bytes());
571        rd.extend(&(-1i32).to_be_bytes());
572        rd.extend(&0u16.to_be_bytes());
573        r.extend(pg_msg(b'T', &rd));
574
575        let mut dr = Vec::new();
576        dr.extend(&1u16.to_be_bytes());
577        dr.extend(&2u32.to_be_bytes());
578        dr.extend(b"42");
579        r.extend(pg_msg(b'D', &dr));
580
581        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
582        r.extend(ready_for_query());
583        r
584    }
585
586    /// Build a parameterized execute response: ParseComplete + ParameterDescription + BindComplete +
587    /// NoData + CommandComplete("UPDATE 1") + ReadyForQuery
588    fn execute_params_response() -> Vec<u8> {
589        let mut r = Vec::new();
590        r.extend(pg_msg(b'1', &[]));
591
592        let mut pd = Vec::new();
593        pd.extend(&1u16.to_be_bytes());
594        pd.extend(&23u32.to_be_bytes());
595        r.extend(pg_msg(b't', &pd));
596
597        r.extend(pg_msg(b'2', &[]));
598        r.extend(pg_msg(b'n', &[]));
599        r.extend(pg_msg(b'C', b"UPDATE 1\x00"));
600        r.extend(ready_for_query());
601        r
602    }
603
604    /// Build a parameterized query response with NULL row value.
605    fn query_params_null_response() -> Vec<u8> {
606        let mut r = Vec::new();
607        r.extend(pg_msg(b'1', &[]));
608
609        let mut pd = Vec::new();
610        pd.extend(&1u16.to_be_bytes());
611        pd.extend(&25u32.to_be_bytes());
612        r.extend(pg_msg(b't', &pd));
613
614        r.extend(pg_msg(b'2', &[]));
615
616        let mut rd = Vec::new();
617        rd.extend(&1u16.to_be_bytes());
618        rd.extend(b"n\x00");
619        rd.extend(&0u32.to_be_bytes());
620        rd.extend(&1u16.to_be_bytes());
621        rd.extend(&25u32.to_be_bytes());
622        rd.extend(&(-1i16).to_be_bytes());
623        rd.extend(&(-1i32).to_be_bytes());
624        rd.extend(&0u16.to_be_bytes());
625        r.extend(pg_msg(b'T', &rd));
626
627        let mut dr = Vec::new();
628        dr.extend(&1u16.to_be_bytes());
629        dr.extend(&(-1i32).to_be_bytes());
630        r.extend(pg_msg(b'D', &dr));
631
632        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
633        r.extend(ready_for_query());
634        r
635    }
636
637    /// Build an ErrorResponse for query phase
638    fn error_response() -> Vec<u8> {
639        let mut payload = Vec::new();
640        payload.push(b'C');
641        payload.extend(b"42601\x00");
642        payload.push(b'M');
643        payload.extend(b"syntax error\x00");
644        payload.push(0);
645        let mut r = Vec::new();
646        r.extend(pg_msg(b'1', &[]));
647        r.extend(pg_msg(b'2', &[]));
648        r.extend(pg_msg(b'E', &payload));
649        r.extend(ready_for_query());
650        r
651    }
652
653    // ── mock server spawners ───────────────────────────────────────────
654
655    /// Config pointing at 127.0.0.1:<port>
656    fn mock_config(port: u16) -> Config {
657        Config {
658            debug: false,
659            hostname: "127.0.0.1".into(),
660            hostport: port as i32,
661            username: "u".into(),
662            userpass: "p".into(),
663            database: "d".into(),
664            charset: "utf8".into(),
665            pool_max: 5,
666            sslmode: "disable".into(),
667        }
668    }
669
670    /// Spawn a mock PG server that does **cleartext** auth.
671    /// Returns the port.  The server handles one connection.
672    fn spawn_cleartext_server() -> u16 {
673        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
674        let port = listener.local_addr().unwrap().port();
675        thread::spawn(move || {
676            let (mut s, _) = listener.accept().unwrap();
677            let mut buf = [0u8; 4096];
678            // 1. read startup
679            let _ = s.read(&mut buf).unwrap();
680            // 2. send CleartextPassword request (auth_type=3)
681            let _ = s.write_all(&pg_auth(3, &[]));
682            // 3. read password message
683            let _ = s.read(&mut buf).unwrap();
684            // 4. send AuthOk + params + ready
685            let _ = s.write_all(&post_auth_ok());
686            // keep connection alive for queries
687            loop {
688                match s.read(&mut buf) {
689                    Ok(0) | Err(_) => break,
690                    Ok(_) => {
691                        let _ = s.write_all(&simple_query_response());
692                    }
693                }
694            }
695        });
696        thread::sleep(Duration::from_millis(30));
697        port
698    }
699
700    /// Spawn a mock PG server that does **MD5** auth.
701    fn spawn_md5_server() -> u16 {
702        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
703        let port = listener.local_addr().unwrap().port();
704        thread::spawn(move || {
705            let (mut s, _) = listener.accept().unwrap();
706            let mut buf = [0u8; 4096];
707            // 1. read startup
708            let _ = s.read(&mut buf).unwrap();
709            // 2. send MD5Password request (auth_type=5) + 4-byte salt
710            let _ = s.write_all(&pg_auth(5, &[0xAA, 0xBB, 0xCC, 0xDD]));
711            // 3. read md5 password
712            let _ = s.read(&mut buf).unwrap();
713            // 4. send AuthOk + params + ready
714            let _ = s.write_all(&post_auth_ok());
715            loop {
716                match s.read(&mut buf) {
717                    Ok(0) | Err(_) => break,
718                    Ok(_) => {
719                        let _ = s.write_all(&simple_query_response());
720                    }
721                }
722            }
723        });
724        thread::sleep(Duration::from_millis(30));
725        port
726    }
727
728    /// Spawn a mock PG server that does **SCRAM-SHA-256** auth.
729    fn spawn_scram_server() -> u16 {
730        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
731        let port = listener.local_addr().unwrap().port();
732        thread::spawn(move || {
733            let (mut s, _) = listener.accept().unwrap();
734            let mut buf = [0u8; 4096];
735            // 1. read startup
736            let _ = s.read(&mut buf).unwrap();
737            // 2. send SASL auth (auth_type=10, mechanism)
738            let _ = s.write_all(&pg_auth(10, b"SCRAM-SHA-256\x00\x00"));
739            // 3. read SASLInitialResponse – extract client nonce
740            let n = s.read(&mut buf).unwrap();
741            let payload = &buf[..n];
742            // find "n=,r=" in the payload to extract client nonce
743            let text = String::from_utf8_lossy(payload);
744            let client_nonce = text.split("r=").nth(1).unwrap_or("clientnonce").to_string();
745            // 4. send SCRAM challenge (auth_type=11)
746            let challenge = format!("r={client_nonce}SERVERNONCE,s=c2FsdA==,i=4096");
747            let _ = s.write_all(&pg_auth(11, challenge.as_bytes()));
748            // 5. read SCRAM client final
749            let _ = s.read(&mut buf).unwrap();
750            // 6. send SCRAM complete (auth_type=12) + AuthOk + params + ready
751            let mut resp = Vec::new();
752            resp.extend(pg_auth(12, b"v=dummyproof"));
753            resp.extend(auth_ok());
754            resp.extend(param_status());
755            resp.extend(backend_key());
756            resp.extend(ready_for_query());
757            let _ = s.write_all(&resp);
758            loop {
759                match s.read(&mut buf) {
760                    Ok(0) | Err(_) => break,
761                    Ok(_) => {
762                        let _ = s.write_all(&simple_query_response());
763                    }
764                }
765            }
766        });
767        thread::sleep(Duration::from_millis(30));
768        port
769    }
770
771    /// Spawn a server that accepts connection then immediately closes (EOF).
772    fn spawn_eof_server() -> u16 {
773        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
774        let port = listener.local_addr().unwrap().port();
775        thread::spawn(move || {
776            let (s, _) = listener.accept().unwrap();
777            drop(s); // close immediately
778        });
779        thread::sleep(Duration::from_millis(30));
780        port
781    }
782
783    /// Spawn a server that sends an ErrorResponse after startup.
784    fn spawn_auth_error_server() -> u16 {
785        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
786        let port = listener.local_addr().unwrap().port();
787        thread::spawn(move || {
788            let (mut s, _) = listener.accept().unwrap();
789            let mut buf = [0u8; 4096];
790            let _ = s.read(&mut buf).unwrap();
791            // Send ErrorResponse
792            let mut payload = Vec::new();
793            payload.push(b'C');
794            payload.extend(b"28P01\x00");
795            payload.push(b'M');
796            payload.extend(b"password authentication failed\x00");
797            payload.push(0);
798            let _ = s.write_all(&pg_msg(b'E', &payload));
799        });
800        thread::sleep(Duration::from_millis(30));
801        port
802    }
803
804    /// Spawn a cleartext server that responds with error on query.
805    fn spawn_query_error_server() -> u16 {
806        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
807        let port = listener.local_addr().unwrap().port();
808        thread::spawn(move || {
809            let (mut s, _) = listener.accept().unwrap();
810            let mut buf = [0u8; 4096];
811            // auth
812            let _ = s.read(&mut buf).unwrap();
813            let _ = s.write_all(&pg_auth(3, &[]));
814            let _ = s.read(&mut buf).unwrap();
815            let _ = s.write_all(&post_auth_ok());
816            // query → error
817            loop {
818                match s.read(&mut buf) {
819                    Ok(0) | Err(_) => break,
820                    Ok(_) => {
821                        let _ = s.write_all(&error_response());
822                    }
823                }
824            }
825        });
826        thread::sleep(Duration::from_millis(30));
827        port
828    }
829
830    /// Spawn a cleartext server that responds with execute response.
831    fn spawn_execute_server() -> u16 {
832        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
833        let port = listener.local_addr().unwrap().port();
834        thread::spawn(move || {
835            let (mut s, _) = listener.accept().unwrap();
836            let mut buf = [0u8; 4096];
837            // auth
838            let _ = s.read(&mut buf).unwrap();
839            let _ = s.write_all(&pg_auth(3, &[]));
840            let _ = s.read(&mut buf).unwrap();
841            let _ = s.write_all(&post_auth_ok());
842            // queries
843            loop {
844                match s.read(&mut buf) {
845                    Ok(0) | Err(_) => break,
846                    Ok(_) => {
847                        let _ = s.write_all(&execute_response());
848                    }
849                }
850            }
851        });
852        thread::sleep(Duration::from_millis(30));
853        port
854    }
855
856    /// Spawn a cleartext server that responds with parameterized query response.
857    fn spawn_query_params_server() -> u16 {
858        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
859        let port = listener.local_addr().unwrap().port();
860        thread::spawn(move || {
861            let (mut s, _) = listener.accept().unwrap();
862            let mut buf = [0u8; 4096];
863            let _ = s.read(&mut buf).unwrap();
864            let _ = s.write_all(&pg_auth(3, &[]));
865            let _ = s.read(&mut buf).unwrap();
866            let _ = s.write_all(&post_auth_ok());
867            loop {
868                match s.read(&mut buf) {
869                    Ok(0) | Err(_) => break,
870                    Ok(_) => {
871                        let _ = s.write_all(&query_params_response());
872                    }
873                }
874            }
875        });
876        thread::sleep(Duration::from_millis(30));
877        port
878    }
879
880    fn spawn_params_server() -> u16 {
881        spawn_query_params_server()
882    }
883
884    /// Spawn a cleartext server that responds with parameterized execute response.
885    fn spawn_execute_params_server() -> u16 {
886        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
887        let port = listener.local_addr().unwrap().port();
888        thread::spawn(move || {
889            let (mut s, _) = listener.accept().unwrap();
890            let mut buf = [0u8; 4096];
891            let _ = s.read(&mut buf).unwrap();
892            let _ = s.write_all(&pg_auth(3, &[]));
893            let _ = s.read(&mut buf).unwrap();
894            let _ = s.write_all(&post_auth_ok());
895            loop {
896                match s.read(&mut buf) {
897                    Ok(0) | Err(_) => break,
898                    Ok(_) => {
899                        let _ = s.write_all(&execute_params_response());
900                    }
901                }
902            }
903        });
904        thread::sleep(Duration::from_millis(30));
905        port
906    }
907
908    /// Spawn a cleartext server that returns NULL in parameterized query result.
909    fn spawn_query_params_null_server() -> u16 {
910        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
911        let port = listener.local_addr().unwrap().port();
912        thread::spawn(move || {
913            let (mut s, _) = listener.accept().unwrap();
914            let mut buf = [0u8; 4096];
915            let _ = s.read(&mut buf).unwrap();
916            let _ = s.write_all(&pg_auth(3, &[]));
917            let _ = s.read(&mut buf).unwrap();
918            let _ = s.write_all(&post_auth_ok());
919            loop {
920                match s.read(&mut buf) {
921                    Ok(0) | Err(_) => break,
922                    Ok(_) => {
923                        let _ = s.write_all(&query_params_null_response());
924                    }
925                }
926            }
927        });
928        thread::sleep(Duration::from_millis(30));
929        port
930    }
931
932    // ── tests ──────────────────────────────────────────────────────────
933
934    #[test]
935    fn connect_cleartext_auth_success() {
936        let port = spawn_cleartext_server();
937        let conn = Connect::new(mock_config(port));
938        assert!(conn.is_ok());
939    }
940
941    #[test]
942    fn connect_md5_auth_success() {
943        let port = spawn_md5_server();
944        let conn = Connect::new(mock_config(port));
945        assert!(conn.is_ok());
946    }
947
948    #[test]
949    fn connect_scram_auth_success() {
950        let port = spawn_scram_server();
951        let conn = Connect::new(mock_config(port));
952        assert!(conn.is_ok());
953    }
954
955    #[test]
956    fn connect_connection_refused() {
957        // port 1 is almost certainly not listening
958        let cfg = mock_config(1);
959        let result = Connect::new(cfg);
960        assert!(result.is_err());
961        match result.unwrap_err() {
962            PgsqlError::Connection(_) => {}
963            other => panic!("expected Connection error, got {other:?}"),
964        }
965    }
966
967    #[test]
968    fn connect_server_closes_immediately() {
969        let port = spawn_eof_server();
970        let result = Connect::new(mock_config(port));
971        assert!(result.is_err());
972    }
973
974    #[test]
975    fn connect_auth_error_from_server() {
976        let port = spawn_auth_error_server();
977        let result = Connect::new(mock_config(port));
978        assert!(result.is_err());
979    }
980
981    #[test]
982    fn connect_query_success() {
983        let port = spawn_cleartext_server();
984        let mut conn = Connect::new(mock_config(port)).unwrap();
985        let result = conn.query("SELECT 1");
986        assert!(result.is_ok());
987        let msg = result.unwrap();
988        assert_eq!(msg.rows.len(), 1);
989        assert_eq!(msg.rows[0]["c"].as_i32(), Some(1));
990    }
991
992    #[test]
993    fn connect_execute_success() {
994        let port = spawn_execute_server();
995        let mut conn = Connect::new(mock_config(port)).unwrap();
996        let result = conn.execute("UPDATE t SET x=1");
997        assert!(result.is_ok());
998        let msg = result.unwrap();
999        assert_eq!(msg.affect_count, 3);
1000        assert_eq!(msg.tag, "UPDATE 3");
1001    }
1002
1003    #[test]
1004    fn connect_query_params_success() {
1005        let port = spawn_query_params_server();
1006        let mut conn = Connect::new(mock_config(port)).unwrap();
1007        let result = conn.query_params("SELECT $1::int", &[Some("42")]);
1008        assert!(result.is_ok());
1009        let msg = result.unwrap();
1010        assert!(!msg.param_oids.is_empty());
1011        assert_eq!(msg.rows.len(), 1);
1012        assert_eq!(msg.rows[0]["p"].as_i32(), Some(42));
1013    }
1014
1015    #[test]
1016    fn connect_execute_params_success() {
1017        let port = spawn_execute_params_server();
1018        let mut conn = Connect::new(mock_config(port)).unwrap();
1019        let result = conn.execute_params("UPDATE t SET x=$1", &[Some("42")]);
1020        assert!(result.is_ok());
1021        let msg = result.unwrap();
1022        assert!(!msg.param_oids.is_empty());
1023        assert_eq!(msg.affect_count, 1);
1024        assert_eq!(msg.tag, "UPDATE 1");
1025    }
1026
1027    #[test]
1028    fn connect_query_str_success() {
1029        let port = spawn_params_server();
1030        let mut conn = Connect::new(mock_config(port)).unwrap();
1031        let result = conn.query_str("SELECT $1::int", &["42"]);
1032        assert!(result.is_ok());
1033        let msg = result.unwrap();
1034        assert!(!msg.param_oids.is_empty());
1035        assert_eq!(msg.rows.len(), 1);
1036    }
1037
1038    #[test]
1039    fn connect_execute_str_success() {
1040        let port = spawn_execute_params_server();
1041        let mut conn = Connect::new(mock_config(port)).unwrap();
1042        let result = conn.execute_str("UPDATE t SET x=$1", &["1"]);
1043        assert!(result.is_ok());
1044        let msg = result.unwrap();
1045        assert!(!msg.param_oids.is_empty());
1046        assert_eq!(msg.affect_count, 1);
1047    }
1048
1049    #[test]
1050    fn connect_query_params_with_null() {
1051        let port = spawn_query_params_null_server();
1052        let mut conn = Connect::new(mock_config(port)).unwrap();
1053        let result = conn.query_params("SELECT $1::text", &[None]);
1054        assert!(result.is_ok());
1055        let msg = result.unwrap();
1056        assert!(!msg.param_oids.is_empty());
1057        assert_eq!(msg.rows.len(), 1);
1058        assert_eq!(msg.rows[0]["n"], "");
1059    }
1060
1061    #[test]
1062    fn connect_query_params_empty_string_vs_null() {
1063        let port = spawn_params_server();
1064        let mut conn = Connect::new(mock_config(port)).unwrap();
1065
1066        // 空字符串参数
1067        let r1 = conn.query_params("SELECT $1::text", &[Some("")]);
1068        assert!(r1.is_ok());
1069
1070        // NULL 参数
1071        let r2 = conn.query_params("SELECT $1::text", &[None]);
1072        assert!(r2.is_ok());
1073    }
1074
1075    #[test]
1076    fn connect_query_returns_error() {
1077        let port = spawn_query_error_server();
1078        let mut conn = Connect::new(mock_config(port)).unwrap();
1079        let result = conn.query("BAD SQL");
1080        assert!(result.is_err());
1081    }
1082
1083    #[test]
1084    fn connect_is_valid_true() {
1085        let port = spawn_cleartext_server();
1086        let mut conn = Connect::new(mock_config(port)).unwrap();
1087        assert!(conn.is_valid());
1088    }
1089
1090    #[test]
1091    fn connect_is_valid_false_after_close() {
1092        let port = spawn_cleartext_server();
1093        let mut conn = Connect::new(mock_config(port)).unwrap();
1094        conn._close();
1095        // After closing, is_valid should return false
1096        assert!(!conn.is_valid());
1097    }
1098
1099    #[test]
1100    fn connect_close_does_not_panic() {
1101        let port = spawn_cleartext_server();
1102        let mut conn = Connect::new(mock_config(port)).unwrap();
1103        conn._close();
1104        // calling close again should not panic
1105        conn._close();
1106    }
1107
1108    #[test]
1109    fn connect_drop_does_not_panic() {
1110        let port = spawn_cleartext_server();
1111        let conn = Connect::new(mock_config(port)).unwrap();
1112        drop(conn);
1113    }
1114
1115    fn spawn_transaction_status_server() -> u16 {
1116        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1117        let port = listener.local_addr().unwrap().port();
1118        thread::spawn(move || {
1119            let (mut s, _) = listener.accept().unwrap();
1120            let mut buf = [0u8; 4096];
1121            let _ = s.read(&mut buf).unwrap();
1122            let _ = s.write_all(&pg_auth(3, &[]));
1123            let _ = s.read(&mut buf).unwrap();
1124            let _ = s.write_all(&post_auth_ok());
1125            loop {
1126                match s.read(&mut buf) {
1127                    Ok(0) | Err(_) => break,
1128                    Ok(_) => {
1129                        let mut r = Vec::new();
1130                        r.extend(pg_msg(b'1', &[]));
1131                        r.extend(pg_msg(b'2', &[]));
1132                        let mut rd = Vec::new();
1133                        rd.extend(&1u16.to_be_bytes());
1134                        rd.extend(b"c\x00");
1135                        rd.extend(&0u32.to_be_bytes());
1136                        rd.extend(&1u16.to_be_bytes());
1137                        rd.extend(&23u32.to_be_bytes());
1138                        rd.extend(&4i16.to_be_bytes());
1139                        rd.extend(&(-1i32).to_be_bytes());
1140                        rd.extend(&0u16.to_be_bytes());
1141                        r.extend(pg_msg(b'T', &rd));
1142                        let mut dr = Vec::new();
1143                        dr.extend(&1u16.to_be_bytes());
1144                        dr.extend(&1u32.to_be_bytes());
1145                        dr.push(b'1');
1146                        r.extend(pg_msg(b'D', &dr));
1147                        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
1148                        r.extend(pg_msg(b'Z', b"T"));
1149                        let _ = s.write_all(&r);
1150                    }
1151                }
1152            }
1153        });
1154        thread::sleep(Duration::from_millis(30));
1155        port
1156    }
1157
1158    fn spawn_error_status_server() -> u16 {
1159        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1160        let port = listener.local_addr().unwrap().port();
1161        thread::spawn(move || {
1162            let (mut s, _) = listener.accept().unwrap();
1163            let mut buf = [0u8; 4096];
1164            let _ = s.read(&mut buf).unwrap();
1165            let _ = s.write_all(&pg_auth(3, &[]));
1166            let _ = s.read(&mut buf).unwrap();
1167            let _ = s.write_all(&post_auth_ok());
1168            loop {
1169                match s.read(&mut buf) {
1170                    Ok(0) | Err(_) => break,
1171                    Ok(_) => {
1172                        let mut r = Vec::new();
1173                        r.extend(pg_msg(b'1', &[]));
1174                        r.extend(pg_msg(b'2', &[]));
1175                        let mut rd = Vec::new();
1176                        rd.extend(&1u16.to_be_bytes());
1177                        rd.extend(b"c\x00");
1178                        rd.extend(&0u32.to_be_bytes());
1179                        rd.extend(&1u16.to_be_bytes());
1180                        rd.extend(&23u32.to_be_bytes());
1181                        rd.extend(&4i16.to_be_bytes());
1182                        rd.extend(&(-1i32).to_be_bytes());
1183                        rd.extend(&0u16.to_be_bytes());
1184                        r.extend(pg_msg(b'T', &rd));
1185                        let mut dr = Vec::new();
1186                        dr.extend(&1u16.to_be_bytes());
1187                        dr.extend(&1u32.to_be_bytes());
1188                        dr.push(b'1');
1189                        r.extend(pg_msg(b'D', &dr));
1190                        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
1191                        r.extend(pg_msg(b'Z', b"E"));
1192                        let _ = s.write_all(&r);
1193                    }
1194                }
1195            }
1196        });
1197        thread::sleep(Duration::from_millis(30));
1198        port
1199    }
1200
1201    fn spawn_slow_partial_server() -> u16 {
1202        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1203        let port = listener.local_addr().unwrap().port();
1204        thread::spawn(move || {
1205            let (mut s, _) = listener.accept().unwrap();
1206            let mut buf = [0u8; 4096];
1207            let _ = s.read(&mut buf).unwrap();
1208            let _ = s.write_all(&pg_auth(3, &[]));
1209            let _ = s.read(&mut buf).unwrap();
1210            let _ = s.write_all(&post_auth_ok());
1211            match s.read(&mut buf) {
1212                Ok(0) | Err(_) => {}
1213                Ok(_) => {
1214                    let _ = s.write_all(&simple_query_response());
1215                }
1216            }
1217        });
1218        thread::sleep(Duration::from_millis(30));
1219        port
1220    }
1221
1222    fn spawn_rst_on_query_server() -> u16 {
1223        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1224        let port = listener.local_addr().unwrap().port();
1225        thread::spawn(move || {
1226            let (mut s, _) = listener.accept().unwrap();
1227            let mut buf = [0u8; 4096];
1228            let _ = s.read(&mut buf).unwrap();
1229            let _ = s.write_all(&pg_auth(3, &[]));
1230            let _ = s.read(&mut buf).unwrap();
1231            let _ = s.write_all(&post_auth_ok());
1232            match s.read(&mut buf) {
1233                Ok(0) | Err(_) => {}
1234                Ok(_) => {
1235                    drop(s);
1236                }
1237            }
1238        });
1239        thread::sleep(Duration::from_millis(30));
1240        port
1241    }
1242
1243    #[test]
1244    fn connect_query_ready_for_query_transaction_status() {
1245        let port = spawn_transaction_status_server();
1246        let mut conn = Connect::new(mock_config(port)).unwrap();
1247        let result = conn.query("SELECT 1");
1248        assert!(result.is_ok());
1249    }
1250
1251    #[test]
1252    fn connect_query_ready_for_query_error_status() {
1253        let port = spawn_error_status_server();
1254        let mut conn = Connect::new(mock_config(port)).unwrap();
1255        let result = conn.query("SELECT 1");
1256        assert!(result.is_ok());
1257    }
1258
1259    #[test]
1260    fn connect_query_server_closes_after_partial() {
1261        let port = spawn_slow_partial_server();
1262        let mut conn = Connect::new(mock_config(port)).unwrap();
1263        let r1 = conn.query("SELECT 1");
1264        assert!(r1.is_ok());
1265        let r2 = conn.query("SELECT 1");
1266        assert!(r2.is_err());
1267    }
1268
1269    #[test]
1270    fn connect_query_server_rst_returns_io_or_connection_error() {
1271        let port = spawn_rst_on_query_server();
1272        let mut conn = Connect::new(mock_config(port)).unwrap();
1273        let result = conn.query("SELECT 1");
1274        assert!(result.is_err());
1275    }
1276
1277    #[test]
1278    fn connect_read_would_block_max_retries() {
1279        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1280        let port = listener.local_addr().unwrap().port();
1281        thread::spawn(move || {
1282            let (mut s, _) = listener.accept().unwrap();
1283            let mut buf = [0u8; 4096];
1284            let _ = s.read(&mut buf);
1285            let _ = s.write_all(&pg_auth(3, &[]));
1286            let _ = s.read(&mut buf);
1287            let _ = s.write_all(&post_auth_ok());
1288            let _ = s.read(&mut buf);
1289            thread::sleep(Duration::from_secs(5));
1290        });
1291        thread::sleep(Duration::from_millis(30));
1292
1293        let mut conn = Connect::new(mock_config(port)).unwrap();
1294        conn.stream
1295            .set_read_timeout(Some(Duration::from_millis(1)))
1296            .ok();
1297        let result = conn.query("SELECT 1");
1298        assert!(result.is_err());
1299        let err_str = result.unwrap_err().to_string();
1300        assert!(
1301            err_str.contains("超时") || err_str.contains("Timeout") || err_str.contains("重试"),
1302            "expected timeout error, got: {err_str}"
1303        );
1304    }
1305
1306    #[test]
1307    fn connect_read_exceeds_max_message_size() {
1308        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1309        let port = listener.local_addr().unwrap().port();
1310        thread::spawn(move || {
1311            let (mut s, _) = listener.accept().unwrap();
1312            let mut buf = [0u8; 4096];
1313            let _ = s.read(&mut buf);
1314            let _ = s.write_all(&pg_auth(3, &[]));
1315            let _ = s.read(&mut buf);
1316            let _ = s.write_all(&post_auth_ok());
1317            let _ = s.read(&mut buf);
1318            let big = vec![b'X'; 256];
1319            let _ = s.write_all(&big);
1320            thread::sleep(Duration::from_secs(2));
1321        });
1322        thread::sleep(Duration::from_millis(30));
1323
1324        let mut conn = Connect::new(mock_config(port)).unwrap();
1325        let result = conn.query("SELECT 1");
1326        assert!(result.is_err());
1327        let err_str = result.unwrap_err().to_string();
1328        assert!(
1329            err_str.contains("最大") || err_str.contains("大小") || err_str.contains("size"),
1330            "expected max message size error, got: {err_str}"
1331        );
1332    }
1333
1334    #[test]
1335    fn connect_read_deadline_timeout() {
1336        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1337        let port = listener.local_addr().unwrap().port();
1338        thread::spawn(move || {
1339            let (mut s, _) = listener.accept().unwrap();
1340            let mut buf = [0u8; 4096];
1341            let _ = s.read(&mut buf);
1342            let _ = s.write_all(&pg_auth(3, &[]));
1343            let _ = s.read(&mut buf);
1344            let _ = s.write_all(&post_auth_ok());
1345            let _ = s.read(&mut buf);
1346            for _ in 0..200 {
1347                let _ = s.write_all(b"X");
1348                thread::sleep(Duration::from_millis(5));
1349            }
1350        });
1351        thread::sleep(Duration::from_millis(30));
1352
1353        let mut conn = Connect::new(mock_config(port)).unwrap();
1354        let result = conn.query("SELECT 1");
1355        assert!(result.is_err());
1356    }
1357
1358    #[test]
1359    fn connect_read_partial_auth_frame() {
1360        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1361        let port = listener.local_addr().unwrap().port();
1362        thread::spawn(move || {
1363            let (mut s, _) = listener.accept().unwrap();
1364            let mut buf = [0u8; 4096];
1365            let _ = s.read(&mut buf);
1366            let auth = pg_auth(3, &[]);
1367            let _ = s.write_all(&auth[..5]);
1368            thread::sleep(Duration::from_millis(50));
1369            let _ = s.write_all(&auth[5..]);
1370            let _ = s.read(&mut buf);
1371            let _ = s.write_all(&post_auth_ok());
1372            loop {
1373                match s.read(&mut buf) {
1374                    Ok(0) | Err(_) => break,
1375                    Ok(_) => {
1376                        let _ = s.write_all(&simple_query_response());
1377                    }
1378                }
1379            }
1380        });
1381        thread::sleep(Duration::from_millis(30));
1382
1383        let mut conn = Connect::new(mock_config(port)).unwrap();
1384        let result = conn.query("SELECT 1");
1385        assert!(result.is_ok());
1386    }
1387    // ── portal/cursor helpers ──────────────────────────────────────────────────
1388    /// Build a portal query response with PortalSuspended:
1389    /// ParseComplete + BindComplete + RowDescription + DataRow(s) + PortalSuspended + ReadyForQuery
1390    fn portal_response(rows: u16) -> Vec<u8> {
1391        let mut r = Vec::new();
1392        r.extend(pg_msg(b'1', &[])); // ParseComplete
1393        r.extend(pg_msg(b'2', &[])); // BindComplete
1394                                     // RowDescription – 1 field "id", type_oid=23 (int4)
1395        let mut rd = Vec::new();
1396        rd.extend(&1u16.to_be_bytes());
1397        rd.extend(b"id\x00");
1398        rd.extend(&0u32.to_be_bytes());
1399        rd.extend(&1u16.to_be_bytes());
1400        rd.extend(&23u32.to_be_bytes());
1401        rd.extend(&4i16.to_be_bytes());
1402        rd.extend(&(-1i32).to_be_bytes());
1403        rd.extend(&0u16.to_be_bytes());
1404        r.extend(pg_msg(b'T', &rd));
1405        for i in 0..rows {
1406            let val = format!("{}", i + 1);
1407            let mut dr = Vec::new();
1408            dr.extend(&1u16.to_be_bytes());
1409            dr.extend(&(val.len() as u32).to_be_bytes());
1410            dr.extend(val.as_bytes());
1411            r.extend(pg_msg(b'D', &dr));
1412        }
1413        // PortalSuspended (tag 's')
1414        r.extend(pg_msg(b's', &[]));
1415        r.extend(ready_for_query());
1416        r
1417    }
1418    /// Build a portal complete response (no more rows):
1419    /// DataRow(s) + CommandComplete + ReadyForQuery
1420    fn portal_complete_response(rows: u16) -> Vec<u8> {
1421        let mut r = Vec::new();
1422        for i in 0..rows {
1423            let val = format!("{}", i + 1);
1424            let mut dr = Vec::new();
1425            dr.extend(&1u16.to_be_bytes());
1426            dr.extend(&(val.len() as u32).to_be_bytes());
1427            dr.extend(val.as_bytes());
1428            r.extend(pg_msg(b'D', &dr));
1429        }
1430        r.extend(pg_msg(b'C', b"SELECT 2\x00"));
1431        r.extend(ready_for_query());
1432        r
1433    }
1434    /// Build a close portal response: CloseComplete + ReadyForQuery
1435    fn close_portal_response() -> Vec<u8> {
1436        let mut r = Vec::new();
1437        r.extend(pg_msg(b'3', &[])); // CloseComplete
1438        r.extend(ready_for_query());
1439        r
1440    }
1441    fn spawn_portal_server() -> u16 {
1442        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1443        let port = listener.local_addr().unwrap().port();
1444        thread::spawn(move || {
1445            let (mut s, _) = listener.accept().unwrap();
1446            let mut buf = [0u8; 4096];
1447            // auth
1448            let _ = s.read(&mut buf).unwrap();
1449            let _ = s.write_all(&pg_auth(3, &[]));
1450            let _ = s.read(&mut buf).unwrap();
1451            let _ = s.write_all(&post_auth_ok());
1452            // First query: portal response with 2 rows + suspended
1453            match s.read(&mut buf) {
1454                Ok(0) | Err(_) => (),
1455                Ok(_) => {
1456                    let _ = s.write_all(&portal_response(2));
1457                }
1458            }
1459            // Second query: fetch_more → complete with 1 row
1460            match s.read(&mut buf) {
1461                Ok(0) | Err(_) => (),
1462                Ok(_) => {
1463                    let _ = s.write_all(&portal_complete_response(1));
1464                }
1465            }
1466            // Third query: close_portal
1467            match s.read(&mut buf) {
1468                Ok(0) | Err(_) => (),
1469                Ok(_) => {
1470                    let _ = s.write_all(&close_portal_response());
1471                }
1472            }
1473        });
1474        thread::sleep(Duration::from_millis(30));
1475        port
1476    }
1477    // ── SSL tests ─────────────────────────────────────────────────────────────
1478    #[test]
1479    fn ssl_prefer_fallback_on_rejection() {
1480        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1481        let port = listener.local_addr().unwrap().port();
1482        thread::spawn(move || {
1483            let (mut s, _) = listener.accept().unwrap();
1484            let mut buf = [0u8; 4096];
1485            // 读取 SSLRequest
1486            let _ = s.read(&mut buf);
1487            // 拒绝 SSL
1488            let _ = s.write_all(b"N");
1489            // 正常 cleartext auth
1490            let _ = s.read(&mut buf);
1491            let _ = s.write_all(&pg_auth(3, &[]));
1492            let _ = s.read(&mut buf);
1493            let _ = s.write_all(&post_auth_ok());
1494            loop {
1495                match s.read(&mut buf) {
1496                    Ok(0) | Err(_) => break,
1497                    Ok(_) => {
1498                        let _ = s.write_all(&simple_query_response());
1499                    }
1500                }
1501            }
1502        });
1503        thread::sleep(Duration::from_millis(30));
1504        let mut cfg = mock_config(port);
1505        cfg.sslmode = "prefer".into();
1506        let conn = Connect::new(cfg);
1507        assert!(conn.is_ok());
1508    }
1509    #[test]
1510    fn ssl_require_rejected_returns_error() {
1511        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1512        let port = listener.local_addr().unwrap().port();
1513        thread::spawn(move || {
1514            let (mut s, _) = listener.accept().unwrap();
1515            let mut buf = [0u8; 4096];
1516            let _ = s.read(&mut buf);
1517            let _ = s.write_all(b"N");
1518        });
1519        thread::sleep(Duration::from_millis(30));
1520        let mut cfg = mock_config(port);
1521        cfg.sslmode = "require".into();
1522        let result = Connect::new(cfg);
1523        assert!(result.is_err());
1524    }
1525    #[test]
1526    fn ssl_invalid_response_byte() {
1527        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1528        let port = listener.local_addr().unwrap().port();
1529        thread::spawn(move || {
1530            let (mut s, _) = listener.accept().unwrap();
1531            let mut buf = [0u8; 4096];
1532            let _ = s.read(&mut buf);
1533            let _ = s.write_all(b"X");
1534        });
1535        thread::sleep(Duration::from_millis(30));
1536        let mut cfg = mock_config(port);
1537        cfg.sslmode = "prefer".into();
1538        let result = Connect::new(cfg);
1539        assert!(result.is_err());
1540    }
1541    #[test]
1542    fn ssl_disable_skips_ssl_handshake() {
1543        let port = spawn_cleartext_server();
1544        let mut cfg = mock_config(port);
1545        cfg.sslmode = "disable".into();
1546        let conn = Connect::new(cfg);
1547        assert!(conn.is_ok());
1548    }
1549    // ── portal/cursor tests ──────────────────────────────────────────────────
1550    #[test]
1551    fn connect_query_portal_returns_rows_with_has_more() {
1552        let port = spawn_portal_server();
1553        let mut conn = Connect::new(mock_config(port)).unwrap();
1554        let result = conn.query_portal("SELECT id FROM t", 2);
1555        assert!(result.is_ok());
1556        let msg = result.unwrap();
1557        assert_eq!(msg.rows.len(), 2);
1558        assert!(msg.has_more);
1559    }
1560    #[test]
1561    fn connect_fetch_more_returns_remaining_rows() {
1562        let port = spawn_portal_server();
1563        let mut conn = Connect::new(mock_config(port)).unwrap();
1564        let _ = conn.query_portal("SELECT id FROM t", 2).unwrap();
1565        let result = conn.fetch_more(10);
1566        assert!(result.is_ok());
1567        let msg = result.unwrap();
1568        assert_eq!(msg.rows.len(), 1);
1569        assert!(!msg.has_more);
1570    }
1571    #[test]
1572    fn connect_close_portal_succeeds() {
1573        let port = spawn_portal_server();
1574        let mut conn = Connect::new(mock_config(port)).unwrap();
1575        let _ = conn.query_portal("SELECT id FROM t", 2).unwrap();
1576        let _ = conn.fetch_more(10).unwrap();
1577        let result = conn.close_portal();
1578        assert!(result.is_ok());
1579    }
1580}