Skip to main content

br_pgsql/
connect.rs

1use crate::config::Config;
2use crate::error::PgsqlError;
3use crate::packet::{AuthStatus, Packet, SuccessMessage};
4use std::io::{Read, Write};
5use std::net::TcpStream;
6use std::sync::Arc;
7use std::time::Duration;
8
9#[derive(Clone, Debug)]
10pub struct Connect {
11    pub(crate) stream: Arc<TcpStream>,
12    packet: Packet,
13    auth_status: AuthStatus,
14}
15
16impl Connect {
17    pub fn is_valid(&mut self) -> bool {
18        self.query("SELECT 1").is_ok()
19    }
20
21    pub fn _close(&mut self) {
22        let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
23        let _ = self.stream.shutdown(std::net::Shutdown::Both);
24    }
25
26    pub fn new(mut config: Config) -> Result<Connect, PgsqlError> {
27        let stream =
28            TcpStream::connect(config.url()).map_err(|e| PgsqlError::Connection(e.to_string()))?;
29
30        stream
31            .set_read_timeout(Some(Duration::from_secs(30)))
32            .map_err(|e| PgsqlError::Connection(format!("设置读取超时失败: {}", e)))?;
33        stream
34            .set_write_timeout(Some(Duration::from_secs(30)))
35            .map_err(|e| PgsqlError::Connection(format!("设置写入超时失败: {}", e)))?;
36
37        let _ = stream.peer_addr();
38
39        let mut connect = Self {
40            stream: Arc::new(stream),
41            packet: Packet::new(config),
42            auth_status: AuthStatus::None,
43        };
44
45        connect.authenticate()?;
46
47        Ok(connect)
48    }
49
50    fn authenticate(&mut self) -> Result<(), PgsqlError> {
51        self.stream
52            .as_ref()
53            .write_all(&self.packet.pack_first())
54            .map_err(|e| PgsqlError::Auth(format!("发送 startup message 失败: {}", e)))?;
55
56        let data = self.read()?;
57        self.packet.unpack(data, 0)?;
58
59        if !self.packet.md5_salt.is_empty() {
60            self.md5_auth()?;
61        } else if self.packet.auth_mechanism.is_empty() && self.packet.md5_salt.is_empty() {
62            self.cleartext_auth()?;
63        } else {
64            self.scram_auth()?;
65        }
66
67        self.auth_status = AuthStatus::AuthenticationOk;
68        Ok(())
69    }
70
71    fn md5_auth(&mut self) -> Result<(), PgsqlError> {
72        self.stream
73            .as_ref()
74            .write_all(&self.packet.pack_md5_password())
75            .map_err(|e| PgsqlError::Auth(format!("发送 MD5 密码失败: {}", e)))?;
76
77        let data = self.read()?;
78        self.packet.unpack(data, 0)?;
79        Ok(())
80    }
81
82    fn cleartext_auth(&mut self) -> Result<(), PgsqlError> {
83        self.stream
84            .as_ref()
85            .write_all(&self.packet.pack_cleartext_password())
86            .map_err(|e| PgsqlError::Auth(format!("发送明文密码失败: {}", e)))?;
87
88        let data = self.read()?;
89        self.packet.unpack(data, 0)?;
90        Ok(())
91    }
92
93    fn scram_auth(&mut self) -> Result<(), PgsqlError> {
94        self.stream
95            .as_ref()
96            .write_all(&self.packet.pack_auth())
97            .map_err(|e| PgsqlError::Auth(format!("发送 SASL Initial Response 失败: {}", e)))?;
98
99        let data = self.read()?;
100        self.packet.unpack(data, 0)?;
101
102        self.stream
103            .as_ref()
104            .write_all(&self.packet.pack_auth_verify())
105            .map_err(|e| PgsqlError::Auth(format!("发送 SASL Verify 失败: {}", e)))?;
106
107        let data = self.read()?;
108        self.packet.unpack(data, 0)?;
109        Ok(())
110    }
111
112    fn read(&mut self) -> Result<Vec<u8>, PgsqlError> {
113        let mut msg = Vec::new();
114        let mut buf = [0u8; 4096];
115        let mut retry_count = 0;
116
117        #[cfg(not(test))]
118        const MAX_RETRIES: u32 = 100;
119        #[cfg(test)]
120        const MAX_RETRIES: u32 = 3;
121
122        #[cfg(not(test))]
123        const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024;
124        #[cfg(test)]
125        const MAX_MESSAGE_SIZE: usize = 128;
126
127        #[cfg(not(test))]
128        let deadline = std::time::Instant::now() + Duration::from_secs(300);
129        #[cfg(test)]
130        let deadline = std::time::Instant::now() + Duration::from_millis(200);
131
132        loop {
133            if std::time::Instant::now() >= deadline {
134                return Err(PgsqlError::Timeout("读取总超时".into()));
135            }
136
137            match self.stream.as_ref().read(&mut buf) {
138                Ok(0) => return Err(PgsqlError::Connection("连接已关闭或服务端断开".into())),
139                Ok(n) => {
140                    if msg.len() + n > MAX_MESSAGE_SIZE {
141                        return Err(PgsqlError::Protocol("消息超过最大允许大小".into()));
142                    }
143                    msg.extend_from_slice(&buf[..n]);
144                    retry_count = 0;
145                }
146                Err(ref e)
147                    if e.kind() == std::io::ErrorKind::WouldBlock
148                        || e.kind() == std::io::ErrorKind::TimedOut =>
149                {
150                    retry_count += 1;
151                    if retry_count > MAX_RETRIES {
152                        return Err(PgsqlError::Timeout("读取超时,已达最大重试次数".into()));
153                    }
154                    std::thread::sleep(Duration::from_millis(10));
155                    continue;
156                }
157                Err(e) => return Err(PgsqlError::Io(e)),
158            };
159
160            if let AuthStatus::AuthenticationOk = self.auth_status {
161                if msg.ends_with(&[90, 0, 0, 0, 5, 73])
162                    || msg.ends_with(&[90, 0, 0, 0, 5, 84])
163                    || msg.ends_with(&[90, 0, 0, 0, 5, 69])
164                {
165                    break;
166                }
167            } else if msg.len() >= 5 {
168                let len_bytes = &msg[1..=4];
169                if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
170                    if msg.len() > len as usize {
171                        break;
172                    }
173                }
174            }
175        }
176
177        Ok(msg)
178    }
179
180    pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
181        self.stream
182            .as_ref()
183            .write_all(&self.packet.pack_query(sql))
184            .map_err(PgsqlError::Io)?;
185
186        let data = self.read()?;
187
188        self.packet.unpack(data, 0)
189    }
190
191    pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
192        self.stream
193            .as_ref()
194            .write_all(&self.packet.pack_execute(sql))
195            .map_err(PgsqlError::Io)?;
196        let data = self.read()?;
197        self.packet.unpack(data, 0)
198    }
199}
200
201impl Drop for Connect {
202    fn drop(&mut self) {
203        let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
204        let _ = self.stream.shutdown(std::net::Shutdown::Both);
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use std::net::TcpListener;
212    use std::thread;
213
214    // ── wire-protocol helpers ──────────────────────────────────────────
215
216    /// Build a single PG backend message: type_byte | len(4) | payload
217    fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
218        let mut m = Vec::with_capacity(5 + payload.len());
219        m.push(tag);
220        m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
221        m.extend_from_slice(payload);
222        m
223    }
224
225    /// Build an Authentication message (tag 'R') with given auth_type + extra payload
226    fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
227        let mut body = Vec::new();
228        body.extend(&auth_type.to_be_bytes());
229        body.extend_from_slice(extra);
230        pg_msg(b'R', &body)
231    }
232
233    /// AuthenticationOk (R, type=0)
234    fn auth_ok() -> Vec<u8> {
235        pg_auth(0, &[])
236    }
237
238    /// ParameterStatus for server_version=15.0
239    fn param_status() -> Vec<u8> {
240        pg_msg(b'S', b"server_version\x0015.0\x00")
241    }
242
243    /// BackendKeyData (process_id=1, secret_key=2)
244    fn backend_key() -> Vec<u8> {
245        let mut p = Vec::new();
246        p.extend(&1u32.to_be_bytes());
247        p.extend(&2u32.to_be_bytes());
248        pg_msg(b'K', &p)
249    }
250
251    /// ReadyForQuery (status = Idle)
252    fn ready_for_query() -> Vec<u8> {
253        pg_msg(b'Z', b"I")
254    }
255
256    /// The standard tail sent after auth succeeds: AuthOk + param + key + ready
257    fn post_auth_ok() -> Vec<u8> {
258        let mut v = Vec::new();
259        v.extend(auth_ok());
260        v.extend(param_status());
261        v.extend(backend_key());
262        v.extend(ready_for_query());
263        v
264    }
265
266    /// Build a simple query response: ParseComplete + BindComplete +
267    /// RowDescription(1 int4 col "c") + DataRow("1") + CommandComplete("SELECT 1") + ReadyForQuery
268    fn simple_query_response() -> Vec<u8> {
269        let mut r = Vec::new();
270        // ParseComplete
271        r.extend(pg_msg(b'1', &[]));
272        // BindComplete
273        r.extend(pg_msg(b'2', &[]));
274        // RowDescription – 1 field "c", type_oid=23 (int4)
275        let mut rd = Vec::new();
276        rd.extend(&1u16.to_be_bytes()); // field count
277        rd.extend(b"c\x00"); // name
278        rd.extend(&0u32.to_be_bytes()); // table oid
279        rd.extend(&1u16.to_be_bytes()); // column index
280        rd.extend(&23u32.to_be_bytes()); // type oid (int4)
281        rd.extend(&4i16.to_be_bytes()); // column length
282        rd.extend(&(-1i32).to_be_bytes()); // type modifier
283        rd.extend(&0u16.to_be_bytes()); // format (text)
284        r.extend(pg_msg(b'T', &rd));
285        // DataRow – 1 field, value "1"
286        let mut dr = Vec::new();
287        dr.extend(&1u16.to_be_bytes());
288        dr.extend(&1u32.to_be_bytes()); // length of value
289        dr.push(b'1');
290        r.extend(pg_msg(b'D', &dr));
291        // CommandComplete
292        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
293        // ReadyForQuery
294        r.extend(ready_for_query());
295        r
296    }
297
298    /// Build an execute response (no rows): ParseComplete + BindComplete +
299    /// NoData + CommandComplete("UPDATE 3") + ReadyForQuery
300    fn execute_response() -> Vec<u8> {
301        let mut r = Vec::new();
302        r.extend(pg_msg(b'1', &[]));
303        r.extend(pg_msg(b'2', &[]));
304        r.extend(pg_msg(b'n', &[])); // NoData
305        r.extend(pg_msg(b'C', b"UPDATE 3\x00"));
306        r.extend(ready_for_query());
307        r
308    }
309
310    /// Build an ErrorResponse for query phase
311    fn error_response() -> Vec<u8> {
312        let mut payload = Vec::new();
313        payload.push(b'C');
314        payload.extend(b"42601\x00");
315        payload.push(b'M');
316        payload.extend(b"syntax error\x00");
317        payload.push(0);
318        let mut r = Vec::new();
319        r.extend(pg_msg(b'1', &[]));
320        r.extend(pg_msg(b'2', &[]));
321        r.extend(pg_msg(b'E', &payload));
322        r.extend(ready_for_query());
323        r
324    }
325
326    // ── mock server spawners ───────────────────────────────────────────
327
328    /// Config pointing at 127.0.0.1:<port>
329    fn mock_config(port: u16) -> Config {
330        Config {
331            debug: false,
332            hostname: "127.0.0.1".into(),
333            hostport: port as i32,
334            username: "u".into(),
335            userpass: "p".into(),
336            database: "d".into(),
337            charset: "utf8".into(),
338            pool_max: 5,
339        }
340    }
341
342    /// Spawn a mock PG server that does **cleartext** auth.
343    /// Returns the port.  The server handles one connection.
344    fn spawn_cleartext_server() -> u16 {
345        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
346        let port = listener.local_addr().unwrap().port();
347        thread::spawn(move || {
348            let (mut s, _) = listener.accept().unwrap();
349            let mut buf = [0u8; 4096];
350            // 1. read startup
351            let _ = s.read(&mut buf).unwrap();
352            // 2. send CleartextPassword request (auth_type=3)
353            let _ = s.write_all(&pg_auth(3, &[]));
354            // 3. read password message
355            let _ = s.read(&mut buf).unwrap();
356            // 4. send AuthOk + params + ready
357            let _ = s.write_all(&post_auth_ok());
358            // keep connection alive for queries
359            loop {
360                match s.read(&mut buf) {
361                    Ok(0) | Err(_) => break,
362                    Ok(_) => {
363                        let _ = s.write_all(&simple_query_response());
364                    }
365                }
366            }
367        });
368        thread::sleep(Duration::from_millis(30));
369        port
370    }
371
372    /// Spawn a mock PG server that does **MD5** auth.
373    fn spawn_md5_server() -> u16 {
374        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
375        let port = listener.local_addr().unwrap().port();
376        thread::spawn(move || {
377            let (mut s, _) = listener.accept().unwrap();
378            let mut buf = [0u8; 4096];
379            // 1. read startup
380            let _ = s.read(&mut buf).unwrap();
381            // 2. send MD5Password request (auth_type=5) + 4-byte salt
382            let _ = s.write_all(&pg_auth(5, &[0xAA, 0xBB, 0xCC, 0xDD]));
383            // 3. read md5 password
384            let _ = s.read(&mut buf).unwrap();
385            // 4. send AuthOk + params + ready
386            let _ = s.write_all(&post_auth_ok());
387            loop {
388                match s.read(&mut buf) {
389                    Ok(0) | Err(_) => break,
390                    Ok(_) => {
391                        let _ = s.write_all(&simple_query_response());
392                    }
393                }
394            }
395        });
396        thread::sleep(Duration::from_millis(30));
397        port
398    }
399
400    /// Spawn a mock PG server that does **SCRAM-SHA-256** auth.
401    fn spawn_scram_server() -> u16 {
402        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
403        let port = listener.local_addr().unwrap().port();
404        thread::spawn(move || {
405            let (mut s, _) = listener.accept().unwrap();
406            let mut buf = [0u8; 4096];
407            // 1. read startup
408            let _ = s.read(&mut buf).unwrap();
409            // 2. send SASL auth (auth_type=10, mechanism)
410            let _ = s.write_all(&pg_auth(10, b"SCRAM-SHA-256\x00\x00"));
411            // 3. read SASLInitialResponse – extract client nonce
412            let n = s.read(&mut buf).unwrap();
413            let payload = &buf[..n];
414            // find "n=,r=" in the payload to extract client nonce
415            let text = String::from_utf8_lossy(payload);
416            let client_nonce = text.split("r=").nth(1).unwrap_or("clientnonce").to_string();
417            // 4. send SCRAM challenge (auth_type=11)
418            let challenge = format!("r={client_nonce}SERVERNONCE,s=c2FsdA==,i=4096");
419            let _ = s.write_all(&pg_auth(11, challenge.as_bytes()));
420            // 5. read SCRAM client final
421            let _ = s.read(&mut buf).unwrap();
422            // 6. send SCRAM complete (auth_type=12) + AuthOk + params + ready
423            let mut resp = Vec::new();
424            resp.extend(pg_auth(12, b"v=dummyproof"));
425            resp.extend(auth_ok());
426            resp.extend(param_status());
427            resp.extend(backend_key());
428            resp.extend(ready_for_query());
429            let _ = s.write_all(&resp);
430            loop {
431                match s.read(&mut buf) {
432                    Ok(0) | Err(_) => break,
433                    Ok(_) => {
434                        let _ = s.write_all(&simple_query_response());
435                    }
436                }
437            }
438        });
439        thread::sleep(Duration::from_millis(30));
440        port
441    }
442
443    /// Spawn a server that accepts connection then immediately closes (EOF).
444    fn spawn_eof_server() -> u16 {
445        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
446        let port = listener.local_addr().unwrap().port();
447        thread::spawn(move || {
448            let (s, _) = listener.accept().unwrap();
449            drop(s); // close immediately
450        });
451        thread::sleep(Duration::from_millis(30));
452        port
453    }
454
455    /// Spawn a server that sends an ErrorResponse after startup.
456    fn spawn_auth_error_server() -> u16 {
457        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
458        let port = listener.local_addr().unwrap().port();
459        thread::spawn(move || {
460            let (mut s, _) = listener.accept().unwrap();
461            let mut buf = [0u8; 4096];
462            let _ = s.read(&mut buf).unwrap();
463            // Send ErrorResponse
464            let mut payload = Vec::new();
465            payload.push(b'C');
466            payload.extend(b"28P01\x00");
467            payload.push(b'M');
468            payload.extend(b"password authentication failed\x00");
469            payload.push(0);
470            let _ = s.write_all(&pg_msg(b'E', &payload));
471        });
472        thread::sleep(Duration::from_millis(30));
473        port
474    }
475
476    /// Spawn a cleartext server that responds with error on query.
477    fn spawn_query_error_server() -> u16 {
478        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
479        let port = listener.local_addr().unwrap().port();
480        thread::spawn(move || {
481            let (mut s, _) = listener.accept().unwrap();
482            let mut buf = [0u8; 4096];
483            // auth
484            let _ = s.read(&mut buf).unwrap();
485            let _ = s.write_all(&pg_auth(3, &[]));
486            let _ = s.read(&mut buf).unwrap();
487            let _ = s.write_all(&post_auth_ok());
488            // query → error
489            loop {
490                match s.read(&mut buf) {
491                    Ok(0) | Err(_) => break,
492                    Ok(_) => {
493                        let _ = s.write_all(&error_response());
494                    }
495                }
496            }
497        });
498        thread::sleep(Duration::from_millis(30));
499        port
500    }
501
502    /// Spawn a cleartext server that responds with execute response.
503    fn spawn_execute_server() -> u16 {
504        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
505        let port = listener.local_addr().unwrap().port();
506        thread::spawn(move || {
507            let (mut s, _) = listener.accept().unwrap();
508            let mut buf = [0u8; 4096];
509            // auth
510            let _ = s.read(&mut buf).unwrap();
511            let _ = s.write_all(&pg_auth(3, &[]));
512            let _ = s.read(&mut buf).unwrap();
513            let _ = s.write_all(&post_auth_ok());
514            // queries
515            loop {
516                match s.read(&mut buf) {
517                    Ok(0) | Err(_) => break,
518                    Ok(_) => {
519                        let _ = s.write_all(&execute_response());
520                    }
521                }
522            }
523        });
524        thread::sleep(Duration::from_millis(30));
525        port
526    }
527
528    // ── tests ──────────────────────────────────────────────────────────
529
530    #[test]
531    fn connect_cleartext_auth_success() {
532        let port = spawn_cleartext_server();
533        let conn = Connect::new(mock_config(port));
534        assert!(conn.is_ok());
535    }
536
537    #[test]
538    fn connect_md5_auth_success() {
539        let port = spawn_md5_server();
540        let conn = Connect::new(mock_config(port));
541        assert!(conn.is_ok());
542    }
543
544    #[test]
545    fn connect_scram_auth_success() {
546        let port = spawn_scram_server();
547        let conn = Connect::new(mock_config(port));
548        assert!(conn.is_ok());
549    }
550
551    #[test]
552    fn connect_connection_refused() {
553        // port 1 is almost certainly not listening
554        let cfg = mock_config(1);
555        let result = Connect::new(cfg);
556        assert!(result.is_err());
557        match result.unwrap_err() {
558            PgsqlError::Connection(_) => {}
559            other => panic!("expected Connection error, got {other:?}"),
560        }
561    }
562
563    #[test]
564    fn connect_server_closes_immediately() {
565        let port = spawn_eof_server();
566        let result = Connect::new(mock_config(port));
567        assert!(result.is_err());
568    }
569
570    #[test]
571    fn connect_auth_error_from_server() {
572        let port = spawn_auth_error_server();
573        let result = Connect::new(mock_config(port));
574        assert!(result.is_err());
575    }
576
577    #[test]
578    fn connect_query_success() {
579        let port = spawn_cleartext_server();
580        let mut conn = Connect::new(mock_config(port)).unwrap();
581        let result = conn.query("SELECT 1");
582        assert!(result.is_ok());
583        let msg = result.unwrap();
584        assert_eq!(msg.rows.len(), 1);
585        assert_eq!(msg.rows[0]["c"].as_i32(), Some(1));
586    }
587
588    #[test]
589    fn connect_execute_success() {
590        let port = spawn_execute_server();
591        let mut conn = Connect::new(mock_config(port)).unwrap();
592        let result = conn.execute("UPDATE t SET x=1");
593        assert!(result.is_ok());
594        let msg = result.unwrap();
595        assert_eq!(msg.affect_count, 3);
596        assert_eq!(msg.tag, "UPDATE 3");
597    }
598
599    #[test]
600    fn connect_query_returns_error() {
601        let port = spawn_query_error_server();
602        let mut conn = Connect::new(mock_config(port)).unwrap();
603        let result = conn.query("BAD SQL");
604        assert!(result.is_err());
605    }
606
607    #[test]
608    fn connect_is_valid_true() {
609        let port = spawn_cleartext_server();
610        let mut conn = Connect::new(mock_config(port)).unwrap();
611        assert!(conn.is_valid());
612    }
613
614    #[test]
615    fn connect_is_valid_false_after_close() {
616        let port = spawn_cleartext_server();
617        let mut conn = Connect::new(mock_config(port)).unwrap();
618        conn._close();
619        // After closing, is_valid should return false
620        assert!(!conn.is_valid());
621    }
622
623    #[test]
624    fn connect_close_does_not_panic() {
625        let port = spawn_cleartext_server();
626        let mut conn = Connect::new(mock_config(port)).unwrap();
627        conn._close();
628        // calling close again should not panic
629        conn._close();
630    }
631
632    #[test]
633    fn connect_drop_does_not_panic() {
634        let port = spawn_cleartext_server();
635        let conn = Connect::new(mock_config(port)).unwrap();
636        drop(conn);
637    }
638
639    fn spawn_transaction_status_server() -> u16 {
640        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
641        let port = listener.local_addr().unwrap().port();
642        thread::spawn(move || {
643            let (mut s, _) = listener.accept().unwrap();
644            let mut buf = [0u8; 4096];
645            let _ = s.read(&mut buf).unwrap();
646            let _ = s.write_all(&pg_auth(3, &[]));
647            let _ = s.read(&mut buf).unwrap();
648            let _ = s.write_all(&post_auth_ok());
649            loop {
650                match s.read(&mut buf) {
651                    Ok(0) | Err(_) => break,
652                    Ok(_) => {
653                        let mut r = Vec::new();
654                        r.extend(pg_msg(b'1', &[]));
655                        r.extend(pg_msg(b'2', &[]));
656                        let mut rd = Vec::new();
657                        rd.extend(&1u16.to_be_bytes());
658                        rd.extend(b"c\x00");
659                        rd.extend(&0u32.to_be_bytes());
660                        rd.extend(&1u16.to_be_bytes());
661                        rd.extend(&23u32.to_be_bytes());
662                        rd.extend(&4i16.to_be_bytes());
663                        rd.extend(&(-1i32).to_be_bytes());
664                        rd.extend(&0u16.to_be_bytes());
665                        r.extend(pg_msg(b'T', &rd));
666                        let mut dr = Vec::new();
667                        dr.extend(&1u16.to_be_bytes());
668                        dr.extend(&1u32.to_be_bytes());
669                        dr.push(b'1');
670                        r.extend(pg_msg(b'D', &dr));
671                        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
672                        r.extend(pg_msg(b'Z', b"T"));
673                        let _ = s.write_all(&r);
674                    }
675                }
676            }
677        });
678        thread::sleep(Duration::from_millis(30));
679        port
680    }
681
682    fn spawn_error_status_server() -> u16 {
683        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
684        let port = listener.local_addr().unwrap().port();
685        thread::spawn(move || {
686            let (mut s, _) = listener.accept().unwrap();
687            let mut buf = [0u8; 4096];
688            let _ = s.read(&mut buf).unwrap();
689            let _ = s.write_all(&pg_auth(3, &[]));
690            let _ = s.read(&mut buf).unwrap();
691            let _ = s.write_all(&post_auth_ok());
692            loop {
693                match s.read(&mut buf) {
694                    Ok(0) | Err(_) => break,
695                    Ok(_) => {
696                        let mut r = Vec::new();
697                        r.extend(pg_msg(b'1', &[]));
698                        r.extend(pg_msg(b'2', &[]));
699                        let mut rd = Vec::new();
700                        rd.extend(&1u16.to_be_bytes());
701                        rd.extend(b"c\x00");
702                        rd.extend(&0u32.to_be_bytes());
703                        rd.extend(&1u16.to_be_bytes());
704                        rd.extend(&23u32.to_be_bytes());
705                        rd.extend(&4i16.to_be_bytes());
706                        rd.extend(&(-1i32).to_be_bytes());
707                        rd.extend(&0u16.to_be_bytes());
708                        r.extend(pg_msg(b'T', &rd));
709                        let mut dr = Vec::new();
710                        dr.extend(&1u16.to_be_bytes());
711                        dr.extend(&1u32.to_be_bytes());
712                        dr.push(b'1');
713                        r.extend(pg_msg(b'D', &dr));
714                        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
715                        r.extend(pg_msg(b'Z', b"E"));
716                        let _ = s.write_all(&r);
717                    }
718                }
719            }
720        });
721        thread::sleep(Duration::from_millis(30));
722        port
723    }
724
725    fn spawn_slow_partial_server() -> u16 {
726        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
727        let port = listener.local_addr().unwrap().port();
728        thread::spawn(move || {
729            let (mut s, _) = listener.accept().unwrap();
730            let mut buf = [0u8; 4096];
731            let _ = s.read(&mut buf).unwrap();
732            let _ = s.write_all(&pg_auth(3, &[]));
733            let _ = s.read(&mut buf).unwrap();
734            let _ = s.write_all(&post_auth_ok());
735            match s.read(&mut buf) {
736                Ok(0) | Err(_) => {}
737                Ok(_) => {
738                    let _ = s.write_all(&simple_query_response());
739                }
740            }
741        });
742        thread::sleep(Duration::from_millis(30));
743        port
744    }
745
746    fn spawn_rst_on_query_server() -> u16 {
747        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
748        let port = listener.local_addr().unwrap().port();
749        thread::spawn(move || {
750            let (mut s, _) = listener.accept().unwrap();
751            let mut buf = [0u8; 4096];
752            let _ = s.read(&mut buf).unwrap();
753            let _ = s.write_all(&pg_auth(3, &[]));
754            let _ = s.read(&mut buf).unwrap();
755            let _ = s.write_all(&post_auth_ok());
756            match s.read(&mut buf) {
757                Ok(0) | Err(_) => {}
758                Ok(_) => {
759                    drop(s);
760                }
761            }
762        });
763        thread::sleep(Duration::from_millis(30));
764        port
765    }
766
767    #[test]
768    fn connect_query_ready_for_query_transaction_status() {
769        let port = spawn_transaction_status_server();
770        let mut conn = Connect::new(mock_config(port)).unwrap();
771        let result = conn.query("SELECT 1");
772        assert!(result.is_ok());
773    }
774
775    #[test]
776    fn connect_query_ready_for_query_error_status() {
777        let port = spawn_error_status_server();
778        let mut conn = Connect::new(mock_config(port)).unwrap();
779        let result = conn.query("SELECT 1");
780        assert!(result.is_ok());
781    }
782
783    #[test]
784    fn connect_query_server_closes_after_partial() {
785        let port = spawn_slow_partial_server();
786        let mut conn = Connect::new(mock_config(port)).unwrap();
787        let r1 = conn.query("SELECT 1");
788        assert!(r1.is_ok());
789        let r2 = conn.query("SELECT 1");
790        assert!(r2.is_err());
791    }
792
793    #[test]
794    fn connect_query_server_rst_returns_io_or_connection_error() {
795        let port = spawn_rst_on_query_server();
796        let mut conn = Connect::new(mock_config(port)).unwrap();
797        let result = conn.query("SELECT 1");
798        assert!(result.is_err());
799    }
800
801    #[test]
802    fn connect_read_would_block_max_retries() {
803        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
804        let port = listener.local_addr().unwrap().port();
805        thread::spawn(move || {
806            let (mut s, _) = listener.accept().unwrap();
807            let mut buf = [0u8; 4096];
808            let _ = s.read(&mut buf);
809            let _ = s.write_all(&pg_auth(3, &[]));
810            let _ = s.read(&mut buf);
811            let _ = s.write_all(&post_auth_ok());
812            let _ = s.read(&mut buf);
813            thread::sleep(Duration::from_secs(5));
814        });
815        thread::sleep(Duration::from_millis(30));
816
817        let mut conn = Connect::new(mock_config(port)).unwrap();
818        conn.stream
819            .set_read_timeout(Some(Duration::from_millis(1)))
820            .ok();
821        let result = conn.query("SELECT 1");
822        assert!(result.is_err());
823        let err_str = result.unwrap_err().to_string();
824        assert!(
825            err_str.contains("超时") || err_str.contains("Timeout") || err_str.contains("重试"),
826            "expected timeout error, got: {err_str}"
827        );
828    }
829
830    #[test]
831    fn connect_read_exceeds_max_message_size() {
832        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
833        let port = listener.local_addr().unwrap().port();
834        thread::spawn(move || {
835            let (mut s, _) = listener.accept().unwrap();
836            let mut buf = [0u8; 4096];
837            let _ = s.read(&mut buf);
838            let _ = s.write_all(&pg_auth(3, &[]));
839            let _ = s.read(&mut buf);
840            let _ = s.write_all(&post_auth_ok());
841            let _ = s.read(&mut buf);
842            let big = vec![b'X'; 256];
843            let _ = s.write_all(&big);
844            thread::sleep(Duration::from_secs(2));
845        });
846        thread::sleep(Duration::from_millis(30));
847
848        let mut conn = Connect::new(mock_config(port)).unwrap();
849        let result = conn.query("SELECT 1");
850        assert!(result.is_err());
851        let err_str = result.unwrap_err().to_string();
852        assert!(
853            err_str.contains("最大") || err_str.contains("大小") || err_str.contains("size"),
854            "expected max message size error, got: {err_str}"
855        );
856    }
857
858    #[test]
859    fn connect_read_deadline_timeout() {
860        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
861        let port = listener.local_addr().unwrap().port();
862        thread::spawn(move || {
863            let (mut s, _) = listener.accept().unwrap();
864            let mut buf = [0u8; 4096];
865            let _ = s.read(&mut buf);
866            let _ = s.write_all(&pg_auth(3, &[]));
867            let _ = s.read(&mut buf);
868            let _ = s.write_all(&post_auth_ok());
869            let _ = s.read(&mut buf);
870            for _ in 0..200 {
871                let _ = s.write_all(b"X");
872                thread::sleep(Duration::from_millis(5));
873            }
874        });
875        thread::sleep(Duration::from_millis(30));
876
877        let mut conn = Connect::new(mock_config(port)).unwrap();
878        let result = conn.query("SELECT 1");
879        assert!(result.is_err());
880    }
881
882    #[test]
883    fn connect_read_partial_auth_frame() {
884        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
885        let port = listener.local_addr().unwrap().port();
886        thread::spawn(move || {
887            let (mut s, _) = listener.accept().unwrap();
888            let mut buf = [0u8; 4096];
889            let _ = s.read(&mut buf);
890            let auth = pg_auth(3, &[]);
891            let _ = s.write_all(&auth[..5]);
892            thread::sleep(Duration::from_millis(50));
893            let _ = s.write_all(&auth[5..]);
894            let _ = s.read(&mut buf);
895            let _ = s.write_all(&post_auth_ok());
896            loop {
897                match s.read(&mut buf) {
898                    Ok(0) | Err(_) => break,
899                    Ok(_) => {
900                        let _ = s.write_all(&simple_query_response());
901                    }
902                }
903            }
904        });
905        thread::sleep(Duration::from_millis(30));
906
907        let mut conn = Connect::new(mock_config(port)).unwrap();
908        let result = conn.query("SELECT 1");
909        assert!(result.is_ok());
910    }
911}