use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum PoolingMode {
#[default]
Session,
Transaction,
Statement,
}
impl PoolingMode {
pub fn supports_prepared_statements(&self) -> bool {
match self {
PoolingMode::Session => true,
PoolingMode::Transaction => true, PoolingMode::Statement => false,
}
}
pub fn description(&self) -> &'static str {
match self {
PoolingMode::Session => "Hold connection for entire client session",
PoolingMode::Transaction => "Return connection after COMMIT/ROLLBACK",
PoolingMode::Statement => "Return connection after each statement",
}
}
pub fn from_str_lossy(s: &str) -> Self {
match s.to_lowercase().as_str() {
"session" => PoolingMode::Session,
"transaction" | "txn" => PoolingMode::Transaction,
"statement" | "stmt" => PoolingMode::Statement,
_ => PoolingMode::Session,
}
}
}
impl std::fmt::Display for PoolingMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PoolingMode::Session => write!(f, "session"),
PoolingMode::Transaction => write!(f, "transaction"),
PoolingMode::Statement => write!(f, "statement"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum PreparedStatementMode {
#[default]
Disable,
Track,
Named,
}
impl PreparedStatementMode {
pub fn description(&self) -> &'static str {
match self {
PreparedStatementMode::Disable => "Disable prepared statements (safest)",
PreparedStatementMode::Track => "Track and recreate on new connections",
PreparedStatementMode::Named => "Use protocol-level named statements",
}
}
pub fn from_str_lossy(s: &str) -> Self {
match s.to_lowercase().as_str() {
"disable" | "disabled" | "off" => PreparedStatementMode::Disable,
"track" | "tracking" => PreparedStatementMode::Track,
"named" | "protocol" => PreparedStatementMode::Named,
_ => PreparedStatementMode::Disable,
}
}
}
impl std::fmt::Display for PreparedStatementMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PreparedStatementMode::Disable => write!(f, "disable"),
PreparedStatementMode::Track => write!(f, "track"),
PreparedStatementMode::Named => write!(f, "named"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransactionEvent {
Begin,
Commit,
Rollback,
Savepoint,
ReleaseSavepoint,
RollbackToSavepoint,
Statement,
}
impl TransactionEvent {
pub fn detect(sql: &str) -> Self {
let upper = sql.trim().to_uppercase();
let upper_ref = upper.as_str();
if upper_ref.starts_with("BEGIN") {
return TransactionEvent::Begin;
}
if upper_ref.starts_with("START TRANSACTION") || upper_ref.starts_with("START ") {
if upper.contains("TRANSACTION") {
return TransactionEvent::Begin;
}
}
if upper_ref.starts_with("COMMIT") || upper_ref.starts_with("END") {
return TransactionEvent::Commit;
}
if upper_ref.starts_with("ROLLBACK") {
if upper.contains(" TO ") {
return TransactionEvent::RollbackToSavepoint;
}
return TransactionEvent::Rollback;
}
if upper_ref.starts_with("SAVEPOINT") {
return TransactionEvent::Savepoint;
}
if upper_ref.starts_with("RELEASE") {
return TransactionEvent::ReleaseSavepoint;
}
TransactionEvent::Statement
}
pub fn is_transaction_end(&self) -> bool {
matches!(self, TransactionEvent::Commit | TransactionEvent::Rollback)
}
pub fn is_transaction_start(&self) -> bool {
matches!(self, TransactionEvent::Begin)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pooling_mode_default() {
assert_eq!(PoolingMode::default(), PoolingMode::Session);
}
#[test]
fn test_pooling_mode_display() {
assert_eq!(PoolingMode::Session.to_string(), "session");
assert_eq!(PoolingMode::Transaction.to_string(), "transaction");
assert_eq!(PoolingMode::Statement.to_string(), "statement");
}
#[test]
fn test_pooling_mode_from_str() {
assert_eq!(PoolingMode::from_str_lossy("SESSION"), PoolingMode::Session);
assert_eq!(
PoolingMode::from_str_lossy("transaction"),
PoolingMode::Transaction
);
assert_eq!(PoolingMode::from_str_lossy("txn"), PoolingMode::Transaction);
assert_eq!(
PoolingMode::from_str_lossy("STATEMENT"),
PoolingMode::Statement
);
assert_eq!(PoolingMode::from_str_lossy("stmt"), PoolingMode::Statement);
assert_eq!(
PoolingMode::from_str_lossy("unknown"),
PoolingMode::Session
);
}
#[test]
fn test_prepared_statement_mode_default() {
assert_eq!(
PreparedStatementMode::default(),
PreparedStatementMode::Disable
);
}
#[test]
fn test_transaction_event_detect() {
assert_eq!(TransactionEvent::detect("BEGIN"), TransactionEvent::Begin);
assert_eq!(
TransactionEvent::detect("begin work"),
TransactionEvent::Begin
);
assert_eq!(
TransactionEvent::detect("START TRANSACTION"),
TransactionEvent::Begin
);
assert_eq!(TransactionEvent::detect("COMMIT"), TransactionEvent::Commit);
assert_eq!(TransactionEvent::detect("END"), TransactionEvent::Commit);
assert_eq!(
TransactionEvent::detect("ROLLBACK"),
TransactionEvent::Rollback
);
assert_eq!(
TransactionEvent::detect("ROLLBACK TO SAVEPOINT sp1"),
TransactionEvent::RollbackToSavepoint
);
assert_eq!(
TransactionEvent::detect("SAVEPOINT sp1"),
TransactionEvent::Savepoint
);
assert_eq!(
TransactionEvent::detect("RELEASE SAVEPOINT sp1"),
TransactionEvent::ReleaseSavepoint
);
assert_eq!(
TransactionEvent::detect("SELECT * FROM users"),
TransactionEvent::Statement
);
}
#[test]
fn test_transaction_event_predicates() {
assert!(TransactionEvent::Begin.is_transaction_start());
assert!(!TransactionEvent::Begin.is_transaction_end());
assert!(TransactionEvent::Commit.is_transaction_end());
assert!(!TransactionEvent::Commit.is_transaction_start());
assert!(TransactionEvent::Rollback.is_transaction_end());
assert!(!TransactionEvent::Statement.is_transaction_end());
}
}