Skip to main content

qail_pg/driver/connection/
types.rs

1//! PostgreSQL Connection
2//!
3//! Low-level TCP connection with wire protocol handling.
4//! This is Layer 3 (async I/O).
5//!
6//! Methods are split across modules for easier maintenance:
7//! - `io.rs` - Core I/O (send, recv)
8//! - `query.rs` - Query execution
9//! - `transaction.rs` - Transaction control
10//! - `cursor.rs` - Streaming cursors
11//! - `copy.rs` - COPY protocol
12//! - `pipeline.rs` - High-performance pipelining
13//! - `cancel.rs` - Query cancellation
14
15use super::super::notification::Notification;
16use super::super::stream::PgStream;
17use super::super::{AuthSettings, EnterpriseAuthMechanism};
18use crate::protocol::PROTOCOL_VERSION_3_2;
19use bytes::BytesMut;
20use std::collections::{HashMap, VecDeque};
21use std::num::NonZeroUsize;
22use std::sync::Arc;
23use std::sync::atomic::AtomicU64;
24use tokio::net::TcpStream;
25
26/// Statement cache capacity per connection.
27pub(super) const STMT_CACHE_CAPACITY: NonZeroUsize = NonZeroUsize::new(100).unwrap();
28
29/// Small, allocation-bounded prepared statement cache.
30///
31/// This mirrors the subset of `lru::LruCache` APIs used by the driver while
32/// avoiding external unsoundness advisories on `IterMut` (which we don't use).
33#[derive(Debug)]
34pub(crate) struct StatementCache {
35    capacity: NonZeroUsize,
36    entries: HashMap<u64, String>,
37    order: VecDeque<u64>, // Front = LRU, back = MRU
38}
39
40impl StatementCache {
41    pub(crate) fn new(capacity: NonZeroUsize) -> Self {
42        Self {
43            capacity,
44            entries: HashMap::with_capacity(capacity.get()),
45            order: VecDeque::with_capacity(capacity.get()),
46        }
47    }
48
49    pub(crate) fn len(&self) -> usize {
50        self.entries.len()
51    }
52
53    pub(crate) fn cap(&self) -> NonZeroUsize {
54        self.capacity
55    }
56
57    pub(crate) fn contains(&self, key: &u64) -> bool {
58        self.entries.contains_key(key)
59    }
60
61    pub(crate) fn get(&mut self, key: &u64) -> Option<String> {
62        let value = self.entries.get(key).cloned()?;
63        self.touch(*key);
64        Some(value)
65    }
66
67    pub(crate) fn peek(&self, key: &u64) -> Option<&str> {
68        self.entries.get(key).map(String::as_str)
69    }
70
71    pub(crate) fn touch_key(&mut self, key: u64) {
72        if self.entries.contains_key(&key) {
73            self.touch(key);
74        }
75    }
76
77    pub(crate) fn put(&mut self, key: u64, value: String) {
78        if let std::collections::hash_map::Entry::Occupied(mut e) = self.entries.entry(key) {
79            e.insert(value);
80            self.touch(key);
81            return;
82        }
83
84        if self.entries.len() >= self.capacity.get() {
85            let _ = self.pop_lru();
86        }
87
88        self.entries.insert(key, value);
89        self.order.push_back(key);
90    }
91
92    pub(crate) fn pop_lru(&mut self) -> Option<(u64, String)> {
93        while let Some(key) = self.order.pop_front() {
94            if let Some(value) = self.entries.remove(&key) {
95                return Some((key, value));
96            }
97        }
98        None
99    }
100
101    pub(crate) fn remove(&mut self, key: &u64) -> Option<String> {
102        let removed = self.entries.remove(key);
103        if removed.is_some() {
104            self.order.retain(|k| k != key);
105        }
106        removed
107    }
108
109    pub(crate) fn clear(&mut self) {
110        self.entries.clear();
111        self.order.clear();
112    }
113
114    fn touch(&mut self, key: u64) {
115        self.order.retain(|k| *k != key);
116        self.order.push_back(key);
117    }
118}
119
120/// Initial buffer capacity (64KB for pipeline performance)
121pub(crate) const BUFFER_CAPACITY: usize = 65536;
122
123/// SSLRequest message bytes (request code: 80877103)
124pub(super) const SSL_REQUEST: [u8; 8] = [0, 0, 0, 8, 4, 210, 22, 47];
125
126/// GSSENCRequest message bytes (request code: 80877104)
127/// Byte breakdown: length=8 (00 00 00 08), code=80877104 (04 D2 16 30)
128pub(super) const GSSENC_REQUEST: [u8; 8] = [0, 0, 0, 8, 4, 210, 22, 48];
129
130/// Result of sending a GSSENCRequest to the server.
131#[derive(Debug)]
132pub(super) enum GssEncNegotiationResult {
133    /// Server responded 'G' — willing to perform GSSAPI encryption.
134    /// The TCP stream is returned for the caller to establish the
135    /// GSSAPI security context and wrap all subsequent traffic.
136    Accepted(TcpStream),
137    /// Server responded 'N' — unwilling to perform GSSAPI encryption.
138    Rejected,
139    /// Server sent an ErrorMessage — must not be displayed to user
140    /// (CVE-2024-10977: server not yet authenticated).
141    ServerError,
142}
143
144/// CancelRequest protocol code: 80877102
145pub(crate) const CANCEL_REQUEST_CODE: i32 = 80877102;
146
147/// Monotonic session id source for stateful GSS provider callbacks.
148pub(super) static GSS_SESSION_COUNTER: AtomicU64 = AtomicU64::new(1);
149
150/// Default timeout for TCP connect + PostgreSQL handshake.
151/// Prevents Slowloris DoS where a malicious server accepts TCP but never responds.
152pub(crate) const DEFAULT_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
153pub(super) const CONNECT_TRANSPORT_PLAIN: &str = "plain";
154pub(super) const CONNECT_TRANSPORT_TLS: &str = "tls";
155pub(super) const CONNECT_TRANSPORT_MTLS: &str = "mtls";
156pub(super) const CONNECT_TRANSPORT_GSSENC: &str = "gssenc";
157pub(super) const CONNECT_BACKEND_TOKIO: &str = "tokio";
158#[cfg(all(target_os = "linux", feature = "io_uring"))]
159pub(super) const CONNECT_BACKEND_IO_URING: &str = "io_uring";
160
161/// TLS configuration for mutual TLS (client certificate authentication).
162#[derive(Debug, Clone)]
163pub struct TlsConfig {
164    /// Client certificate in PEM format
165    pub client_cert_pem: Vec<u8>,
166    /// Client private key in PEM format
167    pub client_key_pem: Vec<u8>,
168    /// Optional CA certificate for server verification (uses system certs if None)
169    pub ca_cert_pem: Option<Vec<u8>>,
170}
171
172impl TlsConfig {
173    /// Create a new TLS config from file paths.
174    pub fn from_files(
175        cert_path: impl AsRef<std::path::Path>,
176        key_path: impl AsRef<std::path::Path>,
177        ca_path: Option<impl AsRef<std::path::Path>>,
178    ) -> std::io::Result<Self> {
179        Ok(Self {
180            client_cert_pem: std::fs::read(cert_path)?,
181            client_key_pem: std::fs::read(key_path)?,
182            ca_cert_pem: ca_path.map(|p| std::fs::read(p)).transpose()?,
183        })
184    }
185}
186
187/// Bundled connection parameters for internal functions.
188///
189/// Groups the 8 common arguments to avoid exceeding clippy's
190/// `too_many_arguments` threshold.
191#[derive(Clone)]
192pub(super) struct ConnectParams<'a> {
193    pub(super) host: &'a str,
194    pub(super) port: u16,
195    pub(super) user: &'a str,
196    pub(super) database: &'a str,
197    pub(super) password: Option<&'a str>,
198    pub(super) auth_settings: AuthSettings,
199    pub(super) gss_token_provider: Option<super::super::GssTokenProvider>,
200    pub(super) gss_token_provider_ex: Option<super::super::GssTokenProviderEx>,
201    pub(super) protocol_minor: u16,
202    pub(super) startup_params: Vec<(String, String)>,
203}
204
205#[inline]
206pub(super) fn has_logical_replication_startup_mode(startup_params: &[(String, String)]) -> bool {
207    startup_params
208        .iter()
209        .any(|(k, v)| k.eq_ignore_ascii_case("replication") && v.eq_ignore_ascii_case("database"))
210}
211
212#[derive(Debug, Clone, Copy, PartialEq, Eq)]
213pub(super) enum StartupAuthFlow {
214    CleartextPassword,
215    Md5Password,
216    Scram { server_final_seen: bool },
217    EnterpriseGss { mechanism: EnterpriseAuthMechanism },
218}
219
220impl StartupAuthFlow {
221    pub(super) fn label(self) -> &'static str {
222        match self {
223            Self::CleartextPassword => "cleartext-password",
224            Self::Md5Password => "md5-password",
225            Self::Scram { .. } => "scram",
226            Self::EnterpriseGss { mechanism } => match mechanism {
227                EnterpriseAuthMechanism::KerberosV5 => "kerberos-v5",
228                EnterpriseAuthMechanism::GssApi => "gssapi",
229                EnterpriseAuthMechanism::Sspi => "sspi",
230            },
231        }
232    }
233}
234
235/// A raw PostgreSQL connection.
236pub struct PgConnection {
237    pub(crate) stream: PgStream,
238    pub(crate) buffer: BytesMut,
239    pub(crate) write_buf: BytesMut,
240    pub(crate) sql_buf: BytesMut,
241    pub(crate) params_buf: Vec<Option<Vec<u8>>>,
242    pub(crate) prepared_statements: HashMap<String, String>,
243    pub(crate) stmt_cache: StatementCache,
244    /// Cache of column metadata (RowDescription) per statement hash.
245    /// PostgreSQL only sends RowDescription after Parse, not on subsequent Bind+Execute.
246    /// This cache ensures by-name column access works even for cached prepared statements.
247    pub(crate) column_info_cache: HashMap<u64, Arc<super::super::ColumnInfo>>,
248    pub(crate) process_id: i32,
249    /// Legacy 4-byte cancel secret key (protocol 3.0-compatible wrappers).
250    ///
251    /// For protocol 3.2 extended key lengths, this remains `0` and callers
252    /// must use `cancel_key_bytes`.
253    pub(crate) secret_key: i32,
254    /// Full cancel key bytes (`4..=256`) from BackendKeyData.
255    pub(crate) cancel_key_bytes: Vec<u8>,
256    /// Startup protocol minor requested by this connection (for example `2` for 3.2).
257    pub(crate) requested_protocol_minor: u16,
258    /// Startup protocol minor negotiated with the server.
259    pub(crate) negotiated_protocol_minor: u16,
260    /// Buffer for asynchronous LISTEN/NOTIFY notifications.
261    /// Populated by `recv()` when it encounters NotificationResponse messages.
262    pub(crate) notifications: VecDeque<Notification>,
263    /// True while a logical replication CopyBoth stream is active.
264    pub(crate) replication_stream_active: bool,
265    /// True when StartupMessage was sent with `replication=database`.
266    pub(crate) replication_mode_enabled: bool,
267    /// Last seen wal_end from a replication XLogData frame.
268    pub(crate) last_replication_wal_end: Option<u64>,
269    /// Sticky fail-closed flag for uncertain protocol/I-O state.
270    /// Once set, the connection must not return to pool reuse.
271    pub(crate) io_desynced: bool,
272    /// Statement names scheduled for server-side `Close` on next write.
273    /// This keeps backend prepared state aligned with local LRU eviction.
274    pub(crate) pending_statement_closes: Vec<String>,
275    /// Reentrancy guard for pending-close drain path.
276    pub(crate) draining_statement_closes: bool,
277}
278
279impl PgConnection {
280    #[inline]
281    pub(crate) fn default_protocol_minor() -> u16 {
282        (PROTOCOL_VERSION_3_2 & 0xFFFF) as u16
283    }
284
285    /// Startup protocol minor requested by this connection.
286    #[inline]
287    pub fn requested_protocol_minor(&self) -> u16 {
288        self.requested_protocol_minor
289    }
290
291    /// Startup protocol minor negotiated with the server.
292    #[inline]
293    pub fn negotiated_protocol_minor(&self) -> u16 {
294        self.negotiated_protocol_minor
295    }
296
297    /// Active transport backend label for this connection.
298    ///
299    /// Returns `"tokio"` or `"io_uring"` for plain TCP connections,
300    /// and `"tokio"` for TLS/Unix/GSSENC paths.
301    #[inline]
302    pub fn transport_backend(&self) -> &'static str {
303        super::helpers::connect_backend_for_stream(&self.stream)
304    }
305}