1use crate::config::Config;
2use crate::error::PgsqlError;
3use crate::packet::{AuthStatus, Packet, SuccessMessage};
4use std::io::{Read, Write};
5use std::net::TcpStream;
6use std::time::{Duration, Instant};
7
8#[derive(Debug)]
9pub struct Connect {
10 pub(crate) stream: TcpStream,
11 packet: Packet,
12 auth_status: AuthStatus,
13 last_used: Instant,
15 created_at: Instant,
17}
18
19impl Connect {
20 pub fn is_valid(&mut self) -> bool {
22 if self.stream.peer_addr().is_err() {
23 return false;
24 }
25 #[cfg(not(test))]
26 const IDLE_THRESHOLD: Duration = Duration::from_secs(5);
27 #[cfg(test)]
28 const IDLE_THRESHOLD: Duration = Duration::from_millis(0);
29 if self.last_used.elapsed() > IDLE_THRESHOLD {
30 return self.query("SELECT 1").is_ok();
31 }
32 true
33 }
34
35 pub fn peer_valid(&self) -> bool {
37 self.stream.peer_addr().is_ok()
38 }
39
40 pub fn touch(&mut self) {
42 self.last_used = Instant::now();
43 }
44
45 pub fn idle_elapsed(&self) -> Duration {
47 self.last_used.elapsed()
48 }
49
50 pub fn age(&self) -> Duration {
52 self.created_at.elapsed()
53 }
54
55 pub fn _close(&mut self) {
56 let _ = (&self.stream).write_all(&Packet::pack_terminate());
57 let _ = self.stream.shutdown(std::net::Shutdown::Both);
58 }
59
60 fn set_keepalive(stream: &TcpStream) -> Result<(), PgsqlError> {
62 use std::os::unix::io::{AsRawFd, FromRawFd};
63 let fd = stream.as_raw_fd();
64 let socket = unsafe { socket2::Socket::from_raw_fd(fd) };
65 let keepalive = socket2::TcpKeepalive::new()
66 .with_time(Duration::from_secs(60))
67 .with_interval(Duration::from_secs(15))
68 .with_retries(3);
69 let result = socket.set_tcp_keepalive(&keepalive);
70 std::mem::forget(socket);
72 result.map_err(|e| PgsqlError::Connection(format!("设置 TCP Keepalive 失败: {}", e)))
73 }
74
75 pub fn new(mut config: Config) -> Result<Connect, PgsqlError> {
76 let stream =
77 TcpStream::connect(config.url()).map_err(|e| PgsqlError::Connection(e.to_string()))?;
78 stream
80 .set_nodelay(true)
81 .map_err(|e| PgsqlError::Connection(format!("设置 TCP_NODELAY 失败: {}", e)))?;
82 Self::set_keepalive(&stream)?;
84 stream
85 .set_read_timeout(Some(Duration::from_secs(30)))
86 .map_err(|e| PgsqlError::Connection(format!("设置读取超时失败: {}", e)))?;
87 stream
88 .set_write_timeout(Some(Duration::from_secs(30)))
89 .map_err(|e| PgsqlError::Connection(format!("设置写入超时失败: {}", e)))?;
90 let _ = stream.peer_addr();
91
92 let mut connect = Self {
93 stream,
94 packet: Packet::new(config),
95 auth_status: AuthStatus::None,
96 last_used: Instant::now(),
97 created_at: Instant::now(),
98 };
99
100 connect.authenticate()?;
101
102 Ok(connect)
103 }
104
105 fn authenticate(&mut self) -> Result<(), PgsqlError> {
106 (&self.stream)
107 .write_all(&self.packet.pack_first())
108 .map_err(|e| PgsqlError::Auth(format!("发送 startup message 失败: {}", e)))?;
109
110 let data = self.read()?;
111 self.packet.unpack(data, 0)?;
112
113 if !self.packet.md5_salt.is_empty() {
114 self.md5_auth()?;
115 } else if self.packet.auth_mechanism.is_empty() && self.packet.md5_salt.is_empty() {
116 self.cleartext_auth()?;
117 } else {
118 self.scram_auth()?;
119 }
120
121 self.auth_status = AuthStatus::AuthenticationOk;
122 Ok(())
123 }
124
125 fn md5_auth(&mut self) -> Result<(), PgsqlError> {
126 (&self.stream)
127 .write_all(&self.packet.pack_md5_password())
128 .map_err(|e| PgsqlError::Auth(format!("发送 MD5 密码失败: {}", e)))?;
129
130 let data = self.read()?;
131 self.packet.unpack(data, 0)?;
132 Ok(())
133 }
134
135 fn cleartext_auth(&mut self) -> Result<(), PgsqlError> {
136 (&self.stream)
137 .write_all(&self.packet.pack_cleartext_password())
138 .map_err(|e| PgsqlError::Auth(format!("发送明文密码失败: {}", e)))?;
139
140 let data = self.read()?;
141 self.packet.unpack(data, 0)?;
142 Ok(())
143 }
144
145 fn scram_auth(&mut self) -> Result<(), PgsqlError> {
146 (&self.stream)
147 .write_all(&self.packet.pack_auth())
148 .map_err(|e| PgsqlError::Auth(format!("发送 SASL Initial Response 失败: {}", e)))?;
149
150 let data = self.read()?;
151 self.packet.unpack(data, 0)?;
152
153 (&self.stream)
154 .write_all(&self.packet.pack_auth_verify())
155 .map_err(|e| PgsqlError::Auth(format!("发送 SASL Verify 失败: {}", e)))?;
156
157 let data = self.read()?;
158 self.packet.unpack(data, 0)?;
159 Ok(())
160 }
161
162 fn read(&mut self) -> Result<Vec<u8>, PgsqlError> {
163 let mut msg = Vec::new();
164 let mut buf = [0u8; 4096];
165 let mut retry_count = 0;
166
167 #[cfg(not(test))]
168 const MAX_RETRIES: u32 = 100;
169 #[cfg(test)]
170 const MAX_RETRIES: u32 = 3;
171
172 #[cfg(not(test))]
173 const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024;
174 #[cfg(test)]
175 const MAX_MESSAGE_SIZE: usize = 128;
176
177 #[cfg(not(test))]
178 let deadline = std::time::Instant::now() + Duration::from_secs(300);
179 #[cfg(test)]
180 let deadline = std::time::Instant::now() + Duration::from_millis(200);
181
182 loop {
183 if std::time::Instant::now() >= deadline {
184 return Err(PgsqlError::Timeout("读取总超时".into()));
185 }
186
187 match (&self.stream).read(&mut buf) {
188 Ok(0) => return Err(PgsqlError::Connection("连接已关闭或服务端断开".into())),
189 Ok(n) => {
190 if msg.len() + n > MAX_MESSAGE_SIZE {
191 return Err(PgsqlError::Protocol("消息超过最大允许大小".into()));
192 }
193 msg.extend_from_slice(&buf[..n]);
194 retry_count = 0;
195 }
196 Err(ref e)
197 if e.kind() == std::io::ErrorKind::WouldBlock
198 || e.kind() == std::io::ErrorKind::TimedOut =>
199 {
200 retry_count += 1;
201 if retry_count > MAX_RETRIES {
202 return Err(PgsqlError::Timeout("读取超时,已达最大重试次数".into()));
203 }
204 std::thread::sleep(Duration::from_millis(10));
205 continue;
206 }
207 Err(e) => return Err(PgsqlError::Io(e)),
208 };
209
210 if let AuthStatus::AuthenticationOk = self.auth_status {
211 if msg.ends_with(&[90, 0, 0, 0, 5, 73])
212 || msg.ends_with(&[90, 0, 0, 0, 5, 84])
213 || msg.ends_with(&[90, 0, 0, 0, 5, 69])
214 {
215 break;
216 }
217 } else if msg.len() >= 5 {
218 let len_bytes = &msg[1..=4];
219 if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
220 if msg.len() > len as usize {
221 break;
222 }
223 }
224 }
225 }
226
227 Ok(msg)
228 }
229
230 pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
231 (&self.stream)
232 .write_all(&self.packet.pack_query(sql))
233 .map_err(PgsqlError::Io)?;
234 let data = self.read()?;
235 self.last_used = Instant::now();
236 self.packet.unpack(data, 0)
237 }
238
239 pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
240 (&self.stream)
241 .write_all(&self.packet.pack_execute(sql))
242 .map_err(PgsqlError::Io)?;
243 let data = self.read()?;
244 self.last_used = Instant::now();
245 self.packet.unpack(data, 0)
246 }
247
248 pub fn query_params(
250 &mut self,
251 sql: &str,
252 params: &[Option<&str>],
253 ) -> Result<SuccessMessage, PgsqlError> {
254 (&self.stream)
255 .write_all(&self.packet.pack_query_params(sql, params))
256 .map_err(PgsqlError::Io)?;
257
258 let data = self.read()?;
259 self.last_used = Instant::now();
260 self.packet.unpack(data, 0)
261 }
262
263 pub fn execute_params(
265 &mut self,
266 sql: &str,
267 params: &[Option<&str>],
268 ) -> Result<SuccessMessage, PgsqlError> {
269 (&self.stream)
270 .write_all(&self.packet.pack_execute_params(sql, params))
271 .map_err(PgsqlError::Io)?;
272 let data = self.read()?;
273 self.last_used = Instant::now();
274 self.packet.unpack(data, 0)
275 }
276
277 pub fn query_str(&mut self, sql: &str, params: &[&str]) -> Result<SuccessMessage, PgsqlError> {
279 let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
280 self.query_params(sql, &opts)
281 }
282
283 pub fn execute_str(
285 &mut self,
286 sql: &str,
287 params: &[&str],
288 ) -> Result<SuccessMessage, PgsqlError> {
289 let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
290 self.execute_params(sql, &opts)
291 }
292}
293
294impl Drop for Connect {
295 fn drop(&mut self) {
296 let _ = (&self.stream).write_all(&Packet::pack_terminate());
297 let _ = self.stream.shutdown(std::net::Shutdown::Both);
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use std::net::TcpListener;
305 use std::thread;
306
307 fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
311 let mut m = Vec::with_capacity(5 + payload.len());
312 m.push(tag);
313 m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
314 m.extend_from_slice(payload);
315 m
316 }
317
318 fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
320 let mut body = Vec::new();
321 body.extend(&auth_type.to_be_bytes());
322 body.extend_from_slice(extra);
323 pg_msg(b'R', &body)
324 }
325
326 fn auth_ok() -> Vec<u8> {
328 pg_auth(0, &[])
329 }
330
331 fn param_status() -> Vec<u8> {
333 pg_msg(b'S', b"server_version\x0015.0\x00")
334 }
335
336 fn backend_key() -> Vec<u8> {
338 let mut p = Vec::new();
339 p.extend(&1u32.to_be_bytes());
340 p.extend(&2u32.to_be_bytes());
341 pg_msg(b'K', &p)
342 }
343
344 fn ready_for_query() -> Vec<u8> {
346 pg_msg(b'Z', b"I")
347 }
348
349 fn post_auth_ok() -> Vec<u8> {
351 let mut v = Vec::new();
352 v.extend(auth_ok());
353 v.extend(param_status());
354 v.extend(backend_key());
355 v.extend(ready_for_query());
356 v
357 }
358
359 fn simple_query_response() -> Vec<u8> {
362 let mut r = Vec::new();
363 r.extend(pg_msg(b'1', &[]));
365 r.extend(pg_msg(b'2', &[]));
367 let mut rd = Vec::new();
369 rd.extend(&1u16.to_be_bytes()); rd.extend(b"c\x00"); rd.extend(&0u32.to_be_bytes()); rd.extend(&1u16.to_be_bytes()); rd.extend(&23u32.to_be_bytes()); rd.extend(&4i16.to_be_bytes()); rd.extend(&(-1i32).to_be_bytes()); rd.extend(&0u16.to_be_bytes()); r.extend(pg_msg(b'T', &rd));
378 let mut dr = Vec::new();
380 dr.extend(&1u16.to_be_bytes());
381 dr.extend(&1u32.to_be_bytes()); dr.push(b'1');
383 r.extend(pg_msg(b'D', &dr));
384 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
386 r.extend(ready_for_query());
388 r
389 }
390
391 fn execute_response() -> Vec<u8> {
394 let mut r = Vec::new();
395 r.extend(pg_msg(b'1', &[]));
396 r.extend(pg_msg(b'2', &[]));
397 r.extend(pg_msg(b'n', &[])); r.extend(pg_msg(b'C', b"UPDATE 3\x00"));
399 r.extend(ready_for_query());
400 r
401 }
402
403 fn query_params_response() -> Vec<u8> {
406 let mut r = Vec::new();
407 r.extend(pg_msg(b'1', &[]));
408
409 let mut pd = Vec::new();
410 pd.extend(&1u16.to_be_bytes());
411 pd.extend(&23u32.to_be_bytes());
412 r.extend(pg_msg(b't', &pd));
413
414 r.extend(pg_msg(b'2', &[]));
415
416 let mut rd = Vec::new();
417 rd.extend(&1u16.to_be_bytes());
418 rd.extend(b"p\x00");
419 rd.extend(&0u32.to_be_bytes());
420 rd.extend(&1u16.to_be_bytes());
421 rd.extend(&23u32.to_be_bytes());
422 rd.extend(&4i16.to_be_bytes());
423 rd.extend(&(-1i32).to_be_bytes());
424 rd.extend(&0u16.to_be_bytes());
425 r.extend(pg_msg(b'T', &rd));
426
427 let mut dr = Vec::new();
428 dr.extend(&1u16.to_be_bytes());
429 dr.extend(&2u32.to_be_bytes());
430 dr.extend(b"42");
431 r.extend(pg_msg(b'D', &dr));
432
433 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
434 r.extend(ready_for_query());
435 r
436 }
437
438 fn execute_params_response() -> Vec<u8> {
441 let mut r = Vec::new();
442 r.extend(pg_msg(b'1', &[]));
443
444 let mut pd = Vec::new();
445 pd.extend(&1u16.to_be_bytes());
446 pd.extend(&23u32.to_be_bytes());
447 r.extend(pg_msg(b't', &pd));
448
449 r.extend(pg_msg(b'2', &[]));
450 r.extend(pg_msg(b'n', &[]));
451 r.extend(pg_msg(b'C', b"UPDATE 1\x00"));
452 r.extend(ready_for_query());
453 r
454 }
455
456 fn query_params_null_response() -> Vec<u8> {
458 let mut r = Vec::new();
459 r.extend(pg_msg(b'1', &[]));
460
461 let mut pd = Vec::new();
462 pd.extend(&1u16.to_be_bytes());
463 pd.extend(&25u32.to_be_bytes());
464 r.extend(pg_msg(b't', &pd));
465
466 r.extend(pg_msg(b'2', &[]));
467
468 let mut rd = Vec::new();
469 rd.extend(&1u16.to_be_bytes());
470 rd.extend(b"n\x00");
471 rd.extend(&0u32.to_be_bytes());
472 rd.extend(&1u16.to_be_bytes());
473 rd.extend(&25u32.to_be_bytes());
474 rd.extend(&(-1i16).to_be_bytes());
475 rd.extend(&(-1i32).to_be_bytes());
476 rd.extend(&0u16.to_be_bytes());
477 r.extend(pg_msg(b'T', &rd));
478
479 let mut dr = Vec::new();
480 dr.extend(&1u16.to_be_bytes());
481 dr.extend(&(-1i32).to_be_bytes());
482 r.extend(pg_msg(b'D', &dr));
483
484 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
485 r.extend(ready_for_query());
486 r
487 }
488
489 fn error_response() -> Vec<u8> {
491 let mut payload = Vec::new();
492 payload.push(b'C');
493 payload.extend(b"42601\x00");
494 payload.push(b'M');
495 payload.extend(b"syntax error\x00");
496 payload.push(0);
497 let mut r = Vec::new();
498 r.extend(pg_msg(b'1', &[]));
499 r.extend(pg_msg(b'2', &[]));
500 r.extend(pg_msg(b'E', &payload));
501 r.extend(ready_for_query());
502 r
503 }
504
505 fn mock_config(port: u16) -> Config {
509 Config {
510 debug: false,
511 hostname: "127.0.0.1".into(),
512 hostport: port as i32,
513 username: "u".into(),
514 userpass: "p".into(),
515 database: "d".into(),
516 charset: "utf8".into(),
517 pool_max: 5,
518 }
519 }
520
521 fn spawn_cleartext_server() -> u16 {
524 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
525 let port = listener.local_addr().unwrap().port();
526 thread::spawn(move || {
527 let (mut s, _) = listener.accept().unwrap();
528 let mut buf = [0u8; 4096];
529 let _ = s.read(&mut buf).unwrap();
531 let _ = s.write_all(&pg_auth(3, &[]));
533 let _ = s.read(&mut buf).unwrap();
535 let _ = s.write_all(&post_auth_ok());
537 loop {
539 match s.read(&mut buf) {
540 Ok(0) | Err(_) => break,
541 Ok(_) => {
542 let _ = s.write_all(&simple_query_response());
543 }
544 }
545 }
546 });
547 thread::sleep(Duration::from_millis(30));
548 port
549 }
550
551 fn spawn_md5_server() -> u16 {
553 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
554 let port = listener.local_addr().unwrap().port();
555 thread::spawn(move || {
556 let (mut s, _) = listener.accept().unwrap();
557 let mut buf = [0u8; 4096];
558 let _ = s.read(&mut buf).unwrap();
560 let _ = s.write_all(&pg_auth(5, &[0xAA, 0xBB, 0xCC, 0xDD]));
562 let _ = s.read(&mut buf).unwrap();
564 let _ = s.write_all(&post_auth_ok());
566 loop {
567 match s.read(&mut buf) {
568 Ok(0) | Err(_) => break,
569 Ok(_) => {
570 let _ = s.write_all(&simple_query_response());
571 }
572 }
573 }
574 });
575 thread::sleep(Duration::from_millis(30));
576 port
577 }
578
579 fn spawn_scram_server() -> u16 {
581 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
582 let port = listener.local_addr().unwrap().port();
583 thread::spawn(move || {
584 let (mut s, _) = listener.accept().unwrap();
585 let mut buf = [0u8; 4096];
586 let _ = s.read(&mut buf).unwrap();
588 let _ = s.write_all(&pg_auth(10, b"SCRAM-SHA-256\x00\x00"));
590 let n = s.read(&mut buf).unwrap();
592 let payload = &buf[..n];
593 let text = String::from_utf8_lossy(payload);
595 let client_nonce = text.split("r=").nth(1).unwrap_or("clientnonce").to_string();
596 let challenge = format!("r={client_nonce}SERVERNONCE,s=c2FsdA==,i=4096");
598 let _ = s.write_all(&pg_auth(11, challenge.as_bytes()));
599 let _ = s.read(&mut buf).unwrap();
601 let mut resp = Vec::new();
603 resp.extend(pg_auth(12, b"v=dummyproof"));
604 resp.extend(auth_ok());
605 resp.extend(param_status());
606 resp.extend(backend_key());
607 resp.extend(ready_for_query());
608 let _ = s.write_all(&resp);
609 loop {
610 match s.read(&mut buf) {
611 Ok(0) | Err(_) => break,
612 Ok(_) => {
613 let _ = s.write_all(&simple_query_response());
614 }
615 }
616 }
617 });
618 thread::sleep(Duration::from_millis(30));
619 port
620 }
621
622 fn spawn_eof_server() -> u16 {
624 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
625 let port = listener.local_addr().unwrap().port();
626 thread::spawn(move || {
627 let (s, _) = listener.accept().unwrap();
628 drop(s); });
630 thread::sleep(Duration::from_millis(30));
631 port
632 }
633
634 fn spawn_auth_error_server() -> u16 {
636 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
637 let port = listener.local_addr().unwrap().port();
638 thread::spawn(move || {
639 let (mut s, _) = listener.accept().unwrap();
640 let mut buf = [0u8; 4096];
641 let _ = s.read(&mut buf).unwrap();
642 let mut payload = Vec::new();
644 payload.push(b'C');
645 payload.extend(b"28P01\x00");
646 payload.push(b'M');
647 payload.extend(b"password authentication failed\x00");
648 payload.push(0);
649 let _ = s.write_all(&pg_msg(b'E', &payload));
650 });
651 thread::sleep(Duration::from_millis(30));
652 port
653 }
654
655 fn spawn_query_error_server() -> u16 {
657 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
658 let port = listener.local_addr().unwrap().port();
659 thread::spawn(move || {
660 let (mut s, _) = listener.accept().unwrap();
661 let mut buf = [0u8; 4096];
662 let _ = s.read(&mut buf).unwrap();
664 let _ = s.write_all(&pg_auth(3, &[]));
665 let _ = s.read(&mut buf).unwrap();
666 let _ = s.write_all(&post_auth_ok());
667 loop {
669 match s.read(&mut buf) {
670 Ok(0) | Err(_) => break,
671 Ok(_) => {
672 let _ = s.write_all(&error_response());
673 }
674 }
675 }
676 });
677 thread::sleep(Duration::from_millis(30));
678 port
679 }
680
681 fn spawn_execute_server() -> u16 {
683 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
684 let port = listener.local_addr().unwrap().port();
685 thread::spawn(move || {
686 let (mut s, _) = listener.accept().unwrap();
687 let mut buf = [0u8; 4096];
688 let _ = s.read(&mut buf).unwrap();
690 let _ = s.write_all(&pg_auth(3, &[]));
691 let _ = s.read(&mut buf).unwrap();
692 let _ = s.write_all(&post_auth_ok());
693 loop {
695 match s.read(&mut buf) {
696 Ok(0) | Err(_) => break,
697 Ok(_) => {
698 let _ = s.write_all(&execute_response());
699 }
700 }
701 }
702 });
703 thread::sleep(Duration::from_millis(30));
704 port
705 }
706
707 fn spawn_query_params_server() -> u16 {
709 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
710 let port = listener.local_addr().unwrap().port();
711 thread::spawn(move || {
712 let (mut s, _) = listener.accept().unwrap();
713 let mut buf = [0u8; 4096];
714 let _ = s.read(&mut buf).unwrap();
715 let _ = s.write_all(&pg_auth(3, &[]));
716 let _ = s.read(&mut buf).unwrap();
717 let _ = s.write_all(&post_auth_ok());
718 loop {
719 match s.read(&mut buf) {
720 Ok(0) | Err(_) => break,
721 Ok(_) => {
722 let _ = s.write_all(&query_params_response());
723 }
724 }
725 }
726 });
727 thread::sleep(Duration::from_millis(30));
728 port
729 }
730
731 fn spawn_params_server() -> u16 {
732 spawn_query_params_server()
733 }
734
735 fn spawn_execute_params_server() -> u16 {
737 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
738 let port = listener.local_addr().unwrap().port();
739 thread::spawn(move || {
740 let (mut s, _) = listener.accept().unwrap();
741 let mut buf = [0u8; 4096];
742 let _ = s.read(&mut buf).unwrap();
743 let _ = s.write_all(&pg_auth(3, &[]));
744 let _ = s.read(&mut buf).unwrap();
745 let _ = s.write_all(&post_auth_ok());
746 loop {
747 match s.read(&mut buf) {
748 Ok(0) | Err(_) => break,
749 Ok(_) => {
750 let _ = s.write_all(&execute_params_response());
751 }
752 }
753 }
754 });
755 thread::sleep(Duration::from_millis(30));
756 port
757 }
758
759 fn spawn_query_params_null_server() -> u16 {
761 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
762 let port = listener.local_addr().unwrap().port();
763 thread::spawn(move || {
764 let (mut s, _) = listener.accept().unwrap();
765 let mut buf = [0u8; 4096];
766 let _ = s.read(&mut buf).unwrap();
767 let _ = s.write_all(&pg_auth(3, &[]));
768 let _ = s.read(&mut buf).unwrap();
769 let _ = s.write_all(&post_auth_ok());
770 loop {
771 match s.read(&mut buf) {
772 Ok(0) | Err(_) => break,
773 Ok(_) => {
774 let _ = s.write_all(&query_params_null_response());
775 }
776 }
777 }
778 });
779 thread::sleep(Duration::from_millis(30));
780 port
781 }
782
783 #[test]
786 fn connect_cleartext_auth_success() {
787 let port = spawn_cleartext_server();
788 let conn = Connect::new(mock_config(port));
789 assert!(conn.is_ok());
790 }
791
792 #[test]
793 fn connect_md5_auth_success() {
794 let port = spawn_md5_server();
795 let conn = Connect::new(mock_config(port));
796 assert!(conn.is_ok());
797 }
798
799 #[test]
800 fn connect_scram_auth_success() {
801 let port = spawn_scram_server();
802 let conn = Connect::new(mock_config(port));
803 assert!(conn.is_ok());
804 }
805
806 #[test]
807 fn connect_connection_refused() {
808 let cfg = mock_config(1);
810 let result = Connect::new(cfg);
811 assert!(result.is_err());
812 match result.unwrap_err() {
813 PgsqlError::Connection(_) => {}
814 other => panic!("expected Connection error, got {other:?}"),
815 }
816 }
817
818 #[test]
819 fn connect_server_closes_immediately() {
820 let port = spawn_eof_server();
821 let result = Connect::new(mock_config(port));
822 assert!(result.is_err());
823 }
824
825 #[test]
826 fn connect_auth_error_from_server() {
827 let port = spawn_auth_error_server();
828 let result = Connect::new(mock_config(port));
829 assert!(result.is_err());
830 }
831
832 #[test]
833 fn connect_query_success() {
834 let port = spawn_cleartext_server();
835 let mut conn = Connect::new(mock_config(port)).unwrap();
836 let result = conn.query("SELECT 1");
837 assert!(result.is_ok());
838 let msg = result.unwrap();
839 assert_eq!(msg.rows.len(), 1);
840 assert_eq!(msg.rows[0]["c"].as_i32(), Some(1));
841 }
842
843 #[test]
844 fn connect_execute_success() {
845 let port = spawn_execute_server();
846 let mut conn = Connect::new(mock_config(port)).unwrap();
847 let result = conn.execute("UPDATE t SET x=1");
848 assert!(result.is_ok());
849 let msg = result.unwrap();
850 assert_eq!(msg.affect_count, 3);
851 assert_eq!(msg.tag, "UPDATE 3");
852 }
853
854 #[test]
855 fn connect_query_params_success() {
856 let port = spawn_query_params_server();
857 let mut conn = Connect::new(mock_config(port)).unwrap();
858 let result = conn.query_params("SELECT $1::int", &[Some("42")]);
859 assert!(result.is_ok());
860 let msg = result.unwrap();
861 assert!(!msg.param_oids.is_empty());
862 assert_eq!(msg.rows.len(), 1);
863 assert_eq!(msg.rows[0]["p"].as_i32(), Some(42));
864 }
865
866 #[test]
867 fn connect_execute_params_success() {
868 let port = spawn_execute_params_server();
869 let mut conn = Connect::new(mock_config(port)).unwrap();
870 let result = conn.execute_params("UPDATE t SET x=$1", &[Some("42")]);
871 assert!(result.is_ok());
872 let msg = result.unwrap();
873 assert!(!msg.param_oids.is_empty());
874 assert_eq!(msg.affect_count, 1);
875 assert_eq!(msg.tag, "UPDATE 1");
876 }
877
878 #[test]
879 fn connect_query_str_success() {
880 let port = spawn_params_server();
881 let mut conn = Connect::new(mock_config(port)).unwrap();
882 let result = conn.query_str("SELECT $1::int", &["42"]);
883 assert!(result.is_ok());
884 let msg = result.unwrap();
885 assert!(!msg.param_oids.is_empty());
886 assert_eq!(msg.rows.len(), 1);
887 }
888
889 #[test]
890 fn connect_execute_str_success() {
891 let port = spawn_execute_params_server();
892 let mut conn = Connect::new(mock_config(port)).unwrap();
893 let result = conn.execute_str("UPDATE t SET x=$1", &["1"]);
894 assert!(result.is_ok());
895 let msg = result.unwrap();
896 assert!(!msg.param_oids.is_empty());
897 assert_eq!(msg.affect_count, 1);
898 }
899
900 #[test]
901 fn connect_query_params_with_null() {
902 let port = spawn_query_params_null_server();
903 let mut conn = Connect::new(mock_config(port)).unwrap();
904 let result = conn.query_params("SELECT $1::text", &[None]);
905 assert!(result.is_ok());
906 let msg = result.unwrap();
907 assert!(!msg.param_oids.is_empty());
908 assert_eq!(msg.rows.len(), 1);
909 assert_eq!(msg.rows[0]["n"], "");
910 }
911
912 #[test]
913 fn connect_query_params_empty_string_vs_null() {
914 let port = spawn_params_server();
915 let mut conn = Connect::new(mock_config(port)).unwrap();
916
917 let r1 = conn.query_params("SELECT $1::text", &[Some("")]);
919 assert!(r1.is_ok());
920
921 let r2 = conn.query_params("SELECT $1::text", &[None]);
923 assert!(r2.is_ok());
924 }
925
926 #[test]
927 fn connect_query_returns_error() {
928 let port = spawn_query_error_server();
929 let mut conn = Connect::new(mock_config(port)).unwrap();
930 let result = conn.query("BAD SQL");
931 assert!(result.is_err());
932 }
933
934 #[test]
935 fn connect_is_valid_true() {
936 let port = spawn_cleartext_server();
937 let mut conn = Connect::new(mock_config(port)).unwrap();
938 assert!(conn.is_valid());
939 }
940
941 #[test]
942 fn connect_is_valid_false_after_close() {
943 let port = spawn_cleartext_server();
944 let mut conn = Connect::new(mock_config(port)).unwrap();
945 conn._close();
946 assert!(!conn.is_valid());
948 }
949
950 #[test]
951 fn connect_close_does_not_panic() {
952 let port = spawn_cleartext_server();
953 let mut conn = Connect::new(mock_config(port)).unwrap();
954 conn._close();
955 conn._close();
957 }
958
959 #[test]
960 fn connect_drop_does_not_panic() {
961 let port = spawn_cleartext_server();
962 let conn = Connect::new(mock_config(port)).unwrap();
963 drop(conn);
964 }
965
966 fn spawn_transaction_status_server() -> u16 {
967 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
968 let port = listener.local_addr().unwrap().port();
969 thread::spawn(move || {
970 let (mut s, _) = listener.accept().unwrap();
971 let mut buf = [0u8; 4096];
972 let _ = s.read(&mut buf).unwrap();
973 let _ = s.write_all(&pg_auth(3, &[]));
974 let _ = s.read(&mut buf).unwrap();
975 let _ = s.write_all(&post_auth_ok());
976 loop {
977 match s.read(&mut buf) {
978 Ok(0) | Err(_) => break,
979 Ok(_) => {
980 let mut r = Vec::new();
981 r.extend(pg_msg(b'1', &[]));
982 r.extend(pg_msg(b'2', &[]));
983 let mut rd = Vec::new();
984 rd.extend(&1u16.to_be_bytes());
985 rd.extend(b"c\x00");
986 rd.extend(&0u32.to_be_bytes());
987 rd.extend(&1u16.to_be_bytes());
988 rd.extend(&23u32.to_be_bytes());
989 rd.extend(&4i16.to_be_bytes());
990 rd.extend(&(-1i32).to_be_bytes());
991 rd.extend(&0u16.to_be_bytes());
992 r.extend(pg_msg(b'T', &rd));
993 let mut dr = Vec::new();
994 dr.extend(&1u16.to_be_bytes());
995 dr.extend(&1u32.to_be_bytes());
996 dr.push(b'1');
997 r.extend(pg_msg(b'D', &dr));
998 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
999 r.extend(pg_msg(b'Z', b"T"));
1000 let _ = s.write_all(&r);
1001 }
1002 }
1003 }
1004 });
1005 thread::sleep(Duration::from_millis(30));
1006 port
1007 }
1008
1009 fn spawn_error_status_server() -> u16 {
1010 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1011 let port = listener.local_addr().unwrap().port();
1012 thread::spawn(move || {
1013 let (mut s, _) = listener.accept().unwrap();
1014 let mut buf = [0u8; 4096];
1015 let _ = s.read(&mut buf).unwrap();
1016 let _ = s.write_all(&pg_auth(3, &[]));
1017 let _ = s.read(&mut buf).unwrap();
1018 let _ = s.write_all(&post_auth_ok());
1019 loop {
1020 match s.read(&mut buf) {
1021 Ok(0) | Err(_) => break,
1022 Ok(_) => {
1023 let mut r = Vec::new();
1024 r.extend(pg_msg(b'1', &[]));
1025 r.extend(pg_msg(b'2', &[]));
1026 let mut rd = Vec::new();
1027 rd.extend(&1u16.to_be_bytes());
1028 rd.extend(b"c\x00");
1029 rd.extend(&0u32.to_be_bytes());
1030 rd.extend(&1u16.to_be_bytes());
1031 rd.extend(&23u32.to_be_bytes());
1032 rd.extend(&4i16.to_be_bytes());
1033 rd.extend(&(-1i32).to_be_bytes());
1034 rd.extend(&0u16.to_be_bytes());
1035 r.extend(pg_msg(b'T', &rd));
1036 let mut dr = Vec::new();
1037 dr.extend(&1u16.to_be_bytes());
1038 dr.extend(&1u32.to_be_bytes());
1039 dr.push(b'1');
1040 r.extend(pg_msg(b'D', &dr));
1041 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
1042 r.extend(pg_msg(b'Z', b"E"));
1043 let _ = s.write_all(&r);
1044 }
1045 }
1046 }
1047 });
1048 thread::sleep(Duration::from_millis(30));
1049 port
1050 }
1051
1052 fn spawn_slow_partial_server() -> u16 {
1053 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1054 let port = listener.local_addr().unwrap().port();
1055 thread::spawn(move || {
1056 let (mut s, _) = listener.accept().unwrap();
1057 let mut buf = [0u8; 4096];
1058 let _ = s.read(&mut buf).unwrap();
1059 let _ = s.write_all(&pg_auth(3, &[]));
1060 let _ = s.read(&mut buf).unwrap();
1061 let _ = s.write_all(&post_auth_ok());
1062 match s.read(&mut buf) {
1063 Ok(0) | Err(_) => {}
1064 Ok(_) => {
1065 let _ = s.write_all(&simple_query_response());
1066 }
1067 }
1068 });
1069 thread::sleep(Duration::from_millis(30));
1070 port
1071 }
1072
1073 fn spawn_rst_on_query_server() -> u16 {
1074 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1075 let port = listener.local_addr().unwrap().port();
1076 thread::spawn(move || {
1077 let (mut s, _) = listener.accept().unwrap();
1078 let mut buf = [0u8; 4096];
1079 let _ = s.read(&mut buf).unwrap();
1080 let _ = s.write_all(&pg_auth(3, &[]));
1081 let _ = s.read(&mut buf).unwrap();
1082 let _ = s.write_all(&post_auth_ok());
1083 match s.read(&mut buf) {
1084 Ok(0) | Err(_) => {}
1085 Ok(_) => {
1086 drop(s);
1087 }
1088 }
1089 });
1090 thread::sleep(Duration::from_millis(30));
1091 port
1092 }
1093
1094 #[test]
1095 fn connect_query_ready_for_query_transaction_status() {
1096 let port = spawn_transaction_status_server();
1097 let mut conn = Connect::new(mock_config(port)).unwrap();
1098 let result = conn.query("SELECT 1");
1099 assert!(result.is_ok());
1100 }
1101
1102 #[test]
1103 fn connect_query_ready_for_query_error_status() {
1104 let port = spawn_error_status_server();
1105 let mut conn = Connect::new(mock_config(port)).unwrap();
1106 let result = conn.query("SELECT 1");
1107 assert!(result.is_ok());
1108 }
1109
1110 #[test]
1111 fn connect_query_server_closes_after_partial() {
1112 let port = spawn_slow_partial_server();
1113 let mut conn = Connect::new(mock_config(port)).unwrap();
1114 let r1 = conn.query("SELECT 1");
1115 assert!(r1.is_ok());
1116 let r2 = conn.query("SELECT 1");
1117 assert!(r2.is_err());
1118 }
1119
1120 #[test]
1121 fn connect_query_server_rst_returns_io_or_connection_error() {
1122 let port = spawn_rst_on_query_server();
1123 let mut conn = Connect::new(mock_config(port)).unwrap();
1124 let result = conn.query("SELECT 1");
1125 assert!(result.is_err());
1126 }
1127
1128 #[test]
1129 fn connect_read_would_block_max_retries() {
1130 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1131 let port = listener.local_addr().unwrap().port();
1132 thread::spawn(move || {
1133 let (mut s, _) = listener.accept().unwrap();
1134 let mut buf = [0u8; 4096];
1135 let _ = s.read(&mut buf);
1136 let _ = s.write_all(&pg_auth(3, &[]));
1137 let _ = s.read(&mut buf);
1138 let _ = s.write_all(&post_auth_ok());
1139 let _ = s.read(&mut buf);
1140 thread::sleep(Duration::from_secs(5));
1141 });
1142 thread::sleep(Duration::from_millis(30));
1143
1144 let mut conn = Connect::new(mock_config(port)).unwrap();
1145 conn.stream
1146 .set_read_timeout(Some(Duration::from_millis(1)))
1147 .ok();
1148 let result = conn.query("SELECT 1");
1149 assert!(result.is_err());
1150 let err_str = result.unwrap_err().to_string();
1151 assert!(
1152 err_str.contains("超时") || err_str.contains("Timeout") || err_str.contains("重试"),
1153 "expected timeout error, got: {err_str}"
1154 );
1155 }
1156
1157 #[test]
1158 fn connect_read_exceeds_max_message_size() {
1159 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1160 let port = listener.local_addr().unwrap().port();
1161 thread::spawn(move || {
1162 let (mut s, _) = listener.accept().unwrap();
1163 let mut buf = [0u8; 4096];
1164 let _ = s.read(&mut buf);
1165 let _ = s.write_all(&pg_auth(3, &[]));
1166 let _ = s.read(&mut buf);
1167 let _ = s.write_all(&post_auth_ok());
1168 let _ = s.read(&mut buf);
1169 let big = vec![b'X'; 256];
1170 let _ = s.write_all(&big);
1171 thread::sleep(Duration::from_secs(2));
1172 });
1173 thread::sleep(Duration::from_millis(30));
1174
1175 let mut conn = Connect::new(mock_config(port)).unwrap();
1176 let result = conn.query("SELECT 1");
1177 assert!(result.is_err());
1178 let err_str = result.unwrap_err().to_string();
1179 assert!(
1180 err_str.contains("最大") || err_str.contains("大小") || err_str.contains("size"),
1181 "expected max message size error, got: {err_str}"
1182 );
1183 }
1184
1185 #[test]
1186 fn connect_read_deadline_timeout() {
1187 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1188 let port = listener.local_addr().unwrap().port();
1189 thread::spawn(move || {
1190 let (mut s, _) = listener.accept().unwrap();
1191 let mut buf = [0u8; 4096];
1192 let _ = s.read(&mut buf);
1193 let _ = s.write_all(&pg_auth(3, &[]));
1194 let _ = s.read(&mut buf);
1195 let _ = s.write_all(&post_auth_ok());
1196 let _ = s.read(&mut buf);
1197 for _ in 0..200 {
1198 let _ = s.write_all(b"X");
1199 thread::sleep(Duration::from_millis(5));
1200 }
1201 });
1202 thread::sleep(Duration::from_millis(30));
1203
1204 let mut conn = Connect::new(mock_config(port)).unwrap();
1205 let result = conn.query("SELECT 1");
1206 assert!(result.is_err());
1207 }
1208
1209 #[test]
1210 fn connect_read_partial_auth_frame() {
1211 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1212 let port = listener.local_addr().unwrap().port();
1213 thread::spawn(move || {
1214 let (mut s, _) = listener.accept().unwrap();
1215 let mut buf = [0u8; 4096];
1216 let _ = s.read(&mut buf);
1217 let auth = pg_auth(3, &[]);
1218 let _ = s.write_all(&auth[..5]);
1219 thread::sleep(Duration::from_millis(50));
1220 let _ = s.write_all(&auth[5..]);
1221 let _ = s.read(&mut buf);
1222 let _ = s.write_all(&post_auth_ok());
1223 loop {
1224 match s.read(&mut buf) {
1225 Ok(0) | Err(_) => break,
1226 Ok(_) => {
1227 let _ = s.write_all(&simple_query_response());
1228 }
1229 }
1230 }
1231 });
1232 thread::sleep(Duration::from_millis(30));
1233
1234 let mut conn = Connect::new(mock_config(port)).unwrap();
1235 let result = conn.query("SELECT 1");
1236 assert!(result.is_ok());
1237 }
1238}