1pub mod branch_sql;
20mod cancel;
21mod connection;
22mod copy;
23mod cursor;
24pub mod explain;
25#[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
26pub mod gss;
27mod io;
28pub mod io_backend;
29pub mod notification;
30mod pipeline;
31mod pool;
32mod prepared;
33mod query;
34pub mod rls;
35mod row;
36mod stream;
37mod transaction;
38
39pub use cancel::CancelToken;
40pub use connection::PgConnection;
41pub use connection::TlsConfig;
42pub(crate) use connection::{CANCEL_REQUEST_CODE, parse_affected_rows};
43pub use io_backend::{IoBackend, backend_name, detect as detect_io_backend};
44pub use notification::Notification;
45pub use pool::{PgPool, PoolConfig, PoolStats, PooledConnection};
46pub use prepared::PreparedStatement;
47pub use rls::RlsContext;
48pub use row::QailRow;
49
50use qail_core::ast::Qail;
51use std::collections::HashMap;
52use std::sync::Arc;
53
54#[derive(Debug, Clone)]
59pub struct ColumnInfo {
60 pub name_to_index: HashMap<String, usize>,
62 pub oids: Vec<u32>,
64 pub formats: Vec<i16>,
66}
67
68impl ColumnInfo {
69 pub fn from_fields(fields: &[crate::protocol::FieldDescription]) -> Self {
72 let mut name_to_index = HashMap::with_capacity(fields.len());
73 let mut oids = Vec::with_capacity(fields.len());
74 let mut formats = Vec::with_capacity(fields.len());
75
76 for (i, field) in fields.iter().enumerate() {
77 name_to_index.insert(field.name.clone(), i);
78 oids.push(field.type_oid);
79 formats.push(field.format);
80 }
81
82 Self {
83 name_to_index,
84 oids,
85 formats,
86 }
87 }
88}
89
90pub struct PgRow {
92 pub columns: Vec<Option<Vec<u8>>>,
94 pub column_info: Option<Arc<ColumnInfo>>,
96}
97
98#[derive(Debug)]
100pub enum PgError {
101 Connection(String),
103 Protocol(String),
105 Auth(String),
107 Query(String),
109 QueryServer(PgServerError),
111 NoRows,
113 Io(std::io::Error),
115 Encode(String),
117 Timeout(String),
119 PoolExhausted {
121 max: usize,
123 },
124 PoolClosed,
126}
127
128#[derive(Debug, Clone, PartialEq, Eq)]
130pub struct PgServerError {
131 pub severity: String,
133 pub code: String,
135 pub message: String,
137 pub detail: Option<String>,
139 pub hint: Option<String>,
141}
142
143impl From<crate::protocol::ErrorFields> for PgServerError {
144 fn from(value: crate::protocol::ErrorFields) -> Self {
145 Self {
146 severity: value.severity,
147 code: value.code,
148 message: value.message,
149 detail: value.detail,
150 hint: value.hint,
151 }
152 }
153}
154
155impl std::fmt::Display for PgError {
156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 match self {
158 PgError::Connection(e) => write!(f, "Connection error: {}", e),
159 PgError::Protocol(e) => write!(f, "Protocol error: {}", e),
160 PgError::Auth(e) => write!(f, "Auth error: {}", e),
161 PgError::Query(e) => write!(f, "Query error: {}", e),
162 PgError::QueryServer(e) => write!(f, "Query error [{}]: {}", e.code, e.message),
163 PgError::NoRows => write!(f, "No rows returned"),
164 PgError::Io(e) => write!(f, "I/O error: {}", e),
165 PgError::Encode(e) => write!(f, "Encode error: {}", e),
166 PgError::Timeout(ctx) => write!(f, "Timeout: {}", ctx),
167 PgError::PoolExhausted { max } => write!(f, "Pool exhausted ({} max connections)", max),
168 PgError::PoolClosed => write!(f, "Connection pool is closed"),
169 }
170 }
171}
172
173impl std::error::Error for PgError {
174 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
175 match self {
176 PgError::Io(e) => Some(e),
177 _ => None,
178 }
179 }
180}
181
182impl From<std::io::Error> for PgError {
183 fn from(e: std::io::Error) -> Self {
184 PgError::Io(e)
185 }
186}
187
188impl From<crate::protocol::EncodeError> for PgError {
189 fn from(e: crate::protocol::EncodeError) -> Self {
190 PgError::Encode(e.to_string())
191 }
192}
193
194impl PgError {
195 pub fn server_error(&self) -> Option<&PgServerError> {
197 match self {
198 PgError::QueryServer(err) => Some(err),
199 _ => None,
200 }
201 }
202
203 pub fn sqlstate(&self) -> Option<&str> {
205 self.server_error().map(|e| e.code.as_str())
206 }
207
208 pub fn is_prepared_statement_retryable(&self) -> bool {
211 let Some(err) = self.server_error() else {
212 return false;
213 };
214
215 let code = err.code.as_str();
216 let message = err.message.to_ascii_lowercase();
217
218 if code.eq_ignore_ascii_case("26000")
220 && message.contains("prepared statement")
221 && message.contains("does not exist")
222 {
223 return true;
224 }
225
226 if code.eq_ignore_ascii_case("0A000") && message.contains("cached plan must be replanned") {
228 return true;
229 }
230
231 message.contains("cached plan must be replanned")
233 }
234
235 pub fn is_transient_server_error(&self) -> bool {
241 match self {
243 PgError::Timeout(_) => return true,
244 PgError::Io(io) => {
245 return matches!(
246 io.kind(),
247 std::io::ErrorKind::TimedOut
248 | std::io::ErrorKind::ConnectionRefused
249 | std::io::ErrorKind::ConnectionReset
250 | std::io::ErrorKind::BrokenPipe
251 | std::io::ErrorKind::Interrupted
252 );
253 }
254 PgError::Connection(_) => return true,
255 _ => {}
256 }
257
258 if self.is_prepared_statement_retryable() {
260 return true;
261 }
262
263 let Some(code) = self.sqlstate() else {
264 return false;
265 };
266
267 matches!(
268 code,
269 "40001"
271 | "40P01"
273 | "57P03"
275 | "57P01"
277 | "57P02"
278 ) || code.starts_with("08") }
280}
281
282pub type PgResult<T> = Result<T, PgError>;
284
285#[derive(Debug, Clone)]
287pub struct QueryResult {
288 pub columns: Vec<String>,
290 pub rows: Vec<Vec<Option<String>>>,
292}
293
294#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
299pub enum ResultFormat {
300 #[default]
302 Text,
303 Binary,
305}
306
307impl ResultFormat {
308 #[inline]
309 pub(crate) fn as_wire_code(self) -> i16 {
310 match self {
311 ResultFormat::Text => crate::protocol::PgEncoder::FORMAT_TEXT,
312 ResultFormat::Binary => crate::protocol::PgEncoder::FORMAT_BINARY,
313 }
314 }
315}
316
317#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
319pub enum ScramChannelBindingMode {
320 Disable,
322 #[default]
324 Prefer,
325 Require,
327}
328
329impl ScramChannelBindingMode {
330 pub fn parse(value: &str) -> Option<Self> {
332 match value.trim().to_ascii_lowercase().as_str() {
333 "disable" | "off" | "false" | "no" => Some(Self::Disable),
334 "prefer" | "on" | "true" | "yes" => Some(Self::Prefer),
335 "require" | "required" => Some(Self::Require),
336 _ => None,
337 }
338 }
339}
340
341#[derive(Debug, Clone, Copy, PartialEq, Eq)]
343pub enum EnterpriseAuthMechanism {
344 KerberosV5,
346 GssApi,
348 Sspi,
350}
351
352pub type GssTokenProvider = fn(EnterpriseAuthMechanism, Option<&[u8]>) -> Result<Vec<u8>, String>;
360
361#[derive(Debug, Clone, Copy)]
363pub struct GssTokenRequest<'a> {
364 pub session_id: u64,
366 pub mechanism: EnterpriseAuthMechanism,
368 pub server_token: Option<&'a [u8]>,
370}
371
372pub type GssTokenProviderEx =
377 Arc<dyn for<'a> Fn(GssTokenRequest<'a>) -> Result<Vec<u8>, String> + Send + Sync>;
378
379#[derive(Debug, Clone, Copy, PartialEq, Eq)]
383pub struct AuthSettings {
384 pub allow_cleartext_password: bool,
386 pub allow_md5_password: bool,
388 pub allow_scram_sha_256: bool,
390 pub allow_kerberos_v5: bool,
392 pub allow_gssapi: bool,
394 pub allow_sspi: bool,
396 pub channel_binding: ScramChannelBindingMode,
398}
399
400impl Default for AuthSettings {
401 fn default() -> Self {
402 Self {
403 allow_cleartext_password: true,
404 allow_md5_password: true,
405 allow_scram_sha_256: true,
406 allow_kerberos_v5: false,
407 allow_gssapi: false,
408 allow_sspi: false,
409 channel_binding: ScramChannelBindingMode::Prefer,
410 }
411 }
412}
413
414impl AuthSettings {
415 pub fn scram_only() -> Self {
417 Self {
418 allow_cleartext_password: false,
419 allow_md5_password: false,
420 allow_scram_sha_256: true,
421 allow_kerberos_v5: false,
422 allow_gssapi: false,
423 allow_sspi: false,
424 channel_binding: ScramChannelBindingMode::Prefer,
425 }
426 }
427
428 pub fn gssapi_only() -> Self {
430 Self {
431 allow_cleartext_password: false,
432 allow_md5_password: false,
433 allow_scram_sha_256: false,
434 allow_kerberos_v5: true,
435 allow_gssapi: true,
436 allow_sspi: true,
437 channel_binding: ScramChannelBindingMode::Prefer,
438 }
439 }
440
441 pub(crate) fn has_any_password_method(self) -> bool {
442 self.allow_cleartext_password || self.allow_md5_password || self.allow_scram_sha_256
443 }
444}
445
446#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
448pub enum TlsMode {
449 #[default]
451 Disable,
452 Prefer,
454 Require,
456}
457
458impl TlsMode {
459 pub fn parse_sslmode(value: &str) -> Option<Self> {
461 match value.trim().to_ascii_lowercase().as_str() {
462 "disable" => Some(Self::Disable),
463 "allow" | "prefer" => Some(Self::Prefer),
464 "require" | "verify-ca" | "verify-full" => Some(Self::Require),
465 _ => None,
466 }
467 }
468}
469
470#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
477pub enum GssEncMode {
478 #[default]
480 Disable,
481 Prefer,
483 Require,
485}
486
487impl GssEncMode {
488 pub fn parse_gssencmode(value: &str) -> Option<Self> {
490 match value.trim().to_ascii_lowercase().as_str() {
491 "disable" => Some(Self::Disable),
492 "prefer" => Some(Self::Prefer),
493 "require" => Some(Self::Require),
494 _ => None,
495 }
496 }
497}
498
499#[derive(Clone, Default)]
501pub struct ConnectOptions {
502 pub tls_mode: TlsMode,
504 pub gss_enc_mode: GssEncMode,
506 pub tls_ca_cert_pem: Option<Vec<u8>>,
508 pub mtls: Option<TlsConfig>,
510 pub gss_token_provider: Option<GssTokenProvider>,
512 pub gss_token_provider_ex: Option<GssTokenProviderEx>,
514 pub auth: AuthSettings,
516}
517
518impl std::fmt::Debug for ConnectOptions {
519 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
520 f.debug_struct("ConnectOptions")
521 .field("tls_mode", &self.tls_mode)
522 .field("gss_enc_mode", &self.gss_enc_mode)
523 .field(
524 "tls_ca_cert_pem",
525 &self.tls_ca_cert_pem.as_ref().map(std::vec::Vec::len),
526 )
527 .field("mtls", &self.mtls.as_ref().map(|_| "<configured>"))
528 .field(
529 "gss_token_provider",
530 &self.gss_token_provider.as_ref().map(|_| "<configured>"),
531 )
532 .field(
533 "gss_token_provider_ex",
534 &self.gss_token_provider_ex.as_ref().map(|_| "<configured>"),
535 )
536 .field("auth", &self.auth)
537 .finish()
538 }
539}
540
541pub struct PgDriver {
543 #[allow(dead_code)]
544 connection: PgConnection,
545 rls_context: Option<RlsContext>,
547}
548
549impl PgDriver {
550 pub fn new(connection: PgConnection) -> Self {
552 Self {
553 connection,
554 rls_context: None,
555 }
556 }
557
558 pub fn builder() -> PgDriverBuilder {
571 PgDriverBuilder::new()
572 }
573
574 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
583 let connection = PgConnection::connect(host, port, user, database).await?;
584 Ok(Self::new(connection))
585 }
586
587 pub async fn connect_with_password(
590 host: &str,
591 port: u16,
592 user: &str,
593 database: &str,
594 password: &str,
595 ) -> PgResult<Self> {
596 let connection =
597 PgConnection::connect_with_password(host, port, user, database, Some(password)).await?;
598 Ok(Self::new(connection))
599 }
600
601 pub async fn connect_with_options(
603 host: &str,
604 port: u16,
605 user: &str,
606 database: &str,
607 password: Option<&str>,
608 options: ConnectOptions,
609 ) -> PgResult<Self> {
610 let connection =
611 PgConnection::connect_with_options(host, port, user, database, password, options)
612 .await?;
613 Ok(Self::new(connection))
614 }
615
616 pub async fn connect_env() -> PgResult<Self> {
627 let url = std::env::var("DATABASE_URL").map_err(|_| {
628 PgError::Connection("DATABASE_URL environment variable not set".to_string())
629 })?;
630 Self::connect_url(&url).await
631 }
632
633 pub async fn connect_url(url: &str) -> PgResult<Self> {
646 let (host, port, user, database, password) = Self::parse_database_url(url)?;
647
648 let mut pool_cfg = pool::PoolConfig::new(&host, port, &user, &database);
650 if let Some(pw) = &password {
651 pool_cfg = pool_cfg.password(pw);
652 }
653 if let Some(query) = url.split('?').nth(1) {
654 pool::apply_url_query_params(&mut pool_cfg, query, &host)?;
655 }
656
657 let opts = ConnectOptions {
658 tls_mode: pool_cfg.tls_mode,
659 gss_enc_mode: pool_cfg.gss_enc_mode,
660 tls_ca_cert_pem: pool_cfg.tls_ca_cert_pem,
661 mtls: pool_cfg.mtls,
662 gss_token_provider: pool_cfg.gss_token_provider,
663 gss_token_provider_ex: pool_cfg.gss_token_provider_ex,
664 auth: pool_cfg.auth_settings,
665 };
666
667 Self::connect_with_options(&host, port, &user, &database, password.as_deref(), opts).await
668 }
669
670 fn parse_database_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
677 let after_scheme = url.split("://").nth(1).ok_or_else(|| {
679 PgError::Connection("Invalid DATABASE_URL: missing scheme".to_string())
680 })?;
681
682 let (auth_part, host_db_part) = if let Some(at_pos) = after_scheme.rfind('@') {
684 (Some(&after_scheme[..at_pos]), &after_scheme[at_pos + 1..])
685 } else {
686 (None, after_scheme)
687 };
688
689 let (user, password) = if let Some(auth) = auth_part {
691 let parts: Vec<&str> = auth.splitn(2, ':').collect();
692 if parts.len() == 2 {
693 (
695 Self::percent_decode(parts[0]),
696 Some(Self::percent_decode(parts[1])),
697 )
698 } else {
699 (Self::percent_decode(parts[0]), None)
700 }
701 } else {
702 return Err(PgError::Connection(
703 "Invalid DATABASE_URL: missing user".to_string(),
704 ));
705 };
706
707 let (host_port, database) = if let Some(slash_pos) = host_db_part.find('/') {
709 let raw_db = &host_db_part[slash_pos + 1..];
710 let db = raw_db.split('?').next().unwrap_or(raw_db).to_string();
712 (&host_db_part[..slash_pos], db)
713 } else {
714 return Err(PgError::Connection(
715 "Invalid DATABASE_URL: missing database name".to_string(),
716 ));
717 };
718
719 let (host, port) = if let Some(colon_pos) = host_port.rfind(':') {
721 let port_str = &host_port[colon_pos + 1..];
722 let port = port_str
723 .parse::<u16>()
724 .map_err(|_| PgError::Connection(format!("Invalid port: {}", port_str)))?;
725 (host_port[..colon_pos].to_string(), port)
726 } else {
727 (host_port.to_string(), 5432) };
729
730 Ok((host, port, user, database, password))
731 }
732
733 fn percent_decode(s: &str) -> String {
736 let mut result = String::with_capacity(s.len());
737 let mut chars = s.chars().peekable();
738
739 while let Some(c) = chars.next() {
740 if c == '%' {
741 let hex: String = chars.by_ref().take(2).collect();
743 if hex.len() == 2
744 && let Ok(byte) = u8::from_str_radix(&hex, 16)
745 {
746 result.push(byte as char);
747 continue;
748 }
749 result.push('%');
751 result.push_str(&hex);
752 } else if c == '+' {
753 result.push('+');
756 } else {
757 result.push(c);
758 }
759 }
760
761 result
762 }
763
764 pub async fn connect_with_timeout(
775 host: &str,
776 port: u16,
777 user: &str,
778 database: &str,
779 password: &str,
780 timeout: std::time::Duration,
781 ) -> PgResult<Self> {
782 tokio::time::timeout(
783 timeout,
784 Self::connect_with_password(host, port, user, database, password),
785 )
786 .await
787 .map_err(|_| PgError::Timeout(format!("connection after {:?}", timeout)))?
788 }
789 pub fn clear_cache(&mut self) {
793 self.connection.clear_prepared_statement_state();
794 }
795
796 pub fn cache_stats(&self) -> (usize, usize) {
799 (
800 self.connection.stmt_cache.len(),
801 self.connection.stmt_cache.cap().get(),
802 )
803 }
804
805 pub async fn fetch_all(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
811 self.fetch_all_with_format(cmd, ResultFormat::Text).await
812 }
813
814 pub async fn fetch_all_with_format(
820 &mut self,
821 cmd: &Qail,
822 result_format: ResultFormat,
823 ) -> PgResult<Vec<PgRow>> {
824 self.fetch_all_cached_with_format(cmd, result_format).await
826 }
827
828 pub async fn fetch_typed<T: row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Vec<T>> {
836 let rows = self.fetch_all(cmd).await?;
837 Ok(rows.iter().map(T::from_row).collect())
838 }
839
840 pub async fn fetch_one_typed<T: row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Option<T>> {
843 let rows = self.fetch_all(cmd).await?;
844 Ok(rows.first().map(T::from_row))
845 }
846
847 pub async fn fetch_all_uncached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
853 self.fetch_all_uncached_with_format(cmd, ResultFormat::Text)
854 .await
855 }
856
857 pub async fn fetch_all_uncached_with_format(
859 &mut self,
860 cmd: &Qail,
861 result_format: ResultFormat,
862 ) -> PgResult<Vec<PgRow>> {
863 use crate::protocol::AstEncoder;
864
865 AstEncoder::encode_cmd_reuse_into_with_result_format(
866 cmd,
867 &mut self.connection.sql_buf,
868 &mut self.connection.params_buf,
869 &mut self.connection.write_buf,
870 result_format.as_wire_code(),
871 )
872 .map_err(|e| PgError::Encode(e.to_string()))?;
873
874 self.connection.flush_write_buf().await?;
875
876 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
877 let mut column_info: Option<Arc<ColumnInfo>> = None;
878
879 let mut error: Option<PgError> = None;
880
881 loop {
882 let msg = self.connection.recv().await?;
883 match msg {
884 crate::protocol::BackendMessage::ParseComplete
885 | crate::protocol::BackendMessage::BindComplete => {}
886 crate::protocol::BackendMessage::RowDescription(fields) => {
887 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
888 }
889 crate::protocol::BackendMessage::DataRow(data) => {
890 if error.is_none() {
891 rows.push(PgRow {
892 columns: data,
893 column_info: column_info.clone(),
894 });
895 }
896 }
897 crate::protocol::BackendMessage::CommandComplete(_) => {}
898 crate::protocol::BackendMessage::ReadyForQuery(_) => {
899 if let Some(err) = error {
900 return Err(err);
901 }
902 return Ok(rows);
903 }
904 crate::protocol::BackendMessage::ErrorResponse(err) => {
905 if error.is_none() {
906 error = Some(PgError::QueryServer(err.into()));
907 }
908 }
909 _ => {}
910 }
911 }
912 }
913
914 pub async fn fetch_all_fast(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
918 self.fetch_all_fast_with_format(cmd, ResultFormat::Text)
919 .await
920 }
921
922 pub async fn fetch_all_fast_with_format(
924 &mut self,
925 cmd: &Qail,
926 result_format: ResultFormat,
927 ) -> PgResult<Vec<PgRow>> {
928 use crate::protocol::AstEncoder;
929
930 AstEncoder::encode_cmd_reuse_into_with_result_format(
931 cmd,
932 &mut self.connection.sql_buf,
933 &mut self.connection.params_buf,
934 &mut self.connection.write_buf,
935 result_format.as_wire_code(),
936 )
937 .map_err(|e| PgError::Encode(e.to_string()))?;
938
939 self.connection.flush_write_buf().await?;
940
941 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
943 let mut error: Option<PgError> = None;
944
945 loop {
946 let res = self.connection.recv_with_data_fast().await;
947 match res {
948 Ok((msg_type, data)) => {
949 match msg_type {
950 b'D' => {
951 if error.is_none()
953 && let Some(columns) = data
954 {
955 rows.push(PgRow {
956 columns,
957 column_info: None, });
959 }
960 }
961 b'Z' => {
962 if let Some(err) = error {
964 return Err(err);
965 }
966 return Ok(rows);
967 }
968 _ => {} }
970 }
971 Err(e) => {
972 if error.is_none() {
981 error = Some(e);
982 }
983 }
988 }
989 }
990 }
991
992 pub async fn fetch_one(&mut self, cmd: &Qail) -> PgResult<PgRow> {
994 let rows = self.fetch_all(cmd).await?;
995 rows.into_iter().next().ok_or(PgError::NoRows)
996 }
997
998 pub async fn fetch_all_cached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
1007 self.fetch_all_cached_with_format(cmd, ResultFormat::Text)
1008 .await
1009 }
1010
1011 pub async fn fetch_all_cached_with_format(
1013 &mut self,
1014 cmd: &Qail,
1015 result_format: ResultFormat,
1016 ) -> PgResult<Vec<PgRow>> {
1017 let mut retried = false;
1018 loop {
1019 match self
1020 .fetch_all_cached_with_format_once(cmd, result_format)
1021 .await
1022 {
1023 Ok(rows) => return Ok(rows),
1024 Err(err) if !retried && err.is_prepared_statement_retryable() => {
1025 retried = true;
1026 self.connection.clear_prepared_statement_state();
1027 }
1028 Err(err) => return Err(err),
1029 }
1030 }
1031 }
1032
1033 async fn fetch_all_cached_with_format_once(
1034 &mut self,
1035 cmd: &Qail,
1036 result_format: ResultFormat,
1037 ) -> PgResult<Vec<PgRow>> {
1038 use crate::protocol::AstEncoder;
1039 use std::collections::hash_map::DefaultHasher;
1040 use std::hash::{Hash, Hasher};
1041
1042 self.connection.sql_buf.clear();
1043 self.connection.params_buf.clear();
1044
1045 match cmd.action {
1047 qail_core::ast::Action::Get | qail_core::ast::Action::With => {
1048 crate::protocol::ast_encoder::dml::encode_select(
1049 cmd,
1050 &mut self.connection.sql_buf,
1051 &mut self.connection.params_buf,
1052 )?;
1053 }
1054 qail_core::ast::Action::Add => {
1055 crate::protocol::ast_encoder::dml::encode_insert(
1056 cmd,
1057 &mut self.connection.sql_buf,
1058 &mut self.connection.params_buf,
1059 )?;
1060 }
1061 qail_core::ast::Action::Set => {
1062 crate::protocol::ast_encoder::dml::encode_update(
1063 cmd,
1064 &mut self.connection.sql_buf,
1065 &mut self.connection.params_buf,
1066 )?;
1067 }
1068 qail_core::ast::Action::Del => {
1069 crate::protocol::ast_encoder::dml::encode_delete(
1070 cmd,
1071 &mut self.connection.sql_buf,
1072 &mut self.connection.params_buf,
1073 )?;
1074 }
1075 _ => {
1076 let (sql, params) =
1078 AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
1079 let raw_rows = self
1080 .connection
1081 .query_cached_with_result_format(&sql, ¶ms, result_format.as_wire_code())
1082 .await?;
1083 return Ok(raw_rows
1084 .into_iter()
1085 .map(|data| PgRow {
1086 columns: data,
1087 column_info: None,
1088 })
1089 .collect());
1090 }
1091 }
1092
1093 let mut hasher = DefaultHasher::new();
1094 self.connection.sql_buf.hash(&mut hasher);
1095 let sql_hash = hasher.finish();
1096
1097 let is_cache_miss = !self.connection.stmt_cache.contains(&sql_hash);
1098
1099 self.connection.write_buf.clear();
1101
1102 let stmt_name = if let Some(name) = self.connection.stmt_cache.get(&sql_hash) {
1103 name
1104 } else {
1105 let name = format!("qail_{:x}", sql_hash);
1106
1107 self.connection.evict_prepared_if_full();
1109
1110 let sql_str = std::str::from_utf8(&self.connection.sql_buf).unwrap_or("");
1111
1112 use crate::protocol::PgEncoder;
1114 let parse_msg = PgEncoder::encode_parse(&name, sql_str, &[]);
1115 let describe_msg = PgEncoder::encode_describe(false, &name);
1116 self.connection.write_buf.extend_from_slice(&parse_msg);
1117 self.connection.write_buf.extend_from_slice(&describe_msg);
1118
1119 self.connection.stmt_cache.put(sql_hash, name.clone());
1120 self.connection
1121 .prepared_statements
1122 .insert(name.clone(), sql_str.to_string());
1123
1124 name
1125 };
1126
1127 use crate::protocol::PgEncoder;
1129 PgEncoder::encode_bind_to_with_result_format(
1130 &mut self.connection.write_buf,
1131 &stmt_name,
1132 &self.connection.params_buf,
1133 result_format.as_wire_code(),
1134 )
1135 .map_err(|e| PgError::Encode(e.to_string()))?;
1136 PgEncoder::encode_execute_to(&mut self.connection.write_buf);
1137 PgEncoder::encode_sync_to(&mut self.connection.write_buf);
1138
1139 self.connection.flush_write_buf().await?;
1141
1142 let cached_column_info = self.connection.column_info_cache.get(&sql_hash).cloned();
1144
1145 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
1146 let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
1147 let mut error: Option<PgError> = None;
1148
1149 loop {
1150 let msg = self.connection.recv().await?;
1151 match msg {
1152 crate::protocol::BackendMessage::ParseComplete
1153 | crate::protocol::BackendMessage::BindComplete => {}
1154 crate::protocol::BackendMessage::ParameterDescription(_) => {
1155 }
1157 crate::protocol::BackendMessage::RowDescription(fields) => {
1158 let info = Arc::new(ColumnInfo::from_fields(&fields));
1160 if is_cache_miss {
1161 self.connection
1162 .column_info_cache
1163 .insert(sql_hash, info.clone());
1164 }
1165 column_info = Some(info);
1166 }
1167 crate::protocol::BackendMessage::DataRow(data) => {
1168 if error.is_none() {
1169 rows.push(PgRow {
1170 columns: data,
1171 column_info: column_info.clone(),
1172 });
1173 }
1174 }
1175 crate::protocol::BackendMessage::CommandComplete(_) => {}
1176 crate::protocol::BackendMessage::NoData => {
1177 }
1179 crate::protocol::BackendMessage::ReadyForQuery(_) => {
1180 if let Some(err) = error {
1181 return Err(err);
1182 }
1183 return Ok(rows);
1184 }
1185 crate::protocol::BackendMessage::ErrorResponse(err) => {
1186 if error.is_none() {
1187 let query_err = PgError::QueryServer(err.into());
1188 if query_err.is_prepared_statement_retryable() {
1189 self.connection.clear_prepared_statement_state();
1190 }
1191 error = Some(query_err);
1192 }
1193 }
1194 _ => {}
1195 }
1196 }
1197 }
1198
1199 pub async fn execute(&mut self, cmd: &Qail) -> PgResult<u64> {
1201 use crate::protocol::AstEncoder;
1202
1203 let wire_bytes = AstEncoder::encode_cmd_reuse(
1204 cmd,
1205 &mut self.connection.sql_buf,
1206 &mut self.connection.params_buf,
1207 )
1208 .map_err(|e| PgError::Encode(e.to_string()))?;
1209
1210 self.connection.send_bytes(&wire_bytes).await?;
1211
1212 let mut affected = 0u64;
1213 let mut error: Option<PgError> = None;
1214
1215 loop {
1216 let msg = self.connection.recv().await?;
1217 match msg {
1218 crate::protocol::BackendMessage::ParseComplete
1219 | crate::protocol::BackendMessage::BindComplete => {}
1220 crate::protocol::BackendMessage::RowDescription(_) => {}
1221 crate::protocol::BackendMessage::DataRow(_) => {}
1222 crate::protocol::BackendMessage::CommandComplete(tag) => {
1223 if error.is_none()
1224 && let Some(n) = tag.split_whitespace().last()
1225 {
1226 affected = n.parse().unwrap_or(0);
1227 }
1228 }
1229 crate::protocol::BackendMessage::ReadyForQuery(_) => {
1230 if let Some(err) = error {
1231 return Err(err);
1232 }
1233 return Ok(affected);
1234 }
1235 crate::protocol::BackendMessage::ErrorResponse(err) => {
1236 if error.is_none() {
1237 error = Some(PgError::QueryServer(err.into()));
1238 }
1239 }
1240 _ => {}
1241 }
1242 }
1243 }
1244
1245 pub async fn query_ast(&mut self, cmd: &Qail) -> PgResult<QueryResult> {
1249 self.query_ast_with_format(cmd, ResultFormat::Text).await
1250 }
1251
1252 pub async fn query_ast_with_format(
1254 &mut self,
1255 cmd: &Qail,
1256 result_format: ResultFormat,
1257 ) -> PgResult<QueryResult> {
1258 use crate::protocol::AstEncoder;
1259
1260 let wire_bytes = AstEncoder::encode_cmd_reuse_with_result_format(
1261 cmd,
1262 &mut self.connection.sql_buf,
1263 &mut self.connection.params_buf,
1264 result_format.as_wire_code(),
1265 )
1266 .map_err(|e| PgError::Encode(e.to_string()))?;
1267
1268 self.connection.send_bytes(&wire_bytes).await?;
1269
1270 let mut columns: Vec<String> = Vec::new();
1271 let mut rows: Vec<Vec<Option<String>>> = Vec::new();
1272 let mut error: Option<PgError> = None;
1273
1274 loop {
1275 let msg = self.connection.recv().await?;
1276 match msg {
1277 crate::protocol::BackendMessage::ParseComplete
1278 | crate::protocol::BackendMessage::BindComplete => {}
1279 crate::protocol::BackendMessage::RowDescription(fields) => {
1280 columns = fields.into_iter().map(|f| f.name).collect();
1281 }
1282 crate::protocol::BackendMessage::DataRow(data) => {
1283 if error.is_none() {
1284 let row: Vec<Option<String>> = data
1285 .into_iter()
1286 .map(|col| col.map(|bytes| String::from_utf8_lossy(&bytes).to_string()))
1287 .collect();
1288 rows.push(row);
1289 }
1290 }
1291 crate::protocol::BackendMessage::CommandComplete(_) => {}
1292 crate::protocol::BackendMessage::NoData => {}
1293 crate::protocol::BackendMessage::ReadyForQuery(_) => {
1294 if let Some(err) = error {
1295 return Err(err);
1296 }
1297 return Ok(QueryResult { columns, rows });
1298 }
1299 crate::protocol::BackendMessage::ErrorResponse(err) => {
1300 if error.is_none() {
1301 error = Some(PgError::QueryServer(err.into()));
1302 }
1303 }
1304 _ => {}
1305 }
1306 }
1307 }
1308
1309 pub async fn begin(&mut self) -> PgResult<()> {
1313 self.connection.begin_transaction().await
1314 }
1315
1316 pub async fn commit(&mut self) -> PgResult<()> {
1318 self.connection.commit().await
1319 }
1320
1321 pub async fn rollback(&mut self) -> PgResult<()> {
1323 self.connection.rollback().await
1324 }
1325
1326 pub async fn savepoint(&mut self, name: &str) -> PgResult<()> {
1339 self.connection.savepoint(name).await
1340 }
1341
1342 pub async fn rollback_to(&mut self, name: &str) -> PgResult<()> {
1346 self.connection.rollback_to(name).await
1347 }
1348
1349 pub async fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
1352 self.connection.release_savepoint(name).await
1353 }
1354
1355 pub async fn execute_batch(&mut self, cmds: &[Qail]) -> PgResult<Vec<u64>> {
1369 self.begin().await?;
1370 let mut results = Vec::with_capacity(cmds.len());
1371 for cmd in cmds {
1372 match self.execute(cmd).await {
1373 Ok(n) => results.push(n),
1374 Err(e) => {
1375 self.rollback().await?;
1376 return Err(e);
1377 }
1378 }
1379 }
1380 self.commit().await?;
1381 Ok(results)
1382 }
1383
1384 pub async fn set_statement_timeout(&mut self, ms: u32) -> PgResult<()> {
1392 self.execute_raw(&format!("SET statement_timeout = {}", ms))
1393 .await
1394 }
1395
1396 pub async fn reset_statement_timeout(&mut self) -> PgResult<()> {
1398 self.execute_raw("RESET statement_timeout").await
1399 }
1400
1401 pub async fn set_rls_context(&mut self, ctx: rls::RlsContext) -> PgResult<()> {
1419 let sql = rls::context_to_sql(&ctx);
1420 self.execute_raw(&sql).await?;
1421 self.rls_context = Some(ctx);
1422 Ok(())
1423 }
1424
1425 pub async fn clear_rls_context(&mut self) -> PgResult<()> {
1430 self.execute_raw(rls::reset_sql()).await?;
1431 self.rls_context = None;
1432 Ok(())
1433 }
1434
1435 pub fn rls_context(&self) -> Option<&rls::RlsContext> {
1437 self.rls_context.as_ref()
1438 }
1439
1440 pub async fn pipeline_batch(&mut self, cmds: &[Qail]) -> PgResult<usize> {
1452 self.connection.pipeline_ast_fast(cmds).await
1453 }
1454
1455 pub async fn pipeline_fetch(&mut self, cmds: &[Qail]) -> PgResult<Vec<Vec<PgRow>>> {
1457 let raw_results = self.connection.pipeline_ast(cmds).await?;
1458
1459 let results: Vec<Vec<PgRow>> = raw_results
1460 .into_iter()
1461 .map(|rows| {
1462 rows.into_iter()
1463 .map(|columns| PgRow {
1464 columns,
1465 column_info: None,
1466 })
1467 .collect()
1468 })
1469 .collect();
1470
1471 Ok(results)
1472 }
1473
1474 pub async fn prepare(&mut self, sql: &str) -> PgResult<PreparedStatement> {
1476 self.connection.prepare(sql).await
1477 }
1478
1479 pub async fn pipeline_prepared_fast(
1481 &mut self,
1482 stmt: &PreparedStatement,
1483 params_batch: &[Vec<Option<Vec<u8>>>],
1484 ) -> PgResult<usize> {
1485 self.connection
1486 .pipeline_prepared_fast(stmt, params_batch)
1487 .await
1488 }
1489
1490 pub async fn execute_raw(&mut self, sql: &str) -> PgResult<()> {
1497 if sql.as_bytes().contains(&0) {
1499 return Err(crate::PgError::Protocol(
1500 "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
1501 ));
1502 }
1503 self.connection.execute_simple(sql).await
1504 }
1505
1506 pub async fn fetch_raw(&mut self, sql: &str) -> PgResult<Vec<PgRow>> {
1510 if sql.as_bytes().contains(&0) {
1511 return Err(crate::PgError::Protocol(
1512 "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
1513 ));
1514 }
1515
1516 use crate::protocol::PgEncoder;
1517 use tokio::io::AsyncWriteExt;
1518
1519 let msg = PgEncoder::encode_query_string(sql);
1521 self.connection.stream.write_all(&msg).await?;
1522
1523 let mut rows: Vec<PgRow> = Vec::new();
1524 let mut column_info: Option<std::sync::Arc<ColumnInfo>> = None;
1525
1526 let mut error: Option<PgError> = None;
1527
1528 loop {
1529 let msg = self.connection.recv().await?;
1530 match msg {
1531 crate::protocol::BackendMessage::RowDescription(fields) => {
1532 column_info = Some(std::sync::Arc::new(ColumnInfo::from_fields(&fields)));
1533 }
1534 crate::protocol::BackendMessage::DataRow(data) => {
1535 if error.is_none() {
1536 rows.push(PgRow {
1537 columns: data,
1538 column_info: column_info.clone(),
1539 });
1540 }
1541 }
1542 crate::protocol::BackendMessage::CommandComplete(_) => {}
1543 crate::protocol::BackendMessage::ReadyForQuery(_) => {
1544 if let Some(err) = error {
1545 return Err(err);
1546 }
1547 return Ok(rows);
1548 }
1549 crate::protocol::BackendMessage::ErrorResponse(err) => {
1550 if error.is_none() {
1551 error = Some(PgError::QueryServer(err.into()));
1552 }
1553 }
1554 _ => {}
1555 }
1556 }
1557 }
1558
1559 pub async fn copy_bulk(
1575 &mut self,
1576 cmd: &Qail,
1577 rows: &[Vec<qail_core::ast::Value>],
1578 ) -> PgResult<u64> {
1579 use qail_core::ast::Action;
1580
1581 if cmd.action != Action::Add {
1582 return Err(PgError::Query(
1583 "copy_bulk requires Qail::Add action".to_string(),
1584 ));
1585 }
1586
1587 let table = &cmd.table;
1588
1589 let columns: Vec<String> = cmd
1590 .columns
1591 .iter()
1592 .filter_map(|expr| {
1593 use qail_core::ast::Expr;
1594 match expr {
1595 Expr::Named(name) => Some(name.clone()),
1596 Expr::Aliased { name, .. } => Some(name.clone()),
1597 Expr::Star => None, _ => None,
1599 }
1600 })
1601 .collect();
1602
1603 if columns.is_empty() {
1604 return Err(PgError::Query(
1605 "copy_bulk requires columns in Qail".to_string(),
1606 ));
1607 }
1608
1609 self.connection.copy_in_fast(table, &columns, rows).await
1611 }
1612
1613 pub async fn copy_bulk_bytes(&mut self, cmd: &Qail, data: &[u8]) -> PgResult<u64> {
1626 use qail_core::ast::Action;
1627
1628 if cmd.action != Action::Add {
1629 return Err(PgError::Query(
1630 "copy_bulk_bytes requires Qail::Add action".to_string(),
1631 ));
1632 }
1633
1634 let table = &cmd.table;
1635 let columns: Vec<String> = cmd
1636 .columns
1637 .iter()
1638 .filter_map(|expr| {
1639 use qail_core::ast::Expr;
1640 match expr {
1641 Expr::Named(name) => Some(name.clone()),
1642 Expr::Aliased { name, .. } => Some(name.clone()),
1643 _ => None,
1644 }
1645 })
1646 .collect();
1647
1648 if columns.is_empty() {
1649 return Err(PgError::Query(
1650 "copy_bulk_bytes requires columns in Qail".to_string(),
1651 ));
1652 }
1653
1654 self.connection.copy_in_raw(table, &columns, data).await
1656 }
1657
1658 pub async fn copy_export_table(
1666 &mut self,
1667 table: &str,
1668 columns: &[String],
1669 ) -> PgResult<Vec<u8>> {
1670 let cols = columns.join(", ");
1671 let sql = format!("COPY {} ({}) TO STDOUT", table, cols);
1672
1673 self.connection.copy_out_raw(&sql).await
1674 }
1675
1676 pub async fn stream_cmd(&mut self, cmd: &Qail, batch_size: usize) -> PgResult<Vec<Vec<PgRow>>> {
1690 use std::sync::atomic::{AtomicU64, Ordering};
1691 static CURSOR_ID: AtomicU64 = AtomicU64::new(0);
1692
1693 let cursor_name = format!("qail_cursor_{}", CURSOR_ID.fetch_add(1, Ordering::SeqCst));
1694
1695 use crate::protocol::AstEncoder;
1697 let mut sql_buf = bytes::BytesMut::with_capacity(256);
1698 let mut params: Vec<Option<Vec<u8>>> = Vec::new();
1699 AstEncoder::encode_select_sql(cmd, &mut sql_buf, &mut params)
1700 .map_err(|e| PgError::Encode(e.to_string()))?;
1701 let sql = String::from_utf8_lossy(&sql_buf).to_string();
1702
1703 self.connection.begin_transaction().await?;
1705
1706 self.connection
1709 .declare_cursor(&cursor_name, &sql, ¶ms)
1710 .await?;
1711
1712 let mut all_batches = Vec::new();
1714 while let Some(rows) = self
1715 .connection
1716 .fetch_cursor(&cursor_name, batch_size)
1717 .await?
1718 {
1719 let pg_rows: Vec<PgRow> = rows
1720 .into_iter()
1721 .map(|cols| PgRow {
1722 columns: cols,
1723 column_info: None,
1724 })
1725 .collect();
1726 all_batches.push(pg_rows);
1727 }
1728
1729 self.connection.close_cursor(&cursor_name).await?;
1730 self.connection.commit().await?;
1731
1732 Ok(all_batches)
1733 }
1734}
1735
1736#[derive(Default)]
1753pub struct PgDriverBuilder {
1754 host: Option<String>,
1755 port: Option<u16>,
1756 user: Option<String>,
1757 database: Option<String>,
1758 password: Option<String>,
1759 timeout: Option<std::time::Duration>,
1760 connect_options: ConnectOptions,
1761}
1762
1763impl PgDriverBuilder {
1764 pub fn new() -> Self {
1766 Self::default()
1767 }
1768
1769 pub fn host(mut self, host: impl Into<String>) -> Self {
1771 self.host = Some(host.into());
1772 self
1773 }
1774
1775 pub fn port(mut self, port: u16) -> Self {
1777 self.port = Some(port);
1778 self
1779 }
1780
1781 pub fn user(mut self, user: impl Into<String>) -> Self {
1783 self.user = Some(user.into());
1784 self
1785 }
1786
1787 pub fn database(mut self, database: impl Into<String>) -> Self {
1789 self.database = Some(database.into());
1790 self
1791 }
1792
1793 pub fn password(mut self, password: impl Into<String>) -> Self {
1795 self.password = Some(password.into());
1796 self
1797 }
1798
1799 pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
1801 self.timeout = Some(timeout);
1802 self
1803 }
1804
1805 pub fn tls_mode(mut self, mode: TlsMode) -> Self {
1807 self.connect_options.tls_mode = mode;
1808 self
1809 }
1810
1811 pub fn gss_enc_mode(mut self, mode: GssEncMode) -> Self {
1813 self.connect_options.gss_enc_mode = mode;
1814 self
1815 }
1816
1817 pub fn tls_ca_cert_pem(mut self, ca_pem: Vec<u8>) -> Self {
1819 self.connect_options.tls_ca_cert_pem = Some(ca_pem);
1820 self
1821 }
1822
1823 pub fn mtls(mut self, config: TlsConfig) -> Self {
1825 self.connect_options.mtls = Some(config);
1826 self.connect_options.tls_mode = TlsMode::Require;
1827 self
1828 }
1829
1830 pub fn auth_settings(mut self, settings: AuthSettings) -> Self {
1832 self.connect_options.auth = settings;
1833 self
1834 }
1835
1836 pub fn channel_binding_mode(mut self, mode: ScramChannelBindingMode) -> Self {
1838 self.connect_options.auth.channel_binding = mode;
1839 self
1840 }
1841
1842 pub fn gss_token_provider(mut self, provider: GssTokenProvider) -> Self {
1844 self.connect_options.gss_token_provider = Some(provider);
1845 self
1846 }
1847
1848 pub fn gss_token_provider_ex(mut self, provider: GssTokenProviderEx) -> Self {
1850 self.connect_options.gss_token_provider_ex = Some(provider);
1851 self
1852 }
1853
1854 pub async fn connect(self) -> PgResult<PgDriver> {
1856 let host = self.host.unwrap_or_else(|| "127.0.0.1".to_string());
1857 let port = self.port.unwrap_or(5432);
1858 let user = self
1859 .user
1860 .ok_or_else(|| PgError::Connection("User is required".to_string()))?;
1861 let database = self
1862 .database
1863 .ok_or_else(|| PgError::Connection("Database is required".to_string()))?;
1864
1865 let password = self.password;
1866 let options = self.connect_options;
1867
1868 if let Some(timeout) = self.timeout {
1869 let options = options.clone();
1870 tokio::time::timeout(
1871 timeout,
1872 PgDriver::connect_with_options(
1873 &host,
1874 port,
1875 &user,
1876 &database,
1877 password.as_deref(),
1878 options,
1879 ),
1880 )
1881 .await
1882 .map_err(|_| PgError::Timeout(format!("connection after {:?}", timeout)))?
1883 } else {
1884 PgDriver::connect_with_options(
1885 &host,
1886 port,
1887 &user,
1888 &database,
1889 password.as_deref(),
1890 options,
1891 )
1892 .await
1893 }
1894 }
1895}
1896
1897#[cfg(test)]
1898mod tests {
1899 use super::{PgError, PgServerError};
1900
1901 fn server_error(code: &str, message: &str) -> PgError {
1902 PgError::QueryServer(PgServerError {
1903 severity: "ERROR".to_string(),
1904 code: code.to_string(),
1905 message: message.to_string(),
1906 detail: None,
1907 hint: None,
1908 })
1909 }
1910
1911 #[test]
1912 fn prepared_statement_missing_is_retryable() {
1913 let err = server_error("26000", "prepared statement \"s1\" does not exist");
1914 assert!(err.is_prepared_statement_retryable());
1915 }
1916
1917 #[test]
1918 fn cached_plan_replanned_is_retryable() {
1919 let err = server_error("0A000", "cached plan must be replanned");
1920 assert!(err.is_prepared_statement_retryable());
1921 }
1922
1923 #[test]
1924 fn unrelated_server_error_is_not_retryable() {
1925 let err = server_error("23505", "duplicate key value violates unique constraint");
1926 assert!(!err.is_prepared_statement_retryable());
1927 }
1928
1929 #[test]
1934 fn serialization_failure_is_transient() {
1935 let err = server_error("40001", "could not serialize access");
1936 assert!(err.is_transient_server_error());
1937 }
1938
1939 #[test]
1940 fn deadlock_detected_is_transient() {
1941 let err = server_error("40P01", "deadlock detected");
1942 assert!(err.is_transient_server_error());
1943 }
1944
1945 #[test]
1946 fn cannot_connect_now_is_transient() {
1947 let err = server_error("57P03", "the database system is starting up");
1948 assert!(err.is_transient_server_error());
1949 }
1950
1951 #[test]
1952 fn admin_shutdown_is_transient() {
1953 let err = server_error(
1954 "57P01",
1955 "terminating connection due to administrator command",
1956 );
1957 assert!(err.is_transient_server_error());
1958 }
1959
1960 #[test]
1961 fn connection_exception_class_is_transient() {
1962 let err = server_error("08006", "connection failure");
1963 assert!(err.is_transient_server_error());
1964 }
1965
1966 #[test]
1967 fn connection_does_not_exist_is_transient() {
1968 let err = server_error("08003", "connection does not exist");
1969 assert!(err.is_transient_server_error());
1970 }
1971
1972 #[test]
1973 fn unique_violation_is_not_transient() {
1974 let err = server_error("23505", "duplicate key value violates unique constraint");
1975 assert!(!err.is_transient_server_error());
1976 }
1977
1978 #[test]
1979 fn syntax_error_is_not_transient() {
1980 let err = server_error("42601", "syntax error at or near \"SELECT\"");
1981 assert!(!err.is_transient_server_error());
1982 }
1983
1984 #[test]
1985 fn timeout_error_is_transient() {
1986 let err = PgError::Timeout("query after 30s".to_string());
1987 assert!(err.is_transient_server_error());
1988 }
1989
1990 #[test]
1991 fn io_connection_reset_is_transient() {
1992 let err = PgError::Io(std::io::Error::new(
1993 std::io::ErrorKind::ConnectionReset,
1994 "connection reset by peer",
1995 ));
1996 assert!(err.is_transient_server_error());
1997 }
1998
1999 #[test]
2000 fn io_permission_denied_is_not_transient() {
2001 let err = PgError::Io(std::io::Error::new(
2002 std::io::ErrorKind::PermissionDenied,
2003 "permission denied",
2004 ));
2005 assert!(!err.is_transient_server_error());
2006 }
2007
2008 #[test]
2009 fn connection_error_is_transient() {
2010 let err = PgError::Connection("host not found".to_string());
2011 assert!(err.is_transient_server_error());
2012 }
2013
2014 #[test]
2015 fn prepared_stmt_retryable_counts_as_transient() {
2016 let err = server_error("26000", "prepared statement \"s1\" does not exist");
2017 assert!(err.is_transient_server_error());
2018 }
2019
2020 #[test]
2025 fn tls_mode_parse_disable() {
2026 assert_eq!(
2027 super::TlsMode::parse_sslmode("disable"),
2028 Some(super::TlsMode::Disable)
2029 );
2030 }
2031
2032 #[test]
2033 fn tls_mode_parse_prefer_variants() {
2034 assert_eq!(
2035 super::TlsMode::parse_sslmode("prefer"),
2036 Some(super::TlsMode::Prefer)
2037 );
2038 assert_eq!(
2039 super::TlsMode::parse_sslmode("allow"),
2040 Some(super::TlsMode::Prefer),
2041 "libpq 'allow' maps to Prefer"
2042 );
2043 }
2044
2045 #[test]
2046 fn tls_mode_parse_require_variants() {
2047 assert_eq!(
2050 super::TlsMode::parse_sslmode("require"),
2051 Some(super::TlsMode::Require)
2052 );
2053 assert_eq!(
2054 super::TlsMode::parse_sslmode("verify-ca"),
2055 Some(super::TlsMode::Require),
2056 "verify-ca → Require (CA validation at TLS layer)"
2057 );
2058 assert_eq!(
2059 super::TlsMode::parse_sslmode("verify-full"),
2060 Some(super::TlsMode::Require),
2061 "verify-full → Require (hostname validation at TLS layer)"
2062 );
2063 }
2064
2065 #[test]
2066 fn tls_mode_parse_case_insensitive() {
2067 assert_eq!(
2068 super::TlsMode::parse_sslmode("REQUIRE"),
2069 Some(super::TlsMode::Require)
2070 );
2071 assert_eq!(
2072 super::TlsMode::parse_sslmode("Verify-Full"),
2073 Some(super::TlsMode::Require)
2074 );
2075 }
2076
2077 #[test]
2078 fn tls_mode_parse_unknown_returns_none() {
2079 assert_eq!(super::TlsMode::parse_sslmode("invalid"), None);
2080 assert_eq!(super::TlsMode::parse_sslmode(""), None);
2081 }
2082
2083 #[test]
2084 fn tls_mode_parse_trims_whitespace() {
2085 assert_eq!(
2086 super::TlsMode::parse_sslmode(" require "),
2087 Some(super::TlsMode::Require)
2088 );
2089 }
2090
2091 #[test]
2092 fn tls_mode_default_is_disable() {
2093 assert_eq!(super::TlsMode::default(), super::TlsMode::Disable);
2094 }
2095
2096 #[test]
2101 fn auth_default_allows_all_password_methods() {
2102 let auth = super::AuthSettings::default();
2103 assert!(auth.allow_cleartext_password);
2104 assert!(auth.allow_md5_password);
2105 assert!(auth.allow_scram_sha_256);
2106 assert!(auth.has_any_password_method());
2107 }
2108
2109 #[test]
2110 fn auth_default_disables_enterprise_methods() {
2111 let auth = super::AuthSettings::default();
2112 assert!(
2113 !auth.allow_kerberos_v5,
2114 "Kerberos V5 should be disabled by default"
2115 );
2116 assert!(!auth.allow_gssapi, "GSSAPI should be disabled by default");
2117 assert!(!auth.allow_sspi, "SSPI should be disabled by default");
2118 }
2119
2120 #[test]
2121 fn auth_scram_only_restricts_to_scram() {
2122 let auth = super::AuthSettings::scram_only();
2123 assert!(auth.allow_scram_sha_256);
2125 assert!(!auth.allow_cleartext_password);
2126 assert!(!auth.allow_md5_password);
2127 assert!(!auth.allow_kerberos_v5);
2129 assert!(!auth.allow_gssapi);
2130 assert!(!auth.allow_sspi);
2131 assert!(auth.has_any_password_method());
2133 }
2134
2135 #[test]
2136 fn auth_gssapi_only_disables_all_passwords() {
2137 let auth = super::AuthSettings::gssapi_only();
2138 assert!(!auth.allow_cleartext_password);
2140 assert!(!auth.allow_md5_password);
2141 assert!(!auth.allow_scram_sha_256);
2142 assert!(!auth.has_any_password_method());
2143 assert!(auth.allow_kerberos_v5);
2145 assert!(auth.allow_gssapi);
2146 assert!(auth.allow_sspi);
2147 }
2148
2149 #[test]
2150 fn auth_has_any_password_when_only_cleartext() {
2151 let auth = super::AuthSettings {
2152 allow_cleartext_password: true,
2153 allow_md5_password: false,
2154 allow_scram_sha_256: false,
2155 ..super::AuthSettings::default()
2156 };
2157 assert!(auth.has_any_password_method());
2158 }
2159
2160 #[test]
2161 fn auth_no_password_method_when_all_disabled() {
2162 let auth = super::AuthSettings {
2163 allow_cleartext_password: false,
2164 allow_md5_password: false,
2165 allow_scram_sha_256: false,
2166 ..super::AuthSettings::default()
2167 };
2168 assert!(!auth.has_any_password_method());
2169 }
2170
2171 #[test]
2172 fn auth_enterprise_mechanisms_are_distinct() {
2173 assert_ne!(
2175 super::EnterpriseAuthMechanism::KerberosV5,
2176 super::EnterpriseAuthMechanism::GssApi
2177 );
2178 assert_ne!(
2179 super::EnterpriseAuthMechanism::GssApi,
2180 super::EnterpriseAuthMechanism::Sspi
2181 );
2182 assert_ne!(
2183 super::EnterpriseAuthMechanism::KerberosV5,
2184 super::EnterpriseAuthMechanism::Sspi
2185 );
2186 }
2187
2188 #[test]
2189 fn auth_channel_binding_default_is_prefer() {
2190 let auth = super::AuthSettings::default();
2191 assert_eq!(auth.channel_binding, super::ScramChannelBindingMode::Prefer);
2192 }
2193
2194 #[test]
2199 fn parse_database_url_basic() {
2200 let (host, port, user, db, pw) =
2201 super::PgDriver::parse_database_url("postgresql://admin:secret@localhost:5432/mydb")
2202 .unwrap();
2203 assert_eq!(host, "localhost");
2204 assert_eq!(port, 5432);
2205 assert_eq!(user, "admin");
2206 assert_eq!(db, "mydb");
2207 assert_eq!(pw, Some("secret".to_string()));
2208 }
2209
2210 #[test]
2211 fn parse_database_url_strips_query_params() {
2212 let (_, _, _, db, _) = super::PgDriver::parse_database_url(
2213 "postgresql://user:pass@host:5432/mydb?sslmode=require&auth_mode=scram_only",
2214 )
2215 .unwrap();
2216 assert_eq!(db, "mydb", "query params must not leak into database name");
2217 }
2218
2219 #[test]
2220 fn parse_database_url_strips_single_query_param() {
2221 let (_, _, _, db, _) =
2222 super::PgDriver::parse_database_url("postgres://u:p@h/testdb?gss_provider=linux_krb5")
2223 .unwrap();
2224 assert_eq!(db, "testdb");
2225 }
2226
2227 #[test]
2228 fn parse_database_url_no_query_still_works() {
2229 let (_, _, _, db, _) =
2230 super::PgDriver::parse_database_url("postgresql://user@host:5432/cleandb").unwrap();
2231 assert_eq!(db, "cleandb");
2232 }
2233}