1use std::collections::VecDeque;
20use std::io::{Read, Write};
21use std::net::TcpStream;
22#[cfg(unix)]
23use std::os::unix::io::AsRawFd;
24#[cfg(unix)]
25use std::os::unix::net::UnixStream;
26use std::rc::Rc;
27use std::time::{Duration, Instant};
28
29use crate::auth::ScramClient;
30use crate::codec;
31use crate::error::{PgError, PgResult};
32use crate::protocol::*;
33use crate::row::Row;
34use crate::statement::StatementCache;
35#[cfg(feature = "tls")]
36use crate::tls;
37use crate::types::{PgValue, ToSql};
38
39const DEFAULT_IO_TIMEOUT: Duration = Duration::from_secs(5);
41
42enum PgStream {
46 Tcp(TcpStream),
47 #[cfg(unix)]
48 Unix(UnixStream),
49 #[cfg(feature = "tls")]
50 Tls(tls::TlsStream),
51}
52
53impl Read for PgStream {
54 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
55 match self {
56 PgStream::Tcp(s) => s.read(buf),
57 #[cfg(unix)]
58 PgStream::Unix(s) => s.read(buf),
59 #[cfg(feature = "tls")]
60 PgStream::Tls(s) => s.read(buf),
61 }
62 }
63}
64
65impl Write for PgStream {
66 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
67 match self {
68 PgStream::Tcp(s) => s.write(buf),
69 #[cfg(unix)]
70 PgStream::Unix(s) => s.write(buf),
71 #[cfg(feature = "tls")]
72 PgStream::Tls(s) => s.write(buf),
73 }
74 }
75
76 fn flush(&mut self) -> std::io::Result<()> {
77 match self {
78 PgStream::Tcp(s) => s.flush(),
79 #[cfg(unix)]
80 PgStream::Unix(s) => s.flush(),
81 #[cfg(feature = "tls")]
82 PgStream::Tls(s) => s.flush(),
83 }
84 }
85
86 fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
87 match self {
88 PgStream::Tcp(s) => s.write_all(buf),
89 #[cfg(unix)]
90 PgStream::Unix(s) => s.write_all(buf),
91 #[cfg(feature = "tls")]
92 PgStream::Tls(s) => s.write_all(buf),
93 }
94 }
95}
96
97impl PgStream {
98 fn set_nonblocking(&self, nonblocking: bool) -> std::io::Result<()> {
99 match self {
100 PgStream::Tcp(s) => s.set_nonblocking(nonblocking),
101 #[cfg(unix)]
102 PgStream::Unix(s) => s.set_nonblocking(nonblocking),
103 #[cfg(feature = "tls")]
104 PgStream::Tls(s) => s.set_nonblocking(nonblocking),
105 }
106 }
107
108 #[cfg(unix)]
109 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
110 match self {
111 PgStream::Tcp(s) => s.as_raw_fd(),
112 PgStream::Unix(s) => s.as_raw_fd(),
113 #[cfg(feature = "tls")]
114 PgStream::Tls(s) => s.as_raw_fd(),
115 }
116 }
117
118 #[cfg(feature = "tls")]
121 fn tls_server_cert_hash(&self) -> Option<Vec<u8>> {
122 match self {
123 PgStream::Tls(s) => s.server_cert_hash(),
124 _ => None,
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct PgConfig {
132 pub host: String,
133 pub port: u16,
134 pub user: String,
135 pub password: String,
136 pub database: String,
137 pub socket_dir: Option<String>,
140 #[cfg(feature = "tls")]
143 pub ssl_mode: tls::SslMode,
144}
145
146impl PgConfig {
147 pub fn new(host: &str, port: u16, user: &str, password: &str, database: &str) -> Self {
148 Self {
149 host: host.to_string(),
150 port,
151 user: user.to_string(),
152 password: password.to_string(),
153 database: database.to_string(),
154 socket_dir: None,
155 #[cfg(feature = "tls")]
156 ssl_mode: tls::SslMode::default(),
157 }
158 }
159
160 pub fn with_socket_dir(mut self, dir: &str) -> Self {
163 self.socket_dir = Some(dir.to_string());
164 self
165 }
166
167 #[cfg(feature = "tls")]
169 pub fn with_ssl_mode(mut self, mode: tls::SslMode) -> Self {
170 self.ssl_mode = mode;
171 self
172 }
173
174 pub fn from_url(url: &str) -> PgResult<Self> {
180 let url = url
181 .strip_prefix("postgres://")
182 .or_else(|| url.strip_prefix("postgresql://"))
183 .ok_or_else(|| PgError::Protocol("Invalid URL scheme".to_string()))?;
184
185 let (userpass, hostdb) = url
187 .split_once('@')
188 .ok_or_else(|| PgError::Protocol("Missing @ in URL".to_string()))?;
189 let (user, password) = userpass.split_once(':').unwrap_or((userpass, ""));
190
191 let (hostdb_part, query_part) = hostdb.split_once('?').unwrap_or((hostdb, ""));
193
194 let (hostport, database) = hostdb_part
195 .split_once('/')
196 .ok_or_else(|| PgError::Protocol("Missing database in URL".to_string()))?;
197
198 let mut socket_dir: Option<String> = None;
200 #[cfg(feature = "tls")]
201 let mut ssl_mode = tls::SslMode::default();
202 if !query_part.is_empty() {
203 for param in query_part.split('&') {
204 if let Some(value) = param.strip_prefix("host=")
205 && value.starts_with('/')
206 {
207 socket_dir = Some(value.to_string());
208 }
209 #[cfg(feature = "tls")]
210 if let Some(value) = param.strip_prefix("sslmode=") {
211 if let Some(mode) = tls::SslMode::parse(value) {
212 ssl_mode = mode;
213 }
214 }
215 }
216 }
217
218 let decoded_host = percent_decode(hostport);
220 let is_unix_path = decoded_host.starts_with('/');
221 if is_unix_path {
222 socket_dir = Some(decoded_host);
223 }
224
225 let (host, port) = if socket_dir.is_some() {
226 let port_str = if hostport.is_empty() || is_unix_path {
228 "5432"
229 } else {
230 hostport.rsplit_once(':').map(|(_, p)| p).unwrap_or("5432")
231 };
232 let port: u16 = port_str
233 .parse()
234 .map_err(|_| PgError::Protocol("Invalid port".to_string()))?;
235 ("localhost".to_string(), port)
236 } else {
237 let (h, port_str) = hostport.split_once(':').unwrap_or((hostport, "5432"));
238 let port: u16 = port_str
239 .parse()
240 .map_err(|_| PgError::Protocol("Invalid port".to_string()))?;
241 (h.to_string(), port)
242 };
243
244 Ok(Self {
245 host,
246 port,
247 user: user.to_string(),
248 password: password.to_string(),
249 database: database.to_string(),
250 socket_dir,
251 #[cfg(feature = "tls")]
252 ssl_mode,
253 })
254 }
255}
256
257fn percent_decode(input: &str) -> String {
259 let mut result = String::with_capacity(input.len());
260 let bytes = input.as_bytes();
261 let mut i = 0;
262 while i < bytes.len() {
263 if bytes[i] == b'%'
264 && i + 2 < bytes.len()
265 && let (Some(hi), Some(lo)) = (hex_digit(bytes[i + 1]), hex_digit(bytes[i + 2]))
266 {
267 result.push((hi << 4 | lo) as char);
268 i += 3;
269 continue;
270 }
271 result.push(bytes[i] as char);
272 i += 1;
273 }
274 result
275}
276
277fn hex_digit(b: u8) -> Option<u8> {
278 match b {
279 b'0'..=b'9' => Some(b - b'0'),
280 b'a'..=b'f' => Some(b - b'a' + 10),
281 b'A'..=b'F' => Some(b - b'A' + 10),
282 _ => None,
283 }
284}
285
286#[derive(Debug, Clone)]
288pub struct Notification {
289 pub process_id: i32,
291 pub channel: String,
293 pub payload: String,
295}
296
297type NoticeHandler = Box<dyn Fn(&str, &str, &str) + Send + Sync>;
299
300pub struct PgConnection {
306 stream: PgStream,
307 read_buf: Vec<u8>,
308 write_buf: Vec<u8>,
309 read_pos: usize,
310 tx_status: TransactionStatus,
311 stmt_cache: StatementCache,
312 process_id: i32,
313 secret_key: i32,
314 server_params: Vec<(String, String)>,
315 notifications: VecDeque<Notification>,
317 last_affected_rows: u64,
319 last_command_tag: String,
321 nonblocking: bool,
323 io_timeout: Duration,
325 notice_handler: Option<NoticeHandler>,
327 broken: bool,
330}
331
332impl PgConnection {
333 pub fn connect(config: &PgConfig) -> PgResult<Self> {
336 let stream = if let Some(ref socket_dir) = config.socket_dir {
337 #[cfg(unix)]
339 {
340 let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, config.port);
341 let unix_stream = UnixStream::connect(&socket_path).map_err(PgError::Io)?;
342 PgStream::Unix(unix_stream)
343 }
344 #[cfg(not(unix))]
345 {
346 let _ = socket_dir;
347 return Err(PgError::Protocol(
348 "Unix domain sockets are not supported on this platform".to_string(),
349 ));
350 }
351 } else {
352 let addr = format!("{}:{}", config.host, config.port);
353 let tcp = TcpStream::connect(&addr).map_err(PgError::Io)?;
354 let _ = tcp.set_nodelay(true);
356
357 #[cfg(feature = "tls")]
359 {
360 match config.ssl_mode {
361 tls::SslMode::Disable => PgStream::Tcp(tcp),
362 tls::SslMode::Prefer => match tls::negotiate(tcp, &config.host) {
363 Ok(tls::TlsNegotiateResult::Tls(tls_stream)) => PgStream::Tls(tls_stream),
364 Ok(tls::TlsNegotiateResult::Rejected(tcp)) => PgStream::Tcp(tcp),
365 Err(_) => {
366 let tcp = TcpStream::connect(&addr).map_err(PgError::Io)?;
368 let _ = tcp.set_nodelay(true);
369 PgStream::Tcp(tcp)
370 }
371 },
372 tls::SslMode::Require => match tls::negotiate(tcp, &config.host)? {
373 tls::TlsNegotiateResult::Tls(tls_stream) => PgStream::Tls(tls_stream),
374 tls::TlsNegotiateResult::Rejected(_) => {
375 return Err(PgError::Protocol(
376 "Server does not support SSL (sslmode=require)".to_string(),
377 ));
378 }
379 },
380 }
381 }
382
383 #[cfg(not(feature = "tls"))]
384 PgStream::Tcp(tcp)
385 };
386
387 let mut conn = Self {
388 stream,
389 read_buf: vec![0u8; 64 * 1024], write_buf: vec![0u8; 64 * 1024], read_pos: 0,
392 tx_status: TransactionStatus::Idle,
393 stmt_cache: StatementCache::new(),
394 process_id: 0,
395 secret_key: 0,
396 server_params: Vec::new(),
397 notifications: VecDeque::new(),
398 last_affected_rows: 0,
399 last_command_tag: String::new(),
400 nonblocking: false,
401 io_timeout: DEFAULT_IO_TIMEOUT,
402 notice_handler: None,
403 broken: false,
404 };
405
406 conn.startup(config)?;
407
408 conn.stream.set_nonblocking(true).map_err(PgError::Io)?;
410 conn.nonblocking = true;
411
412 Ok(conn)
413 }
414
415 pub fn connect_with_timeout(config: &PgConfig, timeout: Duration) -> PgResult<Self> {
417 let mut conn = Self::connect(config)?;
418 conn.io_timeout = timeout;
419 Ok(conn)
420 }
421
422 pub fn set_io_timeout(&mut self, timeout: Duration) {
424 self.io_timeout = timeout;
425 }
426
427 pub fn io_timeout(&self) -> Duration {
429 self.io_timeout
430 }
431
432 pub fn set_notice_handler<F>(&mut self, handler: F)
444 where
445 F: Fn(&str, &str, &str) + Send + Sync + 'static,
446 {
447 self.notice_handler = Some(Box::new(handler));
448 }
449
450 pub fn clear_notice_handler(&mut self) {
452 self.notice_handler = None;
453 }
454
455 pub fn set_statement_cache_capacity(&mut self, capacity: usize) {
457 self.stmt_cache.set_max_capacity(capacity);
458 }
459
460 #[cfg(unix)]
463 pub fn raw_fd(&self) -> std::os::unix::io::RawFd {
464 self.stream.as_raw_fd()
465 }
466
467 pub fn is_nonblocking(&self) -> bool {
469 self.nonblocking
470 }
471
472 pub fn set_nonblocking(&mut self, nonblocking: bool) -> PgResult<()> {
474 self.stream
475 .set_nonblocking(nonblocking)
476 .map_err(PgError::Io)?;
477 self.nonblocking = nonblocking;
478 Ok(())
479 }
480
481 fn startup(&mut self, config: &PgConfig) -> PgResult<()> {
483 self.ensure_write_capacity(512);
485 let n = codec::encode_startup(&mut self.write_buf, &config.user, &config.database, &[]);
486 self.stream
487 .write_all(&self.write_buf[..n])
488 .map_err(PgError::Io)?;
489
490 loop {
492 self.fill_read_buf(None)?;
493
494 while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) {
495 let header = codec::decode_header(&self.read_buf)
496 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
497 let body = &self.read_buf[5..msg_len];
498
499 match header.tag {
500 BackendTag::AuthenticationRequest => {
501 let auth_type = codec::read_i32(&self.read_buf, 5);
502 match AuthType::from_i32(auth_type) {
503 Some(AuthType::Ok) => {
504 }
506 Some(AuthType::CleartextPassword) => {
507 let n =
508 codec::encode_password(&mut self.write_buf, &config.password);
509 self.stream
510 .write_all(&self.write_buf[..n])
511 .map_err(PgError::Io)?;
512 }
513 Some(AuthType::SASLInit) => {
514 #[cfg(feature = "tls")]
516 let (mut scram, mechanism) =
517 if let Some(cb_data) = self.stream.tls_server_cert_hash() {
518 (
519 ScramClient::new_with_channel_binding(
520 &config.user,
521 &config.password,
522 cb_data,
523 ),
524 "SCRAM-SHA-256-PLUS",
525 )
526 } else {
527 (
528 ScramClient::new(&config.user, &config.password),
529 "SCRAM-SHA-256",
530 )
531 };
532 #[cfg(not(feature = "tls"))]
533 let (mut scram, mechanism) = (
534 ScramClient::new(&config.user, &config.password),
535 "SCRAM-SHA-256",
536 );
537
538 let client_first = scram.client_first_message();
539 let n = codec::encode_sasl_initial(
540 &mut self.write_buf,
541 mechanism,
542 &client_first,
543 );
544 self.stream
545 .write_all(&self.write_buf[..n])
546 .map_err(PgError::Io)?;
547
548 self.consume_read(msg_len);
549 self.wait_for_sasl_continue(&mut scram, config)?;
550 continue;
553 }
554 Some(AuthType::MD5Password) => {
555 if body.len() < 8 {
557 return Err(PgError::Protocol(
558 "MD5Password message too short".to_string(),
559 ));
560 }
561 let salt: [u8; 4] = [body[4], body[5], body[6], body[7]];
562 let hash = crate::auth::md5_password_hash(
563 &config.user,
564 &config.password,
565 &salt,
566 );
567 let n = codec::encode_password(&mut self.write_buf, &hash);
568 self.stream
569 .write_all(&self.write_buf[..n])
570 .map_err(PgError::Io)?;
571 }
572 _ => {
573 return Err(PgError::Auth(format!(
574 "Unsupported auth type: {}",
575 auth_type
576 )));
577 }
578 }
579 }
580 BackendTag::ParameterStatus => {
581 let (name, consumed) = codec::read_cstring(body, 0);
582 let (value, _) = codec::read_cstring(body, consumed);
583 self.server_params
584 .push((name.to_string(), value.to_string()));
585 }
586 BackendTag::BackendKeyData => {
587 self.process_id = codec::read_i32(body, 0);
588 self.secret_key = codec::read_i32(body, 4);
589 }
590 BackendTag::ReadyForQuery => {
591 self.tx_status = TransactionStatus::from(body[0]);
592 self.consume_read(msg_len);
593 return Ok(()); }
595 BackendTag::ErrorResponse => {
596 let fields = codec::parse_error_fields(body);
597 return Err(PgError::from_fields(&fields));
598 }
599 _ => {
600 }
602 }
603 self.consume_read(msg_len);
604 }
605 }
606 }
607
608 fn wait_for_sasl_continue(
610 &mut self,
611 scram: &mut ScramClient,
612 _config: &PgConfig,
613 ) -> PgResult<()> {
614 loop {
615 self.fill_read_buf(None)?;
616
617 while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) {
618 let header = codec::decode_header(&self.read_buf)
619 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
620 let body = &self.read_buf[5..msg_len].to_vec();
621
622 match header.tag {
623 BackendTag::AuthenticationRequest => {
624 let auth_type = codec::read_i32(&self.read_buf, 5);
625 match AuthType::from_i32(auth_type) {
626 Some(AuthType::SASLContinue) => {
627 let server_first = &body[4..];
628 let client_final = scram
629 .process_server_first(server_first)
630 .map_err(PgError::Auth)?;
631
632 let n =
633 codec::encode_sasl_response(&mut self.write_buf, &client_final);
634 self.stream
635 .write_all(&self.write_buf[..n])
636 .map_err(PgError::Io)?;
637 }
638 Some(AuthType::SASLFinal) => {
639 let server_final = &body[4..];
640 scram
641 .verify_server_final(server_final)
642 .map_err(PgError::Auth)?;
643 }
644 Some(AuthType::Ok) => {
645 self.consume_read(msg_len);
646 return Ok(());
647 }
648 _ => {
649 return Err(PgError::Auth(
650 "Unexpected auth message during SASL".to_string(),
651 ));
652 }
653 }
654 }
655 _ => {
656 }
658 }
659 self.consume_read(msg_len);
660 }
661 }
662 }
663
664 pub fn query_simple(&mut self, sql: &str) -> PgResult<Vec<Row>> {
668 self.ensure_write_capacity(5 + sql.len());
669 let n = codec::encode_query(&mut self.write_buf, sql);
670 self.flush_write_buf(n)?;
671 self.read_query_results()
672 }
673
674 pub fn query(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Vec<Row>> {
677 let stmt = self.stmt_cache.get_or_create(sql);
678
679 let estimated = 10 + sql.len() + (params.len() * 256);
681 self.ensure_write_capacity(estimated);
682
683 let mut pos = 0;
684
685 if stmt.is_new {
686 let n = codec::encode_parse(&mut self.write_buf[pos..], &stmt.name, sql, &[]);
688 pos += n;
689
690 let n = codec::encode_describe(
692 &mut self.write_buf[pos..],
693 DescribeTarget::Statement,
694 &stmt.name,
695 );
696 pos += n;
697 }
698
699 let pg_values: Vec<PgValue> = params.iter().map(|p| p.to_sql()).collect();
701 let param_formats: Vec<i16> = pg_values
702 .iter()
703 .map(|v| if v.prefers_binary() { 1_i16 } else { 0_i16 })
704 .collect();
705 let param_values: Vec<Option<Vec<u8>>> = pg_values
706 .iter()
707 .zip(param_formats.iter())
708 .map(|(v, &fmt)| {
709 if fmt == 1 {
710 v.to_binary_bytes()
711 } else {
712 v.to_text_bytes()
713 }
714 })
715 .collect();
716 let param_refs: Vec<Option<&[u8]>> = param_values.iter().map(|p| p.as_deref()).collect();
717 let n = codec::encode_bind(
718 &mut self.write_buf[pos..],
719 "", &stmt.name,
721 ¶m_formats,
722 ¶m_refs,
723 &[1], );
725 pos += n;
726
727 let n = codec::encode_execute(&mut self.write_buf[pos..], "", 0);
729 pos += n;
730
731 let n = codec::encode_sync(&mut self.write_buf[pos..]);
733 pos += n;
734
735 self.flush_write_buf(pos)?;
736
737 let rows = self.read_extended_results(sql, &stmt.name, stmt.is_new, stmt.columns)?;
739 Ok(rows)
740 }
741
742 pub fn query_one(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Row> {
749 let stmt = self.stmt_cache.get_or_create(sql);
750
751 let estimated = 10 + sql.len() + (params.len() * 256);
752 self.ensure_write_capacity(estimated);
753
754 let mut pos = 0;
755
756 if stmt.is_new {
757 let n = codec::encode_parse(&mut self.write_buf[pos..], &stmt.name, sql, &[]);
758 pos += n;
759 let n = codec::encode_describe(
760 &mut self.write_buf[pos..],
761 DescribeTarget::Statement,
762 &stmt.name,
763 );
764 pos += n;
765 }
766
767 let pg_values: Vec<PgValue> = params.iter().map(|p| p.to_sql()).collect();
768 let param_formats: Vec<i16> = pg_values
769 .iter()
770 .map(|v| if v.prefers_binary() { 1_i16 } else { 0_i16 })
771 .collect();
772 let param_values: Vec<Option<Vec<u8>>> = pg_values
773 .iter()
774 .zip(param_formats.iter())
775 .map(|(v, &fmt)| {
776 if fmt == 1 {
777 v.to_binary_bytes()
778 } else {
779 v.to_text_bytes()
780 }
781 })
782 .collect();
783 let param_refs: Vec<Option<&[u8]>> = param_values.iter().map(|p| p.as_deref()).collect();
784 let n = codec::encode_bind(
785 &mut self.write_buf[pos..],
786 "",
787 &stmt.name,
788 ¶m_formats,
789 ¶m_refs,
790 &[1],
791 );
792 pos += n;
793
794 let n = codec::encode_execute(&mut self.write_buf[pos..], "", 0);
797 pos += n;
798
799 let n = codec::encode_sync(&mut self.write_buf[pos..]);
800 pos += n;
801
802 self.flush_write_buf(pos)?;
803
804 self.read_extended_result_one(sql, &stmt.name, stmt.is_new, stmt.columns)
805 }
806
807 pub fn query_opt(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Option<Row>> {
810 let stmt = self.stmt_cache.get_or_create(sql);
811
812 let estimated = 10 + sql.len() + (params.len() * 256);
813 self.ensure_write_capacity(estimated);
814
815 let mut pos = 0;
816
817 if stmt.is_new {
818 let n = codec::encode_parse(&mut self.write_buf[pos..], &stmt.name, sql, &[]);
819 pos += n;
820 let n = codec::encode_describe(
821 &mut self.write_buf[pos..],
822 DescribeTarget::Statement,
823 &stmt.name,
824 );
825 pos += n;
826 }
827
828 let pg_values: Vec<PgValue> = params.iter().map(|p| p.to_sql()).collect();
829 let param_formats: Vec<i16> = pg_values
830 .iter()
831 .map(|v| if v.prefers_binary() { 1_i16 } else { 0_i16 })
832 .collect();
833 let param_values: Vec<Option<Vec<u8>>> = pg_values
834 .iter()
835 .zip(param_formats.iter())
836 .map(|(v, &fmt)| {
837 if fmt == 1 {
838 v.to_binary_bytes()
839 } else {
840 v.to_text_bytes()
841 }
842 })
843 .collect();
844 let param_refs: Vec<Option<&[u8]>> = param_values.iter().map(|p| p.as_deref()).collect();
845 let n = codec::encode_bind(
846 &mut self.write_buf[pos..],
847 "",
848 &stmt.name,
849 ¶m_formats,
850 ¶m_refs,
851 &[1],
852 );
853 pos += n;
854
855 let n = codec::encode_execute(&mut self.write_buf[pos..], "", 0);
856 pos += n;
857
858 let n = codec::encode_sync(&mut self.write_buf[pos..]);
859 pos += n;
860
861 self.flush_write_buf(pos)?;
862
863 self.read_extended_result_opt(sql, &stmt.name, stmt.is_new, stmt.columns)
864 }
865
866 pub fn execute(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<u64> {
869 let _rows = self.query(sql, params)?;
870 Ok(self.last_affected_rows)
871 }
872
873 pub fn begin(&mut self) -> PgResult<()> {
877 self.query_simple("BEGIN")?;
878 Ok(())
879 }
880
881 pub fn commit(&mut self) -> PgResult<()> {
883 self.query_simple("COMMIT")?;
884 Ok(())
885 }
886
887 pub fn rollback(&mut self) -> PgResult<()> {
889 self.query_simple("ROLLBACK")?;
890 Ok(())
891 }
892
893 pub fn savepoint(&mut self, name: &str) -> PgResult<()> {
895 self.query_simple(&format!("SAVEPOINT {}", name))?;
896 Ok(())
897 }
898
899 pub fn rollback_to(&mut self, name: &str) -> PgResult<()> {
901 self.query_simple(&format!("ROLLBACK TO SAVEPOINT {}", name))?;
902 Ok(())
903 }
904
905 pub fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
907 self.query_simple(&format!("RELEASE SAVEPOINT {}", name))?;
908 Ok(())
909 }
910
911 pub fn transaction<F, T>(&mut self, f: F) -> PgResult<T>
926 where
927 F: FnOnce(&mut Transaction<'_>) -> PgResult<T>,
928 {
929 self.begin()?;
930 let mut tx = Transaction {
931 conn: self,
932 finished: false,
933 savepoint_name: None,
934 savepoint_counter: 0,
935 };
936 match f(&mut tx) {
937 Ok(val) => {
938 tx.commit()?;
939 Ok(val)
940 }
941 Err(e) => {
942 let _ = tx.rollback();
944 Err(e)
945 }
946 }
947 }
948
949 pub fn copy_in(&mut self, sql: &str) -> PgResult<CopyWriter<'_>> {
953 let n = codec::encode_query(&mut self.write_buf, sql);
954 #[allow(clippy::unnecessary_to_owned)]
955 self.write_all(&self.write_buf[..n].to_vec())?;
956
957 loop {
959 self.fill_read_buf(None)?;
960 let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) else {
961 continue;
962 };
963 let header = codec::decode_header(&self.read_buf)
964 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
965 match header.tag {
966 BackendTag::CopyInResponse => {
967 self.consume_read(msg_len);
968 return Ok(CopyWriter { conn: self });
969 }
970 BackendTag::ErrorResponse => {
971 let body = &self.read_buf[5..msg_len];
972 return Err(self.parse_error(body));
973 }
974 _ => {
975 self.consume_read(msg_len);
976 }
977 }
978 }
979 }
980
981 pub fn copy_out(&mut self, sql: &str) -> PgResult<CopyReader<'_>> {
984 let n = codec::encode_query(&mut self.write_buf, sql);
985 #[allow(clippy::unnecessary_to_owned)]
986 self.write_all(&self.write_buf[..n].to_vec())?;
987
988 loop {
990 self.fill_read_buf(None)?;
991 let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) else {
992 continue;
993 };
994 let header = codec::decode_header(&self.read_buf)
995 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
996 match header.tag {
997 BackendTag::CopyOutResponse => {
998 self.consume_read(msg_len);
999 return Ok(CopyReader {
1000 conn: self,
1001 done: false,
1002 });
1003 }
1004 BackendTag::ErrorResponse => {
1005 let body = &self.read_buf[5..msg_len];
1006 return Err(self.parse_error(body));
1007 }
1008 _ => {
1009 self.consume_read(msg_len);
1010 }
1011 }
1012 }
1013 }
1014
1015 pub fn listen(&mut self, channel: &str) -> PgResult<()> {
1019 self.query_simple(&format!("LISTEN {}", channel))?;
1020 Ok(())
1021 }
1022
1023 pub fn notify(&mut self, channel: &str, payload: &str) -> PgResult<()> {
1025 self.query_simple(&format!("NOTIFY {}, '{}'", channel, payload))?;
1026 Ok(())
1027 }
1028
1029 pub fn unlisten(&mut self, channel: &str) -> PgResult<()> {
1031 self.query_simple(&format!("UNLISTEN {}", channel))?;
1032 Ok(())
1033 }
1034
1035 pub fn unlisten_all(&mut self) -> PgResult<()> {
1037 self.query_simple("UNLISTEN *")?;
1038 Ok(())
1039 }
1040
1041 pub fn drain_notifications(&mut self) -> Vec<Notification> {
1043 self.notifications.drain(..).collect()
1044 }
1045
1046 pub fn has_notifications(&self) -> bool {
1048 !self.notifications.is_empty()
1049 }
1050
1051 pub fn notification_count(&self) -> usize {
1053 self.notifications.len()
1054 }
1055
1056 pub fn poll_notification(&mut self) -> PgResult<Option<Notification>> {
1060 if let Some(n) = self.notifications.pop_front() {
1062 return Ok(Some(n));
1063 }
1064
1065 self.ensure_read_space();
1067 match self.stream.read(&mut self.read_buf[self.read_pos..]) {
1068 Ok(0) => return Err(PgError::ConnectionClosed),
1069 Ok(n) => {
1070 self.read_pos += n;
1071 while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) {
1073 let header = codec::decode_header(&self.read_buf).ok_or_else(|| {
1074 PgError::Protocol("Incomplete message header".to_string())
1075 })?;
1076 if header.tag == BackendTag::NotificationResponse {
1077 let body = &self.read_buf[5..msg_len];
1078 let notification = Self::parse_notification(body);
1079 self.notifications.push_back(notification);
1080 }
1081 self.consume_read(msg_len);
1082 }
1083 }
1084 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
1085 }
1087 Err(e) => return Err(PgError::Io(e)),
1088 }
1089
1090 Ok(self.notifications.pop_front())
1091 }
1092
1093 pub fn transaction_status(&self) -> TransactionStatus {
1097 self.tx_status
1098 }
1099
1100 pub fn cached_statements(&self) -> usize {
1102 self.stmt_cache.len()
1103 }
1104
1105 pub fn last_affected_rows(&self) -> u64 {
1107 self.last_affected_rows
1108 }
1109
1110 pub fn last_command_tag(&self) -> &str {
1112 &self.last_command_tag
1113 }
1114
1115 pub fn process_id(&self) -> i32 {
1117 self.process_id
1118 }
1119
1120 pub fn secret_key(&self) -> i32 {
1122 self.secret_key
1123 }
1124
1125 pub fn server_params(&self) -> &[(String, String)] {
1127 &self.server_params
1128 }
1129
1130 pub fn server_param(&self, name: &str) -> Option<&str> {
1132 self.server_params
1133 .iter()
1134 .find(|(k, _)| k == name)
1135 .map(|(_, v)| v.as_str())
1136 }
1137
1138 pub fn in_transaction(&self) -> bool {
1140 matches!(
1141 self.tx_status,
1142 TransactionStatus::InTransaction | TransactionStatus::Failed
1143 )
1144 }
1145
1146 pub fn clear_statement_cache(&mut self) {
1152 let _ = self.query_simple("DEALLOCATE ALL");
1153 self.stmt_cache.clear();
1154 }
1155
1156 pub fn is_broken(&self) -> bool {
1160 self.broken
1161 }
1162
1163 pub fn reset(&mut self) -> PgResult<()> {
1169 self.query_simple("DISCARD ALL")?;
1170 self.stmt_cache.clear();
1171 Ok(())
1172 }
1173
1174 pub fn execute_batch(&mut self, sql: &str) -> PgResult<u64> {
1186 self.query_simple(sql)?;
1187 Ok(self.last_affected_rows)
1188 }
1189
1190 pub fn is_alive(&mut self) -> bool {
1192 self.query_simple("SELECT 1").is_ok()
1193 }
1194
1195 pub fn try_fill_read_buf(&mut self) -> PgResult<usize> {
1203 self.ensure_read_space();
1204
1205 match self.stream.read(&mut self.read_buf[self.read_pos..]) {
1206 Ok(0) => {
1207 self.broken = true;
1208 Err(PgError::ConnectionClosed)
1209 }
1210 Ok(n) => {
1211 self.read_pos += n;
1212 Ok(n)
1213 }
1214 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Err(PgError::WouldBlock),
1215 Err(e) => {
1216 self.broken = true;
1217 Err(PgError::Io(e))
1218 }
1219 }
1220 }
1221
1222 pub fn try_write(&mut self, data: &[u8]) -> PgResult<usize> {
1226 match self.stream.write(data) {
1227 Ok(n) => Ok(n),
1228 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Err(PgError::WouldBlock),
1229 Err(e) => {
1230 self.broken = true;
1231 Err(PgError::Io(e))
1232 }
1233 }
1234 }
1235
1236 #[cfg(unix)]
1242 fn wait_readable(&self, timeout: Duration) -> PgResult<()> {
1243 let fd = self.stream.as_raw_fd();
1244 let timeout_ms = timeout.as_millis().min(i32::MAX as u128) as i32;
1245 let mut pfd = libc::pollfd {
1246 fd,
1247 events: libc::POLLIN,
1248 revents: 0,
1249 };
1250 let ret = unsafe { libc::poll(&mut pfd, 1, timeout_ms) };
1251 if ret < 0 {
1252 let e = std::io::Error::last_os_error();
1253 if e.kind() == std::io::ErrorKind::Interrupted {
1254 return Ok(()); }
1256 return Err(PgError::Io(e));
1257 }
1258 if ret == 0 {
1259 return Err(PgError::Timeout);
1260 }
1261 if pfd.revents & (libc::POLLERR | libc::POLLHUP | libc::POLLNVAL) != 0 {
1262 return Err(PgError::ConnectionClosed);
1263 }
1264 Ok(())
1265 }
1266
1267 #[cfg(unix)]
1269 fn wait_writable(&self, timeout: Duration) -> PgResult<()> {
1270 let fd = self.stream.as_raw_fd();
1271 let timeout_ms = timeout.as_millis().min(i32::MAX as u128) as i32;
1272 let mut pfd = libc::pollfd {
1273 fd,
1274 events: libc::POLLOUT,
1275 revents: 0,
1276 };
1277 let ret = unsafe { libc::poll(&mut pfd, 1, timeout_ms) };
1278 if ret < 0 {
1279 let e = std::io::Error::last_os_error();
1280 if e.kind() == std::io::ErrorKind::Interrupted {
1281 return Ok(());
1282 }
1283 return Err(PgError::Io(e));
1284 }
1285 if ret == 0 {
1286 return Err(PgError::Timeout);
1287 }
1288 if pfd.revents & (libc::POLLERR | libc::POLLHUP | libc::POLLNVAL) != 0 {
1289 return Err(PgError::ConnectionClosed);
1290 }
1291 Ok(())
1292 }
1293
1294 pub fn poll_read(&mut self, timeout: Duration) -> PgResult<usize> {
1300 let start = Instant::now();
1301 loop {
1302 match self.try_fill_read_buf() {
1303 Ok(n) => return Ok(n),
1304 Err(PgError::WouldBlock) => {
1305 let elapsed = start.elapsed();
1306 if elapsed >= timeout {
1307 return Err(PgError::Timeout);
1308 }
1309 #[cfg(unix)]
1310 self.wait_readable(timeout - elapsed)?;
1311 #[cfg(not(unix))]
1312 std::thread::sleep(Duration::from_micros(50));
1313 }
1314 Err(e) => return Err(e),
1315 }
1316 }
1317 }
1318
1319 pub fn poll_write(&mut self, data: &[u8], timeout: Duration) -> PgResult<()> {
1322 let start = Instant::now();
1323 let mut written = 0;
1324 while written < data.len() {
1325 match self.try_write(&data[written..]) {
1326 Ok(n) => written += n,
1327 Err(PgError::WouldBlock) => {
1328 let elapsed = start.elapsed();
1329 if elapsed >= timeout {
1330 return Err(PgError::Timeout);
1331 }
1332 #[cfg(unix)]
1333 self.wait_writable(timeout - elapsed)?;
1334 #[cfg(not(unix))]
1335 std::thread::sleep(Duration::from_micros(50));
1336 }
1337 Err(e) => return Err(e),
1338 }
1339 }
1340 Ok(())
1341 }
1342
1343 fn fill_read_buf(&mut self, min_size: Option<usize>) -> PgResult<()> {
1346 if let Some(min) = min_size {
1347 self.ensure_read_capacity(min);
1348 }
1349
1350 self.ensure_read_space();
1351
1352 if self.nonblocking {
1353 self.poll_read(self.io_timeout)?;
1355 } else {
1356 let n = self
1358 .stream
1359 .read(&mut self.read_buf[self.read_pos..])
1360 .map_err(PgError::Io)?;
1361 if n == 0 {
1362 return Err(PgError::ConnectionClosed);
1363 }
1364 self.read_pos += n;
1365 }
1366 Ok(())
1367 }
1368
1369 fn write_all(&mut self, data: &[u8]) -> PgResult<()> {
1371 if self.nonblocking {
1372 self.poll_write(data, self.io_timeout)
1373 } else {
1374 self.stream.write_all(data).map_err(PgError::Io)
1375 }
1376 }
1377
1378 fn flush_write_buf(&mut self, n: usize) -> PgResult<()> {
1386 if self.nonblocking {
1387 let timeout = self.io_timeout;
1388 let start = Instant::now();
1389 let mut written = 0;
1390 while written < n {
1391 match self.stream.write(&self.write_buf[written..n]) {
1392 Ok(w) => written += w,
1393 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
1394 let elapsed = start.elapsed();
1395 if elapsed >= timeout {
1396 return Err(PgError::Timeout);
1397 }
1398 #[cfg(unix)]
1399 self.wait_writable(timeout - elapsed)?;
1400 #[cfg(not(unix))]
1401 std::thread::sleep(Duration::from_micros(50));
1402 }
1403 Err(e) => {
1404 self.broken = true;
1405 return Err(PgError::Io(e));
1406 }
1407 }
1408 }
1409 Ok(())
1410 } else {
1411 self.stream
1412 .write_all(&self.write_buf[..n])
1413 .map_err(PgError::Io)
1414 }
1415 }
1416
1417 fn ensure_read_space(&mut self) {
1419 if self.read_pos == self.read_buf.len() {
1420 if self.read_pos >= 5
1421 && let Some(header) = codec::decode_header(&self.read_buf)
1422 {
1423 let total = 1 + header.length as usize;
1424 self.ensure_read_capacity(total - self.read_pos);
1425 return;
1426 }
1427 self.ensure_read_capacity(8192);
1428 }
1429 }
1430
1431 fn consume_read(&mut self, n: usize) {
1432 self.read_buf.copy_within(n..self.read_pos, 0);
1433 self.read_pos -= n;
1434 }
1435
1436 fn ensure_read_capacity(&mut self, additional: usize) {
1437 if self.read_pos + additional > self.read_buf.len() {
1438 let new_len = (self.read_pos + additional).max(self.read_buf.len() * 2);
1439 self.read_buf.resize(new_len, 0);
1440 }
1441 }
1442
1443 fn ensure_write_capacity(&mut self, additional: usize) {
1444 if additional > self.write_buf.len() {
1445 let new_len = additional.max(self.write_buf.len() * 2);
1446 self.write_buf.resize(new_len, 0);
1447 }
1448 }
1449
1450 fn read_query_results(&mut self) -> PgResult<Vec<Row>> {
1451 let mut rows = Vec::new();
1452 let mut columns_rc: Rc<Vec<codec::ColumnDesc>> = Rc::new(Vec::new());
1453
1454 loop {
1455 if codec::message_complete(&self.read_buf[..self.read_pos]).is_none() {
1456 self.fill_read_buf(None)?;
1457 }
1458
1459 while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) {
1460 let header = codec::decode_header(&self.read_buf)
1461 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1462 let body = &self.read_buf[5..msg_len];
1463
1464 match header.tag {
1465 BackendTag::RowDescription => {
1466 columns_rc = Rc::new(codec::parse_row_description(body));
1467 }
1468 BackendTag::DataRow => {
1469 let raw_values = codec::parse_data_row(body);
1470 rows.push(Row::new(Rc::clone(&columns_rc), raw_values));
1471 }
1472 BackendTag::CommandComplete => {
1473 let (tag, rows_affected) = extract_command_complete(body);
1474 self.last_command_tag = tag;
1475 self.last_affected_rows = rows_affected;
1476 }
1477 BackendTag::ReadyForQuery => {
1478 self.tx_status = TransactionStatus::from(body[0]);
1479 self.consume_read(msg_len);
1480 return Ok(rows);
1481 }
1482 BackendTag::ErrorResponse => {
1483 let err = self.parse_error(body);
1484 self.consume_read(msg_len);
1485 self.drain_to_ready()?;
1487 return Err(err);
1488 }
1489 BackendTag::NotificationResponse => {
1490 let notification = Self::parse_notification(body);
1491 self.notifications.push_back(notification);
1492 }
1493 BackendTag::EmptyQueryResponse => {}
1494 BackendTag::NoticeResponse => {
1495 self.dispatch_notice(body);
1496 }
1497 _ => {}
1498 }
1499 self.consume_read(msg_len);
1500 }
1501 }
1502 }
1503
1504 fn read_extended_results(
1505 &mut self,
1506 sql: &str,
1507 stmt_name: &str,
1508 is_new: bool,
1509 cached_columns: Option<Vec<codec::ColumnDesc>>,
1510 ) -> PgResult<Vec<Row>> {
1511 let mut rows = Vec::new();
1512 let mut columns_rc: Rc<Vec<codec::ColumnDesc>> = match cached_columns {
1513 Some(c) => Rc::new(c),
1514 None => Rc::new(Vec::new()),
1515 };
1516
1517 loop {
1518 if codec::message_complete(&self.read_buf[..self.read_pos]).is_none() {
1519 self.fill_read_buf(None)?;
1520 }
1521
1522 while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) {
1523 let header = codec::decode_header(&self.read_buf)
1524 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1525 let body = &self.read_buf[5..msg_len];
1526
1527 match header.tag {
1528 BackendTag::ParseComplete => {}
1529 BackendTag::ParameterDescription => {}
1530 BackendTag::RowDescription => {
1531 let mut columns = codec::parse_row_description(body);
1532 for col in &mut columns {
1538 col.format_code = FormatCode::Binary;
1539 }
1540 if is_new
1541 && let Some(evicted) = self.stmt_cache.insert(
1542 sql,
1543 stmt_name.to_string(),
1544 0,
1545 Some(columns.clone()),
1546 )
1547 {
1548 self.close_statement_on_server(&evicted.name);
1549 }
1550 columns_rc = Rc::new(columns);
1551 }
1552 BackendTag::NoData if is_new => {
1553 if let Some(evicted) =
1554 self.stmt_cache.insert(sql, stmt_name.to_string(), 0, None)
1555 {
1556 self.close_statement_on_server(&evicted.name);
1557 }
1558 }
1559 BackendTag::NoData => {}
1560 BackendTag::BindComplete => {}
1561 BackendTag::DataRow => {
1562 let raw_values = codec::parse_data_row(body);
1563 rows.push(Row::new(Rc::clone(&columns_rc), raw_values));
1564 }
1565 BackendTag::CommandComplete => {
1566 let (tag, rows_affected) = extract_command_complete(body);
1567 self.last_command_tag = tag;
1568 self.last_affected_rows = rows_affected;
1569 }
1570 BackendTag::ReadyForQuery => {
1571 self.tx_status = TransactionStatus::from(body[0]);
1572 self.consume_read(msg_len);
1573 return Ok(rows);
1574 }
1575 BackendTag::ErrorResponse => {
1576 let err = self.parse_error_with_context(body, sql);
1577 self.consume_read(msg_len);
1578 self.drain_to_ready()?;
1579 return Err(err);
1580 }
1581 BackendTag::NotificationResponse => {
1582 let notification = Self::parse_notification(body);
1583 self.notifications.push_back(notification);
1584 }
1585 BackendTag::NoticeResponse => {
1586 self.dispatch_notice(body);
1587 }
1588 _ => {}
1589 }
1590 self.consume_read(msg_len);
1591 }
1592 }
1593 }
1594
1595 fn read_extended_result_one(
1599 &mut self,
1600 sql: &str,
1601 stmt_name: &str,
1602 is_new: bool,
1603 cached_columns: Option<Vec<codec::ColumnDesc>>,
1604 ) -> PgResult<Row> {
1605 match self.read_extended_result_opt(sql, stmt_name, is_new, cached_columns)? {
1606 Some(row) => Ok(row),
1607 None => Err(PgError::NoRows),
1608 }
1609 }
1610
1611 fn read_extended_result_opt(
1615 &mut self,
1616 sql: &str,
1617 stmt_name: &str,
1618 is_new: bool,
1619 cached_columns: Option<Vec<codec::ColumnDesc>>,
1620 ) -> PgResult<Option<Row>> {
1621 let mut result: Option<Row> = None;
1622 let mut columns_rc: Rc<Vec<codec::ColumnDesc>> = match cached_columns {
1623 Some(c) => Rc::new(c),
1624 None => Rc::new(Vec::new()),
1625 };
1626
1627 loop {
1628 if codec::message_complete(&self.read_buf[..self.read_pos]).is_none() {
1629 self.fill_read_buf(None)?;
1630 }
1631
1632 while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) {
1633 let header = codec::decode_header(&self.read_buf)
1634 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1635 let body = &self.read_buf[5..msg_len];
1636
1637 match header.tag {
1638 BackendTag::ParseComplete => {}
1639 BackendTag::ParameterDescription => {}
1640 BackendTag::RowDescription => {
1641 let mut columns = codec::parse_row_description(body);
1642 for col in &mut columns {
1643 col.format_code = FormatCode::Binary;
1644 }
1645 if is_new
1646 && let Some(evicted) = self.stmt_cache.insert(
1647 sql,
1648 stmt_name.to_string(),
1649 0,
1650 Some(columns.clone()),
1651 )
1652 {
1653 self.close_statement_on_server(&evicted.name);
1654 }
1655 columns_rc = Rc::new(columns);
1656 }
1657 BackendTag::NoData if is_new => {
1658 if let Some(evicted) =
1659 self.stmt_cache.insert(sql, stmt_name.to_string(), 0, None)
1660 {
1661 self.close_statement_on_server(&evicted.name);
1662 }
1663 }
1664 BackendTag::NoData => {}
1665 BackendTag::BindComplete => {}
1666 BackendTag::DataRow
1667 if result.is_none() => {
1669 let raw_values = codec::parse_data_row(body);
1670 result = Some(Row::new(Rc::clone(&columns_rc), raw_values));
1671 }
1672 BackendTag::DataRow => {
1673 }
1675 BackendTag::CommandComplete => {
1676 let (tag, rows_affected) = extract_command_complete(body);
1677 self.last_command_tag = tag;
1678 self.last_affected_rows = rows_affected;
1679 }
1680 BackendTag::ReadyForQuery => {
1681 self.tx_status = TransactionStatus::from(body[0]);
1682 self.consume_read(msg_len);
1683 return Ok(result);
1684 }
1685 BackendTag::ErrorResponse => {
1686 let err = self.parse_error_with_context(body, sql);
1687 self.consume_read(msg_len);
1688 self.drain_to_ready()?;
1689 return Err(err);
1690 }
1691 BackendTag::NotificationResponse => {
1692 let notification = Self::parse_notification(body);
1693 self.notifications.push_back(notification);
1694 }
1695 BackendTag::NoticeResponse => {
1696 self.dispatch_notice(body);
1697 }
1698 _ => {}
1699 }
1700 self.consume_read(msg_len);
1701 }
1702 }
1703 }
1704
1705 fn drain_to_ready(&mut self) -> PgResult<()> {
1706 loop {
1707 if codec::message_complete(&self.read_buf[..self.read_pos]).is_none() {
1710 self.fill_read_buf(None)?;
1711 }
1712 while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) {
1713 let header = codec::decode_header(&self.read_buf)
1714 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1715 if header.tag == BackendTag::ReadyForQuery {
1716 let body = &self.read_buf[5..msg_len];
1717 self.tx_status = TransactionStatus::from(body[0]);
1718 self.consume_read(msg_len);
1719 return Ok(());
1720 }
1721 self.consume_read(msg_len);
1722 }
1723 }
1724 }
1725
1726 fn parse_error(&self, body: &[u8]) -> PgError {
1727 let fields = codec::parse_error_fields(body);
1728 PgError::from_fields(&fields)
1729 }
1730
1731 fn parse_error_with_context(&self, body: &[u8], query: &str) -> PgError {
1733 let fields = codec::parse_error_fields(body);
1734 let mut err = PgError::from_fields(&fields);
1735 if let PgError::Server(ref mut server_err) = err
1736 && server_err.internal_query.is_none()
1737 {
1738 server_err.internal_query = Some(query.to_string());
1739 }
1740 err
1741 }
1742
1743 fn dispatch_notice(&self, body: &[u8]) {
1745 if let Some(ref handler) = self.notice_handler {
1746 let fields = codec::parse_error_fields(body);
1747 let mut severity = "";
1748 let mut code = "";
1749 let mut message = "";
1750 for (field_type, value) in &fields {
1751 match field_type {
1752 b'S' => severity = value,
1753 b'C' => code = value,
1754 b'M' => message = value,
1755 _ => {}
1756 }
1757 }
1758 handler(severity, code, message);
1759 }
1760 }
1761
1762 fn close_statement_on_server(&mut self, name: &str) {
1765 self.ensure_write_capacity(7 + name.len());
1766 let n = codec::encode_close(&mut self.write_buf, CloseTarget::Statement, name);
1767 let _ = self.flush_write_buf(n);
1768 }
1769
1770 fn parse_notification(body: &[u8]) -> Notification {
1776 let process_id = codec::read_i32(body, 0);
1777 let (channel, consumed) = codec::read_cstring(body, 4);
1778 let (payload, _) = codec::read_cstring(body, 4 + consumed);
1779 Notification {
1780 process_id,
1781 channel: channel.to_string(),
1782 payload: payload.to_string(),
1783 }
1784 }
1785}
1786
1787fn extract_command_complete(body: &[u8]) -> (String, u64) {
1791 let (tag, _) = codec::read_cstring(body, 0);
1792 let tag_str = tag.to_string();
1793 let affected_rows = tag
1794 .rsplit(' ')
1795 .next()
1796 .and_then(|s| s.parse::<u64>().ok())
1797 .unwrap_or(0);
1798 (tag_str, affected_rows)
1799}
1800
1801impl Drop for PgConnection {
1802 fn drop(&mut self) {
1803 if self.nonblocking {
1807 let _ = self.stream.set_nonblocking(false);
1808 }
1809 let n = codec::encode_terminate(&mut self.write_buf);
1810 let _ = self.stream.write_all(&self.write_buf[..n]);
1811 }
1812}
1813
1814pub struct Transaction<'a> {
1822 conn: &'a mut PgConnection,
1823 finished: bool,
1824 savepoint_name: Option<String>,
1826 savepoint_counter: u32,
1828}
1829
1830impl<'a> Transaction<'a> {
1831 pub fn commit(&mut self) -> PgResult<()> {
1833 if !self.finished {
1834 self.finished = true;
1835 if let Some(ref name) = self.savepoint_name {
1836 self.conn.release_savepoint(name)
1837 } else {
1838 self.conn.commit()
1839 }
1840 } else {
1841 Ok(())
1842 }
1843 }
1844
1845 pub fn rollback(&mut self) -> PgResult<()> {
1847 if !self.finished {
1848 self.finished = true;
1849 if let Some(ref name) = self.savepoint_name {
1850 self.conn.rollback_to(name)
1851 } else {
1852 self.conn.rollback()
1853 }
1854 } else {
1855 Ok(())
1856 }
1857 }
1858
1859 pub fn transaction<F, T>(&mut self, f: F) -> PgResult<T>
1878 where
1879 F: FnOnce(&mut Transaction<'_>) -> PgResult<T>,
1880 {
1881 self.savepoint_counter += 1;
1882 let sp_name = format!("chopin_sp_{}", self.savepoint_counter);
1883 self.conn.savepoint(&sp_name)?;
1884 let mut nested = Transaction {
1885 conn: self.conn,
1886 finished: false,
1887 savepoint_name: Some(sp_name),
1888 savepoint_counter: 0,
1889 };
1890 match f(&mut nested) {
1891 Ok(val) => {
1892 nested.commit()?;
1893 Ok(val)
1894 }
1895 Err(e) => {
1896 let _ = nested.rollback();
1897 Err(e)
1898 }
1899 }
1900 }
1901
1902 pub fn query_simple(&mut self, sql: &str) -> PgResult<Vec<Row>> {
1904 self.conn.query_simple(sql)
1905 }
1906
1907 pub fn query(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Vec<Row>> {
1909 self.conn.query(sql, params)
1910 }
1911
1912 pub fn query_one(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Row> {
1914 self.conn.query_one(sql, params)
1915 }
1916
1917 pub fn query_opt(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Option<Row>> {
1919 self.conn.query_opt(sql, params)
1920 }
1921
1922 pub fn execute(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<u64> {
1924 self.conn.execute(sql, params)
1925 }
1926
1927 pub fn savepoint(&mut self, name: &str) -> PgResult<()> {
1929 self.conn.savepoint(name)
1930 }
1931
1932 pub fn rollback_to(&mut self, name: &str) -> PgResult<()> {
1934 self.conn.rollback_to(name)
1935 }
1936
1937 pub fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
1939 self.conn.release_savepoint(name)
1940 }
1941
1942 pub fn status(&self) -> TransactionStatus {
1944 self.conn.transaction_status()
1945 }
1946}
1947
1948impl<'a> Drop for Transaction<'a> {
1949 fn drop(&mut self) {
1950 if !self.finished {
1951 if let Some(ref name) = self.savepoint_name {
1953 let _ = self.conn.rollback_to(name);
1954 } else {
1955 let _ = self.conn.rollback();
1956 }
1957 }
1958 }
1959}
1960
1961pub struct CopyWriter<'a> {
1965 conn: &'a mut PgConnection,
1966}
1967
1968impl<'a> CopyWriter<'a> {
1969 pub fn write_data(&mut self, data: &[u8]) -> PgResult<()> {
1971 self.conn.ensure_write_capacity(5 + data.len());
1972 let n = codec::encode_copy_data(&mut self.conn.write_buf, data);
1973 self.conn.flush_write_buf(n)
1974 }
1975
1976 pub fn fail(self, reason: &str) -> PgResult<()> {
1982 self.conn.ensure_write_capacity(6 + reason.len());
1983 let n = codec::encode_copy_fail(&mut self.conn.write_buf, reason);
1984 self.conn.flush_write_buf(n)?;
1985
1986 loop {
1988 self.conn.fill_read_buf(None)?;
1989 while let Some(msg_len) =
1990 codec::message_complete(&self.conn.read_buf[..self.conn.read_pos])
1991 {
1992 let header = codec::decode_header(&self.conn.read_buf)
1993 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1994 match header.tag {
1995 BackendTag::ErrorResponse => {
1996 self.conn.consume_read(msg_len);
1998 }
1999 BackendTag::ReadyForQuery => {
2000 let body = &self.conn.read_buf[5..msg_len];
2001 self.conn.tx_status = TransactionStatus::from(body[0]);
2002 self.conn.consume_read(msg_len);
2003 return Ok(());
2004 }
2005 _ => {
2006 self.conn.consume_read(msg_len);
2007 }
2008 }
2009 }
2010 }
2011 }
2012
2013 pub fn write_row(&mut self, columns: &[&str]) -> PgResult<()> {
2015 let line = columns.join("\t") + "\n";
2016 self.write_data(line.as_bytes())
2017 }
2018
2019 pub fn finish(self) -> PgResult<u64> {
2021 let n = codec::encode_copy_done(&mut self.conn.write_buf);
2022 self.conn.flush_write_buf(n)?;
2023
2024 loop {
2026 self.conn.fill_read_buf(None)?;
2027 while let Some(msg_len) =
2028 codec::message_complete(&self.conn.read_buf[..self.conn.read_pos])
2029 {
2030 let header = codec::decode_header(&self.conn.read_buf)
2031 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
2032 let body = &self.conn.read_buf[5..msg_len];
2033 match header.tag {
2034 BackendTag::CommandComplete => {
2035 let (tag, rows_affected) = extract_command_complete(body);
2036 self.conn.last_command_tag = tag;
2037 self.conn.last_affected_rows = rows_affected;
2038 }
2039 BackendTag::ReadyForQuery => {
2040 self.conn.tx_status = TransactionStatus::from(body[0]);
2041 self.conn.consume_read(msg_len);
2042 return Ok(self.conn.last_affected_rows);
2043 }
2044 BackendTag::ErrorResponse => {
2045 let err = self.conn.parse_error(body);
2046 self.conn.consume_read(msg_len);
2047 return Err(err);
2048 }
2049 _ => {}
2050 }
2051 self.conn.consume_read(msg_len);
2052 }
2053 }
2054 }
2055}
2056
2057pub struct CopyReader<'a> {
2061 conn: &'a mut PgConnection,
2062 done: bool,
2063}
2064
2065impl<'a> CopyReader<'a> {
2066 pub fn read_data(&mut self) -> PgResult<Option<Vec<u8>>> {
2069 if self.done {
2070 return Ok(None);
2071 }
2072
2073 loop {
2074 if codec::message_complete(&self.conn.read_buf[..self.conn.read_pos]).is_none() {
2080 self.conn.fill_read_buf(None)?;
2081 }
2082
2083 while let Some(msg_len) =
2084 codec::message_complete(&self.conn.read_buf[..self.conn.read_pos])
2085 {
2086 let header = codec::decode_header(&self.conn.read_buf)
2087 .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
2088 let body = &self.conn.read_buf[5..msg_len];
2089
2090 match header.tag {
2091 BackendTag::CopyData => {
2092 let data = body.to_vec();
2093 self.conn.consume_read(msg_len);
2094 return Ok(Some(data));
2095 }
2096 BackendTag::CopyDone => {
2097 self.conn.consume_read(msg_len);
2098 }
2100 BackendTag::CommandComplete => {
2101 let (tag, rows_affected) = extract_command_complete(body);
2102 self.conn.last_command_tag = tag;
2103 self.conn.last_affected_rows = rows_affected;
2104 self.conn.consume_read(msg_len);
2105 }
2106 BackendTag::ReadyForQuery => {
2107 self.conn.tx_status = TransactionStatus::from(body[0]);
2108 self.conn.consume_read(msg_len);
2109 self.done = true;
2110 return Ok(None);
2111 }
2112 BackendTag::ErrorResponse => {
2113 let err = self.conn.parse_error(body);
2114 self.conn.consume_read(msg_len);
2115 self.done = true;
2116 return Err(err);
2117 }
2118 _ => {
2119 self.conn.consume_read(msg_len);
2120 }
2121 }
2122 }
2123 }
2124 }
2125
2126 pub fn read_all(&mut self) -> PgResult<Vec<u8>> {
2128 let mut result = Vec::new();
2129 while let Some(chunk) = self.read_data()? {
2130 result.extend_from_slice(&chunk);
2131 }
2132 Ok(result)
2133 }
2134
2135 pub fn is_done(&self) -> bool {
2137 self.done
2138 }
2139}
2140
2141#[cfg(test)]
2142mod tests {
2143 use super::*;
2144
2145 #[test]
2148 fn test_pgconfig_new_fields() {
2149 let cfg = PgConfig::new("db.example.com", 5432, "alice", "s3cret", "mydb");
2150 assert_eq!(cfg.host, "db.example.com");
2151 assert_eq!(cfg.port, 5432);
2152 assert_eq!(cfg.user, "alice");
2153 assert_eq!(cfg.password, "s3cret");
2154 assert_eq!(cfg.database, "mydb");
2155 assert!(cfg.socket_dir.is_none());
2156 }
2157
2158 #[test]
2159 fn test_pgconfig_new_custom_port() {
2160 let cfg = PgConfig::new("host", 9999, "u", "p", "d");
2161 assert_eq!(cfg.port, 9999);
2162 }
2163
2164 #[test]
2165 fn test_pgconfig_with_socket_dir_sets_field() {
2166 let cfg =
2167 PgConfig::new("localhost", 5432, "u", "p", "d").with_socket_dir("/var/run/postgresql");
2168 assert_eq!(cfg.socket_dir.as_deref(), Some("/var/run/postgresql"));
2169 }
2170
2171 #[test]
2172 fn test_pgconfig_clone_preserves_all_fields() {
2173 let cfg = PgConfig::new("h", 1234, "u", "p", "db").with_socket_dir("/tmp");
2174 let cloned = cfg.clone();
2175 assert_eq!(cloned.host, "h");
2176 assert_eq!(cloned.port, 1234);
2177 assert_eq!(cloned.user, "u");
2178 assert_eq!(cloned.password, "p");
2179 assert_eq!(cloned.database, "db");
2180 assert_eq!(cloned.socket_dir, Some("/tmp".to_string()));
2181 }
2182
2183 #[test]
2184 fn test_pgconfig_debug_contains_host() {
2185 let cfg = PgConfig::new("myhost", 5432, "u", "p", "d");
2186 let s = format!("{:?}", cfg);
2187 assert!(s.contains("myhost"), "Debug must include host: {}", s);
2188 }
2189
2190 #[test]
2193 fn test_from_url_basic_postgres_scheme() {
2194 let cfg = PgConfig::from_url("postgres://bob:hunter2@dbhost:5432/appdb").unwrap();
2195 assert_eq!(cfg.host, "dbhost");
2196 assert_eq!(cfg.port, 5432);
2197 assert_eq!(cfg.user, "bob");
2198 assert_eq!(cfg.password, "hunter2");
2199 assert_eq!(cfg.database, "appdb");
2200 assert!(cfg.socket_dir.is_none());
2201 }
2202
2203 #[test]
2204 fn test_from_url_postgresql_scheme() {
2205 let cfg = PgConfig::from_url("postgresql://u:p@host:5432/db").unwrap();
2206 assert_eq!(cfg.host, "host");
2207 assert_eq!(cfg.user, "u");
2208 }
2209
2210 #[test]
2211 fn test_from_url_default_port() {
2212 let cfg = PgConfig::from_url("postgres://u:p@myhost/mydb").unwrap();
2214 assert_eq!(cfg.port, 5432);
2215 assert_eq!(cfg.host, "myhost");
2216 }
2217
2218 #[test]
2219 fn test_from_url_no_password() {
2220 let cfg = PgConfig::from_url("postgres://alice@host:5432/db").unwrap();
2222 assert_eq!(cfg.user, "alice");
2223 assert_eq!(cfg.password, "");
2224 }
2225
2226 #[test]
2227 fn test_from_url_custom_port() {
2228 let cfg = PgConfig::from_url("postgres://u:p@host:9000/db").unwrap();
2229 assert_eq!(cfg.port, 9000);
2230 }
2231
2232 #[test]
2233 fn test_from_url_unix_socket_query_param() {
2234 let cfg = PgConfig::from_url("postgres://u:p@/db?host=/var/run/postgresql").unwrap();
2235 assert_eq!(cfg.socket_dir.as_deref(), Some("/var/run/postgresql"));
2236 assert_eq!(cfg.database, "db");
2237 }
2238
2239 #[test]
2240 fn test_from_url_unix_socket_percent_encoded() {
2241 let cfg = PgConfig::from_url("postgres://u:p@%2Fvar%2Frun%2Fpostgresql/db").unwrap();
2242 assert_eq!(cfg.socket_dir.as_deref(), Some("/var/run/postgresql"));
2243 assert_eq!(cfg.database, "db");
2244 }
2245
2246 #[test]
2249 fn test_from_url_invalid_scheme_errors() {
2250 let result = PgConfig::from_url("mysql://u:p@host/db");
2251 assert!(result.is_err(), "Non-postgres scheme must fail");
2252 }
2253
2254 #[test]
2255 fn test_from_url_missing_at_symbol_errors() {
2256 let result = PgConfig::from_url("postgres://no-at-sign/db");
2257 assert!(result.is_err(), "URL without @ must fail");
2258 }
2259
2260 #[test]
2261 fn test_from_url_missing_database_errors() {
2262 let result = PgConfig::from_url("postgres://u:p@host");
2264 assert!(result.is_err(), "URL without database must fail");
2265 }
2266
2267 #[test]
2268 fn test_from_url_invalid_port_errors() {
2269 let result = PgConfig::from_url("postgres://u:p@host:notaport/db");
2270 assert!(result.is_err(), "Non-numeric port must fail");
2271 }
2272
2273 #[test]
2274 fn test_from_url_empty_string_errors() {
2275 let result = PgConfig::from_url("");
2276 assert!(result.is_err());
2277 }
2278
2279 #[test]
2280 fn test_from_url_special_chars_in_password() {
2281 let cfg = PgConfig::from_url("postgres://user:p%40ss@host:5432/db");
2283 let _ = cfg; }
2286
2287 #[test]
2290 fn test_notification_fields() {
2291 let n = Notification {
2292 process_id: 12345,
2293 channel: "my_channel".to_string(),
2294 payload: "hello world".to_string(),
2295 };
2296 assert_eq!(n.process_id, 12345);
2297 assert_eq!(n.channel, "my_channel");
2298 assert_eq!(n.payload, "hello world");
2299 }
2300
2301 #[test]
2302 fn test_notification_clone() {
2303 let n = Notification {
2304 process_id: 42,
2305 channel: "ch".to_string(),
2306 payload: "pay".to_string(),
2307 };
2308 let n2 = n.clone();
2309 assert_eq!(n2.process_id, n.process_id);
2310 assert_eq!(n2.channel, n.channel);
2311 assert_eq!(n2.payload, n.payload);
2312 }
2313
2314 #[test]
2315 fn test_notification_debug() {
2316 let n = Notification {
2317 process_id: 1,
2318 channel: "c".to_string(),
2319 payload: "p".to_string(),
2320 };
2321 let s = format!("{:?}", n);
2322 assert!(
2323 s.contains("process_id"),
2324 "Debug must include process_id: {}",
2325 s
2326 );
2327 }
2328}