use crate::{ProxyError, Result};
pub struct ConnectionResetExecutor {
reset_query: String,
use_discard_all: bool,
custom_commands: Vec<String>,
}
impl Default for ConnectionResetExecutor {
fn default() -> Self {
Self::new("DISCARD ALL")
}
}
impl ConnectionResetExecutor {
pub fn new(reset_query: impl Into<String>) -> Self {
let query = reset_query.into();
let use_discard_all = query.to_uppercase().contains("DISCARD ALL");
Self {
reset_query: query,
use_discard_all,
custom_commands: Vec::new(),
}
}
pub fn with_commands(commands: Vec<String>) -> Self {
Self {
reset_query: String::new(),
use_discard_all: false,
custom_commands: commands,
}
}
pub fn add_command(&mut self, command: impl Into<String>) {
self.custom_commands.push(command.into());
}
pub fn reset_queries(&self) -> Vec<&str> {
if !self.custom_commands.is_empty() {
self.custom_commands.iter().map(|s| s.as_str()).collect()
} else {
vec![&self.reset_query]
}
}
pub fn uses_discard_all(&self) -> bool {
self.use_discard_all
}
pub fn build_reset_sql(&self) -> String {
if !self.custom_commands.is_empty() {
self.custom_commands.join("; ")
} else {
self.reset_query.clone()
}
}
pub fn validate(&self) -> Result<()> {
let queries = self.reset_queries();
for query in queries {
let upper = query.to_uppercase();
if upper.contains("INSERT")
|| upper.contains("UPDATE")
|| upper.contains("DELETE")
|| upper.contains("DROP")
|| upper.contains("CREATE")
|| upper.contains("ALTER")
|| upper.contains("TRUNCATE")
{
return Err(ProxyError::Configuration(format!(
"Reset query cannot contain data modification: {}",
query
)));
}
if upper.contains("BEGIN") || upper.contains("COMMIT") || upper.contains("ROLLBACK") {
return Err(ProxyError::Configuration(format!(
"Reset query cannot contain transaction control: {}",
query
)));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ResetLevel {
Full,
PreparedStatements,
SessionVariables,
Minimal,
None,
}
impl ResetLevel {
pub fn sql(&self) -> Option<&'static str> {
match self {
ResetLevel::Full => Some("DISCARD ALL"),
ResetLevel::PreparedStatements => Some("DEALLOCATE ALL"),
ResetLevel::SessionVariables => Some("RESET ALL"),
ResetLevel::Minimal => Some("SELECT pg_advisory_unlock_all()"),
ResetLevel::None => None,
}
}
pub fn executor(&self) -> ConnectionResetExecutor {
match self.sql() {
Some(sql) => ConnectionResetExecutor::new(sql),
None => ConnectionResetExecutor {
reset_query: String::new(),
use_discard_all: false,
custom_commands: Vec::new(),
},
}
}
}
pub struct ResetBuilder {
commands: Vec<String>,
}
impl Default for ResetBuilder {
fn default() -> Self {
Self::new()
}
}
impl ResetBuilder {
pub fn new() -> Self {
Self {
commands: Vec::new(),
}
}
pub fn deallocate_all(mut self) -> Self {
self.commands.push("DEALLOCATE ALL".to_string());
self
}
pub fn close_cursors(mut self) -> Self {
self.commands.push("CLOSE ALL".to_string());
self
}
pub fn unlisten_all(mut self) -> Self {
self.commands.push("UNLISTEN *".to_string());
self
}
pub fn reset_all(mut self) -> Self {
self.commands.push("RESET ALL".to_string());
self
}
pub fn release_advisory_locks(mut self) -> Self {
self.commands
.push("SELECT pg_advisory_unlock_all()".to_string());
self
}
pub fn discard_plans(mut self) -> Self {
self.commands.push("DISCARD PLANS".to_string());
self
}
pub fn discard_temp(mut self) -> Self {
self.commands.push("DISCARD TEMP".to_string());
self
}
pub fn custom(mut self, command: impl Into<String>) -> Self {
self.commands.push(command.into());
self
}
pub fn build(self) -> ConnectionResetExecutor {
if self.commands.is_empty() {
ConnectionResetExecutor::default()
} else {
ConnectionResetExecutor::with_commands(self.commands)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_reset() {
let executor = ConnectionResetExecutor::default();
assert!(executor.uses_discard_all());
assert_eq!(executor.reset_queries(), vec!["DISCARD ALL"]);
}
#[test]
fn test_custom_reset() {
let executor = ConnectionResetExecutor::new("RESET ALL");
assert!(!executor.uses_discard_all());
assert_eq!(executor.reset_queries(), vec!["RESET ALL"]);
}
#[test]
fn test_multiple_commands() {
let executor = ConnectionResetExecutor::with_commands(vec![
"DEALLOCATE ALL".to_string(),
"RESET ALL".to_string(),
]);
assert_eq!(executor.reset_queries(), vec!["DEALLOCATE ALL", "RESET ALL"]);
assert_eq!(executor.build_reset_sql(), "DEALLOCATE ALL; RESET ALL");
}
#[test]
fn test_validation_success() {
let executor = ConnectionResetExecutor::default();
assert!(executor.validate().is_ok());
let executor = ConnectionResetExecutor::new("RESET ALL");
assert!(executor.validate().is_ok());
}
#[test]
fn test_validation_failure() {
let executor = ConnectionResetExecutor::new("DROP TABLE users");
assert!(executor.validate().is_err());
let executor = ConnectionResetExecutor::new("INSERT INTO log VALUES (1)");
assert!(executor.validate().is_err());
let executor = ConnectionResetExecutor::new("BEGIN; RESET ALL; COMMIT");
assert!(executor.validate().is_err());
}
#[test]
fn test_reset_level() {
assert_eq!(ResetLevel::Full.sql(), Some("DISCARD ALL"));
assert_eq!(ResetLevel::PreparedStatements.sql(), Some("DEALLOCATE ALL"));
assert_eq!(ResetLevel::None.sql(), None);
}
#[test]
fn test_reset_builder() {
let executor = ResetBuilder::new()
.deallocate_all()
.close_cursors()
.reset_all()
.build();
let queries = executor.reset_queries();
assert_eq!(queries.len(), 3);
assert!(queries.contains(&"DEALLOCATE ALL"));
assert!(queries.contains(&"CLOSE ALL"));
assert!(queries.contains(&"RESET ALL"));
}
}