Skip to main content

qail_pg/driver/connection/
types.rs

1//! PostgreSQL Connection
2//!
3//! Low-level TCP connection with wire protocol handling.
4//! This is Layer 3 (async I/O).
5//!
6//! Methods are split across modules for easier maintenance:
7//! - `io.rs` - Core I/O (send, recv)
8//! - `query.rs` - Query execution
9//! - `transaction.rs` - Transaction control
10//! - `cursor.rs` - Streaming cursors
11//! - `copy.rs` - COPY protocol
12//! - `pipeline.rs` - High-performance pipelining
13//! - `cancel.rs` - Query cancellation
14
15use super::super::notification::Notification;
16use super::super::stream::PgStream;
17use super::super::{AuthSettings, EnterpriseAuthMechanism};
18use bytes::BytesMut;
19use std::collections::{HashMap, VecDeque};
20use std::num::NonZeroUsize;
21use std::sync::Arc;
22use std::sync::atomic::AtomicU64;
23use tokio::net::TcpStream;
24
25/// Statement cache capacity per connection.
26pub(super) const STMT_CACHE_CAPACITY: NonZeroUsize = NonZeroUsize::new(100).unwrap();
27
28/// Small, allocation-bounded prepared statement cache.
29///
30/// This mirrors the subset of `lru::LruCache` APIs used by the driver while
31/// avoiding external unsoundness advisories on `IterMut` (which we don't use).
32#[derive(Debug)]
33pub(crate) struct StatementCache {
34    capacity: NonZeroUsize,
35    entries: HashMap<u64, String>,
36    order: VecDeque<u64>, // Front = LRU, back = MRU
37}
38
39impl StatementCache {
40    pub(crate) fn new(capacity: NonZeroUsize) -> Self {
41        Self {
42            capacity,
43            entries: HashMap::with_capacity(capacity.get()),
44            order: VecDeque::with_capacity(capacity.get()),
45        }
46    }
47
48    pub(crate) fn len(&self) -> usize {
49        self.entries.len()
50    }
51
52    pub(crate) fn cap(&self) -> NonZeroUsize {
53        self.capacity
54    }
55
56    pub(crate) fn contains(&self, key: &u64) -> bool {
57        self.entries.contains_key(key)
58    }
59
60    pub(crate) fn get(&mut self, key: &u64) -> Option<String> {
61        let value = self.entries.get(key).cloned()?;
62        self.touch(*key);
63        Some(value)
64    }
65
66    pub(crate) fn put(&mut self, key: u64, value: String) {
67        if let std::collections::hash_map::Entry::Occupied(mut e) = self.entries.entry(key) {
68            e.insert(value);
69            self.touch(key);
70            return;
71        }
72
73        if self.entries.len() >= self.capacity.get() {
74            let _ = self.pop_lru();
75        }
76
77        self.entries.insert(key, value);
78        self.order.push_back(key);
79    }
80
81    pub(crate) fn pop_lru(&mut self) -> Option<(u64, String)> {
82        while let Some(key) = self.order.pop_front() {
83            if let Some(value) = self.entries.remove(&key) {
84                return Some((key, value));
85            }
86        }
87        None
88    }
89
90    pub(crate) fn remove(&mut self, key: &u64) -> Option<String> {
91        let removed = self.entries.remove(key);
92        if removed.is_some() {
93            self.order.retain(|k| k != key);
94        }
95        removed
96    }
97
98    pub(crate) fn clear(&mut self) {
99        self.entries.clear();
100        self.order.clear();
101    }
102
103    fn touch(&mut self, key: u64) {
104        self.order.retain(|k| *k != key);
105        self.order.push_back(key);
106    }
107}
108
109/// Initial buffer capacity (64KB for pipeline performance)
110pub(crate) const BUFFER_CAPACITY: usize = 65536;
111
112/// SSLRequest message bytes (request code: 80877103)
113pub(super) const SSL_REQUEST: [u8; 8] = [0, 0, 0, 8, 4, 210, 22, 47];
114
115/// GSSENCRequest message bytes (request code: 80877104)
116/// Byte breakdown: length=8 (00 00 00 08), code=80877104 (04 D2 16 30)
117pub(super) const GSSENC_REQUEST: [u8; 8] = [0, 0, 0, 8, 4, 210, 22, 48];
118
119/// Result of sending a GSSENCRequest to the server.
120#[derive(Debug)]
121pub(super) enum GssEncNegotiationResult {
122    /// Server responded 'G' — willing to perform GSSAPI encryption.
123    /// The TCP stream is returned for the caller to establish the
124    /// GSSAPI security context and wrap all subsequent traffic.
125    Accepted(TcpStream),
126    /// Server responded 'N' — unwilling to perform GSSAPI encryption.
127    Rejected,
128    /// Server sent an ErrorMessage — must not be displayed to user
129    /// (CVE-2024-10977: server not yet authenticated).
130    ServerError,
131}
132
133/// CancelRequest protocol code: 80877102
134pub(crate) const CANCEL_REQUEST_CODE: i32 = 80877102;
135
136/// Monotonic session id source for stateful GSS provider callbacks.
137pub(super) static GSS_SESSION_COUNTER: AtomicU64 = AtomicU64::new(1);
138
139/// Default timeout for TCP connect + PostgreSQL handshake.
140/// Prevents Slowloris DoS where a malicious server accepts TCP but never responds.
141pub(crate) const DEFAULT_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
142pub(super) const CONNECT_TRANSPORT_PLAIN: &str = "plain";
143pub(super) const CONNECT_TRANSPORT_TLS: &str = "tls";
144pub(super) const CONNECT_TRANSPORT_MTLS: &str = "mtls";
145pub(super) const CONNECT_TRANSPORT_GSSENC: &str = "gssenc";
146pub(super) const CONNECT_BACKEND_TOKIO: &str = "tokio";
147#[cfg(all(target_os = "linux", feature = "io_uring"))]
148pub(super) const CONNECT_BACKEND_IO_URING: &str = "io_uring";
149
150/// TLS configuration for mutual TLS (client certificate authentication).
151#[derive(Debug, Clone)]
152pub struct TlsConfig {
153    /// Client certificate in PEM format
154    pub client_cert_pem: Vec<u8>,
155    /// Client private key in PEM format
156    pub client_key_pem: Vec<u8>,
157    /// Optional CA certificate for server verification (uses system certs if None)
158    pub ca_cert_pem: Option<Vec<u8>>,
159}
160
161impl TlsConfig {
162    /// Create a new TLS config from file paths.
163    pub fn from_files(
164        cert_path: impl AsRef<std::path::Path>,
165        key_path: impl AsRef<std::path::Path>,
166        ca_path: Option<impl AsRef<std::path::Path>>,
167    ) -> std::io::Result<Self> {
168        Ok(Self {
169            client_cert_pem: std::fs::read(cert_path)?,
170            client_key_pem: std::fs::read(key_path)?,
171            ca_cert_pem: ca_path.map(|p| std::fs::read(p)).transpose()?,
172        })
173    }
174}
175
176/// Bundled connection parameters for internal functions.
177///
178/// Groups the 8 common arguments to avoid exceeding clippy's
179/// `too_many_arguments` threshold.
180pub(super) struct ConnectParams<'a> {
181    pub(super) host: &'a str,
182    pub(super) port: u16,
183    pub(super) user: &'a str,
184    pub(super) database: &'a str,
185    pub(super) password: Option<&'a str>,
186    pub(super) auth_settings: AuthSettings,
187    pub(super) gss_token_provider: Option<super::super::GssTokenProvider>,
188    pub(super) gss_token_provider_ex: Option<super::super::GssTokenProviderEx>,
189    pub(super) startup_params: Vec<(String, String)>,
190}
191
192#[inline]
193pub(super) fn has_logical_replication_startup_mode(startup_params: &[(String, String)]) -> bool {
194    startup_params
195        .iter()
196        .any(|(k, v)| k.eq_ignore_ascii_case("replication") && v.eq_ignore_ascii_case("database"))
197}
198
199#[derive(Debug, Clone, Copy, PartialEq, Eq)]
200pub(super) enum StartupAuthFlow {
201    CleartextPassword,
202    Md5Password,
203    Scram { server_final_seen: bool },
204    EnterpriseGss { mechanism: EnterpriseAuthMechanism },
205}
206
207impl StartupAuthFlow {
208    pub(super) fn label(self) -> &'static str {
209        match self {
210            Self::CleartextPassword => "cleartext-password",
211            Self::Md5Password => "md5-password",
212            Self::Scram { .. } => "scram",
213            Self::EnterpriseGss { mechanism } => match mechanism {
214                EnterpriseAuthMechanism::KerberosV5 => "kerberos-v5",
215                EnterpriseAuthMechanism::GssApi => "gssapi",
216                EnterpriseAuthMechanism::Sspi => "sspi",
217            },
218        }
219    }
220}
221
222/// A raw PostgreSQL connection.
223pub struct PgConnection {
224    pub(crate) stream: PgStream,
225    pub(crate) buffer: BytesMut,
226    pub(crate) write_buf: BytesMut,
227    pub(crate) sql_buf: BytesMut,
228    pub(crate) params_buf: Vec<Option<Vec<u8>>>,
229    pub(crate) prepared_statements: HashMap<String, String>,
230    pub(crate) stmt_cache: StatementCache,
231    /// Cache of column metadata (RowDescription) per statement hash.
232    /// PostgreSQL only sends RowDescription after Parse, not on subsequent Bind+Execute.
233    /// This cache ensures by-name column access works even for cached prepared statements.
234    pub(crate) column_info_cache: HashMap<u64, Arc<super::super::ColumnInfo>>,
235    pub(crate) process_id: i32,
236    pub(crate) secret_key: i32,
237    /// Buffer for asynchronous LISTEN/NOTIFY notifications.
238    /// Populated by `recv()` when it encounters NotificationResponse messages.
239    pub(crate) notifications: VecDeque<Notification>,
240    /// True while a logical replication CopyBoth stream is active.
241    pub(crate) replication_stream_active: bool,
242    /// True when StartupMessage was sent with `replication=database`.
243    pub(crate) replication_mode_enabled: bool,
244    /// Last seen wal_end from a replication XLogData frame.
245    pub(crate) last_replication_wal_end: Option<u64>,
246    /// Sticky fail-closed flag for uncertain protocol/I-O state.
247    /// Once set, the connection must not return to pool reuse.
248    pub(crate) io_desynced: bool,
249    /// Statement names scheduled for server-side `Close` on next write.
250    /// This keeps backend prepared state aligned with local LRU eviction.
251    pub(crate) pending_statement_closes: Vec<String>,
252    /// Reentrancy guard for pending-close drain path.
253    pub(crate) draining_statement_closes: bool,
254}