Skip to main content

heliosdb_proxy/pool/
mode.rs

1//! Pooling Mode Definitions
2//!
3//! Defines the three connection pooling modes and related enums.
4
5use serde::{Deserialize, Serialize};
6
7/// Connection pooling mode
8///
9/// Determines when connections are returned to the pool.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
11#[serde(rename_all = "lowercase")]
12pub enum PoolingMode {
13    /// Session mode: 1:1 client-to-backend mapping
14    ///
15    /// Connection is held for the entire client session lifetime.
16    /// This is the safest mode, compatible with all PostgreSQL features.
17    #[default]
18    Session,
19
20    /// Transaction mode: Return connection after transaction ends
21    ///
22    /// Connection is returned to the pool after COMMIT or ROLLBACK.
23    /// Provides good connection sharing while maintaining transaction integrity.
24    /// Server-side prepared statements may need re-creation on new connections.
25    Transaction,
26
27    /// Statement mode: Return connection after each statement
28    ///
29    /// Most aggressive connection sharing - returns after every statement.
30    /// Cannot use server-side prepared statements.
31    /// Best for simple queries where maximum connection sharing is desired.
32    Statement,
33}
34
35impl PoolingMode {
36    /// Returns whether this mode supports server-side prepared statements
37    pub fn supports_prepared_statements(&self) -> bool {
38        match self {
39            PoolingMode::Session => true,
40            PoolingMode::Transaction => true, // With tracking/recreation
41            PoolingMode::Statement => false,
42        }
43    }
44
45    /// Returns a human-readable description
46    pub fn description(&self) -> &'static str {
47        match self {
48            PoolingMode::Session => "Hold connection for entire client session",
49            PoolingMode::Transaction => "Return connection after COMMIT/ROLLBACK",
50            PoolingMode::Statement => "Return connection after each statement",
51        }
52    }
53
54    /// Parse from string (case-insensitive)
55    pub fn from_str_lossy(s: &str) -> Self {
56        match s.to_lowercase().as_str() {
57            "session" => PoolingMode::Session,
58            "transaction" | "txn" => PoolingMode::Transaction,
59            "statement" | "stmt" => PoolingMode::Statement,
60            _ => PoolingMode::Session,
61        }
62    }
63}
64
65impl std::fmt::Display for PoolingMode {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            PoolingMode::Session => write!(f, "session"),
69            PoolingMode::Transaction => write!(f, "transaction"),
70            PoolingMode::Statement => write!(f, "statement"),
71        }
72    }
73}
74
75/// Prepared statement handling mode
76///
77/// Controls how server-side prepared statements are handled across connection switches.
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
79#[serde(rename_all = "lowercase")]
80pub enum PreparedStatementMode {
81    /// Disable prepared statements (safest for statement mode)
82    ///
83    /// Forces all queries to use simple query protocol.
84    #[default]
85    Disable,
86
87    /// Track and recreate prepared statements
88    ///
89    /// Records PREPARE commands and replays them on new connections.
90    /// Adds some overhead but maintains compatibility.
91    Track,
92
93    /// Use protocol-level named statements
94    ///
95    /// Leverages PostgreSQL extended query protocol for statement tracking.
96    /// Most efficient but requires careful state management.
97    Named,
98}
99
100impl PreparedStatementMode {
101    /// Returns a human-readable description
102    pub fn description(&self) -> &'static str {
103        match self {
104            PreparedStatementMode::Disable => "Disable prepared statements (safest)",
105            PreparedStatementMode::Track => "Track and recreate on new connections",
106            PreparedStatementMode::Named => "Use protocol-level named statements",
107        }
108    }
109
110    /// Parse from string (case-insensitive)
111    pub fn from_str_lossy(s: &str) -> Self {
112        match s.to_lowercase().as_str() {
113            "disable" | "disabled" | "off" => PreparedStatementMode::Disable,
114            "track" | "tracking" => PreparedStatementMode::Track,
115            "named" | "protocol" => PreparedStatementMode::Named,
116            _ => PreparedStatementMode::Disable,
117        }
118    }
119}
120
121impl std::fmt::Display for PreparedStatementMode {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        match self {
124            PreparedStatementMode::Disable => write!(f, "disable"),
125            PreparedStatementMode::Track => write!(f, "track"),
126            PreparedStatementMode::Named => write!(f, "named"),
127        }
128    }
129}
130
131/// Transaction boundary events
132///
133/// Used to detect when transactions begin and end for mode-aware pooling.
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub enum TransactionEvent {
136    /// BEGIN or START TRANSACTION
137    Begin,
138    /// COMMIT or END
139    Commit,
140    /// ROLLBACK
141    Rollback,
142    /// SAVEPOINT created
143    Savepoint,
144    /// RELEASE SAVEPOINT
145    ReleaseSavepoint,
146    /// ROLLBACK TO SAVEPOINT
147    RollbackToSavepoint,
148    /// Regular statement (not transaction control)
149    Statement,
150}
151
152impl TransactionEvent {
153    /// Parse SQL to detect transaction boundary
154    ///
155    /// # Arguments
156    /// * `sql` - SQL statement to parse
157    ///
158    /// # Returns
159    /// The detected transaction event type
160    pub fn detect(sql: &str) -> Self {
161        use crate::protocol::{contains_ci, starts_with_ci};
162        // Allocation-free: transaction-control statements are short,
163        // and everything else bails on the leading-keyword check —
164        // a multi-megabyte INSERT must not pay an uppercased copy
165        // of itself just to be classified as `Statement`.
166        let trimmed = sql.trim();
167
168        // Check for transaction control commands
169        if starts_with_ci(trimmed, "BEGIN") {
170            return TransactionEvent::Begin;
171        }
172        if starts_with_ci(trimmed, "START TRANSACTION") || starts_with_ci(trimmed, "START ") {
173            // START could be START TRANSACTION
174            if contains_ci(trimmed, "TRANSACTION") {
175                return TransactionEvent::Begin;
176            }
177        }
178        if starts_with_ci(trimmed, "COMMIT") || starts_with_ci(trimmed, "END") {
179            // END is alias for COMMIT in PostgreSQL
180            return TransactionEvent::Commit;
181        }
182        if starts_with_ci(trimmed, "ROLLBACK") {
183            // Check for ROLLBACK TO SAVEPOINT
184            if contains_ci(trimmed, " TO ") {
185                return TransactionEvent::RollbackToSavepoint;
186            }
187            return TransactionEvent::Rollback;
188        }
189        if starts_with_ci(trimmed, "SAVEPOINT") {
190            return TransactionEvent::Savepoint;
191        }
192        if starts_with_ci(trimmed, "RELEASE") {
193            return TransactionEvent::ReleaseSavepoint;
194        }
195
196        TransactionEvent::Statement
197    }
198
199    /// Returns true if this event ends a transaction
200    pub fn is_transaction_end(&self) -> bool {
201        matches!(self, TransactionEvent::Commit | TransactionEvent::Rollback)
202    }
203
204    /// Returns true if this event starts a transaction
205    pub fn is_transaction_start(&self) -> bool {
206        matches!(self, TransactionEvent::Begin)
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    #[test]
215    fn test_pooling_mode_default() {
216        assert_eq!(PoolingMode::default(), PoolingMode::Session);
217    }
218
219    #[test]
220    fn test_pooling_mode_display() {
221        assert_eq!(PoolingMode::Session.to_string(), "session");
222        assert_eq!(PoolingMode::Transaction.to_string(), "transaction");
223        assert_eq!(PoolingMode::Statement.to_string(), "statement");
224    }
225
226    #[test]
227    fn test_pooling_mode_from_str() {
228        assert_eq!(PoolingMode::from_str_lossy("SESSION"), PoolingMode::Session);
229        assert_eq!(
230            PoolingMode::from_str_lossy("transaction"),
231            PoolingMode::Transaction
232        );
233        assert_eq!(PoolingMode::from_str_lossy("txn"), PoolingMode::Transaction);
234        assert_eq!(
235            PoolingMode::from_str_lossy("STATEMENT"),
236            PoolingMode::Statement
237        );
238        assert_eq!(PoolingMode::from_str_lossy("stmt"), PoolingMode::Statement);
239        assert_eq!(
240            PoolingMode::from_str_lossy("unknown"),
241            PoolingMode::Session
242        );
243    }
244
245    #[test]
246    fn test_prepared_statement_mode_default() {
247        assert_eq!(
248            PreparedStatementMode::default(),
249            PreparedStatementMode::Disable
250        );
251    }
252
253    #[test]
254    fn test_transaction_event_detect() {
255        assert_eq!(TransactionEvent::detect("BEGIN"), TransactionEvent::Begin);
256        assert_eq!(
257            TransactionEvent::detect("begin work"),
258            TransactionEvent::Begin
259        );
260        assert_eq!(
261            TransactionEvent::detect("START TRANSACTION"),
262            TransactionEvent::Begin
263        );
264        assert_eq!(TransactionEvent::detect("COMMIT"), TransactionEvent::Commit);
265        assert_eq!(TransactionEvent::detect("END"), TransactionEvent::Commit);
266        assert_eq!(
267            TransactionEvent::detect("ROLLBACK"),
268            TransactionEvent::Rollback
269        );
270        assert_eq!(
271            TransactionEvent::detect("ROLLBACK TO SAVEPOINT sp1"),
272            TransactionEvent::RollbackToSavepoint
273        );
274        assert_eq!(
275            TransactionEvent::detect("SAVEPOINT sp1"),
276            TransactionEvent::Savepoint
277        );
278        assert_eq!(
279            TransactionEvent::detect("RELEASE SAVEPOINT sp1"),
280            TransactionEvent::ReleaseSavepoint
281        );
282        assert_eq!(
283            TransactionEvent::detect("SELECT * FROM users"),
284            TransactionEvent::Statement
285        );
286    }
287
288    #[test]
289    fn test_transaction_event_predicates() {
290        assert!(TransactionEvent::Begin.is_transaction_start());
291        assert!(!TransactionEvent::Begin.is_transaction_end());
292
293        assert!(TransactionEvent::Commit.is_transaction_end());
294        assert!(!TransactionEvent::Commit.is_transaction_start());
295
296        assert!(TransactionEvent::Rollback.is_transaction_end());
297        assert!(!TransactionEvent::Statement.is_transaction_end());
298    }
299}