Skip to main content

br_pgsql/
connect.rs

1use crate::config::Config;
2use crate::error::PgsqlError;
3use crate::packet::{AuthStatus, Packet, SuccessMessage};
4use std::io::{Read, Write};
5use std::net::TcpStream;
6use std::time::{Duration, Instant};
7
8#[derive(Debug)]
9pub struct Connect {
10    pub(crate) stream: TcpStream,
11    packet: Packet,
12    auth_status: AuthStatus,
13    /// 上次使用时间,用于懒健康检查
14    last_used: Instant,
15}
16
17impl Connect {
18    /// 懒健康检查:peer_addr 检测 TCP 存活,空闲超过 30 秒才发 SELECT 1
19    pub fn is_valid(&mut self) -> bool {
20        if self.stream.peer_addr().is_err() {
21            return false;
22        }
23        #[cfg(not(test))]
24        const IDLE_THRESHOLD: Duration = Duration::from_secs(30);
25        #[cfg(test)]
26        const IDLE_THRESHOLD: Duration = Duration::from_millis(0);
27        if self.last_used.elapsed() > IDLE_THRESHOLD {
28            return self.query("SELECT 1").is_ok();
29        }
30        true
31    }
32
33    /// 仅检查 TCP 连接是否存活(不发送查询)
34    pub fn peer_valid(&self) -> bool {
35        self.stream.peer_addr().is_ok()
36    }
37
38    /// 更新最后使用时间
39    pub fn touch(&mut self) {
40        self.last_used = Instant::now();
41    }
42
43    /// 返回空闲时长
44    pub fn idle_elapsed(&self) -> Duration {
45        self.last_used.elapsed()
46    }
47
48    pub fn _close(&mut self) {
49        let _ = (&self.stream).write_all(&Packet::pack_terminate());
50        let _ = self.stream.shutdown(std::net::Shutdown::Both);
51    }
52
53    /// 设置 TCP Keepalive(60秒间隔,3次探测)
54    fn set_keepalive(stream: &TcpStream) -> Result<(), PgsqlError> {
55        use std::os::unix::io::{AsRawFd, FromRawFd};
56        let fd = stream.as_raw_fd();
57        let socket = unsafe { socket2::Socket::from_raw_fd(fd) };
58        let keepalive = socket2::TcpKeepalive::new()
59            .with_time(Duration::from_secs(60))
60            .with_interval(Duration::from_secs(15))
61            .with_retries(3);
62        let result = socket.set_tcp_keepalive(&keepalive);
63        // 不要让 socket2::Socket drop 关闭 fd,转回原始 fd
64        std::mem::forget(socket);
65        result.map_err(|e| PgsqlError::Connection(format!("设置 TCP Keepalive 失败: {}", e)))
66    }
67
68    pub fn new(mut config: Config) -> Result<Connect, PgsqlError> {
69        let stream =
70            TcpStream::connect(config.url()).map_err(|e| PgsqlError::Connection(e.to_string()))?;
71        // TCP 优化:禁用 Nagle 算法,减少小包延迟
72        stream
73            .set_nodelay(true)
74            .map_err(|e| PgsqlError::Connection(format!("设置 TCP_NODELAY 失败: {}", e)))?;
75        // TCP Keepalive:防止空闲连接被防火墙/NAT/PG服务端静默断开
76        Self::set_keepalive(&stream)?;
77        stream
78            .set_read_timeout(Some(Duration::from_secs(30)))
79            .map_err(|e| PgsqlError::Connection(format!("设置读取超时失败: {}", e)))?;
80        stream
81            .set_write_timeout(Some(Duration::from_secs(30)))
82            .map_err(|e| PgsqlError::Connection(format!("设置写入超时失败: {}", e)))?;
83        let _ = stream.peer_addr();
84
85        let mut connect = Self {
86            stream,
87            packet: Packet::new(config),
88            auth_status: AuthStatus::None,
89            last_used: Instant::now(),
90        };
91
92        connect.authenticate()?;
93
94        Ok(connect)
95    }
96
97    fn authenticate(&mut self) -> Result<(), PgsqlError> {
98        (&self.stream)
99            .write_all(&self.packet.pack_first())
100            .map_err(|e| PgsqlError::Auth(format!("发送 startup message 失败: {}", e)))?;
101
102        let data = self.read()?;
103        self.packet.unpack(data, 0)?;
104
105        if !self.packet.md5_salt.is_empty() {
106            self.md5_auth()?;
107        } else if self.packet.auth_mechanism.is_empty() && self.packet.md5_salt.is_empty() {
108            self.cleartext_auth()?;
109        } else {
110            self.scram_auth()?;
111        }
112
113        self.auth_status = AuthStatus::AuthenticationOk;
114        Ok(())
115    }
116
117    fn md5_auth(&mut self) -> Result<(), PgsqlError> {
118        (&self.stream)
119            .write_all(&self.packet.pack_md5_password())
120            .map_err(|e| PgsqlError::Auth(format!("发送 MD5 密码失败: {}", e)))?;
121
122        let data = self.read()?;
123        self.packet.unpack(data, 0)?;
124        Ok(())
125    }
126
127    fn cleartext_auth(&mut self) -> Result<(), PgsqlError> {
128        (&self.stream)
129            .write_all(&self.packet.pack_cleartext_password())
130            .map_err(|e| PgsqlError::Auth(format!("发送明文密码失败: {}", e)))?;
131
132        let data = self.read()?;
133        self.packet.unpack(data, 0)?;
134        Ok(())
135    }
136
137    fn scram_auth(&mut self) -> Result<(), PgsqlError> {
138        (&self.stream)
139            .write_all(&self.packet.pack_auth())
140            .map_err(|e| PgsqlError::Auth(format!("发送 SASL Initial Response 失败: {}", e)))?;
141
142        let data = self.read()?;
143        self.packet.unpack(data, 0)?;
144
145        (&self.stream)
146            .write_all(&self.packet.pack_auth_verify())
147            .map_err(|e| PgsqlError::Auth(format!("发送 SASL Verify 失败: {}", e)))?;
148
149        let data = self.read()?;
150        self.packet.unpack(data, 0)?;
151        Ok(())
152    }
153
154    fn read(&mut self) -> Result<Vec<u8>, PgsqlError> {
155        let mut msg = Vec::new();
156        let mut buf = [0u8; 4096];
157        let mut retry_count = 0;
158
159        #[cfg(not(test))]
160        const MAX_RETRIES: u32 = 100;
161        #[cfg(test)]
162        const MAX_RETRIES: u32 = 3;
163
164        #[cfg(not(test))]
165        const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024;
166        #[cfg(test)]
167        const MAX_MESSAGE_SIZE: usize = 128;
168
169        #[cfg(not(test))]
170        let deadline = std::time::Instant::now() + Duration::from_secs(300);
171        #[cfg(test)]
172        let deadline = std::time::Instant::now() + Duration::from_millis(200);
173
174        loop {
175            if std::time::Instant::now() >= deadline {
176                return Err(PgsqlError::Timeout("读取总超时".into()));
177            }
178
179            match (&self.stream).read(&mut buf) {
180                Ok(0) => return Err(PgsqlError::Connection("连接已关闭或服务端断开".into())),
181                Ok(n) => {
182                    if msg.len() + n > MAX_MESSAGE_SIZE {
183                        return Err(PgsqlError::Protocol("消息超过最大允许大小".into()));
184                    }
185                    msg.extend_from_slice(&buf[..n]);
186                    retry_count = 0;
187                }
188                Err(ref e)
189                    if e.kind() == std::io::ErrorKind::WouldBlock
190                        || e.kind() == std::io::ErrorKind::TimedOut =>
191                {
192                    retry_count += 1;
193                    if retry_count > MAX_RETRIES {
194                        return Err(PgsqlError::Timeout("读取超时,已达最大重试次数".into()));
195                    }
196                    std::thread::sleep(Duration::from_millis(10));
197                    continue;
198                }
199                Err(e) => return Err(PgsqlError::Io(e)),
200            };
201
202            if let AuthStatus::AuthenticationOk = self.auth_status {
203                if msg.ends_with(&[90, 0, 0, 0, 5, 73])
204                    || msg.ends_with(&[90, 0, 0, 0, 5, 84])
205                    || msg.ends_with(&[90, 0, 0, 0, 5, 69])
206                {
207                    break;
208                }
209            } else if msg.len() >= 5 {
210                let len_bytes = &msg[1..=4];
211                if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
212                    if msg.len() > len as usize {
213                        break;
214                    }
215                }
216            }
217        }
218
219        Ok(msg)
220    }
221
222    pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
223        (&self.stream)
224            .write_all(&self.packet.pack_query(sql))
225            .map_err(PgsqlError::Io)?;
226        let data = self.read()?;
227        self.last_used = Instant::now();
228        self.packet.unpack(data, 0)
229    }
230
231    pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
232        (&self.stream)
233            .write_all(&self.packet.pack_execute(sql))
234            .map_err(PgsqlError::Io)?;
235        let data = self.read()?;
236        self.last_used = Instant::now();
237        self.packet.unpack(data, 0)
238    }
239
240    /// 参数化查询
241    pub fn query_params(
242        &mut self,
243        sql: &str,
244        params: &[Option<&str>],
245    ) -> Result<SuccessMessage, PgsqlError> {
246        (&self.stream)
247            .write_all(&self.packet.pack_query_params(sql, params))
248            .map_err(PgsqlError::Io)?;
249
250        let data = self.read()?;
251        self.last_used = Instant::now();
252        self.packet.unpack(data, 0)
253    }
254
255    /// 参数化执行
256    pub fn execute_params(
257        &mut self,
258        sql: &str,
259        params: &[Option<&str>],
260    ) -> Result<SuccessMessage, PgsqlError> {
261        (&self.stream)
262            .write_all(&self.packet.pack_execute_params(sql, params))
263            .map_err(PgsqlError::Io)?;
264        let data = self.read()?;
265        self.last_used = Instant::now();
266        self.packet.unpack(data, 0)
267    }
268
269    /// 参数化查询(便捷版,所有参数非 NULL)
270    pub fn query_str(&mut self, sql: &str, params: &[&str]) -> Result<SuccessMessage, PgsqlError> {
271        let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
272        self.query_params(sql, &opts)
273    }
274
275    /// 参数化执行(便捷版,所有参数非 NULL)
276    pub fn execute_str(
277        &mut self,
278        sql: &str,
279        params: &[&str],
280    ) -> Result<SuccessMessage, PgsqlError> {
281        let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
282        self.execute_params(sql, &opts)
283    }
284}
285
286impl Drop for Connect {
287    fn drop(&mut self) {
288        let _ = (&self.stream).write_all(&Packet::pack_terminate());
289        let _ = self.stream.shutdown(std::net::Shutdown::Both);
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use std::net::TcpListener;
297    use std::thread;
298
299    // ── wire-protocol helpers ──────────────────────────────────────────
300
301    /// Build a single PG backend message: type_byte | len(4) | payload
302    fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
303        let mut m = Vec::with_capacity(5 + payload.len());
304        m.push(tag);
305        m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
306        m.extend_from_slice(payload);
307        m
308    }
309
310    /// Build an Authentication message (tag 'R') with given auth_type + extra payload
311    fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
312        let mut body = Vec::new();
313        body.extend(&auth_type.to_be_bytes());
314        body.extend_from_slice(extra);
315        pg_msg(b'R', &body)
316    }
317
318    /// AuthenticationOk (R, type=0)
319    fn auth_ok() -> Vec<u8> {
320        pg_auth(0, &[])
321    }
322
323    /// ParameterStatus for server_version=15.0
324    fn param_status() -> Vec<u8> {
325        pg_msg(b'S', b"server_version\x0015.0\x00")
326    }
327
328    /// BackendKeyData (process_id=1, secret_key=2)
329    fn backend_key() -> Vec<u8> {
330        let mut p = Vec::new();
331        p.extend(&1u32.to_be_bytes());
332        p.extend(&2u32.to_be_bytes());
333        pg_msg(b'K', &p)
334    }
335
336    /// ReadyForQuery (status = Idle)
337    fn ready_for_query() -> Vec<u8> {
338        pg_msg(b'Z', b"I")
339    }
340
341    /// The standard tail sent after auth succeeds: AuthOk + param + key + ready
342    fn post_auth_ok() -> Vec<u8> {
343        let mut v = Vec::new();
344        v.extend(auth_ok());
345        v.extend(param_status());
346        v.extend(backend_key());
347        v.extend(ready_for_query());
348        v
349    }
350
351    /// Build a simple query response: ParseComplete + BindComplete +
352    /// RowDescription(1 int4 col "c") + DataRow("1") + CommandComplete("SELECT 1") + ReadyForQuery
353    fn simple_query_response() -> Vec<u8> {
354        let mut r = Vec::new();
355        // ParseComplete
356        r.extend(pg_msg(b'1', &[]));
357        // BindComplete
358        r.extend(pg_msg(b'2', &[]));
359        // RowDescription – 1 field "c", type_oid=23 (int4)
360        let mut rd = Vec::new();
361        rd.extend(&1u16.to_be_bytes()); // field count
362        rd.extend(b"c\x00"); // name
363        rd.extend(&0u32.to_be_bytes()); // table oid
364        rd.extend(&1u16.to_be_bytes()); // column index
365        rd.extend(&23u32.to_be_bytes()); // type oid (int4)
366        rd.extend(&4i16.to_be_bytes()); // column length
367        rd.extend(&(-1i32).to_be_bytes()); // type modifier
368        rd.extend(&0u16.to_be_bytes()); // format (text)
369        r.extend(pg_msg(b'T', &rd));
370        // DataRow – 1 field, value "1"
371        let mut dr = Vec::new();
372        dr.extend(&1u16.to_be_bytes());
373        dr.extend(&1u32.to_be_bytes()); // length of value
374        dr.push(b'1');
375        r.extend(pg_msg(b'D', &dr));
376        // CommandComplete
377        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
378        // ReadyForQuery
379        r.extend(ready_for_query());
380        r
381    }
382
383    /// Build an execute response (no rows): ParseComplete + BindComplete +
384    /// NoData + CommandComplete("UPDATE 3") + ReadyForQuery
385    fn execute_response() -> Vec<u8> {
386        let mut r = Vec::new();
387        r.extend(pg_msg(b'1', &[]));
388        r.extend(pg_msg(b'2', &[]));
389        r.extend(pg_msg(b'n', &[])); // NoData
390        r.extend(pg_msg(b'C', b"UPDATE 3\x00"));
391        r.extend(ready_for_query());
392        r
393    }
394
395    /// Build a parameterized query response: ParseComplete + ParameterDescription + BindComplete +
396    /// RowDescription(1 int4 col "p") + DataRow("42") + CommandComplete("SELECT 1") + ReadyForQuery
397    fn query_params_response() -> Vec<u8> {
398        let mut r = Vec::new();
399        r.extend(pg_msg(b'1', &[]));
400
401        let mut pd = Vec::new();
402        pd.extend(&1u16.to_be_bytes());
403        pd.extend(&23u32.to_be_bytes());
404        r.extend(pg_msg(b't', &pd));
405
406        r.extend(pg_msg(b'2', &[]));
407
408        let mut rd = Vec::new();
409        rd.extend(&1u16.to_be_bytes());
410        rd.extend(b"p\x00");
411        rd.extend(&0u32.to_be_bytes());
412        rd.extend(&1u16.to_be_bytes());
413        rd.extend(&23u32.to_be_bytes());
414        rd.extend(&4i16.to_be_bytes());
415        rd.extend(&(-1i32).to_be_bytes());
416        rd.extend(&0u16.to_be_bytes());
417        r.extend(pg_msg(b'T', &rd));
418
419        let mut dr = Vec::new();
420        dr.extend(&1u16.to_be_bytes());
421        dr.extend(&2u32.to_be_bytes());
422        dr.extend(b"42");
423        r.extend(pg_msg(b'D', &dr));
424
425        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
426        r.extend(ready_for_query());
427        r
428    }
429
430    /// Build a parameterized execute response: ParseComplete + ParameterDescription + BindComplete +
431    /// NoData + CommandComplete("UPDATE 1") + ReadyForQuery
432    fn execute_params_response() -> Vec<u8> {
433        let mut r = Vec::new();
434        r.extend(pg_msg(b'1', &[]));
435
436        let mut pd = Vec::new();
437        pd.extend(&1u16.to_be_bytes());
438        pd.extend(&23u32.to_be_bytes());
439        r.extend(pg_msg(b't', &pd));
440
441        r.extend(pg_msg(b'2', &[]));
442        r.extend(pg_msg(b'n', &[]));
443        r.extend(pg_msg(b'C', b"UPDATE 1\x00"));
444        r.extend(ready_for_query());
445        r
446    }
447
448    /// Build a parameterized query response with NULL row value.
449    fn query_params_null_response() -> Vec<u8> {
450        let mut r = Vec::new();
451        r.extend(pg_msg(b'1', &[]));
452
453        let mut pd = Vec::new();
454        pd.extend(&1u16.to_be_bytes());
455        pd.extend(&25u32.to_be_bytes());
456        r.extend(pg_msg(b't', &pd));
457
458        r.extend(pg_msg(b'2', &[]));
459
460        let mut rd = Vec::new();
461        rd.extend(&1u16.to_be_bytes());
462        rd.extend(b"n\x00");
463        rd.extend(&0u32.to_be_bytes());
464        rd.extend(&1u16.to_be_bytes());
465        rd.extend(&25u32.to_be_bytes());
466        rd.extend(&(-1i16).to_be_bytes());
467        rd.extend(&(-1i32).to_be_bytes());
468        rd.extend(&0u16.to_be_bytes());
469        r.extend(pg_msg(b'T', &rd));
470
471        let mut dr = Vec::new();
472        dr.extend(&1u16.to_be_bytes());
473        dr.extend(&(-1i32).to_be_bytes());
474        r.extend(pg_msg(b'D', &dr));
475
476        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
477        r.extend(ready_for_query());
478        r
479    }
480
481    /// Build an ErrorResponse for query phase
482    fn error_response() -> Vec<u8> {
483        let mut payload = Vec::new();
484        payload.push(b'C');
485        payload.extend(b"42601\x00");
486        payload.push(b'M');
487        payload.extend(b"syntax error\x00");
488        payload.push(0);
489        let mut r = Vec::new();
490        r.extend(pg_msg(b'1', &[]));
491        r.extend(pg_msg(b'2', &[]));
492        r.extend(pg_msg(b'E', &payload));
493        r.extend(ready_for_query());
494        r
495    }
496
497    // ── mock server spawners ───────────────────────────────────────────
498
499    /// Config pointing at 127.0.0.1:<port>
500    fn mock_config(port: u16) -> Config {
501        Config {
502            debug: false,
503            hostname: "127.0.0.1".into(),
504            hostport: port as i32,
505            username: "u".into(),
506            userpass: "p".into(),
507            database: "d".into(),
508            charset: "utf8".into(),
509            pool_max: 5,
510        }
511    }
512
513    /// Spawn a mock PG server that does **cleartext** auth.
514    /// Returns the port.  The server handles one connection.
515    fn spawn_cleartext_server() -> u16 {
516        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
517        let port = listener.local_addr().unwrap().port();
518        thread::spawn(move || {
519            let (mut s, _) = listener.accept().unwrap();
520            let mut buf = [0u8; 4096];
521            // 1. read startup
522            let _ = s.read(&mut buf).unwrap();
523            // 2. send CleartextPassword request (auth_type=3)
524            let _ = s.write_all(&pg_auth(3, &[]));
525            // 3. read password message
526            let _ = s.read(&mut buf).unwrap();
527            // 4. send AuthOk + params + ready
528            let _ = s.write_all(&post_auth_ok());
529            // keep connection alive for queries
530            loop {
531                match s.read(&mut buf) {
532                    Ok(0) | Err(_) => break,
533                    Ok(_) => {
534                        let _ = s.write_all(&simple_query_response());
535                    }
536                }
537            }
538        });
539        thread::sleep(Duration::from_millis(30));
540        port
541    }
542
543    /// Spawn a mock PG server that does **MD5** auth.
544    fn spawn_md5_server() -> u16 {
545        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
546        let port = listener.local_addr().unwrap().port();
547        thread::spawn(move || {
548            let (mut s, _) = listener.accept().unwrap();
549            let mut buf = [0u8; 4096];
550            // 1. read startup
551            let _ = s.read(&mut buf).unwrap();
552            // 2. send MD5Password request (auth_type=5) + 4-byte salt
553            let _ = s.write_all(&pg_auth(5, &[0xAA, 0xBB, 0xCC, 0xDD]));
554            // 3. read md5 password
555            let _ = s.read(&mut buf).unwrap();
556            // 4. send AuthOk + params + ready
557            let _ = s.write_all(&post_auth_ok());
558            loop {
559                match s.read(&mut buf) {
560                    Ok(0) | Err(_) => break,
561                    Ok(_) => {
562                        let _ = s.write_all(&simple_query_response());
563                    }
564                }
565            }
566        });
567        thread::sleep(Duration::from_millis(30));
568        port
569    }
570
571    /// Spawn a mock PG server that does **SCRAM-SHA-256** auth.
572    fn spawn_scram_server() -> u16 {
573        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
574        let port = listener.local_addr().unwrap().port();
575        thread::spawn(move || {
576            let (mut s, _) = listener.accept().unwrap();
577            let mut buf = [0u8; 4096];
578            // 1. read startup
579            let _ = s.read(&mut buf).unwrap();
580            // 2. send SASL auth (auth_type=10, mechanism)
581            let _ = s.write_all(&pg_auth(10, b"SCRAM-SHA-256\x00\x00"));
582            // 3. read SASLInitialResponse – extract client nonce
583            let n = s.read(&mut buf).unwrap();
584            let payload = &buf[..n];
585            // find "n=,r=" in the payload to extract client nonce
586            let text = String::from_utf8_lossy(payload);
587            let client_nonce = text.split("r=").nth(1).unwrap_or("clientnonce").to_string();
588            // 4. send SCRAM challenge (auth_type=11)
589            let challenge = format!("r={client_nonce}SERVERNONCE,s=c2FsdA==,i=4096");
590            let _ = s.write_all(&pg_auth(11, challenge.as_bytes()));
591            // 5. read SCRAM client final
592            let _ = s.read(&mut buf).unwrap();
593            // 6. send SCRAM complete (auth_type=12) + AuthOk + params + ready
594            let mut resp = Vec::new();
595            resp.extend(pg_auth(12, b"v=dummyproof"));
596            resp.extend(auth_ok());
597            resp.extend(param_status());
598            resp.extend(backend_key());
599            resp.extend(ready_for_query());
600            let _ = s.write_all(&resp);
601            loop {
602                match s.read(&mut buf) {
603                    Ok(0) | Err(_) => break,
604                    Ok(_) => {
605                        let _ = s.write_all(&simple_query_response());
606                    }
607                }
608            }
609        });
610        thread::sleep(Duration::from_millis(30));
611        port
612    }
613
614    /// Spawn a server that accepts connection then immediately closes (EOF).
615    fn spawn_eof_server() -> u16 {
616        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
617        let port = listener.local_addr().unwrap().port();
618        thread::spawn(move || {
619            let (s, _) = listener.accept().unwrap();
620            drop(s); // close immediately
621        });
622        thread::sleep(Duration::from_millis(30));
623        port
624    }
625
626    /// Spawn a server that sends an ErrorResponse after startup.
627    fn spawn_auth_error_server() -> u16 {
628        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
629        let port = listener.local_addr().unwrap().port();
630        thread::spawn(move || {
631            let (mut s, _) = listener.accept().unwrap();
632            let mut buf = [0u8; 4096];
633            let _ = s.read(&mut buf).unwrap();
634            // Send ErrorResponse
635            let mut payload = Vec::new();
636            payload.push(b'C');
637            payload.extend(b"28P01\x00");
638            payload.push(b'M');
639            payload.extend(b"password authentication failed\x00");
640            payload.push(0);
641            let _ = s.write_all(&pg_msg(b'E', &payload));
642        });
643        thread::sleep(Duration::from_millis(30));
644        port
645    }
646
647    /// Spawn a cleartext server that responds with error on query.
648    fn spawn_query_error_server() -> u16 {
649        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
650        let port = listener.local_addr().unwrap().port();
651        thread::spawn(move || {
652            let (mut s, _) = listener.accept().unwrap();
653            let mut buf = [0u8; 4096];
654            // auth
655            let _ = s.read(&mut buf).unwrap();
656            let _ = s.write_all(&pg_auth(3, &[]));
657            let _ = s.read(&mut buf).unwrap();
658            let _ = s.write_all(&post_auth_ok());
659            // query → error
660            loop {
661                match s.read(&mut buf) {
662                    Ok(0) | Err(_) => break,
663                    Ok(_) => {
664                        let _ = s.write_all(&error_response());
665                    }
666                }
667            }
668        });
669        thread::sleep(Duration::from_millis(30));
670        port
671    }
672
673    /// Spawn a cleartext server that responds with execute response.
674    fn spawn_execute_server() -> u16 {
675        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
676        let port = listener.local_addr().unwrap().port();
677        thread::spawn(move || {
678            let (mut s, _) = listener.accept().unwrap();
679            let mut buf = [0u8; 4096];
680            // auth
681            let _ = s.read(&mut buf).unwrap();
682            let _ = s.write_all(&pg_auth(3, &[]));
683            let _ = s.read(&mut buf).unwrap();
684            let _ = s.write_all(&post_auth_ok());
685            // queries
686            loop {
687                match s.read(&mut buf) {
688                    Ok(0) | Err(_) => break,
689                    Ok(_) => {
690                        let _ = s.write_all(&execute_response());
691                    }
692                }
693            }
694        });
695        thread::sleep(Duration::from_millis(30));
696        port
697    }
698
699    /// Spawn a cleartext server that responds with parameterized query response.
700    fn spawn_query_params_server() -> u16 {
701        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
702        let port = listener.local_addr().unwrap().port();
703        thread::spawn(move || {
704            let (mut s, _) = listener.accept().unwrap();
705            let mut buf = [0u8; 4096];
706            let _ = s.read(&mut buf).unwrap();
707            let _ = s.write_all(&pg_auth(3, &[]));
708            let _ = s.read(&mut buf).unwrap();
709            let _ = s.write_all(&post_auth_ok());
710            loop {
711                match s.read(&mut buf) {
712                    Ok(0) | Err(_) => break,
713                    Ok(_) => {
714                        let _ = s.write_all(&query_params_response());
715                    }
716                }
717            }
718        });
719        thread::sleep(Duration::from_millis(30));
720        port
721    }
722
723    fn spawn_params_server() -> u16 {
724        spawn_query_params_server()
725    }
726
727    /// Spawn a cleartext server that responds with parameterized execute response.
728    fn spawn_execute_params_server() -> u16 {
729        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
730        let port = listener.local_addr().unwrap().port();
731        thread::spawn(move || {
732            let (mut s, _) = listener.accept().unwrap();
733            let mut buf = [0u8; 4096];
734            let _ = s.read(&mut buf).unwrap();
735            let _ = s.write_all(&pg_auth(3, &[]));
736            let _ = s.read(&mut buf).unwrap();
737            let _ = s.write_all(&post_auth_ok());
738            loop {
739                match s.read(&mut buf) {
740                    Ok(0) | Err(_) => break,
741                    Ok(_) => {
742                        let _ = s.write_all(&execute_params_response());
743                    }
744                }
745            }
746        });
747        thread::sleep(Duration::from_millis(30));
748        port
749    }
750
751    /// Spawn a cleartext server that returns NULL in parameterized query result.
752    fn spawn_query_params_null_server() -> u16 {
753        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
754        let port = listener.local_addr().unwrap().port();
755        thread::spawn(move || {
756            let (mut s, _) = listener.accept().unwrap();
757            let mut buf = [0u8; 4096];
758            let _ = s.read(&mut buf).unwrap();
759            let _ = s.write_all(&pg_auth(3, &[]));
760            let _ = s.read(&mut buf).unwrap();
761            let _ = s.write_all(&post_auth_ok());
762            loop {
763                match s.read(&mut buf) {
764                    Ok(0) | Err(_) => break,
765                    Ok(_) => {
766                        let _ = s.write_all(&query_params_null_response());
767                    }
768                }
769            }
770        });
771        thread::sleep(Duration::from_millis(30));
772        port
773    }
774
775    // ── tests ──────────────────────────────────────────────────────────
776
777    #[test]
778    fn connect_cleartext_auth_success() {
779        let port = spawn_cleartext_server();
780        let conn = Connect::new(mock_config(port));
781        assert!(conn.is_ok());
782    }
783
784    #[test]
785    fn connect_md5_auth_success() {
786        let port = spawn_md5_server();
787        let conn = Connect::new(mock_config(port));
788        assert!(conn.is_ok());
789    }
790
791    #[test]
792    fn connect_scram_auth_success() {
793        let port = spawn_scram_server();
794        let conn = Connect::new(mock_config(port));
795        assert!(conn.is_ok());
796    }
797
798    #[test]
799    fn connect_connection_refused() {
800        // port 1 is almost certainly not listening
801        let cfg = mock_config(1);
802        let result = Connect::new(cfg);
803        assert!(result.is_err());
804        match result.unwrap_err() {
805            PgsqlError::Connection(_) => {}
806            other => panic!("expected Connection error, got {other:?}"),
807        }
808    }
809
810    #[test]
811    fn connect_server_closes_immediately() {
812        let port = spawn_eof_server();
813        let result = Connect::new(mock_config(port));
814        assert!(result.is_err());
815    }
816
817    #[test]
818    fn connect_auth_error_from_server() {
819        let port = spawn_auth_error_server();
820        let result = Connect::new(mock_config(port));
821        assert!(result.is_err());
822    }
823
824    #[test]
825    fn connect_query_success() {
826        let port = spawn_cleartext_server();
827        let mut conn = Connect::new(mock_config(port)).unwrap();
828        let result = conn.query("SELECT 1");
829        assert!(result.is_ok());
830        let msg = result.unwrap();
831        assert_eq!(msg.rows.len(), 1);
832        assert_eq!(msg.rows[0]["c"].as_i32(), Some(1));
833    }
834
835    #[test]
836    fn connect_execute_success() {
837        let port = spawn_execute_server();
838        let mut conn = Connect::new(mock_config(port)).unwrap();
839        let result = conn.execute("UPDATE t SET x=1");
840        assert!(result.is_ok());
841        let msg = result.unwrap();
842        assert_eq!(msg.affect_count, 3);
843        assert_eq!(msg.tag, "UPDATE 3");
844    }
845
846    #[test]
847    fn connect_query_params_success() {
848        let port = spawn_query_params_server();
849        let mut conn = Connect::new(mock_config(port)).unwrap();
850        let result = conn.query_params("SELECT $1::int", &[Some("42")]);
851        assert!(result.is_ok());
852        let msg = result.unwrap();
853        assert!(!msg.param_oids.is_empty());
854        assert_eq!(msg.rows.len(), 1);
855        assert_eq!(msg.rows[0]["p"].as_i32(), Some(42));
856    }
857
858    #[test]
859    fn connect_execute_params_success() {
860        let port = spawn_execute_params_server();
861        let mut conn = Connect::new(mock_config(port)).unwrap();
862        let result = conn.execute_params("UPDATE t SET x=$1", &[Some("42")]);
863        assert!(result.is_ok());
864        let msg = result.unwrap();
865        assert!(!msg.param_oids.is_empty());
866        assert_eq!(msg.affect_count, 1);
867        assert_eq!(msg.tag, "UPDATE 1");
868    }
869
870    #[test]
871    fn connect_query_str_success() {
872        let port = spawn_params_server();
873        let mut conn = Connect::new(mock_config(port)).unwrap();
874        let result = conn.query_str("SELECT $1::int", &["42"]);
875        assert!(result.is_ok());
876        let msg = result.unwrap();
877        assert!(!msg.param_oids.is_empty());
878        assert_eq!(msg.rows.len(), 1);
879    }
880
881    #[test]
882    fn connect_execute_str_success() {
883        let port = spawn_execute_params_server();
884        let mut conn = Connect::new(mock_config(port)).unwrap();
885        let result = conn.execute_str("UPDATE t SET x=$1", &["1"]);
886        assert!(result.is_ok());
887        let msg = result.unwrap();
888        assert!(!msg.param_oids.is_empty());
889        assert_eq!(msg.affect_count, 1);
890    }
891
892    #[test]
893    fn connect_query_params_with_null() {
894        let port = spawn_query_params_null_server();
895        let mut conn = Connect::new(mock_config(port)).unwrap();
896        let result = conn.query_params("SELECT $1::text", &[None]);
897        assert!(result.is_ok());
898        let msg = result.unwrap();
899        assert!(!msg.param_oids.is_empty());
900        assert_eq!(msg.rows.len(), 1);
901        assert_eq!(msg.rows[0]["n"], "");
902    }
903
904    #[test]
905    fn connect_query_params_empty_string_vs_null() {
906        let port = spawn_params_server();
907        let mut conn = Connect::new(mock_config(port)).unwrap();
908
909        // 空字符串参数
910        let r1 = conn.query_params("SELECT $1::text", &[Some("")]);
911        assert!(r1.is_ok());
912
913        // NULL 参数
914        let r2 = conn.query_params("SELECT $1::text", &[None]);
915        assert!(r2.is_ok());
916    }
917
918    #[test]
919    fn connect_query_returns_error() {
920        let port = spawn_query_error_server();
921        let mut conn = Connect::new(mock_config(port)).unwrap();
922        let result = conn.query("BAD SQL");
923        assert!(result.is_err());
924    }
925
926    #[test]
927    fn connect_is_valid_true() {
928        let port = spawn_cleartext_server();
929        let mut conn = Connect::new(mock_config(port)).unwrap();
930        assert!(conn.is_valid());
931    }
932
933    #[test]
934    fn connect_is_valid_false_after_close() {
935        let port = spawn_cleartext_server();
936        let mut conn = Connect::new(mock_config(port)).unwrap();
937        conn._close();
938        // After closing, is_valid should return false
939        assert!(!conn.is_valid());
940    }
941
942    #[test]
943    fn connect_close_does_not_panic() {
944        let port = spawn_cleartext_server();
945        let mut conn = Connect::new(mock_config(port)).unwrap();
946        conn._close();
947        // calling close again should not panic
948        conn._close();
949    }
950
951    #[test]
952    fn connect_drop_does_not_panic() {
953        let port = spawn_cleartext_server();
954        let conn = Connect::new(mock_config(port)).unwrap();
955        drop(conn);
956    }
957
958    fn spawn_transaction_status_server() -> u16 {
959        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
960        let port = listener.local_addr().unwrap().port();
961        thread::spawn(move || {
962            let (mut s, _) = listener.accept().unwrap();
963            let mut buf = [0u8; 4096];
964            let _ = s.read(&mut buf).unwrap();
965            let _ = s.write_all(&pg_auth(3, &[]));
966            let _ = s.read(&mut buf).unwrap();
967            let _ = s.write_all(&post_auth_ok());
968            loop {
969                match s.read(&mut buf) {
970                    Ok(0) | Err(_) => break,
971                    Ok(_) => {
972                        let mut r = Vec::new();
973                        r.extend(pg_msg(b'1', &[]));
974                        r.extend(pg_msg(b'2', &[]));
975                        let mut rd = Vec::new();
976                        rd.extend(&1u16.to_be_bytes());
977                        rd.extend(b"c\x00");
978                        rd.extend(&0u32.to_be_bytes());
979                        rd.extend(&1u16.to_be_bytes());
980                        rd.extend(&23u32.to_be_bytes());
981                        rd.extend(&4i16.to_be_bytes());
982                        rd.extend(&(-1i32).to_be_bytes());
983                        rd.extend(&0u16.to_be_bytes());
984                        r.extend(pg_msg(b'T', &rd));
985                        let mut dr = Vec::new();
986                        dr.extend(&1u16.to_be_bytes());
987                        dr.extend(&1u32.to_be_bytes());
988                        dr.push(b'1');
989                        r.extend(pg_msg(b'D', &dr));
990                        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
991                        r.extend(pg_msg(b'Z', b"T"));
992                        let _ = s.write_all(&r);
993                    }
994                }
995            }
996        });
997        thread::sleep(Duration::from_millis(30));
998        port
999    }
1000
1001    fn spawn_error_status_server() -> u16 {
1002        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1003        let port = listener.local_addr().unwrap().port();
1004        thread::spawn(move || {
1005            let (mut s, _) = listener.accept().unwrap();
1006            let mut buf = [0u8; 4096];
1007            let _ = s.read(&mut buf).unwrap();
1008            let _ = s.write_all(&pg_auth(3, &[]));
1009            let _ = s.read(&mut buf).unwrap();
1010            let _ = s.write_all(&post_auth_ok());
1011            loop {
1012                match s.read(&mut buf) {
1013                    Ok(0) | Err(_) => break,
1014                    Ok(_) => {
1015                        let mut r = Vec::new();
1016                        r.extend(pg_msg(b'1', &[]));
1017                        r.extend(pg_msg(b'2', &[]));
1018                        let mut rd = Vec::new();
1019                        rd.extend(&1u16.to_be_bytes());
1020                        rd.extend(b"c\x00");
1021                        rd.extend(&0u32.to_be_bytes());
1022                        rd.extend(&1u16.to_be_bytes());
1023                        rd.extend(&23u32.to_be_bytes());
1024                        rd.extend(&4i16.to_be_bytes());
1025                        rd.extend(&(-1i32).to_be_bytes());
1026                        rd.extend(&0u16.to_be_bytes());
1027                        r.extend(pg_msg(b'T', &rd));
1028                        let mut dr = Vec::new();
1029                        dr.extend(&1u16.to_be_bytes());
1030                        dr.extend(&1u32.to_be_bytes());
1031                        dr.push(b'1');
1032                        r.extend(pg_msg(b'D', &dr));
1033                        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
1034                        r.extend(pg_msg(b'Z', b"E"));
1035                        let _ = s.write_all(&r);
1036                    }
1037                }
1038            }
1039        });
1040        thread::sleep(Duration::from_millis(30));
1041        port
1042    }
1043
1044    fn spawn_slow_partial_server() -> u16 {
1045        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1046        let port = listener.local_addr().unwrap().port();
1047        thread::spawn(move || {
1048            let (mut s, _) = listener.accept().unwrap();
1049            let mut buf = [0u8; 4096];
1050            let _ = s.read(&mut buf).unwrap();
1051            let _ = s.write_all(&pg_auth(3, &[]));
1052            let _ = s.read(&mut buf).unwrap();
1053            let _ = s.write_all(&post_auth_ok());
1054            match s.read(&mut buf) {
1055                Ok(0) | Err(_) => {}
1056                Ok(_) => {
1057                    let _ = s.write_all(&simple_query_response());
1058                }
1059            }
1060        });
1061        thread::sleep(Duration::from_millis(30));
1062        port
1063    }
1064
1065    fn spawn_rst_on_query_server() -> u16 {
1066        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1067        let port = listener.local_addr().unwrap().port();
1068        thread::spawn(move || {
1069            let (mut s, _) = listener.accept().unwrap();
1070            let mut buf = [0u8; 4096];
1071            let _ = s.read(&mut buf).unwrap();
1072            let _ = s.write_all(&pg_auth(3, &[]));
1073            let _ = s.read(&mut buf).unwrap();
1074            let _ = s.write_all(&post_auth_ok());
1075            match s.read(&mut buf) {
1076                Ok(0) | Err(_) => {}
1077                Ok(_) => {
1078                    drop(s);
1079                }
1080            }
1081        });
1082        thread::sleep(Duration::from_millis(30));
1083        port
1084    }
1085
1086    #[test]
1087    fn connect_query_ready_for_query_transaction_status() {
1088        let port = spawn_transaction_status_server();
1089        let mut conn = Connect::new(mock_config(port)).unwrap();
1090        let result = conn.query("SELECT 1");
1091        assert!(result.is_ok());
1092    }
1093
1094    #[test]
1095    fn connect_query_ready_for_query_error_status() {
1096        let port = spawn_error_status_server();
1097        let mut conn = Connect::new(mock_config(port)).unwrap();
1098        let result = conn.query("SELECT 1");
1099        assert!(result.is_ok());
1100    }
1101
1102    #[test]
1103    fn connect_query_server_closes_after_partial() {
1104        let port = spawn_slow_partial_server();
1105        let mut conn = Connect::new(mock_config(port)).unwrap();
1106        let r1 = conn.query("SELECT 1");
1107        assert!(r1.is_ok());
1108        let r2 = conn.query("SELECT 1");
1109        assert!(r2.is_err());
1110    }
1111
1112    #[test]
1113    fn connect_query_server_rst_returns_io_or_connection_error() {
1114        let port = spawn_rst_on_query_server();
1115        let mut conn = Connect::new(mock_config(port)).unwrap();
1116        let result = conn.query("SELECT 1");
1117        assert!(result.is_err());
1118    }
1119
1120    #[test]
1121    fn connect_read_would_block_max_retries() {
1122        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1123        let port = listener.local_addr().unwrap().port();
1124        thread::spawn(move || {
1125            let (mut s, _) = listener.accept().unwrap();
1126            let mut buf = [0u8; 4096];
1127            let _ = s.read(&mut buf);
1128            let _ = s.write_all(&pg_auth(3, &[]));
1129            let _ = s.read(&mut buf);
1130            let _ = s.write_all(&post_auth_ok());
1131            let _ = s.read(&mut buf);
1132            thread::sleep(Duration::from_secs(5));
1133        });
1134        thread::sleep(Duration::from_millis(30));
1135
1136        let mut conn = Connect::new(mock_config(port)).unwrap();
1137        conn.stream
1138            .set_read_timeout(Some(Duration::from_millis(1)))
1139            .ok();
1140        let result = conn.query("SELECT 1");
1141        assert!(result.is_err());
1142        let err_str = result.unwrap_err().to_string();
1143        assert!(
1144            err_str.contains("超时") || err_str.contains("Timeout") || err_str.contains("重试"),
1145            "expected timeout error, got: {err_str}"
1146        );
1147    }
1148
1149    #[test]
1150    fn connect_read_exceeds_max_message_size() {
1151        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1152        let port = listener.local_addr().unwrap().port();
1153        thread::spawn(move || {
1154            let (mut s, _) = listener.accept().unwrap();
1155            let mut buf = [0u8; 4096];
1156            let _ = s.read(&mut buf);
1157            let _ = s.write_all(&pg_auth(3, &[]));
1158            let _ = s.read(&mut buf);
1159            let _ = s.write_all(&post_auth_ok());
1160            let _ = s.read(&mut buf);
1161            let big = vec![b'X'; 256];
1162            let _ = s.write_all(&big);
1163            thread::sleep(Duration::from_secs(2));
1164        });
1165        thread::sleep(Duration::from_millis(30));
1166
1167        let mut conn = Connect::new(mock_config(port)).unwrap();
1168        let result = conn.query("SELECT 1");
1169        assert!(result.is_err());
1170        let err_str = result.unwrap_err().to_string();
1171        assert!(
1172            err_str.contains("最大") || err_str.contains("大小") || err_str.contains("size"),
1173            "expected max message size error, got: {err_str}"
1174        );
1175    }
1176
1177    #[test]
1178    fn connect_read_deadline_timeout() {
1179        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1180        let port = listener.local_addr().unwrap().port();
1181        thread::spawn(move || {
1182            let (mut s, _) = listener.accept().unwrap();
1183            let mut buf = [0u8; 4096];
1184            let _ = s.read(&mut buf);
1185            let _ = s.write_all(&pg_auth(3, &[]));
1186            let _ = s.read(&mut buf);
1187            let _ = s.write_all(&post_auth_ok());
1188            let _ = s.read(&mut buf);
1189            for _ in 0..200 {
1190                let _ = s.write_all(b"X");
1191                thread::sleep(Duration::from_millis(5));
1192            }
1193        });
1194        thread::sleep(Duration::from_millis(30));
1195
1196        let mut conn = Connect::new(mock_config(port)).unwrap();
1197        let result = conn.query("SELECT 1");
1198        assert!(result.is_err());
1199    }
1200
1201    #[test]
1202    fn connect_read_partial_auth_frame() {
1203        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1204        let port = listener.local_addr().unwrap().port();
1205        thread::spawn(move || {
1206            let (mut s, _) = listener.accept().unwrap();
1207            let mut buf = [0u8; 4096];
1208            let _ = s.read(&mut buf);
1209            let auth = pg_auth(3, &[]);
1210            let _ = s.write_all(&auth[..5]);
1211            thread::sleep(Duration::from_millis(50));
1212            let _ = s.write_all(&auth[5..]);
1213            let _ = s.read(&mut buf);
1214            let _ = s.write_all(&post_auth_ok());
1215            loop {
1216                match s.read(&mut buf) {
1217                    Ok(0) | Err(_) => break,
1218                    Ok(_) => {
1219                        let _ = s.write_all(&simple_query_response());
1220                    }
1221                }
1222            }
1223        });
1224        thread::sleep(Duration::from_millis(30));
1225
1226        let mut conn = Connect::new(mock_config(port)).unwrap();
1227        let result = conn.query("SELECT 1");
1228        assert!(result.is_ok());
1229    }
1230}