1use crate::config::Config;
2use crate::error::PgsqlError;
3use crate::packet::{AuthStatus, Packet, SuccessMessage};
4use std::io::{Read, Write};
5use std::net::TcpStream;
6use std::time::{Duration, Instant};
7
8#[derive(Debug)]
9pub struct Connect {
10 pub(crate) stream: TcpStream,
11 packet: Packet,
12 auth_status: AuthStatus,
13 last_used: Instant,
15}
16
17impl Connect {
18 pub fn is_valid(&mut self) -> bool {
20 if self.stream.peer_addr().is_err() {
21 return false;
22 }
23 #[cfg(not(test))]
24 const IDLE_THRESHOLD: Duration = Duration::from_secs(30);
25 #[cfg(test)]
26 const IDLE_THRESHOLD: Duration = Duration::from_millis(0);
27 if self.last_used.elapsed() > IDLE_THRESHOLD {
28 return self.query("SELECT 1").is_ok();
29 }
30 true
31 }
32
33 pub fn peer_valid(&self) -> bool {
35 self.stream.peer_addr().is_ok()
36 }
37
38 pub fn touch(&mut self) {
40 self.last_used = Instant::now();
41 }
42
43 pub fn idle_elapsed(&self) -> Duration {
45 self.last_used.elapsed()
46 }
47
48 pub fn _close(&mut self) {
49 let _ = (&self.stream).write_all(&Packet::pack_terminate());
50 let _ = self.stream.shutdown(std::net::Shutdown::Both);
51 }
52
53 fn set_keepalive(stream: &TcpStream) -> Result<(), PgsqlError> {
55 use std::os::unix::io::{AsRawFd, FromRawFd};
56 let fd = stream.as_raw_fd();
57 let socket = unsafe { socket2::Socket::from_raw_fd(fd) };
58 let keepalive = socket2::TcpKeepalive::new()
59 .with_time(Duration::from_secs(60))
60 .with_interval(Duration::from_secs(15))
61 .with_retries(3);
62 let result = socket.set_tcp_keepalive(&keepalive);
63 std::mem::forget(socket);
65 result.map_err(|e| PgsqlError::Connection(format!("设置 TCP Keepalive 失败: {}", e)))
66 }
67
68 pub fn new(mut config: Config) -> Result<Connect, PgsqlError> {
69 let stream =
70 TcpStream::connect(config.url()).map_err(|e| PgsqlError::Connection(e.to_string()))?;
71 stream
73 .set_nodelay(true)
74 .map_err(|e| PgsqlError::Connection(format!("设置 TCP_NODELAY 失败: {}", e)))?;
75 Self::set_keepalive(&stream)?;
77 stream
78 .set_read_timeout(Some(Duration::from_secs(30)))
79 .map_err(|e| PgsqlError::Connection(format!("设置读取超时失败: {}", e)))?;
80 stream
81 .set_write_timeout(Some(Duration::from_secs(30)))
82 .map_err(|e| PgsqlError::Connection(format!("设置写入超时失败: {}", e)))?;
83 let _ = stream.peer_addr();
84
85 let mut connect = Self {
86 stream,
87 packet: Packet::new(config),
88 auth_status: AuthStatus::None,
89 last_used: Instant::now(),
90 };
91
92 connect.authenticate()?;
93
94 Ok(connect)
95 }
96
97 fn authenticate(&mut self) -> Result<(), PgsqlError> {
98 (&self.stream)
99 .write_all(&self.packet.pack_first())
100 .map_err(|e| PgsqlError::Auth(format!("发送 startup message 失败: {}", e)))?;
101
102 let data = self.read()?;
103 self.packet.unpack(data, 0)?;
104
105 if !self.packet.md5_salt.is_empty() {
106 self.md5_auth()?;
107 } else if self.packet.auth_mechanism.is_empty() && self.packet.md5_salt.is_empty() {
108 self.cleartext_auth()?;
109 } else {
110 self.scram_auth()?;
111 }
112
113 self.auth_status = AuthStatus::AuthenticationOk;
114 Ok(())
115 }
116
117 fn md5_auth(&mut self) -> Result<(), PgsqlError> {
118 (&self.stream)
119 .write_all(&self.packet.pack_md5_password())
120 .map_err(|e| PgsqlError::Auth(format!("发送 MD5 密码失败: {}", e)))?;
121
122 let data = self.read()?;
123 self.packet.unpack(data, 0)?;
124 Ok(())
125 }
126
127 fn cleartext_auth(&mut self) -> Result<(), PgsqlError> {
128 (&self.stream)
129 .write_all(&self.packet.pack_cleartext_password())
130 .map_err(|e| PgsqlError::Auth(format!("发送明文密码失败: {}", e)))?;
131
132 let data = self.read()?;
133 self.packet.unpack(data, 0)?;
134 Ok(())
135 }
136
137 fn scram_auth(&mut self) -> Result<(), PgsqlError> {
138 (&self.stream)
139 .write_all(&self.packet.pack_auth())
140 .map_err(|e| PgsqlError::Auth(format!("发送 SASL Initial Response 失败: {}", e)))?;
141
142 let data = self.read()?;
143 self.packet.unpack(data, 0)?;
144
145 (&self.stream)
146 .write_all(&self.packet.pack_auth_verify())
147 .map_err(|e| PgsqlError::Auth(format!("发送 SASL Verify 失败: {}", e)))?;
148
149 let data = self.read()?;
150 self.packet.unpack(data, 0)?;
151 Ok(())
152 }
153
154 fn read(&mut self) -> Result<Vec<u8>, PgsqlError> {
155 let mut msg = Vec::new();
156 let mut buf = [0u8; 4096];
157 let mut retry_count = 0;
158
159 #[cfg(not(test))]
160 const MAX_RETRIES: u32 = 100;
161 #[cfg(test)]
162 const MAX_RETRIES: u32 = 3;
163
164 #[cfg(not(test))]
165 const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024;
166 #[cfg(test)]
167 const MAX_MESSAGE_SIZE: usize = 128;
168
169 #[cfg(not(test))]
170 let deadline = std::time::Instant::now() + Duration::from_secs(300);
171 #[cfg(test)]
172 let deadline = std::time::Instant::now() + Duration::from_millis(200);
173
174 loop {
175 if std::time::Instant::now() >= deadline {
176 return Err(PgsqlError::Timeout("读取总超时".into()));
177 }
178
179 match (&self.stream).read(&mut buf) {
180 Ok(0) => return Err(PgsqlError::Connection("连接已关闭或服务端断开".into())),
181 Ok(n) => {
182 if msg.len() + n > MAX_MESSAGE_SIZE {
183 return Err(PgsqlError::Protocol("消息超过最大允许大小".into()));
184 }
185 msg.extend_from_slice(&buf[..n]);
186 retry_count = 0;
187 }
188 Err(ref e)
189 if e.kind() == std::io::ErrorKind::WouldBlock
190 || e.kind() == std::io::ErrorKind::TimedOut =>
191 {
192 retry_count += 1;
193 if retry_count > MAX_RETRIES {
194 return Err(PgsqlError::Timeout("读取超时,已达最大重试次数".into()));
195 }
196 std::thread::sleep(Duration::from_millis(10));
197 continue;
198 }
199 Err(e) => return Err(PgsqlError::Io(e)),
200 };
201
202 if let AuthStatus::AuthenticationOk = self.auth_status {
203 if msg.ends_with(&[90, 0, 0, 0, 5, 73])
204 || msg.ends_with(&[90, 0, 0, 0, 5, 84])
205 || msg.ends_with(&[90, 0, 0, 0, 5, 69])
206 {
207 break;
208 }
209 } else if msg.len() >= 5 {
210 let len_bytes = &msg[1..=4];
211 if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
212 if msg.len() > len as usize {
213 break;
214 }
215 }
216 }
217 }
218
219 Ok(msg)
220 }
221
222 pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
223 (&self.stream)
224 .write_all(&self.packet.pack_query(sql))
225 .map_err(PgsqlError::Io)?;
226 let data = self.read()?;
227 self.last_used = Instant::now();
228 self.packet.unpack(data, 0)
229 }
230
231 pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
232 (&self.stream)
233 .write_all(&self.packet.pack_execute(sql))
234 .map_err(PgsqlError::Io)?;
235 let data = self.read()?;
236 self.last_used = Instant::now();
237 self.packet.unpack(data, 0)
238 }
239
240 pub fn query_params(
242 &mut self,
243 sql: &str,
244 params: &[Option<&str>],
245 ) -> Result<SuccessMessage, PgsqlError> {
246 (&self.stream)
247 .write_all(&self.packet.pack_query_params(sql, params))
248 .map_err(PgsqlError::Io)?;
249
250 let data = self.read()?;
251 self.last_used = Instant::now();
252 self.packet.unpack(data, 0)
253 }
254
255 pub fn execute_params(
257 &mut self,
258 sql: &str,
259 params: &[Option<&str>],
260 ) -> Result<SuccessMessage, PgsqlError> {
261 (&self.stream)
262 .write_all(&self.packet.pack_execute_params(sql, params))
263 .map_err(PgsqlError::Io)?;
264 let data = self.read()?;
265 self.last_used = Instant::now();
266 self.packet.unpack(data, 0)
267 }
268
269 pub fn query_str(&mut self, sql: &str, params: &[&str]) -> Result<SuccessMessage, PgsqlError> {
271 let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
272 self.query_params(sql, &opts)
273 }
274
275 pub fn execute_str(
277 &mut self,
278 sql: &str,
279 params: &[&str],
280 ) -> Result<SuccessMessage, PgsqlError> {
281 let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
282 self.execute_params(sql, &opts)
283 }
284}
285
286impl Drop for Connect {
287 fn drop(&mut self) {
288 let _ = (&self.stream).write_all(&Packet::pack_terminate());
289 let _ = self.stream.shutdown(std::net::Shutdown::Both);
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use std::net::TcpListener;
297 use std::thread;
298
299 fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
303 let mut m = Vec::with_capacity(5 + payload.len());
304 m.push(tag);
305 m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
306 m.extend_from_slice(payload);
307 m
308 }
309
310 fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
312 let mut body = Vec::new();
313 body.extend(&auth_type.to_be_bytes());
314 body.extend_from_slice(extra);
315 pg_msg(b'R', &body)
316 }
317
318 fn auth_ok() -> Vec<u8> {
320 pg_auth(0, &[])
321 }
322
323 fn param_status() -> Vec<u8> {
325 pg_msg(b'S', b"server_version\x0015.0\x00")
326 }
327
328 fn backend_key() -> Vec<u8> {
330 let mut p = Vec::new();
331 p.extend(&1u32.to_be_bytes());
332 p.extend(&2u32.to_be_bytes());
333 pg_msg(b'K', &p)
334 }
335
336 fn ready_for_query() -> Vec<u8> {
338 pg_msg(b'Z', b"I")
339 }
340
341 fn post_auth_ok() -> Vec<u8> {
343 let mut v = Vec::new();
344 v.extend(auth_ok());
345 v.extend(param_status());
346 v.extend(backend_key());
347 v.extend(ready_for_query());
348 v
349 }
350
351 fn simple_query_response() -> Vec<u8> {
354 let mut r = Vec::new();
355 r.extend(pg_msg(b'1', &[]));
357 r.extend(pg_msg(b'2', &[]));
359 let mut rd = Vec::new();
361 rd.extend(&1u16.to_be_bytes()); rd.extend(b"c\x00"); rd.extend(&0u32.to_be_bytes()); rd.extend(&1u16.to_be_bytes()); rd.extend(&23u32.to_be_bytes()); rd.extend(&4i16.to_be_bytes()); rd.extend(&(-1i32).to_be_bytes()); rd.extend(&0u16.to_be_bytes()); r.extend(pg_msg(b'T', &rd));
370 let mut dr = Vec::new();
372 dr.extend(&1u16.to_be_bytes());
373 dr.extend(&1u32.to_be_bytes()); dr.push(b'1');
375 r.extend(pg_msg(b'D', &dr));
376 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
378 r.extend(ready_for_query());
380 r
381 }
382
383 fn execute_response() -> Vec<u8> {
386 let mut r = Vec::new();
387 r.extend(pg_msg(b'1', &[]));
388 r.extend(pg_msg(b'2', &[]));
389 r.extend(pg_msg(b'n', &[])); r.extend(pg_msg(b'C', b"UPDATE 3\x00"));
391 r.extend(ready_for_query());
392 r
393 }
394
395 fn query_params_response() -> Vec<u8> {
398 let mut r = Vec::new();
399 r.extend(pg_msg(b'1', &[]));
400
401 let mut pd = Vec::new();
402 pd.extend(&1u16.to_be_bytes());
403 pd.extend(&23u32.to_be_bytes());
404 r.extend(pg_msg(b't', &pd));
405
406 r.extend(pg_msg(b'2', &[]));
407
408 let mut rd = Vec::new();
409 rd.extend(&1u16.to_be_bytes());
410 rd.extend(b"p\x00");
411 rd.extend(&0u32.to_be_bytes());
412 rd.extend(&1u16.to_be_bytes());
413 rd.extend(&23u32.to_be_bytes());
414 rd.extend(&4i16.to_be_bytes());
415 rd.extend(&(-1i32).to_be_bytes());
416 rd.extend(&0u16.to_be_bytes());
417 r.extend(pg_msg(b'T', &rd));
418
419 let mut dr = Vec::new();
420 dr.extend(&1u16.to_be_bytes());
421 dr.extend(&2u32.to_be_bytes());
422 dr.extend(b"42");
423 r.extend(pg_msg(b'D', &dr));
424
425 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
426 r.extend(ready_for_query());
427 r
428 }
429
430 fn execute_params_response() -> Vec<u8> {
433 let mut r = Vec::new();
434 r.extend(pg_msg(b'1', &[]));
435
436 let mut pd = Vec::new();
437 pd.extend(&1u16.to_be_bytes());
438 pd.extend(&23u32.to_be_bytes());
439 r.extend(pg_msg(b't', &pd));
440
441 r.extend(pg_msg(b'2', &[]));
442 r.extend(pg_msg(b'n', &[]));
443 r.extend(pg_msg(b'C', b"UPDATE 1\x00"));
444 r.extend(ready_for_query());
445 r
446 }
447
448 fn query_params_null_response() -> Vec<u8> {
450 let mut r = Vec::new();
451 r.extend(pg_msg(b'1', &[]));
452
453 let mut pd = Vec::new();
454 pd.extend(&1u16.to_be_bytes());
455 pd.extend(&25u32.to_be_bytes());
456 r.extend(pg_msg(b't', &pd));
457
458 r.extend(pg_msg(b'2', &[]));
459
460 let mut rd = Vec::new();
461 rd.extend(&1u16.to_be_bytes());
462 rd.extend(b"n\x00");
463 rd.extend(&0u32.to_be_bytes());
464 rd.extend(&1u16.to_be_bytes());
465 rd.extend(&25u32.to_be_bytes());
466 rd.extend(&(-1i16).to_be_bytes());
467 rd.extend(&(-1i32).to_be_bytes());
468 rd.extend(&0u16.to_be_bytes());
469 r.extend(pg_msg(b'T', &rd));
470
471 let mut dr = Vec::new();
472 dr.extend(&1u16.to_be_bytes());
473 dr.extend(&(-1i32).to_be_bytes());
474 r.extend(pg_msg(b'D', &dr));
475
476 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
477 r.extend(ready_for_query());
478 r
479 }
480
481 fn error_response() -> Vec<u8> {
483 let mut payload = Vec::new();
484 payload.push(b'C');
485 payload.extend(b"42601\x00");
486 payload.push(b'M');
487 payload.extend(b"syntax error\x00");
488 payload.push(0);
489 let mut r = Vec::new();
490 r.extend(pg_msg(b'1', &[]));
491 r.extend(pg_msg(b'2', &[]));
492 r.extend(pg_msg(b'E', &payload));
493 r.extend(ready_for_query());
494 r
495 }
496
497 fn mock_config(port: u16) -> Config {
501 Config {
502 debug: false,
503 hostname: "127.0.0.1".into(),
504 hostport: port as i32,
505 username: "u".into(),
506 userpass: "p".into(),
507 database: "d".into(),
508 charset: "utf8".into(),
509 pool_max: 5,
510 }
511 }
512
513 fn spawn_cleartext_server() -> u16 {
516 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
517 let port = listener.local_addr().unwrap().port();
518 thread::spawn(move || {
519 let (mut s, _) = listener.accept().unwrap();
520 let mut buf = [0u8; 4096];
521 let _ = s.read(&mut buf).unwrap();
523 let _ = s.write_all(&pg_auth(3, &[]));
525 let _ = s.read(&mut buf).unwrap();
527 let _ = s.write_all(&post_auth_ok());
529 loop {
531 match s.read(&mut buf) {
532 Ok(0) | Err(_) => break,
533 Ok(_) => {
534 let _ = s.write_all(&simple_query_response());
535 }
536 }
537 }
538 });
539 thread::sleep(Duration::from_millis(30));
540 port
541 }
542
543 fn spawn_md5_server() -> u16 {
545 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
546 let port = listener.local_addr().unwrap().port();
547 thread::spawn(move || {
548 let (mut s, _) = listener.accept().unwrap();
549 let mut buf = [0u8; 4096];
550 let _ = s.read(&mut buf).unwrap();
552 let _ = s.write_all(&pg_auth(5, &[0xAA, 0xBB, 0xCC, 0xDD]));
554 let _ = s.read(&mut buf).unwrap();
556 let _ = s.write_all(&post_auth_ok());
558 loop {
559 match s.read(&mut buf) {
560 Ok(0) | Err(_) => break,
561 Ok(_) => {
562 let _ = s.write_all(&simple_query_response());
563 }
564 }
565 }
566 });
567 thread::sleep(Duration::from_millis(30));
568 port
569 }
570
571 fn spawn_scram_server() -> u16 {
573 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
574 let port = listener.local_addr().unwrap().port();
575 thread::spawn(move || {
576 let (mut s, _) = listener.accept().unwrap();
577 let mut buf = [0u8; 4096];
578 let _ = s.read(&mut buf).unwrap();
580 let _ = s.write_all(&pg_auth(10, b"SCRAM-SHA-256\x00\x00"));
582 let n = s.read(&mut buf).unwrap();
584 let payload = &buf[..n];
585 let text = String::from_utf8_lossy(payload);
587 let client_nonce = text.split("r=").nth(1).unwrap_or("clientnonce").to_string();
588 let challenge = format!("r={client_nonce}SERVERNONCE,s=c2FsdA==,i=4096");
590 let _ = s.write_all(&pg_auth(11, challenge.as_bytes()));
591 let _ = s.read(&mut buf).unwrap();
593 let mut resp = Vec::new();
595 resp.extend(pg_auth(12, b"v=dummyproof"));
596 resp.extend(auth_ok());
597 resp.extend(param_status());
598 resp.extend(backend_key());
599 resp.extend(ready_for_query());
600 let _ = s.write_all(&resp);
601 loop {
602 match s.read(&mut buf) {
603 Ok(0) | Err(_) => break,
604 Ok(_) => {
605 let _ = s.write_all(&simple_query_response());
606 }
607 }
608 }
609 });
610 thread::sleep(Duration::from_millis(30));
611 port
612 }
613
614 fn spawn_eof_server() -> u16 {
616 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
617 let port = listener.local_addr().unwrap().port();
618 thread::spawn(move || {
619 let (s, _) = listener.accept().unwrap();
620 drop(s); });
622 thread::sleep(Duration::from_millis(30));
623 port
624 }
625
626 fn spawn_auth_error_server() -> u16 {
628 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
629 let port = listener.local_addr().unwrap().port();
630 thread::spawn(move || {
631 let (mut s, _) = listener.accept().unwrap();
632 let mut buf = [0u8; 4096];
633 let _ = s.read(&mut buf).unwrap();
634 let mut payload = Vec::new();
636 payload.push(b'C');
637 payload.extend(b"28P01\x00");
638 payload.push(b'M');
639 payload.extend(b"password authentication failed\x00");
640 payload.push(0);
641 let _ = s.write_all(&pg_msg(b'E', &payload));
642 });
643 thread::sleep(Duration::from_millis(30));
644 port
645 }
646
647 fn spawn_query_error_server() -> u16 {
649 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
650 let port = listener.local_addr().unwrap().port();
651 thread::spawn(move || {
652 let (mut s, _) = listener.accept().unwrap();
653 let mut buf = [0u8; 4096];
654 let _ = s.read(&mut buf).unwrap();
656 let _ = s.write_all(&pg_auth(3, &[]));
657 let _ = s.read(&mut buf).unwrap();
658 let _ = s.write_all(&post_auth_ok());
659 loop {
661 match s.read(&mut buf) {
662 Ok(0) | Err(_) => break,
663 Ok(_) => {
664 let _ = s.write_all(&error_response());
665 }
666 }
667 }
668 });
669 thread::sleep(Duration::from_millis(30));
670 port
671 }
672
673 fn spawn_execute_server() -> u16 {
675 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
676 let port = listener.local_addr().unwrap().port();
677 thread::spawn(move || {
678 let (mut s, _) = listener.accept().unwrap();
679 let mut buf = [0u8; 4096];
680 let _ = s.read(&mut buf).unwrap();
682 let _ = s.write_all(&pg_auth(3, &[]));
683 let _ = s.read(&mut buf).unwrap();
684 let _ = s.write_all(&post_auth_ok());
685 loop {
687 match s.read(&mut buf) {
688 Ok(0) | Err(_) => break,
689 Ok(_) => {
690 let _ = s.write_all(&execute_response());
691 }
692 }
693 }
694 });
695 thread::sleep(Duration::from_millis(30));
696 port
697 }
698
699 fn spawn_query_params_server() -> u16 {
701 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
702 let port = listener.local_addr().unwrap().port();
703 thread::spawn(move || {
704 let (mut s, _) = listener.accept().unwrap();
705 let mut buf = [0u8; 4096];
706 let _ = s.read(&mut buf).unwrap();
707 let _ = s.write_all(&pg_auth(3, &[]));
708 let _ = s.read(&mut buf).unwrap();
709 let _ = s.write_all(&post_auth_ok());
710 loop {
711 match s.read(&mut buf) {
712 Ok(0) | Err(_) => break,
713 Ok(_) => {
714 let _ = s.write_all(&query_params_response());
715 }
716 }
717 }
718 });
719 thread::sleep(Duration::from_millis(30));
720 port
721 }
722
723 fn spawn_params_server() -> u16 {
724 spawn_query_params_server()
725 }
726
727 fn spawn_execute_params_server() -> u16 {
729 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
730 let port = listener.local_addr().unwrap().port();
731 thread::spawn(move || {
732 let (mut s, _) = listener.accept().unwrap();
733 let mut buf = [0u8; 4096];
734 let _ = s.read(&mut buf).unwrap();
735 let _ = s.write_all(&pg_auth(3, &[]));
736 let _ = s.read(&mut buf).unwrap();
737 let _ = s.write_all(&post_auth_ok());
738 loop {
739 match s.read(&mut buf) {
740 Ok(0) | Err(_) => break,
741 Ok(_) => {
742 let _ = s.write_all(&execute_params_response());
743 }
744 }
745 }
746 });
747 thread::sleep(Duration::from_millis(30));
748 port
749 }
750
751 fn spawn_query_params_null_server() -> u16 {
753 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
754 let port = listener.local_addr().unwrap().port();
755 thread::spawn(move || {
756 let (mut s, _) = listener.accept().unwrap();
757 let mut buf = [0u8; 4096];
758 let _ = s.read(&mut buf).unwrap();
759 let _ = s.write_all(&pg_auth(3, &[]));
760 let _ = s.read(&mut buf).unwrap();
761 let _ = s.write_all(&post_auth_ok());
762 loop {
763 match s.read(&mut buf) {
764 Ok(0) | Err(_) => break,
765 Ok(_) => {
766 let _ = s.write_all(&query_params_null_response());
767 }
768 }
769 }
770 });
771 thread::sleep(Duration::from_millis(30));
772 port
773 }
774
775 #[test]
778 fn connect_cleartext_auth_success() {
779 let port = spawn_cleartext_server();
780 let conn = Connect::new(mock_config(port));
781 assert!(conn.is_ok());
782 }
783
784 #[test]
785 fn connect_md5_auth_success() {
786 let port = spawn_md5_server();
787 let conn = Connect::new(mock_config(port));
788 assert!(conn.is_ok());
789 }
790
791 #[test]
792 fn connect_scram_auth_success() {
793 let port = spawn_scram_server();
794 let conn = Connect::new(mock_config(port));
795 assert!(conn.is_ok());
796 }
797
798 #[test]
799 fn connect_connection_refused() {
800 let cfg = mock_config(1);
802 let result = Connect::new(cfg);
803 assert!(result.is_err());
804 match result.unwrap_err() {
805 PgsqlError::Connection(_) => {}
806 other => panic!("expected Connection error, got {other:?}"),
807 }
808 }
809
810 #[test]
811 fn connect_server_closes_immediately() {
812 let port = spawn_eof_server();
813 let result = Connect::new(mock_config(port));
814 assert!(result.is_err());
815 }
816
817 #[test]
818 fn connect_auth_error_from_server() {
819 let port = spawn_auth_error_server();
820 let result = Connect::new(mock_config(port));
821 assert!(result.is_err());
822 }
823
824 #[test]
825 fn connect_query_success() {
826 let port = spawn_cleartext_server();
827 let mut conn = Connect::new(mock_config(port)).unwrap();
828 let result = conn.query("SELECT 1");
829 assert!(result.is_ok());
830 let msg = result.unwrap();
831 assert_eq!(msg.rows.len(), 1);
832 assert_eq!(msg.rows[0]["c"].as_i32(), Some(1));
833 }
834
835 #[test]
836 fn connect_execute_success() {
837 let port = spawn_execute_server();
838 let mut conn = Connect::new(mock_config(port)).unwrap();
839 let result = conn.execute("UPDATE t SET x=1");
840 assert!(result.is_ok());
841 let msg = result.unwrap();
842 assert_eq!(msg.affect_count, 3);
843 assert_eq!(msg.tag, "UPDATE 3");
844 }
845
846 #[test]
847 fn connect_query_params_success() {
848 let port = spawn_query_params_server();
849 let mut conn = Connect::new(mock_config(port)).unwrap();
850 let result = conn.query_params("SELECT $1::int", &[Some("42")]);
851 assert!(result.is_ok());
852 let msg = result.unwrap();
853 assert!(!msg.param_oids.is_empty());
854 assert_eq!(msg.rows.len(), 1);
855 assert_eq!(msg.rows[0]["p"].as_i32(), Some(42));
856 }
857
858 #[test]
859 fn connect_execute_params_success() {
860 let port = spawn_execute_params_server();
861 let mut conn = Connect::new(mock_config(port)).unwrap();
862 let result = conn.execute_params("UPDATE t SET x=$1", &[Some("42")]);
863 assert!(result.is_ok());
864 let msg = result.unwrap();
865 assert!(!msg.param_oids.is_empty());
866 assert_eq!(msg.affect_count, 1);
867 assert_eq!(msg.tag, "UPDATE 1");
868 }
869
870 #[test]
871 fn connect_query_str_success() {
872 let port = spawn_params_server();
873 let mut conn = Connect::new(mock_config(port)).unwrap();
874 let result = conn.query_str("SELECT $1::int", &["42"]);
875 assert!(result.is_ok());
876 let msg = result.unwrap();
877 assert!(!msg.param_oids.is_empty());
878 assert_eq!(msg.rows.len(), 1);
879 }
880
881 #[test]
882 fn connect_execute_str_success() {
883 let port = spawn_execute_params_server();
884 let mut conn = Connect::new(mock_config(port)).unwrap();
885 let result = conn.execute_str("UPDATE t SET x=$1", &["1"]);
886 assert!(result.is_ok());
887 let msg = result.unwrap();
888 assert!(!msg.param_oids.is_empty());
889 assert_eq!(msg.affect_count, 1);
890 }
891
892 #[test]
893 fn connect_query_params_with_null() {
894 let port = spawn_query_params_null_server();
895 let mut conn = Connect::new(mock_config(port)).unwrap();
896 let result = conn.query_params("SELECT $1::text", &[None]);
897 assert!(result.is_ok());
898 let msg = result.unwrap();
899 assert!(!msg.param_oids.is_empty());
900 assert_eq!(msg.rows.len(), 1);
901 assert_eq!(msg.rows[0]["n"], "");
902 }
903
904 #[test]
905 fn connect_query_params_empty_string_vs_null() {
906 let port = spawn_params_server();
907 let mut conn = Connect::new(mock_config(port)).unwrap();
908
909 let r1 = conn.query_params("SELECT $1::text", &[Some("")]);
911 assert!(r1.is_ok());
912
913 let r2 = conn.query_params("SELECT $1::text", &[None]);
915 assert!(r2.is_ok());
916 }
917
918 #[test]
919 fn connect_query_returns_error() {
920 let port = spawn_query_error_server();
921 let mut conn = Connect::new(mock_config(port)).unwrap();
922 let result = conn.query("BAD SQL");
923 assert!(result.is_err());
924 }
925
926 #[test]
927 fn connect_is_valid_true() {
928 let port = spawn_cleartext_server();
929 let mut conn = Connect::new(mock_config(port)).unwrap();
930 assert!(conn.is_valid());
931 }
932
933 #[test]
934 fn connect_is_valid_false_after_close() {
935 let port = spawn_cleartext_server();
936 let mut conn = Connect::new(mock_config(port)).unwrap();
937 conn._close();
938 assert!(!conn.is_valid());
940 }
941
942 #[test]
943 fn connect_close_does_not_panic() {
944 let port = spawn_cleartext_server();
945 let mut conn = Connect::new(mock_config(port)).unwrap();
946 conn._close();
947 conn._close();
949 }
950
951 #[test]
952 fn connect_drop_does_not_panic() {
953 let port = spawn_cleartext_server();
954 let conn = Connect::new(mock_config(port)).unwrap();
955 drop(conn);
956 }
957
958 fn spawn_transaction_status_server() -> u16 {
959 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
960 let port = listener.local_addr().unwrap().port();
961 thread::spawn(move || {
962 let (mut s, _) = listener.accept().unwrap();
963 let mut buf = [0u8; 4096];
964 let _ = s.read(&mut buf).unwrap();
965 let _ = s.write_all(&pg_auth(3, &[]));
966 let _ = s.read(&mut buf).unwrap();
967 let _ = s.write_all(&post_auth_ok());
968 loop {
969 match s.read(&mut buf) {
970 Ok(0) | Err(_) => break,
971 Ok(_) => {
972 let mut r = Vec::new();
973 r.extend(pg_msg(b'1', &[]));
974 r.extend(pg_msg(b'2', &[]));
975 let mut rd = Vec::new();
976 rd.extend(&1u16.to_be_bytes());
977 rd.extend(b"c\x00");
978 rd.extend(&0u32.to_be_bytes());
979 rd.extend(&1u16.to_be_bytes());
980 rd.extend(&23u32.to_be_bytes());
981 rd.extend(&4i16.to_be_bytes());
982 rd.extend(&(-1i32).to_be_bytes());
983 rd.extend(&0u16.to_be_bytes());
984 r.extend(pg_msg(b'T', &rd));
985 let mut dr = Vec::new();
986 dr.extend(&1u16.to_be_bytes());
987 dr.extend(&1u32.to_be_bytes());
988 dr.push(b'1');
989 r.extend(pg_msg(b'D', &dr));
990 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
991 r.extend(pg_msg(b'Z', b"T"));
992 let _ = s.write_all(&r);
993 }
994 }
995 }
996 });
997 thread::sleep(Duration::from_millis(30));
998 port
999 }
1000
1001 fn spawn_error_status_server() -> u16 {
1002 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1003 let port = listener.local_addr().unwrap().port();
1004 thread::spawn(move || {
1005 let (mut s, _) = listener.accept().unwrap();
1006 let mut buf = [0u8; 4096];
1007 let _ = s.read(&mut buf).unwrap();
1008 let _ = s.write_all(&pg_auth(3, &[]));
1009 let _ = s.read(&mut buf).unwrap();
1010 let _ = s.write_all(&post_auth_ok());
1011 loop {
1012 match s.read(&mut buf) {
1013 Ok(0) | Err(_) => break,
1014 Ok(_) => {
1015 let mut r = Vec::new();
1016 r.extend(pg_msg(b'1', &[]));
1017 r.extend(pg_msg(b'2', &[]));
1018 let mut rd = Vec::new();
1019 rd.extend(&1u16.to_be_bytes());
1020 rd.extend(b"c\x00");
1021 rd.extend(&0u32.to_be_bytes());
1022 rd.extend(&1u16.to_be_bytes());
1023 rd.extend(&23u32.to_be_bytes());
1024 rd.extend(&4i16.to_be_bytes());
1025 rd.extend(&(-1i32).to_be_bytes());
1026 rd.extend(&0u16.to_be_bytes());
1027 r.extend(pg_msg(b'T', &rd));
1028 let mut dr = Vec::new();
1029 dr.extend(&1u16.to_be_bytes());
1030 dr.extend(&1u32.to_be_bytes());
1031 dr.push(b'1');
1032 r.extend(pg_msg(b'D', &dr));
1033 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
1034 r.extend(pg_msg(b'Z', b"E"));
1035 let _ = s.write_all(&r);
1036 }
1037 }
1038 }
1039 });
1040 thread::sleep(Duration::from_millis(30));
1041 port
1042 }
1043
1044 fn spawn_slow_partial_server() -> u16 {
1045 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1046 let port = listener.local_addr().unwrap().port();
1047 thread::spawn(move || {
1048 let (mut s, _) = listener.accept().unwrap();
1049 let mut buf = [0u8; 4096];
1050 let _ = s.read(&mut buf).unwrap();
1051 let _ = s.write_all(&pg_auth(3, &[]));
1052 let _ = s.read(&mut buf).unwrap();
1053 let _ = s.write_all(&post_auth_ok());
1054 match s.read(&mut buf) {
1055 Ok(0) | Err(_) => {}
1056 Ok(_) => {
1057 let _ = s.write_all(&simple_query_response());
1058 }
1059 }
1060 });
1061 thread::sleep(Duration::from_millis(30));
1062 port
1063 }
1064
1065 fn spawn_rst_on_query_server() -> u16 {
1066 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1067 let port = listener.local_addr().unwrap().port();
1068 thread::spawn(move || {
1069 let (mut s, _) = listener.accept().unwrap();
1070 let mut buf = [0u8; 4096];
1071 let _ = s.read(&mut buf).unwrap();
1072 let _ = s.write_all(&pg_auth(3, &[]));
1073 let _ = s.read(&mut buf).unwrap();
1074 let _ = s.write_all(&post_auth_ok());
1075 match s.read(&mut buf) {
1076 Ok(0) | Err(_) => {}
1077 Ok(_) => {
1078 drop(s);
1079 }
1080 }
1081 });
1082 thread::sleep(Duration::from_millis(30));
1083 port
1084 }
1085
1086 #[test]
1087 fn connect_query_ready_for_query_transaction_status() {
1088 let port = spawn_transaction_status_server();
1089 let mut conn = Connect::new(mock_config(port)).unwrap();
1090 let result = conn.query("SELECT 1");
1091 assert!(result.is_ok());
1092 }
1093
1094 #[test]
1095 fn connect_query_ready_for_query_error_status() {
1096 let port = spawn_error_status_server();
1097 let mut conn = Connect::new(mock_config(port)).unwrap();
1098 let result = conn.query("SELECT 1");
1099 assert!(result.is_ok());
1100 }
1101
1102 #[test]
1103 fn connect_query_server_closes_after_partial() {
1104 let port = spawn_slow_partial_server();
1105 let mut conn = Connect::new(mock_config(port)).unwrap();
1106 let r1 = conn.query("SELECT 1");
1107 assert!(r1.is_ok());
1108 let r2 = conn.query("SELECT 1");
1109 assert!(r2.is_err());
1110 }
1111
1112 #[test]
1113 fn connect_query_server_rst_returns_io_or_connection_error() {
1114 let port = spawn_rst_on_query_server();
1115 let mut conn = Connect::new(mock_config(port)).unwrap();
1116 let result = conn.query("SELECT 1");
1117 assert!(result.is_err());
1118 }
1119
1120 #[test]
1121 fn connect_read_would_block_max_retries() {
1122 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1123 let port = listener.local_addr().unwrap().port();
1124 thread::spawn(move || {
1125 let (mut s, _) = listener.accept().unwrap();
1126 let mut buf = [0u8; 4096];
1127 let _ = s.read(&mut buf);
1128 let _ = s.write_all(&pg_auth(3, &[]));
1129 let _ = s.read(&mut buf);
1130 let _ = s.write_all(&post_auth_ok());
1131 let _ = s.read(&mut buf);
1132 thread::sleep(Duration::from_secs(5));
1133 });
1134 thread::sleep(Duration::from_millis(30));
1135
1136 let mut conn = Connect::new(mock_config(port)).unwrap();
1137 conn.stream
1138 .set_read_timeout(Some(Duration::from_millis(1)))
1139 .ok();
1140 let result = conn.query("SELECT 1");
1141 assert!(result.is_err());
1142 let err_str = result.unwrap_err().to_string();
1143 assert!(
1144 err_str.contains("超时") || err_str.contains("Timeout") || err_str.contains("重试"),
1145 "expected timeout error, got: {err_str}"
1146 );
1147 }
1148
1149 #[test]
1150 fn connect_read_exceeds_max_message_size() {
1151 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1152 let port = listener.local_addr().unwrap().port();
1153 thread::spawn(move || {
1154 let (mut s, _) = listener.accept().unwrap();
1155 let mut buf = [0u8; 4096];
1156 let _ = s.read(&mut buf);
1157 let _ = s.write_all(&pg_auth(3, &[]));
1158 let _ = s.read(&mut buf);
1159 let _ = s.write_all(&post_auth_ok());
1160 let _ = s.read(&mut buf);
1161 let big = vec![b'X'; 256];
1162 let _ = s.write_all(&big);
1163 thread::sleep(Duration::from_secs(2));
1164 });
1165 thread::sleep(Duration::from_millis(30));
1166
1167 let mut conn = Connect::new(mock_config(port)).unwrap();
1168 let result = conn.query("SELECT 1");
1169 assert!(result.is_err());
1170 let err_str = result.unwrap_err().to_string();
1171 assert!(
1172 err_str.contains("最大") || err_str.contains("大小") || err_str.contains("size"),
1173 "expected max message size error, got: {err_str}"
1174 );
1175 }
1176
1177 #[test]
1178 fn connect_read_deadline_timeout() {
1179 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1180 let port = listener.local_addr().unwrap().port();
1181 thread::spawn(move || {
1182 let (mut s, _) = listener.accept().unwrap();
1183 let mut buf = [0u8; 4096];
1184 let _ = s.read(&mut buf);
1185 let _ = s.write_all(&pg_auth(3, &[]));
1186 let _ = s.read(&mut buf);
1187 let _ = s.write_all(&post_auth_ok());
1188 let _ = s.read(&mut buf);
1189 for _ in 0..200 {
1190 let _ = s.write_all(b"X");
1191 thread::sleep(Duration::from_millis(5));
1192 }
1193 });
1194 thread::sleep(Duration::from_millis(30));
1195
1196 let mut conn = Connect::new(mock_config(port)).unwrap();
1197 let result = conn.query("SELECT 1");
1198 assert!(result.is_err());
1199 }
1200
1201 #[test]
1202 fn connect_read_partial_auth_frame() {
1203 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1204 let port = listener.local_addr().unwrap().port();
1205 thread::spawn(move || {
1206 let (mut s, _) = listener.accept().unwrap();
1207 let mut buf = [0u8; 4096];
1208 let _ = s.read(&mut buf);
1209 let auth = pg_auth(3, &[]);
1210 let _ = s.write_all(&auth[..5]);
1211 thread::sleep(Duration::from_millis(50));
1212 let _ = s.write_all(&auth[5..]);
1213 let _ = s.read(&mut buf);
1214 let _ = s.write_all(&post_auth_ok());
1215 loop {
1216 match s.read(&mut buf) {
1217 Ok(0) | Err(_) => break,
1218 Ok(_) => {
1219 let _ = s.write_all(&simple_query_response());
1220 }
1221 }
1222 }
1223 });
1224 thread::sleep(Duration::from_millis(30));
1225
1226 let mut conn = Connect::new(mock_config(port)).unwrap();
1227 let result = conn.query("SELECT 1");
1228 assert!(result.is_ok());
1229 }
1230}