Skip to main content

heliosdb_proxy/pool/
mode.rs

1//! Pooling Mode Definitions
2//!
3//! Defines the three connection pooling modes and related enums.
4
5use serde::{Deserialize, Serialize};
6
7/// Connection pooling mode
8///
9/// Determines when connections are returned to the pool.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
11#[serde(rename_all = "lowercase")]
12pub enum PoolingMode {
13    /// Session mode: 1:1 client-to-backend mapping
14    ///
15    /// Connection is held for the entire client session lifetime.
16    /// This is the safest mode, compatible with all PostgreSQL features.
17    #[default]
18    Session,
19
20    /// Transaction mode: Return connection after transaction ends
21    ///
22    /// Connection is returned to the pool after COMMIT or ROLLBACK.
23    /// Provides good connection sharing while maintaining transaction integrity.
24    /// Server-side prepared statements may need re-creation on new connections.
25    Transaction,
26
27    /// Statement mode: Return connection after each statement
28    ///
29    /// Most aggressive connection sharing - returns after every statement.
30    /// Cannot use server-side prepared statements.
31    /// Best for simple queries where maximum connection sharing is desired.
32    Statement,
33}
34
35impl PoolingMode {
36    /// Returns whether this mode supports server-side prepared statements
37    pub fn supports_prepared_statements(&self) -> bool {
38        match self {
39            PoolingMode::Session => true,
40            PoolingMode::Transaction => true, // With tracking/recreation
41            PoolingMode::Statement => false,
42        }
43    }
44
45    /// Returns a human-readable description
46    pub fn description(&self) -> &'static str {
47        match self {
48            PoolingMode::Session => "Hold connection for entire client session",
49            PoolingMode::Transaction => "Return connection after COMMIT/ROLLBACK",
50            PoolingMode::Statement => "Return connection after each statement",
51        }
52    }
53
54    /// Parse from string (case-insensitive)
55    pub fn from_str_lossy(s: &str) -> Self {
56        match s.to_lowercase().as_str() {
57            "session" => PoolingMode::Session,
58            "transaction" | "txn" => PoolingMode::Transaction,
59            "statement" | "stmt" => PoolingMode::Statement,
60            _ => PoolingMode::Session,
61        }
62    }
63}
64
65impl std::fmt::Display for PoolingMode {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            PoolingMode::Session => write!(f, "session"),
69            PoolingMode::Transaction => write!(f, "transaction"),
70            PoolingMode::Statement => write!(f, "statement"),
71        }
72    }
73}
74
75/// Prepared statement handling mode
76///
77/// Controls how server-side prepared statements are handled across connection switches.
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
79#[serde(rename_all = "lowercase")]
80pub enum PreparedStatementMode {
81    /// Disable prepared statements (safest for statement mode)
82    ///
83    /// Forces all queries to use simple query protocol.
84    #[default]
85    Disable,
86
87    /// Track and recreate prepared statements
88    ///
89    /// Records PREPARE commands and replays them on new connections.
90    /// Adds some overhead but maintains compatibility.
91    Track,
92
93    /// Use protocol-level named statements
94    ///
95    /// Leverages PostgreSQL extended query protocol for statement tracking.
96    /// Most efficient but requires careful state management.
97    Named,
98}
99
100impl PreparedStatementMode {
101    /// Returns a human-readable description
102    pub fn description(&self) -> &'static str {
103        match self {
104            PreparedStatementMode::Disable => "Disable prepared statements (safest)",
105            PreparedStatementMode::Track => "Track and recreate on new connections",
106            PreparedStatementMode::Named => "Use protocol-level named statements",
107        }
108    }
109
110    /// Parse from string (case-insensitive)
111    pub fn from_str_lossy(s: &str) -> Self {
112        match s.to_lowercase().as_str() {
113            "disable" | "disabled" | "off" => PreparedStatementMode::Disable,
114            "track" | "tracking" => PreparedStatementMode::Track,
115            "named" | "protocol" => PreparedStatementMode::Named,
116            _ => PreparedStatementMode::Disable,
117        }
118    }
119}
120
121impl std::fmt::Display for PreparedStatementMode {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        match self {
124            PreparedStatementMode::Disable => write!(f, "disable"),
125            PreparedStatementMode::Track => write!(f, "track"),
126            PreparedStatementMode::Named => write!(f, "named"),
127        }
128    }
129}
130
131/// Transaction boundary events
132///
133/// Used to detect when transactions begin and end for mode-aware pooling.
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub enum TransactionEvent {
136    /// BEGIN or START TRANSACTION
137    Begin,
138    /// COMMIT or END
139    Commit,
140    /// ROLLBACK
141    Rollback,
142    /// SAVEPOINT created
143    Savepoint,
144    /// RELEASE SAVEPOINT
145    ReleaseSavepoint,
146    /// ROLLBACK TO SAVEPOINT
147    RollbackToSavepoint,
148    /// Regular statement (not transaction control)
149    Statement,
150}
151
152impl TransactionEvent {
153    /// Parse SQL to detect transaction boundary
154    ///
155    /// # Arguments
156    /// * `sql` - SQL statement to parse
157    ///
158    /// # Returns
159    /// The detected transaction event type
160    pub fn detect(sql: &str) -> Self {
161        let upper = sql.trim().to_uppercase();
162        let upper_ref = upper.as_str();
163
164        // Check for transaction control commands
165        if upper_ref.starts_with("BEGIN") {
166            return TransactionEvent::Begin;
167        }
168        if upper_ref.starts_with("START TRANSACTION") || upper_ref.starts_with("START ") {
169            // START could be START TRANSACTION
170            if upper.contains("TRANSACTION") {
171                return TransactionEvent::Begin;
172            }
173        }
174        if upper_ref.starts_with("COMMIT") || upper_ref.starts_with("END") {
175            // END is alias for COMMIT in PostgreSQL
176            return TransactionEvent::Commit;
177        }
178        if upper_ref.starts_with("ROLLBACK") {
179            // Check for ROLLBACK TO SAVEPOINT
180            if upper.contains(" TO ") {
181                return TransactionEvent::RollbackToSavepoint;
182            }
183            return TransactionEvent::Rollback;
184        }
185        if upper_ref.starts_with("SAVEPOINT") {
186            return TransactionEvent::Savepoint;
187        }
188        if upper_ref.starts_with("RELEASE") {
189            return TransactionEvent::ReleaseSavepoint;
190        }
191
192        TransactionEvent::Statement
193    }
194
195    /// Returns true if this event ends a transaction
196    pub fn is_transaction_end(&self) -> bool {
197        matches!(self, TransactionEvent::Commit | TransactionEvent::Rollback)
198    }
199
200    /// Returns true if this event starts a transaction
201    pub fn is_transaction_start(&self) -> bool {
202        matches!(self, TransactionEvent::Begin)
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn test_pooling_mode_default() {
212        assert_eq!(PoolingMode::default(), PoolingMode::Session);
213    }
214
215    #[test]
216    fn test_pooling_mode_display() {
217        assert_eq!(PoolingMode::Session.to_string(), "session");
218        assert_eq!(PoolingMode::Transaction.to_string(), "transaction");
219        assert_eq!(PoolingMode::Statement.to_string(), "statement");
220    }
221
222    #[test]
223    fn test_pooling_mode_from_str() {
224        assert_eq!(PoolingMode::from_str_lossy("SESSION"), PoolingMode::Session);
225        assert_eq!(
226            PoolingMode::from_str_lossy("transaction"),
227            PoolingMode::Transaction
228        );
229        assert_eq!(PoolingMode::from_str_lossy("txn"), PoolingMode::Transaction);
230        assert_eq!(
231            PoolingMode::from_str_lossy("STATEMENT"),
232            PoolingMode::Statement
233        );
234        assert_eq!(PoolingMode::from_str_lossy("stmt"), PoolingMode::Statement);
235        assert_eq!(
236            PoolingMode::from_str_lossy("unknown"),
237            PoolingMode::Session
238        );
239    }
240
241    #[test]
242    fn test_prepared_statement_mode_default() {
243        assert_eq!(
244            PreparedStatementMode::default(),
245            PreparedStatementMode::Disable
246        );
247    }
248
249    #[test]
250    fn test_transaction_event_detect() {
251        assert_eq!(TransactionEvent::detect("BEGIN"), TransactionEvent::Begin);
252        assert_eq!(
253            TransactionEvent::detect("begin work"),
254            TransactionEvent::Begin
255        );
256        assert_eq!(
257            TransactionEvent::detect("START TRANSACTION"),
258            TransactionEvent::Begin
259        );
260        assert_eq!(TransactionEvent::detect("COMMIT"), TransactionEvent::Commit);
261        assert_eq!(TransactionEvent::detect("END"), TransactionEvent::Commit);
262        assert_eq!(
263            TransactionEvent::detect("ROLLBACK"),
264            TransactionEvent::Rollback
265        );
266        assert_eq!(
267            TransactionEvent::detect("ROLLBACK TO SAVEPOINT sp1"),
268            TransactionEvent::RollbackToSavepoint
269        );
270        assert_eq!(
271            TransactionEvent::detect("SAVEPOINT sp1"),
272            TransactionEvent::Savepoint
273        );
274        assert_eq!(
275            TransactionEvent::detect("RELEASE SAVEPOINT sp1"),
276            TransactionEvent::ReleaseSavepoint
277        );
278        assert_eq!(
279            TransactionEvent::detect("SELECT * FROM users"),
280            TransactionEvent::Statement
281        );
282    }
283
284    #[test]
285    fn test_transaction_event_predicates() {
286        assert!(TransactionEvent::Begin.is_transaction_start());
287        assert!(!TransactionEvent::Begin.is_transaction_end());
288
289        assert!(TransactionEvent::Commit.is_transaction_end());
290        assert!(!TransactionEvent::Commit.is_transaction_start());
291
292        assert!(TransactionEvent::Rollback.is_transaction_end());
293        assert!(!TransactionEvent::Statement.is_transaction_end());
294    }
295}