1use crate::config::Config;
2use crate::error::PgsqlError;
3use crate::packet::{AuthStatus, Packet, SuccessMessage};
4use std::io::{Read, Write};
5use std::net::{SocketAddr, TcpStream};
6use std::time::{Duration, Instant};
7
8#[derive(Debug)]
10pub(crate) enum PgStream {
11 Plain(TcpStream),
12 #[cfg(feature = "tls")]
13 Tls(Box<native_tls::TlsStream<TcpStream>>),
14}
15
16impl Read for PgStream {
17 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
18 match self {
19 PgStream::Plain(s) => s.read(buf),
20 #[cfg(feature = "tls")]
21 PgStream::Tls(s) => s.read(buf),
22 }
23 }
24}
25
26impl Write for PgStream {
27 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
28 match self {
29 PgStream::Plain(s) => s.write(buf),
30 #[cfg(feature = "tls")]
31 PgStream::Tls(s) => s.write(buf),
32 }
33 }
34 fn flush(&mut self) -> std::io::Result<()> {
35 match self {
36 PgStream::Plain(s) => s.flush(),
37 #[cfg(feature = "tls")]
38 PgStream::Tls(s) => s.flush(),
39 }
40 }
41}
42
43impl PgStream {
44 fn peer_addr(&self) -> std::io::Result<SocketAddr> {
45 match self {
46 PgStream::Plain(s) => s.peer_addr(),
47 #[cfg(feature = "tls")]
48 PgStream::Tls(s) => s.get_ref().peer_addr(),
49 }
50 }
51 fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
52 match self {
53 PgStream::Plain(s) => s.shutdown(how),
54 #[cfg(feature = "tls")]
55 PgStream::Tls(s) => s.get_ref().shutdown(how),
56 }
57 }
58 #[allow(dead_code)]
59 fn set_read_timeout(&self, dur: Option<Duration>) -> std::io::Result<()> {
60 match self {
61 PgStream::Plain(s) => s.set_read_timeout(dur),
62 #[cfg(feature = "tls")]
63 PgStream::Tls(s) => s.get_ref().set_read_timeout(dur),
64 }
65 }
66}
67
68#[derive(Debug)]
69pub struct Connect {
70 pub(crate) stream: PgStream,
71 _peer_addr: SocketAddr,
73 packet: Packet,
74 auth_status: AuthStatus,
75 last_used: Instant,
77 created_at: Instant,
79}
80
81impl Connect {
82 pub fn is_valid(&mut self) -> bool {
84 if self.stream.peer_addr().is_err() {
85 return false;
86 }
87 #[cfg(not(test))]
88 const IDLE_THRESHOLD: Duration = Duration::from_secs(5);
89 #[cfg(test)]
90 const IDLE_THRESHOLD: Duration = Duration::from_millis(0);
91 if self.last_used.elapsed() > IDLE_THRESHOLD {
92 return self.query("SELECT 1").is_ok();
93 }
94 true
95 }
96
97 pub fn peer_valid(&self) -> bool {
99 self.stream.peer_addr().is_ok()
100 }
101
102 pub fn touch(&mut self) {
104 self.last_used = Instant::now();
105 }
106
107 pub fn idle_elapsed(&self) -> Duration {
109 self.last_used.elapsed()
110 }
111
112 pub fn age(&self) -> Duration {
114 self.created_at.elapsed()
115 }
116
117 pub fn _close(&mut self) {
118 let _ = self.stream.write_all(&Packet::pack_terminate());
119 let _ = self.stream.shutdown(std::net::Shutdown::Both);
120 }
121
122 fn set_keepalive(stream: &TcpStream) -> Result<(), PgsqlError> {
124 let keepalive = socket2::TcpKeepalive::new()
125 .with_time(Duration::from_secs(60))
126 .with_interval(Duration::from_secs(15));
127 #[cfg(not(target_os = "windows"))]
128 let keepalive = keepalive.with_retries(3);
129
130 let socket = socket2::SockRef::from(stream);
131 socket
132 .set_tcp_keepalive(&keepalive)
133 .map_err(|e| PgsqlError::Connection(format!("设置 TCP Keepalive 失败: {}", e)))
134 }
135
136 fn try_ssl_upgrade(mut stream: TcpStream, config: &Config) -> Result<PgStream, PgsqlError> {
138 stream
140 .write_all(&Packet::pack_ssl_request())
141 .map_err(|e| PgsqlError::Connection(format!("发送 SSLRequest 失败: {}", e)))?;
142 let mut resp = [0u8; 1];
143 stream
144 .read_exact(&mut resp)
145 .map_err(|e| PgsqlError::Connection(format!("读取 SSL 响应失败: {}", e)))?;
146 match resp[0] {
147 b'S' => {
148 #[cfg(feature = "tls")]
150 {
151 let connector = native_tls::TlsConnector::builder()
152 .danger_accept_invalid_certs(true)
153 .build()
154 .map_err(|e| PgsqlError::Connection(format!("TLS 初始化失败: {}", e)))?;
155 let tls_stream = connector
156 .connect(&config.hostname, stream)
157 .map_err(|e| PgsqlError::Connection(format!("TLS 握手失败: {}", e)))?;
158 Ok(PgStream::Tls(Box::new(tls_stream)))
159 }
160 #[cfg(not(feature = "tls"))]
161 {
162 let _ = config;
163 Err(PgsqlError::Connection(
164 "服务端要求 SSL 但未启用 tls feature".into(),
165 ))
166 }
167 }
168 b'N' => {
169 if config.sslmode == "require" {
171 Err(PgsqlError::Connection(
172 "sslmode=require 但服务端不支持 SSL".into(),
173 ))
174 } else {
175 Ok(PgStream::Plain(stream))
176 }
177 }
178 other => Err(PgsqlError::Connection(format!(
179 "无效的 SSL 响应字节: 0x{:02X}",
180 other
181 ))),
182 }
183 }
184 pub fn new(mut config: Config) -> Result<Connect, PgsqlError> {
185 let stream =
186 TcpStream::connect(config.url()).map_err(|e| PgsqlError::Connection(e.to_string()))?;
187 stream
189 .set_nodelay(true)
190 .map_err(|e| PgsqlError::Connection(format!("设置 TCP_NODELAY 失败: {}", e)))?;
191 Self::set_keepalive(&stream)?;
193 stream
194 .set_read_timeout(Some(Duration::from_secs(30)))
195 .map_err(|e| PgsqlError::Connection(format!("设置读取超时失败: {}", e)))?;
196 stream
197 .set_write_timeout(Some(Duration::from_secs(30)))
198 .map_err(|e| PgsqlError::Connection(format!("设置写入超时失败: {}", e)))?;
199 let peer_addr = stream
200 .peer_addr()
201 .map_err(|e| PgsqlError::Connection(e.to_string()))?;
202
203 let stream = if config.sslmode != "disable" {
205 Self::try_ssl_upgrade(stream, &config)?
206 } else {
207 PgStream::Plain(stream)
208 };
209
210 let mut connect = Self {
211 stream,
212 _peer_addr: peer_addr,
213 packet: Packet::new(config),
214 auth_status: AuthStatus::None,
215 last_used: Instant::now(),
216 created_at: Instant::now(),
217 };
218
219 connect.authenticate()?;
220
221 Ok(connect)
222 }
223
224 fn authenticate(&mut self) -> Result<(), PgsqlError> {
225 self.stream
226 .write_all(&self.packet.pack_first())
227 .map_err(|e| PgsqlError::Auth(format!("发送 startup message 失败: {}", e)))?;
228
229 let data = self.read()?;
230 self.packet.unpack(data, 0)?;
231
232 if !self.packet.md5_salt.is_empty() {
233 self.md5_auth()?;
234 } else if self.packet.auth_mechanism.is_empty() && self.packet.md5_salt.is_empty() {
235 self.cleartext_auth()?;
236 } else {
237 self.scram_auth()?;
238 }
239
240 self.auth_status = AuthStatus::AuthenticationOk;
241 Ok(())
242 }
243
244 fn md5_auth(&mut self) -> Result<(), PgsqlError> {
245 self.stream
246 .write_all(&self.packet.pack_md5_password())
247 .map_err(|e| PgsqlError::Auth(format!("发送 MD5 密码失败: {}", e)))?;
248
249 let data = self.read()?;
250 self.packet.unpack(data, 0)?;
251 Ok(())
252 }
253
254 fn cleartext_auth(&mut self) -> Result<(), PgsqlError> {
255 self.stream
256 .write_all(&self.packet.pack_cleartext_password())
257 .map_err(|e| PgsqlError::Auth(format!("发送明文密码失败: {}", e)))?;
258
259 let data = self.read()?;
260 self.packet.unpack(data, 0)?;
261 Ok(())
262 }
263
264 fn scram_auth(&mut self) -> Result<(), PgsqlError> {
265 self.stream
266 .write_all(&self.packet.pack_auth())
267 .map_err(|e| PgsqlError::Auth(format!("发送 SASL Initial Response 失败: {}", e)))?;
268
269 let data = self.read()?;
270 self.packet.unpack(data, 0)?;
271
272 self.stream
273 .write_all(&self.packet.pack_auth_verify())
274 .map_err(|e| PgsqlError::Auth(format!("发送 SASL Verify 失败: {}", e)))?;
275
276 let data = self.read()?;
277 self.packet.unpack(data, 0)?;
278 Ok(())
279 }
280
281 fn read(&mut self) -> Result<Vec<u8>, PgsqlError> {
282 let mut msg = Vec::new();
283 let mut buf = [0u8; 4096];
284 let mut retry_count = 0;
285
286 #[cfg(not(test))]
287 const MAX_RETRIES: u32 = 100;
288 #[cfg(test)]
289 const MAX_RETRIES: u32 = 3;
290
291 #[cfg(not(test))]
292 const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024;
293 #[cfg(test)]
294 const MAX_MESSAGE_SIZE: usize = 128;
295
296 #[cfg(not(test))]
297 let deadline = std::time::Instant::now() + Duration::from_secs(300);
298 #[cfg(test)]
299 let deadline = std::time::Instant::now() + Duration::from_millis(200);
300
301 loop {
302 if std::time::Instant::now() >= deadline {
303 return Err(PgsqlError::Timeout("读取总超时".into()));
304 }
305
306 match self.stream.read(&mut buf) {
307 Ok(0) => return Err(PgsqlError::Connection("连接已关闭或服务端断开".into())),
308 Ok(n) => {
309 if msg.len() + n > MAX_MESSAGE_SIZE {
310 return Err(PgsqlError::Protocol("消息超过最大允许大小".into()));
311 }
312 msg.extend_from_slice(&buf[..n]);
313 retry_count = 0;
314 }
315 Err(ref e)
316 if e.kind() == std::io::ErrorKind::WouldBlock
317 || e.kind() == std::io::ErrorKind::TimedOut =>
318 {
319 retry_count += 1;
320 if retry_count > MAX_RETRIES {
321 return Err(PgsqlError::Timeout("读取超时,已达最大重试次数".into()));
322 }
323 std::thread::sleep(Duration::from_millis(10));
324 continue;
325 }
326 Err(e) => return Err(PgsqlError::Io(e)),
327 };
328
329 if let AuthStatus::AuthenticationOk = self.auth_status {
330 if msg.ends_with(&[90, 0, 0, 0, 5, 73])
331 || msg.ends_with(&[90, 0, 0, 0, 5, 84])
332 || msg.ends_with(&[90, 0, 0, 0, 5, 69])
333 {
334 break;
335 }
336 } else if msg.len() >= 5 {
337 let len_bytes = &msg[1..=4];
338 if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
339 if msg.len() > len as usize {
340 break;
341 }
342 }
343 }
344 }
345
346 Ok(msg)
347 }
348
349 pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
350 self.stream
351 .write_all(&self.packet.pack_query(sql))
352 .map_err(PgsqlError::Io)?;
353 let data = self.read()?;
354 self.last_used = Instant::now();
355 self.packet.unpack(data, 0)
356 }
357
358 pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
359 self.stream
360 .write_all(&self.packet.pack_execute(sql))
361 .map_err(PgsqlError::Io)?;
362 let data = self.read()?;
363 self.last_used = Instant::now();
364 self.packet.unpack(data, 0)
365 }
366
367 pub fn query_params(
369 &mut self,
370 sql: &str,
371 params: &[Option<&str>],
372 ) -> Result<SuccessMessage, PgsqlError> {
373 self.stream
374 .write_all(&self.packet.pack_query_params(sql, params))
375 .map_err(PgsqlError::Io)?;
376
377 let data = self.read()?;
378 self.last_used = Instant::now();
379 self.packet.unpack(data, 0)
380 }
381
382 pub fn execute_params(
384 &mut self,
385 sql: &str,
386 params: &[Option<&str>],
387 ) -> Result<SuccessMessage, PgsqlError> {
388 self.stream
389 .write_all(&self.packet.pack_execute_params(sql, params))
390 .map_err(PgsqlError::Io)?;
391 let data = self.read()?;
392 self.last_used = Instant::now();
393 self.packet.unpack(data, 0)
394 }
395
396 pub fn query_str(&mut self, sql: &str, params: &[&str]) -> Result<SuccessMessage, PgsqlError> {
398 let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
399 self.query_params(sql, &opts)
400 }
401
402 pub fn execute_str(
404 &mut self,
405 sql: &str,
406 params: &[&str],
407 ) -> Result<SuccessMessage, PgsqlError> {
408 let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
409 self.execute_params(sql, &opts)
410 }
411 pub fn query_portal(&mut self, sql: &str, max_rows: u32) -> Result<SuccessMessage, PgsqlError> {
414 self.stream
415 .write_all(&self.packet.pack_query_portal(sql, max_rows))
416 .map_err(PgsqlError::Io)?;
417 let data = self.read()?;
418 self.last_used = Instant::now();
419 self.packet.unpack(data, 0)
420 }
421 pub fn fetch_more(&mut self, max_rows: u32) -> Result<SuccessMessage, PgsqlError> {
423 self.stream
424 .write_all(&self.packet.pack_fetch_more(max_rows))
425 .map_err(PgsqlError::Io)?;
426 let data = self.read()?;
427 self.last_used = Instant::now();
428 self.packet.unpack(data, 0)
429 }
430 pub fn close_portal(&mut self) -> Result<SuccessMessage, PgsqlError> {
432 self.stream
433 .write_all(&self.packet.pack_close_portal())
434 .map_err(PgsqlError::Io)?;
435 let data = self.read()?;
436 self.last_used = Instant::now();
437 self.packet.unpack(data, 0)
438 }
439}
440
441impl Drop for Connect {
442 fn drop(&mut self) {
443 let _ = self.stream.write_all(&Packet::pack_terminate());
444 let _ = self.stream.shutdown(std::net::Shutdown::Both);
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use std::net::TcpListener;
452 use std::thread;
453
454 fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
458 let mut m = Vec::with_capacity(5 + payload.len());
459 m.push(tag);
460 m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
461 m.extend_from_slice(payload);
462 m
463 }
464
465 fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
467 let mut body = Vec::new();
468 body.extend(&auth_type.to_be_bytes());
469 body.extend_from_slice(extra);
470 pg_msg(b'R', &body)
471 }
472
473 fn auth_ok() -> Vec<u8> {
475 pg_auth(0, &[])
476 }
477
478 fn param_status() -> Vec<u8> {
480 pg_msg(b'S', b"server_version\x0015.0\x00")
481 }
482
483 fn backend_key() -> Vec<u8> {
485 let mut p = Vec::new();
486 p.extend(&1u32.to_be_bytes());
487 p.extend(&2u32.to_be_bytes());
488 pg_msg(b'K', &p)
489 }
490
491 fn ready_for_query() -> Vec<u8> {
493 pg_msg(b'Z', b"I")
494 }
495
496 fn post_auth_ok() -> Vec<u8> {
498 let mut v = Vec::new();
499 v.extend(auth_ok());
500 v.extend(param_status());
501 v.extend(backend_key());
502 v.extend(ready_for_query());
503 v
504 }
505
506 fn simple_query_response() -> Vec<u8> {
509 let mut r = Vec::new();
510 r.extend(pg_msg(b'1', &[]));
512 r.extend(pg_msg(b'2', &[]));
514 let mut rd = Vec::new();
516 rd.extend(&1u16.to_be_bytes()); rd.extend(b"c\x00"); rd.extend(&0u32.to_be_bytes()); rd.extend(&1u16.to_be_bytes()); rd.extend(&23u32.to_be_bytes()); rd.extend(&4i16.to_be_bytes()); rd.extend(&(-1i32).to_be_bytes()); rd.extend(&0u16.to_be_bytes()); r.extend(pg_msg(b'T', &rd));
525 let mut dr = Vec::new();
527 dr.extend(&1u16.to_be_bytes());
528 dr.extend(&1u32.to_be_bytes()); dr.push(b'1');
530 r.extend(pg_msg(b'D', &dr));
531 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
533 r.extend(ready_for_query());
535 r
536 }
537
538 fn execute_response() -> Vec<u8> {
541 let mut r = Vec::new();
542 r.extend(pg_msg(b'1', &[]));
543 r.extend(pg_msg(b'2', &[]));
544 r.extend(pg_msg(b'n', &[])); r.extend(pg_msg(b'C', b"UPDATE 3\x00"));
546 r.extend(ready_for_query());
547 r
548 }
549
550 fn query_params_response() -> Vec<u8> {
553 let mut r = Vec::new();
554 r.extend(pg_msg(b'1', &[]));
555
556 let mut pd = Vec::new();
557 pd.extend(&1u16.to_be_bytes());
558 pd.extend(&23u32.to_be_bytes());
559 r.extend(pg_msg(b't', &pd));
560
561 r.extend(pg_msg(b'2', &[]));
562
563 let mut rd = Vec::new();
564 rd.extend(&1u16.to_be_bytes());
565 rd.extend(b"p\x00");
566 rd.extend(&0u32.to_be_bytes());
567 rd.extend(&1u16.to_be_bytes());
568 rd.extend(&23u32.to_be_bytes());
569 rd.extend(&4i16.to_be_bytes());
570 rd.extend(&(-1i32).to_be_bytes());
571 rd.extend(&0u16.to_be_bytes());
572 r.extend(pg_msg(b'T', &rd));
573
574 let mut dr = Vec::new();
575 dr.extend(&1u16.to_be_bytes());
576 dr.extend(&2u32.to_be_bytes());
577 dr.extend(b"42");
578 r.extend(pg_msg(b'D', &dr));
579
580 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
581 r.extend(ready_for_query());
582 r
583 }
584
585 fn execute_params_response() -> Vec<u8> {
588 let mut r = Vec::new();
589 r.extend(pg_msg(b'1', &[]));
590
591 let mut pd = Vec::new();
592 pd.extend(&1u16.to_be_bytes());
593 pd.extend(&23u32.to_be_bytes());
594 r.extend(pg_msg(b't', &pd));
595
596 r.extend(pg_msg(b'2', &[]));
597 r.extend(pg_msg(b'n', &[]));
598 r.extend(pg_msg(b'C', b"UPDATE 1\x00"));
599 r.extend(ready_for_query());
600 r
601 }
602
603 fn query_params_null_response() -> Vec<u8> {
605 let mut r = Vec::new();
606 r.extend(pg_msg(b'1', &[]));
607
608 let mut pd = Vec::new();
609 pd.extend(&1u16.to_be_bytes());
610 pd.extend(&25u32.to_be_bytes());
611 r.extend(pg_msg(b't', &pd));
612
613 r.extend(pg_msg(b'2', &[]));
614
615 let mut rd = Vec::new();
616 rd.extend(&1u16.to_be_bytes());
617 rd.extend(b"n\x00");
618 rd.extend(&0u32.to_be_bytes());
619 rd.extend(&1u16.to_be_bytes());
620 rd.extend(&25u32.to_be_bytes());
621 rd.extend(&(-1i16).to_be_bytes());
622 rd.extend(&(-1i32).to_be_bytes());
623 rd.extend(&0u16.to_be_bytes());
624 r.extend(pg_msg(b'T', &rd));
625
626 let mut dr = Vec::new();
627 dr.extend(&1u16.to_be_bytes());
628 dr.extend(&(-1i32).to_be_bytes());
629 r.extend(pg_msg(b'D', &dr));
630
631 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
632 r.extend(ready_for_query());
633 r
634 }
635
636 fn error_response() -> Vec<u8> {
638 let mut payload = Vec::new();
639 payload.push(b'C');
640 payload.extend(b"42601\x00");
641 payload.push(b'M');
642 payload.extend(b"syntax error\x00");
643 payload.push(0);
644 let mut r = Vec::new();
645 r.extend(pg_msg(b'1', &[]));
646 r.extend(pg_msg(b'2', &[]));
647 r.extend(pg_msg(b'E', &payload));
648 r.extend(ready_for_query());
649 r
650 }
651
652 fn mock_config(port: u16) -> Config {
656 Config {
657 debug: false,
658 hostname: "127.0.0.1".into(),
659 hostport: port as i32,
660 username: "u".into(),
661 userpass: "p".into(),
662 database: "d".into(),
663 charset: "utf8".into(),
664 pool_max: 5,
665 sslmode: "disable".into(),
666 }
667 }
668
669 fn spawn_cleartext_server() -> u16 {
672 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
673 let port = listener.local_addr().unwrap().port();
674 thread::spawn(move || {
675 let (mut s, _) = listener.accept().unwrap();
676 let mut buf = [0u8; 4096];
677 let _ = s.read(&mut buf).unwrap();
679 let _ = s.write_all(&pg_auth(3, &[]));
681 let _ = s.read(&mut buf).unwrap();
683 let _ = s.write_all(&post_auth_ok());
685 loop {
687 match s.read(&mut buf) {
688 Ok(0) | Err(_) => break,
689 Ok(_) => {
690 let _ = s.write_all(&simple_query_response());
691 }
692 }
693 }
694 });
695 thread::sleep(Duration::from_millis(30));
696 port
697 }
698
699 fn spawn_md5_server() -> u16 {
701 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
702 let port = listener.local_addr().unwrap().port();
703 thread::spawn(move || {
704 let (mut s, _) = listener.accept().unwrap();
705 let mut buf = [0u8; 4096];
706 let _ = s.read(&mut buf).unwrap();
708 let _ = s.write_all(&pg_auth(5, &[0xAA, 0xBB, 0xCC, 0xDD]));
710 let _ = s.read(&mut buf).unwrap();
712 let _ = s.write_all(&post_auth_ok());
714 loop {
715 match s.read(&mut buf) {
716 Ok(0) | Err(_) => break,
717 Ok(_) => {
718 let _ = s.write_all(&simple_query_response());
719 }
720 }
721 }
722 });
723 thread::sleep(Duration::from_millis(30));
724 port
725 }
726
727 fn spawn_scram_server() -> u16 {
729 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
730 let port = listener.local_addr().unwrap().port();
731 thread::spawn(move || {
732 let (mut s, _) = listener.accept().unwrap();
733 let mut buf = [0u8; 4096];
734 let _ = s.read(&mut buf).unwrap();
736 let _ = s.write_all(&pg_auth(10, b"SCRAM-SHA-256\x00\x00"));
738 let n = s.read(&mut buf).unwrap();
740 let payload = &buf[..n];
741 let text = String::from_utf8_lossy(payload);
743 let client_nonce = text.split("r=").nth(1).unwrap_or("clientnonce").to_string();
744 let challenge = format!("r={client_nonce}SERVERNONCE,s=c2FsdA==,i=4096");
746 let _ = s.write_all(&pg_auth(11, challenge.as_bytes()));
747 let _ = s.read(&mut buf).unwrap();
749 let mut resp = Vec::new();
751 resp.extend(pg_auth(12, b"v=dummyproof"));
752 resp.extend(auth_ok());
753 resp.extend(param_status());
754 resp.extend(backend_key());
755 resp.extend(ready_for_query());
756 let _ = s.write_all(&resp);
757 loop {
758 match s.read(&mut buf) {
759 Ok(0) | Err(_) => break,
760 Ok(_) => {
761 let _ = s.write_all(&simple_query_response());
762 }
763 }
764 }
765 });
766 thread::sleep(Duration::from_millis(30));
767 port
768 }
769
770 fn spawn_eof_server() -> u16 {
772 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
773 let port = listener.local_addr().unwrap().port();
774 thread::spawn(move || {
775 let (s, _) = listener.accept().unwrap();
776 drop(s); });
778 thread::sleep(Duration::from_millis(30));
779 port
780 }
781
782 fn spawn_auth_error_server() -> u16 {
784 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
785 let port = listener.local_addr().unwrap().port();
786 thread::spawn(move || {
787 let (mut s, _) = listener.accept().unwrap();
788 let mut buf = [0u8; 4096];
789 let _ = s.read(&mut buf).unwrap();
790 let mut payload = Vec::new();
792 payload.push(b'C');
793 payload.extend(b"28P01\x00");
794 payload.push(b'M');
795 payload.extend(b"password authentication failed\x00");
796 payload.push(0);
797 let _ = s.write_all(&pg_msg(b'E', &payload));
798 });
799 thread::sleep(Duration::from_millis(30));
800 port
801 }
802
803 fn spawn_query_error_server() -> u16 {
805 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
806 let port = listener.local_addr().unwrap().port();
807 thread::spawn(move || {
808 let (mut s, _) = listener.accept().unwrap();
809 let mut buf = [0u8; 4096];
810 let _ = s.read(&mut buf).unwrap();
812 let _ = s.write_all(&pg_auth(3, &[]));
813 let _ = s.read(&mut buf).unwrap();
814 let _ = s.write_all(&post_auth_ok());
815 loop {
817 match s.read(&mut buf) {
818 Ok(0) | Err(_) => break,
819 Ok(_) => {
820 let _ = s.write_all(&error_response());
821 }
822 }
823 }
824 });
825 thread::sleep(Duration::from_millis(30));
826 port
827 }
828
829 fn spawn_execute_server() -> u16 {
831 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
832 let port = listener.local_addr().unwrap().port();
833 thread::spawn(move || {
834 let (mut s, _) = listener.accept().unwrap();
835 let mut buf = [0u8; 4096];
836 let _ = s.read(&mut buf).unwrap();
838 let _ = s.write_all(&pg_auth(3, &[]));
839 let _ = s.read(&mut buf).unwrap();
840 let _ = s.write_all(&post_auth_ok());
841 loop {
843 match s.read(&mut buf) {
844 Ok(0) | Err(_) => break,
845 Ok(_) => {
846 let _ = s.write_all(&execute_response());
847 }
848 }
849 }
850 });
851 thread::sleep(Duration::from_millis(30));
852 port
853 }
854
855 fn spawn_query_params_server() -> u16 {
857 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
858 let port = listener.local_addr().unwrap().port();
859 thread::spawn(move || {
860 let (mut s, _) = listener.accept().unwrap();
861 let mut buf = [0u8; 4096];
862 let _ = s.read(&mut buf).unwrap();
863 let _ = s.write_all(&pg_auth(3, &[]));
864 let _ = s.read(&mut buf).unwrap();
865 let _ = s.write_all(&post_auth_ok());
866 loop {
867 match s.read(&mut buf) {
868 Ok(0) | Err(_) => break,
869 Ok(_) => {
870 let _ = s.write_all(&query_params_response());
871 }
872 }
873 }
874 });
875 thread::sleep(Duration::from_millis(30));
876 port
877 }
878
879 fn spawn_params_server() -> u16 {
880 spawn_query_params_server()
881 }
882
883 fn spawn_execute_params_server() -> u16 {
885 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
886 let port = listener.local_addr().unwrap().port();
887 thread::spawn(move || {
888 let (mut s, _) = listener.accept().unwrap();
889 let mut buf = [0u8; 4096];
890 let _ = s.read(&mut buf).unwrap();
891 let _ = s.write_all(&pg_auth(3, &[]));
892 let _ = s.read(&mut buf).unwrap();
893 let _ = s.write_all(&post_auth_ok());
894 loop {
895 match s.read(&mut buf) {
896 Ok(0) | Err(_) => break,
897 Ok(_) => {
898 let _ = s.write_all(&execute_params_response());
899 }
900 }
901 }
902 });
903 thread::sleep(Duration::from_millis(30));
904 port
905 }
906
907 fn spawn_query_params_null_server() -> u16 {
909 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
910 let port = listener.local_addr().unwrap().port();
911 thread::spawn(move || {
912 let (mut s, _) = listener.accept().unwrap();
913 let mut buf = [0u8; 4096];
914 let _ = s.read(&mut buf).unwrap();
915 let _ = s.write_all(&pg_auth(3, &[]));
916 let _ = s.read(&mut buf).unwrap();
917 let _ = s.write_all(&post_auth_ok());
918 loop {
919 match s.read(&mut buf) {
920 Ok(0) | Err(_) => break,
921 Ok(_) => {
922 let _ = s.write_all(&query_params_null_response());
923 }
924 }
925 }
926 });
927 thread::sleep(Duration::from_millis(30));
928 port
929 }
930
931 #[test]
934 fn connect_cleartext_auth_success() {
935 let port = spawn_cleartext_server();
936 let conn = Connect::new(mock_config(port));
937 assert!(conn.is_ok());
938 }
939
940 #[test]
941 fn connect_md5_auth_success() {
942 let port = spawn_md5_server();
943 let conn = Connect::new(mock_config(port));
944 assert!(conn.is_ok());
945 }
946
947 #[test]
948 fn connect_scram_auth_success() {
949 let port = spawn_scram_server();
950 let conn = Connect::new(mock_config(port));
951 assert!(conn.is_ok());
952 }
953
954 #[test]
955 fn connect_connection_refused() {
956 let cfg = mock_config(1);
958 let result = Connect::new(cfg);
959 assert!(result.is_err());
960 match result.unwrap_err() {
961 PgsqlError::Connection(_) => {}
962 other => panic!("expected Connection error, got {other:?}"),
963 }
964 }
965
966 #[test]
967 fn connect_server_closes_immediately() {
968 let port = spawn_eof_server();
969 let result = Connect::new(mock_config(port));
970 assert!(result.is_err());
971 }
972
973 #[test]
974 fn connect_auth_error_from_server() {
975 let port = spawn_auth_error_server();
976 let result = Connect::new(mock_config(port));
977 assert!(result.is_err());
978 }
979
980 #[test]
981 fn connect_query_success() {
982 let port = spawn_cleartext_server();
983 let mut conn = Connect::new(mock_config(port)).unwrap();
984 let result = conn.query("SELECT 1");
985 assert!(result.is_ok());
986 let msg = result.unwrap();
987 assert_eq!(msg.rows.len(), 1);
988 assert_eq!(msg.rows[0]["c"].as_i32(), Some(1));
989 }
990
991 #[test]
992 fn connect_execute_success() {
993 let port = spawn_execute_server();
994 let mut conn = Connect::new(mock_config(port)).unwrap();
995 let result = conn.execute("UPDATE t SET x=1");
996 assert!(result.is_ok());
997 let msg = result.unwrap();
998 assert_eq!(msg.affect_count, 3);
999 assert_eq!(msg.tag, "UPDATE 3");
1000 }
1001
1002 #[test]
1003 fn connect_query_params_success() {
1004 let port = spawn_query_params_server();
1005 let mut conn = Connect::new(mock_config(port)).unwrap();
1006 let result = conn.query_params("SELECT $1::int", &[Some("42")]);
1007 assert!(result.is_ok());
1008 let msg = result.unwrap();
1009 assert!(!msg.param_oids.is_empty());
1010 assert_eq!(msg.rows.len(), 1);
1011 assert_eq!(msg.rows[0]["p"].as_i32(), Some(42));
1012 }
1013
1014 #[test]
1015 fn connect_execute_params_success() {
1016 let port = spawn_execute_params_server();
1017 let mut conn = Connect::new(mock_config(port)).unwrap();
1018 let result = conn.execute_params("UPDATE t SET x=$1", &[Some("42")]);
1019 assert!(result.is_ok());
1020 let msg = result.unwrap();
1021 assert!(!msg.param_oids.is_empty());
1022 assert_eq!(msg.affect_count, 1);
1023 assert_eq!(msg.tag, "UPDATE 1");
1024 }
1025
1026 #[test]
1027 fn connect_query_str_success() {
1028 let port = spawn_params_server();
1029 let mut conn = Connect::new(mock_config(port)).unwrap();
1030 let result = conn.query_str("SELECT $1::int", &["42"]);
1031 assert!(result.is_ok());
1032 let msg = result.unwrap();
1033 assert!(!msg.param_oids.is_empty());
1034 assert_eq!(msg.rows.len(), 1);
1035 }
1036
1037 #[test]
1038 fn connect_execute_str_success() {
1039 let port = spawn_execute_params_server();
1040 let mut conn = Connect::new(mock_config(port)).unwrap();
1041 let result = conn.execute_str("UPDATE t SET x=$1", &["1"]);
1042 assert!(result.is_ok());
1043 let msg = result.unwrap();
1044 assert!(!msg.param_oids.is_empty());
1045 assert_eq!(msg.affect_count, 1);
1046 }
1047
1048 #[test]
1049 fn connect_query_params_with_null() {
1050 let port = spawn_query_params_null_server();
1051 let mut conn = Connect::new(mock_config(port)).unwrap();
1052 let result = conn.query_params("SELECT $1::text", &[None]);
1053 assert!(result.is_ok());
1054 let msg = result.unwrap();
1055 assert!(!msg.param_oids.is_empty());
1056 assert_eq!(msg.rows.len(), 1);
1057 assert_eq!(msg.rows[0]["n"], "");
1058 }
1059
1060 #[test]
1061 fn connect_query_params_empty_string_vs_null() {
1062 let port = spawn_params_server();
1063 let mut conn = Connect::new(mock_config(port)).unwrap();
1064
1065 let r1 = conn.query_params("SELECT $1::text", &[Some("")]);
1067 assert!(r1.is_ok());
1068
1069 let r2 = conn.query_params("SELECT $1::text", &[None]);
1071 assert!(r2.is_ok());
1072 }
1073
1074 #[test]
1075 fn connect_query_returns_error() {
1076 let port = spawn_query_error_server();
1077 let mut conn = Connect::new(mock_config(port)).unwrap();
1078 let result = conn.query("BAD SQL");
1079 assert!(result.is_err());
1080 }
1081
1082 #[test]
1083 fn connect_is_valid_true() {
1084 let port = spawn_cleartext_server();
1085 let mut conn = Connect::new(mock_config(port)).unwrap();
1086 assert!(conn.is_valid());
1087 }
1088
1089 #[test]
1090 fn connect_is_valid_false_after_close() {
1091 let port = spawn_cleartext_server();
1092 let mut conn = Connect::new(mock_config(port)).unwrap();
1093 conn._close();
1094 assert!(!conn.is_valid());
1096 }
1097
1098 #[test]
1099 fn connect_close_does_not_panic() {
1100 let port = spawn_cleartext_server();
1101 let mut conn = Connect::new(mock_config(port)).unwrap();
1102 conn._close();
1103 conn._close();
1105 }
1106
1107 #[test]
1108 fn connect_drop_does_not_panic() {
1109 let port = spawn_cleartext_server();
1110 let conn = Connect::new(mock_config(port)).unwrap();
1111 drop(conn);
1112 }
1113
1114 fn spawn_transaction_status_server() -> u16 {
1115 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1116 let port = listener.local_addr().unwrap().port();
1117 thread::spawn(move || {
1118 let (mut s, _) = listener.accept().unwrap();
1119 let mut buf = [0u8; 4096];
1120 let _ = s.read(&mut buf).unwrap();
1121 let _ = s.write_all(&pg_auth(3, &[]));
1122 let _ = s.read(&mut buf).unwrap();
1123 let _ = s.write_all(&post_auth_ok());
1124 loop {
1125 match s.read(&mut buf) {
1126 Ok(0) | Err(_) => break,
1127 Ok(_) => {
1128 let mut r = Vec::new();
1129 r.extend(pg_msg(b'1', &[]));
1130 r.extend(pg_msg(b'2', &[]));
1131 let mut rd = Vec::new();
1132 rd.extend(&1u16.to_be_bytes());
1133 rd.extend(b"c\x00");
1134 rd.extend(&0u32.to_be_bytes());
1135 rd.extend(&1u16.to_be_bytes());
1136 rd.extend(&23u32.to_be_bytes());
1137 rd.extend(&4i16.to_be_bytes());
1138 rd.extend(&(-1i32).to_be_bytes());
1139 rd.extend(&0u16.to_be_bytes());
1140 r.extend(pg_msg(b'T', &rd));
1141 let mut dr = Vec::new();
1142 dr.extend(&1u16.to_be_bytes());
1143 dr.extend(&1u32.to_be_bytes());
1144 dr.push(b'1');
1145 r.extend(pg_msg(b'D', &dr));
1146 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
1147 r.extend(pg_msg(b'Z', b"T"));
1148 let _ = s.write_all(&r);
1149 }
1150 }
1151 }
1152 });
1153 thread::sleep(Duration::from_millis(30));
1154 port
1155 }
1156
1157 fn spawn_error_status_server() -> u16 {
1158 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1159 let port = listener.local_addr().unwrap().port();
1160 thread::spawn(move || {
1161 let (mut s, _) = listener.accept().unwrap();
1162 let mut buf = [0u8; 4096];
1163 let _ = s.read(&mut buf).unwrap();
1164 let _ = s.write_all(&pg_auth(3, &[]));
1165 let _ = s.read(&mut buf).unwrap();
1166 let _ = s.write_all(&post_auth_ok());
1167 loop {
1168 match s.read(&mut buf) {
1169 Ok(0) | Err(_) => break,
1170 Ok(_) => {
1171 let mut r = Vec::new();
1172 r.extend(pg_msg(b'1', &[]));
1173 r.extend(pg_msg(b'2', &[]));
1174 let mut rd = Vec::new();
1175 rd.extend(&1u16.to_be_bytes());
1176 rd.extend(b"c\x00");
1177 rd.extend(&0u32.to_be_bytes());
1178 rd.extend(&1u16.to_be_bytes());
1179 rd.extend(&23u32.to_be_bytes());
1180 rd.extend(&4i16.to_be_bytes());
1181 rd.extend(&(-1i32).to_be_bytes());
1182 rd.extend(&0u16.to_be_bytes());
1183 r.extend(pg_msg(b'T', &rd));
1184 let mut dr = Vec::new();
1185 dr.extend(&1u16.to_be_bytes());
1186 dr.extend(&1u32.to_be_bytes());
1187 dr.push(b'1');
1188 r.extend(pg_msg(b'D', &dr));
1189 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
1190 r.extend(pg_msg(b'Z', b"E"));
1191 let _ = s.write_all(&r);
1192 }
1193 }
1194 }
1195 });
1196 thread::sleep(Duration::from_millis(30));
1197 port
1198 }
1199
1200 fn spawn_slow_partial_server() -> u16 {
1201 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1202 let port = listener.local_addr().unwrap().port();
1203 thread::spawn(move || {
1204 let (mut s, _) = listener.accept().unwrap();
1205 let mut buf = [0u8; 4096];
1206 let _ = s.read(&mut buf).unwrap();
1207 let _ = s.write_all(&pg_auth(3, &[]));
1208 let _ = s.read(&mut buf).unwrap();
1209 let _ = s.write_all(&post_auth_ok());
1210 match s.read(&mut buf) {
1211 Ok(0) | Err(_) => {}
1212 Ok(_) => {
1213 let _ = s.write_all(&simple_query_response());
1214 }
1215 }
1216 });
1217 thread::sleep(Duration::from_millis(30));
1218 port
1219 }
1220
1221 fn spawn_rst_on_query_server() -> u16 {
1222 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1223 let port = listener.local_addr().unwrap().port();
1224 thread::spawn(move || {
1225 let (mut s, _) = listener.accept().unwrap();
1226 let mut buf = [0u8; 4096];
1227 let _ = s.read(&mut buf).unwrap();
1228 let _ = s.write_all(&pg_auth(3, &[]));
1229 let _ = s.read(&mut buf).unwrap();
1230 let _ = s.write_all(&post_auth_ok());
1231 match s.read(&mut buf) {
1232 Ok(0) | Err(_) => {}
1233 Ok(_) => {
1234 drop(s);
1235 }
1236 }
1237 });
1238 thread::sleep(Duration::from_millis(30));
1239 port
1240 }
1241
1242 #[test]
1243 fn connect_query_ready_for_query_transaction_status() {
1244 let port = spawn_transaction_status_server();
1245 let mut conn = Connect::new(mock_config(port)).unwrap();
1246 let result = conn.query("SELECT 1");
1247 assert!(result.is_ok());
1248 }
1249
1250 #[test]
1251 fn connect_query_ready_for_query_error_status() {
1252 let port = spawn_error_status_server();
1253 let mut conn = Connect::new(mock_config(port)).unwrap();
1254 let result = conn.query("SELECT 1");
1255 assert!(result.is_ok());
1256 }
1257
1258 #[test]
1259 fn connect_query_server_closes_after_partial() {
1260 let port = spawn_slow_partial_server();
1261 let mut conn = Connect::new(mock_config(port)).unwrap();
1262 let r1 = conn.query("SELECT 1");
1263 assert!(r1.is_ok());
1264 let r2 = conn.query("SELECT 1");
1265 assert!(r2.is_err());
1266 }
1267
1268 #[test]
1269 fn connect_query_server_rst_returns_io_or_connection_error() {
1270 let port = spawn_rst_on_query_server();
1271 let mut conn = Connect::new(mock_config(port)).unwrap();
1272 let result = conn.query("SELECT 1");
1273 assert!(result.is_err());
1274 }
1275
1276 #[test]
1277 fn connect_read_would_block_max_retries() {
1278 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1279 let port = listener.local_addr().unwrap().port();
1280 thread::spawn(move || {
1281 let (mut s, _) = listener.accept().unwrap();
1282 let mut buf = [0u8; 4096];
1283 let _ = s.read(&mut buf);
1284 let _ = s.write_all(&pg_auth(3, &[]));
1285 let _ = s.read(&mut buf);
1286 let _ = s.write_all(&post_auth_ok());
1287 let _ = s.read(&mut buf);
1288 thread::sleep(Duration::from_secs(5));
1289 });
1290 thread::sleep(Duration::from_millis(30));
1291
1292 let mut conn = Connect::new(mock_config(port)).unwrap();
1293 conn.stream
1294 .set_read_timeout(Some(Duration::from_millis(1)))
1295 .ok();
1296 let result = conn.query("SELECT 1");
1297 assert!(result.is_err());
1298 let err_str = result.unwrap_err().to_string();
1299 assert!(
1300 err_str.contains("超时") || err_str.contains("Timeout") || err_str.contains("重试"),
1301 "expected timeout error, got: {err_str}"
1302 );
1303 }
1304
1305 #[test]
1306 fn connect_read_exceeds_max_message_size() {
1307 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1308 let port = listener.local_addr().unwrap().port();
1309 thread::spawn(move || {
1310 let (mut s, _) = listener.accept().unwrap();
1311 let mut buf = [0u8; 4096];
1312 let _ = s.read(&mut buf);
1313 let _ = s.write_all(&pg_auth(3, &[]));
1314 let _ = s.read(&mut buf);
1315 let _ = s.write_all(&post_auth_ok());
1316 let _ = s.read(&mut buf);
1317 let big = vec![b'X'; 256];
1318 let _ = s.write_all(&big);
1319 thread::sleep(Duration::from_secs(2));
1320 });
1321 thread::sleep(Duration::from_millis(30));
1322
1323 let mut conn = Connect::new(mock_config(port)).unwrap();
1324 let result = conn.query("SELECT 1");
1325 assert!(result.is_err());
1326 let err_str = result.unwrap_err().to_string();
1327 assert!(
1328 err_str.contains("最大") || err_str.contains("大小") || err_str.contains("size"),
1329 "expected max message size error, got: {err_str}"
1330 );
1331 }
1332
1333 #[test]
1334 fn connect_read_deadline_timeout() {
1335 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1336 let port = listener.local_addr().unwrap().port();
1337 thread::spawn(move || {
1338 let (mut s, _) = listener.accept().unwrap();
1339 let mut buf = [0u8; 4096];
1340 let _ = s.read(&mut buf);
1341 let _ = s.write_all(&pg_auth(3, &[]));
1342 let _ = s.read(&mut buf);
1343 let _ = s.write_all(&post_auth_ok());
1344 let _ = s.read(&mut buf);
1345 for _ in 0..200 {
1346 let _ = s.write_all(b"X");
1347 thread::sleep(Duration::from_millis(5));
1348 }
1349 });
1350 thread::sleep(Duration::from_millis(30));
1351
1352 let mut conn = Connect::new(mock_config(port)).unwrap();
1353 let result = conn.query("SELECT 1");
1354 assert!(result.is_err());
1355 }
1356
1357 #[test]
1358 fn connect_read_partial_auth_frame() {
1359 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1360 let port = listener.local_addr().unwrap().port();
1361 thread::spawn(move || {
1362 let (mut s, _) = listener.accept().unwrap();
1363 let mut buf = [0u8; 4096];
1364 let _ = s.read(&mut buf);
1365 let auth = pg_auth(3, &[]);
1366 let _ = s.write_all(&auth[..5]);
1367 thread::sleep(Duration::from_millis(50));
1368 let _ = s.write_all(&auth[5..]);
1369 let _ = s.read(&mut buf);
1370 let _ = s.write_all(&post_auth_ok());
1371 loop {
1372 match s.read(&mut buf) {
1373 Ok(0) | Err(_) => break,
1374 Ok(_) => {
1375 let _ = s.write_all(&simple_query_response());
1376 }
1377 }
1378 }
1379 });
1380 thread::sleep(Duration::from_millis(30));
1381
1382 let mut conn = Connect::new(mock_config(port)).unwrap();
1383 let result = conn.query("SELECT 1");
1384 assert!(result.is_ok());
1385 }
1386 fn portal_response(rows: u16) -> Vec<u8> {
1390 let mut r = Vec::new();
1391 r.extend(pg_msg(b'1', &[])); r.extend(pg_msg(b'2', &[])); let mut rd = Vec::new();
1395 rd.extend(&1u16.to_be_bytes());
1396 rd.extend(b"id\x00");
1397 rd.extend(&0u32.to_be_bytes());
1398 rd.extend(&1u16.to_be_bytes());
1399 rd.extend(&23u32.to_be_bytes());
1400 rd.extend(&4i16.to_be_bytes());
1401 rd.extend(&(-1i32).to_be_bytes());
1402 rd.extend(&0u16.to_be_bytes());
1403 r.extend(pg_msg(b'T', &rd));
1404 for i in 0..rows {
1405 let val = format!("{}", i + 1);
1406 let mut dr = Vec::new();
1407 dr.extend(&1u16.to_be_bytes());
1408 dr.extend(&(val.len() as u32).to_be_bytes());
1409 dr.extend(val.as_bytes());
1410 r.extend(pg_msg(b'D', &dr));
1411 }
1412 r.extend(pg_msg(b's', &[]));
1414 r.extend(ready_for_query());
1415 r
1416 }
1417 fn portal_complete_response(rows: u16) -> Vec<u8> {
1420 let mut r = Vec::new();
1421 for i in 0..rows {
1422 let val = format!("{}", i + 1);
1423 let mut dr = Vec::new();
1424 dr.extend(&1u16.to_be_bytes());
1425 dr.extend(&(val.len() as u32).to_be_bytes());
1426 dr.extend(val.as_bytes());
1427 r.extend(pg_msg(b'D', &dr));
1428 }
1429 r.extend(pg_msg(b'C', b"SELECT 2\x00"));
1430 r.extend(ready_for_query());
1431 r
1432 }
1433 fn close_portal_response() -> Vec<u8> {
1435 let mut r = Vec::new();
1436 r.extend(pg_msg(b'3', &[])); r.extend(ready_for_query());
1438 r
1439 }
1440 fn spawn_portal_server() -> u16 {
1441 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1442 let port = listener.local_addr().unwrap().port();
1443 thread::spawn(move || {
1444 let (mut s, _) = listener.accept().unwrap();
1445 let mut buf = [0u8; 4096];
1446 let _ = s.read(&mut buf).unwrap();
1448 let _ = s.write_all(&pg_auth(3, &[]));
1449 let _ = s.read(&mut buf).unwrap();
1450 let _ = s.write_all(&post_auth_ok());
1451 match s.read(&mut buf) {
1453 Ok(0) | Err(_) => (),
1454 Ok(_) => {
1455 let _ = s.write_all(&portal_response(2));
1456 }
1457 }
1458 match s.read(&mut buf) {
1460 Ok(0) | Err(_) => (),
1461 Ok(_) => {
1462 let _ = s.write_all(&portal_complete_response(1));
1463 }
1464 }
1465 match s.read(&mut buf) {
1467 Ok(0) | Err(_) => (),
1468 Ok(_) => {
1469 let _ = s.write_all(&close_portal_response());
1470 }
1471 }
1472 });
1473 thread::sleep(Duration::from_millis(30));
1474 port
1475 }
1476 #[test]
1478 fn ssl_prefer_fallback_on_rejection() {
1479 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1480 let port = listener.local_addr().unwrap().port();
1481 thread::spawn(move || {
1482 let (mut s, _) = listener.accept().unwrap();
1483 let mut buf = [0u8; 4096];
1484 let _ = s.read(&mut buf);
1486 let _ = s.write_all(b"N");
1488 let _ = s.read(&mut buf);
1490 let _ = s.write_all(&pg_auth(3, &[]));
1491 let _ = s.read(&mut buf);
1492 let _ = s.write_all(&post_auth_ok());
1493 loop {
1494 match s.read(&mut buf) {
1495 Ok(0) | Err(_) => break,
1496 Ok(_) => {
1497 let _ = s.write_all(&simple_query_response());
1498 }
1499 }
1500 }
1501 });
1502 thread::sleep(Duration::from_millis(30));
1503 let mut cfg = mock_config(port);
1504 cfg.sslmode = "prefer".into();
1505 let conn = Connect::new(cfg);
1506 assert!(conn.is_ok());
1507 }
1508 #[test]
1509 fn ssl_require_rejected_returns_error() {
1510 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1511 let port = listener.local_addr().unwrap().port();
1512 thread::spawn(move || {
1513 let (mut s, _) = listener.accept().unwrap();
1514 let mut buf = [0u8; 4096];
1515 let _ = s.read(&mut buf);
1516 let _ = s.write_all(b"N");
1517 });
1518 thread::sleep(Duration::from_millis(30));
1519 let mut cfg = mock_config(port);
1520 cfg.sslmode = "require".into();
1521 let result = Connect::new(cfg);
1522 assert!(result.is_err());
1523 }
1524 #[test]
1525 fn ssl_invalid_response_byte() {
1526 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1527 let port = listener.local_addr().unwrap().port();
1528 thread::spawn(move || {
1529 let (mut s, _) = listener.accept().unwrap();
1530 let mut buf = [0u8; 4096];
1531 let _ = s.read(&mut buf);
1532 let _ = s.write_all(b"X");
1533 });
1534 thread::sleep(Duration::from_millis(30));
1535 let mut cfg = mock_config(port);
1536 cfg.sslmode = "prefer".into();
1537 let result = Connect::new(cfg);
1538 assert!(result.is_err());
1539 }
1540 #[test]
1541 fn ssl_disable_skips_ssl_handshake() {
1542 let port = spawn_cleartext_server();
1543 let mut cfg = mock_config(port);
1544 cfg.sslmode = "disable".into();
1545 let conn = Connect::new(cfg);
1546 assert!(conn.is_ok());
1547 }
1548 #[test]
1550 fn connect_query_portal_returns_rows_with_has_more() {
1551 let port = spawn_portal_server();
1552 let mut conn = Connect::new(mock_config(port)).unwrap();
1553 let result = conn.query_portal("SELECT id FROM t", 2);
1554 assert!(result.is_ok());
1555 let msg = result.unwrap();
1556 assert_eq!(msg.rows.len(), 2);
1557 assert!(msg.has_more);
1558 }
1559 #[test]
1560 fn connect_fetch_more_returns_remaining_rows() {
1561 let port = spawn_portal_server();
1562 let mut conn = Connect::new(mock_config(port)).unwrap();
1563 let _ = conn.query_portal("SELECT id FROM t", 2).unwrap();
1564 let result = conn.fetch_more(10);
1565 assert!(result.is_ok());
1566 let msg = result.unwrap();
1567 assert_eq!(msg.rows.len(), 1);
1568 assert!(!msg.has_more);
1569 }
1570 #[test]
1571 fn connect_close_portal_succeeds() {
1572 let port = spawn_portal_server();
1573 let mut conn = Connect::new(mock_config(port)).unwrap();
1574 let _ = conn.query_portal("SELECT id FROM t", 2).unwrap();
1575 let _ = conn.fetch_more(10).unwrap();
1576 let result = conn.close_portal();
1577 assert!(result.is_ok());
1578 }
1579}