1use crate::config::Config;
2use crate::error::PgsqlError;
3use crate::packet::{AuthStatus, Packet, SuccessMessage};
4use std::io::{Read, Write};
5use std::net::TcpStream;
6use std::sync::Arc;
7use std::time::Duration;
8
9#[derive(Clone, Debug)]
10pub struct Connect {
11 pub(crate) stream: Arc<TcpStream>,
12 packet: Packet,
13 auth_status: AuthStatus,
14}
15
16impl Connect {
17 pub fn is_valid(&mut self) -> bool {
18 self.query("SELECT 1").is_ok()
19 }
20
21 pub fn _close(&mut self) {
22 let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
23 let _ = self.stream.shutdown(std::net::Shutdown::Both);
24 }
25
26 pub fn new(mut config: Config) -> Result<Connect, PgsqlError> {
27 let stream =
28 TcpStream::connect(config.url()).map_err(|e| PgsqlError::Connection(e.to_string()))?;
29
30 stream
31 .set_read_timeout(Some(Duration::from_secs(30)))
32 .map_err(|e| PgsqlError::Connection(format!("设置读取超时失败: {}", e)))?;
33 stream
34 .set_write_timeout(Some(Duration::from_secs(30)))
35 .map_err(|e| PgsqlError::Connection(format!("设置写入超时失败: {}", e)))?;
36
37 let _ = stream.peer_addr();
38
39 let mut connect = Self {
40 stream: Arc::new(stream),
41 packet: Packet::new(config),
42 auth_status: AuthStatus::None,
43 };
44
45 connect.authenticate()?;
46
47 Ok(connect)
48 }
49
50 fn authenticate(&mut self) -> Result<(), PgsqlError> {
51 self.stream
52 .as_ref()
53 .write_all(&self.packet.pack_first())
54 .map_err(|e| PgsqlError::Auth(format!("发送 startup message 失败: {}", e)))?;
55
56 let data = self.read()?;
57 self.packet.unpack(data, 0)?;
58
59 if !self.packet.md5_salt.is_empty() {
60 self.md5_auth()?;
61 } else if self.packet.auth_mechanism.is_empty() && self.packet.md5_salt.is_empty() {
62 self.cleartext_auth()?;
63 } else {
64 self.scram_auth()?;
65 }
66
67 self.auth_status = AuthStatus::AuthenticationOk;
68 Ok(())
69 }
70
71 fn md5_auth(&mut self) -> Result<(), PgsqlError> {
72 self.stream
73 .as_ref()
74 .write_all(&self.packet.pack_md5_password())
75 .map_err(|e| PgsqlError::Auth(format!("发送 MD5 密码失败: {}", e)))?;
76
77 let data = self.read()?;
78 self.packet.unpack(data, 0)?;
79 Ok(())
80 }
81
82 fn cleartext_auth(&mut self) -> Result<(), PgsqlError> {
83 self.stream
84 .as_ref()
85 .write_all(&self.packet.pack_cleartext_password())
86 .map_err(|e| PgsqlError::Auth(format!("发送明文密码失败: {}", e)))?;
87
88 let data = self.read()?;
89 self.packet.unpack(data, 0)?;
90 Ok(())
91 }
92
93 fn scram_auth(&mut self) -> Result<(), PgsqlError> {
94 self.stream
95 .as_ref()
96 .write_all(&self.packet.pack_auth())
97 .map_err(|e| PgsqlError::Auth(format!("发送 SASL Initial Response 失败: {}", e)))?;
98
99 let data = self.read()?;
100 self.packet.unpack(data, 0)?;
101
102 self.stream
103 .as_ref()
104 .write_all(&self.packet.pack_auth_verify())
105 .map_err(|e| PgsqlError::Auth(format!("发送 SASL Verify 失败: {}", e)))?;
106
107 let data = self.read()?;
108 self.packet.unpack(data, 0)?;
109 Ok(())
110 }
111
112 fn read(&mut self) -> Result<Vec<u8>, PgsqlError> {
113 let mut msg = Vec::new();
114 let mut buf = [0u8; 4096];
115 let mut retry_count = 0;
116
117 #[cfg(not(test))]
118 const MAX_RETRIES: u32 = 100;
119 #[cfg(test)]
120 const MAX_RETRIES: u32 = 3;
121
122 #[cfg(not(test))]
123 const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024;
124 #[cfg(test)]
125 const MAX_MESSAGE_SIZE: usize = 128;
126
127 #[cfg(not(test))]
128 let deadline = std::time::Instant::now() + Duration::from_secs(300);
129 #[cfg(test)]
130 let deadline = std::time::Instant::now() + Duration::from_millis(200);
131
132 loop {
133 if std::time::Instant::now() >= deadline {
134 return Err(PgsqlError::Timeout("读取总超时".into()));
135 }
136
137 match self.stream.as_ref().read(&mut buf) {
138 Ok(0) => return Err(PgsqlError::Connection("连接已关闭或服务端断开".into())),
139 Ok(n) => {
140 if msg.len() + n > MAX_MESSAGE_SIZE {
141 return Err(PgsqlError::Protocol("消息超过最大允许大小".into()));
142 }
143 msg.extend_from_slice(&buf[..n]);
144 retry_count = 0;
145 }
146 Err(ref e)
147 if e.kind() == std::io::ErrorKind::WouldBlock
148 || e.kind() == std::io::ErrorKind::TimedOut =>
149 {
150 retry_count += 1;
151 if retry_count > MAX_RETRIES {
152 return Err(PgsqlError::Timeout("读取超时,已达最大重试次数".into()));
153 }
154 std::thread::sleep(Duration::from_millis(10));
155 continue;
156 }
157 Err(e) => return Err(PgsqlError::Io(e)),
158 };
159
160 if let AuthStatus::AuthenticationOk = self.auth_status {
161 if msg.ends_with(&[90, 0, 0, 0, 5, 73])
162 || msg.ends_with(&[90, 0, 0, 0, 5, 84])
163 || msg.ends_with(&[90, 0, 0, 0, 5, 69])
164 {
165 break;
166 }
167 } else if msg.len() >= 5 {
168 let len_bytes = &msg[1..=4];
169 if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
170 if msg.len() > len as usize {
171 break;
172 }
173 }
174 }
175 }
176
177 Ok(msg)
178 }
179
180 pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
181 self.stream
182 .as_ref()
183 .write_all(&self.packet.pack_query(sql))
184 .map_err(PgsqlError::Io)?;
185
186 let data = self.read()?;
187
188 self.packet.unpack(data, 0)
189 }
190
191 pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
192 self.stream
193 .as_ref()
194 .write_all(&self.packet.pack_execute(sql))
195 .map_err(PgsqlError::Io)?;
196 let data = self.read()?;
197 self.packet.unpack(data, 0)
198 }
199
200 pub fn query_params(
202 &mut self,
203 sql: &str,
204 params: &[Option<&str>],
205 ) -> Result<SuccessMessage, PgsqlError> {
206 self.stream
207 .as_ref()
208 .write_all(&self.packet.pack_query_params(sql, params))
209 .map_err(PgsqlError::Io)?;
210
211 let data = self.read()?;
212 self.packet.unpack(data, 0)
213 }
214
215 pub fn execute_params(
217 &mut self,
218 sql: &str,
219 params: &[Option<&str>],
220 ) -> Result<SuccessMessage, PgsqlError> {
221 self.stream
222 .as_ref()
223 .write_all(&self.packet.pack_execute_params(sql, params))
224 .map_err(PgsqlError::Io)?;
225 let data = self.read()?;
226 self.packet.unpack(data, 0)
227 }
228
229 pub fn query_str(&mut self, sql: &str, params: &[&str]) -> Result<SuccessMessage, PgsqlError> {
231 let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
232 self.query_params(sql, &opts)
233 }
234
235 pub fn execute_str(
237 &mut self,
238 sql: &str,
239 params: &[&str],
240 ) -> Result<SuccessMessage, PgsqlError> {
241 let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
242 self.execute_params(sql, &opts)
243 }
244}
245
246impl Drop for Connect {
247 fn drop(&mut self) {
248 let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
249 let _ = self.stream.shutdown(std::net::Shutdown::Both);
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use std::net::TcpListener;
257 use std::thread;
258
259 fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
263 let mut m = Vec::with_capacity(5 + payload.len());
264 m.push(tag);
265 m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
266 m.extend_from_slice(payload);
267 m
268 }
269
270 fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
272 let mut body = Vec::new();
273 body.extend(&auth_type.to_be_bytes());
274 body.extend_from_slice(extra);
275 pg_msg(b'R', &body)
276 }
277
278 fn auth_ok() -> Vec<u8> {
280 pg_auth(0, &[])
281 }
282
283 fn param_status() -> Vec<u8> {
285 pg_msg(b'S', b"server_version\x0015.0\x00")
286 }
287
288 fn backend_key() -> Vec<u8> {
290 let mut p = Vec::new();
291 p.extend(&1u32.to_be_bytes());
292 p.extend(&2u32.to_be_bytes());
293 pg_msg(b'K', &p)
294 }
295
296 fn ready_for_query() -> Vec<u8> {
298 pg_msg(b'Z', b"I")
299 }
300
301 fn post_auth_ok() -> Vec<u8> {
303 let mut v = Vec::new();
304 v.extend(auth_ok());
305 v.extend(param_status());
306 v.extend(backend_key());
307 v.extend(ready_for_query());
308 v
309 }
310
311 fn simple_query_response() -> Vec<u8> {
314 let mut r = Vec::new();
315 r.extend(pg_msg(b'1', &[]));
317 r.extend(pg_msg(b'2', &[]));
319 let mut rd = Vec::new();
321 rd.extend(&1u16.to_be_bytes()); rd.extend(b"c\x00"); rd.extend(&0u32.to_be_bytes()); rd.extend(&1u16.to_be_bytes()); rd.extend(&23u32.to_be_bytes()); rd.extend(&4i16.to_be_bytes()); rd.extend(&(-1i32).to_be_bytes()); rd.extend(&0u16.to_be_bytes()); r.extend(pg_msg(b'T', &rd));
330 let mut dr = Vec::new();
332 dr.extend(&1u16.to_be_bytes());
333 dr.extend(&1u32.to_be_bytes()); dr.push(b'1');
335 r.extend(pg_msg(b'D', &dr));
336 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
338 r.extend(ready_for_query());
340 r
341 }
342
343 fn execute_response() -> Vec<u8> {
346 let mut r = Vec::new();
347 r.extend(pg_msg(b'1', &[]));
348 r.extend(pg_msg(b'2', &[]));
349 r.extend(pg_msg(b'n', &[])); r.extend(pg_msg(b'C', b"UPDATE 3\x00"));
351 r.extend(ready_for_query());
352 r
353 }
354
355 fn query_params_response() -> Vec<u8> {
358 let mut r = Vec::new();
359 r.extend(pg_msg(b'1', &[]));
360
361 let mut pd = Vec::new();
362 pd.extend(&1u16.to_be_bytes());
363 pd.extend(&23u32.to_be_bytes());
364 r.extend(pg_msg(b't', &pd));
365
366 r.extend(pg_msg(b'2', &[]));
367
368 let mut rd = Vec::new();
369 rd.extend(&1u16.to_be_bytes());
370 rd.extend(b"p\x00");
371 rd.extend(&0u32.to_be_bytes());
372 rd.extend(&1u16.to_be_bytes());
373 rd.extend(&23u32.to_be_bytes());
374 rd.extend(&4i16.to_be_bytes());
375 rd.extend(&(-1i32).to_be_bytes());
376 rd.extend(&0u16.to_be_bytes());
377 r.extend(pg_msg(b'T', &rd));
378
379 let mut dr = Vec::new();
380 dr.extend(&1u16.to_be_bytes());
381 dr.extend(&2u32.to_be_bytes());
382 dr.extend(b"42");
383 r.extend(pg_msg(b'D', &dr));
384
385 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
386 r.extend(ready_for_query());
387 r
388 }
389
390 fn execute_params_response() -> Vec<u8> {
393 let mut r = Vec::new();
394 r.extend(pg_msg(b'1', &[]));
395
396 let mut pd = Vec::new();
397 pd.extend(&1u16.to_be_bytes());
398 pd.extend(&23u32.to_be_bytes());
399 r.extend(pg_msg(b't', &pd));
400
401 r.extend(pg_msg(b'2', &[]));
402 r.extend(pg_msg(b'n', &[]));
403 r.extend(pg_msg(b'C', b"UPDATE 1\x00"));
404 r.extend(ready_for_query());
405 r
406 }
407
408 fn query_params_null_response() -> Vec<u8> {
410 let mut r = Vec::new();
411 r.extend(pg_msg(b'1', &[]));
412
413 let mut pd = Vec::new();
414 pd.extend(&1u16.to_be_bytes());
415 pd.extend(&25u32.to_be_bytes());
416 r.extend(pg_msg(b't', &pd));
417
418 r.extend(pg_msg(b'2', &[]));
419
420 let mut rd = Vec::new();
421 rd.extend(&1u16.to_be_bytes());
422 rd.extend(b"n\x00");
423 rd.extend(&0u32.to_be_bytes());
424 rd.extend(&1u16.to_be_bytes());
425 rd.extend(&25u32.to_be_bytes());
426 rd.extend(&(-1i16).to_be_bytes());
427 rd.extend(&(-1i32).to_be_bytes());
428 rd.extend(&0u16.to_be_bytes());
429 r.extend(pg_msg(b'T', &rd));
430
431 let mut dr = Vec::new();
432 dr.extend(&1u16.to_be_bytes());
433 dr.extend(&(-1i32).to_be_bytes());
434 r.extend(pg_msg(b'D', &dr));
435
436 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
437 r.extend(ready_for_query());
438 r
439 }
440
441 fn error_response() -> Vec<u8> {
443 let mut payload = Vec::new();
444 payload.push(b'C');
445 payload.extend(b"42601\x00");
446 payload.push(b'M');
447 payload.extend(b"syntax error\x00");
448 payload.push(0);
449 let mut r = Vec::new();
450 r.extend(pg_msg(b'1', &[]));
451 r.extend(pg_msg(b'2', &[]));
452 r.extend(pg_msg(b'E', &payload));
453 r.extend(ready_for_query());
454 r
455 }
456
457 fn mock_config(port: u16) -> Config {
461 Config {
462 debug: false,
463 hostname: "127.0.0.1".into(),
464 hostport: port as i32,
465 username: "u".into(),
466 userpass: "p".into(),
467 database: "d".into(),
468 charset: "utf8".into(),
469 pool_max: 5,
470 }
471 }
472
473 fn spawn_cleartext_server() -> u16 {
476 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
477 let port = listener.local_addr().unwrap().port();
478 thread::spawn(move || {
479 let (mut s, _) = listener.accept().unwrap();
480 let mut buf = [0u8; 4096];
481 let _ = s.read(&mut buf).unwrap();
483 let _ = s.write_all(&pg_auth(3, &[]));
485 let _ = s.read(&mut buf).unwrap();
487 let _ = s.write_all(&post_auth_ok());
489 loop {
491 match s.read(&mut buf) {
492 Ok(0) | Err(_) => break,
493 Ok(_) => {
494 let _ = s.write_all(&simple_query_response());
495 }
496 }
497 }
498 });
499 thread::sleep(Duration::from_millis(30));
500 port
501 }
502
503 fn spawn_md5_server() -> u16 {
505 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
506 let port = listener.local_addr().unwrap().port();
507 thread::spawn(move || {
508 let (mut s, _) = listener.accept().unwrap();
509 let mut buf = [0u8; 4096];
510 let _ = s.read(&mut buf).unwrap();
512 let _ = s.write_all(&pg_auth(5, &[0xAA, 0xBB, 0xCC, 0xDD]));
514 let _ = s.read(&mut buf).unwrap();
516 let _ = s.write_all(&post_auth_ok());
518 loop {
519 match s.read(&mut buf) {
520 Ok(0) | Err(_) => break,
521 Ok(_) => {
522 let _ = s.write_all(&simple_query_response());
523 }
524 }
525 }
526 });
527 thread::sleep(Duration::from_millis(30));
528 port
529 }
530
531 fn spawn_scram_server() -> u16 {
533 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
534 let port = listener.local_addr().unwrap().port();
535 thread::spawn(move || {
536 let (mut s, _) = listener.accept().unwrap();
537 let mut buf = [0u8; 4096];
538 let _ = s.read(&mut buf).unwrap();
540 let _ = s.write_all(&pg_auth(10, b"SCRAM-SHA-256\x00\x00"));
542 let n = s.read(&mut buf).unwrap();
544 let payload = &buf[..n];
545 let text = String::from_utf8_lossy(payload);
547 let client_nonce = text.split("r=").nth(1).unwrap_or("clientnonce").to_string();
548 let challenge = format!("r={client_nonce}SERVERNONCE,s=c2FsdA==,i=4096");
550 let _ = s.write_all(&pg_auth(11, challenge.as_bytes()));
551 let _ = s.read(&mut buf).unwrap();
553 let mut resp = Vec::new();
555 resp.extend(pg_auth(12, b"v=dummyproof"));
556 resp.extend(auth_ok());
557 resp.extend(param_status());
558 resp.extend(backend_key());
559 resp.extend(ready_for_query());
560 let _ = s.write_all(&resp);
561 loop {
562 match s.read(&mut buf) {
563 Ok(0) | Err(_) => break,
564 Ok(_) => {
565 let _ = s.write_all(&simple_query_response());
566 }
567 }
568 }
569 });
570 thread::sleep(Duration::from_millis(30));
571 port
572 }
573
574 fn spawn_eof_server() -> u16 {
576 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
577 let port = listener.local_addr().unwrap().port();
578 thread::spawn(move || {
579 let (s, _) = listener.accept().unwrap();
580 drop(s); });
582 thread::sleep(Duration::from_millis(30));
583 port
584 }
585
586 fn spawn_auth_error_server() -> u16 {
588 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
589 let port = listener.local_addr().unwrap().port();
590 thread::spawn(move || {
591 let (mut s, _) = listener.accept().unwrap();
592 let mut buf = [0u8; 4096];
593 let _ = s.read(&mut buf).unwrap();
594 let mut payload = Vec::new();
596 payload.push(b'C');
597 payload.extend(b"28P01\x00");
598 payload.push(b'M');
599 payload.extend(b"password authentication failed\x00");
600 payload.push(0);
601 let _ = s.write_all(&pg_msg(b'E', &payload));
602 });
603 thread::sleep(Duration::from_millis(30));
604 port
605 }
606
607 fn spawn_query_error_server() -> u16 {
609 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
610 let port = listener.local_addr().unwrap().port();
611 thread::spawn(move || {
612 let (mut s, _) = listener.accept().unwrap();
613 let mut buf = [0u8; 4096];
614 let _ = s.read(&mut buf).unwrap();
616 let _ = s.write_all(&pg_auth(3, &[]));
617 let _ = s.read(&mut buf).unwrap();
618 let _ = s.write_all(&post_auth_ok());
619 loop {
621 match s.read(&mut buf) {
622 Ok(0) | Err(_) => break,
623 Ok(_) => {
624 let _ = s.write_all(&error_response());
625 }
626 }
627 }
628 });
629 thread::sleep(Duration::from_millis(30));
630 port
631 }
632
633 fn spawn_execute_server() -> u16 {
635 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
636 let port = listener.local_addr().unwrap().port();
637 thread::spawn(move || {
638 let (mut s, _) = listener.accept().unwrap();
639 let mut buf = [0u8; 4096];
640 let _ = s.read(&mut buf).unwrap();
642 let _ = s.write_all(&pg_auth(3, &[]));
643 let _ = s.read(&mut buf).unwrap();
644 let _ = s.write_all(&post_auth_ok());
645 loop {
647 match s.read(&mut buf) {
648 Ok(0) | Err(_) => break,
649 Ok(_) => {
650 let _ = s.write_all(&execute_response());
651 }
652 }
653 }
654 });
655 thread::sleep(Duration::from_millis(30));
656 port
657 }
658
659 fn spawn_query_params_server() -> u16 {
661 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
662 let port = listener.local_addr().unwrap().port();
663 thread::spawn(move || {
664 let (mut s, _) = listener.accept().unwrap();
665 let mut buf = [0u8; 4096];
666 let _ = s.read(&mut buf).unwrap();
667 let _ = s.write_all(&pg_auth(3, &[]));
668 let _ = s.read(&mut buf).unwrap();
669 let _ = s.write_all(&post_auth_ok());
670 loop {
671 match s.read(&mut buf) {
672 Ok(0) | Err(_) => break,
673 Ok(_) => {
674 let _ = s.write_all(&query_params_response());
675 }
676 }
677 }
678 });
679 thread::sleep(Duration::from_millis(30));
680 port
681 }
682
683 fn spawn_params_server() -> u16 {
684 spawn_query_params_server()
685 }
686
687 fn spawn_execute_params_server() -> u16 {
689 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
690 let port = listener.local_addr().unwrap().port();
691 thread::spawn(move || {
692 let (mut s, _) = listener.accept().unwrap();
693 let mut buf = [0u8; 4096];
694 let _ = s.read(&mut buf).unwrap();
695 let _ = s.write_all(&pg_auth(3, &[]));
696 let _ = s.read(&mut buf).unwrap();
697 let _ = s.write_all(&post_auth_ok());
698 loop {
699 match s.read(&mut buf) {
700 Ok(0) | Err(_) => break,
701 Ok(_) => {
702 let _ = s.write_all(&execute_params_response());
703 }
704 }
705 }
706 });
707 thread::sleep(Duration::from_millis(30));
708 port
709 }
710
711 fn spawn_query_params_null_server() -> u16 {
713 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
714 let port = listener.local_addr().unwrap().port();
715 thread::spawn(move || {
716 let (mut s, _) = listener.accept().unwrap();
717 let mut buf = [0u8; 4096];
718 let _ = s.read(&mut buf).unwrap();
719 let _ = s.write_all(&pg_auth(3, &[]));
720 let _ = s.read(&mut buf).unwrap();
721 let _ = s.write_all(&post_auth_ok());
722 loop {
723 match s.read(&mut buf) {
724 Ok(0) | Err(_) => break,
725 Ok(_) => {
726 let _ = s.write_all(&query_params_null_response());
727 }
728 }
729 }
730 });
731 thread::sleep(Duration::from_millis(30));
732 port
733 }
734
735 #[test]
738 fn connect_cleartext_auth_success() {
739 let port = spawn_cleartext_server();
740 let conn = Connect::new(mock_config(port));
741 assert!(conn.is_ok());
742 }
743
744 #[test]
745 fn connect_md5_auth_success() {
746 let port = spawn_md5_server();
747 let conn = Connect::new(mock_config(port));
748 assert!(conn.is_ok());
749 }
750
751 #[test]
752 fn connect_scram_auth_success() {
753 let port = spawn_scram_server();
754 let conn = Connect::new(mock_config(port));
755 assert!(conn.is_ok());
756 }
757
758 #[test]
759 fn connect_connection_refused() {
760 let cfg = mock_config(1);
762 let result = Connect::new(cfg);
763 assert!(result.is_err());
764 match result.unwrap_err() {
765 PgsqlError::Connection(_) => {}
766 other => panic!("expected Connection error, got {other:?}"),
767 }
768 }
769
770 #[test]
771 fn connect_server_closes_immediately() {
772 let port = spawn_eof_server();
773 let result = Connect::new(mock_config(port));
774 assert!(result.is_err());
775 }
776
777 #[test]
778 fn connect_auth_error_from_server() {
779 let port = spawn_auth_error_server();
780 let result = Connect::new(mock_config(port));
781 assert!(result.is_err());
782 }
783
784 #[test]
785 fn connect_query_success() {
786 let port = spawn_cleartext_server();
787 let mut conn = Connect::new(mock_config(port)).unwrap();
788 let result = conn.query("SELECT 1");
789 assert!(result.is_ok());
790 let msg = result.unwrap();
791 assert_eq!(msg.rows.len(), 1);
792 assert_eq!(msg.rows[0]["c"].as_i32(), Some(1));
793 }
794
795 #[test]
796 fn connect_execute_success() {
797 let port = spawn_execute_server();
798 let mut conn = Connect::new(mock_config(port)).unwrap();
799 let result = conn.execute("UPDATE t SET x=1");
800 assert!(result.is_ok());
801 let msg = result.unwrap();
802 assert_eq!(msg.affect_count, 3);
803 assert_eq!(msg.tag, "UPDATE 3");
804 }
805
806 #[test]
807 fn connect_query_params_success() {
808 let port = spawn_query_params_server();
809 let mut conn = Connect::new(mock_config(port)).unwrap();
810 let result = conn.query_params("SELECT $1::int", &[Some("42")]);
811 assert!(result.is_ok());
812 let msg = result.unwrap();
813 assert!(!msg.param_oids.is_empty());
814 assert_eq!(msg.rows.len(), 1);
815 assert_eq!(msg.rows[0]["p"].as_i32(), Some(42));
816 }
817
818 #[test]
819 fn connect_execute_params_success() {
820 let port = spawn_execute_params_server();
821 let mut conn = Connect::new(mock_config(port)).unwrap();
822 let result = conn.execute_params("UPDATE t SET x=$1", &[Some("42")]);
823 assert!(result.is_ok());
824 let msg = result.unwrap();
825 assert!(!msg.param_oids.is_empty());
826 assert_eq!(msg.affect_count, 1);
827 assert_eq!(msg.tag, "UPDATE 1");
828 }
829
830 #[test]
831 fn connect_query_str_success() {
832 let port = spawn_params_server();
833 let mut conn = Connect::new(mock_config(port)).unwrap();
834 let result = conn.query_str("SELECT $1::int", &["42"]);
835 assert!(result.is_ok());
836 let msg = result.unwrap();
837 assert!(!msg.param_oids.is_empty());
838 assert_eq!(msg.rows.len(), 1);
839 }
840
841 #[test]
842 fn connect_execute_str_success() {
843 let port = spawn_execute_params_server();
844 let mut conn = Connect::new(mock_config(port)).unwrap();
845 let result = conn.execute_str("UPDATE t SET x=$1", &["1"]);
846 assert!(result.is_ok());
847 let msg = result.unwrap();
848 assert!(!msg.param_oids.is_empty());
849 assert_eq!(msg.affect_count, 1);
850 }
851
852 #[test]
853 fn connect_query_params_with_null() {
854 let port = spawn_query_params_null_server();
855 let mut conn = Connect::new(mock_config(port)).unwrap();
856 let result = conn.query_params("SELECT $1::text", &[None]);
857 assert!(result.is_ok());
858 let msg = result.unwrap();
859 assert!(!msg.param_oids.is_empty());
860 assert_eq!(msg.rows.len(), 1);
861 assert!(msg.rows[0]["n"].is_null());
862 }
863
864 #[test]
865 fn connect_query_params_empty_string_vs_null() {
866 let port = spawn_params_server();
867 let mut conn = Connect::new(mock_config(port)).unwrap();
868
869 let r1 = conn.query_params("SELECT $1::text", &[Some("")]);
871 assert!(r1.is_ok());
872
873 let r2 = conn.query_params("SELECT $1::text", &[None]);
875 assert!(r2.is_ok());
876 }
877
878 #[test]
879 fn connect_query_returns_error() {
880 let port = spawn_query_error_server();
881 let mut conn = Connect::new(mock_config(port)).unwrap();
882 let result = conn.query("BAD SQL");
883 assert!(result.is_err());
884 }
885
886 #[test]
887 fn connect_is_valid_true() {
888 let port = spawn_cleartext_server();
889 let mut conn = Connect::new(mock_config(port)).unwrap();
890 assert!(conn.is_valid());
891 }
892
893 #[test]
894 fn connect_is_valid_false_after_close() {
895 let port = spawn_cleartext_server();
896 let mut conn = Connect::new(mock_config(port)).unwrap();
897 conn._close();
898 assert!(!conn.is_valid());
900 }
901
902 #[test]
903 fn connect_close_does_not_panic() {
904 let port = spawn_cleartext_server();
905 let mut conn = Connect::new(mock_config(port)).unwrap();
906 conn._close();
907 conn._close();
909 }
910
911 #[test]
912 fn connect_drop_does_not_panic() {
913 let port = spawn_cleartext_server();
914 let conn = Connect::new(mock_config(port)).unwrap();
915 drop(conn);
916 }
917
918 fn spawn_transaction_status_server() -> u16 {
919 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
920 let port = listener.local_addr().unwrap().port();
921 thread::spawn(move || {
922 let (mut s, _) = listener.accept().unwrap();
923 let mut buf = [0u8; 4096];
924 let _ = s.read(&mut buf).unwrap();
925 let _ = s.write_all(&pg_auth(3, &[]));
926 let _ = s.read(&mut buf).unwrap();
927 let _ = s.write_all(&post_auth_ok());
928 loop {
929 match s.read(&mut buf) {
930 Ok(0) | Err(_) => break,
931 Ok(_) => {
932 let mut r = Vec::new();
933 r.extend(pg_msg(b'1', &[]));
934 r.extend(pg_msg(b'2', &[]));
935 let mut rd = Vec::new();
936 rd.extend(&1u16.to_be_bytes());
937 rd.extend(b"c\x00");
938 rd.extend(&0u32.to_be_bytes());
939 rd.extend(&1u16.to_be_bytes());
940 rd.extend(&23u32.to_be_bytes());
941 rd.extend(&4i16.to_be_bytes());
942 rd.extend(&(-1i32).to_be_bytes());
943 rd.extend(&0u16.to_be_bytes());
944 r.extend(pg_msg(b'T', &rd));
945 let mut dr = Vec::new();
946 dr.extend(&1u16.to_be_bytes());
947 dr.extend(&1u32.to_be_bytes());
948 dr.push(b'1');
949 r.extend(pg_msg(b'D', &dr));
950 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
951 r.extend(pg_msg(b'Z', b"T"));
952 let _ = s.write_all(&r);
953 }
954 }
955 }
956 });
957 thread::sleep(Duration::from_millis(30));
958 port
959 }
960
961 fn spawn_error_status_server() -> u16 {
962 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
963 let port = listener.local_addr().unwrap().port();
964 thread::spawn(move || {
965 let (mut s, _) = listener.accept().unwrap();
966 let mut buf = [0u8; 4096];
967 let _ = s.read(&mut buf).unwrap();
968 let _ = s.write_all(&pg_auth(3, &[]));
969 let _ = s.read(&mut buf).unwrap();
970 let _ = s.write_all(&post_auth_ok());
971 loop {
972 match s.read(&mut buf) {
973 Ok(0) | Err(_) => break,
974 Ok(_) => {
975 let mut r = Vec::new();
976 r.extend(pg_msg(b'1', &[]));
977 r.extend(pg_msg(b'2', &[]));
978 let mut rd = Vec::new();
979 rd.extend(&1u16.to_be_bytes());
980 rd.extend(b"c\x00");
981 rd.extend(&0u32.to_be_bytes());
982 rd.extend(&1u16.to_be_bytes());
983 rd.extend(&23u32.to_be_bytes());
984 rd.extend(&4i16.to_be_bytes());
985 rd.extend(&(-1i32).to_be_bytes());
986 rd.extend(&0u16.to_be_bytes());
987 r.extend(pg_msg(b'T', &rd));
988 let mut dr = Vec::new();
989 dr.extend(&1u16.to_be_bytes());
990 dr.extend(&1u32.to_be_bytes());
991 dr.push(b'1');
992 r.extend(pg_msg(b'D', &dr));
993 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
994 r.extend(pg_msg(b'Z', b"E"));
995 let _ = s.write_all(&r);
996 }
997 }
998 }
999 });
1000 thread::sleep(Duration::from_millis(30));
1001 port
1002 }
1003
1004 fn spawn_slow_partial_server() -> u16 {
1005 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1006 let port = listener.local_addr().unwrap().port();
1007 thread::spawn(move || {
1008 let (mut s, _) = listener.accept().unwrap();
1009 let mut buf = [0u8; 4096];
1010 let _ = s.read(&mut buf).unwrap();
1011 let _ = s.write_all(&pg_auth(3, &[]));
1012 let _ = s.read(&mut buf).unwrap();
1013 let _ = s.write_all(&post_auth_ok());
1014 match s.read(&mut buf) {
1015 Ok(0) | Err(_) => {}
1016 Ok(_) => {
1017 let _ = s.write_all(&simple_query_response());
1018 }
1019 }
1020 });
1021 thread::sleep(Duration::from_millis(30));
1022 port
1023 }
1024
1025 fn spawn_rst_on_query_server() -> u16 {
1026 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1027 let port = listener.local_addr().unwrap().port();
1028 thread::spawn(move || {
1029 let (mut s, _) = listener.accept().unwrap();
1030 let mut buf = [0u8; 4096];
1031 let _ = s.read(&mut buf).unwrap();
1032 let _ = s.write_all(&pg_auth(3, &[]));
1033 let _ = s.read(&mut buf).unwrap();
1034 let _ = s.write_all(&post_auth_ok());
1035 match s.read(&mut buf) {
1036 Ok(0) | Err(_) => {}
1037 Ok(_) => {
1038 drop(s);
1039 }
1040 }
1041 });
1042 thread::sleep(Duration::from_millis(30));
1043 port
1044 }
1045
1046 #[test]
1047 fn connect_query_ready_for_query_transaction_status() {
1048 let port = spawn_transaction_status_server();
1049 let mut conn = Connect::new(mock_config(port)).unwrap();
1050 let result = conn.query("SELECT 1");
1051 assert!(result.is_ok());
1052 }
1053
1054 #[test]
1055 fn connect_query_ready_for_query_error_status() {
1056 let port = spawn_error_status_server();
1057 let mut conn = Connect::new(mock_config(port)).unwrap();
1058 let result = conn.query("SELECT 1");
1059 assert!(result.is_ok());
1060 }
1061
1062 #[test]
1063 fn connect_query_server_closes_after_partial() {
1064 let port = spawn_slow_partial_server();
1065 let mut conn = Connect::new(mock_config(port)).unwrap();
1066 let r1 = conn.query("SELECT 1");
1067 assert!(r1.is_ok());
1068 let r2 = conn.query("SELECT 1");
1069 assert!(r2.is_err());
1070 }
1071
1072 #[test]
1073 fn connect_query_server_rst_returns_io_or_connection_error() {
1074 let port = spawn_rst_on_query_server();
1075 let mut conn = Connect::new(mock_config(port)).unwrap();
1076 let result = conn.query("SELECT 1");
1077 assert!(result.is_err());
1078 }
1079
1080 #[test]
1081 fn connect_read_would_block_max_retries() {
1082 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1083 let port = listener.local_addr().unwrap().port();
1084 thread::spawn(move || {
1085 let (mut s, _) = listener.accept().unwrap();
1086 let mut buf = [0u8; 4096];
1087 let _ = s.read(&mut buf);
1088 let _ = s.write_all(&pg_auth(3, &[]));
1089 let _ = s.read(&mut buf);
1090 let _ = s.write_all(&post_auth_ok());
1091 let _ = s.read(&mut buf);
1092 thread::sleep(Duration::from_secs(5));
1093 });
1094 thread::sleep(Duration::from_millis(30));
1095
1096 let mut conn = Connect::new(mock_config(port)).unwrap();
1097 conn.stream
1098 .set_read_timeout(Some(Duration::from_millis(1)))
1099 .ok();
1100 let result = conn.query("SELECT 1");
1101 assert!(result.is_err());
1102 let err_str = result.unwrap_err().to_string();
1103 assert!(
1104 err_str.contains("超时") || err_str.contains("Timeout") || err_str.contains("重试"),
1105 "expected timeout error, got: {err_str}"
1106 );
1107 }
1108
1109 #[test]
1110 fn connect_read_exceeds_max_message_size() {
1111 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1112 let port = listener.local_addr().unwrap().port();
1113 thread::spawn(move || {
1114 let (mut s, _) = listener.accept().unwrap();
1115 let mut buf = [0u8; 4096];
1116 let _ = s.read(&mut buf);
1117 let _ = s.write_all(&pg_auth(3, &[]));
1118 let _ = s.read(&mut buf);
1119 let _ = s.write_all(&post_auth_ok());
1120 let _ = s.read(&mut buf);
1121 let big = vec![b'X'; 256];
1122 let _ = s.write_all(&big);
1123 thread::sleep(Duration::from_secs(2));
1124 });
1125 thread::sleep(Duration::from_millis(30));
1126
1127 let mut conn = Connect::new(mock_config(port)).unwrap();
1128 let result = conn.query("SELECT 1");
1129 assert!(result.is_err());
1130 let err_str = result.unwrap_err().to_string();
1131 assert!(
1132 err_str.contains("最大") || err_str.contains("大小") || err_str.contains("size"),
1133 "expected max message size error, got: {err_str}"
1134 );
1135 }
1136
1137 #[test]
1138 fn connect_read_deadline_timeout() {
1139 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1140 let port = listener.local_addr().unwrap().port();
1141 thread::spawn(move || {
1142 let (mut s, _) = listener.accept().unwrap();
1143 let mut buf = [0u8; 4096];
1144 let _ = s.read(&mut buf);
1145 let _ = s.write_all(&pg_auth(3, &[]));
1146 let _ = s.read(&mut buf);
1147 let _ = s.write_all(&post_auth_ok());
1148 let _ = s.read(&mut buf);
1149 for _ in 0..200 {
1150 let _ = s.write_all(b"X");
1151 thread::sleep(Duration::from_millis(5));
1152 }
1153 });
1154 thread::sleep(Duration::from_millis(30));
1155
1156 let mut conn = Connect::new(mock_config(port)).unwrap();
1157 let result = conn.query("SELECT 1");
1158 assert!(result.is_err());
1159 }
1160
1161 #[test]
1162 fn connect_read_partial_auth_frame() {
1163 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1164 let port = listener.local_addr().unwrap().port();
1165 thread::spawn(move || {
1166 let (mut s, _) = listener.accept().unwrap();
1167 let mut buf = [0u8; 4096];
1168 let _ = s.read(&mut buf);
1169 let auth = pg_auth(3, &[]);
1170 let _ = s.write_all(&auth[..5]);
1171 thread::sleep(Duration::from_millis(50));
1172 let _ = s.write_all(&auth[5..]);
1173 let _ = s.read(&mut buf);
1174 let _ = s.write_all(&post_auth_ok());
1175 loop {
1176 match s.read(&mut buf) {
1177 Ok(0) | Err(_) => break,
1178 Ok(_) => {
1179 let _ = s.write_all(&simple_query_response());
1180 }
1181 }
1182 }
1183 });
1184 thread::sleep(Duration::from_millis(30));
1185
1186 let mut conn = Connect::new(mock_config(port)).unwrap();
1187 let result = conn.query("SELECT 1");
1188 assert!(result.is_ok());
1189 }
1190}