Skip to main content

heliosdb_proxy/pool/
reset.rs

1//! Connection Reset Executor
2//!
3//! Handles resetting connection state when returning connections to the pool.
4
5use crate::{ProxyError, Result};
6
7/// Connection reset executor
8///
9/// Executes reset queries to clear session state before returning connections to the pool.
10pub struct ConnectionResetExecutor {
11    /// SQL to execute for reset
12    reset_query: String,
13    /// Whether to use DISCARD ALL
14    use_discard_all: bool,
15    /// Custom reset commands
16    custom_commands: Vec<String>,
17}
18
19impl Default for ConnectionResetExecutor {
20    fn default() -> Self {
21        Self::new("DISCARD ALL")
22    }
23}
24
25impl ConnectionResetExecutor {
26    /// Create a new reset executor with the given query
27    pub fn new(reset_query: impl Into<String>) -> Self {
28        let query = reset_query.into();
29        let use_discard_all = query.to_uppercase().contains("DISCARD ALL");
30
31        Self {
32            reset_query: query,
33            use_discard_all,
34            custom_commands: Vec::new(),
35        }
36    }
37
38    /// Create a reset executor with multiple commands
39    pub fn with_commands(commands: Vec<String>) -> Self {
40        Self {
41            reset_query: String::new(),
42            use_discard_all: false,
43            custom_commands: commands,
44        }
45    }
46
47    /// Add a custom reset command
48    pub fn add_command(&mut self, command: impl Into<String>) {
49        self.custom_commands.push(command.into());
50    }
51
52    /// Get the reset query (or queries)
53    pub fn reset_queries(&self) -> Vec<&str> {
54        if !self.custom_commands.is_empty() {
55            self.custom_commands.iter().map(|s| s.as_str()).collect()
56        } else {
57            vec![&self.reset_query]
58        }
59    }
60
61    /// Check if using DISCARD ALL
62    pub fn uses_discard_all(&self) -> bool {
63        self.use_discard_all
64    }
65
66    /// Build the complete reset SQL (for protocols that support multi-statement)
67    pub fn build_reset_sql(&self) -> String {
68        if !self.custom_commands.is_empty() {
69            self.custom_commands.join("; ")
70        } else {
71            self.reset_query.clone()
72        }
73    }
74
75    /// Validate that reset queries are safe
76    ///
77    /// Returns an error if the reset queries contain potentially dangerous statements.
78    pub fn validate(&self) -> Result<()> {
79        let queries = self.reset_queries();
80
81        for query in queries {
82            let upper = query.to_uppercase();
83
84            // Disallow data modification
85            if upper.contains("INSERT")
86                || upper.contains("UPDATE")
87                || upper.contains("DELETE")
88                || upper.contains("DROP")
89                || upper.contains("CREATE")
90                || upper.contains("ALTER")
91                || upper.contains("TRUNCATE")
92            {
93                return Err(ProxyError::Configuration(format!(
94                    "Reset query cannot contain data modification: {}",
95                    query
96                )));
97            }
98
99            // Disallow transaction control
100            if upper.contains("BEGIN") || upper.contains("COMMIT") || upper.contains("ROLLBACK") {
101                return Err(ProxyError::Configuration(format!(
102                    "Reset query cannot contain transaction control: {}",
103                    query
104                )));
105            }
106        }
107
108        Ok(())
109    }
110}
111
112/// What DISCARD ALL resets in PostgreSQL:
113///
114/// - Prepared statements (DEALLOCATE ALL)
115/// - Temporary tables (unlisted)
116/// - Session variables (RESET ALL)
117/// - Session-local advisory locks (pg_advisory_unlock_all)
118/// - Sequences (not reset)
119///
120/// Equivalent to:
121/// ```sql
122/// CLOSE ALL;
123/// DEALLOCATE ALL;
124/// UNLISTEN *;
125/// SELECT pg_advisory_unlock_all();
126/// DISCARD PLANS;
127/// DISCARD SEQUENCES;
128/// DISCARD TEMP;
129/// RESET ALL;
130/// ```
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub enum ResetLevel {
133    /// Full reset (DISCARD ALL)
134    Full,
135    /// Reset prepared statements only (DEALLOCATE ALL)
136    PreparedStatements,
137    /// Reset session variables only (RESET ALL)
138    SessionVariables,
139    /// Minimal reset (just advisory locks)
140    Minimal,
141    /// No reset
142    None,
143}
144
145impl ResetLevel {
146    /// Get the SQL for this reset level
147    pub fn sql(&self) -> Option<&'static str> {
148        match self {
149            ResetLevel::Full => Some("DISCARD ALL"),
150            ResetLevel::PreparedStatements => Some("DEALLOCATE ALL"),
151            ResetLevel::SessionVariables => Some("RESET ALL"),
152            ResetLevel::Minimal => Some("SELECT pg_advisory_unlock_all()"),
153            ResetLevel::None => None,
154        }
155    }
156
157    /// Create an executor for this reset level
158    pub fn executor(&self) -> ConnectionResetExecutor {
159        match self.sql() {
160            Some(sql) => ConnectionResetExecutor::new(sql),
161            None => ConnectionResetExecutor {
162                reset_query: String::new(),
163                use_discard_all: false,
164                custom_commands: Vec::new(),
165            },
166        }
167    }
168}
169
170/// Builder for customizing reset behavior
171pub struct ResetBuilder {
172    commands: Vec<String>,
173}
174
175impl Default for ResetBuilder {
176    fn default() -> Self {
177        Self::new()
178    }
179}
180
181impl ResetBuilder {
182    /// Create a new reset builder
183    pub fn new() -> Self {
184        Self {
185            commands: Vec::new(),
186        }
187    }
188
189    /// Add DEALLOCATE ALL (clear prepared statements)
190    pub fn deallocate_all(mut self) -> Self {
191        self.commands.push("DEALLOCATE ALL".to_string());
192        self
193    }
194
195    /// Add CLOSE ALL (close cursors)
196    pub fn close_cursors(mut self) -> Self {
197        self.commands.push("CLOSE ALL".to_string());
198        self
199    }
200
201    /// Add UNLISTEN * (stop listening for notifications)
202    pub fn unlisten_all(mut self) -> Self {
203        self.commands.push("UNLISTEN *".to_string());
204        self
205    }
206
207    /// Add RESET ALL (reset session variables)
208    pub fn reset_all(mut self) -> Self {
209        self.commands.push("RESET ALL".to_string());
210        self
211    }
212
213    /// Add advisory lock release
214    pub fn release_advisory_locks(mut self) -> Self {
215        self.commands
216            .push("SELECT pg_advisory_unlock_all()".to_string());
217        self
218    }
219
220    /// Add DISCARD PLANS (clear cached query plans)
221    pub fn discard_plans(mut self) -> Self {
222        self.commands.push("DISCARD PLANS".to_string());
223        self
224    }
225
226    /// Add DISCARD TEMP (drop temporary tables)
227    pub fn discard_temp(mut self) -> Self {
228        self.commands.push("DISCARD TEMP".to_string());
229        self
230    }
231
232    /// Add a custom command
233    pub fn custom(mut self, command: impl Into<String>) -> Self {
234        self.commands.push(command.into());
235        self
236    }
237
238    /// Build the reset executor
239    pub fn build(self) -> ConnectionResetExecutor {
240        if self.commands.is_empty() {
241            ConnectionResetExecutor::default()
242        } else {
243            ConnectionResetExecutor::with_commands(self.commands)
244        }
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_default_reset() {
254        let executor = ConnectionResetExecutor::default();
255        assert!(executor.uses_discard_all());
256        assert_eq!(executor.reset_queries(), vec!["DISCARD ALL"]);
257    }
258
259    #[test]
260    fn test_custom_reset() {
261        let executor = ConnectionResetExecutor::new("RESET ALL");
262        assert!(!executor.uses_discard_all());
263        assert_eq!(executor.reset_queries(), vec!["RESET ALL"]);
264    }
265
266    #[test]
267    fn test_multiple_commands() {
268        let executor = ConnectionResetExecutor::with_commands(vec![
269            "DEALLOCATE ALL".to_string(),
270            "RESET ALL".to_string(),
271        ]);
272        assert_eq!(executor.reset_queries(), vec!["DEALLOCATE ALL", "RESET ALL"]);
273        assert_eq!(executor.build_reset_sql(), "DEALLOCATE ALL; RESET ALL");
274    }
275
276    #[test]
277    fn test_validation_success() {
278        let executor = ConnectionResetExecutor::default();
279        assert!(executor.validate().is_ok());
280
281        let executor = ConnectionResetExecutor::new("RESET ALL");
282        assert!(executor.validate().is_ok());
283    }
284
285    #[test]
286    fn test_validation_failure() {
287        let executor = ConnectionResetExecutor::new("DROP TABLE users");
288        assert!(executor.validate().is_err());
289
290        let executor = ConnectionResetExecutor::new("INSERT INTO log VALUES (1)");
291        assert!(executor.validate().is_err());
292
293        let executor = ConnectionResetExecutor::new("BEGIN; RESET ALL; COMMIT");
294        assert!(executor.validate().is_err());
295    }
296
297    #[test]
298    fn test_reset_level() {
299        assert_eq!(ResetLevel::Full.sql(), Some("DISCARD ALL"));
300        assert_eq!(ResetLevel::PreparedStatements.sql(), Some("DEALLOCATE ALL"));
301        assert_eq!(ResetLevel::None.sql(), None);
302    }
303
304    #[test]
305    fn test_reset_builder() {
306        let executor = ResetBuilder::new()
307            .deallocate_all()
308            .close_cursors()
309            .reset_all()
310            .build();
311
312        let queries = executor.reset_queries();
313        assert_eq!(queries.len(), 3);
314        assert!(queries.contains(&"DEALLOCATE ALL"));
315        assert!(queries.contains(&"CLOSE ALL"));
316        assert!(queries.contains(&"RESET ALL"));
317    }
318}