1use crate::config::Config;
2use crate::error::PgsqlError;
3use crate::packet::{AuthStatus, Packet, SuccessMessage};
4use std::io::{Read, Write};
5use std::net::{SocketAddr, TcpStream};
6use std::time::{Duration, Instant};
7
8#[derive(Debug)]
10pub(crate) enum PgStream {
11 Plain(TcpStream),
12 #[cfg(feature = "tls")]
13 Tls(native_tls::TlsStream<TcpStream>),
14}
15
16impl Read for PgStream {
17 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
18 match self {
19 PgStream::Plain(s) => s.read(buf),
20 #[cfg(feature = "tls")]
21 PgStream::Tls(s) => s.read(buf),
22 }
23 }
24}
25
26impl Write for PgStream {
27 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
28 match self {
29 PgStream::Plain(s) => s.write(buf),
30 #[cfg(feature = "tls")]
31 PgStream::Tls(s) => s.write(buf),
32 }
33 }
34 fn flush(&mut self) -> std::io::Result<()> {
35 match self {
36 PgStream::Plain(s) => s.flush(),
37 #[cfg(feature = "tls")]
38 PgStream::Tls(s) => s.flush(),
39 }
40 }
41}
42
43impl PgStream {
44 fn peer_addr(&self) -> std::io::Result<SocketAddr> {
45 match self {
46 PgStream::Plain(s) => s.peer_addr(),
47 #[cfg(feature = "tls")]
48 PgStream::Tls(s) => s.get_ref().peer_addr(),
49 }
50 }
51 fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
52 match self {
53 PgStream::Plain(s) => s.shutdown(how),
54 #[cfg(feature = "tls")]
55 PgStream::Tls(s) => s.get_ref().shutdown(how),
56 }
57 }
58 #[allow(dead_code)]
59 fn set_read_timeout(&self, dur: Option<Duration>) -> std::io::Result<()> {
60 match self {
61 PgStream::Plain(s) => s.set_read_timeout(dur),
62 #[cfg(feature = "tls")]
63 PgStream::Tls(s) => s.get_ref().set_read_timeout(dur),
64 }
65 }
66}
67
68#[derive(Debug)]
69pub struct Connect {
70 pub(crate) stream: PgStream,
71 _peer_addr: SocketAddr,
73 packet: Packet,
74 auth_status: AuthStatus,
75 last_used: Instant,
77 created_at: Instant,
79}
80
81impl Connect {
82 pub fn is_valid(&mut self) -> bool {
84 if self.stream.peer_addr().is_err() {
85 return false;
86 }
87 #[cfg(not(test))]
88 const IDLE_THRESHOLD: Duration = Duration::from_secs(5);
89 #[cfg(test)]
90 const IDLE_THRESHOLD: Duration = Duration::from_millis(0);
91 if self.last_used.elapsed() > IDLE_THRESHOLD {
92 return self.query("SELECT 1").is_ok();
93 }
94 true
95 }
96
97 pub fn peer_valid(&self) -> bool {
99 self.stream.peer_addr().is_ok()
100 }
101
102 pub fn touch(&mut self) {
104 self.last_used = Instant::now();
105 }
106
107 pub fn idle_elapsed(&self) -> Duration {
109 self.last_used.elapsed()
110 }
111
112 pub fn age(&self) -> Duration {
114 self.created_at.elapsed()
115 }
116
117 pub fn _close(&mut self) {
118 let _ = self.stream.write_all(&Packet::pack_terminate());
119 let _ = self.stream.shutdown(std::net::Shutdown::Both);
120 }
121
122 fn set_keepalive(stream: &TcpStream) -> Result<(), PgsqlError> {
124 use std::os::unix::io::{AsRawFd, FromRawFd};
125 let fd = stream.as_raw_fd();
126 let socket = unsafe { socket2::Socket::from_raw_fd(fd) };
127 let keepalive = socket2::TcpKeepalive::new()
128 .with_time(Duration::from_secs(60))
129 .with_interval(Duration::from_secs(15))
130 .with_retries(3);
131 let result = socket.set_tcp_keepalive(&keepalive);
132 std::mem::forget(socket);
134 result.map_err(|e| PgsqlError::Connection(format!("设置 TCP Keepalive 失败: {}", e)))
135 }
136
137 fn try_ssl_upgrade(mut stream: TcpStream, config: &Config) -> Result<PgStream, PgsqlError> {
139 stream
141 .write_all(&Packet::pack_ssl_request())
142 .map_err(|e| PgsqlError::Connection(format!("发送 SSLRequest 失败: {}", e)))?;
143 let mut resp = [0u8; 1];
144 stream
145 .read_exact(&mut resp)
146 .map_err(|e| PgsqlError::Connection(format!("读取 SSL 响应失败: {}", e)))?;
147 match resp[0] {
148 b'S' => {
149 #[cfg(feature = "tls")]
151 {
152 let connector = native_tls::TlsConnector::builder()
153 .danger_accept_invalid_certs(true)
154 .build()
155 .map_err(|e| PgsqlError::Connection(format!("TLS 初始化失败: {}", e)))?;
156 let tls_stream = connector
157 .connect(&config.hostname, stream)
158 .map_err(|e| PgsqlError::Connection(format!("TLS 握手失败: {}", e)))?;
159 Ok(PgStream::Tls(tls_stream))
160 }
161 #[cfg(not(feature = "tls"))]
162 {
163 let _ = config;
164 Err(PgsqlError::Connection(
165 "服务端要求 SSL 但未启用 tls feature".into(),
166 ))
167 }
168 }
169 b'N' => {
170 if config.sslmode == "require" {
172 Err(PgsqlError::Connection(
173 "sslmode=require 但服务端不支持 SSL".into(),
174 ))
175 } else {
176 Ok(PgStream::Plain(stream))
177 }
178 }
179 other => Err(PgsqlError::Connection(format!(
180 "无效的 SSL 响应字节: 0x{:02X}",
181 other
182 ))),
183 }
184 }
185 pub fn new(mut config: Config) -> Result<Connect, PgsqlError> {
186 let stream =
187 TcpStream::connect(config.url()).map_err(|e| PgsqlError::Connection(e.to_string()))?;
188 stream
190 .set_nodelay(true)
191 .map_err(|e| PgsqlError::Connection(format!("设置 TCP_NODELAY 失败: {}", e)))?;
192 Self::set_keepalive(&stream)?;
194 stream
195 .set_read_timeout(Some(Duration::from_secs(30)))
196 .map_err(|e| PgsqlError::Connection(format!("设置读取超时失败: {}", e)))?;
197 stream
198 .set_write_timeout(Some(Duration::from_secs(30)))
199 .map_err(|e| PgsqlError::Connection(format!("设置写入超时失败: {}", e)))?;
200 let peer_addr = stream
201 .peer_addr()
202 .map_err(|e| PgsqlError::Connection(e.to_string()))?;
203
204 let stream = if config.sslmode != "disable" {
206 Self::try_ssl_upgrade(stream, &config)?
207 } else {
208 PgStream::Plain(stream)
209 };
210
211 let mut connect = Self {
212 stream,
213 _peer_addr: peer_addr,
214 packet: Packet::new(config),
215 auth_status: AuthStatus::None,
216 last_used: Instant::now(),
217 created_at: Instant::now(),
218 };
219
220 connect.authenticate()?;
221
222 Ok(connect)
223 }
224
225 fn authenticate(&mut self) -> Result<(), PgsqlError> {
226 self.stream
227 .write_all(&self.packet.pack_first())
228 .map_err(|e| PgsqlError::Auth(format!("发送 startup message 失败: {}", e)))?;
229
230 let data = self.read()?;
231 self.packet.unpack(data, 0)?;
232
233 if !self.packet.md5_salt.is_empty() {
234 self.md5_auth()?;
235 } else if self.packet.auth_mechanism.is_empty() && self.packet.md5_salt.is_empty() {
236 self.cleartext_auth()?;
237 } else {
238 self.scram_auth()?;
239 }
240
241 self.auth_status = AuthStatus::AuthenticationOk;
242 Ok(())
243 }
244
245 fn md5_auth(&mut self) -> Result<(), PgsqlError> {
246 self.stream
247 .write_all(&self.packet.pack_md5_password())
248 .map_err(|e| PgsqlError::Auth(format!("发送 MD5 密码失败: {}", e)))?;
249
250 let data = self.read()?;
251 self.packet.unpack(data, 0)?;
252 Ok(())
253 }
254
255 fn cleartext_auth(&mut self) -> Result<(), PgsqlError> {
256 self.stream
257 .write_all(&self.packet.pack_cleartext_password())
258 .map_err(|e| PgsqlError::Auth(format!("发送明文密码失败: {}", e)))?;
259
260 let data = self.read()?;
261 self.packet.unpack(data, 0)?;
262 Ok(())
263 }
264
265 fn scram_auth(&mut self) -> Result<(), PgsqlError> {
266 self.stream
267 .write_all(&self.packet.pack_auth())
268 .map_err(|e| PgsqlError::Auth(format!("发送 SASL Initial Response 失败: {}", e)))?;
269
270 let data = self.read()?;
271 self.packet.unpack(data, 0)?;
272
273 self.stream
274 .write_all(&self.packet.pack_auth_verify())
275 .map_err(|e| PgsqlError::Auth(format!("发送 SASL Verify 失败: {}", e)))?;
276
277 let data = self.read()?;
278 self.packet.unpack(data, 0)?;
279 Ok(())
280 }
281
282 fn read(&mut self) -> Result<Vec<u8>, PgsqlError> {
283 let mut msg = Vec::new();
284 let mut buf = [0u8; 4096];
285 let mut retry_count = 0;
286
287 #[cfg(not(test))]
288 const MAX_RETRIES: u32 = 100;
289 #[cfg(test)]
290 const MAX_RETRIES: u32 = 3;
291
292 #[cfg(not(test))]
293 const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024;
294 #[cfg(test)]
295 const MAX_MESSAGE_SIZE: usize = 128;
296
297 #[cfg(not(test))]
298 let deadline = std::time::Instant::now() + Duration::from_secs(300);
299 #[cfg(test)]
300 let deadline = std::time::Instant::now() + Duration::from_millis(200);
301
302 loop {
303 if std::time::Instant::now() >= deadline {
304 return Err(PgsqlError::Timeout("读取总超时".into()));
305 }
306
307 match self.stream.read(&mut buf) {
308 Ok(0) => return Err(PgsqlError::Connection("连接已关闭或服务端断开".into())),
309 Ok(n) => {
310 if msg.len() + n > MAX_MESSAGE_SIZE {
311 return Err(PgsqlError::Protocol("消息超过最大允许大小".into()));
312 }
313 msg.extend_from_slice(&buf[..n]);
314 retry_count = 0;
315 }
316 Err(ref e)
317 if e.kind() == std::io::ErrorKind::WouldBlock
318 || e.kind() == std::io::ErrorKind::TimedOut =>
319 {
320 retry_count += 1;
321 if retry_count > MAX_RETRIES {
322 return Err(PgsqlError::Timeout("读取超时,已达最大重试次数".into()));
323 }
324 std::thread::sleep(Duration::from_millis(10));
325 continue;
326 }
327 Err(e) => return Err(PgsqlError::Io(e)),
328 };
329
330 if let AuthStatus::AuthenticationOk = self.auth_status {
331 if msg.ends_with(&[90, 0, 0, 0, 5, 73])
332 || msg.ends_with(&[90, 0, 0, 0, 5, 84])
333 || msg.ends_with(&[90, 0, 0, 0, 5, 69])
334 {
335 break;
336 }
337 } else if msg.len() >= 5 {
338 let len_bytes = &msg[1..=4];
339 if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
340 if msg.len() > len as usize {
341 break;
342 }
343 }
344 }
345 }
346
347 Ok(msg)
348 }
349
350 pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
351 self.stream
352 .write_all(&self.packet.pack_query(sql))
353 .map_err(PgsqlError::Io)?;
354 let data = self.read()?;
355 self.last_used = Instant::now();
356 self.packet.unpack(data, 0)
357 }
358
359 pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
360 self.stream
361 .write_all(&self.packet.pack_execute(sql))
362 .map_err(PgsqlError::Io)?;
363 let data = self.read()?;
364 self.last_used = Instant::now();
365 self.packet.unpack(data, 0)
366 }
367
368 pub fn query_params(
370 &mut self,
371 sql: &str,
372 params: &[Option<&str>],
373 ) -> Result<SuccessMessage, PgsqlError> {
374 self.stream
375 .write_all(&self.packet.pack_query_params(sql, params))
376 .map_err(PgsqlError::Io)?;
377
378 let data = self.read()?;
379 self.last_used = Instant::now();
380 self.packet.unpack(data, 0)
381 }
382
383 pub fn execute_params(
385 &mut self,
386 sql: &str,
387 params: &[Option<&str>],
388 ) -> Result<SuccessMessage, PgsqlError> {
389 self.stream
390 .write_all(&self.packet.pack_execute_params(sql, params))
391 .map_err(PgsqlError::Io)?;
392 let data = self.read()?;
393 self.last_used = Instant::now();
394 self.packet.unpack(data, 0)
395 }
396
397 pub fn query_str(&mut self, sql: &str, params: &[&str]) -> Result<SuccessMessage, PgsqlError> {
399 let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
400 self.query_params(sql, &opts)
401 }
402
403 pub fn execute_str(
405 &mut self,
406 sql: &str,
407 params: &[&str],
408 ) -> Result<SuccessMessage, PgsqlError> {
409 let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
410 self.execute_params(sql, &opts)
411 }
412 pub fn query_portal(&mut self, sql: &str, max_rows: u32) -> Result<SuccessMessage, PgsqlError> {
415 self.stream
416 .write_all(&self.packet.pack_query_portal(sql, max_rows))
417 .map_err(PgsqlError::Io)?;
418 let data = self.read()?;
419 self.last_used = Instant::now();
420 self.packet.unpack(data, 0)
421 }
422 pub fn fetch_more(&mut self, max_rows: u32) -> Result<SuccessMessage, PgsqlError> {
424 self.stream
425 .write_all(&self.packet.pack_fetch_more(max_rows))
426 .map_err(PgsqlError::Io)?;
427 let data = self.read()?;
428 self.last_used = Instant::now();
429 self.packet.unpack(data, 0)
430 }
431 pub fn close_portal(&mut self) -> Result<SuccessMessage, PgsqlError> {
433 self.stream
434 .write_all(&self.packet.pack_close_portal())
435 .map_err(PgsqlError::Io)?;
436 let data = self.read()?;
437 self.last_used = Instant::now();
438 self.packet.unpack(data, 0)
439 }
440}
441
442impl Drop for Connect {
443 fn drop(&mut self) {
444 let _ = self.stream.write_all(&Packet::pack_terminate());
445 let _ = self.stream.shutdown(std::net::Shutdown::Both);
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452 use std::net::TcpListener;
453 use std::thread;
454
455 fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
459 let mut m = Vec::with_capacity(5 + payload.len());
460 m.push(tag);
461 m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
462 m.extend_from_slice(payload);
463 m
464 }
465
466 fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
468 let mut body = Vec::new();
469 body.extend(&auth_type.to_be_bytes());
470 body.extend_from_slice(extra);
471 pg_msg(b'R', &body)
472 }
473
474 fn auth_ok() -> Vec<u8> {
476 pg_auth(0, &[])
477 }
478
479 fn param_status() -> Vec<u8> {
481 pg_msg(b'S', b"server_version\x0015.0\x00")
482 }
483
484 fn backend_key() -> Vec<u8> {
486 let mut p = Vec::new();
487 p.extend(&1u32.to_be_bytes());
488 p.extend(&2u32.to_be_bytes());
489 pg_msg(b'K', &p)
490 }
491
492 fn ready_for_query() -> Vec<u8> {
494 pg_msg(b'Z', b"I")
495 }
496
497 fn post_auth_ok() -> Vec<u8> {
499 let mut v = Vec::new();
500 v.extend(auth_ok());
501 v.extend(param_status());
502 v.extend(backend_key());
503 v.extend(ready_for_query());
504 v
505 }
506
507 fn simple_query_response() -> Vec<u8> {
510 let mut r = Vec::new();
511 r.extend(pg_msg(b'1', &[]));
513 r.extend(pg_msg(b'2', &[]));
515 let mut rd = Vec::new();
517 rd.extend(&1u16.to_be_bytes()); rd.extend(b"c\x00"); rd.extend(&0u32.to_be_bytes()); rd.extend(&1u16.to_be_bytes()); rd.extend(&23u32.to_be_bytes()); rd.extend(&4i16.to_be_bytes()); rd.extend(&(-1i32).to_be_bytes()); rd.extend(&0u16.to_be_bytes()); r.extend(pg_msg(b'T', &rd));
526 let mut dr = Vec::new();
528 dr.extend(&1u16.to_be_bytes());
529 dr.extend(&1u32.to_be_bytes()); dr.push(b'1');
531 r.extend(pg_msg(b'D', &dr));
532 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
534 r.extend(ready_for_query());
536 r
537 }
538
539 fn execute_response() -> Vec<u8> {
542 let mut r = Vec::new();
543 r.extend(pg_msg(b'1', &[]));
544 r.extend(pg_msg(b'2', &[]));
545 r.extend(pg_msg(b'n', &[])); r.extend(pg_msg(b'C', b"UPDATE 3\x00"));
547 r.extend(ready_for_query());
548 r
549 }
550
551 fn query_params_response() -> Vec<u8> {
554 let mut r = Vec::new();
555 r.extend(pg_msg(b'1', &[]));
556
557 let mut pd = Vec::new();
558 pd.extend(&1u16.to_be_bytes());
559 pd.extend(&23u32.to_be_bytes());
560 r.extend(pg_msg(b't', &pd));
561
562 r.extend(pg_msg(b'2', &[]));
563
564 let mut rd = Vec::new();
565 rd.extend(&1u16.to_be_bytes());
566 rd.extend(b"p\x00");
567 rd.extend(&0u32.to_be_bytes());
568 rd.extend(&1u16.to_be_bytes());
569 rd.extend(&23u32.to_be_bytes());
570 rd.extend(&4i16.to_be_bytes());
571 rd.extend(&(-1i32).to_be_bytes());
572 rd.extend(&0u16.to_be_bytes());
573 r.extend(pg_msg(b'T', &rd));
574
575 let mut dr = Vec::new();
576 dr.extend(&1u16.to_be_bytes());
577 dr.extend(&2u32.to_be_bytes());
578 dr.extend(b"42");
579 r.extend(pg_msg(b'D', &dr));
580
581 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
582 r.extend(ready_for_query());
583 r
584 }
585
586 fn execute_params_response() -> Vec<u8> {
589 let mut r = Vec::new();
590 r.extend(pg_msg(b'1', &[]));
591
592 let mut pd = Vec::new();
593 pd.extend(&1u16.to_be_bytes());
594 pd.extend(&23u32.to_be_bytes());
595 r.extend(pg_msg(b't', &pd));
596
597 r.extend(pg_msg(b'2', &[]));
598 r.extend(pg_msg(b'n', &[]));
599 r.extend(pg_msg(b'C', b"UPDATE 1\x00"));
600 r.extend(ready_for_query());
601 r
602 }
603
604 fn query_params_null_response() -> Vec<u8> {
606 let mut r = Vec::new();
607 r.extend(pg_msg(b'1', &[]));
608
609 let mut pd = Vec::new();
610 pd.extend(&1u16.to_be_bytes());
611 pd.extend(&25u32.to_be_bytes());
612 r.extend(pg_msg(b't', &pd));
613
614 r.extend(pg_msg(b'2', &[]));
615
616 let mut rd = Vec::new();
617 rd.extend(&1u16.to_be_bytes());
618 rd.extend(b"n\x00");
619 rd.extend(&0u32.to_be_bytes());
620 rd.extend(&1u16.to_be_bytes());
621 rd.extend(&25u32.to_be_bytes());
622 rd.extend(&(-1i16).to_be_bytes());
623 rd.extend(&(-1i32).to_be_bytes());
624 rd.extend(&0u16.to_be_bytes());
625 r.extend(pg_msg(b'T', &rd));
626
627 let mut dr = Vec::new();
628 dr.extend(&1u16.to_be_bytes());
629 dr.extend(&(-1i32).to_be_bytes());
630 r.extend(pg_msg(b'D', &dr));
631
632 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
633 r.extend(ready_for_query());
634 r
635 }
636
637 fn error_response() -> Vec<u8> {
639 let mut payload = Vec::new();
640 payload.push(b'C');
641 payload.extend(b"42601\x00");
642 payload.push(b'M');
643 payload.extend(b"syntax error\x00");
644 payload.push(0);
645 let mut r = Vec::new();
646 r.extend(pg_msg(b'1', &[]));
647 r.extend(pg_msg(b'2', &[]));
648 r.extend(pg_msg(b'E', &payload));
649 r.extend(ready_for_query());
650 r
651 }
652
653 fn mock_config(port: u16) -> Config {
657 Config {
658 debug: false,
659 hostname: "127.0.0.1".into(),
660 hostport: port as i32,
661 username: "u".into(),
662 userpass: "p".into(),
663 database: "d".into(),
664 charset: "utf8".into(),
665 pool_max: 5,
666 sslmode: "disable".into(),
667 }
668 }
669
670 fn spawn_cleartext_server() -> u16 {
673 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
674 let port = listener.local_addr().unwrap().port();
675 thread::spawn(move || {
676 let (mut s, _) = listener.accept().unwrap();
677 let mut buf = [0u8; 4096];
678 let _ = s.read(&mut buf).unwrap();
680 let _ = s.write_all(&pg_auth(3, &[]));
682 let _ = s.read(&mut buf).unwrap();
684 let _ = s.write_all(&post_auth_ok());
686 loop {
688 match s.read(&mut buf) {
689 Ok(0) | Err(_) => break,
690 Ok(_) => {
691 let _ = s.write_all(&simple_query_response());
692 }
693 }
694 }
695 });
696 thread::sleep(Duration::from_millis(30));
697 port
698 }
699
700 fn spawn_md5_server() -> u16 {
702 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
703 let port = listener.local_addr().unwrap().port();
704 thread::spawn(move || {
705 let (mut s, _) = listener.accept().unwrap();
706 let mut buf = [0u8; 4096];
707 let _ = s.read(&mut buf).unwrap();
709 let _ = s.write_all(&pg_auth(5, &[0xAA, 0xBB, 0xCC, 0xDD]));
711 let _ = s.read(&mut buf).unwrap();
713 let _ = s.write_all(&post_auth_ok());
715 loop {
716 match s.read(&mut buf) {
717 Ok(0) | Err(_) => break,
718 Ok(_) => {
719 let _ = s.write_all(&simple_query_response());
720 }
721 }
722 }
723 });
724 thread::sleep(Duration::from_millis(30));
725 port
726 }
727
728 fn spawn_scram_server() -> u16 {
730 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
731 let port = listener.local_addr().unwrap().port();
732 thread::spawn(move || {
733 let (mut s, _) = listener.accept().unwrap();
734 let mut buf = [0u8; 4096];
735 let _ = s.read(&mut buf).unwrap();
737 let _ = s.write_all(&pg_auth(10, b"SCRAM-SHA-256\x00\x00"));
739 let n = s.read(&mut buf).unwrap();
741 let payload = &buf[..n];
742 let text = String::from_utf8_lossy(payload);
744 let client_nonce = text.split("r=").nth(1).unwrap_or("clientnonce").to_string();
745 let challenge = format!("r={client_nonce}SERVERNONCE,s=c2FsdA==,i=4096");
747 let _ = s.write_all(&pg_auth(11, challenge.as_bytes()));
748 let _ = s.read(&mut buf).unwrap();
750 let mut resp = Vec::new();
752 resp.extend(pg_auth(12, b"v=dummyproof"));
753 resp.extend(auth_ok());
754 resp.extend(param_status());
755 resp.extend(backend_key());
756 resp.extend(ready_for_query());
757 let _ = s.write_all(&resp);
758 loop {
759 match s.read(&mut buf) {
760 Ok(0) | Err(_) => break,
761 Ok(_) => {
762 let _ = s.write_all(&simple_query_response());
763 }
764 }
765 }
766 });
767 thread::sleep(Duration::from_millis(30));
768 port
769 }
770
771 fn spawn_eof_server() -> u16 {
773 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
774 let port = listener.local_addr().unwrap().port();
775 thread::spawn(move || {
776 let (s, _) = listener.accept().unwrap();
777 drop(s); });
779 thread::sleep(Duration::from_millis(30));
780 port
781 }
782
783 fn spawn_auth_error_server() -> u16 {
785 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
786 let port = listener.local_addr().unwrap().port();
787 thread::spawn(move || {
788 let (mut s, _) = listener.accept().unwrap();
789 let mut buf = [0u8; 4096];
790 let _ = s.read(&mut buf).unwrap();
791 let mut payload = Vec::new();
793 payload.push(b'C');
794 payload.extend(b"28P01\x00");
795 payload.push(b'M');
796 payload.extend(b"password authentication failed\x00");
797 payload.push(0);
798 let _ = s.write_all(&pg_msg(b'E', &payload));
799 });
800 thread::sleep(Duration::from_millis(30));
801 port
802 }
803
804 fn spawn_query_error_server() -> u16 {
806 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
807 let port = listener.local_addr().unwrap().port();
808 thread::spawn(move || {
809 let (mut s, _) = listener.accept().unwrap();
810 let mut buf = [0u8; 4096];
811 let _ = s.read(&mut buf).unwrap();
813 let _ = s.write_all(&pg_auth(3, &[]));
814 let _ = s.read(&mut buf).unwrap();
815 let _ = s.write_all(&post_auth_ok());
816 loop {
818 match s.read(&mut buf) {
819 Ok(0) | Err(_) => break,
820 Ok(_) => {
821 let _ = s.write_all(&error_response());
822 }
823 }
824 }
825 });
826 thread::sleep(Duration::from_millis(30));
827 port
828 }
829
830 fn spawn_execute_server() -> u16 {
832 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
833 let port = listener.local_addr().unwrap().port();
834 thread::spawn(move || {
835 let (mut s, _) = listener.accept().unwrap();
836 let mut buf = [0u8; 4096];
837 let _ = s.read(&mut buf).unwrap();
839 let _ = s.write_all(&pg_auth(3, &[]));
840 let _ = s.read(&mut buf).unwrap();
841 let _ = s.write_all(&post_auth_ok());
842 loop {
844 match s.read(&mut buf) {
845 Ok(0) | Err(_) => break,
846 Ok(_) => {
847 let _ = s.write_all(&execute_response());
848 }
849 }
850 }
851 });
852 thread::sleep(Duration::from_millis(30));
853 port
854 }
855
856 fn spawn_query_params_server() -> u16 {
858 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
859 let port = listener.local_addr().unwrap().port();
860 thread::spawn(move || {
861 let (mut s, _) = listener.accept().unwrap();
862 let mut buf = [0u8; 4096];
863 let _ = s.read(&mut buf).unwrap();
864 let _ = s.write_all(&pg_auth(3, &[]));
865 let _ = s.read(&mut buf).unwrap();
866 let _ = s.write_all(&post_auth_ok());
867 loop {
868 match s.read(&mut buf) {
869 Ok(0) | Err(_) => break,
870 Ok(_) => {
871 let _ = s.write_all(&query_params_response());
872 }
873 }
874 }
875 });
876 thread::sleep(Duration::from_millis(30));
877 port
878 }
879
880 fn spawn_params_server() -> u16 {
881 spawn_query_params_server()
882 }
883
884 fn spawn_execute_params_server() -> u16 {
886 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
887 let port = listener.local_addr().unwrap().port();
888 thread::spawn(move || {
889 let (mut s, _) = listener.accept().unwrap();
890 let mut buf = [0u8; 4096];
891 let _ = s.read(&mut buf).unwrap();
892 let _ = s.write_all(&pg_auth(3, &[]));
893 let _ = s.read(&mut buf).unwrap();
894 let _ = s.write_all(&post_auth_ok());
895 loop {
896 match s.read(&mut buf) {
897 Ok(0) | Err(_) => break,
898 Ok(_) => {
899 let _ = s.write_all(&execute_params_response());
900 }
901 }
902 }
903 });
904 thread::sleep(Duration::from_millis(30));
905 port
906 }
907
908 fn spawn_query_params_null_server() -> u16 {
910 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
911 let port = listener.local_addr().unwrap().port();
912 thread::spawn(move || {
913 let (mut s, _) = listener.accept().unwrap();
914 let mut buf = [0u8; 4096];
915 let _ = s.read(&mut buf).unwrap();
916 let _ = s.write_all(&pg_auth(3, &[]));
917 let _ = s.read(&mut buf).unwrap();
918 let _ = s.write_all(&post_auth_ok());
919 loop {
920 match s.read(&mut buf) {
921 Ok(0) | Err(_) => break,
922 Ok(_) => {
923 let _ = s.write_all(&query_params_null_response());
924 }
925 }
926 }
927 });
928 thread::sleep(Duration::from_millis(30));
929 port
930 }
931
932 #[test]
935 fn connect_cleartext_auth_success() {
936 let port = spawn_cleartext_server();
937 let conn = Connect::new(mock_config(port));
938 assert!(conn.is_ok());
939 }
940
941 #[test]
942 fn connect_md5_auth_success() {
943 let port = spawn_md5_server();
944 let conn = Connect::new(mock_config(port));
945 assert!(conn.is_ok());
946 }
947
948 #[test]
949 fn connect_scram_auth_success() {
950 let port = spawn_scram_server();
951 let conn = Connect::new(mock_config(port));
952 assert!(conn.is_ok());
953 }
954
955 #[test]
956 fn connect_connection_refused() {
957 let cfg = mock_config(1);
959 let result = Connect::new(cfg);
960 assert!(result.is_err());
961 match result.unwrap_err() {
962 PgsqlError::Connection(_) => {}
963 other => panic!("expected Connection error, got {other:?}"),
964 }
965 }
966
967 #[test]
968 fn connect_server_closes_immediately() {
969 let port = spawn_eof_server();
970 let result = Connect::new(mock_config(port));
971 assert!(result.is_err());
972 }
973
974 #[test]
975 fn connect_auth_error_from_server() {
976 let port = spawn_auth_error_server();
977 let result = Connect::new(mock_config(port));
978 assert!(result.is_err());
979 }
980
981 #[test]
982 fn connect_query_success() {
983 let port = spawn_cleartext_server();
984 let mut conn = Connect::new(mock_config(port)).unwrap();
985 let result = conn.query("SELECT 1");
986 assert!(result.is_ok());
987 let msg = result.unwrap();
988 assert_eq!(msg.rows.len(), 1);
989 assert_eq!(msg.rows[0]["c"].as_i32(), Some(1));
990 }
991
992 #[test]
993 fn connect_execute_success() {
994 let port = spawn_execute_server();
995 let mut conn = Connect::new(mock_config(port)).unwrap();
996 let result = conn.execute("UPDATE t SET x=1");
997 assert!(result.is_ok());
998 let msg = result.unwrap();
999 assert_eq!(msg.affect_count, 3);
1000 assert_eq!(msg.tag, "UPDATE 3");
1001 }
1002
1003 #[test]
1004 fn connect_query_params_success() {
1005 let port = spawn_query_params_server();
1006 let mut conn = Connect::new(mock_config(port)).unwrap();
1007 let result = conn.query_params("SELECT $1::int", &[Some("42")]);
1008 assert!(result.is_ok());
1009 let msg = result.unwrap();
1010 assert!(!msg.param_oids.is_empty());
1011 assert_eq!(msg.rows.len(), 1);
1012 assert_eq!(msg.rows[0]["p"].as_i32(), Some(42));
1013 }
1014
1015 #[test]
1016 fn connect_execute_params_success() {
1017 let port = spawn_execute_params_server();
1018 let mut conn = Connect::new(mock_config(port)).unwrap();
1019 let result = conn.execute_params("UPDATE t SET x=$1", &[Some("42")]);
1020 assert!(result.is_ok());
1021 let msg = result.unwrap();
1022 assert!(!msg.param_oids.is_empty());
1023 assert_eq!(msg.affect_count, 1);
1024 assert_eq!(msg.tag, "UPDATE 1");
1025 }
1026
1027 #[test]
1028 fn connect_query_str_success() {
1029 let port = spawn_params_server();
1030 let mut conn = Connect::new(mock_config(port)).unwrap();
1031 let result = conn.query_str("SELECT $1::int", &["42"]);
1032 assert!(result.is_ok());
1033 let msg = result.unwrap();
1034 assert!(!msg.param_oids.is_empty());
1035 assert_eq!(msg.rows.len(), 1);
1036 }
1037
1038 #[test]
1039 fn connect_execute_str_success() {
1040 let port = spawn_execute_params_server();
1041 let mut conn = Connect::new(mock_config(port)).unwrap();
1042 let result = conn.execute_str("UPDATE t SET x=$1", &["1"]);
1043 assert!(result.is_ok());
1044 let msg = result.unwrap();
1045 assert!(!msg.param_oids.is_empty());
1046 assert_eq!(msg.affect_count, 1);
1047 }
1048
1049 #[test]
1050 fn connect_query_params_with_null() {
1051 let port = spawn_query_params_null_server();
1052 let mut conn = Connect::new(mock_config(port)).unwrap();
1053 let result = conn.query_params("SELECT $1::text", &[None]);
1054 assert!(result.is_ok());
1055 let msg = result.unwrap();
1056 assert!(!msg.param_oids.is_empty());
1057 assert_eq!(msg.rows.len(), 1);
1058 assert_eq!(msg.rows[0]["n"], "");
1059 }
1060
1061 #[test]
1062 fn connect_query_params_empty_string_vs_null() {
1063 let port = spawn_params_server();
1064 let mut conn = Connect::new(mock_config(port)).unwrap();
1065
1066 let r1 = conn.query_params("SELECT $1::text", &[Some("")]);
1068 assert!(r1.is_ok());
1069
1070 let r2 = conn.query_params("SELECT $1::text", &[None]);
1072 assert!(r2.is_ok());
1073 }
1074
1075 #[test]
1076 fn connect_query_returns_error() {
1077 let port = spawn_query_error_server();
1078 let mut conn = Connect::new(mock_config(port)).unwrap();
1079 let result = conn.query("BAD SQL");
1080 assert!(result.is_err());
1081 }
1082
1083 #[test]
1084 fn connect_is_valid_true() {
1085 let port = spawn_cleartext_server();
1086 let mut conn = Connect::new(mock_config(port)).unwrap();
1087 assert!(conn.is_valid());
1088 }
1089
1090 #[test]
1091 fn connect_is_valid_false_after_close() {
1092 let port = spawn_cleartext_server();
1093 let mut conn = Connect::new(mock_config(port)).unwrap();
1094 conn._close();
1095 assert!(!conn.is_valid());
1097 }
1098
1099 #[test]
1100 fn connect_close_does_not_panic() {
1101 let port = spawn_cleartext_server();
1102 let mut conn = Connect::new(mock_config(port)).unwrap();
1103 conn._close();
1104 conn._close();
1106 }
1107
1108 #[test]
1109 fn connect_drop_does_not_panic() {
1110 let port = spawn_cleartext_server();
1111 let conn = Connect::new(mock_config(port)).unwrap();
1112 drop(conn);
1113 }
1114
1115 fn spawn_transaction_status_server() -> u16 {
1116 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1117 let port = listener.local_addr().unwrap().port();
1118 thread::spawn(move || {
1119 let (mut s, _) = listener.accept().unwrap();
1120 let mut buf = [0u8; 4096];
1121 let _ = s.read(&mut buf).unwrap();
1122 let _ = s.write_all(&pg_auth(3, &[]));
1123 let _ = s.read(&mut buf).unwrap();
1124 let _ = s.write_all(&post_auth_ok());
1125 loop {
1126 match s.read(&mut buf) {
1127 Ok(0) | Err(_) => break,
1128 Ok(_) => {
1129 let mut r = Vec::new();
1130 r.extend(pg_msg(b'1', &[]));
1131 r.extend(pg_msg(b'2', &[]));
1132 let mut rd = Vec::new();
1133 rd.extend(&1u16.to_be_bytes());
1134 rd.extend(b"c\x00");
1135 rd.extend(&0u32.to_be_bytes());
1136 rd.extend(&1u16.to_be_bytes());
1137 rd.extend(&23u32.to_be_bytes());
1138 rd.extend(&4i16.to_be_bytes());
1139 rd.extend(&(-1i32).to_be_bytes());
1140 rd.extend(&0u16.to_be_bytes());
1141 r.extend(pg_msg(b'T', &rd));
1142 let mut dr = Vec::new();
1143 dr.extend(&1u16.to_be_bytes());
1144 dr.extend(&1u32.to_be_bytes());
1145 dr.push(b'1');
1146 r.extend(pg_msg(b'D', &dr));
1147 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
1148 r.extend(pg_msg(b'Z', b"T"));
1149 let _ = s.write_all(&r);
1150 }
1151 }
1152 }
1153 });
1154 thread::sleep(Duration::from_millis(30));
1155 port
1156 }
1157
1158 fn spawn_error_status_server() -> u16 {
1159 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1160 let port = listener.local_addr().unwrap().port();
1161 thread::spawn(move || {
1162 let (mut s, _) = listener.accept().unwrap();
1163 let mut buf = [0u8; 4096];
1164 let _ = s.read(&mut buf).unwrap();
1165 let _ = s.write_all(&pg_auth(3, &[]));
1166 let _ = s.read(&mut buf).unwrap();
1167 let _ = s.write_all(&post_auth_ok());
1168 loop {
1169 match s.read(&mut buf) {
1170 Ok(0) | Err(_) => break,
1171 Ok(_) => {
1172 let mut r = Vec::new();
1173 r.extend(pg_msg(b'1', &[]));
1174 r.extend(pg_msg(b'2', &[]));
1175 let mut rd = Vec::new();
1176 rd.extend(&1u16.to_be_bytes());
1177 rd.extend(b"c\x00");
1178 rd.extend(&0u32.to_be_bytes());
1179 rd.extend(&1u16.to_be_bytes());
1180 rd.extend(&23u32.to_be_bytes());
1181 rd.extend(&4i16.to_be_bytes());
1182 rd.extend(&(-1i32).to_be_bytes());
1183 rd.extend(&0u16.to_be_bytes());
1184 r.extend(pg_msg(b'T', &rd));
1185 let mut dr = Vec::new();
1186 dr.extend(&1u16.to_be_bytes());
1187 dr.extend(&1u32.to_be_bytes());
1188 dr.push(b'1');
1189 r.extend(pg_msg(b'D', &dr));
1190 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
1191 r.extend(pg_msg(b'Z', b"E"));
1192 let _ = s.write_all(&r);
1193 }
1194 }
1195 }
1196 });
1197 thread::sleep(Duration::from_millis(30));
1198 port
1199 }
1200
1201 fn spawn_slow_partial_server() -> u16 {
1202 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1203 let port = listener.local_addr().unwrap().port();
1204 thread::spawn(move || {
1205 let (mut s, _) = listener.accept().unwrap();
1206 let mut buf = [0u8; 4096];
1207 let _ = s.read(&mut buf).unwrap();
1208 let _ = s.write_all(&pg_auth(3, &[]));
1209 let _ = s.read(&mut buf).unwrap();
1210 let _ = s.write_all(&post_auth_ok());
1211 match s.read(&mut buf) {
1212 Ok(0) | Err(_) => {}
1213 Ok(_) => {
1214 let _ = s.write_all(&simple_query_response());
1215 }
1216 }
1217 });
1218 thread::sleep(Duration::from_millis(30));
1219 port
1220 }
1221
1222 fn spawn_rst_on_query_server() -> u16 {
1223 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1224 let port = listener.local_addr().unwrap().port();
1225 thread::spawn(move || {
1226 let (mut s, _) = listener.accept().unwrap();
1227 let mut buf = [0u8; 4096];
1228 let _ = s.read(&mut buf).unwrap();
1229 let _ = s.write_all(&pg_auth(3, &[]));
1230 let _ = s.read(&mut buf).unwrap();
1231 let _ = s.write_all(&post_auth_ok());
1232 match s.read(&mut buf) {
1233 Ok(0) | Err(_) => {}
1234 Ok(_) => {
1235 drop(s);
1236 }
1237 }
1238 });
1239 thread::sleep(Duration::from_millis(30));
1240 port
1241 }
1242
1243 #[test]
1244 fn connect_query_ready_for_query_transaction_status() {
1245 let port = spawn_transaction_status_server();
1246 let mut conn = Connect::new(mock_config(port)).unwrap();
1247 let result = conn.query("SELECT 1");
1248 assert!(result.is_ok());
1249 }
1250
1251 #[test]
1252 fn connect_query_ready_for_query_error_status() {
1253 let port = spawn_error_status_server();
1254 let mut conn = Connect::new(mock_config(port)).unwrap();
1255 let result = conn.query("SELECT 1");
1256 assert!(result.is_ok());
1257 }
1258
1259 #[test]
1260 fn connect_query_server_closes_after_partial() {
1261 let port = spawn_slow_partial_server();
1262 let mut conn = Connect::new(mock_config(port)).unwrap();
1263 let r1 = conn.query("SELECT 1");
1264 assert!(r1.is_ok());
1265 let r2 = conn.query("SELECT 1");
1266 assert!(r2.is_err());
1267 }
1268
1269 #[test]
1270 fn connect_query_server_rst_returns_io_or_connection_error() {
1271 let port = spawn_rst_on_query_server();
1272 let mut conn = Connect::new(mock_config(port)).unwrap();
1273 let result = conn.query("SELECT 1");
1274 assert!(result.is_err());
1275 }
1276
1277 #[test]
1278 fn connect_read_would_block_max_retries() {
1279 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1280 let port = listener.local_addr().unwrap().port();
1281 thread::spawn(move || {
1282 let (mut s, _) = listener.accept().unwrap();
1283 let mut buf = [0u8; 4096];
1284 let _ = s.read(&mut buf);
1285 let _ = s.write_all(&pg_auth(3, &[]));
1286 let _ = s.read(&mut buf);
1287 let _ = s.write_all(&post_auth_ok());
1288 let _ = s.read(&mut buf);
1289 thread::sleep(Duration::from_secs(5));
1290 });
1291 thread::sleep(Duration::from_millis(30));
1292
1293 let mut conn = Connect::new(mock_config(port)).unwrap();
1294 conn.stream
1295 .set_read_timeout(Some(Duration::from_millis(1)))
1296 .ok();
1297 let result = conn.query("SELECT 1");
1298 assert!(result.is_err());
1299 let err_str = result.unwrap_err().to_string();
1300 assert!(
1301 err_str.contains("超时") || err_str.contains("Timeout") || err_str.contains("重试"),
1302 "expected timeout error, got: {err_str}"
1303 );
1304 }
1305
1306 #[test]
1307 fn connect_read_exceeds_max_message_size() {
1308 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1309 let port = listener.local_addr().unwrap().port();
1310 thread::spawn(move || {
1311 let (mut s, _) = listener.accept().unwrap();
1312 let mut buf = [0u8; 4096];
1313 let _ = s.read(&mut buf);
1314 let _ = s.write_all(&pg_auth(3, &[]));
1315 let _ = s.read(&mut buf);
1316 let _ = s.write_all(&post_auth_ok());
1317 let _ = s.read(&mut buf);
1318 let big = vec![b'X'; 256];
1319 let _ = s.write_all(&big);
1320 thread::sleep(Duration::from_secs(2));
1321 });
1322 thread::sleep(Duration::from_millis(30));
1323
1324 let mut conn = Connect::new(mock_config(port)).unwrap();
1325 let result = conn.query("SELECT 1");
1326 assert!(result.is_err());
1327 let err_str = result.unwrap_err().to_string();
1328 assert!(
1329 err_str.contains("最大") || err_str.contains("大小") || err_str.contains("size"),
1330 "expected max message size error, got: {err_str}"
1331 );
1332 }
1333
1334 #[test]
1335 fn connect_read_deadline_timeout() {
1336 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1337 let port = listener.local_addr().unwrap().port();
1338 thread::spawn(move || {
1339 let (mut s, _) = listener.accept().unwrap();
1340 let mut buf = [0u8; 4096];
1341 let _ = s.read(&mut buf);
1342 let _ = s.write_all(&pg_auth(3, &[]));
1343 let _ = s.read(&mut buf);
1344 let _ = s.write_all(&post_auth_ok());
1345 let _ = s.read(&mut buf);
1346 for _ in 0..200 {
1347 let _ = s.write_all(b"X");
1348 thread::sleep(Duration::from_millis(5));
1349 }
1350 });
1351 thread::sleep(Duration::from_millis(30));
1352
1353 let mut conn = Connect::new(mock_config(port)).unwrap();
1354 let result = conn.query("SELECT 1");
1355 assert!(result.is_err());
1356 }
1357
1358 #[test]
1359 fn connect_read_partial_auth_frame() {
1360 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1361 let port = listener.local_addr().unwrap().port();
1362 thread::spawn(move || {
1363 let (mut s, _) = listener.accept().unwrap();
1364 let mut buf = [0u8; 4096];
1365 let _ = s.read(&mut buf);
1366 let auth = pg_auth(3, &[]);
1367 let _ = s.write_all(&auth[..5]);
1368 thread::sleep(Duration::from_millis(50));
1369 let _ = s.write_all(&auth[5..]);
1370 let _ = s.read(&mut buf);
1371 let _ = s.write_all(&post_auth_ok());
1372 loop {
1373 match s.read(&mut buf) {
1374 Ok(0) | Err(_) => break,
1375 Ok(_) => {
1376 let _ = s.write_all(&simple_query_response());
1377 }
1378 }
1379 }
1380 });
1381 thread::sleep(Duration::from_millis(30));
1382
1383 let mut conn = Connect::new(mock_config(port)).unwrap();
1384 let result = conn.query("SELECT 1");
1385 assert!(result.is_ok());
1386 }
1387 fn portal_response(rows: u16) -> Vec<u8> {
1391 let mut r = Vec::new();
1392 r.extend(pg_msg(b'1', &[])); r.extend(pg_msg(b'2', &[])); let mut rd = Vec::new();
1396 rd.extend(&1u16.to_be_bytes());
1397 rd.extend(b"id\x00");
1398 rd.extend(&0u32.to_be_bytes());
1399 rd.extend(&1u16.to_be_bytes());
1400 rd.extend(&23u32.to_be_bytes());
1401 rd.extend(&4i16.to_be_bytes());
1402 rd.extend(&(-1i32).to_be_bytes());
1403 rd.extend(&0u16.to_be_bytes());
1404 r.extend(pg_msg(b'T', &rd));
1405 for i in 0..rows {
1406 let val = format!("{}", i + 1);
1407 let mut dr = Vec::new();
1408 dr.extend(&1u16.to_be_bytes());
1409 dr.extend(&(val.len() as u32).to_be_bytes());
1410 dr.extend(val.as_bytes());
1411 r.extend(pg_msg(b'D', &dr));
1412 }
1413 r.extend(pg_msg(b's', &[]));
1415 r.extend(ready_for_query());
1416 r
1417 }
1418 fn portal_complete_response(rows: u16) -> Vec<u8> {
1421 let mut r = Vec::new();
1422 for i in 0..rows {
1423 let val = format!("{}", i + 1);
1424 let mut dr = Vec::new();
1425 dr.extend(&1u16.to_be_bytes());
1426 dr.extend(&(val.len() as u32).to_be_bytes());
1427 dr.extend(val.as_bytes());
1428 r.extend(pg_msg(b'D', &dr));
1429 }
1430 r.extend(pg_msg(b'C', b"SELECT 2\x00"));
1431 r.extend(ready_for_query());
1432 r
1433 }
1434 fn close_portal_response() -> Vec<u8> {
1436 let mut r = Vec::new();
1437 r.extend(pg_msg(b'3', &[])); r.extend(ready_for_query());
1439 r
1440 }
1441 fn spawn_portal_server() -> u16 {
1442 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1443 let port = listener.local_addr().unwrap().port();
1444 thread::spawn(move || {
1445 let (mut s, _) = listener.accept().unwrap();
1446 let mut buf = [0u8; 4096];
1447 let _ = s.read(&mut buf).unwrap();
1449 let _ = s.write_all(&pg_auth(3, &[]));
1450 let _ = s.read(&mut buf).unwrap();
1451 let _ = s.write_all(&post_auth_ok());
1452 match s.read(&mut buf) {
1454 Ok(0) | Err(_) => (),
1455 Ok(_) => {
1456 let _ = s.write_all(&portal_response(2));
1457 }
1458 }
1459 match s.read(&mut buf) {
1461 Ok(0) | Err(_) => (),
1462 Ok(_) => {
1463 let _ = s.write_all(&portal_complete_response(1));
1464 }
1465 }
1466 match s.read(&mut buf) {
1468 Ok(0) | Err(_) => (),
1469 Ok(_) => {
1470 let _ = s.write_all(&close_portal_response());
1471 }
1472 }
1473 });
1474 thread::sleep(Duration::from_millis(30));
1475 port
1476 }
1477 #[test]
1479 fn ssl_prefer_fallback_on_rejection() {
1480 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1481 let port = listener.local_addr().unwrap().port();
1482 thread::spawn(move || {
1483 let (mut s, _) = listener.accept().unwrap();
1484 let mut buf = [0u8; 4096];
1485 let _ = s.read(&mut buf);
1487 let _ = s.write_all(b"N");
1489 let _ = s.read(&mut buf);
1491 let _ = s.write_all(&pg_auth(3, &[]));
1492 let _ = s.read(&mut buf);
1493 let _ = s.write_all(&post_auth_ok());
1494 loop {
1495 match s.read(&mut buf) {
1496 Ok(0) | Err(_) => break,
1497 Ok(_) => {
1498 let _ = s.write_all(&simple_query_response());
1499 }
1500 }
1501 }
1502 });
1503 thread::sleep(Duration::from_millis(30));
1504 let mut cfg = mock_config(port);
1505 cfg.sslmode = "prefer".into();
1506 let conn = Connect::new(cfg);
1507 assert!(conn.is_ok());
1508 }
1509 #[test]
1510 fn ssl_require_rejected_returns_error() {
1511 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1512 let port = listener.local_addr().unwrap().port();
1513 thread::spawn(move || {
1514 let (mut s, _) = listener.accept().unwrap();
1515 let mut buf = [0u8; 4096];
1516 let _ = s.read(&mut buf);
1517 let _ = s.write_all(b"N");
1518 });
1519 thread::sleep(Duration::from_millis(30));
1520 let mut cfg = mock_config(port);
1521 cfg.sslmode = "require".into();
1522 let result = Connect::new(cfg);
1523 assert!(result.is_err());
1524 }
1525 #[test]
1526 fn ssl_invalid_response_byte() {
1527 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1528 let port = listener.local_addr().unwrap().port();
1529 thread::spawn(move || {
1530 let (mut s, _) = listener.accept().unwrap();
1531 let mut buf = [0u8; 4096];
1532 let _ = s.read(&mut buf);
1533 let _ = s.write_all(b"X");
1534 });
1535 thread::sleep(Duration::from_millis(30));
1536 let mut cfg = mock_config(port);
1537 cfg.sslmode = "prefer".into();
1538 let result = Connect::new(cfg);
1539 assert!(result.is_err());
1540 }
1541 #[test]
1542 fn ssl_disable_skips_ssl_handshake() {
1543 let port = spawn_cleartext_server();
1544 let mut cfg = mock_config(port);
1545 cfg.sslmode = "disable".into();
1546 let conn = Connect::new(cfg);
1547 assert!(conn.is_ok());
1548 }
1549 #[test]
1551 fn connect_query_portal_returns_rows_with_has_more() {
1552 let port = spawn_portal_server();
1553 let mut conn = Connect::new(mock_config(port)).unwrap();
1554 let result = conn.query_portal("SELECT id FROM t", 2);
1555 assert!(result.is_ok());
1556 let msg = result.unwrap();
1557 assert_eq!(msg.rows.len(), 2);
1558 assert!(msg.has_more);
1559 }
1560 #[test]
1561 fn connect_fetch_more_returns_remaining_rows() {
1562 let port = spawn_portal_server();
1563 let mut conn = Connect::new(mock_config(port)).unwrap();
1564 let _ = conn.query_portal("SELECT id FROM t", 2).unwrap();
1565 let result = conn.fetch_more(10);
1566 assert!(result.is_ok());
1567 let msg = result.unwrap();
1568 assert_eq!(msg.rows.len(), 1);
1569 assert!(!msg.has_more);
1570 }
1571 #[test]
1572 fn connect_close_portal_succeeds() {
1573 let port = spawn_portal_server();
1574 let mut conn = Connect::new(mock_config(port)).unwrap();
1575 let _ = conn.query_portal("SELECT id FROM t", 2).unwrap();
1576 let _ = conn.fetch_more(10).unwrap();
1577 let result = conn.close_portal();
1578 assert!(result.is_ok());
1579 }
1580}