1use std::collections::VecDeque;
20use std::io::{Read, Write};
21use std::net::TcpStream;
22#[cfg(unix)]
23use std::os::unix::io::AsRawFd;
24#[cfg(unix)]
25use std::os::unix::net::UnixStream;
26use std::rc::Rc;
27use std::time::{Duration, Instant};
28
29use crate::auth::ScramClient;
30use crate::codec;
31use crate::error::{PgError, PgResult};
32use crate::protocol::*;
33use crate::row::Row;
34use crate::statement::StatementCache;
35#[cfg(feature = "tls")]
36use crate::tls;
37use crate::types::{PgValue, ToSql};
38
39const DEFAULT_IO_TIMEOUT: Duration = Duration::from_secs(5);
41
42enum PgStream {
46 Tcp(TcpStream),
47 #[cfg(unix)]
48 Unix(UnixStream),
49 #[cfg(feature = "tls")]
50 Tls(tls::TlsStream),
51}
52
53impl Read for PgStream {
54 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
55 match self {
56 PgStream::Tcp(s) => s.read(buf),
57 #[cfg(unix)]
58 PgStream::Unix(s) => s.read(buf),
59 #[cfg(feature = "tls")]
60 PgStream::Tls(s) => s.read(buf),
61 }
62 }
63}
64
65impl Write for PgStream {
66 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
67 match self {
68 PgStream::Tcp(s) => s.write(buf),
69 #[cfg(unix)]
70 PgStream::Unix(s) => s.write(buf),
71 #[cfg(feature = "tls")]
72 PgStream::Tls(s) => s.write(buf),
73 }
74 }
75
76 fn flush(&mut self) -> std::io::Result<()> {
77 match self {
78 PgStream::Tcp(s) => s.flush(),
79 #[cfg(unix)]
80 PgStream::Unix(s) => s.flush(),
81 #[cfg(feature = "tls")]
82 PgStream::Tls(s) => s.flush(),
83 }
84 }
85
86 fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
87 match self {
88 PgStream::Tcp(s) => s.write_all(buf),
89 #[cfg(unix)]
90 PgStream::Unix(s) => s.write_all(buf),
91 #[cfg(feature = "tls")]
92 PgStream::Tls(s) => s.write_all(buf),
93 }
94 }
95}
96
97impl PgStream {
98 fn set_nonblocking(&self, nonblocking: bool) -> std::io::Result<()> {
99 match self {
100 PgStream::Tcp(s) => s.set_nonblocking(nonblocking),
101 #[cfg(unix)]
102 PgStream::Unix(s) => s.set_nonblocking(nonblocking),
103 #[cfg(feature = "tls")]
104 PgStream::Tls(s) => s.set_nonblocking(nonblocking),
105 }
106 }
107
108 #[cfg(unix)]
109 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
110 match self {
111 PgStream::Tcp(s) => s.as_raw_fd(),
112 PgStream::Unix(s) => s.as_raw_fd(),
113 #[cfg(feature = "tls")]
114 PgStream::Tls(s) => s.as_raw_fd(),
115 }
116 }
117
118 #[cfg(feature = "tls")]
121 fn tls_server_cert_hash(&self) -> Option<Vec<u8>> {
122 match self {
123 PgStream::Tls(s) => s.server_cert_hash(),
124 _ => None,
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct PgConfig {
132 pub host: String,
133 pub port: u16,
134 pub user: String,
135 pub password: String,
136 pub database: String,
137 pub socket_dir: Option<String>,
140 #[cfg(feature = "tls")]
143 pub ssl_mode: tls::SslMode,
144}
145
146impl PgConfig {
147 pub fn new(host: &str, port: u16, user: &str, password: &str, database: &str) -> Self {
148 Self {
149 host: host.to_string(),
150 port,
151 user: user.to_string(),
152 password: password.to_string(),
153 database: database.to_string(),
154 socket_dir: None,
155 #[cfg(feature = "tls")]
156 ssl_mode: tls::SslMode::default(),
157 }
158 }
159
160 pub fn with_socket_dir(mut self, dir: &str) -> Self {
163 self.socket_dir = Some(dir.to_string());
164 self
165 }
166
167 #[cfg(feature = "tls")]
169 pub fn with_ssl_mode(mut self, mode: tls::SslMode) -> Self {
170 self.ssl_mode = mode;
171 self
172 }
173
174 pub fn from_url(url: &str) -> PgResult<Self> {
180 let url = url
181 .strip_prefix("postgres://")
182 .or_else(|| url.strip_prefix("postgresql://"))
183 .ok_or_else(|| PgError::Protocol("Invalid URL scheme".to_string()))?;
184
185 let (userpass, hostdb) = url
187 .split_once('@')
188 .ok_or_else(|| PgError::Protocol("Missing @ in URL".to_string()))?;
189 let (user, password) = userpass.split_once(':').unwrap_or((userpass, ""));
190
191 let (hostdb_part, query_part) = hostdb.split_once('?').unwrap_or((hostdb, ""));
193
194 let (hostport, database) = hostdb_part
195 .split_once('/')
196 .ok_or_else(|| PgError::Protocol("Missing database in URL".to_string()))?;
197
198 let mut socket_dir: Option<String> = None;
200 #[cfg(feature = "tls")]
201 let mut ssl_mode = tls::SslMode::default();
202 if !query_part.is_empty() {
203 for param in query_part.split('&') {
204 if let Some(value) = param.strip_prefix("host=")
205 && value.starts_with('/')
206 {
207 socket_dir = Some(value.to_string());
208 }
209 #[cfg(feature = "tls")]
210 if let Some(value) = param.strip_prefix("sslmode=") {
211 if let Some(mode) = tls::SslMode::parse(value) {
212 ssl_mode = mode;
213 }
214 }
215 }
216 }
217
218 let decoded_host = percent_decode(hostport);
220 let is_unix_path = decoded_host.starts_with('/');
221 if is_unix_path {
222 socket_dir = Some(decoded_host);
223 }
224
225 let (host, port) = if socket_dir.is_some() {
226 let port_str = if hostport.is_empty() || is_unix_path {
228 "5432"
229 } else {
230 hostport.rsplit_once(':').map(|(_, p)| p).unwrap_or("5432")
231 };
232 let port: u16 = port_str
233 .parse()
234 .map_err(|_| PgError::Protocol("Invalid port".to_string()))?;
235 ("localhost".to_string(), port)
236 } else {
237 let (h, port_str) = hostport.split_once(':').unwrap_or((hostport, "5432"));
238 let port: u16 = port_str
239 .parse()
240 .map_err(|_| PgError::Protocol("Invalid port".to_string()))?;
241 (h.to_string(), port)
242 };
243
244 Ok(Self {
245 host,
246 port,
247 user: user.to_string(),
248 password: password.to_string(),
249 database: database.to_string(),
250 socket_dir,
251 #[cfg(feature = "tls")]
252 ssl_mode,
253 })
254 }
255}
256
257fn percent_decode(input: &str) -> String {
259 let mut result = String::with_capacity(input.len());
260 let bytes = input.as_bytes();
261 let mut i = 0;
262 while i < bytes.len() {
263 if bytes[i] == b'%'
264 && i + 2 < bytes.len()
265 && let (Some(hi), Some(lo)) = (hex_digit(bytes[i + 1]), hex_digit(bytes[i + 2]))
266 {
267 result.push((hi << 4 | lo) as char);
268 i += 3;
269 continue;
270 }
271 result.push(bytes[i] as char);
272 i += 1;
273 }
274 result
275}
276
277fn hex_digit(b: u8) -> Option<u8> {
278 match b {
279 b'0'..=b'9' => Some(b - b'0'),
280 b'a'..=b'f' => Some(b - b'a' + 10),
281 b'A'..=b'F' => Some(b - b'A' + 10),
282 _ => None,
283 }
284}
285
286#[derive(Debug, Clone)]
288pub struct Notification {
289 pub process_id: i32,
291 pub channel: String,
293 pub payload: String,
295}
296
297type NoticeHandler = Box<dyn Fn(&str, &str, &str) + Send + Sync>;
299
300pub struct PgConnection {
306 stream: PgStream,
307 read_buf: Vec<u8>,
308 write_buf: Vec<u8>,
309 read_pos: usize,
310 tx_status: TransactionStatus,
311 stmt_cache: StatementCache,
312 process_id: i32,
313 secret_key: i32,
314 server_params: Vec<(String, String)>,
315 notifications: VecDeque<Notification>,
317 last_affected_rows: u64,
319 last_command_tag: String,
321 nonblocking: bool,
323 io_timeout: Duration,
325 notice_handler: Option<NoticeHandler>,
327 broken: bool,
330}
331
332impl PgConnection {
333 pub fn connect(config: &PgConfig) -> PgResult<Self> {
336 let stream = if let Some(ref socket_dir) = config.socket_dir {
337 #[cfg(unix)]
339 {
340 let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, config.port);
341 let unix_stream = UnixStream::connect(&socket_path).map_err(PgError::Io)?;
342 PgStream::Unix(unix_stream)
343 }
344 #[cfg(not(unix))]
345 {
346 let _ = socket_dir;
347 return Err(PgError::Protocol(
348 "Unix domain sockets are not supported on this platform".to_string(),
349 ));
350 }
351 } else {
352 let addr = format!("{}:{}", config.host, config.port);
353 let tcp = TcpStream::connect(&addr).map_err(PgError::Io)?;
354 let _ = tcp.set_nodelay(true);
356
357 #[cfg(feature = "tls")]
359 {
360 match config.ssl_mode {
361 tls::SslMode::Disable => PgStream::Tcp(tcp),
362 tls::SslMode::Prefer => match tls::negotiate(tcp, &config.host) {
363 Ok(tls::TlsNegotiateResult::Tls(tls_stream)) => PgStream::Tls(tls_stream),
364 Ok(tls::TlsNegotiateResult::Rejected(tcp)) => PgStream::Tcp(tcp),
365 Err(_) => {
366 let tcp = TcpStream::connect(&addr).map_err(PgError::Io)?;
368 let _ = tcp.set_nodelay(true);
369 PgStream::Tcp(tcp)
370 }
371 },
372 tls::SslMode::Require => match tls::negotiate(tcp, &config.host)? {
373 tls::TlsNegotiateResult::Tls(tls_stream) => PgStream::Tls(tls_stream),
374 tls::TlsNegotiateResult::Rejected(_) => {
375 return Err(PgError::Protocol(
376 "Server does not support SSL (sslmode=require)".to_string(),
377 ));
378 }
379 },
380 }
381 }
382
383 #[cfg(not(feature = "tls"))]
384 PgStream::Tcp(tcp)
385 };
386
387 let mut conn = Self {
388 stream,
389 read_buf: vec![0u8; 64 * 1024], write_buf: vec![0u8; 64 * 1024], read_pos: 0,
392 tx_status: TransactionStatus::Idle,
393 stmt_cache: StatementCache::new(),
394 process_id: 0,
395 secret_key: 0,
396 server_params: Vec::new(),
397 notifications: VecDeque::new(),
398 last_affected_rows: 0,
399 last_command_tag: String::new(),
400 nonblocking: false,
401 io_timeout: DEFAULT_IO_TIMEOUT,
402 notice_handler: None,
403 broken: false,
404 };
405
406 conn.startup(config)?;
407
408 conn.stream.set_nonblocking(true).map_err(PgError::Io)?;
410 conn.nonblocking = true;
411
412 Ok(conn)
413 }
414
415 pub fn connect_with_timeout(config: &PgConfig, timeout: Duration) -> PgResult<Self> {
417 let mut conn = Self::connect(config)?;
418 conn.io_timeout = timeout;
419 Ok(conn)
420 }
421
422 pub fn set_io_timeout(&mut self, timeout: Duration) {
424 self.io_timeout = timeout;
425 }
426
427 pub fn io_timeout(&self) -> Duration {
429 self.io_timeout
430 }
431
432 pub fn set_notice_handler<F>(&mut self, handler: F)
444 where
445 F: Fn(&str, &str, &str) + Send + Sync + 'static,
446 {
447 self.notice_handler = Some(Box::new(handler));
448 }
449
450 pub fn clear_notice_handler(&mut self) {
452 self.notice_handler = None;
453 }
454
455 pub fn set_statement_cache_capacity(&mut self, capacity: usize) {
457 self.stmt_cache.set_max_capacity(capacity);
458 }
459
460 #[cfg(unix)]
463 pub fn raw_fd(&self) -> std::os::unix::io::RawFd {
464 self.stream.as_raw_fd()
465 }
466
467 pub fn is_nonblocking(&self) -> bool {
469 self.nonblocking
470 }
471
472 pub fn set_nonblocking(&mut self, nonblocking: bool) -> PgResult<()> {
474 self.stream
475 .set_nonblocking(nonblocking)
476 .map_err(PgError::Io)?;
477 self.nonblocking = nonblocking;
478 Ok(())
479 }
480
481 fn startup(&mut self, config: &PgConfig) -> PgResult<()> {
483 self.ensure_write_capacity(512);
485 let n = codec::encode_startup(&mut self.write_buf, &config.user, &config.database, &[]);
486 self.stream
487 .write_all(&self.write_buf[..n])
488 .map_err(PgError::Io)?;
489
490 loop {
492 self.fill_read_buf(None)?;
493
494 while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])? {
495 let header = codec::decode_header(&self.read_buf)
496 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
497 let body = &self.read_buf[5..msg_len];
498
499 match header.tag {
500 BackendTag::AuthenticationRequest => {
501 let auth_type = codec::read_i32(&self.read_buf, 5);
502 match AuthType::from_i32(auth_type) {
503 Some(AuthType::Ok) => {
504 }
506 Some(AuthType::CleartextPassword) => {
507 let n =
508 codec::encode_password(&mut self.write_buf, &config.password);
509 self.stream
510 .write_all(&self.write_buf[..n])
511 .map_err(PgError::Io)?;
512 }
513 Some(AuthType::SASLInit) => {
514 #[cfg(feature = "tls")]
516 let (mut scram, mechanism) =
517 if let Some(cb_data) = self.stream.tls_server_cert_hash() {
518 (
519 ScramClient::new_with_channel_binding(
520 &config.user,
521 &config.password,
522 cb_data,
523 ),
524 "SCRAM-SHA-256-PLUS",
525 )
526 } else {
527 (
528 ScramClient::new(&config.user, &config.password),
529 "SCRAM-SHA-256",
530 )
531 };
532 #[cfg(not(feature = "tls"))]
533 let (mut scram, mechanism) = (
534 ScramClient::new(&config.user, &config.password),
535 "SCRAM-SHA-256",
536 );
537
538 let client_first = scram.client_first_message();
539 let n = codec::encode_sasl_initial(
540 &mut self.write_buf,
541 mechanism,
542 &client_first,
543 );
544 self.stream
545 .write_all(&self.write_buf[..n])
546 .map_err(PgError::Io)?;
547
548 self.consume_read(msg_len);
549 self.wait_for_sasl_continue(&mut scram, config)?;
550 continue;
553 }
554 Some(AuthType::MD5Password) => {
555 if body.len() < 8 {
557 return Err(PgError::Protocol(
558 "MD5Password message too short".to_string(),
559 ));
560 }
561 let salt: [u8; 4] = [body[4], body[5], body[6], body[7]];
562 let hash = crate::auth::md5_password_hash(
563 &config.user,
564 &config.password,
565 &salt,
566 );
567 let n = codec::encode_password(&mut self.write_buf, &hash);
568 self.stream
569 .write_all(&self.write_buf[..n])
570 .map_err(PgError::Io)?;
571 }
572 _ => {
573 return Err(PgError::Auth(format!(
574 "Unsupported auth type: {}",
575 auth_type
576 )));
577 }
578 }
579 }
580 BackendTag::ParameterStatus => {
581 let (name, consumed) = codec::read_cstring(body, 0);
582 let (value, _) = codec::read_cstring(body, consumed);
583 self.server_params
584 .push((name.to_string(), value.to_string()));
585 }
586 BackendTag::BackendKeyData => {
587 self.process_id = codec::read_i32(body, 0);
588 self.secret_key = codec::read_i32(body, 4);
589 }
590 BackendTag::ReadyForQuery => {
591 self.tx_status = TransactionStatus::from(body[0]);
592 self.consume_read(msg_len);
593 return Ok(()); }
595 BackendTag::ErrorResponse => {
596 let fields = codec::parse_error_fields(body);
597 return Err(PgError::from_fields(&fields));
598 }
599 _ => {
600 }
602 }
603 self.consume_read(msg_len);
604 }
605 }
606 }
607
608 fn wait_for_sasl_continue(
610 &mut self,
611 scram: &mut ScramClient,
612 _config: &PgConfig,
613 ) -> PgResult<()> {
614 loop {
615 self.fill_read_buf(None)?;
616
617 while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])? {
618 let header = codec::decode_header(&self.read_buf)
619 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
620 let body = &self.read_buf[5..msg_len].to_vec();
621
622 match header.tag {
623 BackendTag::AuthenticationRequest => {
624 let auth_type = codec::read_i32(&self.read_buf, 5);
625 match AuthType::from_i32(auth_type) {
626 Some(AuthType::SASLContinue) => {
627 let server_first = &body[4..];
628 let client_final = scram
629 .process_server_first(server_first)
630 .map_err(PgError::Auth)?;
631
632 let n =
633 codec::encode_sasl_response(&mut self.write_buf, &client_final);
634 self.stream
635 .write_all(&self.write_buf[..n])
636 .map_err(PgError::Io)?;
637 }
638 Some(AuthType::SASLFinal) => {
639 let server_final = &body[4..];
640 scram
641 .verify_server_final(server_final)
642 .map_err(PgError::Auth)?;
643 }
644 Some(AuthType::Ok) => {
645 self.consume_read(msg_len);
646 return Ok(());
647 }
648 _ => {
649 return Err(PgError::Auth(
650 "Unexpected auth message during SASL".to_string(),
651 ));
652 }
653 }
654 }
655 _ => {
656 }
658 }
659 self.consume_read(msg_len);
660 }
661 }
662 }
663
664 pub fn query_simple(&mut self, sql: &str) -> PgResult<Vec<Row>> {
668 self.ensure_write_capacity(5 + sql.len());
669 let n = codec::encode_query(&mut self.write_buf, sql);
670 self.flush_write_buf(n)?;
671 self.read_query_results()
672 }
673
674 pub fn query(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Vec<Row>> {
677 let stmt = self.stmt_cache.get_or_create(sql);
678
679 let estimated = 10 + sql.len() + (params.len() * 256);
681 self.ensure_write_capacity(estimated);
682
683 let mut pos = 0;
684
685 if stmt.is_new {
686 let n = codec::encode_parse(&mut self.write_buf[pos..], &stmt.name, sql, &[]);
688 pos += n;
689
690 let n = codec::encode_describe(
692 &mut self.write_buf[pos..],
693 DescribeTarget::Statement,
694 &stmt.name,
695 );
696 pos += n;
697 }
698
699 let pg_values: Vec<PgValue> = params.iter().map(|p| p.to_sql()).collect();
701 let param_formats: Vec<i16> = pg_values
702 .iter()
703 .map(|v| if v.prefers_binary() { 1_i16 } else { 0_i16 })
704 .collect();
705 let param_values: Vec<Option<Vec<u8>>> = pg_values
706 .iter()
707 .zip(param_formats.iter())
708 .map(|(v, &fmt)| {
709 if fmt == 1 {
710 v.to_binary_bytes()
711 } else {
712 v.to_text_bytes()
713 }
714 })
715 .collect();
716 let param_refs: Vec<Option<&[u8]>> = param_values.iter().map(|p| p.as_deref()).collect();
717 let n = codec::encode_bind(
718 &mut self.write_buf[pos..],
719 "", &stmt.name,
721 ¶m_formats,
722 ¶m_refs,
723 &[1], );
725 pos += n;
726
727 let n = codec::encode_execute(&mut self.write_buf[pos..], "", 0);
729 pos += n;
730
731 let n = codec::encode_sync(&mut self.write_buf[pos..]);
733 pos += n;
734
735 self.flush_write_buf(pos)?;
736
737 let rows = self.read_extended_results(sql, &stmt.name, stmt.is_new, stmt.columns)?;
739 Ok(rows)
740 }
741
742 pub fn query_one(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Row> {
749 let stmt = self.stmt_cache.get_or_create(sql);
750
751 let estimated = 10 + sql.len() + (params.len() * 256);
752 self.ensure_write_capacity(estimated);
753
754 let mut pos = 0;
755
756 if stmt.is_new {
757 let n = codec::encode_parse(&mut self.write_buf[pos..], &stmt.name, sql, &[]);
758 pos += n;
759 let n = codec::encode_describe(
760 &mut self.write_buf[pos..],
761 DescribeTarget::Statement,
762 &stmt.name,
763 );
764 pos += n;
765 }
766
767 let pg_values: Vec<PgValue> = params.iter().map(|p| p.to_sql()).collect();
768 let param_formats: Vec<i16> = pg_values
769 .iter()
770 .map(|v| if v.prefers_binary() { 1_i16 } else { 0_i16 })
771 .collect();
772 let param_values: Vec<Option<Vec<u8>>> = pg_values
773 .iter()
774 .zip(param_formats.iter())
775 .map(|(v, &fmt)| {
776 if fmt == 1 {
777 v.to_binary_bytes()
778 } else {
779 v.to_text_bytes()
780 }
781 })
782 .collect();
783 let param_refs: Vec<Option<&[u8]>> = param_values.iter().map(|p| p.as_deref()).collect();
784 let n = codec::encode_bind(
785 &mut self.write_buf[pos..],
786 "",
787 &stmt.name,
788 ¶m_formats,
789 ¶m_refs,
790 &[1],
791 );
792 pos += n;
793
794 let n = codec::encode_execute(&mut self.write_buf[pos..], "", 0);
797 pos += n;
798
799 let n = codec::encode_sync(&mut self.write_buf[pos..]);
800 pos += n;
801
802 self.flush_write_buf(pos)?;
803
804 self.read_extended_result_one(sql, &stmt.name, stmt.is_new, stmt.columns)
805 }
806
807 pub fn query_opt(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Option<Row>> {
810 let stmt = self.stmt_cache.get_or_create(sql);
811
812 let estimated = 10 + sql.len() + (params.len() * 256);
813 self.ensure_write_capacity(estimated);
814
815 let mut pos = 0;
816
817 if stmt.is_new {
818 let n = codec::encode_parse(&mut self.write_buf[pos..], &stmt.name, sql, &[]);
819 pos += n;
820 let n = codec::encode_describe(
821 &mut self.write_buf[pos..],
822 DescribeTarget::Statement,
823 &stmt.name,
824 );
825 pos += n;
826 }
827
828 let pg_values: Vec<PgValue> = params.iter().map(|p| p.to_sql()).collect();
829 let param_formats: Vec<i16> = pg_values
830 .iter()
831 .map(|v| if v.prefers_binary() { 1_i16 } else { 0_i16 })
832 .collect();
833 let param_values: Vec<Option<Vec<u8>>> = pg_values
834 .iter()
835 .zip(param_formats.iter())
836 .map(|(v, &fmt)| {
837 if fmt == 1 {
838 v.to_binary_bytes()
839 } else {
840 v.to_text_bytes()
841 }
842 })
843 .collect();
844 let param_refs: Vec<Option<&[u8]>> = param_values.iter().map(|p| p.as_deref()).collect();
845 let n = codec::encode_bind(
846 &mut self.write_buf[pos..],
847 "",
848 &stmt.name,
849 ¶m_formats,
850 ¶m_refs,
851 &[1],
852 );
853 pos += n;
854
855 let n = codec::encode_execute(&mut self.write_buf[pos..], "", 0);
856 pos += n;
857
858 let n = codec::encode_sync(&mut self.write_buf[pos..]);
859 pos += n;
860
861 self.flush_write_buf(pos)?;
862
863 self.read_extended_result_opt(sql, &stmt.name, stmt.is_new, stmt.columns)
864 }
865
866 pub fn execute(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<u64> {
869 let _rows = self.query(sql, params)?;
870 Ok(self.last_affected_rows)
871 }
872
873 pub fn begin(&mut self) -> PgResult<()> {
877 self.query_simple("BEGIN")?;
878 Ok(())
879 }
880
881 pub fn commit(&mut self) -> PgResult<()> {
883 self.query_simple("COMMIT")?;
884 Ok(())
885 }
886
887 pub fn rollback(&mut self) -> PgResult<()> {
889 self.query_simple("ROLLBACK")?;
890 Ok(())
891 }
892
893 pub fn savepoint(&mut self, name: &str) -> PgResult<()> {
895 self.query_simple(&format!("SAVEPOINT {}", name))?;
896 Ok(())
897 }
898
899 pub fn rollback_to(&mut self, name: &str) -> PgResult<()> {
901 self.query_simple(&format!("ROLLBACK TO SAVEPOINT {}", name))?;
902 Ok(())
903 }
904
905 pub fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
907 self.query_simple(&format!("RELEASE SAVEPOINT {}", name))?;
908 Ok(())
909 }
910
911 pub fn transaction<F, T>(&mut self, f: F) -> PgResult<T>
926 where
927 F: FnOnce(&mut Transaction<'_>) -> PgResult<T>,
928 {
929 self.begin()?;
930 let mut tx = Transaction {
931 conn: self,
932 finished: false,
933 savepoint_name: None,
934 savepoint_counter: 0,
935 };
936 match f(&mut tx) {
937 Ok(val) => {
938 tx.commit()?;
939 Ok(val)
940 }
941 Err(e) => {
942 let _ = tx.rollback();
944 Err(e)
945 }
946 }
947 }
948
949 pub fn copy_in(&mut self, sql: &str) -> PgResult<CopyWriter<'_>> {
953 let n = codec::encode_query(&mut self.write_buf, sql);
954 #[allow(clippy::unnecessary_to_owned)]
955 self.write_all(&self.write_buf[..n].to_vec())?;
956
957 loop {
959 self.fill_read_buf(None)?;
960 let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])? else {
961 continue;
962 };
963 let header = codec::decode_header(&self.read_buf)
964 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
965 match header.tag {
966 BackendTag::CopyInResponse => {
967 self.consume_read(msg_len);
968 return Ok(CopyWriter { conn: self });
969 }
970 BackendTag::ErrorResponse => {
971 let body = &self.read_buf[5..msg_len];
972 return Err(self.parse_error(body));
973 }
974 _ => {
975 self.consume_read(msg_len);
976 }
977 }
978 }
979 }
980
981 pub fn copy_out(&mut self, sql: &str) -> PgResult<CopyReader<'_>> {
984 let n = codec::encode_query(&mut self.write_buf, sql);
985 #[allow(clippy::unnecessary_to_owned)]
986 self.write_all(&self.write_buf[..n].to_vec())?;
987
988 loop {
990 self.fill_read_buf(None)?;
991 let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])? else {
992 continue;
993 };
994 let header = codec::decode_header(&self.read_buf)
995 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
996 match header.tag {
997 BackendTag::CopyOutResponse => {
998 self.consume_read(msg_len);
999 return Ok(CopyReader {
1000 conn: self,
1001 done: false,
1002 });
1003 }
1004 BackendTag::ErrorResponse => {
1005 let body = &self.read_buf[5..msg_len];
1006 return Err(self.parse_error(body));
1007 }
1008 _ => {
1009 self.consume_read(msg_len);
1010 }
1011 }
1012 }
1013 }
1014
1015 pub fn listen(&mut self, channel: &str) -> PgResult<()> {
1019 self.query_simple(&format!("LISTEN {}", channel))?;
1020 Ok(())
1021 }
1022
1023 pub fn notify(&mut self, channel: &str, payload: &str) -> PgResult<()> {
1025 self.query_simple(&format!("NOTIFY {}, '{}'", channel, payload))?;
1026 Ok(())
1027 }
1028
1029 pub fn unlisten(&mut self, channel: &str) -> PgResult<()> {
1031 self.query_simple(&format!("UNLISTEN {}", channel))?;
1032 Ok(())
1033 }
1034
1035 pub fn unlisten_all(&mut self) -> PgResult<()> {
1037 self.query_simple("UNLISTEN *")?;
1038 Ok(())
1039 }
1040
1041 pub fn drain_notifications(&mut self) -> Vec<Notification> {
1043 self.notifications.drain(..).collect()
1044 }
1045
1046 pub fn has_notifications(&self) -> bool {
1048 !self.notifications.is_empty()
1049 }
1050
1051 pub fn notification_count(&self) -> usize {
1053 self.notifications.len()
1054 }
1055
1056 pub fn poll_notification(&mut self) -> PgResult<Option<Notification>> {
1060 if let Some(n) = self.notifications.pop_front() {
1062 return Ok(Some(n));
1063 }
1064
1065 self.ensure_read_space();
1067 match self.stream.read(&mut self.read_buf[self.read_pos..]) {
1068 Ok(0) => return Err(PgError::ConnectionClosed),
1069 Ok(n) => {
1070 self.read_pos += n;
1071 while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])?
1073 {
1074 let header = codec::decode_header(&self.read_buf).ok_or_else(|| {
1075 PgError::Protocol("Incomplete message header".to_string())
1076 })?;
1077 if header.tag == BackendTag::NotificationResponse {
1078 let body = &self.read_buf[5..msg_len];
1079 let notification = Self::parse_notification(body);
1080 self.notifications.push_back(notification);
1081 }
1082 self.consume_read(msg_len);
1083 }
1084 }
1085 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
1086 }
1088 Err(e) => return Err(PgError::Io(e)),
1089 }
1090
1091 Ok(self.notifications.pop_front())
1092 }
1093
1094 pub fn transaction_status(&self) -> TransactionStatus {
1098 self.tx_status
1099 }
1100
1101 pub fn cached_statements(&self) -> usize {
1103 self.stmt_cache.len()
1104 }
1105
1106 pub fn last_affected_rows(&self) -> u64 {
1108 self.last_affected_rows
1109 }
1110
1111 pub fn last_command_tag(&self) -> &str {
1113 &self.last_command_tag
1114 }
1115
1116 pub fn process_id(&self) -> i32 {
1118 self.process_id
1119 }
1120
1121 pub fn secret_key(&self) -> i32 {
1123 self.secret_key
1124 }
1125
1126 pub fn server_params(&self) -> &[(String, String)] {
1128 &self.server_params
1129 }
1130
1131 pub fn server_param(&self, name: &str) -> Option<&str> {
1133 self.server_params
1134 .iter()
1135 .find(|(k, _)| k == name)
1136 .map(|(_, v)| v.as_str())
1137 }
1138
1139 pub fn in_transaction(&self) -> bool {
1141 matches!(
1142 self.tx_status,
1143 TransactionStatus::InTransaction | TransactionStatus::Failed
1144 )
1145 }
1146
1147 pub fn clear_statement_cache(&mut self) {
1153 let _ = self.query_simple("DEALLOCATE ALL");
1154 self.stmt_cache.clear();
1155 }
1156
1157 pub fn is_broken(&self) -> bool {
1161 self.broken
1162 }
1163
1164 pub fn reset(&mut self) -> PgResult<()> {
1170 self.query_simple("DISCARD ALL")?;
1171 self.stmt_cache.clear();
1172 Ok(())
1173 }
1174
1175 pub fn execute_batch(&mut self, sql: &str) -> PgResult<u64> {
1187 self.query_simple(sql)?;
1188 Ok(self.last_affected_rows)
1189 }
1190
1191 pub fn is_alive(&mut self) -> bool {
1193 self.query_simple("SELECT 1").is_ok()
1194 }
1195
1196 pub fn try_fill_read_buf(&mut self) -> PgResult<usize> {
1204 self.ensure_read_space();
1205
1206 match self.stream.read(&mut self.read_buf[self.read_pos..]) {
1207 Ok(0) => {
1208 self.broken = true;
1209 Err(PgError::ConnectionClosed)
1210 }
1211 Ok(n) => {
1212 self.read_pos += n;
1213 Ok(n)
1214 }
1215 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Err(PgError::WouldBlock),
1216 Err(e) => {
1217 self.broken = true;
1218 Err(PgError::Io(e))
1219 }
1220 }
1221 }
1222
1223 pub fn try_write(&mut self, data: &[u8]) -> PgResult<usize> {
1227 match self.stream.write(data) {
1228 Ok(n) => Ok(n),
1229 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Err(PgError::WouldBlock),
1230 Err(e) => {
1231 self.broken = true;
1232 Err(PgError::Io(e))
1233 }
1234 }
1235 }
1236
1237 #[cfg(unix)]
1243 fn wait_readable(&self, timeout: Duration) -> PgResult<()> {
1244 let fd = self.stream.as_raw_fd();
1245 let timeout_ms = timeout.as_millis().min(i32::MAX as u128) as i32;
1246 let mut pfd = libc::pollfd {
1247 fd,
1248 events: libc::POLLIN,
1249 revents: 0,
1250 };
1251 let ret = unsafe { libc::poll(&mut pfd, 1, timeout_ms) };
1252 if ret < 0 {
1253 let e = std::io::Error::last_os_error();
1254 if e.kind() == std::io::ErrorKind::Interrupted {
1255 return Ok(()); }
1257 return Err(PgError::Io(e));
1258 }
1259 if ret == 0 {
1260 return Err(PgError::Timeout);
1261 }
1262 if pfd.revents & (libc::POLLERR | libc::POLLHUP | libc::POLLNVAL) != 0 {
1263 return Err(PgError::ConnectionClosed);
1264 }
1265 Ok(())
1266 }
1267
1268 #[cfg(unix)]
1270 fn wait_writable(&self, timeout: Duration) -> PgResult<()> {
1271 let fd = self.stream.as_raw_fd();
1272 let timeout_ms = timeout.as_millis().min(i32::MAX as u128) as i32;
1273 let mut pfd = libc::pollfd {
1274 fd,
1275 events: libc::POLLOUT,
1276 revents: 0,
1277 };
1278 let ret = unsafe { libc::poll(&mut pfd, 1, timeout_ms) };
1279 if ret < 0 {
1280 let e = std::io::Error::last_os_error();
1281 if e.kind() == std::io::ErrorKind::Interrupted {
1282 return Ok(());
1283 }
1284 return Err(PgError::Io(e));
1285 }
1286 if ret == 0 {
1287 return Err(PgError::Timeout);
1288 }
1289 if pfd.revents & (libc::POLLERR | libc::POLLHUP | libc::POLLNVAL) != 0 {
1290 return Err(PgError::ConnectionClosed);
1291 }
1292 Ok(())
1293 }
1294
1295 pub fn poll_read(&mut self, timeout: Duration) -> PgResult<usize> {
1301 let start = Instant::now();
1302 loop {
1303 match self.try_fill_read_buf() {
1304 Ok(n) => return Ok(n),
1305 Err(PgError::WouldBlock) => {
1306 let elapsed = start.elapsed();
1307 if elapsed >= timeout {
1308 return Err(PgError::Timeout);
1309 }
1310 #[cfg(unix)]
1311 self.wait_readable(timeout - elapsed)?;
1312 #[cfg(not(unix))]
1313 std::thread::sleep(Duration::from_micros(50));
1314 }
1315 Err(e) => return Err(e),
1316 }
1317 }
1318 }
1319
1320 pub fn poll_write(&mut self, data: &[u8], timeout: Duration) -> PgResult<()> {
1323 let start = Instant::now();
1324 let mut written = 0;
1325 while written < data.len() {
1326 match self.try_write(&data[written..]) {
1327 Ok(n) => written += n,
1328 Err(PgError::WouldBlock) => {
1329 let elapsed = start.elapsed();
1330 if elapsed >= timeout {
1331 return Err(PgError::Timeout);
1332 }
1333 #[cfg(unix)]
1334 self.wait_writable(timeout - elapsed)?;
1335 #[cfg(not(unix))]
1336 std::thread::sleep(Duration::from_micros(50));
1337 }
1338 Err(e) => return Err(e),
1339 }
1340 }
1341 Ok(())
1342 }
1343
1344 fn fill_read_buf(&mut self, min_size: Option<usize>) -> PgResult<()> {
1347 if let Some(min) = min_size {
1348 self.ensure_read_capacity(min);
1349 }
1350
1351 self.ensure_read_space();
1352
1353 if self.nonblocking {
1354 self.poll_read(self.io_timeout)?;
1356 } else {
1357 let n = self
1359 .stream
1360 .read(&mut self.read_buf[self.read_pos..])
1361 .map_err(PgError::Io)?;
1362 if n == 0 {
1363 return Err(PgError::ConnectionClosed);
1364 }
1365 self.read_pos += n;
1366 }
1367 Ok(())
1368 }
1369
1370 fn write_all(&mut self, data: &[u8]) -> PgResult<()> {
1372 if self.nonblocking {
1373 self.poll_write(data, self.io_timeout)
1374 } else {
1375 self.stream.write_all(data).map_err(PgError::Io)
1376 }
1377 }
1378
1379 fn flush_write_buf(&mut self, n: usize) -> PgResult<()> {
1387 if self.nonblocking {
1388 let timeout = self.io_timeout;
1389 let start = Instant::now();
1390 let mut written = 0;
1391 while written < n {
1392 match self.stream.write(&self.write_buf[written..n]) {
1393 Ok(w) => written += w,
1394 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
1395 let elapsed = start.elapsed();
1396 if elapsed >= timeout {
1397 return Err(PgError::Timeout);
1398 }
1399 #[cfg(unix)]
1400 self.wait_writable(timeout - elapsed)?;
1401 #[cfg(not(unix))]
1402 std::thread::sleep(Duration::from_micros(50));
1403 }
1404 Err(e) => {
1405 self.broken = true;
1406 return Err(PgError::Io(e));
1407 }
1408 }
1409 }
1410 Ok(())
1411 } else {
1412 self.stream
1413 .write_all(&self.write_buf[..n])
1414 .map_err(PgError::Io)
1415 }
1416 }
1417
1418 fn ensure_read_space(&mut self) {
1420 if self.read_pos == self.read_buf.len() {
1421 if self.read_pos >= 5
1422 && let Some(header) = codec::decode_header(&self.read_buf)
1423 {
1424 if header.length as usize > codec::MAX_MESSAGE_SIZE {
1429 return;
1430 }
1431 let total = 1 + header.length as usize;
1432 self.ensure_read_capacity(total - self.read_pos);
1433 return;
1434 }
1435 self.ensure_read_capacity(8192);
1436 }
1437 }
1438
1439 fn consume_read(&mut self, n: usize) {
1440 self.read_buf.copy_within(n..self.read_pos, 0);
1441 self.read_pos -= n;
1442 }
1443
1444 fn ensure_read_capacity(&mut self, additional: usize) {
1445 if self.read_pos + additional > self.read_buf.len() {
1446 let new_len = (self.read_pos + additional).max(self.read_buf.len() * 2);
1447 self.read_buf.resize(new_len, 0);
1448 }
1449 }
1450
1451 fn ensure_write_capacity(&mut self, additional: usize) {
1452 if additional > self.write_buf.len() {
1453 let new_len = additional.max(self.write_buf.len() * 2);
1454 self.write_buf.resize(new_len, 0);
1455 }
1456 }
1457
1458 fn read_query_results(&mut self) -> PgResult<Vec<Row>> {
1459 let mut rows = Vec::new();
1460 let mut columns_rc: Rc<Vec<codec::ColumnDesc>> = Rc::new(Vec::new());
1461
1462 loop {
1463 if codec::message_complete(&self.read_buf[..self.read_pos])?.is_none() {
1464 self.fill_read_buf(None)?;
1465 }
1466
1467 while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])? {
1468 let header = codec::decode_header(&self.read_buf)
1469 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1470 let body = &self.read_buf[5..msg_len];
1471
1472 match header.tag {
1473 BackendTag::RowDescription => {
1474 columns_rc = Rc::new(codec::parse_row_description(body));
1475 }
1476 BackendTag::DataRow => {
1477 let raw_values = codec::parse_data_row(body);
1478 rows.push(Row::new(Rc::clone(&columns_rc), raw_values));
1479 }
1480 BackendTag::CommandComplete => {
1481 let (tag, rows_affected) = extract_command_complete(body);
1482 self.last_command_tag = tag;
1483 self.last_affected_rows = rows_affected;
1484 }
1485 BackendTag::ReadyForQuery => {
1486 self.tx_status = TransactionStatus::from(body[0]);
1487 self.consume_read(msg_len);
1488 return Ok(rows);
1489 }
1490 BackendTag::ErrorResponse => {
1491 let err = self.parse_error(body);
1492 self.consume_read(msg_len);
1493 self.drain_to_ready()?;
1495 return Err(err);
1496 }
1497 BackendTag::NotificationResponse => {
1498 let notification = Self::parse_notification(body);
1499 self.notifications.push_back(notification);
1500 }
1501 BackendTag::EmptyQueryResponse => {}
1502 BackendTag::NoticeResponse => {
1503 self.dispatch_notice(body);
1504 }
1505 _ => {}
1506 }
1507 self.consume_read(msg_len);
1508 }
1509 }
1510 }
1511
1512 fn read_extended_results(
1513 &mut self,
1514 sql: &str,
1515 stmt_name: &str,
1516 is_new: bool,
1517 cached_columns: Option<Vec<codec::ColumnDesc>>,
1518 ) -> PgResult<Vec<Row>> {
1519 let mut rows = Vec::new();
1520 let mut columns_rc: Rc<Vec<codec::ColumnDesc>> = match cached_columns {
1521 Some(c) => Rc::new(c),
1522 None => Rc::new(Vec::new()),
1523 };
1524
1525 loop {
1526 if codec::message_complete(&self.read_buf[..self.read_pos])?.is_none() {
1527 self.fill_read_buf(None)?;
1528 }
1529
1530 while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])? {
1531 let header = codec::decode_header(&self.read_buf)
1532 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1533 let body = &self.read_buf[5..msg_len];
1534
1535 match header.tag {
1536 BackendTag::ParseComplete => {}
1537 BackendTag::ParameterDescription => {}
1538 BackendTag::RowDescription => {
1539 let mut columns = codec::parse_row_description(body);
1540 for col in &mut columns {
1546 col.format_code = FormatCode::Binary;
1547 }
1548 if is_new
1549 && let Some(evicted) = self.stmt_cache.insert(
1550 sql,
1551 stmt_name.to_string(),
1552 0,
1553 Some(columns.clone()),
1554 )
1555 {
1556 self.close_statement_on_server(&evicted.name);
1557 }
1558 columns_rc = Rc::new(columns);
1559 }
1560 BackendTag::NoData if is_new => {
1561 if let Some(evicted) =
1562 self.stmt_cache.insert(sql, stmt_name.to_string(), 0, None)
1563 {
1564 self.close_statement_on_server(&evicted.name);
1565 }
1566 }
1567 BackendTag::NoData => {}
1568 BackendTag::BindComplete => {}
1569 BackendTag::DataRow => {
1570 let raw_values = codec::parse_data_row(body);
1571 rows.push(Row::new(Rc::clone(&columns_rc), raw_values));
1572 }
1573 BackendTag::CommandComplete => {
1574 let (tag, rows_affected) = extract_command_complete(body);
1575 self.last_command_tag = tag;
1576 self.last_affected_rows = rows_affected;
1577 }
1578 BackendTag::ReadyForQuery => {
1579 self.tx_status = TransactionStatus::from(body[0]);
1580 self.consume_read(msg_len);
1581 return Ok(rows);
1582 }
1583 BackendTag::ErrorResponse => {
1584 let err = self.parse_error_with_context(body, sql);
1585 self.consume_read(msg_len);
1586 self.drain_to_ready()?;
1587 return Err(err);
1588 }
1589 BackendTag::NotificationResponse => {
1590 let notification = Self::parse_notification(body);
1591 self.notifications.push_back(notification);
1592 }
1593 BackendTag::NoticeResponse => {
1594 self.dispatch_notice(body);
1595 }
1596 _ => {}
1597 }
1598 self.consume_read(msg_len);
1599 }
1600 }
1601 }
1602
1603 fn read_extended_result_one(
1607 &mut self,
1608 sql: &str,
1609 stmt_name: &str,
1610 is_new: bool,
1611 cached_columns: Option<Vec<codec::ColumnDesc>>,
1612 ) -> PgResult<Row> {
1613 match self.read_extended_result_opt(sql, stmt_name, is_new, cached_columns)? {
1614 Some(row) => Ok(row),
1615 None => Err(PgError::NoRows),
1616 }
1617 }
1618
1619 fn read_extended_result_opt(
1623 &mut self,
1624 sql: &str,
1625 stmt_name: &str,
1626 is_new: bool,
1627 cached_columns: Option<Vec<codec::ColumnDesc>>,
1628 ) -> PgResult<Option<Row>> {
1629 let mut result: Option<Row> = None;
1630 let mut columns_rc: Rc<Vec<codec::ColumnDesc>> = match cached_columns {
1631 Some(c) => Rc::new(c),
1632 None => Rc::new(Vec::new()),
1633 };
1634
1635 loop {
1636 if codec::message_complete(&self.read_buf[..self.read_pos])?.is_none() {
1637 self.fill_read_buf(None)?;
1638 }
1639
1640 while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])? {
1641 let header = codec::decode_header(&self.read_buf)
1642 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1643 let body = &self.read_buf[5..msg_len];
1644
1645 match header.tag {
1646 BackendTag::ParseComplete => {}
1647 BackendTag::ParameterDescription => {}
1648 BackendTag::RowDescription => {
1649 let mut columns = codec::parse_row_description(body);
1650 for col in &mut columns {
1651 col.format_code = FormatCode::Binary;
1652 }
1653 if is_new
1654 && let Some(evicted) = self.stmt_cache.insert(
1655 sql,
1656 stmt_name.to_string(),
1657 0,
1658 Some(columns.clone()),
1659 )
1660 {
1661 self.close_statement_on_server(&evicted.name);
1662 }
1663 columns_rc = Rc::new(columns);
1664 }
1665 BackendTag::NoData if is_new => {
1666 if let Some(evicted) =
1667 self.stmt_cache.insert(sql, stmt_name.to_string(), 0, None)
1668 {
1669 self.close_statement_on_server(&evicted.name);
1670 }
1671 }
1672 BackendTag::NoData => {}
1673 BackendTag::BindComplete => {}
1674 BackendTag::DataRow
1675 if result.is_none() => {
1677 let raw_values = codec::parse_data_row(body);
1678 result = Some(Row::new(Rc::clone(&columns_rc), raw_values));
1679 }
1680 BackendTag::DataRow => {
1681 }
1683 BackendTag::CommandComplete => {
1684 let (tag, rows_affected) = extract_command_complete(body);
1685 self.last_command_tag = tag;
1686 self.last_affected_rows = rows_affected;
1687 }
1688 BackendTag::ReadyForQuery => {
1689 self.tx_status = TransactionStatus::from(body[0]);
1690 self.consume_read(msg_len);
1691 return Ok(result);
1692 }
1693 BackendTag::ErrorResponse => {
1694 let err = self.parse_error_with_context(body, sql);
1695 self.consume_read(msg_len);
1696 self.drain_to_ready()?;
1697 return Err(err);
1698 }
1699 BackendTag::NotificationResponse => {
1700 let notification = Self::parse_notification(body);
1701 self.notifications.push_back(notification);
1702 }
1703 BackendTag::NoticeResponse => {
1704 self.dispatch_notice(body);
1705 }
1706 _ => {}
1707 }
1708 self.consume_read(msg_len);
1709 }
1710 }
1711 }
1712
1713 fn drain_to_ready(&mut self) -> PgResult<()> {
1714 loop {
1715 if codec::message_complete(&self.read_buf[..self.read_pos])?.is_none() {
1718 self.fill_read_buf(None)?;
1719 }
1720 while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])? {
1721 let header = codec::decode_header(&self.read_buf)
1722 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1723 if header.tag == BackendTag::ReadyForQuery {
1724 let body = &self.read_buf[5..msg_len];
1725 self.tx_status = TransactionStatus::from(body[0]);
1726 self.consume_read(msg_len);
1727 return Ok(());
1728 }
1729 self.consume_read(msg_len);
1730 }
1731 }
1732 }
1733
1734 fn parse_error(&self, body: &[u8]) -> PgError {
1735 let fields = codec::parse_error_fields(body);
1736 PgError::from_fields(&fields)
1737 }
1738
1739 fn parse_error_with_context(&self, body: &[u8], query: &str) -> PgError {
1741 let fields = codec::parse_error_fields(body);
1742 let mut err = PgError::from_fields(&fields);
1743 if let PgError::Server(ref mut server_err) = err
1744 && server_err.internal_query.is_none()
1745 {
1746 server_err.internal_query = Some(query.to_string());
1747 }
1748 err
1749 }
1750
1751 fn dispatch_notice(&self, body: &[u8]) {
1753 if let Some(ref handler) = self.notice_handler {
1754 let fields = codec::parse_error_fields(body);
1755 let mut severity = "";
1756 let mut code = "";
1757 let mut message = "";
1758 for (field_type, value) in &fields {
1759 match field_type {
1760 b'S' => severity = value,
1761 b'C' => code = value,
1762 b'M' => message = value,
1763 _ => {}
1764 }
1765 }
1766 handler(severity, code, message);
1767 }
1768 }
1769
1770 fn close_statement_on_server(&mut self, name: &str) {
1773 self.ensure_write_capacity(7 + name.len());
1774 let n = codec::encode_close(&mut self.write_buf, CloseTarget::Statement, name);
1775 let _ = self.flush_write_buf(n);
1776 }
1777
1778 fn parse_notification(body: &[u8]) -> Notification {
1784 let process_id = codec::read_i32(body, 0);
1785 let (channel, consumed) = codec::read_cstring(body, 4);
1786 let (payload, _) = codec::read_cstring(body, 4 + consumed);
1787 Notification {
1788 process_id,
1789 channel: channel.to_string(),
1790 payload: payload.to_string(),
1791 }
1792 }
1793}
1794
1795fn extract_command_complete(body: &[u8]) -> (String, u64) {
1799 let (tag, _) = codec::read_cstring(body, 0);
1800 let tag_str = tag.to_string();
1801 let affected_rows = tag
1802 .rsplit(' ')
1803 .next()
1804 .and_then(|s| s.parse::<u64>().ok())
1805 .unwrap_or(0);
1806 (tag_str, affected_rows)
1807}
1808
1809impl Drop for PgConnection {
1810 fn drop(&mut self) {
1811 if self.nonblocking {
1815 let _ = self.stream.set_nonblocking(false);
1816 }
1817 let n = codec::encode_terminate(&mut self.write_buf);
1818 let _ = self.stream.write_all(&self.write_buf[..n]);
1819 }
1820}
1821
1822pub struct Transaction<'a> {
1830 conn: &'a mut PgConnection,
1831 finished: bool,
1832 savepoint_name: Option<String>,
1834 savepoint_counter: u32,
1836}
1837
1838impl<'a> Transaction<'a> {
1839 pub fn commit(&mut self) -> PgResult<()> {
1841 if !self.finished {
1842 self.finished = true;
1843 if let Some(ref name) = self.savepoint_name {
1844 self.conn.release_savepoint(name)
1845 } else {
1846 self.conn.commit()
1847 }
1848 } else {
1849 Ok(())
1850 }
1851 }
1852
1853 pub fn rollback(&mut self) -> PgResult<()> {
1855 if !self.finished {
1856 self.finished = true;
1857 if let Some(ref name) = self.savepoint_name {
1858 self.conn.rollback_to(name)
1859 } else {
1860 self.conn.rollback()
1861 }
1862 } else {
1863 Ok(())
1864 }
1865 }
1866
1867 pub fn transaction<F, T>(&mut self, f: F) -> PgResult<T>
1886 where
1887 F: FnOnce(&mut Transaction<'_>) -> PgResult<T>,
1888 {
1889 self.savepoint_counter += 1;
1890 let sp_name = format!("chopin_sp_{}", self.savepoint_counter);
1891 self.conn.savepoint(&sp_name)?;
1892 let mut nested = Transaction {
1893 conn: self.conn,
1894 finished: false,
1895 savepoint_name: Some(sp_name),
1896 savepoint_counter: 0,
1897 };
1898 match f(&mut nested) {
1899 Ok(val) => {
1900 nested.commit()?;
1901 Ok(val)
1902 }
1903 Err(e) => {
1904 let _ = nested.rollback();
1905 Err(e)
1906 }
1907 }
1908 }
1909
1910 pub fn query_simple(&mut self, sql: &str) -> PgResult<Vec<Row>> {
1912 self.conn.query_simple(sql)
1913 }
1914
1915 pub fn query(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Vec<Row>> {
1917 self.conn.query(sql, params)
1918 }
1919
1920 pub fn query_one(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Row> {
1922 self.conn.query_one(sql, params)
1923 }
1924
1925 pub fn query_opt(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Option<Row>> {
1927 self.conn.query_opt(sql, params)
1928 }
1929
1930 pub fn execute(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<u64> {
1932 self.conn.execute(sql, params)
1933 }
1934
1935 pub fn savepoint(&mut self, name: &str) -> PgResult<()> {
1937 self.conn.savepoint(name)
1938 }
1939
1940 pub fn rollback_to(&mut self, name: &str) -> PgResult<()> {
1942 self.conn.rollback_to(name)
1943 }
1944
1945 pub fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
1947 self.conn.release_savepoint(name)
1948 }
1949
1950 pub fn status(&self) -> TransactionStatus {
1952 self.conn.transaction_status()
1953 }
1954}
1955
1956impl<'a> Drop for Transaction<'a> {
1957 fn drop(&mut self) {
1958 if !self.finished {
1959 if let Some(ref name) = self.savepoint_name {
1961 let _ = self.conn.rollback_to(name);
1962 } else {
1963 let _ = self.conn.rollback();
1964 }
1965 }
1966 }
1967}
1968
1969pub struct CopyWriter<'a> {
1973 conn: &'a mut PgConnection,
1974}
1975
1976impl<'a> CopyWriter<'a> {
1977 pub fn write_data(&mut self, data: &[u8]) -> PgResult<()> {
1979 self.conn.ensure_write_capacity(5 + data.len());
1980 let n = codec::encode_copy_data(&mut self.conn.write_buf, data);
1981 self.conn.flush_write_buf(n)
1982 }
1983
1984 pub fn fail(self, reason: &str) -> PgResult<()> {
1990 self.conn.ensure_write_capacity(6 + reason.len());
1991 let n = codec::encode_copy_fail(&mut self.conn.write_buf, reason);
1992 self.conn.flush_write_buf(n)?;
1993
1994 loop {
1996 self.conn.fill_read_buf(None)?;
1997 while let Some(msg_len) =
1998 codec::message_complete(&self.conn.read_buf[..self.conn.read_pos])?
1999 {
2000 let header = codec::decode_header(&self.conn.read_buf)
2001 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
2002 match header.tag {
2003 BackendTag::ErrorResponse => {
2004 self.conn.consume_read(msg_len);
2006 }
2007 BackendTag::ReadyForQuery => {
2008 let body = &self.conn.read_buf[5..msg_len];
2009 self.conn.tx_status = TransactionStatus::from(body[0]);
2010 self.conn.consume_read(msg_len);
2011 return Ok(());
2012 }
2013 _ => {
2014 self.conn.consume_read(msg_len);
2015 }
2016 }
2017 }
2018 }
2019 }
2020
2021 pub fn write_row(&mut self, columns: &[&str]) -> PgResult<()> {
2023 let line = columns.join("\t") + "\n";
2024 self.write_data(line.as_bytes())
2025 }
2026
2027 pub fn finish(self) -> PgResult<u64> {
2029 let n = codec::encode_copy_done(&mut self.conn.write_buf);
2030 self.conn.flush_write_buf(n)?;
2031
2032 loop {
2034 self.conn.fill_read_buf(None)?;
2035 while let Some(msg_len) =
2036 codec::message_complete(&self.conn.read_buf[..self.conn.read_pos])?
2037 {
2038 let header = codec::decode_header(&self.conn.read_buf)
2039 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
2040 let body = &self.conn.read_buf[5..msg_len];
2041 match header.tag {
2042 BackendTag::CommandComplete => {
2043 let (tag, rows_affected) = extract_command_complete(body);
2044 self.conn.last_command_tag = tag;
2045 self.conn.last_affected_rows = rows_affected;
2046 }
2047 BackendTag::ReadyForQuery => {
2048 self.conn.tx_status = TransactionStatus::from(body[0]);
2049 self.conn.consume_read(msg_len);
2050 return Ok(self.conn.last_affected_rows);
2051 }
2052 BackendTag::ErrorResponse => {
2053 let err = self.conn.parse_error(body);
2054 self.conn.consume_read(msg_len);
2055 return Err(err);
2056 }
2057 _ => {}
2058 }
2059 self.conn.consume_read(msg_len);
2060 }
2061 }
2062 }
2063}
2064
2065pub struct CopyReader<'a> {
2069 conn: &'a mut PgConnection,
2070 done: bool,
2071}
2072
2073impl<'a> CopyReader<'a> {
2074 pub fn read_data(&mut self) -> PgResult<Option<Vec<u8>>> {
2077 if self.done {
2078 return Ok(None);
2079 }
2080
2081 loop {
2082 if codec::message_complete(&self.conn.read_buf[..self.conn.read_pos])?.is_none() {
2088 self.conn.fill_read_buf(None)?;
2089 }
2090
2091 while let Some(msg_len) =
2092 codec::message_complete(&self.conn.read_buf[..self.conn.read_pos])?
2093 {
2094 let header = codec::decode_header(&self.conn.read_buf)
2095 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
2096 let body = &self.conn.read_buf[5..msg_len];
2097
2098 match header.tag {
2099 BackendTag::CopyData => {
2100 let data = body.to_vec();
2101 self.conn.consume_read(msg_len);
2102 return Ok(Some(data));
2103 }
2104 BackendTag::CopyDone => {
2105 self.conn.consume_read(msg_len);
2106 }
2108 BackendTag::CommandComplete => {
2109 let (tag, rows_affected) = extract_command_complete(body);
2110 self.conn.last_command_tag = tag;
2111 self.conn.last_affected_rows = rows_affected;
2112 self.conn.consume_read(msg_len);
2113 }
2114 BackendTag::ReadyForQuery => {
2115 self.conn.tx_status = TransactionStatus::from(body[0]);
2116 self.conn.consume_read(msg_len);
2117 self.done = true;
2118 return Ok(None);
2119 }
2120 BackendTag::ErrorResponse => {
2121 let err = self.conn.parse_error(body);
2122 self.conn.consume_read(msg_len);
2123 self.done = true;
2124 return Err(err);
2125 }
2126 _ => {
2127 self.conn.consume_read(msg_len);
2128 }
2129 }
2130 }
2131 }
2132 }
2133
2134 pub fn read_all(&mut self) -> PgResult<Vec<u8>> {
2136 let mut result = Vec::new();
2137 while let Some(chunk) = self.read_data()? {
2138 result.extend_from_slice(&chunk);
2139 }
2140 Ok(result)
2141 }
2142
2143 pub fn is_done(&self) -> bool {
2145 self.done
2146 }
2147}
2148
2149#[cfg(test)]
2150mod tests {
2151 use super::*;
2152
2153 #[test]
2156 fn test_pgconfig_new_fields() {
2157 let cfg = PgConfig::new("db.example.com", 5432, "alice", "s3cret", "mydb");
2158 assert_eq!(cfg.host, "db.example.com");
2159 assert_eq!(cfg.port, 5432);
2160 assert_eq!(cfg.user, "alice");
2161 assert_eq!(cfg.password, "s3cret");
2162 assert_eq!(cfg.database, "mydb");
2163 assert!(cfg.socket_dir.is_none());
2164 }
2165
2166 #[test]
2167 fn test_pgconfig_new_custom_port() {
2168 let cfg = PgConfig::new("host", 9999, "u", "p", "d");
2169 assert_eq!(cfg.port, 9999);
2170 }
2171
2172 #[test]
2173 fn test_pgconfig_with_socket_dir_sets_field() {
2174 let cfg =
2175 PgConfig::new("localhost", 5432, "u", "p", "d").with_socket_dir("/var/run/postgresql");
2176 assert_eq!(cfg.socket_dir.as_deref(), Some("/var/run/postgresql"));
2177 }
2178
2179 #[test]
2180 fn test_pgconfig_clone_preserves_all_fields() {
2181 let cfg = PgConfig::new("h", 1234, "u", "p", "db").with_socket_dir("/tmp");
2182 let cloned = cfg.clone();
2183 assert_eq!(cloned.host, "h");
2184 assert_eq!(cloned.port, 1234);
2185 assert_eq!(cloned.user, "u");
2186 assert_eq!(cloned.password, "p");
2187 assert_eq!(cloned.database, "db");
2188 assert_eq!(cloned.socket_dir, Some("/tmp".to_string()));
2189 }
2190
2191 #[test]
2192 fn test_pgconfig_debug_contains_host() {
2193 let cfg = PgConfig::new("myhost", 5432, "u", "p", "d");
2194 let s = format!("{:?}", cfg);
2195 assert!(s.contains("myhost"), "Debug must include host: {}", s);
2196 }
2197
2198 #[test]
2201 fn test_from_url_basic_postgres_scheme() {
2202 let cfg = PgConfig::from_url("postgres://bob:hunter2@dbhost:5432/appdb").unwrap();
2203 assert_eq!(cfg.host, "dbhost");
2204 assert_eq!(cfg.port, 5432);
2205 assert_eq!(cfg.user, "bob");
2206 assert_eq!(cfg.password, "hunter2");
2207 assert_eq!(cfg.database, "appdb");
2208 assert!(cfg.socket_dir.is_none());
2209 }
2210
2211 #[test]
2212 fn test_from_url_postgresql_scheme() {
2213 let cfg = PgConfig::from_url("postgresql://u:p@host:5432/db").unwrap();
2214 assert_eq!(cfg.host, "host");
2215 assert_eq!(cfg.user, "u");
2216 }
2217
2218 #[test]
2219 fn test_from_url_default_port() {
2220 let cfg = PgConfig::from_url("postgres://u:p@myhost/mydb").unwrap();
2222 assert_eq!(cfg.port, 5432);
2223 assert_eq!(cfg.host, "myhost");
2224 }
2225
2226 #[test]
2227 fn test_from_url_no_password() {
2228 let cfg = PgConfig::from_url("postgres://alice@host:5432/db").unwrap();
2230 assert_eq!(cfg.user, "alice");
2231 assert_eq!(cfg.password, "");
2232 }
2233
2234 #[test]
2235 fn test_from_url_custom_port() {
2236 let cfg = PgConfig::from_url("postgres://u:p@host:9000/db").unwrap();
2237 assert_eq!(cfg.port, 9000);
2238 }
2239
2240 #[test]
2241 fn test_from_url_unix_socket_query_param() {
2242 let cfg = PgConfig::from_url("postgres://u:p@/db?host=/var/run/postgresql").unwrap();
2243 assert_eq!(cfg.socket_dir.as_deref(), Some("/var/run/postgresql"));
2244 assert_eq!(cfg.database, "db");
2245 }
2246
2247 #[test]
2248 fn test_from_url_unix_socket_percent_encoded() {
2249 let cfg = PgConfig::from_url("postgres://u:p@%2Fvar%2Frun%2Fpostgresql/db").unwrap();
2250 assert_eq!(cfg.socket_dir.as_deref(), Some("/var/run/postgresql"));
2251 assert_eq!(cfg.database, "db");
2252 }
2253
2254 #[test]
2257 fn test_from_url_invalid_scheme_errors() {
2258 let result = PgConfig::from_url("mysql://u:p@host/db");
2259 assert!(result.is_err(), "Non-postgres scheme must fail");
2260 }
2261
2262 #[test]
2263 fn test_from_url_missing_at_symbol_errors() {
2264 let result = PgConfig::from_url("postgres://no-at-sign/db");
2265 assert!(result.is_err(), "URL without @ must fail");
2266 }
2267
2268 #[test]
2269 fn test_from_url_missing_database_errors() {
2270 let result = PgConfig::from_url("postgres://u:p@host");
2272 assert!(result.is_err(), "URL without database must fail");
2273 }
2274
2275 #[test]
2276 fn test_from_url_invalid_port_errors() {
2277 let result = PgConfig::from_url("postgres://u:p@host:notaport/db");
2278 assert!(result.is_err(), "Non-numeric port must fail");
2279 }
2280
2281 #[test]
2282 fn test_from_url_empty_string_errors() {
2283 let result = PgConfig::from_url("");
2284 assert!(result.is_err());
2285 }
2286
2287 #[test]
2288 fn test_from_url_special_chars_in_password() {
2289 let cfg = PgConfig::from_url("postgres://user:p%40ss@host:5432/db");
2291 let _ = cfg; }
2294
2295 #[test]
2298 fn test_notification_fields() {
2299 let n = Notification {
2300 process_id: 12345,
2301 channel: "my_channel".to_string(),
2302 payload: "hello world".to_string(),
2303 };
2304 assert_eq!(n.process_id, 12345);
2305 assert_eq!(n.channel, "my_channel");
2306 assert_eq!(n.payload, "hello world");
2307 }
2308
2309 #[test]
2310 fn test_notification_clone() {
2311 let n = Notification {
2312 process_id: 42,
2313 channel: "ch".to_string(),
2314 payload: "pay".to_string(),
2315 };
2316 let n2 = n.clone();
2317 assert_eq!(n2.process_id, n.process_id);
2318 assert_eq!(n2.channel, n.channel);
2319 assert_eq!(n2.payload, n.payload);
2320 }
2321
2322 #[test]
2323 fn test_notification_debug() {
2324 let n = Notification {
2325 process_id: 1,
2326 channel: "c".to_string(),
2327 payload: "p".to_string(),
2328 };
2329 let s = format!("{:?}", n);
2330 assert!(
2331 s.contains("process_id"),
2332 "Debug must include process_id: {}",
2333 s
2334 );
2335 }
2336}