1use crate::config::Config;
2use crate::error::PgsqlError;
3use crate::packet::{AuthStatus, Packet, SuccessMessage};
4use std::io::{Read, Write};
5use std::net::TcpStream;
6use std::sync::Arc;
7use std::time::Duration;
8
9#[derive(Clone, Debug)]
10pub struct Connect {
11 pub(crate) stream: Arc<TcpStream>,
12 packet: Packet,
13 auth_status: AuthStatus,
14}
15
16impl Connect {
17 pub fn is_valid(&mut self) -> bool {
18 self.query("SELECT 1").is_ok()
19 }
20
21 pub fn _close(&mut self) {
22 let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
23 let _ = self.stream.shutdown(std::net::Shutdown::Both);
24 }
25
26 pub fn new(mut config: Config) -> Result<Connect, PgsqlError> {
27 let stream =
28 TcpStream::connect(config.url()).map_err(|e| PgsqlError::Connection(e.to_string()))?;
29
30 stream
31 .set_read_timeout(Some(Duration::from_secs(30)))
32 .map_err(|e| PgsqlError::Connection(format!("设置读取超时失败: {}", e)))?;
33 stream
34 .set_write_timeout(Some(Duration::from_secs(30)))
35 .map_err(|e| PgsqlError::Connection(format!("设置写入超时失败: {}", e)))?;
36
37 let _ = stream.peer_addr();
38
39 let mut connect = Self {
40 stream: Arc::new(stream),
41 packet: Packet::new(config),
42 auth_status: AuthStatus::None,
43 };
44
45 connect.authenticate()?;
46
47 Ok(connect)
48 }
49
50 fn authenticate(&mut self) -> Result<(), PgsqlError> {
51 self.stream
52 .as_ref()
53 .write_all(&self.packet.pack_first())
54 .map_err(|e| PgsqlError::Auth(format!("发送 startup message 失败: {}", e)))?;
55
56 let data = self.read()?;
57 self.packet.unpack(data, 0)?;
58
59 if !self.packet.md5_salt.is_empty() {
60 self.md5_auth()?;
61 } else if self.packet.auth_mechanism.is_empty() && self.packet.md5_salt.is_empty() {
62 self.cleartext_auth()?;
63 } else {
64 self.scram_auth()?;
65 }
66
67 self.auth_status = AuthStatus::AuthenticationOk;
68 Ok(())
69 }
70
71 fn md5_auth(&mut self) -> Result<(), PgsqlError> {
72 self.stream
73 .as_ref()
74 .write_all(&self.packet.pack_md5_password())
75 .map_err(|e| PgsqlError::Auth(format!("发送 MD5 密码失败: {}", e)))?;
76
77 let data = self.read()?;
78 self.packet.unpack(data, 0)?;
79 Ok(())
80 }
81
82 fn cleartext_auth(&mut self) -> Result<(), PgsqlError> {
83 self.stream
84 .as_ref()
85 .write_all(&self.packet.pack_cleartext_password())
86 .map_err(|e| PgsqlError::Auth(format!("发送明文密码失败: {}", e)))?;
87
88 let data = self.read()?;
89 self.packet.unpack(data, 0)?;
90 Ok(())
91 }
92
93 fn scram_auth(&mut self) -> Result<(), PgsqlError> {
94 self.stream
95 .as_ref()
96 .write_all(&self.packet.pack_auth())
97 .map_err(|e| PgsqlError::Auth(format!("发送 SASL Initial Response 失败: {}", e)))?;
98
99 let data = self.read()?;
100 self.packet.unpack(data, 0)?;
101
102 self.stream
103 .as_ref()
104 .write_all(&self.packet.pack_auth_verify())
105 .map_err(|e| PgsqlError::Auth(format!("发送 SASL Verify 失败: {}", e)))?;
106
107 let data = self.read()?;
108 self.packet.unpack(data, 0)?;
109 Ok(())
110 }
111
112 fn read(&mut self) -> Result<Vec<u8>, PgsqlError> {
113 let mut msg = Vec::new();
114 let mut buf = [0u8; 4096];
115 let mut retry_count = 0;
116
117 #[cfg(not(test))]
118 const MAX_RETRIES: u32 = 100;
119 #[cfg(test)]
120 const MAX_RETRIES: u32 = 3;
121
122 #[cfg(not(test))]
123 const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024;
124 #[cfg(test)]
125 const MAX_MESSAGE_SIZE: usize = 128;
126
127 #[cfg(not(test))]
128 let deadline = std::time::Instant::now() + Duration::from_secs(300);
129 #[cfg(test)]
130 let deadline = std::time::Instant::now() + Duration::from_millis(200);
131
132 loop {
133 if std::time::Instant::now() >= deadline {
134 return Err(PgsqlError::Timeout("读取总超时".into()));
135 }
136
137 match self.stream.as_ref().read(&mut buf) {
138 Ok(0) => return Err(PgsqlError::Connection("连接已关闭或服务端断开".into())),
139 Ok(n) => {
140 if msg.len() + n > MAX_MESSAGE_SIZE {
141 return Err(PgsqlError::Protocol("消息超过最大允许大小".into()));
142 }
143 msg.extend_from_slice(&buf[..n]);
144 retry_count = 0;
145 }
146 Err(ref e)
147 if e.kind() == std::io::ErrorKind::WouldBlock
148 || e.kind() == std::io::ErrorKind::TimedOut =>
149 {
150 retry_count += 1;
151 if retry_count > MAX_RETRIES {
152 return Err(PgsqlError::Timeout("读取超时,已达最大重试次数".into()));
153 }
154 std::thread::sleep(Duration::from_millis(10));
155 continue;
156 }
157 Err(e) => return Err(PgsqlError::Io(e)),
158 };
159
160 if let AuthStatus::AuthenticationOk = self.auth_status {
161 if msg.ends_with(&[90, 0, 0, 0, 5, 73])
162 || msg.ends_with(&[90, 0, 0, 0, 5, 84])
163 || msg.ends_with(&[90, 0, 0, 0, 5, 69])
164 {
165 break;
166 }
167 } else if msg.len() >= 5 {
168 let len_bytes = &msg[1..=4];
169 if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
170 if msg.len() > len as usize {
171 break;
172 }
173 }
174 }
175 }
176
177 Ok(msg)
178 }
179
180 pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
181 self.stream
182 .as_ref()
183 .write_all(&self.packet.pack_query(sql))
184 .map_err(PgsqlError::Io)?;
185
186 let data = self.read()?;
187
188 self.packet.unpack(data, 0)
189 }
190
191 pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
192 self.stream
193 .as_ref()
194 .write_all(&self.packet.pack_execute(sql))
195 .map_err(PgsqlError::Io)?;
196 let data = self.read()?;
197 self.packet.unpack(data, 0)
198 }
199}
200
201impl Drop for Connect {
202 fn drop(&mut self) {
203 let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
204 let _ = self.stream.shutdown(std::net::Shutdown::Both);
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use std::net::TcpListener;
212 use std::thread;
213
214 fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
218 let mut m = Vec::with_capacity(5 + payload.len());
219 m.push(tag);
220 m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
221 m.extend_from_slice(payload);
222 m
223 }
224
225 fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
227 let mut body = Vec::new();
228 body.extend(&auth_type.to_be_bytes());
229 body.extend_from_slice(extra);
230 pg_msg(b'R', &body)
231 }
232
233 fn auth_ok() -> Vec<u8> {
235 pg_auth(0, &[])
236 }
237
238 fn param_status() -> Vec<u8> {
240 pg_msg(b'S', b"server_version\x0015.0\x00")
241 }
242
243 fn backend_key() -> Vec<u8> {
245 let mut p = Vec::new();
246 p.extend(&1u32.to_be_bytes());
247 p.extend(&2u32.to_be_bytes());
248 pg_msg(b'K', &p)
249 }
250
251 fn ready_for_query() -> Vec<u8> {
253 pg_msg(b'Z', b"I")
254 }
255
256 fn post_auth_ok() -> Vec<u8> {
258 let mut v = Vec::new();
259 v.extend(auth_ok());
260 v.extend(param_status());
261 v.extend(backend_key());
262 v.extend(ready_for_query());
263 v
264 }
265
266 fn simple_query_response() -> Vec<u8> {
269 let mut r = Vec::new();
270 r.extend(pg_msg(b'1', &[]));
272 r.extend(pg_msg(b'2', &[]));
274 let mut rd = Vec::new();
276 rd.extend(&1u16.to_be_bytes()); rd.extend(b"c\x00"); rd.extend(&0u32.to_be_bytes()); rd.extend(&1u16.to_be_bytes()); rd.extend(&23u32.to_be_bytes()); rd.extend(&4i16.to_be_bytes()); rd.extend(&(-1i32).to_be_bytes()); rd.extend(&0u16.to_be_bytes()); r.extend(pg_msg(b'T', &rd));
285 let mut dr = Vec::new();
287 dr.extend(&1u16.to_be_bytes());
288 dr.extend(&1u32.to_be_bytes()); dr.push(b'1');
290 r.extend(pg_msg(b'D', &dr));
291 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
293 r.extend(ready_for_query());
295 r
296 }
297
298 fn execute_response() -> Vec<u8> {
301 let mut r = Vec::new();
302 r.extend(pg_msg(b'1', &[]));
303 r.extend(pg_msg(b'2', &[]));
304 r.extend(pg_msg(b'n', &[])); r.extend(pg_msg(b'C', b"UPDATE 3\x00"));
306 r.extend(ready_for_query());
307 r
308 }
309
310 fn error_response() -> Vec<u8> {
312 let mut payload = Vec::new();
313 payload.push(b'C');
314 payload.extend(b"42601\x00");
315 payload.push(b'M');
316 payload.extend(b"syntax error\x00");
317 payload.push(0);
318 let mut r = Vec::new();
319 r.extend(pg_msg(b'1', &[]));
320 r.extend(pg_msg(b'2', &[]));
321 r.extend(pg_msg(b'E', &payload));
322 r.extend(ready_for_query());
323 r
324 }
325
326 fn mock_config(port: u16) -> Config {
330 Config {
331 debug: false,
332 hostname: "127.0.0.1".into(),
333 hostport: port as i32,
334 username: "u".into(),
335 userpass: "p".into(),
336 database: "d".into(),
337 charset: "utf8".into(),
338 pool_max: 5,
339 }
340 }
341
342 fn spawn_cleartext_server() -> u16 {
345 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
346 let port = listener.local_addr().unwrap().port();
347 thread::spawn(move || {
348 let (mut s, _) = listener.accept().unwrap();
349 let mut buf = [0u8; 4096];
350 let _ = s.read(&mut buf).unwrap();
352 let _ = s.write_all(&pg_auth(3, &[]));
354 let _ = s.read(&mut buf).unwrap();
356 let _ = s.write_all(&post_auth_ok());
358 loop {
360 match s.read(&mut buf) {
361 Ok(0) | Err(_) => break,
362 Ok(_) => {
363 let _ = s.write_all(&simple_query_response());
364 }
365 }
366 }
367 });
368 thread::sleep(Duration::from_millis(30));
369 port
370 }
371
372 fn spawn_md5_server() -> u16 {
374 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
375 let port = listener.local_addr().unwrap().port();
376 thread::spawn(move || {
377 let (mut s, _) = listener.accept().unwrap();
378 let mut buf = [0u8; 4096];
379 let _ = s.read(&mut buf).unwrap();
381 let _ = s.write_all(&pg_auth(5, &[0xAA, 0xBB, 0xCC, 0xDD]));
383 let _ = s.read(&mut buf).unwrap();
385 let _ = s.write_all(&post_auth_ok());
387 loop {
388 match s.read(&mut buf) {
389 Ok(0) | Err(_) => break,
390 Ok(_) => {
391 let _ = s.write_all(&simple_query_response());
392 }
393 }
394 }
395 });
396 thread::sleep(Duration::from_millis(30));
397 port
398 }
399
400 fn spawn_scram_server() -> u16 {
402 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
403 let port = listener.local_addr().unwrap().port();
404 thread::spawn(move || {
405 let (mut s, _) = listener.accept().unwrap();
406 let mut buf = [0u8; 4096];
407 let _ = s.read(&mut buf).unwrap();
409 let _ = s.write_all(&pg_auth(10, b"SCRAM-SHA-256\x00\x00"));
411 let n = s.read(&mut buf).unwrap();
413 let payload = &buf[..n];
414 let text = String::from_utf8_lossy(payload);
416 let client_nonce = text.split("r=").nth(1).unwrap_or("clientnonce").to_string();
417 let challenge = format!("r={client_nonce}SERVERNONCE,s=c2FsdA==,i=4096");
419 let _ = s.write_all(&pg_auth(11, challenge.as_bytes()));
420 let _ = s.read(&mut buf).unwrap();
422 let mut resp = Vec::new();
424 resp.extend(pg_auth(12, b"v=dummyproof"));
425 resp.extend(auth_ok());
426 resp.extend(param_status());
427 resp.extend(backend_key());
428 resp.extend(ready_for_query());
429 let _ = s.write_all(&resp);
430 loop {
431 match s.read(&mut buf) {
432 Ok(0) | Err(_) => break,
433 Ok(_) => {
434 let _ = s.write_all(&simple_query_response());
435 }
436 }
437 }
438 });
439 thread::sleep(Duration::from_millis(30));
440 port
441 }
442
443 fn spawn_eof_server() -> u16 {
445 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
446 let port = listener.local_addr().unwrap().port();
447 thread::spawn(move || {
448 let (s, _) = listener.accept().unwrap();
449 drop(s); });
451 thread::sleep(Duration::from_millis(30));
452 port
453 }
454
455 fn spawn_auth_error_server() -> u16 {
457 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
458 let port = listener.local_addr().unwrap().port();
459 thread::spawn(move || {
460 let (mut s, _) = listener.accept().unwrap();
461 let mut buf = [0u8; 4096];
462 let _ = s.read(&mut buf).unwrap();
463 let mut payload = Vec::new();
465 payload.push(b'C');
466 payload.extend(b"28P01\x00");
467 payload.push(b'M');
468 payload.extend(b"password authentication failed\x00");
469 payload.push(0);
470 let _ = s.write_all(&pg_msg(b'E', &payload));
471 });
472 thread::sleep(Duration::from_millis(30));
473 port
474 }
475
476 fn spawn_query_error_server() -> u16 {
478 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
479 let port = listener.local_addr().unwrap().port();
480 thread::spawn(move || {
481 let (mut s, _) = listener.accept().unwrap();
482 let mut buf = [0u8; 4096];
483 let _ = s.read(&mut buf).unwrap();
485 let _ = s.write_all(&pg_auth(3, &[]));
486 let _ = s.read(&mut buf).unwrap();
487 let _ = s.write_all(&post_auth_ok());
488 loop {
490 match s.read(&mut buf) {
491 Ok(0) | Err(_) => break,
492 Ok(_) => {
493 let _ = s.write_all(&error_response());
494 }
495 }
496 }
497 });
498 thread::sleep(Duration::from_millis(30));
499 port
500 }
501
502 fn spawn_execute_server() -> u16 {
504 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
505 let port = listener.local_addr().unwrap().port();
506 thread::spawn(move || {
507 let (mut s, _) = listener.accept().unwrap();
508 let mut buf = [0u8; 4096];
509 let _ = s.read(&mut buf).unwrap();
511 let _ = s.write_all(&pg_auth(3, &[]));
512 let _ = s.read(&mut buf).unwrap();
513 let _ = s.write_all(&post_auth_ok());
514 loop {
516 match s.read(&mut buf) {
517 Ok(0) | Err(_) => break,
518 Ok(_) => {
519 let _ = s.write_all(&execute_response());
520 }
521 }
522 }
523 });
524 thread::sleep(Duration::from_millis(30));
525 port
526 }
527
528 #[test]
531 fn connect_cleartext_auth_success() {
532 let port = spawn_cleartext_server();
533 let conn = Connect::new(mock_config(port));
534 assert!(conn.is_ok());
535 }
536
537 #[test]
538 fn connect_md5_auth_success() {
539 let port = spawn_md5_server();
540 let conn = Connect::new(mock_config(port));
541 assert!(conn.is_ok());
542 }
543
544 #[test]
545 fn connect_scram_auth_success() {
546 let port = spawn_scram_server();
547 let conn = Connect::new(mock_config(port));
548 assert!(conn.is_ok());
549 }
550
551 #[test]
552 fn connect_connection_refused() {
553 let cfg = mock_config(1);
555 let result = Connect::new(cfg);
556 assert!(result.is_err());
557 match result.unwrap_err() {
558 PgsqlError::Connection(_) => {}
559 other => panic!("expected Connection error, got {other:?}"),
560 }
561 }
562
563 #[test]
564 fn connect_server_closes_immediately() {
565 let port = spawn_eof_server();
566 let result = Connect::new(mock_config(port));
567 assert!(result.is_err());
568 }
569
570 #[test]
571 fn connect_auth_error_from_server() {
572 let port = spawn_auth_error_server();
573 let result = Connect::new(mock_config(port));
574 assert!(result.is_err());
575 }
576
577 #[test]
578 fn connect_query_success() {
579 let port = spawn_cleartext_server();
580 let mut conn = Connect::new(mock_config(port)).unwrap();
581 let result = conn.query("SELECT 1");
582 assert!(result.is_ok());
583 let msg = result.unwrap();
584 assert_eq!(msg.rows.len(), 1);
585 assert_eq!(msg.rows[0]["c"].as_i32(), Some(1));
586 }
587
588 #[test]
589 fn connect_execute_success() {
590 let port = spawn_execute_server();
591 let mut conn = Connect::new(mock_config(port)).unwrap();
592 let result = conn.execute("UPDATE t SET x=1");
593 assert!(result.is_ok());
594 let msg = result.unwrap();
595 assert_eq!(msg.affect_count, 3);
596 assert_eq!(msg.tag, "UPDATE 3");
597 }
598
599 #[test]
600 fn connect_query_returns_error() {
601 let port = spawn_query_error_server();
602 let mut conn = Connect::new(mock_config(port)).unwrap();
603 let result = conn.query("BAD SQL");
604 assert!(result.is_err());
605 }
606
607 #[test]
608 fn connect_is_valid_true() {
609 let port = spawn_cleartext_server();
610 let mut conn = Connect::new(mock_config(port)).unwrap();
611 assert!(conn.is_valid());
612 }
613
614 #[test]
615 fn connect_is_valid_false_after_close() {
616 let port = spawn_cleartext_server();
617 let mut conn = Connect::new(mock_config(port)).unwrap();
618 conn._close();
619 assert!(!conn.is_valid());
621 }
622
623 #[test]
624 fn connect_close_does_not_panic() {
625 let port = spawn_cleartext_server();
626 let mut conn = Connect::new(mock_config(port)).unwrap();
627 conn._close();
628 conn._close();
630 }
631
632 #[test]
633 fn connect_drop_does_not_panic() {
634 let port = spawn_cleartext_server();
635 let conn = Connect::new(mock_config(port)).unwrap();
636 drop(conn);
637 }
638
639 fn spawn_transaction_status_server() -> u16 {
640 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
641 let port = listener.local_addr().unwrap().port();
642 thread::spawn(move || {
643 let (mut s, _) = listener.accept().unwrap();
644 let mut buf = [0u8; 4096];
645 let _ = s.read(&mut buf).unwrap();
646 let _ = s.write_all(&pg_auth(3, &[]));
647 let _ = s.read(&mut buf).unwrap();
648 let _ = s.write_all(&post_auth_ok());
649 loop {
650 match s.read(&mut buf) {
651 Ok(0) | Err(_) => break,
652 Ok(_) => {
653 let mut r = Vec::new();
654 r.extend(pg_msg(b'1', &[]));
655 r.extend(pg_msg(b'2', &[]));
656 let mut rd = Vec::new();
657 rd.extend(&1u16.to_be_bytes());
658 rd.extend(b"c\x00");
659 rd.extend(&0u32.to_be_bytes());
660 rd.extend(&1u16.to_be_bytes());
661 rd.extend(&23u32.to_be_bytes());
662 rd.extend(&4i16.to_be_bytes());
663 rd.extend(&(-1i32).to_be_bytes());
664 rd.extend(&0u16.to_be_bytes());
665 r.extend(pg_msg(b'T', &rd));
666 let mut dr = Vec::new();
667 dr.extend(&1u16.to_be_bytes());
668 dr.extend(&1u32.to_be_bytes());
669 dr.push(b'1');
670 r.extend(pg_msg(b'D', &dr));
671 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
672 r.extend(pg_msg(b'Z', b"T"));
673 let _ = s.write_all(&r);
674 }
675 }
676 }
677 });
678 thread::sleep(Duration::from_millis(30));
679 port
680 }
681
682 fn spawn_error_status_server() -> u16 {
683 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
684 let port = listener.local_addr().unwrap().port();
685 thread::spawn(move || {
686 let (mut s, _) = listener.accept().unwrap();
687 let mut buf = [0u8; 4096];
688 let _ = s.read(&mut buf).unwrap();
689 let _ = s.write_all(&pg_auth(3, &[]));
690 let _ = s.read(&mut buf).unwrap();
691 let _ = s.write_all(&post_auth_ok());
692 loop {
693 match s.read(&mut buf) {
694 Ok(0) | Err(_) => break,
695 Ok(_) => {
696 let mut r = Vec::new();
697 r.extend(pg_msg(b'1', &[]));
698 r.extend(pg_msg(b'2', &[]));
699 let mut rd = Vec::new();
700 rd.extend(&1u16.to_be_bytes());
701 rd.extend(b"c\x00");
702 rd.extend(&0u32.to_be_bytes());
703 rd.extend(&1u16.to_be_bytes());
704 rd.extend(&23u32.to_be_bytes());
705 rd.extend(&4i16.to_be_bytes());
706 rd.extend(&(-1i32).to_be_bytes());
707 rd.extend(&0u16.to_be_bytes());
708 r.extend(pg_msg(b'T', &rd));
709 let mut dr = Vec::new();
710 dr.extend(&1u16.to_be_bytes());
711 dr.extend(&1u32.to_be_bytes());
712 dr.push(b'1');
713 r.extend(pg_msg(b'D', &dr));
714 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
715 r.extend(pg_msg(b'Z', b"E"));
716 let _ = s.write_all(&r);
717 }
718 }
719 }
720 });
721 thread::sleep(Duration::from_millis(30));
722 port
723 }
724
725 fn spawn_slow_partial_server() -> u16 {
726 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
727 let port = listener.local_addr().unwrap().port();
728 thread::spawn(move || {
729 let (mut s, _) = listener.accept().unwrap();
730 let mut buf = [0u8; 4096];
731 let _ = s.read(&mut buf).unwrap();
732 let _ = s.write_all(&pg_auth(3, &[]));
733 let _ = s.read(&mut buf).unwrap();
734 let _ = s.write_all(&post_auth_ok());
735 match s.read(&mut buf) {
736 Ok(0) | Err(_) => {}
737 Ok(_) => {
738 let _ = s.write_all(&simple_query_response());
739 }
740 }
741 });
742 thread::sleep(Duration::from_millis(30));
743 port
744 }
745
746 fn spawn_rst_on_query_server() -> u16 {
747 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
748 let port = listener.local_addr().unwrap().port();
749 thread::spawn(move || {
750 let (mut s, _) = listener.accept().unwrap();
751 let mut buf = [0u8; 4096];
752 let _ = s.read(&mut buf).unwrap();
753 let _ = s.write_all(&pg_auth(3, &[]));
754 let _ = s.read(&mut buf).unwrap();
755 let _ = s.write_all(&post_auth_ok());
756 match s.read(&mut buf) {
757 Ok(0) | Err(_) => {}
758 Ok(_) => {
759 drop(s);
760 }
761 }
762 });
763 thread::sleep(Duration::from_millis(30));
764 port
765 }
766
767 #[test]
768 fn connect_query_ready_for_query_transaction_status() {
769 let port = spawn_transaction_status_server();
770 let mut conn = Connect::new(mock_config(port)).unwrap();
771 let result = conn.query("SELECT 1");
772 assert!(result.is_ok());
773 }
774
775 #[test]
776 fn connect_query_ready_for_query_error_status() {
777 let port = spawn_error_status_server();
778 let mut conn = Connect::new(mock_config(port)).unwrap();
779 let result = conn.query("SELECT 1");
780 assert!(result.is_ok());
781 }
782
783 #[test]
784 fn connect_query_server_closes_after_partial() {
785 let port = spawn_slow_partial_server();
786 let mut conn = Connect::new(mock_config(port)).unwrap();
787 let r1 = conn.query("SELECT 1");
788 assert!(r1.is_ok());
789 let r2 = conn.query("SELECT 1");
790 assert!(r2.is_err());
791 }
792
793 #[test]
794 fn connect_query_server_rst_returns_io_or_connection_error() {
795 let port = spawn_rst_on_query_server();
796 let mut conn = Connect::new(mock_config(port)).unwrap();
797 let result = conn.query("SELECT 1");
798 assert!(result.is_err());
799 }
800
801 #[test]
802 fn connect_read_would_block_max_retries() {
803 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
804 let port = listener.local_addr().unwrap().port();
805 thread::spawn(move || {
806 let (mut s, _) = listener.accept().unwrap();
807 let mut buf = [0u8; 4096];
808 let _ = s.read(&mut buf);
809 let _ = s.write_all(&pg_auth(3, &[]));
810 let _ = s.read(&mut buf);
811 let _ = s.write_all(&post_auth_ok());
812 let _ = s.read(&mut buf);
813 thread::sleep(Duration::from_secs(5));
814 });
815 thread::sleep(Duration::from_millis(30));
816
817 let mut conn = Connect::new(mock_config(port)).unwrap();
818 conn.stream
819 .set_read_timeout(Some(Duration::from_millis(1)))
820 .ok();
821 let result = conn.query("SELECT 1");
822 assert!(result.is_err());
823 let err_str = result.unwrap_err().to_string();
824 assert!(
825 err_str.contains("超时") || err_str.contains("Timeout") || err_str.contains("重试"),
826 "expected timeout error, got: {err_str}"
827 );
828 }
829
830 #[test]
831 fn connect_read_exceeds_max_message_size() {
832 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
833 let port = listener.local_addr().unwrap().port();
834 thread::spawn(move || {
835 let (mut s, _) = listener.accept().unwrap();
836 let mut buf = [0u8; 4096];
837 let _ = s.read(&mut buf);
838 let _ = s.write_all(&pg_auth(3, &[]));
839 let _ = s.read(&mut buf);
840 let _ = s.write_all(&post_auth_ok());
841 let _ = s.read(&mut buf);
842 let big = vec![b'X'; 256];
843 let _ = s.write_all(&big);
844 thread::sleep(Duration::from_secs(2));
845 });
846 thread::sleep(Duration::from_millis(30));
847
848 let mut conn = Connect::new(mock_config(port)).unwrap();
849 let result = conn.query("SELECT 1");
850 assert!(result.is_err());
851 let err_str = result.unwrap_err().to_string();
852 assert!(
853 err_str.contains("最大") || err_str.contains("大小") || err_str.contains("size"),
854 "expected max message size error, got: {err_str}"
855 );
856 }
857
858 #[test]
859 fn connect_read_deadline_timeout() {
860 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
861 let port = listener.local_addr().unwrap().port();
862 thread::spawn(move || {
863 let (mut s, _) = listener.accept().unwrap();
864 let mut buf = [0u8; 4096];
865 let _ = s.read(&mut buf);
866 let _ = s.write_all(&pg_auth(3, &[]));
867 let _ = s.read(&mut buf);
868 let _ = s.write_all(&post_auth_ok());
869 let _ = s.read(&mut buf);
870 for _ in 0..200 {
871 let _ = s.write_all(b"X");
872 thread::sleep(Duration::from_millis(5));
873 }
874 });
875 thread::sleep(Duration::from_millis(30));
876
877 let mut conn = Connect::new(mock_config(port)).unwrap();
878 let result = conn.query("SELECT 1");
879 assert!(result.is_err());
880 }
881
882 #[test]
883 fn connect_read_partial_auth_frame() {
884 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
885 let port = listener.local_addr().unwrap().port();
886 thread::spawn(move || {
887 let (mut s, _) = listener.accept().unwrap();
888 let mut buf = [0u8; 4096];
889 let _ = s.read(&mut buf);
890 let auth = pg_auth(3, &[]);
891 let _ = s.write_all(&auth[..5]);
892 thread::sleep(Duration::from_millis(50));
893 let _ = s.write_all(&auth[5..]);
894 let _ = s.read(&mut buf);
895 let _ = s.write_all(&post_auth_ok());
896 loop {
897 match s.read(&mut buf) {
898 Ok(0) | Err(_) => break,
899 Ok(_) => {
900 let _ = s.write_all(&simple_query_response());
901 }
902 }
903 }
904 });
905 thread::sleep(Duration::from_millis(30));
906
907 let mut conn = Connect::new(mock_config(port)).unwrap();
908 let result = conn.query("SELECT 1");
909 assert!(result.is_ok());
910 }
911}