Skip to main content

br_pgsql/
connect.rs

1use crate::config::Config;
2use crate::error::PgsqlError;
3use crate::packet::{AuthStatus, Packet, SuccessMessage};
4use std::io::{Read, Write};
5use std::net::TcpStream;
6use std::time::{Duration, Instant};
7
8#[derive(Debug)]
9pub struct Connect {
10    pub(crate) stream: TcpStream,
11    packet: Packet,
12    auth_status: AuthStatus,
13    /// 上次使用时间,用于懒健康检查
14    last_used: Instant,
15    /// 连接创建时间,用于最大生命周期检查
16    created_at: Instant,
17}
18
19impl Connect {
20    /// 懒健康检查:peer_addr 检测 TCP 存活,空闲超过 5 秒才发 SELECT 1
21    pub fn is_valid(&mut self) -> bool {
22        if self.stream.peer_addr().is_err() {
23            return false;
24        }
25        #[cfg(not(test))]
26        const IDLE_THRESHOLD: Duration = Duration::from_secs(5);
27        #[cfg(test)]
28        const IDLE_THRESHOLD: Duration = Duration::from_millis(0);
29        if self.last_used.elapsed() > IDLE_THRESHOLD {
30            return self.query("SELECT 1").is_ok();
31        }
32        true
33    }
34
35    /// 仅检查 TCP 连接是否存活(不发送查询)
36    pub fn peer_valid(&self) -> bool {
37        self.stream.peer_addr().is_ok()
38    }
39
40    /// 更新最后使用时间
41    pub fn touch(&mut self) {
42        self.last_used = Instant::now();
43    }
44
45    /// 返回空闲时长
46    pub fn idle_elapsed(&self) -> Duration {
47        self.last_used.elapsed()
48    }
49
50    /// 返回连接已存活时长
51    pub fn age(&self) -> Duration {
52        self.created_at.elapsed()
53    }
54
55    pub fn _close(&mut self) {
56        let _ = (&self.stream).write_all(&Packet::pack_terminate());
57        let _ = self.stream.shutdown(std::net::Shutdown::Both);
58    }
59
60    /// 设置 TCP Keepalive(60秒间隔,3次探测)
61    fn set_keepalive(stream: &TcpStream) -> Result<(), PgsqlError> {
62        use std::os::unix::io::{AsRawFd, FromRawFd};
63        let fd = stream.as_raw_fd();
64        let socket = unsafe { socket2::Socket::from_raw_fd(fd) };
65        let keepalive = socket2::TcpKeepalive::new()
66            .with_time(Duration::from_secs(60))
67            .with_interval(Duration::from_secs(15))
68            .with_retries(3);
69        let result = socket.set_tcp_keepalive(&keepalive);
70        // 不要让 socket2::Socket drop 关闭 fd,转回原始 fd
71        std::mem::forget(socket);
72        result.map_err(|e| PgsqlError::Connection(format!("设置 TCP Keepalive 失败: {}", e)))
73    }
74
75    pub fn new(mut config: Config) -> Result<Connect, PgsqlError> {
76        let stream =
77            TcpStream::connect(config.url()).map_err(|e| PgsqlError::Connection(e.to_string()))?;
78        // TCP 优化:禁用 Nagle 算法,减少小包延迟
79        stream
80            .set_nodelay(true)
81            .map_err(|e| PgsqlError::Connection(format!("设置 TCP_NODELAY 失败: {}", e)))?;
82        // TCP Keepalive:防止空闲连接被防火墙/NAT/PG服务端静默断开
83        Self::set_keepalive(&stream)?;
84        stream
85            .set_read_timeout(Some(Duration::from_secs(30)))
86            .map_err(|e| PgsqlError::Connection(format!("设置读取超时失败: {}", e)))?;
87        stream
88            .set_write_timeout(Some(Duration::from_secs(30)))
89            .map_err(|e| PgsqlError::Connection(format!("设置写入超时失败: {}", e)))?;
90        let _ = stream.peer_addr();
91
92        let mut connect = Self {
93            stream,
94            packet: Packet::new(config),
95            auth_status: AuthStatus::None,
96            last_used: Instant::now(),
97            created_at: Instant::now(),
98        };
99
100        connect.authenticate()?;
101
102        Ok(connect)
103    }
104
105    fn authenticate(&mut self) -> Result<(), PgsqlError> {
106        (&self.stream)
107            .write_all(&self.packet.pack_first())
108            .map_err(|e| PgsqlError::Auth(format!("发送 startup message 失败: {}", e)))?;
109
110        let data = self.read()?;
111        self.packet.unpack(data, 0)?;
112
113        if !self.packet.md5_salt.is_empty() {
114            self.md5_auth()?;
115        } else if self.packet.auth_mechanism.is_empty() && self.packet.md5_salt.is_empty() {
116            self.cleartext_auth()?;
117        } else {
118            self.scram_auth()?;
119        }
120
121        self.auth_status = AuthStatus::AuthenticationOk;
122        Ok(())
123    }
124
125    fn md5_auth(&mut self) -> Result<(), PgsqlError> {
126        (&self.stream)
127            .write_all(&self.packet.pack_md5_password())
128            .map_err(|e| PgsqlError::Auth(format!("发送 MD5 密码失败: {}", e)))?;
129
130        let data = self.read()?;
131        self.packet.unpack(data, 0)?;
132        Ok(())
133    }
134
135    fn cleartext_auth(&mut self) -> Result<(), PgsqlError> {
136        (&self.stream)
137            .write_all(&self.packet.pack_cleartext_password())
138            .map_err(|e| PgsqlError::Auth(format!("发送明文密码失败: {}", e)))?;
139
140        let data = self.read()?;
141        self.packet.unpack(data, 0)?;
142        Ok(())
143    }
144
145    fn scram_auth(&mut self) -> Result<(), PgsqlError> {
146        (&self.stream)
147            .write_all(&self.packet.pack_auth())
148            .map_err(|e| PgsqlError::Auth(format!("发送 SASL Initial Response 失败: {}", e)))?;
149
150        let data = self.read()?;
151        self.packet.unpack(data, 0)?;
152
153        (&self.stream)
154            .write_all(&self.packet.pack_auth_verify())
155            .map_err(|e| PgsqlError::Auth(format!("发送 SASL Verify 失败: {}", e)))?;
156
157        let data = self.read()?;
158        self.packet.unpack(data, 0)?;
159        Ok(())
160    }
161
162    fn read(&mut self) -> Result<Vec<u8>, PgsqlError> {
163        let mut msg = Vec::new();
164        let mut buf = [0u8; 4096];
165        let mut retry_count = 0;
166
167        #[cfg(not(test))]
168        const MAX_RETRIES: u32 = 100;
169        #[cfg(test)]
170        const MAX_RETRIES: u32 = 3;
171
172        #[cfg(not(test))]
173        const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024;
174        #[cfg(test)]
175        const MAX_MESSAGE_SIZE: usize = 128;
176
177        #[cfg(not(test))]
178        let deadline = std::time::Instant::now() + Duration::from_secs(300);
179        #[cfg(test)]
180        let deadline = std::time::Instant::now() + Duration::from_millis(200);
181
182        loop {
183            if std::time::Instant::now() >= deadline {
184                return Err(PgsqlError::Timeout("读取总超时".into()));
185            }
186
187            match (&self.stream).read(&mut buf) {
188                Ok(0) => return Err(PgsqlError::Connection("连接已关闭或服务端断开".into())),
189                Ok(n) => {
190                    if msg.len() + n > MAX_MESSAGE_SIZE {
191                        return Err(PgsqlError::Protocol("消息超过最大允许大小".into()));
192                    }
193                    msg.extend_from_slice(&buf[..n]);
194                    retry_count = 0;
195                }
196                Err(ref e)
197                    if e.kind() == std::io::ErrorKind::WouldBlock
198                        || e.kind() == std::io::ErrorKind::TimedOut =>
199                {
200                    retry_count += 1;
201                    if retry_count > MAX_RETRIES {
202                        return Err(PgsqlError::Timeout("读取超时,已达最大重试次数".into()));
203                    }
204                    std::thread::sleep(Duration::from_millis(10));
205                    continue;
206                }
207                Err(e) => return Err(PgsqlError::Io(e)),
208            };
209
210            if let AuthStatus::AuthenticationOk = self.auth_status {
211                if msg.ends_with(&[90, 0, 0, 0, 5, 73])
212                    || msg.ends_with(&[90, 0, 0, 0, 5, 84])
213                    || msg.ends_with(&[90, 0, 0, 0, 5, 69])
214                {
215                    break;
216                }
217            } else if msg.len() >= 5 {
218                let len_bytes = &msg[1..=4];
219                if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
220                    if msg.len() > len as usize {
221                        break;
222                    }
223                }
224            }
225        }
226
227        Ok(msg)
228    }
229
230    pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
231        (&self.stream)
232            .write_all(&self.packet.pack_query(sql))
233            .map_err(PgsqlError::Io)?;
234        let data = self.read()?;
235        self.last_used = Instant::now();
236        self.packet.unpack(data, 0)
237    }
238
239    pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
240        (&self.stream)
241            .write_all(&self.packet.pack_execute(sql))
242            .map_err(PgsqlError::Io)?;
243        let data = self.read()?;
244        self.last_used = Instant::now();
245        self.packet.unpack(data, 0)
246    }
247
248    /// 参数化查询
249    pub fn query_params(
250        &mut self,
251        sql: &str,
252        params: &[Option<&str>],
253    ) -> Result<SuccessMessage, PgsqlError> {
254        (&self.stream)
255            .write_all(&self.packet.pack_query_params(sql, params))
256            .map_err(PgsqlError::Io)?;
257
258        let data = self.read()?;
259        self.last_used = Instant::now();
260        self.packet.unpack(data, 0)
261    }
262
263    /// 参数化执行
264    pub fn execute_params(
265        &mut self,
266        sql: &str,
267        params: &[Option<&str>],
268    ) -> Result<SuccessMessage, PgsqlError> {
269        (&self.stream)
270            .write_all(&self.packet.pack_execute_params(sql, params))
271            .map_err(PgsqlError::Io)?;
272        let data = self.read()?;
273        self.last_used = Instant::now();
274        self.packet.unpack(data, 0)
275    }
276
277    /// 参数化查询(便捷版,所有参数非 NULL)
278    pub fn query_str(&mut self, sql: &str, params: &[&str]) -> Result<SuccessMessage, PgsqlError> {
279        let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
280        self.query_params(sql, &opts)
281    }
282
283    /// 参数化执行(便捷版,所有参数非 NULL)
284    pub fn execute_str(
285        &mut self,
286        sql: &str,
287        params: &[&str],
288    ) -> Result<SuccessMessage, PgsqlError> {
289        let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
290        self.execute_params(sql, &opts)
291    }
292}
293
294impl Drop for Connect {
295    fn drop(&mut self) {
296        let _ = (&self.stream).write_all(&Packet::pack_terminate());
297        let _ = self.stream.shutdown(std::net::Shutdown::Both);
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use std::net::TcpListener;
305    use std::thread;
306
307    // ── wire-protocol helpers ──────────────────────────────────────────
308
309    /// Build a single PG backend message: type_byte | len(4) | payload
310    fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
311        let mut m = Vec::with_capacity(5 + payload.len());
312        m.push(tag);
313        m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
314        m.extend_from_slice(payload);
315        m
316    }
317
318    /// Build an Authentication message (tag 'R') with given auth_type + extra payload
319    fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
320        let mut body = Vec::new();
321        body.extend(&auth_type.to_be_bytes());
322        body.extend_from_slice(extra);
323        pg_msg(b'R', &body)
324    }
325
326    /// AuthenticationOk (R, type=0)
327    fn auth_ok() -> Vec<u8> {
328        pg_auth(0, &[])
329    }
330
331    /// ParameterStatus for server_version=15.0
332    fn param_status() -> Vec<u8> {
333        pg_msg(b'S', b"server_version\x0015.0\x00")
334    }
335
336    /// BackendKeyData (process_id=1, secret_key=2)
337    fn backend_key() -> Vec<u8> {
338        let mut p = Vec::new();
339        p.extend(&1u32.to_be_bytes());
340        p.extend(&2u32.to_be_bytes());
341        pg_msg(b'K', &p)
342    }
343
344    /// ReadyForQuery (status = Idle)
345    fn ready_for_query() -> Vec<u8> {
346        pg_msg(b'Z', b"I")
347    }
348
349    /// The standard tail sent after auth succeeds: AuthOk + param + key + ready
350    fn post_auth_ok() -> Vec<u8> {
351        let mut v = Vec::new();
352        v.extend(auth_ok());
353        v.extend(param_status());
354        v.extend(backend_key());
355        v.extend(ready_for_query());
356        v
357    }
358
359    /// Build a simple query response: ParseComplete + BindComplete +
360    /// RowDescription(1 int4 col "c") + DataRow("1") + CommandComplete("SELECT 1") + ReadyForQuery
361    fn simple_query_response() -> Vec<u8> {
362        let mut r = Vec::new();
363        // ParseComplete
364        r.extend(pg_msg(b'1', &[]));
365        // BindComplete
366        r.extend(pg_msg(b'2', &[]));
367        // RowDescription – 1 field "c", type_oid=23 (int4)
368        let mut rd = Vec::new();
369        rd.extend(&1u16.to_be_bytes()); // field count
370        rd.extend(b"c\x00"); // name
371        rd.extend(&0u32.to_be_bytes()); // table oid
372        rd.extend(&1u16.to_be_bytes()); // column index
373        rd.extend(&23u32.to_be_bytes()); // type oid (int4)
374        rd.extend(&4i16.to_be_bytes()); // column length
375        rd.extend(&(-1i32).to_be_bytes()); // type modifier
376        rd.extend(&0u16.to_be_bytes()); // format (text)
377        r.extend(pg_msg(b'T', &rd));
378        // DataRow – 1 field, value "1"
379        let mut dr = Vec::new();
380        dr.extend(&1u16.to_be_bytes());
381        dr.extend(&1u32.to_be_bytes()); // length of value
382        dr.push(b'1');
383        r.extend(pg_msg(b'D', &dr));
384        // CommandComplete
385        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
386        // ReadyForQuery
387        r.extend(ready_for_query());
388        r
389    }
390
391    /// Build an execute response (no rows): ParseComplete + BindComplete +
392    /// NoData + CommandComplete("UPDATE 3") + ReadyForQuery
393    fn execute_response() -> Vec<u8> {
394        let mut r = Vec::new();
395        r.extend(pg_msg(b'1', &[]));
396        r.extend(pg_msg(b'2', &[]));
397        r.extend(pg_msg(b'n', &[])); // NoData
398        r.extend(pg_msg(b'C', b"UPDATE 3\x00"));
399        r.extend(ready_for_query());
400        r
401    }
402
403    /// Build a parameterized query response: ParseComplete + ParameterDescription + BindComplete +
404    /// RowDescription(1 int4 col "p") + DataRow("42") + CommandComplete("SELECT 1") + ReadyForQuery
405    fn query_params_response() -> Vec<u8> {
406        let mut r = Vec::new();
407        r.extend(pg_msg(b'1', &[]));
408
409        let mut pd = Vec::new();
410        pd.extend(&1u16.to_be_bytes());
411        pd.extend(&23u32.to_be_bytes());
412        r.extend(pg_msg(b't', &pd));
413
414        r.extend(pg_msg(b'2', &[]));
415
416        let mut rd = Vec::new();
417        rd.extend(&1u16.to_be_bytes());
418        rd.extend(b"p\x00");
419        rd.extend(&0u32.to_be_bytes());
420        rd.extend(&1u16.to_be_bytes());
421        rd.extend(&23u32.to_be_bytes());
422        rd.extend(&4i16.to_be_bytes());
423        rd.extend(&(-1i32).to_be_bytes());
424        rd.extend(&0u16.to_be_bytes());
425        r.extend(pg_msg(b'T', &rd));
426
427        let mut dr = Vec::new();
428        dr.extend(&1u16.to_be_bytes());
429        dr.extend(&2u32.to_be_bytes());
430        dr.extend(b"42");
431        r.extend(pg_msg(b'D', &dr));
432
433        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
434        r.extend(ready_for_query());
435        r
436    }
437
438    /// Build a parameterized execute response: ParseComplete + ParameterDescription + BindComplete +
439    /// NoData + CommandComplete("UPDATE 1") + ReadyForQuery
440    fn execute_params_response() -> Vec<u8> {
441        let mut r = Vec::new();
442        r.extend(pg_msg(b'1', &[]));
443
444        let mut pd = Vec::new();
445        pd.extend(&1u16.to_be_bytes());
446        pd.extend(&23u32.to_be_bytes());
447        r.extend(pg_msg(b't', &pd));
448
449        r.extend(pg_msg(b'2', &[]));
450        r.extend(pg_msg(b'n', &[]));
451        r.extend(pg_msg(b'C', b"UPDATE 1\x00"));
452        r.extend(ready_for_query());
453        r
454    }
455
456    /// Build a parameterized query response with NULL row value.
457    fn query_params_null_response() -> Vec<u8> {
458        let mut r = Vec::new();
459        r.extend(pg_msg(b'1', &[]));
460
461        let mut pd = Vec::new();
462        pd.extend(&1u16.to_be_bytes());
463        pd.extend(&25u32.to_be_bytes());
464        r.extend(pg_msg(b't', &pd));
465
466        r.extend(pg_msg(b'2', &[]));
467
468        let mut rd = Vec::new();
469        rd.extend(&1u16.to_be_bytes());
470        rd.extend(b"n\x00");
471        rd.extend(&0u32.to_be_bytes());
472        rd.extend(&1u16.to_be_bytes());
473        rd.extend(&25u32.to_be_bytes());
474        rd.extend(&(-1i16).to_be_bytes());
475        rd.extend(&(-1i32).to_be_bytes());
476        rd.extend(&0u16.to_be_bytes());
477        r.extend(pg_msg(b'T', &rd));
478
479        let mut dr = Vec::new();
480        dr.extend(&1u16.to_be_bytes());
481        dr.extend(&(-1i32).to_be_bytes());
482        r.extend(pg_msg(b'D', &dr));
483
484        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
485        r.extend(ready_for_query());
486        r
487    }
488
489    /// Build an ErrorResponse for query phase
490    fn error_response() -> Vec<u8> {
491        let mut payload = Vec::new();
492        payload.push(b'C');
493        payload.extend(b"42601\x00");
494        payload.push(b'M');
495        payload.extend(b"syntax error\x00");
496        payload.push(0);
497        let mut r = Vec::new();
498        r.extend(pg_msg(b'1', &[]));
499        r.extend(pg_msg(b'2', &[]));
500        r.extend(pg_msg(b'E', &payload));
501        r.extend(ready_for_query());
502        r
503    }
504
505    // ── mock server spawners ───────────────────────────────────────────
506
507    /// Config pointing at 127.0.0.1:<port>
508    fn mock_config(port: u16) -> Config {
509        Config {
510            debug: false,
511            hostname: "127.0.0.1".into(),
512            hostport: port as i32,
513            username: "u".into(),
514            userpass: "p".into(),
515            database: "d".into(),
516            charset: "utf8".into(),
517            pool_max: 5,
518        }
519    }
520
521    /// Spawn a mock PG server that does **cleartext** auth.
522    /// Returns the port.  The server handles one connection.
523    fn spawn_cleartext_server() -> u16 {
524        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
525        let port = listener.local_addr().unwrap().port();
526        thread::spawn(move || {
527            let (mut s, _) = listener.accept().unwrap();
528            let mut buf = [0u8; 4096];
529            // 1. read startup
530            let _ = s.read(&mut buf).unwrap();
531            // 2. send CleartextPassword request (auth_type=3)
532            let _ = s.write_all(&pg_auth(3, &[]));
533            // 3. read password message
534            let _ = s.read(&mut buf).unwrap();
535            // 4. send AuthOk + params + ready
536            let _ = s.write_all(&post_auth_ok());
537            // keep connection alive for queries
538            loop {
539                match s.read(&mut buf) {
540                    Ok(0) | Err(_) => break,
541                    Ok(_) => {
542                        let _ = s.write_all(&simple_query_response());
543                    }
544                }
545            }
546        });
547        thread::sleep(Duration::from_millis(30));
548        port
549    }
550
551    /// Spawn a mock PG server that does **MD5** auth.
552    fn spawn_md5_server() -> u16 {
553        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
554        let port = listener.local_addr().unwrap().port();
555        thread::spawn(move || {
556            let (mut s, _) = listener.accept().unwrap();
557            let mut buf = [0u8; 4096];
558            // 1. read startup
559            let _ = s.read(&mut buf).unwrap();
560            // 2. send MD5Password request (auth_type=5) + 4-byte salt
561            let _ = s.write_all(&pg_auth(5, &[0xAA, 0xBB, 0xCC, 0xDD]));
562            // 3. read md5 password
563            let _ = s.read(&mut buf).unwrap();
564            // 4. send AuthOk + params + ready
565            let _ = s.write_all(&post_auth_ok());
566            loop {
567                match s.read(&mut buf) {
568                    Ok(0) | Err(_) => break,
569                    Ok(_) => {
570                        let _ = s.write_all(&simple_query_response());
571                    }
572                }
573            }
574        });
575        thread::sleep(Duration::from_millis(30));
576        port
577    }
578
579    /// Spawn a mock PG server that does **SCRAM-SHA-256** auth.
580    fn spawn_scram_server() -> u16 {
581        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
582        let port = listener.local_addr().unwrap().port();
583        thread::spawn(move || {
584            let (mut s, _) = listener.accept().unwrap();
585            let mut buf = [0u8; 4096];
586            // 1. read startup
587            let _ = s.read(&mut buf).unwrap();
588            // 2. send SASL auth (auth_type=10, mechanism)
589            let _ = s.write_all(&pg_auth(10, b"SCRAM-SHA-256\x00\x00"));
590            // 3. read SASLInitialResponse – extract client nonce
591            let n = s.read(&mut buf).unwrap();
592            let payload = &buf[..n];
593            // find "n=,r=" in the payload to extract client nonce
594            let text = String::from_utf8_lossy(payload);
595            let client_nonce = text.split("r=").nth(1).unwrap_or("clientnonce").to_string();
596            // 4. send SCRAM challenge (auth_type=11)
597            let challenge = format!("r={client_nonce}SERVERNONCE,s=c2FsdA==,i=4096");
598            let _ = s.write_all(&pg_auth(11, challenge.as_bytes()));
599            // 5. read SCRAM client final
600            let _ = s.read(&mut buf).unwrap();
601            // 6. send SCRAM complete (auth_type=12) + AuthOk + params + ready
602            let mut resp = Vec::new();
603            resp.extend(pg_auth(12, b"v=dummyproof"));
604            resp.extend(auth_ok());
605            resp.extend(param_status());
606            resp.extend(backend_key());
607            resp.extend(ready_for_query());
608            let _ = s.write_all(&resp);
609            loop {
610                match s.read(&mut buf) {
611                    Ok(0) | Err(_) => break,
612                    Ok(_) => {
613                        let _ = s.write_all(&simple_query_response());
614                    }
615                }
616            }
617        });
618        thread::sleep(Duration::from_millis(30));
619        port
620    }
621
622    /// Spawn a server that accepts connection then immediately closes (EOF).
623    fn spawn_eof_server() -> u16 {
624        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
625        let port = listener.local_addr().unwrap().port();
626        thread::spawn(move || {
627            let (s, _) = listener.accept().unwrap();
628            drop(s); // close immediately
629        });
630        thread::sleep(Duration::from_millis(30));
631        port
632    }
633
634    /// Spawn a server that sends an ErrorResponse after startup.
635    fn spawn_auth_error_server() -> u16 {
636        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
637        let port = listener.local_addr().unwrap().port();
638        thread::spawn(move || {
639            let (mut s, _) = listener.accept().unwrap();
640            let mut buf = [0u8; 4096];
641            let _ = s.read(&mut buf).unwrap();
642            // Send ErrorResponse
643            let mut payload = Vec::new();
644            payload.push(b'C');
645            payload.extend(b"28P01\x00");
646            payload.push(b'M');
647            payload.extend(b"password authentication failed\x00");
648            payload.push(0);
649            let _ = s.write_all(&pg_msg(b'E', &payload));
650        });
651        thread::sleep(Duration::from_millis(30));
652        port
653    }
654
655    /// Spawn a cleartext server that responds with error on query.
656    fn spawn_query_error_server() -> u16 {
657        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
658        let port = listener.local_addr().unwrap().port();
659        thread::spawn(move || {
660            let (mut s, _) = listener.accept().unwrap();
661            let mut buf = [0u8; 4096];
662            // auth
663            let _ = s.read(&mut buf).unwrap();
664            let _ = s.write_all(&pg_auth(3, &[]));
665            let _ = s.read(&mut buf).unwrap();
666            let _ = s.write_all(&post_auth_ok());
667            // query → error
668            loop {
669                match s.read(&mut buf) {
670                    Ok(0) | Err(_) => break,
671                    Ok(_) => {
672                        let _ = s.write_all(&error_response());
673                    }
674                }
675            }
676        });
677        thread::sleep(Duration::from_millis(30));
678        port
679    }
680
681    /// Spawn a cleartext server that responds with execute response.
682    fn spawn_execute_server() -> u16 {
683        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
684        let port = listener.local_addr().unwrap().port();
685        thread::spawn(move || {
686            let (mut s, _) = listener.accept().unwrap();
687            let mut buf = [0u8; 4096];
688            // auth
689            let _ = s.read(&mut buf).unwrap();
690            let _ = s.write_all(&pg_auth(3, &[]));
691            let _ = s.read(&mut buf).unwrap();
692            let _ = s.write_all(&post_auth_ok());
693            // queries
694            loop {
695                match s.read(&mut buf) {
696                    Ok(0) | Err(_) => break,
697                    Ok(_) => {
698                        let _ = s.write_all(&execute_response());
699                    }
700                }
701            }
702        });
703        thread::sleep(Duration::from_millis(30));
704        port
705    }
706
707    /// Spawn a cleartext server that responds with parameterized query response.
708    fn spawn_query_params_server() -> u16 {
709        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
710        let port = listener.local_addr().unwrap().port();
711        thread::spawn(move || {
712            let (mut s, _) = listener.accept().unwrap();
713            let mut buf = [0u8; 4096];
714            let _ = s.read(&mut buf).unwrap();
715            let _ = s.write_all(&pg_auth(3, &[]));
716            let _ = s.read(&mut buf).unwrap();
717            let _ = s.write_all(&post_auth_ok());
718            loop {
719                match s.read(&mut buf) {
720                    Ok(0) | Err(_) => break,
721                    Ok(_) => {
722                        let _ = s.write_all(&query_params_response());
723                    }
724                }
725            }
726        });
727        thread::sleep(Duration::from_millis(30));
728        port
729    }
730
731    fn spawn_params_server() -> u16 {
732        spawn_query_params_server()
733    }
734
735    /// Spawn a cleartext server that responds with parameterized execute response.
736    fn spawn_execute_params_server() -> u16 {
737        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
738        let port = listener.local_addr().unwrap().port();
739        thread::spawn(move || {
740            let (mut s, _) = listener.accept().unwrap();
741            let mut buf = [0u8; 4096];
742            let _ = s.read(&mut buf).unwrap();
743            let _ = s.write_all(&pg_auth(3, &[]));
744            let _ = s.read(&mut buf).unwrap();
745            let _ = s.write_all(&post_auth_ok());
746            loop {
747                match s.read(&mut buf) {
748                    Ok(0) | Err(_) => break,
749                    Ok(_) => {
750                        let _ = s.write_all(&execute_params_response());
751                    }
752                }
753            }
754        });
755        thread::sleep(Duration::from_millis(30));
756        port
757    }
758
759    /// Spawn a cleartext server that returns NULL in parameterized query result.
760    fn spawn_query_params_null_server() -> u16 {
761        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
762        let port = listener.local_addr().unwrap().port();
763        thread::spawn(move || {
764            let (mut s, _) = listener.accept().unwrap();
765            let mut buf = [0u8; 4096];
766            let _ = s.read(&mut buf).unwrap();
767            let _ = s.write_all(&pg_auth(3, &[]));
768            let _ = s.read(&mut buf).unwrap();
769            let _ = s.write_all(&post_auth_ok());
770            loop {
771                match s.read(&mut buf) {
772                    Ok(0) | Err(_) => break,
773                    Ok(_) => {
774                        let _ = s.write_all(&query_params_null_response());
775                    }
776                }
777            }
778        });
779        thread::sleep(Duration::from_millis(30));
780        port
781    }
782
783    // ── tests ──────────────────────────────────────────────────────────
784
785    #[test]
786    fn connect_cleartext_auth_success() {
787        let port = spawn_cleartext_server();
788        let conn = Connect::new(mock_config(port));
789        assert!(conn.is_ok());
790    }
791
792    #[test]
793    fn connect_md5_auth_success() {
794        let port = spawn_md5_server();
795        let conn = Connect::new(mock_config(port));
796        assert!(conn.is_ok());
797    }
798
799    #[test]
800    fn connect_scram_auth_success() {
801        let port = spawn_scram_server();
802        let conn = Connect::new(mock_config(port));
803        assert!(conn.is_ok());
804    }
805
806    #[test]
807    fn connect_connection_refused() {
808        // port 1 is almost certainly not listening
809        let cfg = mock_config(1);
810        let result = Connect::new(cfg);
811        assert!(result.is_err());
812        match result.unwrap_err() {
813            PgsqlError::Connection(_) => {}
814            other => panic!("expected Connection error, got {other:?}"),
815        }
816    }
817
818    #[test]
819    fn connect_server_closes_immediately() {
820        let port = spawn_eof_server();
821        let result = Connect::new(mock_config(port));
822        assert!(result.is_err());
823    }
824
825    #[test]
826    fn connect_auth_error_from_server() {
827        let port = spawn_auth_error_server();
828        let result = Connect::new(mock_config(port));
829        assert!(result.is_err());
830    }
831
832    #[test]
833    fn connect_query_success() {
834        let port = spawn_cleartext_server();
835        let mut conn = Connect::new(mock_config(port)).unwrap();
836        let result = conn.query("SELECT 1");
837        assert!(result.is_ok());
838        let msg = result.unwrap();
839        assert_eq!(msg.rows.len(), 1);
840        assert_eq!(msg.rows[0]["c"].as_i32(), Some(1));
841    }
842
843    #[test]
844    fn connect_execute_success() {
845        let port = spawn_execute_server();
846        let mut conn = Connect::new(mock_config(port)).unwrap();
847        let result = conn.execute("UPDATE t SET x=1");
848        assert!(result.is_ok());
849        let msg = result.unwrap();
850        assert_eq!(msg.affect_count, 3);
851        assert_eq!(msg.tag, "UPDATE 3");
852    }
853
854    #[test]
855    fn connect_query_params_success() {
856        let port = spawn_query_params_server();
857        let mut conn = Connect::new(mock_config(port)).unwrap();
858        let result = conn.query_params("SELECT $1::int", &[Some("42")]);
859        assert!(result.is_ok());
860        let msg = result.unwrap();
861        assert!(!msg.param_oids.is_empty());
862        assert_eq!(msg.rows.len(), 1);
863        assert_eq!(msg.rows[0]["p"].as_i32(), Some(42));
864    }
865
866    #[test]
867    fn connect_execute_params_success() {
868        let port = spawn_execute_params_server();
869        let mut conn = Connect::new(mock_config(port)).unwrap();
870        let result = conn.execute_params("UPDATE t SET x=$1", &[Some("42")]);
871        assert!(result.is_ok());
872        let msg = result.unwrap();
873        assert!(!msg.param_oids.is_empty());
874        assert_eq!(msg.affect_count, 1);
875        assert_eq!(msg.tag, "UPDATE 1");
876    }
877
878    #[test]
879    fn connect_query_str_success() {
880        let port = spawn_params_server();
881        let mut conn = Connect::new(mock_config(port)).unwrap();
882        let result = conn.query_str("SELECT $1::int", &["42"]);
883        assert!(result.is_ok());
884        let msg = result.unwrap();
885        assert!(!msg.param_oids.is_empty());
886        assert_eq!(msg.rows.len(), 1);
887    }
888
889    #[test]
890    fn connect_execute_str_success() {
891        let port = spawn_execute_params_server();
892        let mut conn = Connect::new(mock_config(port)).unwrap();
893        let result = conn.execute_str("UPDATE t SET x=$1", &["1"]);
894        assert!(result.is_ok());
895        let msg = result.unwrap();
896        assert!(!msg.param_oids.is_empty());
897        assert_eq!(msg.affect_count, 1);
898    }
899
900    #[test]
901    fn connect_query_params_with_null() {
902        let port = spawn_query_params_null_server();
903        let mut conn = Connect::new(mock_config(port)).unwrap();
904        let result = conn.query_params("SELECT $1::text", &[None]);
905        assert!(result.is_ok());
906        let msg = result.unwrap();
907        assert!(!msg.param_oids.is_empty());
908        assert_eq!(msg.rows.len(), 1);
909        assert_eq!(msg.rows[0]["n"], "");
910    }
911
912    #[test]
913    fn connect_query_params_empty_string_vs_null() {
914        let port = spawn_params_server();
915        let mut conn = Connect::new(mock_config(port)).unwrap();
916
917        // 空字符串参数
918        let r1 = conn.query_params("SELECT $1::text", &[Some("")]);
919        assert!(r1.is_ok());
920
921        // NULL 参数
922        let r2 = conn.query_params("SELECT $1::text", &[None]);
923        assert!(r2.is_ok());
924    }
925
926    #[test]
927    fn connect_query_returns_error() {
928        let port = spawn_query_error_server();
929        let mut conn = Connect::new(mock_config(port)).unwrap();
930        let result = conn.query("BAD SQL");
931        assert!(result.is_err());
932    }
933
934    #[test]
935    fn connect_is_valid_true() {
936        let port = spawn_cleartext_server();
937        let mut conn = Connect::new(mock_config(port)).unwrap();
938        assert!(conn.is_valid());
939    }
940
941    #[test]
942    fn connect_is_valid_false_after_close() {
943        let port = spawn_cleartext_server();
944        let mut conn = Connect::new(mock_config(port)).unwrap();
945        conn._close();
946        // After closing, is_valid should return false
947        assert!(!conn.is_valid());
948    }
949
950    #[test]
951    fn connect_close_does_not_panic() {
952        let port = spawn_cleartext_server();
953        let mut conn = Connect::new(mock_config(port)).unwrap();
954        conn._close();
955        // calling close again should not panic
956        conn._close();
957    }
958
959    #[test]
960    fn connect_drop_does_not_panic() {
961        let port = spawn_cleartext_server();
962        let conn = Connect::new(mock_config(port)).unwrap();
963        drop(conn);
964    }
965
966    fn spawn_transaction_status_server() -> u16 {
967        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
968        let port = listener.local_addr().unwrap().port();
969        thread::spawn(move || {
970            let (mut s, _) = listener.accept().unwrap();
971            let mut buf = [0u8; 4096];
972            let _ = s.read(&mut buf).unwrap();
973            let _ = s.write_all(&pg_auth(3, &[]));
974            let _ = s.read(&mut buf).unwrap();
975            let _ = s.write_all(&post_auth_ok());
976            loop {
977                match s.read(&mut buf) {
978                    Ok(0) | Err(_) => break,
979                    Ok(_) => {
980                        let mut r = Vec::new();
981                        r.extend(pg_msg(b'1', &[]));
982                        r.extend(pg_msg(b'2', &[]));
983                        let mut rd = Vec::new();
984                        rd.extend(&1u16.to_be_bytes());
985                        rd.extend(b"c\x00");
986                        rd.extend(&0u32.to_be_bytes());
987                        rd.extend(&1u16.to_be_bytes());
988                        rd.extend(&23u32.to_be_bytes());
989                        rd.extend(&4i16.to_be_bytes());
990                        rd.extend(&(-1i32).to_be_bytes());
991                        rd.extend(&0u16.to_be_bytes());
992                        r.extend(pg_msg(b'T', &rd));
993                        let mut dr = Vec::new();
994                        dr.extend(&1u16.to_be_bytes());
995                        dr.extend(&1u32.to_be_bytes());
996                        dr.push(b'1');
997                        r.extend(pg_msg(b'D', &dr));
998                        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
999                        r.extend(pg_msg(b'Z', b"T"));
1000                        let _ = s.write_all(&r);
1001                    }
1002                }
1003            }
1004        });
1005        thread::sleep(Duration::from_millis(30));
1006        port
1007    }
1008
1009    fn spawn_error_status_server() -> u16 {
1010        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1011        let port = listener.local_addr().unwrap().port();
1012        thread::spawn(move || {
1013            let (mut s, _) = listener.accept().unwrap();
1014            let mut buf = [0u8; 4096];
1015            let _ = s.read(&mut buf).unwrap();
1016            let _ = s.write_all(&pg_auth(3, &[]));
1017            let _ = s.read(&mut buf).unwrap();
1018            let _ = s.write_all(&post_auth_ok());
1019            loop {
1020                match s.read(&mut buf) {
1021                    Ok(0) | Err(_) => break,
1022                    Ok(_) => {
1023                        let mut r = Vec::new();
1024                        r.extend(pg_msg(b'1', &[]));
1025                        r.extend(pg_msg(b'2', &[]));
1026                        let mut rd = Vec::new();
1027                        rd.extend(&1u16.to_be_bytes());
1028                        rd.extend(b"c\x00");
1029                        rd.extend(&0u32.to_be_bytes());
1030                        rd.extend(&1u16.to_be_bytes());
1031                        rd.extend(&23u32.to_be_bytes());
1032                        rd.extend(&4i16.to_be_bytes());
1033                        rd.extend(&(-1i32).to_be_bytes());
1034                        rd.extend(&0u16.to_be_bytes());
1035                        r.extend(pg_msg(b'T', &rd));
1036                        let mut dr = Vec::new();
1037                        dr.extend(&1u16.to_be_bytes());
1038                        dr.extend(&1u32.to_be_bytes());
1039                        dr.push(b'1');
1040                        r.extend(pg_msg(b'D', &dr));
1041                        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
1042                        r.extend(pg_msg(b'Z', b"E"));
1043                        let _ = s.write_all(&r);
1044                    }
1045                }
1046            }
1047        });
1048        thread::sleep(Duration::from_millis(30));
1049        port
1050    }
1051
1052    fn spawn_slow_partial_server() -> u16 {
1053        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1054        let port = listener.local_addr().unwrap().port();
1055        thread::spawn(move || {
1056            let (mut s, _) = listener.accept().unwrap();
1057            let mut buf = [0u8; 4096];
1058            let _ = s.read(&mut buf).unwrap();
1059            let _ = s.write_all(&pg_auth(3, &[]));
1060            let _ = s.read(&mut buf).unwrap();
1061            let _ = s.write_all(&post_auth_ok());
1062            match s.read(&mut buf) {
1063                Ok(0) | Err(_) => {}
1064                Ok(_) => {
1065                    let _ = s.write_all(&simple_query_response());
1066                }
1067            }
1068        });
1069        thread::sleep(Duration::from_millis(30));
1070        port
1071    }
1072
1073    fn spawn_rst_on_query_server() -> u16 {
1074        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1075        let port = listener.local_addr().unwrap().port();
1076        thread::spawn(move || {
1077            let (mut s, _) = listener.accept().unwrap();
1078            let mut buf = [0u8; 4096];
1079            let _ = s.read(&mut buf).unwrap();
1080            let _ = s.write_all(&pg_auth(3, &[]));
1081            let _ = s.read(&mut buf).unwrap();
1082            let _ = s.write_all(&post_auth_ok());
1083            match s.read(&mut buf) {
1084                Ok(0) | Err(_) => {}
1085                Ok(_) => {
1086                    drop(s);
1087                }
1088            }
1089        });
1090        thread::sleep(Duration::from_millis(30));
1091        port
1092    }
1093
1094    #[test]
1095    fn connect_query_ready_for_query_transaction_status() {
1096        let port = spawn_transaction_status_server();
1097        let mut conn = Connect::new(mock_config(port)).unwrap();
1098        let result = conn.query("SELECT 1");
1099        assert!(result.is_ok());
1100    }
1101
1102    #[test]
1103    fn connect_query_ready_for_query_error_status() {
1104        let port = spawn_error_status_server();
1105        let mut conn = Connect::new(mock_config(port)).unwrap();
1106        let result = conn.query("SELECT 1");
1107        assert!(result.is_ok());
1108    }
1109
1110    #[test]
1111    fn connect_query_server_closes_after_partial() {
1112        let port = spawn_slow_partial_server();
1113        let mut conn = Connect::new(mock_config(port)).unwrap();
1114        let r1 = conn.query("SELECT 1");
1115        assert!(r1.is_ok());
1116        let r2 = conn.query("SELECT 1");
1117        assert!(r2.is_err());
1118    }
1119
1120    #[test]
1121    fn connect_query_server_rst_returns_io_or_connection_error() {
1122        let port = spawn_rst_on_query_server();
1123        let mut conn = Connect::new(mock_config(port)).unwrap();
1124        let result = conn.query("SELECT 1");
1125        assert!(result.is_err());
1126    }
1127
1128    #[test]
1129    fn connect_read_would_block_max_retries() {
1130        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1131        let port = listener.local_addr().unwrap().port();
1132        thread::spawn(move || {
1133            let (mut s, _) = listener.accept().unwrap();
1134            let mut buf = [0u8; 4096];
1135            let _ = s.read(&mut buf);
1136            let _ = s.write_all(&pg_auth(3, &[]));
1137            let _ = s.read(&mut buf);
1138            let _ = s.write_all(&post_auth_ok());
1139            let _ = s.read(&mut buf);
1140            thread::sleep(Duration::from_secs(5));
1141        });
1142        thread::sleep(Duration::from_millis(30));
1143
1144        let mut conn = Connect::new(mock_config(port)).unwrap();
1145        conn.stream
1146            .set_read_timeout(Some(Duration::from_millis(1)))
1147            .ok();
1148        let result = conn.query("SELECT 1");
1149        assert!(result.is_err());
1150        let err_str = result.unwrap_err().to_string();
1151        assert!(
1152            err_str.contains("超时") || err_str.contains("Timeout") || err_str.contains("重试"),
1153            "expected timeout error, got: {err_str}"
1154        );
1155    }
1156
1157    #[test]
1158    fn connect_read_exceeds_max_message_size() {
1159        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1160        let port = listener.local_addr().unwrap().port();
1161        thread::spawn(move || {
1162            let (mut s, _) = listener.accept().unwrap();
1163            let mut buf = [0u8; 4096];
1164            let _ = s.read(&mut buf);
1165            let _ = s.write_all(&pg_auth(3, &[]));
1166            let _ = s.read(&mut buf);
1167            let _ = s.write_all(&post_auth_ok());
1168            let _ = s.read(&mut buf);
1169            let big = vec![b'X'; 256];
1170            let _ = s.write_all(&big);
1171            thread::sleep(Duration::from_secs(2));
1172        });
1173        thread::sleep(Duration::from_millis(30));
1174
1175        let mut conn = Connect::new(mock_config(port)).unwrap();
1176        let result = conn.query("SELECT 1");
1177        assert!(result.is_err());
1178        let err_str = result.unwrap_err().to_string();
1179        assert!(
1180            err_str.contains("最大") || err_str.contains("大小") || err_str.contains("size"),
1181            "expected max message size error, got: {err_str}"
1182        );
1183    }
1184
1185    #[test]
1186    fn connect_read_deadline_timeout() {
1187        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1188        let port = listener.local_addr().unwrap().port();
1189        thread::spawn(move || {
1190            let (mut s, _) = listener.accept().unwrap();
1191            let mut buf = [0u8; 4096];
1192            let _ = s.read(&mut buf);
1193            let _ = s.write_all(&pg_auth(3, &[]));
1194            let _ = s.read(&mut buf);
1195            let _ = s.write_all(&post_auth_ok());
1196            let _ = s.read(&mut buf);
1197            for _ in 0..200 {
1198                let _ = s.write_all(b"X");
1199                thread::sleep(Duration::from_millis(5));
1200            }
1201        });
1202        thread::sleep(Duration::from_millis(30));
1203
1204        let mut conn = Connect::new(mock_config(port)).unwrap();
1205        let result = conn.query("SELECT 1");
1206        assert!(result.is_err());
1207    }
1208
1209    #[test]
1210    fn connect_read_partial_auth_frame() {
1211        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1212        let port = listener.local_addr().unwrap().port();
1213        thread::spawn(move || {
1214            let (mut s, _) = listener.accept().unwrap();
1215            let mut buf = [0u8; 4096];
1216            let _ = s.read(&mut buf);
1217            let auth = pg_auth(3, &[]);
1218            let _ = s.write_all(&auth[..5]);
1219            thread::sleep(Duration::from_millis(50));
1220            let _ = s.write_all(&auth[5..]);
1221            let _ = s.read(&mut buf);
1222            let _ = s.write_all(&post_auth_ok());
1223            loop {
1224                match s.read(&mut buf) {
1225                    Ok(0) | Err(_) => break,
1226                    Ok(_) => {
1227                        let _ = s.write_all(&simple_query_response());
1228                    }
1229                }
1230            }
1231        });
1232        thread::sleep(Duration::from_millis(30));
1233
1234        let mut conn = Connect::new(mock_config(port)).unwrap();
1235        let result = conn.query("SELECT 1");
1236        assert!(result.is_ok());
1237    }
1238}