1use crate::config::Config;
2use crate::error::PgsqlError;
3use crate::packet::{AuthStatus, Packet, SuccessMessage};
4use std::io::{Read, Write};
5use std::net::TcpStream;
6use std::sync::Arc;
7use std::time::Duration;
8
9#[derive(Clone, Debug)]
10pub struct Connect {
11 pub(crate) stream: Arc<TcpStream>,
12 packet: Packet,
13 auth_status: AuthStatus,
14}
15
16impl Connect {
17 pub fn is_valid(&mut self) -> bool {
18 self.query("SELECT 1").is_ok()
19 }
20
21 pub fn _close(&mut self) {
22 let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
23 let _ = self.stream.shutdown(std::net::Shutdown::Both);
24 }
25
26 pub fn new(mut config: Config) -> Result<Connect, PgsqlError> {
27 let stream =
28 TcpStream::connect(config.url()).map_err(|e| PgsqlError::Connection(e.to_string()))?;
29
30 stream
31 .set_read_timeout(Some(Duration::from_secs(30)))
32 .map_err(|e| PgsqlError::Connection(format!("设置读取超时失败: {}", e)))?;
33 stream
34 .set_write_timeout(Some(Duration::from_secs(30)))
35 .map_err(|e| PgsqlError::Connection(format!("设置写入超时失败: {}", e)))?;
36
37 let _ = stream.peer_addr();
38
39 let mut connect = Self {
40 stream: Arc::new(stream),
41 packet: Packet::new(config),
42 auth_status: AuthStatus::None,
43 };
44
45 connect.authenticate()?;
46
47 Ok(connect)
48 }
49
50 fn authenticate(&mut self) -> Result<(), PgsqlError> {
51 self.stream
52 .as_ref()
53 .write_all(&self.packet.pack_first())
54 .map_err(|e| PgsqlError::Auth(format!("发送 startup message 失败: {}", e)))?;
55
56 let data = self.read()?;
57 self.packet.unpack(data, 0)?;
58
59 if !self.packet.md5_salt.is_empty() {
60 self.md5_auth()?;
61 } else if self.packet.auth_mechanism.is_empty() && self.packet.md5_salt.is_empty() {
62 self.cleartext_auth()?;
63 } else {
64 self.scram_auth()?;
65 }
66
67 self.auth_status = AuthStatus::AuthenticationOk;
68 Ok(())
69 }
70
71 fn md5_auth(&mut self) -> Result<(), PgsqlError> {
72 self.stream
73 .as_ref()
74 .write_all(&self.packet.pack_md5_password())
75 .map_err(|e| PgsqlError::Auth(format!("发送 MD5 密码失败: {}", e)))?;
76
77 let data = self.read()?;
78 self.packet.unpack(data, 0)?;
79 Ok(())
80 }
81
82 fn cleartext_auth(&mut self) -> Result<(), PgsqlError> {
83 self.stream
84 .as_ref()
85 .write_all(&self.packet.pack_cleartext_password())
86 .map_err(|e| PgsqlError::Auth(format!("发送明文密码失败: {}", e)))?;
87
88 let data = self.read()?;
89 self.packet.unpack(data, 0)?;
90 Ok(())
91 }
92
93 fn scram_auth(&mut self) -> Result<(), PgsqlError> {
94 self.stream
95 .as_ref()
96 .write_all(&self.packet.pack_auth())
97 .map_err(|e| PgsqlError::Auth(format!("发送 SASL Initial Response 失败: {}", e)))?;
98
99 let data = self.read()?;
100 self.packet.unpack(data, 0)?;
101
102 self.stream
103 .as_ref()
104 .write_all(&self.packet.pack_auth_verify())
105 .map_err(|e| PgsqlError::Auth(format!("发送 SASL Verify 失败: {}", e)))?;
106
107 let data = self.read()?;
108 self.packet.unpack(data, 0)?;
109 Ok(())
110 }
111
112 fn read(&mut self) -> Result<Vec<u8>, PgsqlError> {
113 let mut msg = Vec::new();
114 let mut buf = [0u8; 4096];
115 let mut retry_count = 0;
116
117 #[cfg(not(test))]
118 const MAX_RETRIES: u32 = 100;
119 #[cfg(test)]
120 const MAX_RETRIES: u32 = 3;
121
122 #[cfg(not(test))]
123 const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024;
124 #[cfg(test)]
125 const MAX_MESSAGE_SIZE: usize = 128;
126
127 #[cfg(not(test))]
128 let deadline = std::time::Instant::now() + Duration::from_secs(300);
129 #[cfg(test)]
130 let deadline = std::time::Instant::now() + Duration::from_millis(200);
131
132 loop {
133 if std::time::Instant::now() >= deadline {
134 return Err(PgsqlError::Timeout("读取总超时".into()));
135 }
136
137 match self.stream.as_ref().read(&mut buf) {
138 Ok(0) => return Err(PgsqlError::Connection("连接已关闭或服务端断开".into())),
139 Ok(n) => {
140 if msg.len() + n > MAX_MESSAGE_SIZE {
141 return Err(PgsqlError::Protocol("消息超过最大允许大小".into()));
142 }
143 msg.extend_from_slice(&buf[..n]);
144 retry_count = 0;
145 }
146 Err(ref e)
147 if e.kind() == std::io::ErrorKind::WouldBlock
148 || e.kind() == std::io::ErrorKind::TimedOut =>
149 {
150 retry_count += 1;
151 if retry_count > MAX_RETRIES {
152 return Err(PgsqlError::Timeout("读取超时,已达最大重试次数".into()));
153 }
154 std::thread::sleep(Duration::from_millis(10));
155 continue;
156 }
157 Err(e) => return Err(PgsqlError::Io(e)),
158 };
159
160 if let AuthStatus::AuthenticationOk = self.auth_status {
161 if msg.ends_with(&[90, 0, 0, 0, 5, 73])
162 || msg.ends_with(&[90, 0, 0, 0, 5, 84])
163 || msg.ends_with(&[90, 0, 0, 0, 5, 69])
164 {
165 break;
166 }
167 } else if msg.len() >= 5 {
168 let len_bytes = &msg[1..=4];
169 if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
170 if msg.len() > len as usize {
171 break;
172 }
173 }
174 }
175 }
176
177 Ok(msg)
178 }
179
180 pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
181 self.stream
182 .as_ref()
183 .write_all(&self.packet.pack_query(sql))
184 .map_err(PgsqlError::Io)?;
185
186 let data = self.read()?;
187
188 self.packet.unpack(data, 0)
189 }
190
191 pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
192 self.stream
193 .as_ref()
194 .write_all(&self.packet.pack_execute(sql))
195 .map_err(PgsqlError::Io)?;
196 let data = self.read()?;
197 self.packet.unpack(data, 0)
198 }
199
200 pub fn query_params(
202 &mut self,
203 sql: &str,
204 params: &[Option<&str>],
205 ) -> Result<SuccessMessage, PgsqlError> {
206 self.stream
207 .as_ref()
208 .write_all(&self.packet.pack_query_params(sql, params))
209 .map_err(PgsqlError::Io)?;
210
211 let data = self.read()?;
212 self.packet.unpack(data, 0)
213 }
214
215 pub fn execute_params(
217 &mut self,
218 sql: &str,
219 params: &[Option<&str>],
220 ) -> Result<SuccessMessage, PgsqlError> {
221 self.stream
222 .as_ref()
223 .write_all(&self.packet.pack_execute_params(sql, params))
224 .map_err(PgsqlError::Io)?;
225 let data = self.read()?;
226 self.packet.unpack(data, 0)
227 }
228
229 pub fn query_str(&mut self, sql: &str, params: &[&str]) -> Result<SuccessMessage, PgsqlError> {
231 let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
232 self.query_params(sql, &opts)
233 }
234
235 pub fn execute_str(
237 &mut self,
238 sql: &str,
239 params: &[&str],
240 ) -> Result<SuccessMessage, PgsqlError> {
241 let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
242 self.execute_params(sql, &opts)
243 }
244}
245
246impl Drop for Connect {
247 fn drop(&mut self) {
248 let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
249 let _ = self.stream.shutdown(std::net::Shutdown::Both);
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use std::net::TcpListener;
257 use std::thread;
258
259 fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
263 let mut m = Vec::with_capacity(5 + payload.len());
264 m.push(tag);
265 m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
266 m.extend_from_slice(payload);
267 m
268 }
269
270 fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
272 let mut body = Vec::new();
273 body.extend(&auth_type.to_be_bytes());
274 body.extend_from_slice(extra);
275 pg_msg(b'R', &body)
276 }
277
278 fn auth_ok() -> Vec<u8> {
280 pg_auth(0, &[])
281 }
282
283 fn param_status() -> Vec<u8> {
285 pg_msg(b'S', b"server_version\x0015.0\x00")
286 }
287
288 fn backend_key() -> Vec<u8> {
290 let mut p = Vec::new();
291 p.extend(&1u32.to_be_bytes());
292 p.extend(&2u32.to_be_bytes());
293 pg_msg(b'K', &p)
294 }
295
296 fn ready_for_query() -> Vec<u8> {
298 pg_msg(b'Z', b"I")
299 }
300
301 fn post_auth_ok() -> Vec<u8> {
303 let mut v = Vec::new();
304 v.extend(auth_ok());
305 v.extend(param_status());
306 v.extend(backend_key());
307 v.extend(ready_for_query());
308 v
309 }
310
311 fn simple_query_response() -> Vec<u8> {
314 let mut r = Vec::new();
315 r.extend(pg_msg(b'1', &[]));
317 r.extend(pg_msg(b'2', &[]));
319 let mut rd = Vec::new();
321 rd.extend(&1u16.to_be_bytes()); rd.extend(b"c\x00"); rd.extend(&0u32.to_be_bytes()); rd.extend(&1u16.to_be_bytes()); rd.extend(&23u32.to_be_bytes()); rd.extend(&4i16.to_be_bytes()); rd.extend(&(-1i32).to_be_bytes()); rd.extend(&0u16.to_be_bytes()); r.extend(pg_msg(b'T', &rd));
330 let mut dr = Vec::new();
332 dr.extend(&1u16.to_be_bytes());
333 dr.extend(&1u32.to_be_bytes()); dr.push(b'1');
335 r.extend(pg_msg(b'D', &dr));
336 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
338 r.extend(ready_for_query());
340 r
341 }
342
343 fn execute_response() -> Vec<u8> {
346 let mut r = Vec::new();
347 r.extend(pg_msg(b'1', &[]));
348 r.extend(pg_msg(b'2', &[]));
349 r.extend(pg_msg(b'n', &[])); r.extend(pg_msg(b'C', b"UPDATE 3\x00"));
351 r.extend(ready_for_query());
352 r
353 }
354
355 fn query_params_response() -> Vec<u8> {
358 let mut r = Vec::new();
359 r.extend(pg_msg(b'1', &[]));
360
361 let mut pd = Vec::new();
362 pd.extend(&1u16.to_be_bytes());
363 pd.extend(&23u32.to_be_bytes());
364 r.extend(pg_msg(b't', &pd));
365
366 r.extend(pg_msg(b'2', &[]));
367
368 let mut rd = Vec::new();
369 rd.extend(&1u16.to_be_bytes());
370 rd.extend(b"p\x00");
371 rd.extend(&0u32.to_be_bytes());
372 rd.extend(&1u16.to_be_bytes());
373 rd.extend(&23u32.to_be_bytes());
374 rd.extend(&4i16.to_be_bytes());
375 rd.extend(&(-1i32).to_be_bytes());
376 rd.extend(&0u16.to_be_bytes());
377 r.extend(pg_msg(b'T', &rd));
378
379 let mut dr = Vec::new();
380 dr.extend(&1u16.to_be_bytes());
381 dr.extend(&2u32.to_be_bytes());
382 dr.extend(b"42");
383 r.extend(pg_msg(b'D', &dr));
384
385 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
386 r.extend(ready_for_query());
387 r
388 }
389
390 fn execute_params_response() -> Vec<u8> {
393 let mut r = Vec::new();
394 r.extend(pg_msg(b'1', &[]));
395
396 let mut pd = Vec::new();
397 pd.extend(&1u16.to_be_bytes());
398 pd.extend(&23u32.to_be_bytes());
399 r.extend(pg_msg(b't', &pd));
400
401 r.extend(pg_msg(b'2', &[]));
402 r.extend(pg_msg(b'n', &[]));
403 r.extend(pg_msg(b'C', b"UPDATE 1\x00"));
404 r.extend(ready_for_query());
405 r
406 }
407
408 fn query_params_null_response() -> Vec<u8> {
410 let mut r = Vec::new();
411 r.extend(pg_msg(b'1', &[]));
412
413 let mut pd = Vec::new();
414 pd.extend(&1u16.to_be_bytes());
415 pd.extend(&25u32.to_be_bytes());
416 r.extend(pg_msg(b't', &pd));
417
418 r.extend(pg_msg(b'2', &[]));
419
420 let mut rd = Vec::new();
421 rd.extend(&1u16.to_be_bytes());
422 rd.extend(b"n\x00");
423 rd.extend(&0u32.to_be_bytes());
424 rd.extend(&1u16.to_be_bytes());
425 rd.extend(&25u32.to_be_bytes());
426 rd.extend(&(-1i16).to_be_bytes());
427 rd.extend(&(-1i32).to_be_bytes());
428 rd.extend(&0u16.to_be_bytes());
429 r.extend(pg_msg(b'T', &rd));
430
431 let mut dr = Vec::new();
432 dr.extend(&1u16.to_be_bytes());
433 dr.extend(&(-1i32).to_be_bytes());
434 r.extend(pg_msg(b'D', &dr));
435
436 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
437 r.extend(ready_for_query());
438 r
439 }
440
441 fn error_response() -> Vec<u8> {
443 let mut payload = Vec::new();
444 payload.push(b'C');
445 payload.extend(b"42601\x00");
446 payload.push(b'M');
447 payload.extend(b"syntax error\x00");
448 payload.push(0);
449 let mut r = Vec::new();
450 r.extend(pg_msg(b'1', &[]));
451 r.extend(pg_msg(b'2', &[]));
452 r.extend(pg_msg(b'E', &payload));
453 r.extend(ready_for_query());
454 r
455 }
456
457 fn mock_config(port: u16) -> Config {
461 Config {
462 debug: false,
463 hostname: "127.0.0.1".into(),
464 hostport: port as i32,
465 username: "u".into(),
466 userpass: "p".into(),
467 database: "d".into(),
468 charset: "utf8".into(),
469 pool_max: 5,
470 }
471 }
472
473 fn spawn_cleartext_server() -> u16 {
476 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
477 let port = listener.local_addr().unwrap().port();
478 thread::spawn(move || {
479 let (mut s, _) = listener.accept().unwrap();
480 let mut buf = [0u8; 4096];
481 let _ = s.read(&mut buf).unwrap();
483 let _ = s.write_all(&pg_auth(3, &[]));
485 let _ = s.read(&mut buf).unwrap();
487 let _ = s.write_all(&post_auth_ok());
489 loop {
491 match s.read(&mut buf) {
492 Ok(0) | Err(_) => break,
493 Ok(_) => {
494 let _ = s.write_all(&simple_query_response());
495 }
496 }
497 }
498 });
499 thread::sleep(Duration::from_millis(30));
500 port
501 }
502
503 fn spawn_md5_server() -> u16 {
505 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
506 let port = listener.local_addr().unwrap().port();
507 thread::spawn(move || {
508 let (mut s, _) = listener.accept().unwrap();
509 let mut buf = [0u8; 4096];
510 let _ = s.read(&mut buf).unwrap();
512 let _ = s.write_all(&pg_auth(5, &[0xAA, 0xBB, 0xCC, 0xDD]));
514 let _ = s.read(&mut buf).unwrap();
516 let _ = s.write_all(&post_auth_ok());
518 loop {
519 match s.read(&mut buf) {
520 Ok(0) | Err(_) => break,
521 Ok(_) => {
522 let _ = s.write_all(&simple_query_response());
523 }
524 }
525 }
526 });
527 thread::sleep(Duration::from_millis(30));
528 port
529 }
530
531 fn spawn_scram_server() -> u16 {
533 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
534 let port = listener.local_addr().unwrap().port();
535 thread::spawn(move || {
536 let (mut s, _) = listener.accept().unwrap();
537 let mut buf = [0u8; 4096];
538 let _ = s.read(&mut buf).unwrap();
540 let _ = s.write_all(&pg_auth(10, b"SCRAM-SHA-256\x00\x00"));
542 let n = s.read(&mut buf).unwrap();
544 let payload = &buf[..n];
545 let text = String::from_utf8_lossy(payload);
547 let client_nonce = text.split("r=").nth(1).unwrap_or("clientnonce").to_string();
548 let challenge = format!("r={client_nonce}SERVERNONCE,s=c2FsdA==,i=4096");
550 let _ = s.write_all(&pg_auth(11, challenge.as_bytes()));
551 let _ = s.read(&mut buf).unwrap();
553 let mut resp = Vec::new();
555 resp.extend(pg_auth(12, b"v=dummyproof"));
556 resp.extend(auth_ok());
557 resp.extend(param_status());
558 resp.extend(backend_key());
559 resp.extend(ready_for_query());
560 let _ = s.write_all(&resp);
561 loop {
562 match s.read(&mut buf) {
563 Ok(0) | Err(_) => break,
564 Ok(_) => {
565 let _ = s.write_all(&simple_query_response());
566 }
567 }
568 }
569 });
570 thread::sleep(Duration::from_millis(30));
571 port
572 }
573
574 fn spawn_eof_server() -> u16 {
576 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
577 let port = listener.local_addr().unwrap().port();
578 thread::spawn(move || {
579 let (s, _) = listener.accept().unwrap();
580 drop(s); });
582 thread::sleep(Duration::from_millis(30));
583 port
584 }
585
586 fn spawn_auth_error_server() -> u16 {
588 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
589 let port = listener.local_addr().unwrap().port();
590 thread::spawn(move || {
591 let (mut s, _) = listener.accept().unwrap();
592 let mut buf = [0u8; 4096];
593 let _ = s.read(&mut buf).unwrap();
594 let mut payload = Vec::new();
596 payload.push(b'C');
597 payload.extend(b"28P01\x00");
598 payload.push(b'M');
599 payload.extend(b"password authentication failed\x00");
600 payload.push(0);
601 let _ = s.write_all(&pg_msg(b'E', &payload));
602 });
603 thread::sleep(Duration::from_millis(30));
604 port
605 }
606
607 fn spawn_query_error_server() -> u16 {
609 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
610 let port = listener.local_addr().unwrap().port();
611 thread::spawn(move || {
612 let (mut s, _) = listener.accept().unwrap();
613 let mut buf = [0u8; 4096];
614 let _ = s.read(&mut buf).unwrap();
616 let _ = s.write_all(&pg_auth(3, &[]));
617 let _ = s.read(&mut buf).unwrap();
618 let _ = s.write_all(&post_auth_ok());
619 loop {
621 match s.read(&mut buf) {
622 Ok(0) | Err(_) => break,
623 Ok(_) => {
624 let _ = s.write_all(&error_response());
625 }
626 }
627 }
628 });
629 thread::sleep(Duration::from_millis(30));
630 port
631 }
632
633 fn spawn_execute_server() -> u16 {
635 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
636 let port = listener.local_addr().unwrap().port();
637 thread::spawn(move || {
638 let (mut s, _) = listener.accept().unwrap();
639 let mut buf = [0u8; 4096];
640 let _ = s.read(&mut buf).unwrap();
642 let _ = s.write_all(&pg_auth(3, &[]));
643 let _ = s.read(&mut buf).unwrap();
644 let _ = s.write_all(&post_auth_ok());
645 loop {
647 match s.read(&mut buf) {
648 Ok(0) | Err(_) => break,
649 Ok(_) => {
650 let _ = s.write_all(&execute_response());
651 }
652 }
653 }
654 });
655 thread::sleep(Duration::from_millis(30));
656 port
657 }
658
659 fn spawn_query_params_server() -> u16 {
661 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
662 let port = listener.local_addr().unwrap().port();
663 thread::spawn(move || {
664 let (mut s, _) = listener.accept().unwrap();
665 let mut buf = [0u8; 4096];
666 let _ = s.read(&mut buf).unwrap();
667 let _ = s.write_all(&pg_auth(3, &[]));
668 let _ = s.read(&mut buf).unwrap();
669 let _ = s.write_all(&post_auth_ok());
670 loop {
671 match s.read(&mut buf) {
672 Ok(0) | Err(_) => break,
673 Ok(_) => {
674 let _ = s.write_all(&query_params_response());
675 }
676 }
677 }
678 });
679 thread::sleep(Duration::from_millis(30));
680 port
681 }
682
683 fn spawn_params_server() -> u16 {
684 spawn_query_params_server()
685 }
686
687 fn spawn_execute_params_server() -> u16 {
689 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
690 let port = listener.local_addr().unwrap().port();
691 thread::spawn(move || {
692 let (mut s, _) = listener.accept().unwrap();
693 let mut buf = [0u8; 4096];
694 let _ = s.read(&mut buf).unwrap();
695 let _ = s.write_all(&pg_auth(3, &[]));
696 let _ = s.read(&mut buf).unwrap();
697 let _ = s.write_all(&post_auth_ok());
698 loop {
699 match s.read(&mut buf) {
700 Ok(0) | Err(_) => break,
701 Ok(_) => {
702 let _ = s.write_all(&execute_params_response());
703 }
704 }
705 }
706 });
707 thread::sleep(Duration::from_millis(30));
708 port
709 }
710
711 fn spawn_query_params_null_server() -> u16 {
713 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
714 let port = listener.local_addr().unwrap().port();
715 thread::spawn(move || {
716 let (mut s, _) = listener.accept().unwrap();
717 let mut buf = [0u8; 4096];
718 let _ = s.read(&mut buf).unwrap();
719 let _ = s.write_all(&pg_auth(3, &[]));
720 let _ = s.read(&mut buf).unwrap();
721 let _ = s.write_all(&post_auth_ok());
722 loop {
723 match s.read(&mut buf) {
724 Ok(0) | Err(_) => break,
725 Ok(_) => {
726 let _ = s.write_all(&query_params_null_response());
727 }
728 }
729 }
730 });
731 thread::sleep(Duration::from_millis(30));
732 port
733 }
734
735 #[test]
738 fn connect_cleartext_auth_success() {
739 let port = spawn_cleartext_server();
740 let conn = Connect::new(mock_config(port));
741 assert!(conn.is_ok());
742 }
743
744 #[test]
745 fn connect_md5_auth_success() {
746 let port = spawn_md5_server();
747 let conn = Connect::new(mock_config(port));
748 assert!(conn.is_ok());
749 }
750
751 #[test]
752 fn connect_scram_auth_success() {
753 let port = spawn_scram_server();
754 let conn = Connect::new(mock_config(port));
755 assert!(conn.is_ok());
756 }
757
758 #[test]
759 fn connect_connection_refused() {
760 let cfg = mock_config(1);
762 let result = Connect::new(cfg);
763 assert!(result.is_err());
764 match result.unwrap_err() {
765 PgsqlError::Connection(_) => {}
766 other => panic!("expected Connection error, got {other:?}"),
767 }
768 }
769
770 #[test]
771 fn connect_server_closes_immediately() {
772 let port = spawn_eof_server();
773 let result = Connect::new(mock_config(port));
774 assert!(result.is_err());
775 }
776
777 #[test]
778 fn connect_auth_error_from_server() {
779 let port = spawn_auth_error_server();
780 let result = Connect::new(mock_config(port));
781 assert!(result.is_err());
782 }
783
784 #[test]
785 fn connect_query_success() {
786 let port = spawn_cleartext_server();
787 let mut conn = Connect::new(mock_config(port)).unwrap();
788 let result = conn.query("SELECT 1");
789 assert!(result.is_ok());
790 let msg = result.unwrap();
791 assert_eq!(msg.rows.len(), 1);
792 assert_eq!(msg.rows[0]["c"].as_i32(), Some(1));
793 }
794
795 #[test]
796 fn connect_execute_success() {
797 let port = spawn_execute_server();
798 let mut conn = Connect::new(mock_config(port)).unwrap();
799 let result = conn.execute("UPDATE t SET x=1");
800 assert!(result.is_ok());
801 let msg = result.unwrap();
802 assert_eq!(msg.affect_count, 3);
803 assert_eq!(msg.tag, "UPDATE 3");
804 }
805
806 #[test]
807 fn connect_query_params_success() {
808 let port = spawn_query_params_server();
809 let mut conn = Connect::new(mock_config(port)).unwrap();
810 let result = conn.query_params("SELECT $1::int", &[Some("42")]);
811 assert!(result.is_ok());
812 let msg = result.unwrap();
813 assert_eq!(msg.rows.len(), 1);
814 assert_eq!(msg.rows[0]["p"].as_i32(), Some(42));
815 }
816
817 #[test]
818 fn connect_execute_params_success() {
819 let port = spawn_execute_params_server();
820 let mut conn = Connect::new(mock_config(port)).unwrap();
821 let result = conn.execute_params("UPDATE t SET x=$1", &[Some("42")]);
822 assert!(result.is_ok());
823 let msg = result.unwrap();
824 assert_eq!(msg.affect_count, 1);
825 assert_eq!(msg.tag, "UPDATE 1");
826 }
827
828 #[test]
829 fn connect_query_str_success() {
830 let port = spawn_params_server();
831 let mut conn = Connect::new(mock_config(port)).unwrap();
832 let result = conn.query_str("SELECT $1::int", &["42"]);
833 assert!(result.is_ok());
834 let msg = result.unwrap();
835 assert_eq!(msg.rows.len(), 1);
836 }
837
838 #[test]
839 fn connect_execute_str_success() {
840 let port = spawn_execute_params_server();
841 let mut conn = Connect::new(mock_config(port)).unwrap();
842 let result = conn.execute_str("UPDATE t SET x=$1", &["1"]);
843 assert!(result.is_ok());
844 let msg = result.unwrap();
845 assert_eq!(msg.affect_count, 1);
846 }
847
848 #[test]
849 fn connect_query_params_with_null() {
850 let port = spawn_query_params_null_server();
851 let mut conn = Connect::new(mock_config(port)).unwrap();
852 let result = conn.query_params("SELECT $1::text", &[None]);
853 assert!(result.is_ok());
854 let msg = result.unwrap();
855 assert_eq!(msg.rows.len(), 1);
856 assert!(msg.rows[0]["n"].is_null());
857 }
858
859 #[test]
860 fn connect_query_params_empty_string_vs_null() {
861 let port = spawn_params_server();
862 let mut conn = Connect::new(mock_config(port)).unwrap();
863
864 let r1 = conn.query_params("SELECT $1::text", &[Some("")]);
866 assert!(r1.is_ok());
867
868 let r2 = conn.query_params("SELECT $1::text", &[None]);
870 assert!(r2.is_ok());
871 }
872
873 #[test]
874 fn connect_query_returns_error() {
875 let port = spawn_query_error_server();
876 let mut conn = Connect::new(mock_config(port)).unwrap();
877 let result = conn.query("BAD SQL");
878 assert!(result.is_err());
879 }
880
881 #[test]
882 fn connect_is_valid_true() {
883 let port = spawn_cleartext_server();
884 let mut conn = Connect::new(mock_config(port)).unwrap();
885 assert!(conn.is_valid());
886 }
887
888 #[test]
889 fn connect_is_valid_false_after_close() {
890 let port = spawn_cleartext_server();
891 let mut conn = Connect::new(mock_config(port)).unwrap();
892 conn._close();
893 assert!(!conn.is_valid());
895 }
896
897 #[test]
898 fn connect_close_does_not_panic() {
899 let port = spawn_cleartext_server();
900 let mut conn = Connect::new(mock_config(port)).unwrap();
901 conn._close();
902 conn._close();
904 }
905
906 #[test]
907 fn connect_drop_does_not_panic() {
908 let port = spawn_cleartext_server();
909 let conn = Connect::new(mock_config(port)).unwrap();
910 drop(conn);
911 }
912
913 fn spawn_transaction_status_server() -> u16 {
914 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
915 let port = listener.local_addr().unwrap().port();
916 thread::spawn(move || {
917 let (mut s, _) = listener.accept().unwrap();
918 let mut buf = [0u8; 4096];
919 let _ = s.read(&mut buf).unwrap();
920 let _ = s.write_all(&pg_auth(3, &[]));
921 let _ = s.read(&mut buf).unwrap();
922 let _ = s.write_all(&post_auth_ok());
923 loop {
924 match s.read(&mut buf) {
925 Ok(0) | Err(_) => break,
926 Ok(_) => {
927 let mut r = Vec::new();
928 r.extend(pg_msg(b'1', &[]));
929 r.extend(pg_msg(b'2', &[]));
930 let mut rd = Vec::new();
931 rd.extend(&1u16.to_be_bytes());
932 rd.extend(b"c\x00");
933 rd.extend(&0u32.to_be_bytes());
934 rd.extend(&1u16.to_be_bytes());
935 rd.extend(&23u32.to_be_bytes());
936 rd.extend(&4i16.to_be_bytes());
937 rd.extend(&(-1i32).to_be_bytes());
938 rd.extend(&0u16.to_be_bytes());
939 r.extend(pg_msg(b'T', &rd));
940 let mut dr = Vec::new();
941 dr.extend(&1u16.to_be_bytes());
942 dr.extend(&1u32.to_be_bytes());
943 dr.push(b'1');
944 r.extend(pg_msg(b'D', &dr));
945 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
946 r.extend(pg_msg(b'Z', b"T"));
947 let _ = s.write_all(&r);
948 }
949 }
950 }
951 });
952 thread::sleep(Duration::from_millis(30));
953 port
954 }
955
956 fn spawn_error_status_server() -> u16 {
957 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
958 let port = listener.local_addr().unwrap().port();
959 thread::spawn(move || {
960 let (mut s, _) = listener.accept().unwrap();
961 let mut buf = [0u8; 4096];
962 let _ = s.read(&mut buf).unwrap();
963 let _ = s.write_all(&pg_auth(3, &[]));
964 let _ = s.read(&mut buf).unwrap();
965 let _ = s.write_all(&post_auth_ok());
966 loop {
967 match s.read(&mut buf) {
968 Ok(0) | Err(_) => break,
969 Ok(_) => {
970 let mut r = Vec::new();
971 r.extend(pg_msg(b'1', &[]));
972 r.extend(pg_msg(b'2', &[]));
973 let mut rd = Vec::new();
974 rd.extend(&1u16.to_be_bytes());
975 rd.extend(b"c\x00");
976 rd.extend(&0u32.to_be_bytes());
977 rd.extend(&1u16.to_be_bytes());
978 rd.extend(&23u32.to_be_bytes());
979 rd.extend(&4i16.to_be_bytes());
980 rd.extend(&(-1i32).to_be_bytes());
981 rd.extend(&0u16.to_be_bytes());
982 r.extend(pg_msg(b'T', &rd));
983 let mut dr = Vec::new();
984 dr.extend(&1u16.to_be_bytes());
985 dr.extend(&1u32.to_be_bytes());
986 dr.push(b'1');
987 r.extend(pg_msg(b'D', &dr));
988 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
989 r.extend(pg_msg(b'Z', b"E"));
990 let _ = s.write_all(&r);
991 }
992 }
993 }
994 });
995 thread::sleep(Duration::from_millis(30));
996 port
997 }
998
999 fn spawn_slow_partial_server() -> u16 {
1000 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1001 let port = listener.local_addr().unwrap().port();
1002 thread::spawn(move || {
1003 let (mut s, _) = listener.accept().unwrap();
1004 let mut buf = [0u8; 4096];
1005 let _ = s.read(&mut buf).unwrap();
1006 let _ = s.write_all(&pg_auth(3, &[]));
1007 let _ = s.read(&mut buf).unwrap();
1008 let _ = s.write_all(&post_auth_ok());
1009 match s.read(&mut buf) {
1010 Ok(0) | Err(_) => {}
1011 Ok(_) => {
1012 let _ = s.write_all(&simple_query_response());
1013 }
1014 }
1015 });
1016 thread::sleep(Duration::from_millis(30));
1017 port
1018 }
1019
1020 fn spawn_rst_on_query_server() -> u16 {
1021 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1022 let port = listener.local_addr().unwrap().port();
1023 thread::spawn(move || {
1024 let (mut s, _) = listener.accept().unwrap();
1025 let mut buf = [0u8; 4096];
1026 let _ = s.read(&mut buf).unwrap();
1027 let _ = s.write_all(&pg_auth(3, &[]));
1028 let _ = s.read(&mut buf).unwrap();
1029 let _ = s.write_all(&post_auth_ok());
1030 match s.read(&mut buf) {
1031 Ok(0) | Err(_) => {}
1032 Ok(_) => {
1033 drop(s);
1034 }
1035 }
1036 });
1037 thread::sleep(Duration::from_millis(30));
1038 port
1039 }
1040
1041 #[test]
1042 fn connect_query_ready_for_query_transaction_status() {
1043 let port = spawn_transaction_status_server();
1044 let mut conn = Connect::new(mock_config(port)).unwrap();
1045 let result = conn.query("SELECT 1");
1046 assert!(result.is_ok());
1047 }
1048
1049 #[test]
1050 fn connect_query_ready_for_query_error_status() {
1051 let port = spawn_error_status_server();
1052 let mut conn = Connect::new(mock_config(port)).unwrap();
1053 let result = conn.query("SELECT 1");
1054 assert!(result.is_ok());
1055 }
1056
1057 #[test]
1058 fn connect_query_server_closes_after_partial() {
1059 let port = spawn_slow_partial_server();
1060 let mut conn = Connect::new(mock_config(port)).unwrap();
1061 let r1 = conn.query("SELECT 1");
1062 assert!(r1.is_ok());
1063 let r2 = conn.query("SELECT 1");
1064 assert!(r2.is_err());
1065 }
1066
1067 #[test]
1068 fn connect_query_server_rst_returns_io_or_connection_error() {
1069 let port = spawn_rst_on_query_server();
1070 let mut conn = Connect::new(mock_config(port)).unwrap();
1071 let result = conn.query("SELECT 1");
1072 assert!(result.is_err());
1073 }
1074
1075 #[test]
1076 fn connect_read_would_block_max_retries() {
1077 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1078 let port = listener.local_addr().unwrap().port();
1079 thread::spawn(move || {
1080 let (mut s, _) = listener.accept().unwrap();
1081 let mut buf = [0u8; 4096];
1082 let _ = s.read(&mut buf);
1083 let _ = s.write_all(&pg_auth(3, &[]));
1084 let _ = s.read(&mut buf);
1085 let _ = s.write_all(&post_auth_ok());
1086 let _ = s.read(&mut buf);
1087 thread::sleep(Duration::from_secs(5));
1088 });
1089 thread::sleep(Duration::from_millis(30));
1090
1091 let mut conn = Connect::new(mock_config(port)).unwrap();
1092 conn.stream
1093 .set_read_timeout(Some(Duration::from_millis(1)))
1094 .ok();
1095 let result = conn.query("SELECT 1");
1096 assert!(result.is_err());
1097 let err_str = result.unwrap_err().to_string();
1098 assert!(
1099 err_str.contains("超时") || err_str.contains("Timeout") || err_str.contains("重试"),
1100 "expected timeout error, got: {err_str}"
1101 );
1102 }
1103
1104 #[test]
1105 fn connect_read_exceeds_max_message_size() {
1106 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1107 let port = listener.local_addr().unwrap().port();
1108 thread::spawn(move || {
1109 let (mut s, _) = listener.accept().unwrap();
1110 let mut buf = [0u8; 4096];
1111 let _ = s.read(&mut buf);
1112 let _ = s.write_all(&pg_auth(3, &[]));
1113 let _ = s.read(&mut buf);
1114 let _ = s.write_all(&post_auth_ok());
1115 let _ = s.read(&mut buf);
1116 let big = vec![b'X'; 256];
1117 let _ = s.write_all(&big);
1118 thread::sleep(Duration::from_secs(2));
1119 });
1120 thread::sleep(Duration::from_millis(30));
1121
1122 let mut conn = Connect::new(mock_config(port)).unwrap();
1123 let result = conn.query("SELECT 1");
1124 assert!(result.is_err());
1125 let err_str = result.unwrap_err().to_string();
1126 assert!(
1127 err_str.contains("最大") || err_str.contains("大小") || err_str.contains("size"),
1128 "expected max message size error, got: {err_str}"
1129 );
1130 }
1131
1132 #[test]
1133 fn connect_read_deadline_timeout() {
1134 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1135 let port = listener.local_addr().unwrap().port();
1136 thread::spawn(move || {
1137 let (mut s, _) = listener.accept().unwrap();
1138 let mut buf = [0u8; 4096];
1139 let _ = s.read(&mut buf);
1140 let _ = s.write_all(&pg_auth(3, &[]));
1141 let _ = s.read(&mut buf);
1142 let _ = s.write_all(&post_auth_ok());
1143 let _ = s.read(&mut buf);
1144 for _ in 0..200 {
1145 let _ = s.write_all(b"X");
1146 thread::sleep(Duration::from_millis(5));
1147 }
1148 });
1149 thread::sleep(Duration::from_millis(30));
1150
1151 let mut conn = Connect::new(mock_config(port)).unwrap();
1152 let result = conn.query("SELECT 1");
1153 assert!(result.is_err());
1154 }
1155
1156 #[test]
1157 fn connect_read_partial_auth_frame() {
1158 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1159 let port = listener.local_addr().unwrap().port();
1160 thread::spawn(move || {
1161 let (mut s, _) = listener.accept().unwrap();
1162 let mut buf = [0u8; 4096];
1163 let _ = s.read(&mut buf);
1164 let auth = pg_auth(3, &[]);
1165 let _ = s.write_all(&auth[..5]);
1166 thread::sleep(Duration::from_millis(50));
1167 let _ = s.write_all(&auth[5..]);
1168 let _ = s.read(&mut buf);
1169 let _ = s.write_all(&post_auth_ok());
1170 loop {
1171 match s.read(&mut buf) {
1172 Ok(0) | Err(_) => break,
1173 Ok(_) => {
1174 let _ = s.write_all(&simple_query_response());
1175 }
1176 }
1177 }
1178 });
1179 thread::sleep(Duration::from_millis(30));
1180
1181 let mut conn = Connect::new(mock_config(port)).unwrap();
1182 let result = conn.query("SELECT 1");
1183 assert!(result.is_ok());
1184 }
1185}